├── .gitignore ├── LICENSE ├── README.md ├── adapter_bert ├── __init__.py ├── __main__.py ├── configs │ └── adapter-bert.json ├── convert_tf_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── optimization.py ├── tokenization.py └── utils.py ├── analysis ├── Charts.ipynb ├── Language MLP.ipynb └── rob.mplstyle ├── archive_adapters.py ├── archive_bert.py ├── archive_language_mlp.py ├── concat_treebanks.py ├── config ├── archive │ ├── bert-base-multilingual-cased │ │ ├── bert_config.json │ │ └── vocab.txt │ ├── bert-large-cased │ │ ├── bert_config.json │ │ └── vocab.txt │ └── xlm-mlm-100-1280 │ │ └── vocab.json ├── create_vocab.json ├── ud │ ├── feats.json │ └── multilingual │ │ ├── adapter-test.json │ │ ├── udapter-test.json │ │ └── udify-test.json └── udify_base.json ├── create_vocabs.py ├── docs └── model.png ├── find_missing_features.py ├── languages ├── in-langs.txt ├── language_codes.txt ├── letter_codes.json ├── lr-proxy.json └── oov-langs.txt ├── load_adapters.py ├── pg.scripts ├── pg.clean.ckpt.py ├── pg.feat.eval.py ├── pg.predict.py ├── pg.resume.py ├── pg.run.py └── pg.seq.eval.py ├── predict.py ├── requirements.txt ├── scripts ├── concat_ud_data.sh ├── conll18_ud_eval.py ├── download_ud_data.sh ├── evaluate_feats.py ├── generate_tables.py ├── ner_to_ud.py ├── overlap.py ├── seq_eval.py └── split_file_by_lang.py ├── slides ├── SigTyp_Abstract_EMNLP2020.pdf └── UDapter_EMNLP2020.pdf ├── train.py └── udapter ├── __init__.py ├── dataset_readers ├── __init__.py ├── conll18_ud_eval.py ├── lemma_edit.py ├── parser.py └── universal_dependencies.py ├── modules ├── __init__.py ├── bert_adapter.py ├── bert_pretrained.py ├── bucket_iterator_by_languages.py ├── language_emb.py ├── language_mlp.py ├── residual_rnn.py ├── scalar_mix.py ├── text_field_embedder.py ├── token_characters_encoder.py └── xlm_pretrained.py ├── optimizers ├── __init__.py └── ulmfit_sqrt.py ├── predictors ├── __init__.py ├── udapter_predictor.py └── udify_predictor.py ├── udapter_models ├── __init__.py ├── bilinear_matrix_attention.py ├── dependency_decoder.py ├── feedforward.py ├── linear.py ├── tag_decoder.py ├── time_distributed.py └── udapter_model.py ├── udify_models ├── __init__.py ├── dependency_decoder.py ├── tag_decoder.py └── udify_model.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # Project specific 108 | /.idea/ 109 | /data/ 110 | /logs/ 111 | /tmp/ 112 | notebooks/bertviz/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ahmet Üstün 4 | Copyright (c) 2019 Dan Kondratyuk 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UDapter 2 | 3 | [![MIT License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) 4 | 5 | UDapter is a multilingual dependency parser that uses "contextual" adapters together with language-typology features for language-specific adaptation. This repository includes the code for "[UDapter: Language Adaptation for Truly Universal Dependency Parsing](https://arxiv.org/abs/2004.14327)" 6 | 7 | [![UDify Model Architecture](docs/model.png)](https://arxiv.org/pdf/1904.02099.pdf) 8 | 9 | This project is built on [UDify](https://github.com/Hyperparticle/udify) using [AllenNLP](https://allennlp.org/) and [Huggingface Transformers](https://github.com/huggingface/transformers). The code is tested on Python v3.6 10 | 11 | ## Getting Started 12 | 13 | Install the Python packages in `requirements.txt`. 14 | ```bash 15 | pip install -r ./requirements.txt 16 | ``` 17 | 18 | After downloading the UD corpus from [universaldependencies.org](https://universaldependencies.org/), please run `scripts/concat_ud_data.sh` with `--add_lang_id` to generate the multilingual UD dataset with language ids. 19 | 20 | ### Training the Model 21 | 22 | Before training, make sure the dataset is downloaded and extracted into the `data` directory and the multilingual 23 | dataset is generated with `scripts/concat_ud_data.sh`. To indicate training and zero-shot languages use `languages/in-langs` and `languages\oov-langs` respectively. To train the multilingual model, 24 | run the command 25 | 26 | ```bash 27 | python train.py --config config/ud/multilingual/udapter-test.json --name udapter 28 | ``` 29 | 30 | ### Predicting Universal Dependencies from a Trained Model 31 | 32 | To predict UD annotations, one can supply the path to the trained model and an input `conllu`-formatted file with a language id as the last column. To split concatenated treebanks with language id, use `scripts/split_file_by_lang.py`. For prediction: 33 | 34 | ```bash 35 | python predict.py [--eval_file results.json] 36 | ``` 37 | 38 | ## Citing This Research 39 | 40 | If you use UDify for your research, please cite this work as: 41 | 42 | ```latex 43 | @inproceedings{ustun-etal-2020-udapter, 44 | title = {{UD}apter: Language Adaptation for Truly {U}niversal {D}ependency Parsing}, 45 | author = {{\"U}st{\"u}n, Ahmet and 46 | Bisazza, Arianna and 47 | Bouma, Gosse and 48 | van Noord, Gertjan}, 49 | booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 50 | month = nov, 51 | year = {2020}, 52 | address = {Online}, 53 | publisher = {Association for Computational Linguistics}, 54 | url = {https://www.aclweb.org/anthology/2020.emnlp-main.180}, 55 | pages = {2302--2315}, 56 | abstract = {Recent advances in multilingual dependency parsing have brought the idea of a truly universal parser closer to reality. However, cross-language interference and restrained model capacity remain major obstacles. To address this, we propose a novel multilingual task adaptation approach based on contextual parameter generation and adapter modules. This approach enables to learn adapters via language embeddings while sharing model parameters across languages. It also allows for an easy but effective integration of existing linguistic typology features into the parsing network. The resulting parser, UDapter, outperforms strong monolingual and multilingual baselines on the majority of both high-resource and low-resource (zero-shot) languages, showing the success of the proposed adaptation approach. Our in-depth analyses show that soft parameter sharing via typological features is key to this success.}, 57 | } 58 | -------------------------------------------------------------------------------- /adapter_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | 4 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 5 | BertForMaskedLM, BertForNextSentencePrediction, 6 | BertForSequenceClassification, BertForMultipleChoice, 7 | BertForTokenClassification, BertForQuestionAnswering, 8 | load_tf_weights_in_bert) 9 | 10 | from .optimization import BertAdam 11 | 12 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 13 | -------------------------------------------------------------------------------- /adapter_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /adapter_bert/configs/adapter-bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 119547, 13 | "adapter_initializer_range": 0.0001 14 | } -------------------------------------------------------------------------------- /adapter_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import torch 25 | import numpy as np 26 | 27 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /adapter_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except AttributeError: 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /adapter_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /adapter_bert/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def to_cpu(state_dict): 5 | new_state_dict = {} 6 | for k, v in state_dict.items(): 7 | if isinstance(v, dict): 8 | new_state_dict[k] = to_cpu(v) 9 | elif isinstance(v, int): 10 | new_state_dict[k] = v 11 | elif isinstance(v, list): 12 | # May need to change this in the future 13 | assert k == "param_groups" 14 | new_state_dict[k] = v 15 | else: 16 | new_state_dict[k] = v.cpu() 17 | return new_state_dict 18 | 19 | 20 | def only_one_of(ls): 21 | return count_bool(ls) == 1 22 | 23 | 24 | def at_most_one_of(ls): 25 | return count_bool(ls) <= 1 26 | 27 | 28 | def count_bool(ls): 29 | return sum([1 if elem else 0 for elem in ls]) 30 | 31 | 32 | def truncate_seq_pair(tokens_a, tokens_b, max_length): 33 | """Truncates a sequence pair in place to the maximum length.""" 34 | 35 | # This is a simple heuristic which will always truncate the longer sequence 36 | # one token at a time. This makes more sense than truncating an equal percent 37 | # of tokens from each, since if one sequence is very short then each token 38 | # that's truncated likely contains more information than a longer sequence. 39 | while True: 40 | total_length = len(tokens_a) + len(tokens_b) 41 | if total_length <= max_length: 42 | break 43 | if len(tokens_a) > len(tokens_b): 44 | tokens_a.pop() 45 | else: 46 | tokens_b.pop() 47 | 48 | 49 | def random_sample(ls, size, replace=True): 50 | indices = np.random.choice(range(len(ls)), size=size, replace=replace) 51 | return [ls[i] for i in indices] 52 | -------------------------------------------------------------------------------- /analysis/rob.mplstyle: -------------------------------------------------------------------------------- 1 | patch.linewidth: 0.5 2 | patch.facecolor: 348ABD # blue 3 | patch.edgecolor: EEEEEE 4 | patch.antialiased: True 5 | 6 | font.size: 8.0 7 | 8 | axes.facecolor: f2f2f2 9 | axes.edgecolor: white 10 | axes.linewidth: 1 11 | axes.grid: True 12 | axes.titlesize: medium 13 | axes.labelsize: medium 14 | axes.labelcolor: 404040 15 | axes.axisbelow: True # grid/ticks are below elements (e.g., lines, text) 16 | 17 | #axes.prop_cycle: cycler('color',['1f78b4', '1f78b4']) 18 | axes.prop_cycle: cycler('color',['a6cee3', '1f78b4', '1f78b4']) 19 | #axes.prop_cycle: cycler('color',['a6cee3', 'c6dbef', '1f78b4', '1f78b4']) 20 | #axes.prop_cycle: cycler('color',['1f78b4', 'feb24c', '8c2d04']) 21 | #axes.prop_cycle: cycler('color',['33a02c', 'fee090', 'feb24c', '1f78b4']) 22 | #axes.prop_cycle: cycler('color',['b2df8a', '33a02c', '1f78b4']) 23 | #axes.prop_cycle: cycler('color',['1f78b4', 'feb24c', 'a6cee3', '1f78b4']) 24 | #axes.prop_cycle: cycler('color',['fee090', 'a6761d', '1f78b4']) 25 | #axes.prop_cycle: cycler('color',['darkgreen', 'peru', 'FFB5B8', '348ABD']) 26 | #axes.prop_cycle: cycler('color',['peru', 'darkgreen','988ED5', 'FFB5B8', '348ABD'])#, 'gold', '348ABD', 'slategrey']) 27 | #axes.prop_cycle: cycler('color',['peru', 'darkgreen', 'gold', '348ABD', 'slategrey']) 28 | 29 | 30 | text.color : 404040 31 | legend.frameon : True 32 | legend.fancybox: True 33 | legend.edgecolor : b3b3b3 34 | 35 | xtick.color: 404040 36 | xtick.direction: out 37 | xtick.labelsize: 8 38 | 39 | ytick.color: 404040 40 | ytick.direction: out 41 | ytick.labelsize: 8 42 | 43 | grid.color: white 44 | grid.linestyle: - # solid line 45 | 46 | figure.facecolor: white 47 | figure.edgecolor: 0.5 48 | 49 | legend.facecolor: white 50 | legend.framealpha:1.0 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /archive_adapters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts a adapter weights archive from an existing model 3 | """ 4 | 5 | import logging 6 | import argparse 7 | 8 | from udapter import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("archive_dir", type=str, help="The directory where model.tar.gz resides") 16 | parser.add_argument("--adapter_name", default=None, type=str, help="The name to output the file") 17 | 18 | args = parser.parse_args() 19 | 20 | util.archive_adapter_weights(args.archive_dir, args.adapter_name) -------------------------------------------------------------------------------- /archive_bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts a BERT archive from an existing model 3 | """ 4 | 5 | import logging 6 | import argparse 7 | 8 | from udapter import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("archive_dir", type=str, help="The directory where model.tar.gz resides") 16 | parser.add_argument("--output_path", default=None, type=str, help="The path to output the file") 17 | 18 | args = parser.parse_args() 19 | 20 | bert_config = "config/archive/bert-base-multilingual-cased/bert_config.json" 21 | util.archive_bert_model(args.archive_dir, bert_config, args.output_path) 22 | -------------------------------------------------------------------------------- /archive_language_mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts a language MLP weights archive from an existing model 3 | """ 4 | 5 | import logging 6 | import argparse 7 | 8 | from udapter import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("archive_dir", type=str, help="The directory where model.tar.gz resides") 16 | 17 | args = parser.parse_args() 18 | 19 | util.archive_language_MLP(args.archive_dir) 20 | -------------------------------------------------------------------------------- /concat_treebanks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Concatenates all treebanks together 3 | """ 4 | 5 | import os 6 | import shutil 7 | import logging 8 | import argparse 9 | 10 | from udapter import util 11 | 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("output_dir", type=str, help="The path to output the concatenated files") 19 | parser.add_argument("--dataset_dir", default="data/ud-treebanks-v2.3", type=str, 20 | help="The path containing all UD treebanks") 21 | parser.add_argument("--treebanks", default=[], type=str, nargs="+", 22 | help="Specify a list of treebanks to use; leave blank to default to all treebanks available") 23 | parser.add_argument("--add_lang_id", dest="lang_id", default=False, action="store_true", 24 | help="Add language id for each token in treebanks") 25 | 26 | args = parser.parse_args() 27 | 28 | treebanks = util.get_ud_treebank_files(args.dataset_dir, args.treebanks) 29 | train, dev, test = list(zip(*[treebanks[k] for k in treebanks])) 30 | 31 | for treebank, name in zip([train, dev, test], ["train.conllu", "dev.conllu", "test.conllu"]): 32 | with open(os.path.join(args.output_dir, name), 'w') as write: 33 | for t in treebank: 34 | if not t: 35 | continue 36 | with open(t, 'r') as read: 37 | if not args.lang_id: 38 | shutil.copyfileobj(read, write) 39 | else: 40 | for line in read: 41 | if line != '\n' and not line.startswith('#'): 42 | form = line.rstrip('\n').split('\t') 43 | #form += ['treebank=' + os.path.basename(t).split('-')[0]] 44 | form += [os.path.basename(t).split('-')[0].split('_')[0]] 45 | write.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(*form)) 46 | else: 47 | write.write(line) 48 | write.close() 49 | -------------------------------------------------------------------------------- /config/archive/bert-base-multilingual-cased/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 119547 19 | } 20 | -------------------------------------------------------------------------------- /config/archive/bert-large-cased/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 28996 19 | } 20 | -------------------------------------------------------------------------------- /config/create_vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "udify_universal_dependencies", 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": false 8 | }, 9 | "token_characters": { 10 | "type": "characters" 11 | } 12 | } 13 | }, 14 | "vocabulary": { 15 | "non_padded_namespaces": ["upos", "xpos", "feats", "lemmas", "*tags", "*labels"], 16 | // Add the special "@@UNKNOWN@@" token to handle dev/test set labels that don't appear in training 17 | "tokens_to_add": { 18 | "upos": ["@@UNKNOWN@@"], 19 | "xpos": ["@@UNKNOWN@@"], 20 | "feats": ["@@UNKNOWN@@"], 21 | "lemmas": ["@@UNKNOWN@@"], 22 | "head_tags": ["@@UNKNOWN@@"] 23 | } 24 | }, 25 | "train_file_paths": {}, 26 | "validation_file_paths": {} 27 | } -------------------------------------------------------------------------------- /config/ud/feats.json: -------------------------------------------------------------------------------- 1 | {"Abbr": ["Yes"], "AdjType": ["Attr", "Pred"], "AdpType": ["Circ", "Comadp", "Comprep", "Post", "Prep", "Preppron", "Voc"], "AdvType": ["Cau", "Deg", "Ideoph", "Loc", "Man", "Mod", "Sta", "Tim"], "Agglutination": ["Agl", "Nagl"], "Analyt": ["Yes"], "Animacy": ["Anim", "Hum", "Inan", "NHum", "Nhum"], "Animacy[gram]": ["Anim", "Inan"], "Aspect": ["Cons", "Dur", "DurPerf", "Hab", "Imp", "Inch", "Iter", "Perf", "PerfRapid", "Prog", "ProgRapid", "Prosp", "Rapid", "Res"], "Case": ["Abe", "Abl", "Abs", "Acc", "Add", "Ade", "Advb", "All", "Apr", "Ben", "Car", "Cau", "Cns", "Com", "Comp", "Con", "Dat", "Del", "Dis", "Egr", "Ela", "Equ", "Erg", "Ess", "Gen", "Ill", "Ine", "Ins", "Lat", "Loc", "Mal", "Nom", "NomAcc", "Obl", "Par", "Per", "Prl", "Sub", "Sup", "Tem", "Temp", "Ter", "Tra", "Voc"], "Clitic": ["Add", "AddGA", "AddI", "AddKige", "AddNgA", "AddVok", "Aja", "Ga", "Gak", "Gi", "Go", "Han", "Ja", "Ka", "Kaan", "Ki", "Kin", "Ko", "O", "Pa", "QstA", "S", "So", "Yes"], "Clusivity": ["Ex", "In", "Incl"], "Clusivity[obj]": ["Ex", "In"], "Clusivity[psor]": ["Ex", "In"], "Clusivity[subj]": ["Ex", "In"], "Compound": ["Yes"], "ConjType": ["Cmpr", "Comp", "Oper", "Pred"], "Connegative": ["Yes"], "Definite": ["2", "Com", "Cons", "Def", "Ind", "Spec"], "Degree": ["Abs", "Cmp", "Dim", "Equ", "Pos", "Sup"], "Deixis": ["Med", "Prox", "Remt"], "DeixisRef": ["1", "2"], "Derivation": ["A", "AdvstO", "Al", "An", "CompMod", "DimKE", "DimNe", "Dimin", "F", "GenAttr", "I", "Ig", "Igdyrji", "Igmoz", "Ik", "Ill", "InchL", "Ine", "Inen", "J", "Ja", "Kaj", "Lainen", "Llinen", "Minen", "Njems", "NomAg", "OkshnOms", "Oma", "Omka", "Omon", "Ord", "OvOms", "Ovt", "OvtOms", "Ozj", "Poss", "PrcPrt1", "PrivMod", "PronGak", "ProprietiveMod", "Shka", "Sjuro", "Sti", "Tar", "TempMod", "Tom", "Ton", "Ttain", "U", "VCar", "VGen", "VerbYks", "Vnoun", "Voc", "Vs", "Wife", "Y", "Ysj", "Zhyk"], "Dialect": ["Connaught", "Munster", "Ulster"], "Distance": ["Dist", "Med", "Prox"], "Echo": ["Ech", "Rdp"], "Emphatic": ["Yes"], "Evident": ["Fh", "Nfh"], "FocusType": ["Compl", "Subj", "Verb"], "Foreign": ["No", "Yes"], "Form": ["Adn", "Aux", "Compl", "Ecl", "Emp", "HPref", "Len", "VF"], "Gender": ["Com", "Fem", "Masc", "Neut", "Unsp"], "Gender[dat]": ["Masc"], "Gender[erg]": ["Fem", "Masc"], "Gender[psor]": ["Fem", "Masc", "Neut"], "HebBinyan": ["HIFIL", "HITPAEL", "HUFAL", "NIFAL", "PAAL", "PIEL", "PUAL"], "HebExistential": ["True"], "HebSource": ["ConvUncertainHead", "ConvUncertainLabel"], "Hyph": ["Yes"], "InfForm": ["1", "2", "3", "Dict", "Incp"], "Mood": ["Cnd", "CndCnj", "CndGen", "CndGenPot", "CndPot", "Cnj", "Cond", "Des", "DesPot", "Gen", "GenNec", "GenNecPot", "GenPot", "GenPotPot", "Imp", "ImpPot", "Ind", "Int", "Inter", "Jus", "Nec", "NecPot", "Opt", "Pot", "PotPot", "Prec", "Proh", "Prs", "Prsc", "Qot", "Sub", "Vol"], "Morph": ["VFin", "VInf", "VPar"], "Mutation": ["AM", "NM", "SM"], "NameType": ["Com", "Geo", "Giv", "Hom", "Nat", "Oth", "Pat", "Pro", "Prs", "Sur"], "NegationType": ["Contrastive"], "NounClass": ["Wol1", "Wol10", "Wol11", "Wol12", "Wol2", "Wol3", "Wol4", "Wol5", "Wol6", "Wol7", "Wol8", "Wol9"], "NounForm": ["Depr"], "NounType": ["Class", "Clf", "Het", "NotSlender", "Slender", "Strong", "Weak"], "NumForm": ["Armenian", "Combi", "Cyril", "Digit", "Letter", "Roman", "Word"], "NumType": ["Appr", "Card", "Coll", "Dist", "Frac", "Mult", "MultDist", "Ord", "OrdMult", "OrdinalSets", "Range", "Sets"], "NumValue": ["1", "2", "3"], "Number": ["Adnum", "Assoc", "Coll", "Count", "Dual", "Pauc", "Plur", "Ptan", "Sing", "Unsp"], "Number[abs]": ["Plur", "Sing"], "Number[dat]": ["Plur", "Sing"], "Number[erg]": ["Plur", "Sing"], "Number[obj]": ["Dual", "Plur", "Sing"], "Number[psed]": ["None", "Sing"], "Number[psor]": ["Dual", "None", "Plur", "Sing"], "Number[subj]": ["Plur", "Sing"], "Orth": ["Alt"], "PartForm": ["Agt", "Neg", "Past", "Pres"], "PartType": ["Ad", "Cmpl", "Comp", "Conseq", "Cop", "Deg", "Disc", "Emp", "Gen", "Inf", "Int", "Mod", "Neg", "Num", "Pat", "Res", "Sub", "Sup", "Vb", "Vbp", "Voc"], "Person": ["0", "1", "2", "3", "4", "Auto"], "Person[abs]": ["1", "2", "3"], "Person[dat]": ["1", "2", "3"], "Person[erg]": ["1", "2", "3"], "Person[obj]": ["1", "2", "3"], "Person[psor]": ["1", "2", "3", "None"], "Person[sdat]": ["3"], "Person[subj]": ["1", "2", "3"], "Polarity": ["Int", "Neg", "Pos"], "Polite": ["Depr", "Form", "Infm"], "Polite[abs]": ["Infm"], "Polite[dat]": ["Infm"], "Polite[erg]": ["Infm"], "Position": ["Postnom", "Prenom"], "Poss": ["Yes"], "Prefix": ["Yes"], "PrepCase": ["Npr", "Pre"], "PrepForm": ["Cmpd"], "PronType": ["Add", "Art", "Coll", "Contrastive", "Dem", "Emp", "Exc", "Ind", "Int", "Neg", "Ord", "Prs", "Qnt", "Rcp", "Ref", "Refl", "Rel", "Tot"], "Pun": ["No", "Yes"], "PunctSide": ["Fin", "Ini"], "PunctType": ["Brck", "Bull", "Colo", "Comm", "Dash", "Elip", "Excl", "Hyph", "Ndash", "Peri", "Qest", "Quot", "Semi", "Slsh"], "Reflex": ["No", "Yes"], "Register": ["Form"], "Relative": ["Rel"], "Strength": ["Strong", "Weak"], "Style": ["Arch", "Coll", "Expr", "Form", "Ped", "Rare", "Slng", "Vrnc", "Vulg"], "SubGender": ["Masc1", "Masc2", "Masc3"], "Subcat": ["Ditran", "Int", "IntInd", "Intr", "Prep", "Tran"], "Tense": ["Aor", "Fut", "FutPlan", "Imp", "Past", "PastIter", "PastSimp", "Pqp", "Pres", "PresHab", "Prosp", "Prt", "Prt1", "Prt2"], "Typo": ["Yes"], "Uninflect": ["Yes"], "Valency": ["1", "2"], "Variant": ["Bound", "Full", "Long", "Short", "Uncontr"], "VerbForm": ["Conv", "Cop", "Cov", "Coverb", "Fin", "Gdv", "Ger", "Inf", "Part", "PartFut", "PartPad", "PartPast", "PartPres", "PartPus", "PartRes", "Post", "Prov", "Ser", "Stem", "Sup", "Vnoun"], "VerbType": ["Aux", "Cop", "Mod", "Pas", "Quasi"], "Voice": ["Act", "Auto", "Cau", "CauCau", "CauCauPass", "CauPass", "CauPassRcp", "CauRcp", "Coop", "Mid", "Necess", "Pass", "PassPass", "PassRcp", "PassRfl", "Rcp", "Rfl", "Trans"], "Xtra": ["Junk"]} -------------------------------------------------------------------------------- /config/ud/multilingual/adapter-test.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "use_lang_ids": true, 5 | "use_separate_feats": true, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | }, 11 | "bert": { 12 | "type": "udapter-bert", 13 | "in_lang_list": "languages/in-langs.txt", 14 | "oov_lang_list": "languages/oov-langs.txt", 15 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 16 | "do_lowercase": false, 17 | "use_starting_offsets": true 18 | } 19 | } 20 | }, 21 | "train_data_path": "data/test/sample.feat.conllu", 22 | "validation_data_path": "data/test/sample.feat.conllu", 23 | "test_data_path": "data/test/sample.feat.conllu", 24 | "vocabulary": { 25 | "directory_path": "data/vocab/sample.feat/vocabulary" 26 | }, 27 | "model": { 28 | "type": "udapter_model", 29 | "pretrained_model": "bert-base-multilingual-cased", 30 | "word_dropout": 0.2, 31 | "layer_dropout": 0.1, 32 | "tasks": ["deps"], 33 | "text_field_embedder": { 34 | "type": "udify_embedder", 35 | "dropout": 0.5, 36 | "allow_unmatched_keys": true, 37 | "embedder_to_indexer_map": { 38 | "bert": ["bert", "bert-offsets", "bert-lang-ids"] 39 | }, 40 | "token_embedders": { 41 | "bert": { 42 | "type": "udapter-bert", 43 | "pretrained_model": "bert-base-multilingual-cased", 44 | "bert_config_file": "adapter_bert/configs/adapter-bert.json", 45 | "bert_requires_grad": false, 46 | "adapters_requires_grad": true, 47 | "dropout": 0.15, 48 | "layer_dropout": 0.1, 49 | "combine_layers": "last", 50 | "use_adapter": true, 51 | "adapter_size": 1024, 52 | "adapter_prediction": "single", 53 | "num_adapters": 1, 54 | "use_language_emb": false, //to use only adapter without parameter-generator/language-embeddings 55 | "in_languages": "languages/in-langs.txt", 56 | "oov_languages": "languages/oov-langs.txt", 57 | "language_emb_size": 32, 58 | "language_emb_dropout": 0.1, 59 | "language_one_hot": false, 60 | "num_language_features": 289, 61 | "language_drop_rate": 0.2, 62 | "language_features": "syntax_knn+phonology_knn+inventory_knn", 63 | "num_languages": 5, 64 | "language_emb_from_features": true 65 | } 66 | } 67 | }, 68 | "encoder": { 69 | "type": "pass_through", 70 | "input_dim": 768 71 | }, 72 | "decoders": { 73 | "upos": { 74 | "type": "udapter_tag_decoder", 75 | "encoder": { 76 | "type": "pass_through", 77 | "input_dim": 768 78 | }, 79 | "png_params_dim": 32 80 | }, 81 | "feats": { 82 | "type": "udapter_tag_decoder", 83 | "png_params_dim": 32, 84 | "encoder": { 85 | "type": "pass_through", 86 | "input_dim": 768 87 | }, 88 | "adaptive": false, 89 | "features": [ 90 | "Aspect", 91 | "Gender[psor]" 92 | ] 93 | }, 94 | "lemmas": { 95 | "encoder": { 96 | "type": "pass_through", 97 | "input_dim": 768 98 | }, 99 | "adaptive": false 100 | }, 101 | "deps": { 102 | "type": "udapter_dependency_decoder", 103 | "png_params_dim": 32, 104 | "tag_representation_dim": 256, 105 | "arc_representation_dim": 768, 106 | "encoder": { 107 | "type": "pass_through", 108 | "input_dim": 768 109 | } 110 | } 111 | } 112 | }, 113 | "iterator": { 114 | "type": "lang-bucket", 115 | "sorting_keys": [["tokens", "num_tokens"]], 116 | "biggest_batch_first": true, 117 | "batch_size": 32, 118 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 119 | }, 120 | "trainer": { 121 | "num_epochs": 80, 122 | "patience": 80, 123 | "num_serialized_models_to_keep": 1, 124 | "should_log_learning_rate": true, 125 | "should_log_parameter_statistics": true, 126 | "summary_interval": 100, 127 | "optimizer": { 128 | "type": "bert_adam", 129 | "b1": 0.9, 130 | "b2": 0.99, 131 | "weight_decay": 0.01, 132 | "lr": 0.001, 133 | "parameter_groups": [ 134 | [["^text_field_embedder.*.bert_model.embeddings", 135 | "^text_field_embedder.*.bert_model.*.attention.self", 136 | "^text_field_embedder.*.bert_model.*.intermediate", 137 | "^text_field_embedder.*.bert_model.*.output.LayerNorm", 138 | "^text_field_embedder.*.bert_model.*.output.dense"], {}], 139 | [["^text_field_embedder.*.adapter", 140 | "^text_field_embedder.*.language_embedder", 141 | "^text_field_embedder.*._scalar_mix", 142 | "^text_field_embedder.*.pooler", 143 | "^scalar_mix", 144 | "^decoders", 145 | "^shared_encoder"], {}] 146 | ] 147 | }, 148 | "learning_rate_scheduler": { 149 | "type": "ulmfit_sqrt", 150 | "model_size": 1, 151 | "warmup_steps": 3000, 152 | "start_step": 3000, 153 | "factor": 5.0, 154 | "gradual_unfreezing": false, 155 | "discriminative_fine_tuning": false, 156 | "decay_factor": 0.05 157 | } 158 | }, 159 | "udify_replace": [ 160 | "dataset_reader.token_indexers", 161 | "model.text_field_embedder", 162 | "model.encoder", 163 | "model.decoders.xpos", 164 | "model.decoders.deps.encoder", 165 | "model.decoders.upos.encoder", 166 | "model.decoders.feats.encoder", 167 | "model.decoders.lemmas.encoder", 168 | "trainer.learning_rate_scheduler", 169 | "trainer.optimizer", 170 | "vocabulary.directory_path" 171 | ] 172 | } 173 | -------------------------------------------------------------------------------- /config/ud/multilingual/udapter-test.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "use_lang_ids": true, 5 | "use_separate_feats": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | }, 11 | "bert": { 12 | "type": "udapter-bert", 13 | "in_lang_list": "languages/in-langs.txt", 14 | "oov_lang_list": "languages/oov-langs.txt", 15 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 16 | "do_lowercase": false, 17 | "use_starting_offsets": true 18 | } 19 | } 20 | }, 21 | "train_data_path": "data/test/train.conllu", 22 | "validation_data_path": "data/test/dev.conllu", 23 | "test_data_path": "data/test/test.conllu", 24 | "vocabulary": { 25 | "directory_path": "data/vocab/test/vocabulary" 26 | }, 27 | "model": { 28 | "type": "udapter_model", 29 | "pretrained_model": "bert-base-multilingual-cased", 30 | "word_dropout": 0.2, 31 | "layer_dropout": 0.1, 32 | "tasks": ["deps"], 33 | "text_field_embedder": { 34 | "type": "udify_embedder", 35 | "dropout": 0.5, 36 | "allow_unmatched_keys": true, 37 | "embedder_to_indexer_map": { 38 | "bert": ["bert", "bert-offsets", "bert-lang-ids"] 39 | }, 40 | "token_embedders": { 41 | "bert": { 42 | "type": "udapter-bert", 43 | "pretrained_model": "bert-base-multilingual-cased", 44 | "bert_config_file": "adapter_bert/configs/adapter-bert.json", 45 | "bert_requires_grad": false, 46 | "adapters_requires_grad": true, 47 | "dropout": 0.15, 48 | "layer_dropout": 0.1, 49 | "combine_layers": "last", 50 | "use_adapter": true, 51 | "adapter_size": 256, 52 | "adapter_prediction": "single", 53 | "num_adapters": 1, 54 | "use_language_emb": true, 55 | "in_languages": "languages/in-langs.txt", 56 | "oov_languages": "languages/oov-langs.txt", 57 | "language_emb_size": 32, 58 | "language_emb_dropout": 0.1, 59 | "language_one_hot": false, 60 | "num_language_features": 289, 61 | "language_drop_rate": 0.2, 62 | "language_features": "syntax_knn+phonology_knn+inventory_knn", 63 | "num_languages": 5, 64 | "language_emb_from_features": true 65 | } 66 | } 67 | }, 68 | "encoder": { 69 | "type": "pass_through", 70 | "input_dim": 768 71 | }, 72 | "decoders": { 73 | "upos": { 74 | "type": "udapter_tag_decoder", 75 | "encoder": { 76 | "type": "pass_through", 77 | "input_dim": 768 78 | }, 79 | "png_params_dim": 32 80 | }, 81 | "feats": { 82 | "type": "udapter_tag_decoder", 83 | "png_params_dim": 32, 84 | "encoder": { 85 | "type": "pass_through", 86 | "input_dim": 768 87 | }, 88 | "adaptive": false, 89 | "features": [ 90 | "Aspect", 91 | "Gender[psor]" 92 | ] 93 | }, 94 | "lemmas": { 95 | "encoder": { 96 | "type": "pass_through", 97 | "input_dim": 768 98 | }, 99 | "adaptive": false 100 | }, 101 | "deps": { 102 | "type": "udapter_dependency_decoder", 103 | "png_params_dim": 32, 104 | "tag_representation_dim": 256, 105 | "arc_representation_dim": 768, 106 | "encoder": { 107 | "type": "pass_through", 108 | "input_dim": 768 109 | } 110 | } 111 | } 112 | }, 113 | "iterator": { 114 | "type": "lang-bucket", 115 | "sorting_keys": [["tokens", "num_tokens"]], 116 | "biggest_batch_first": true, 117 | "batch_size": 32, 118 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 119 | }, 120 | "trainer": { 121 | "num_epochs": 80, 122 | "patience": 80, 123 | "num_serialized_models_to_keep": 1, 124 | "should_log_learning_rate": true, 125 | "should_log_parameter_statistics": true, 126 | "summary_interval": 100, 127 | "optimizer": { 128 | "type": "bert_adam", 129 | "b1": 0.9, 130 | "b2": 0.99, 131 | "weight_decay": 0.01, 132 | "lr": 0.001, 133 | "parameter_groups": [ 134 | [["^text_field_embedder.*.bert_model.embeddings", 135 | "^text_field_embedder.*.bert_model.*.attention.self", 136 | "^text_field_embedder.*.bert_model.*.intermediate", 137 | "^text_field_embedder.*.bert_model.*.output.LayerNorm", 138 | "^text_field_embedder.*.bert_model.*.output.dense"], {}], 139 | [["^text_field_embedder.*.adapter", 140 | "^text_field_embedder.*.language_embedder", 141 | "^text_field_embedder.*._scalar_mix", 142 | "^text_field_embedder.*.pooler", 143 | "^scalar_mix", 144 | "^decoders", 145 | "^shared_encoder"], {}] 146 | ] 147 | }, 148 | "learning_rate_scheduler": { 149 | "type": "ulmfit_sqrt", 150 | "model_size": 1, 151 | "warmup_steps": 3000, 152 | "start_step": 3000, 153 | "factor": 5.0, 154 | "gradual_unfreezing": false, 155 | "discriminative_fine_tuning": false, 156 | "decay_factor": 0.05 157 | } 158 | }, 159 | "udify_replace": [ 160 | "dataset_reader.token_indexers", 161 | "model.text_field_embedder", 162 | "model.encoder", 163 | "model.decoders.xpos", 164 | "model.decoders.deps.encoder", 165 | "model.decoders.upos.encoder", 166 | "model.decoders.feats.encoder", 167 | "model.decoders.lemmas.encoder", 168 | "trainer.learning_rate_scheduler", 169 | "trainer.optimizer", 170 | "vocabulary.directory_path" 171 | ] 172 | } 173 | -------------------------------------------------------------------------------- /config/ud/multilingual/udify-test.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "use_separate_feats": true, 5 | "token_indexers": { 6 | "tokens": { 7 | "type": "single_id", 8 | "lowercase_tokens": true 9 | }, 10 | "bert": { 11 | "type": "udify-bert-pretrained", 12 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 13 | "do_lowercase": false, 14 | "use_starting_offsets": true 15 | } 16 | } 17 | }, 18 | "train_data_path": "data/test/sample.feat.conllu", 19 | "validation_data_path": "data/test/sample.feat.conllu", 20 | "test_data_path": "data/test/sample.feat.conllu", 21 | "vocabulary": { 22 | "directory_path": "data/vocab/sample.feat/vocabulary" 23 | }, 24 | "model": { 25 | "type": "udify_model", 26 | "pretrained_model": "bert-base-multilingual-cased", 27 | "word_dropout": 0.2, 28 | "layer_dropout": 0.1, 29 | "tasks": ["deps"], 30 | "text_field_embedder": { 31 | "type": "udify_embedder", 32 | "dropout": 0.5, 33 | "allow_unmatched_keys": true, 34 | "embedder_to_indexer_map": { 35 | "bert": ["bert", "bert-offsets"] 36 | }, 37 | "token_embedders": { 38 | "bert": { 39 | "type": "udify-bert-pretrained", 40 | "pretrained_model": "bert-base-multilingual-cased", 41 | "requires_grad": true, 42 | "dropout": 0.2, 43 | "layer_dropout": 0.1, 44 | "combine_layers": "last" 45 | } 46 | } 47 | }, 48 | "encoder": { 49 | "type": "pass_through", 50 | "input_dim": 768 51 | }, 52 | "decoders": { 53 | "upos": { 54 | "encoder": { 55 | "type": "pass_through", 56 | "input_dim": 768 57 | }, 58 | "adaptive": false 59 | }, 60 | "feats": { 61 | "encoder": { 62 | "type": "pass_through", 63 | "input_dim": 768 64 | }, 65 | "adaptive": false, 66 | "features": [ 67 | "Aspect", 68 | "Gender[psor]" 69 | ] 70 | }, 71 | "lemmas": { 72 | "encoder": { 73 | "type": "pass_through", 74 | "input_dim": 768 75 | }, 76 | "adaptive": false 77 | }, 78 | "deps": { 79 | "tag_representation_dim": 256, 80 | "arc_representation_dim": 768, 81 | "encoder": { 82 | "type": "pass_through", 83 | "input_dim": 768 84 | } 85 | } 86 | } 87 | }, 88 | "iterator": { 89 | "batch_size": 32, 90 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 91 | }, 92 | "trainer": { 93 | "num_epochs": 80, 94 | "patience": 80, 95 | "num_serialized_models_to_keep": 1, 96 | "should_log_learning_rate": true, 97 | "should_log_parameter_statistics": true, 98 | "summary_interval": 100, 99 | "optimizer": { 100 | "type": "bert_adam", 101 | "b1": 0.9, 102 | "b2": 0.99, 103 | "weight_decay": 0.01, 104 | "lr": 0.001, 105 | "parameter_groups": [ 106 | [["^text_field_embedder.*.bert_model.embeddings", 107 | "^text_field_embedder.*.bert_model.*.attention.self", 108 | "^text_field_embedder.*.bert_model.*.intermediate", 109 | "^text_field_embedder.*.bert_model.*.output.LayerNorm", 110 | "^text_field_embedder.*.bert_model.*.output.dense"], {}], 111 | [["^text_field_embedder.*.adapter", 112 | "^text_field_embedder.*.language_embedder", 113 | "^text_field_embedder.*._scalar_mix", 114 | "^text_field_embedder.*.pooler", 115 | "^scalar_mix", 116 | "^decoders", 117 | "^shared_encoder"], {}] 118 | ] 119 | }, 120 | "learning_rate_scheduler": { 121 | "type": "ulmfit_sqrt", 122 | "model_size": 1, 123 | "warmup_steps": 3000, 124 | "start_step": 3000, 125 | "factor": 5.0, 126 | "gradual_unfreezing": true, 127 | "discriminative_fine_tuning": true, 128 | "decay_factor": 0.05 129 | } 130 | }, 131 | "udify_replace": [ 132 | "dataset_reader.token_indexers", 133 | "model.text_field_embedder", 134 | "model.encoder", 135 | "model.decoders.xpos", 136 | "model.decoders.deps.encoder", 137 | "model.decoders.upos.encoder", 138 | "model.decoders.feats.encoder", 139 | "model.decoders.lemmas.encoder", 140 | "trainer.learning_rate_scheduler", 141 | "trainer.optimizer" 142 | ] 143 | } 144 | -------------------------------------------------------------------------------- /config/udify_base.json: -------------------------------------------------------------------------------- 1 | // The base configuration for UD parsing 2 | { 3 | "dataset_reader": { 4 | "type": "udify_universal_dependencies", 5 | "lazy": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | }, 11 | "token_characters": { 12 | "type": "characters", 13 | "min_padding_length": 1 14 | } 15 | } 16 | }, 17 | "vocabulary": { 18 | "directory_path": "data/vocab/UD_English-EWT/vocabulary", 19 | "non_padded_namespaces": ["upos", "xpos", "feats", "lemmas", "*tags", "*labels"] 20 | }, 21 | "train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu", 22 | "validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu", 23 | "test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu", 24 | "evaluate_on_test": true, 25 | "model": { 26 | local global_dropout = 0.5, 27 | local label_smoothing = 0.03, 28 | 29 | local word_embedding_dim = 512, 30 | local char_embedding_dim = 256, 31 | local hidden_dim = 512, 32 | 33 | local shared_layers = 2, 34 | local task_layers = 1, 35 | 36 | "type": "udify_model", 37 | "tasks": ["upos", "xpos", "feats", "lemmas", "deps"], 38 | "dropout": global_dropout, 39 | "word_dropout": 0.2, 40 | "text_field_embedder": { 41 | "type": "udify_embedder", 42 | "sum_embeddings": ["tokens", "token_characters"], 43 | "dropout": global_dropout, 44 | "token_embedders": { 45 | "tokens": { 46 | "type": "embedding", 47 | "embedding_dim": word_embedding_dim 48 | }, 49 | "token_characters": { 50 | "type": "udify_character_encoding", 51 | "dropout": global_dropout, 52 | "embedding": { 53 | "embedding_dim": char_embedding_dim 54 | }, 55 | "encoder": { 56 | "type": "lstm", 57 | "input_size": char_embedding_dim, 58 | "hidden_size": char_embedding_dim, 59 | "num_layers": task_layers, 60 | "dropout": 0.0, 61 | "bidirectional": true 62 | } 63 | } 64 | } 65 | }, 66 | "encoder": { 67 | "type": "udify_residual_rnn", 68 | "input_size": hidden_dim, 69 | "hidden_size": hidden_dim, 70 | "num_layers": shared_layers, 71 | "dropout": global_dropout 72 | }, 73 | "decoders": { 74 | "upos": { 75 | "type": "udify_tag_decoder", 76 | "task": "upos", 77 | "encoder": { 78 | "type": "udify_residual_rnn", 79 | "input_size": hidden_dim, 80 | "hidden_size": hidden_dim, 81 | "num_layers": task_layers, 82 | "dropout": global_dropout 83 | }, 84 | "label_smoothing": label_smoothing, 85 | "dropout": global_dropout 86 | }, 87 | "xpos": { 88 | "type": "udify_tag_decoder", 89 | "task": "xpos", 90 | "encoder": { 91 | "type": "udify_residual_rnn", 92 | "input_size": hidden_dim, 93 | "hidden_size": hidden_dim, 94 | "num_layers": task_layers, 95 | "dropout": global_dropout 96 | }, 97 | "label_smoothing": label_smoothing, 98 | "dropout": global_dropout 99 | }, 100 | "feats": { 101 | "type": "udify_tag_decoder", 102 | "task": "feats", 103 | "encoder": { 104 | "type": "udify_residual_rnn", 105 | "input_size": hidden_dim, 106 | "hidden_size": hidden_dim, 107 | "num_layers": task_layers, 108 | "dropout": global_dropout 109 | }, 110 | "label_smoothing": label_smoothing, 111 | "dropout": global_dropout 112 | }, 113 | "lemmas": { 114 | "type": "udify_tag_decoder", 115 | "task": "lemmas", 116 | "encoder": { 117 | "type": "udify_residual_rnn", 118 | "input_size": hidden_dim, 119 | "hidden_size": hidden_dim, 120 | "num_layers": task_layers, 121 | "dropout": global_dropout 122 | }, 123 | "label_smoothing": label_smoothing, 124 | "dropout": global_dropout 125 | }, 126 | "deps": { 127 | "type": "udify_dependency_decoder", 128 | "pos_embed_dim": null, 129 | "tag_representation_dim": 128, 130 | "arc_representation_dim": 512, 131 | "dropout": global_dropout, 132 | "encoder": { 133 | "type": "udify_residual_rnn", 134 | "input_size": hidden_dim, 135 | "hidden_size": hidden_dim, 136 | "num_layers": task_layers, 137 | "dropout": global_dropout, 138 | "residual": false 139 | } 140 | } 141 | } 142 | }, 143 | "iterator": { 144 | "type": "bucket", 145 | "batch_size": 32, 146 | "sorting_keys": [["tokens", "num_tokens"]], 147 | "biggest_batch_first": true 148 | // "track_epoch": true 149 | }, 150 | "trainer": { 151 | "optimizer": { 152 | "type": "adam", 153 | "lr": 4e-3, 154 | "betas": [0.9, 0.99] 155 | }, 156 | "learning_rate_scheduler": { 157 | "type": "multi_step", 158 | "milestones": [0, 1, 15, 20, 30, 40], 159 | "gamma": 0.5 160 | }, 161 | "num_epochs": 50, 162 | "patience": 50, 163 | "validation_metric": "+.run/.sum", 164 | "should_log_learning_rate": false, 165 | "should_log_parameter_statistics": false, 166 | "summary_interval": 500, 167 | "num_serialized_models_to_keep": 1, 168 | "grad_norm": 5.0, 169 | "grad_clipping": 10.0, 170 | "cuda_device": 0 171 | }, 172 | // "random_seed": 13370, 173 | // "numpy_seed": 1337, 174 | // "pytorch_seed": 133 175 | } 176 | -------------------------------------------------------------------------------- /create_vocabs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates vocab files for all treebanks in the given directory 3 | """ 4 | 5 | import os 6 | import json 7 | import logging 8 | import argparse 9 | 10 | from allennlp.commands.make_vocab import make_vocab_from_params 11 | from allennlp.common import Params 12 | from allennlp.common.util import import_submodules 13 | 14 | from udapter import util 15 | 16 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--dataset_dir", default="data/ud-treebanks-v2.3", type=str, 22 | help="The path containing all UD treebanks") 23 | parser.add_argument("--output_dir", default="data/vocab", type=str, help="The path to save all vocabulary files") 24 | parser.add_argument("--treebanks", default=[], type=str, nargs="+", 25 | help="Specify a list of treebanks to use; leave blank to default to all treebanks available") 26 | parser.add_argument("--params_file", default=None, type=str, help="The path to the vocab params") 27 | args = parser.parse_args() 28 | 29 | import_submodules("udify") 30 | 31 | params_file = util.VOCAB_CONFIG_PATH if not args.params_file else args.params_file 32 | 33 | treebanks = sorted(util.get_ud_treebank_files(args.dataset_dir, args.treebanks).items()) 34 | for treebank, (train_file, dev_file, test_file) in treebanks: 35 | logger.info(f"Creating vocabulary for treebank {treebank}") 36 | 37 | if not train_file: 38 | logger.info(f"No training data for {treebank}, skipping") 39 | continue 40 | 41 | overrides = json.dumps({ 42 | "train_data_path": train_file, 43 | "validation_data_path": dev_file, 44 | "test_data_path": test_file 45 | }) 46 | params = Params.from_file(params_file, overrides) 47 | output_file = os.path.join(args.output_dir, treebank) 48 | 49 | make_vocab_from_params(params, output_file) 50 | -------------------------------------------------------------------------------- /docs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmetustun/udapter/b5438df4a38778f964942200beb3aded15841d65/docs/model.png -------------------------------------------------------------------------------- /find_missing_features.py: -------------------------------------------------------------------------------- 1 | import lang2vec.lang2vec as l2v 2 | 3 | FEATURES = 'syntax_average+phonology_average+inventory_average' 4 | LANGUAGES = 'sv zh ja it ko eu en fi tr hi ru ar he akk aii bho krl koi kpv olo mr mdf sa tl wbp yo gsw am bm be br bxr yue myv fo kk kmr gun pcm ta te hsb cy' 5 | 6 | features_avg = l2v.get_features(LANGUAGES, FEATURES) 7 | 8 | missing_features = open('missing_features.txt', 'w', encoding='utf-8') 9 | for lang, feat in features_avg.items(): 10 | nof = feat.count('--') 11 | missing_features.write('{} : {}\n'.format(lang, nof)) 12 | 13 | missing_features.close() -------------------------------------------------------------------------------- /languages/in-langs.txt: -------------------------------------------------------------------------------- 1 | sv 2 | zh 3 | ja 4 | it 5 | ko 6 | eu 7 | en 8 | fi 9 | tr 10 | hi 11 | ru 12 | ar 13 | he -------------------------------------------------------------------------------- /languages/language_codes.txt: -------------------------------------------------------------------------------- 1 | sv: Swedish 2 | zh: Chinese 3 | ja: Japanese 4 | it: Italian 5 | ko: Korean 6 | eu: Basque 7 | en: English 8 | fi: Finnish 9 | tr: Turkish 10 | hi: Hindi 11 | ru: Russian 12 | ar: Arabic 13 | he: Hebrew 14 | akk: Akkadian 15 | aii: Assyrian 16 | bho: Bhojpuri 17 | krl: Karelian 18 | koi: K.Permyak 19 | kpv: K.Zyrian 20 | olo: Livvi 21 | mr: Marathi 22 | mdf: Moksha 23 | sa: Sanskrit 24 | sms: Skolt Sami 25 | tl: Tagalog 26 | wbp: Warlpiri 27 | yo: Yoruba 28 | gsw: S.German 29 | am: Amharic 30 | bm: Bambara 31 | be: Belarusian 32 | br: Breton 33 | bxr: Buryat 34 | yue: Cantonese 35 | myv: Erzya 36 | fo: Faroese 37 | kk: Kazakh 38 | kmr: Kurmanji 39 | gun: M.Guarani 40 | pcm: Naija 41 | ta: Tamil 42 | te: Telugu 43 | hsb: U.Sorbian 44 | cy: Welsh -------------------------------------------------------------------------------- /languages/letter_codes.json: -------------------------------------------------------------------------------- 1 | {"akk": "akk", "aii": "aii", "bho": "bho", "krl":"krl", "koi": "koi", "kpv": "kpv", "olo":"olo", "mr": "mar", "mdf":"mdf", "sa":"san", "sms":"sms", "tl":"tgl", "wbp": "wbp", "yo":"yor", "gsw":"gsw", "am":"amh", "bm":"bam", "be":"bel", "br":"bre", "hsb": "sou", "bxr": "bxm", "yue":"yue", "myv":"myv", "kmr":"kmr", "gun":"gun", "pcm":"pcm", "ta":"tam", "te":"tel", "cy":"cym", "am": "amh", "bs": "bos", "vi": "vie", "wa": "wln", "eu": "eus", "so": "som", "el": "ell", "aa": "aar", "or": "ori", "sm": "smo", "gn": "grn", "mi": "mri", "pi": "pli", "ps": "pus", "ms": "msa", "sa": "san", "ko": "kor", "sd": "snd", "hz": "her", "ks": "kas", "fo": "fao", "iu": "iku", "tg": "tgk", "dz": "dzo", "ar": "ara", "fa": "fas", "es": "spa", "my": "mya", "mg": "mlg", "st": "sot", "gu": "guj", "uk": "ukr", "lv": "lav", "to": "ton", "nv": "nav", "kl": "kal", "ka": "kat", "yi": "yid", "pl": "pol", "ht": "hat", "lu": "lub", "fr": "fra", "ia": "ina", "lt": "lit", "om": "orm", "qu": "que", "no": "nor", "sr": "srp", "br": "bre", "rm": "roh", "io": "ido", "gl": "glg", "nb": "nob", "ng": "ndo", "ts": "tso", "nr": "nbl", "ee": "ewe", "bo": "bod", "mt": "mlt", "ta": "tam", "et": "est", "yo": "yor", "tw": "twi", "sl": "slv", "su": "sun", "gv": "glv", "lo": "lao", "af": "afr", "sg": "sag", "sv": "swe", "ne": "nep", "ie": "ile", "bm": "bam", "sc": "srd", "sw": "swa", "nn": "nno", "ho": "hmo", "ak": "aka", "ab": "abk", "ti": "tir", "fy": "fry", "cr": "cre", "sh": "hbs", "ny": "nya", "uz": "uzb", "as": "asm", "ky": "kir", "av": "ava", "ig": "ibo", "zh": "zho", "tr": "tur", "hu": "hun", "pt": "por", "fj": "fij", "hr": "hrv", "it": "ita", "te": "tel", "rw": "kin", "kk": "kaz", "hy": "hye", "wo": "wol", "jv": "jav", "oc": "oci", "kn": "kan", "cu": "chu", "ln": "lin", "ha": "hau", "ru": "rus", "pa": "pan", "cv": "chv", "ss": "ssw", "ki": "kik", "ga": "gle", "dv": "div", "vo": "vol", "lb": "ltz", "ce": "che", "oj": "oji", "th": "tha", "ff": "ful", "kv": "kom", "tk": "tuk", "kr": "kau", "bg": "bul", "tt": "tat", "ml": "mal", "tl": "tgl", "mr": "mar", "hi": "hin", "ku": "kur", "na": "nau", "li": "lim", "nl": "nld", "nd": "nde", "os": "oss", "la": "lat", "bn": "ben", "kw": "cor", "id": "ind", "ay": "aym", "xh": "xho", "zu": "zul", "cs": "ces", "sn": "sna", "de": "deu", "co": "cos", "sk": "slk", "ug": "uig", "rn": "run", "he": "heb", "ba": "bak", "ro": "ron", "be": "bel", "ca": "cat", "kj": "kua", "ja": "jpn", "ch": "cha", "ik": "ipk", "bi": "bis", "an": "arg", "cy": "cym", "tn": "tsn", "mk": "mkd", "ve": "ven", "eo": "epo", "kg": "kon", "km": "khm", "se": "sme", "ii": "iii", "az": "aze", "en": "eng", "ur": "urd", "za": "zha", "is": "isl", "mh": "mah", "mn": "mon", "sq": "sqi", "lg": "lug", "gd": "gla", "fi": "fin", "ty": "tah", "da": "dan", "si": "sin", "ae": "ave", "alb": "sqi", "arm": "hye", "baq": "eus", "tib": "bod", "bur": "mya", "cze": "ces", "chi": "zho", "wel": "cym", "ger": "deu", "dut": "nld", "gre": "ell", "per": "fas", "fre": "fra", "geo": "kat", "ice": "isl", "mac": "mkd", "mao": "mri", "may": "msa", "rum": "ron", "slo": "slk"} 2 | -------------------------------------------------------------------------------- /languages/lr-proxy.json: -------------------------------------------------------------------------------- 1 | { "akk": "ar", 2 | "aii": "ar", 3 | "am": "ar", 4 | "be": "ru", 5 | "bho": "hi", 6 | "yue": "zh", 7 | "fo": "sv", 8 | "krl": "fi", 9 | "olo": "fi", 10 | "mr": "hi", 11 | "sa": "hi", 12 | "gsw": "en", 13 | "hsb": "ru", 14 | "kk": "tr"} -------------------------------------------------------------------------------- /languages/oov-langs.txt: -------------------------------------------------------------------------------- 1 | akk 2 | aii 3 | bho 4 | krl 5 | koi 6 | kpv 7 | olo 8 | mr 9 | mdf 10 | sa 11 | tl 12 | wbp 13 | yo 14 | gsw 15 | am 16 | bm 17 | be 18 | br 19 | bxr 20 | yue 21 | myv 22 | fo 23 | kk 24 | kmr 25 | gun 26 | pcm 27 | ta 28 | te 29 | hsb 30 | cy 31 | -------------------------------------------------------------------------------- /load_adapters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts a adapter weights archive from an existing model 3 | """ 4 | 5 | import logging 6 | import argparse 7 | 8 | from udapter import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("archive_dir", type=str, help="The directory where model.tar.gz resides") 16 | parser.add_argument("--adapters_dir", default=None, type=str, help="The directory where adapter.bin.x resides") 17 | parser.add_argument("--output_path", default=None, type=str, help="The path for output pytorch model") 18 | 19 | args = parser.parse_args() 20 | 21 | bert_config = "adapter_bert/configs/adapter-bert.json" 22 | util.load_adapter_weights(args.archive_dir, args.adapters_dir, bert_config, args.output_path) 23 | -------------------------------------------------------------------------------- /pg.scripts/pg.clean.ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if len(sys.argv) < 2: 5 | print("please specify log dir") 6 | exit(1) 7 | 8 | model_file = 'model.tar.gz' 9 | model_state = 'model_state_epoch_' 10 | training_state = 'training_state_epoch_' 11 | best_ckpt = 'best.th' 12 | save_dir = 'midpoints' 13 | 14 | log_dir = sys.argv[1] 15 | 16 | files_and_dirs = os.listdir(log_dir) 17 | 18 | 19 | def find_last_epoch(log_dir, model_state): 20 | last = 0 21 | for f in os.listdir(log_dir): 22 | if f.startswith(model_state): 23 | ep = int(f.split('.')[0].split(model_state)[1]) 24 | last = ep if ep > last else last 25 | return str(last) + '.th' 26 | 27 | 28 | def clean_dir(log_dir, model_state, training_state): 29 | for f in os.listdir(log_dir): 30 | if f.startswith(model_state) or f.startswith(training_state): 31 | os.remove(os.path.join(log_dir, f)) 32 | 33 | 34 | if model_file in files_and_dirs: 35 | clean_dir(log_dir, model_state, training_state) 36 | else: 37 | if save_dir not in files_and_dirs: 38 | os.makedirs(os.path.join(log_dir, save_dir)) 39 | last = find_last_epoch(log_dir, model_state) 40 | os.replace(os.path.join(log_dir, model_state+last), os.path.join(log_dir, save_dir, model_state+last)) 41 | os.replace(os.path.join(log_dir, training_state + last), os.path.join(log_dir, save_dir, training_state + last)) 42 | os.replace(os.path.join(log_dir, best_ckpt), os.path.join(log_dir, save_dir, best_ckpt.split('.')[0]+'.'+last)) 43 | clean_dir(log_dir, model_state, training_state) 44 | os.replace(os.path.join(log_dir, save_dir, model_state + last), os.path.join(log_dir, model_state + last)) 45 | os.replace(os.path.join(log_dir, save_dir, training_state + last), os.path.join(log_dir, training_state + last)) -------------------------------------------------------------------------------- /pg.scripts/pg.feat.eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--model", type=str) 6 | parser.add_argument("--data_dir", type=str) 7 | args = parser.parse_args() 8 | 9 | for file in os.listdir(args.data_dir): 10 | cmd = 'python scripts/evaluate_feats.py ' 11 | if 'test' in file: 12 | cmd += ' --reference ' + os.path.join(args.data_dir, file) 13 | cmd += ' --output ' + os.path.join(args.model, file) 14 | cmd += ' > ' + os.path.join(args.model, file) + '.feat.json' 15 | print(cmd) 16 | -------------------------------------------------------------------------------- /pg.scripts/pg.predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if len(sys.argv) < 5: 5 | print("please specify model file, test dir and for each command: time(hours), memory(gb), #cpus") 6 | exit(1) 7 | 8 | 9 | def createJob(name, cmd): 10 | print("creating: " + name) 11 | outFile = open(name, 'w') 12 | outFile.write("#!/bin/bash\n") 13 | outFile.write('\n') 14 | outFile.write("#SBATCH --time=" + sys.argv[3] + ":00:00\n") 15 | outFile.write("#SBATCH --nodes=1\n") 16 | outFile.write("#SBATCH --ntasks=1\n") 17 | outFile.write("#SBATCH --mem=" + sys.argv[4] + 'G\n') 18 | outFile.write("#SBATCH --cpus-per-task=" + sys.argv[5] + '\n') 19 | outFile.write("#SBATCH --job-name=" + name + '\n') 20 | outFile.write("#SBATCH --output=" + name + '.log\n') 21 | outFile.write("#SBATCH --partition=gpu\n") 22 | outFile.write("#SBATCH --gres=gpu:v100:1\n") 23 | outFile.write("\n") 24 | outFile.write("module load Python/3.6.4-intel-2018a\n") 25 | outFile.write("module load Anaconda3\n") 26 | outFile.write(". /software/software/Anaconda3/5.3.0/etc/profile.d/conda.sh\n") 27 | outFile.write("conda deactivate\n") 28 | outFile.write("conda activate allennlp\n") 29 | for c in cmd: 30 | outFile.write(c + "\n") 31 | outFile.close() 32 | 33 | 34 | model = sys.argv[1] 35 | test_dir = sys.argv[2] 36 | cmd = [] 37 | ts = os.listdir(test_dir) 38 | for t in ts: 39 | if 'test' in t: 40 | test = os.path.join(test_dir, t) 41 | str = 'python predict.py ' + model + ' ' + test + ' ' + os.path.join(os.path.split(model)[0], t) + ' --eval ' + os.path.join(os.path.split(model)[0], t) + '.json' 42 | cmd.append(str) 43 | 44 | name = model.split('/')[-3] + '.predict' 45 | createJob(name, cmd) 46 | -------------------------------------------------------------------------------- /pg.scripts/pg.resume.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if len(sys.argv) < 4: 5 | print("please specify model dir and for each command: time(hours), memory(gb), #cpus") 6 | exit(1) 7 | 8 | ''' 9 | module load Python/3.6.4-intel-2018a 10 | module load Anaconda3 11 | . /software/software/Anaconda3/5.3.0/etc/profile.d/conda.sh 12 | conda deactivate 13 | #echo $CONDA_DEFAULT_ENV 14 | conda activate allennlp 15 | #echo $CONDA_DEFAULT_ENV 16 | python train.py --resume logs/udify.baseline.3/2019.10.08_11.05.27 17 | ''' 18 | 19 | 20 | def createJob(name, cmd): 21 | print("creating: " + name) 22 | outFile = open(name, 'w') 23 | outFile.write("#!/bin/bash\n") 24 | outFile.write('\n') 25 | outFile.write("#SBATCH --time=" + sys.argv[2] + ":00:00\n") 26 | outFile.write("#SBATCH --nodes=1\n") 27 | outFile.write("#SBATCH --ntasks=1\n") 28 | outFile.write("#SBATCH --mem=" + sys.argv[3] + 'G\n') 29 | outFile.write("#SBATCH --cpus-per-task=" + sys.argv[4] + '\n') 30 | outFile.write("#SBATCH --job-name=" + name + '\n') 31 | outFile.write("#SBATCH --output=" + name + '.log\n') 32 | outFile.write("#SBATCH --partition=gpu\n") 33 | outFile.write("#SBATCH --gres=gpu:v100:1\n") 34 | outFile.write("\n") 35 | outFile.write("module load Python/3.6.4-intel-2018a\n") 36 | outFile.write("module load Anaconda3\n") 37 | outFile.write(". /software/software/Anaconda3/5.3.0/etc/profile.d/conda.sh\n") 38 | outFile.write("conda deactivate\n") 39 | outFile.write("conda activate allennlp\n") 40 | outFile.write(cmd + "\n") 41 | outFile.close() 42 | 43 | 44 | model = sys.argv[1] 45 | name = os.path.split(os.path.split(model)[0])[1] + '.resume' 46 | cmd = 'python train.py --resume ' + model 47 | createJob(name, cmd) 48 | -------------------------------------------------------------------------------- /pg.scripts/pg.run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if len(sys.argv) < 4: 5 | print("please specify config file and for each command: time(hours), memory(gb), #cpus") 6 | exit(1) 7 | 8 | 9 | def createJob(name, cmd): 10 | print("creating: " + name) 11 | outFile = open(name, 'w') 12 | outFile.write("#!/bin/bash\n") 13 | outFile.write('\n') 14 | outFile.write("#SBATCH --time=" + sys.argv[2] + ":00:00\n") 15 | outFile.write("#SBATCH --nodes=1\n") 16 | outFile.write("#SBATCH --ntasks=1\n") 17 | outFile.write("#SBATCH --mem=" + sys.argv[3] + 'G\n') 18 | outFile.write("#SBATCH --cpus-per-task=" + sys.argv[4] + '\n') 19 | outFile.write("#SBATCH --job-name=" + name + '\n') 20 | outFile.write("#SBATCH --output=" + name + '.log\n') 21 | outFile.write("#SBATCH --partition=gpu\n") 22 | outFile.write("#SBATCH --gres=gpu:v100:1\n") 23 | outFile.write("\n") 24 | outFile.write("module load Python/3.6.4-intel-2018a\n") 25 | outFile.write("module load Anaconda3\n") 26 | outFile.write(". /software/software/Anaconda3/5.3.0/etc/profile.d/conda.sh\n") 27 | outFile.write("conda deactivate\n") 28 | outFile.write("conda activate allennlp\n") 29 | outFile.write(cmd + "\n") 30 | outFile.close() 31 | 32 | 33 | config = sys.argv[1] 34 | name = os.path.split(config)[1].split('.')[0] 35 | cmd = 'python train.py --config ' + config + ' --name ' + name 36 | createJob(name, cmd) 37 | -------------------------------------------------------------------------------- /pg.scripts/pg.seq.eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--model", type=str) 6 | parser.add_argument("--data_dir", type=str) 7 | parser.add_argument("--column", type=int) 8 | args = parser.parse_args() 9 | 10 | for file in os.listdir(args.data_dir): 11 | cmd = 'python scripts/seq_eval.py ' 12 | if 'test' in file: 13 | cmd += ' --gold_file ' + os.path.join(args.data_dir, file) 14 | cmd += ' --pred_file ' + os.path.join(args.model, file) 15 | cmd += ' --out_plot ' + os.path.join(args.model, file) + '.eval.png' 16 | cmd += ' --column ' + str(args.column) 17 | cmd += ' > ' + os.path.join(args.model, file) + '.eval.json' 18 | print(cmd) 19 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict conllu files given a trained model 3 | """ 4 | 5 | import os 6 | import shutil 7 | import logging 8 | import argparse 9 | import tarfile 10 | from pathlib import Path 11 | 12 | from allennlp.common import Params 13 | from allennlp.common.util import import_submodules 14 | from allennlp.models.archival import archive_model 15 | 16 | from udapter import util 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("archive", type=str, help="The archive file") 24 | parser.add_argument("input_file", type=str, help="The input file to predict") 25 | parser.add_argument("pred_file", type=str, help="The output prediction file") 26 | parser.add_argument("--eval_file", default=None, type=str, 27 | help="If set, evaluate the prediction and store it in the given file") 28 | parser.add_argument("--device", default=0, type=int, help="CUDA device number; set to -1 for CPU") 29 | parser.add_argument("--batch_size", default=1, type=int, help="The size of each prediction batch") 30 | parser.add_argument("--lazy", action="store_true", help="Lazy load dataset") 31 | 32 | args = parser.parse_args() 33 | 34 | import_submodules("udapter") 35 | 36 | archive_dir = Path(args.archive).resolve().parent 37 | 38 | if not os.path.isfile(archive_dir / "weights.th"): 39 | with tarfile.open(args.archive) as tar: 40 | tar.extractall(archive_dir) 41 | 42 | config_file = archive_dir / "config.json" 43 | 44 | overrides = {} 45 | if args.device is not None: 46 | overrides["trainer"] = {"cuda_device": args.device} 47 | if args.lazy: 48 | overrides["dataset_reader"] = {"lazy": args.lazy} 49 | configs = [Params(overrides), Params.from_file(config_file)] 50 | params = util.merge_configs(configs) 51 | 52 | if not args.eval_file: 53 | util.predict_model_with_archive("udapter_predictor", params, archive_dir, args.input_file, args.pred_file, 54 | batch_size=args.batch_size) 55 | else: 56 | util.predict_and_evaluate_model_with_archive("udapter_predictor", params, archive_dir, args.input_file, 57 | args.pred_file, args.eval_file, batch_size=args.batch_size) 58 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | allennlp==0.8.5 3 | seqeval 4 | six 5 | tqdm 6 | lang2vec -------------------------------------------------------------------------------- /scripts/concat_ud_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_DIR="data/ud-treebanks-v2.3" 4 | 5 | echo "Generating multilingual dataset..." 6 | 7 | mkdir -p "data/ud" 8 | mkdir -p "data/ud/multilingual" 9 | 10 | python concat_treebanks.py data/ud/multilingual --dataset_dir ${DATASET_DIR} --add_lang_id -------------------------------------------------------------------------------- /scripts/conll18_ud_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Runs the CoNLL 2018 Shared Task Evaluation 5 | """ 6 | 7 | from udify.dataset_readers.conll18_ud_eval import main 8 | 9 | if __name__ == "__main__": 10 | main() 11 | -------------------------------------------------------------------------------- /scripts/download_ud_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Can download UD 2.3 or 2.4 4 | UD_2_3="https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2895/ud-treebanks-v2.3.tgz?sequence=1&isAllowed=y" 5 | UD_2_4="https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2988/ud-treebanks-v2.4.tgz?sequence=4&isAllowed=y" 6 | 7 | DATASET_DIR="data/ud-treebanks-v2.3" 8 | 9 | ARCHIVE="ud_data.tgz" 10 | 11 | 12 | echo "Downloading UD data..." 13 | 14 | curl ${UD_2_3} -o ${ARCHIVE} 15 | 16 | tar -xvzf ${ARCHIVE} -C ./data 17 | mv ${ARCHIVE} ./data 18 | 19 | 20 | echo "Generating multilingual dataset..." 21 | 22 | mkdir -p "data/ud" 23 | mkdir -p "data/ud/multilingual" 24 | 25 | python concat_treebanks.py data/ud/multilingual --dataset_dir ${DATASET_DIR} --add_lang_id 26 | -------------------------------------------------------------------------------- /scripts/evaluate_feats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Evaluation for the SIGMORPHON 2019 shared task, task 2. 3 | 4 | Computes various metrics on input. 5 | 6 | Author: Arya D. McCarthy 7 | Last update: 2018-12-21 8 | """ 9 | 10 | import argparse 11 | import logging 12 | 13 | import numpy as np 14 | 15 | from collections import namedtuple 16 | from pathlib import Path 17 | 18 | log = logging.getLogger(Path(__file__).stem) 19 | 20 | 21 | REF_COLUMNS = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC LANG".split() 22 | OUT_COLUMNS = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC".split() 23 | ref_ConlluRow = namedtuple("ConlluRow", REF_COLUMNS) 24 | out_ConlluRow = namedtuple("ConlluRow", OUT_COLUMNS) 25 | SEPARATOR = "|" 26 | 27 | 28 | def distance(str1, str2): 29 | """Simple Levenshtein implementation.""" 30 | m = np.zeros([len(str2)+1, len(str1)+1]) 31 | for x in range(1, len(str2) + 1): 32 | m[x][0] = m[x-1][0] + 1 33 | for y in range(1, len(str1) + 1): 34 | m[0][y] = m[0][y-1] + 1 35 | for x in range(1, len(str2) + 1): 36 | for y in range(1, len(str1) + 1): 37 | if str1[y-1] == str2[x-1]: 38 | dg = 0 39 | else: 40 | dg = 1 41 | m[x][y] = min(m[x-1][y] + 1, m[x][y-1] + 1, m[x-1][y-1] + dg) 42 | return int(m[len(str2)][len(str1)]) 43 | 44 | 45 | def set_equal(str1, str2): 46 | set1 = set(str1.split(SEPARATOR)) 47 | set2 = set(str2.split(SEPARATOR)) 48 | return set1 == set2 49 | 50 | 51 | def manipulate_data(pairs): 52 | log.info("Morph acc, Morph F1") 53 | 54 | count = 0 55 | lemma_acc = 0 56 | lemma_lev = 0 57 | morph_acc = 0 58 | 59 | f1_precision_scores = 0 60 | f1_precision_counts = 0 61 | f1_recall_scores = 0 62 | f1_recall_counts = 0 63 | 64 | for r, o in pairs: 65 | log.debug("{}\t{}\t{}\t{}".format(r.LEMMA, o.LEMMA, r.FEATS, o.FEATS)) 66 | count += 1 67 | lemma_acc += (r.LEMMA == o.LEMMA) 68 | lemma_lev += distance(r.LEMMA, o.LEMMA) 69 | morph_acc += set_equal(r.FEATS, o.FEATS) 70 | 71 | # "_" prediction is also taken into account here !! 72 | r_feats = set(r.FEATS.split(SEPARATOR)) # - {"_"} 73 | o_feats = set(o.FEATS.split(SEPARATOR)) # - {"_"} 74 | 75 | union_size = len(r_feats & o_feats) 76 | reference_size = len(r_feats) 77 | output_size = len(o_feats) 78 | 79 | f1_precision_scores += union_size 80 | f1_recall_scores += union_size 81 | f1_precision_counts += output_size 82 | f1_recall_counts += reference_size 83 | 84 | f1_precision = f1_precision_scores / (f1_precision_counts or 1) 85 | f1_recall = f1_recall_scores / (f1_recall_counts or 1) 86 | f1 = 2 * (f1_precision * f1_recall) / (f1_precision + f1_recall + 1E-20) 87 | 88 | # return (100 * lemma_acc / count, lemma_lev / count, 100 * morph_acc / count, 100 * f1) 89 | return 100 * morph_acc / count, 100 * f1 90 | 91 | 92 | def parse_args(): 93 | """Parse command line arguments.""" 94 | parser = argparse.ArgumentParser(description=__doc__) 95 | parser.add_argument('-r', '--reference', 96 | type=Path, required=True) 97 | parser.add_argument('-o', '--output', 98 | type=Path, required=True) 99 | # Set the verbosity level for the logger. The `-v` option will set it to 100 | # the debug level, while the `-q` will set it to the warning level. 101 | # Otherwise use the info level. 102 | verbosity = parser.add_mutually_exclusive_group() 103 | verbosity.add_argument('-v', '--verbose', action='store_const', 104 | const=logging.DEBUG, default=logging.INFO) 105 | verbosity.add_argument('-q', '--quiet', dest='verbose', 106 | action='store_const', const=logging.WARNING) 107 | return parser.parse_args() 108 | 109 | 110 | def strip_comments(lines): 111 | for line in lines: 112 | if not line.startswith("#"): 113 | if '.' not in line.split("\t")[0]: 114 | yield line 115 | 116 | 117 | def read_conllu(file: Path): 118 | with open(file) as f: 119 | yield from strip_comments(f) 120 | 121 | 122 | def input_pairs(reference, output): 123 | for r, o in zip(reference, output): 124 | if r.count("\t") > 0: 125 | r_conllu = ref_ConlluRow._make(r.split("\t")) 126 | o_conllu = out_ConlluRow._make(o.split("\t")) 127 | assert r_conllu.FORM == o_conllu.FORM, (r, o) 128 | yield r_conllu, o_conllu 129 | 130 | 131 | def main(): 132 | args = parse_args() 133 | logging.basicConfig(level=args.verbose) 134 | reference = read_conllu(args.reference) 135 | output = read_conllu(args.output) 136 | results = manipulate_data(input_pairs(reference, output)) 137 | print(*["{0:.2f}".format(v) for v in results], sep='\t') 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /scripts/generate_tables.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | if len(sys.argv) < 2: 6 | print("please specify log dir") 7 | exit(1) 8 | 9 | 10 | las_results = dict() 11 | uas_results = dict() 12 | 13 | prefix = 'test.conllu_' 14 | #suffix = '.json' 15 | suffix = 'feats.json' 16 | log_dir = sys.argv[1] 17 | 18 | for f in os.listdir(log_dir): 19 | if f.startswith(prefix) and f.endswith(suffix): 20 | lang = f.split(prefix)[1].split(suffix)[0] 21 | with open(os.path.join(log_dir,f)) as j: 22 | scores = j.readline().rstrip().split('\t') 23 | acc = scores[0] 24 | f1 = scores[1] 25 | las_results[lang] = acc 26 | uas_results[lang] = f1 27 | #data = json.load(j) 28 | #las = data['LAS']['aligned_accuracy'] 29 | #uas = data['UAS']['aligned_accuracy'] 30 | #las_results[lang] = float(las)*100 31 | #uas_results[lang] = float(uas)*100 32 | 33 | las_results = {k: v for k, v in sorted(las_results.items(), key=lambda item: item[0])} 34 | uas_results = {k: v for k, v in sorted(uas_results.items(), key=lambda item: item[0])} 35 | hr_langs = ['ar', 'eu', 'zh', 'en', 'fi', 'he', 'hi', 'it', 'ja', 'ko', 'ru', 'sv', 'tr'] 36 | 37 | hr_out = open(os.path.join(log_dir, 'hr_results.feats.txt'), 'w') 38 | lr_out = open(os.path.join(log_dir, 'lr_results.feats.txt'), 'w') 39 | for l,r in las_results.items(): 40 | if l in hr_langs: 41 | hr_out.write('{}, {:.2f}, {:.2f}\n'.format(l,r,uas_results[l])) 42 | else: 43 | lr_out.write('{}, {:.2f}, {:.2f}\n'.format(l,r,uas_results[l])) 44 | 45 | hr_out.close() 46 | lr_out.close() 47 | -------------------------------------------------------------------------------- /scripts/ner_to_ud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--dataset_dir", default="data/ner", type=str) 7 | args = parser.parse_args() 8 | 9 | languages = ['koi', 'olo', 'mr', 'mdf', 'sa', 'tl', 'yo', 'am', 'bm', 'be', 'br', 'bxr', 'zh-yue', 'myv', 'fo', 'kk', 'gn', 'ta', 'te', 'hsb', 'cy'] 10 | 11 | dirs = os.listdir(args.dataset_dir) 12 | for dir in dirs: 13 | file_dir = os.path.join(args.dataset_dir,dir) 14 | files = os.listdir(file_dir) 15 | for f in files: 16 | with open(os.path.join(file_dir, f+'.conllu'), 'w') as write: 17 | with open(os.path.join(file_dir, f), 'r') as read: 18 | for line in read: 19 | idx = 0 20 | if line != '\n' and not line.startswith('#'): 21 | form = line.rstrip('\n').split('\t') 22 | lang,word = form[0].split(':',1) 23 | tag = form[1] 24 | anno = [idx, word, '_', tag] 25 | anno += ['99'] * 6 26 | anno += [lang] 27 | write.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(*anno)) 28 | else: 29 | write.write(line) 30 | -------------------------------------------------------------------------------- /scripts/overlap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict # available in Python 2.5 and newer 3 | 4 | 5 | def read_ud(filename, column, exclude=None): 6 | token_count = defaultdict(int) 7 | total_count = 0 8 | for line in open(filename): 9 | line = line.rstrip() 10 | if line.startswith('#') or len(line) == 0: 11 | continue 12 | else: 13 | token = line.split('\t')[column] 14 | if exclude is None or token not in exclude: 15 | token_count[token] += 1 16 | total_count += 1 17 | 18 | return token_count, total_count 19 | 20 | 21 | def main(): 22 | # Parse arguments 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("-S", "--src", type=str) 25 | parser.add_argument("-T", "--tgt", type=str) 26 | parser.add_argument("-C", "--column", type=int, default=1) 27 | parser.add_argument('-E', '--exclude', nargs='*', type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | sc, st = read_ud(args.src, args.column, args.exclude) 32 | print('Source file: {}\nType/Token count: {} / {}'.format(args.src, len(sc), st)) 33 | 34 | tc, tt = read_ud(args.tgt, args.column, args.exclude) 35 | print('Target file: {}\nType/Token count: {} / {}'.format(args.tgt, len(tc), tt)) 36 | 37 | total_overlap = 0 38 | unique_overlap = 0 39 | for word, count in tc.items(): 40 | if word in sc: 41 | total_overlap += tc[word] 42 | unique_overlap += 1 43 | total_ratio = total_overlap / tt * 100 44 | unique_ratio = unique_overlap / len(tc) * 100 45 | print('Type/Token overlap ratio: {:.2f} / {:.2f}'.format(unique_ratio, total_ratio)) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /scripts/seq_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from seqeval.metrics import classification_report 3 | from seqeval.metrics import accuracy_score 4 | 5 | from collections import defaultdict # available in Python 2.5 and newer 6 | from sklearn.metrics import confusion_matrix 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import pandas as pd 10 | 11 | 12 | def read_conllu(file, column): 13 | fin = open(file) 14 | sentences = [] 15 | sentence = [] 16 | for line in fin: 17 | if line.startswith('#'): 18 | continue 19 | if line is None or line == '\n': 20 | sentences.append(sentence) 21 | sentence = [] 22 | else: 23 | columns = line.rstrip().split('\t') 24 | if not '.' in columns[0]: 25 | sentence.append(line.rstrip().split('\t')[column]) 26 | if len(sentence) > 0: 27 | sentences.append(sentence) 28 | return sentences 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--gold_file", type=str) 33 | parser.add_argument("--pred_file", type=str) 34 | parser.add_argument("--out_plot", type=str) 35 | parser.add_argument("--column", type=int, default=5) 36 | args = parser.parse_args() 37 | 38 | y_true = read_conllu(args.gold_file, args.column) 39 | y_pred = read_conllu(args.pred_file, args.column) 40 | 41 | flat_y_true = [item for sublist in y_true for item in sublist] 42 | flat_y_pred = [item for sublist in y_pred for item in sublist] 43 | 44 | assert len(flat_y_true) == len(flat_y_pred) 45 | 46 | print(classification_report(y_true, y_pred, digits=4)) 47 | print(accuracy_score(y_true, y_pred)) 48 | 49 | # Creates a confusion matrix 50 | label_count = defaultdict(int) 51 | for label in flat_y_true: 52 | label_count[label] += 1 53 | 54 | labels = [] 55 | for l,c in label_count.items(): 56 | if c > 20: 57 | labels.append(l) 58 | 59 | cm = confusion_matrix(flat_y_true, flat_y_pred, labels=labels) 60 | cm_df = pd.DataFrame(cm, index=labels, columns=labels) 61 | 62 | plt.figure(figsize=(50, 50)) 63 | sns.heatmap(cm_df, annot=True, cmap="YlGnBu") 64 | plt.ylabel('True label') 65 | plt.xlabel('Predicted label') 66 | plt.savefig(args.out_plot, bbox_inches='tight') 67 | plt.close() 68 | 69 | -------------------------------------------------------------------------------- /scripts/split_file_by_lang.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument("--file", type=str, help="The path containing all UD treebanks") 5 | args = parser.parse_args() 6 | 7 | fin = open(args.file) 8 | lines = [] 9 | lang = None 10 | for line in fin: 11 | if line != '\n' and not line.startswith('#'): 12 | form = line.rstrip('\n').split('\t') 13 | if lang is None: 14 | lang = form[10] 15 | lines.append(line) 16 | elif lang == form[10]: 17 | lines.append(line) 18 | elif lang != form[10]: 19 | fout = open(args.file +'_'+ lang, 'w') 20 | for l in lines: 21 | fout.write(l) 22 | fout.close() 23 | lang = form[10] 24 | lines = [] 25 | lines.append(line) 26 | else: 27 | lines.append(line) 28 | 29 | fout = open(args.file +'_'+ lang, 'w') 30 | for l in lines: 31 | fout.write(l) 32 | fout.close() 33 | -------------------------------------------------------------------------------- /slides/SigTyp_Abstract_EMNLP2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmetustun/udapter/b5438df4a38778f964942200beb3aded15841d65/slides/SigTyp_Abstract_EMNLP2020.pdf -------------------------------------------------------------------------------- /slides/UDapter_EMNLP2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmetustun/udapter/b5438df4a38778f964942200beb3aded15841d65/slides/UDapter_EMNLP2020.pdf -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script useful for debugging UDify and AllenNLP code 3 | """ 4 | 5 | import os 6 | import copy 7 | import datetime 8 | import logging 9 | import argparse 10 | 11 | from allennlp.common import Params 12 | from allennlp.common.util import import_submodules 13 | from allennlp.commands.train import train_model 14 | 15 | from udapter import util 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | level=logging.INFO) 19 | logger = logging.getLogger(__name__) 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--name", default="", type=str, help="Log dir name") 23 | parser.add_argument("--base_config", default="config/udify_base.json", type=str, help="Base configuration file") 24 | parser.add_argument("--config", default=[], type=str, nargs="+", help="Overriding configuration files") 25 | parser.add_argument("--device", default=None, type=int, help="CUDA device; set to -1 for CPU") 26 | parser.add_argument("--resume", type=str, help="Resume training with the given model") 27 | parser.add_argument("--lazy", default=None, action="store_true", help="Lazy load the dataset") 28 | parser.add_argument("--cleanup_archive", action="store_true", help="Delete the model archive") 29 | parser.add_argument("--replace_vocab", action="store_true", help="Create a new vocab and replace the cached one") 30 | parser.add_argument("--archive_bert", action="store_true", help="Archives the finetuned BERT model after training") 31 | parser.add_argument("--predictor", default="udify_predictor", type=str, help="The type of predictor to use") 32 | 33 | args = parser.parse_args() 34 | 35 | log_dir_name = args.name 36 | if not log_dir_name: 37 | file_name = args.config[0] if args.config else args.base_config 38 | log_dir_name = os.path.basename(file_name).split(".")[0] 39 | 40 | configs = [] 41 | 42 | if not args.resume: 43 | serialization_dir = os.path.join("logs", log_dir_name, datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")) 44 | 45 | overrides = {} 46 | if args.device is not None: 47 | overrides["trainer"] = {"cuda_device": args.device} 48 | if args.lazy is not None: 49 | overrides["dataset_reader"] = {"lazy": args.lazy} 50 | configs.append(Params(overrides)) 51 | for config_file in args.config: 52 | configs.append(Params.from_file(config_file)) 53 | configs.append(Params.from_file(args.base_config)) 54 | else: 55 | serialization_dir = args.resume 56 | configs.append(Params.from_file(os.path.join(serialization_dir, "config.json"))) 57 | 58 | train_params = util.merge_configs(configs) 59 | if "vocabulary" in train_params: 60 | # Remove this key to make AllenNLP happy 61 | train_params["vocabulary"].pop("non_padded_namespaces", None) 62 | 63 | predict_params = train_params.duplicate() 64 | 65 | import_submodules("udapter") 66 | 67 | try: 68 | util.cache_vocab(train_params) 69 | train_model(train_params, serialization_dir, recover=bool(args.resume)) 70 | except KeyboardInterrupt: 71 | logger.warning("KeyboardInterrupt, skipping training") 72 | 73 | dev_file = predict_params["validation_data_path"] 74 | test_file = predict_params["test_data_path"] 75 | 76 | dev_pred, dev_eval, test_pred, test_eval = [ 77 | os.path.join(serialization_dir, name) 78 | for name in ["dev.conllu", "dev_results.json", "test.conllu", "test_results.json"] 79 | ] 80 | 81 | if dev_file != test_file: 82 | util.predict_and_evaluate_model(args.predictor, predict_params, serialization_dir, dev_file, dev_pred, dev_eval) 83 | 84 | util.predict_and_evaluate_model(args.predictor, predict_params, serialization_dir, test_file, test_pred, test_eval) 85 | 86 | if args.archive_bert: 87 | bert_config = "config/archive/bert-base-multilingual-cased/bert_config.json" 88 | util.archive_bert_model(serialization_dir, bert_config) 89 | 90 | util.cleanup_training(serialization_dir, keep_archive=not args.cleanup_archive) 91 | -------------------------------------------------------------------------------- /udapter/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.dataset_readers import * 2 | from udapter.udapter_models import * 3 | from udapter.udify_models import * 4 | from udapter.modules import * 5 | from udapter.optimizers import * 6 | from udapter.predictors import * 7 | from udapter import * 8 | -------------------------------------------------------------------------------- /udapter/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader 2 | -------------------------------------------------------------------------------- /udapter/dataset_readers/lemma_edit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for processing lemmas 3 | 4 | Adopted from UDPipe Future 5 | https://github.com/CoNLL-UD-2018/UDPipe-Future 6 | """ 7 | 8 | 9 | def min_edit_script(source, target, allow_copy=False): 10 | """ 11 | Finds the minimum edit script to transform the source to the target 12 | """ 13 | a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)] 14 | for i in range(0, len(source) + 1): 15 | for j in range(0, len(target) + 1): 16 | if i == 0 and j == 0: 17 | a[i][j] = (0, "") 18 | else: 19 | if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]: 20 | a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "→") 21 | if i and a[i-1][j][0] < a[i][j][0]: 22 | a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-") 23 | if j and a[i][j-1][0] < a[i][j][0]: 24 | a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1]) 25 | return a[-1][-1][1] 26 | 27 | 28 | def gen_lemma_rule(form, lemma, allow_copy=False): 29 | """ 30 | Generates a lemma rule to transform the source to the target 31 | """ 32 | form = form.lower() 33 | 34 | previous_case = -1 35 | lemma_casing = "" 36 | for i, c in enumerate(lemma): 37 | case = "↑" if c.lower() != c else "↓" 38 | if case != previous_case: 39 | lemma_casing += "{}{}{}".format("¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma)) 40 | previous_case = case 41 | lemma = lemma.lower() 42 | 43 | best, best_form, best_lemma = 0, 0, 0 44 | for l in range(len(lemma)): 45 | for f in range(len(form)): 46 | cpl = 0 47 | while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1 48 | if cpl > best: 49 | best = cpl 50 | best_form = f 51 | best_lemma = l 52 | 53 | rule = lemma_casing + ";" 54 | if not best: 55 | rule += "a" + lemma 56 | else: 57 | rule += "d{}¦{}".format( 58 | min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy), 59 | min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy), 60 | ) 61 | return rule 62 | 63 | 64 | def apply_lemma_rule(form, lemma_rule): 65 | """ 66 | Applies the lemma rule to the form to generate the lemma 67 | """ 68 | casing, rule = lemma_rule.split(";", 1) 69 | if rule.startswith("a"): 70 | lemma = rule[1:] 71 | else: 72 | form = form.lower() 73 | rules, rule_sources = rule[1:].split("¦"), [] 74 | assert len(rules) == 2 75 | for rule in rules: 76 | source, i = 0, 0 77 | while i < len(rule): 78 | if rule[i] == "→" or rule[i] == "-": 79 | source += 1 80 | else: 81 | assert rule[i] == "+" 82 | i += 1 83 | i += 1 84 | rule_sources.append(source) 85 | 86 | try: 87 | lemma, form_offset = "", 0 88 | for i in range(2): 89 | j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1]) 90 | while j < len(rules[i]): 91 | if rules[i][j] == "→": 92 | lemma += form[offset] 93 | offset += 1 94 | elif rules[i][j] == "-": 95 | offset += 1 96 | else: 97 | assert(rules[i][j] == "+") 98 | lemma += rules[i][j + 1] 99 | j += 1 100 | j += 1 101 | if i == 0: 102 | lemma += form[rule_sources[0]: len(form) - rule_sources[1]] 103 | except: 104 | lemma = form 105 | 106 | for rule in casing.split("¦"): 107 | if rule == "↓0": continue # The lemma is lowercased initially 108 | case, offset = rule[0], int(rule[1:]) 109 | lemma = lemma[:offset] + (lemma[offset:].upper() if case == "↑" else lemma[offset:].lower()) 110 | 111 | return lemma 112 | -------------------------------------------------------------------------------- /udapter/dataset_readers/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parses Universal Dependencies data 3 | """ 4 | 5 | from __future__ import unicode_literals 6 | 7 | import re 8 | from collections import OrderedDict 9 | 10 | DEFAULT_FIELDS = ('id', 'form', 'lemma', 'upostag', 'xpostag', 'feats', 'head', 'deprel', 'deps', 'misc', 'lang') 11 | 12 | deps_pattern = r"\d+:[a-z][a-z_-]*(:[a-z][a-z_-]*)?" 13 | MULTI_DEPS_PATTERN = re.compile(r"^{}(\|{})*$".format(deps_pattern, deps_pattern)) 14 | 15 | 16 | class ParseException(Exception): 17 | pass 18 | 19 | 20 | def parse_token_and_metadata(data, fields=None): 21 | if not data: 22 | raise ParseException("Can't create TokenList, no data sent to constructor.") 23 | 24 | fields = fields or DEFAULT_FIELDS 25 | 26 | tokens = [] 27 | metadata = OrderedDict() 28 | 29 | for line in data.split('\n'): 30 | line = line.strip() 31 | 32 | if not line: 33 | continue 34 | 35 | if line.startswith('#'): 36 | var_name, var_value = parse_comment_line(line) 37 | if var_name: 38 | metadata[var_name] = var_value 39 | else: 40 | tokens.append(parse_line(line, fields=fields)) 41 | 42 | return tokens, metadata 43 | 44 | 45 | def parse_line(line, fields=DEFAULT_FIELDS, parse_feats=True): 46 | line = re.split(r"\t| {2,}", line) 47 | 48 | if len(line) == 1 and " " in line[0]: 49 | raise ParseException("Invalid line format, line must contain either tabs or two spaces.") 50 | 51 | data = OrderedDict() 52 | 53 | for i, field in enumerate(fields): 54 | # Allow parsing CoNNL-U files with fewer columns 55 | if i >= len(line): 56 | break 57 | 58 | if field == "id": 59 | value = parse_id_value(line[i]) 60 | data["multi_id"] = parse_multi_id_value(line[i]) 61 | 62 | elif field == "xpostag": 63 | value = parse_nullable_value(line[i]) 64 | 65 | elif field == "feats": 66 | if parse_feats: 67 | value = parse_dict_value(line[i]) 68 | else: 69 | value = line[i] 70 | 71 | elif field == "head": 72 | value = parse_int_value(line[i]) 73 | 74 | elif field == "deps": 75 | value = parse_paired_list_value(line[i]) 76 | 77 | elif field == "misc": 78 | value = parse_dict_value(line[i]) 79 | 80 | else: 81 | value = line[i] 82 | 83 | data[field] = value 84 | 85 | return data 86 | 87 | 88 | def parse_comment_line(line): 89 | line = line.strip() 90 | if line[0] != '#': 91 | raise ParseException("Invalid comment format, comment must start with '#'") 92 | if '=' not in line: 93 | return None, None 94 | var_name, var_value = line[1:].split('=', 1) 95 | var_name = var_name.strip() 96 | var_value = var_value.strip() 97 | return var_name, var_value 98 | 99 | 100 | def parse_int_value(value): 101 | if value == '_': 102 | return None 103 | try: 104 | return int(value) 105 | except ValueError: 106 | return None 107 | 108 | 109 | def parse_id_value(value): 110 | # return value if "-" not in value else None 111 | return value if "-" not in value and "." not in value else None 112 | # TODO: handle special ids with "." 113 | 114 | 115 | def parse_multi_id_value(value): 116 | if len(value.split('-')) == 2: 117 | return value 118 | return None 119 | 120 | 121 | def parse_paired_list_value(value): 122 | if re.match(MULTI_DEPS_PATTERN, value): 123 | return [ 124 | (part.split(":", 1)[1], parse_int_value(part.split(":", 1)[0])) 125 | for part in value.split("|") 126 | ] 127 | 128 | return parse_nullable_value(value) 129 | 130 | 131 | def parse_dict_value(value): 132 | if "=" in value: 133 | return OrderedDict([ 134 | (part.split("=")[0], parse_nullable_value(part.split("=")[1])) 135 | for part in value.split("|") if len(part.split('=')) == 2 136 | ]) 137 | 138 | return parse_nullable_value(value) 139 | 140 | 141 | def parse_nullable_value(value): 142 | if not value or value == "_": 143 | return None 144 | 145 | return value 146 | -------------------------------------------------------------------------------- /udapter/dataset_readers/universal_dependencies.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Dataset Reader for Universal Dependencies, with support for multiword tokens and special handling for NULL "_" tokens 3 | """ 4 | from collections import OrderedDict 5 | from typing import Dict, Tuple, List, Any, Callable 6 | 7 | from overrides import overrides 8 | from udapter.dataset_readers.parser import parse_line, DEFAULT_FIELDS 9 | 10 | from allennlp.common.file_utils import cached_path 11 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 12 | from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField 13 | from allennlp.data.instance import Instance 14 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 15 | from allennlp.data.tokenizers import Token 16 | 17 | from udapter.dataset_readers.lemma_edit import gen_lemma_rule 18 | 19 | import json 20 | import logging 21 | 22 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | def lazy_parse(text: str, fields: Tuple[str, ...]=DEFAULT_FIELDS): 26 | for sentence in text.split("\n\n"): 27 | if sentence: 28 | # TODO: upgrade conllu library 29 | yield [parse_line(line, fields) 30 | for line in sentence.split("\n") 31 | if line and not line.strip().startswith("#")] 32 | 33 | 34 | @DatasetReader.register("udify_universal_dependencies") 35 | class UniversalDependenciesDatasetReader(DatasetReader): 36 | def __init__(self, 37 | token_indexers: Dict[str, TokenIndexer] = None, 38 | lazy: bool = False, 39 | use_lang_ids: bool = False, 40 | use_separate_feats: bool = False, 41 | ud_feats_schema: str = 'config/ud/feats.json') -> None: 42 | super().__init__(lazy) 43 | self.use_lang_ids = use_lang_ids 44 | self.use_separate_feats = use_separate_feats 45 | with open(ud_feats_schema, "r") as fin: 46 | self.ud_feats_schema = json.load(fin) 47 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 48 | 49 | @overrides 50 | def _read(self, file_path: str): 51 | # if `file_path` is a URL, redirect to the cache 52 | file_path = cached_path(file_path) 53 | 54 | with open(file_path, 'r') as conllu_file: 55 | logger.info("Reading UD instances from conllu dataset at: %s", file_path) 56 | 57 | for annotation in lazy_parse(conllu_file.read()): 58 | # CoNLLU annotations sometimes add back in words that have been elided 59 | # in the original sentence; we remove these, as we're just predicting 60 | # dependencies for the original sentence. 61 | # We filter by None here as elided words have a non-integer word id, 62 | # and are replaced with None by the conllu python library. 63 | multiword_tokens = [x for x in annotation if x["multi_id"] is not None] 64 | annotation = [x for x in annotation if x["id"] is not None] 65 | 66 | if len(annotation) == 0: 67 | continue 68 | 69 | def get_field(tag: str, map_fn: Callable[[Any], Any] = None) -> List[Any]: 70 | map_fn = map_fn if map_fn is not None else lambda x: x 71 | return [map_fn(x[tag]) if x[tag] is not None else "_" for x in annotation if tag in x] 72 | 73 | # Extract multiword token rows (not used for prediction, purely for evaluation) 74 | ids = [x["id"] for x in annotation] 75 | multiword_ids = [x["multi_id"] for x in multiword_tokens] 76 | multiword_forms = [x["form"] for x in multiword_tokens] 77 | 78 | words = get_field("form") 79 | lemmas = get_field("lemma") 80 | lemma_rules = [gen_lemma_rule(word, lemma) 81 | if lemma != "_" else "_" 82 | for word, lemma in zip(words, lemmas)] 83 | upos_tags = get_field("upostag") 84 | xpos_tags = get_field("xpostag") 85 | feats = get_field("feats", lambda x: "|".join(k + "=" + v for k, v in x.items()) 86 | if hasattr(x, "items") else "_") 87 | heads = get_field("head") 88 | dep_rels = get_field("deprel") 89 | dependencies = list(zip(dep_rels, heads)) 90 | 91 | langs = get_field("lang") 92 | 93 | yield self.text_to_instance(words, lemmas, lemma_rules, upos_tags, xpos_tags, 94 | feats, get_field("feats"), dependencies, ids, multiword_ids, multiword_forms, langs) 95 | 96 | @overrides 97 | def text_to_instance(self, # type: ignore 98 | words: List[str], 99 | lemmas: List[str] = None, 100 | lemma_rules: List[str] = None, 101 | upos_tags: List[str] = None, 102 | xpos_tags: List[str] = None, 103 | feats: List[str] = None, 104 | separate_feats: List[Dict[str, str]] = None, 105 | dependencies: List[Tuple[str, int]] = None, 106 | ids: List[str] = None, 107 | multiword_ids: List[str] = None, 108 | multiword_forms: List[str] = None, 109 | langs: List[str] = None) -> Instance: 110 | fields: Dict[str, Field] = {} 111 | 112 | if self.use_lang_ids: 113 | # use ent_type_ for lang_ids 114 | tokens = TextField([Token(text=w, ent_type_=l) for w,l in zip(words,langs)], self._token_indexers) 115 | else: 116 | tokens = TextField([Token(w) for w in words], self._token_indexers) 117 | fields["tokens"] = tokens 118 | 119 | names = ["upos", "xpos", "feats", "lemmas", "langs"] 120 | all_tags = [upos_tags, xpos_tags, feats, lemma_rules, langs] 121 | for name, field in zip(names, all_tags): 122 | if field: 123 | fields[name] = SequenceLabelField(field, tokens, label_namespace=name) 124 | 125 | if dependencies is not None: 126 | # We don't want to expand the label namespace with an additional dummy token, so we'll 127 | # always give the 'ROOT_HEAD' token a label of 'root'. 128 | fields["head_tags"] = SequenceLabelField([x[0] for x in dependencies], 129 | tokens, 130 | label_namespace="head_tags") 131 | fields["head_indices"] = SequenceLabelField([int(x[1]) for x in dependencies], 132 | tokens, 133 | label_namespace="head_index_tags") 134 | 135 | if self.use_separate_feats: 136 | feature_seq = [] 137 | for feat_set in separate_feats: 138 | dimensions = {dimension.replace('[','_').replace(']','_'): "_" for dimension in self.ud_feats_schema} 139 | 140 | if feat_set != "_": 141 | for dimension in feat_set: 142 | dimensions[dimension.replace('[','_').replace(']','_')] = feat_set[dimension] 143 | 144 | feature_seq.append(dimensions) 145 | 146 | for dimension in self.ud_feats_schema: 147 | d = dimension.replace('[','_').replace(']','_') 148 | labels = [f[d] for f in feature_seq] 149 | fields[d] = SequenceLabelField(labels, tokens, label_namespace=d) 150 | 151 | fields["metadata"] = MetadataField({ 152 | "words": words, 153 | "upos_tags": upos_tags, 154 | "xpos_tags": xpos_tags, 155 | "feats": feats, 156 | "lemmas": lemmas, 157 | "lemma_rules": lemma_rules, 158 | "ids": ids, 159 | "multiword_ids": multiword_ids, 160 | "multiword_forms": multiword_forms, 161 | "langs": langs 162 | }) 163 | 164 | return Instance(fields) 165 | -------------------------------------------------------------------------------- /udapter/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.modules.bert_pretrained import UdifyPretrainedBertEmbedder, WordpieceIndexer, PretrainedBertIndexer, BertEmbedder 2 | from udapter.modules.xlm_pretrained import UdifyPretrainedXLMEmbedder, WordpieceIndexer, PretrainedXLMIndexer, XLMEmbedder 3 | from udapter.modules.residual_rnn import ResidualRNN 4 | from udapter.modules.scalar_mix import ScalarMixWithDropout 5 | from udapter.modules.text_field_embedder import UdifyTextFieldEmbedder 6 | from udapter.modules.token_characters_encoder import UdifyTokenCharactersEncoder 7 | from udapter.modules.bert_adapter import UdifyPretrainedBertEmbedder, WordpieceIndexer, PretrainedBertIndexer, BertEmbedder -------------------------------------------------------------------------------- /udapter/modules/bucket_iterator_by_languages.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from collections import deque 4 | from typing import List, Tuple, Iterable, cast, Dict, Deque 5 | 6 | from overrides import overrides 7 | from collections import defaultdict 8 | 9 | from allennlp.common.checks import ConfigurationError 10 | from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of 11 | from allennlp.data.dataset import Batch 12 | from allennlp.data.instance import Instance 13 | from allennlp.data.iterators.data_iterator import DataIterator 14 | from allennlp.data.vocabulary import Vocabulary 15 | 16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | def group_by_lang(instances: List[Instance]) -> List[Instance]: 20 | d = defaultdict(list) 21 | 22 | for item in instances: 23 | d[item['langs'][0]].append(item) 24 | 25 | return [v for v in d.values()] 26 | 27 | 28 | def sort_by_padding(instances: List[Instance], 29 | sorting_keys: List[Tuple[str, str]], # pylint: disable=invalid-sequence-index 30 | vocab: Vocabulary, 31 | padding_noise: float = 0.0) -> List[Instance]: 32 | """ 33 | Sorts the instances by their padding lengths, using the keys in 34 | ``sorting_keys`` (in the order in which they are provided). ``sorting_keys`` is a list of 35 | ``(field_name, padding_key)`` tuples. 36 | """ 37 | instances_with_lengths = [] 38 | for instance in instances: 39 | # Make sure instance is indexed before calling .get_padding 40 | instance.index_fields(vocab) 41 | padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) 42 | if padding_noise > 0.0: 43 | noisy_lengths = {} 44 | for field_name, field_lengths in padding_lengths.items(): 45 | noisy_lengths[field_name] = add_noise_to_dict_values(field_lengths, padding_noise) 46 | padding_lengths = noisy_lengths 47 | instance_with_lengths = ([padding_lengths[field_name][padding_key] 48 | for (field_name, padding_key) in sorting_keys], 49 | instance) 50 | instances_with_lengths.append(instance_with_lengths) 51 | instances_with_lengths.sort(key=lambda x: x[0]) 52 | return [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths] 53 | 54 | 55 | @DataIterator.register("lang-bucket") 56 | class BucketIterator(DataIterator): 57 | """ 58 | An iterator which by default, pads batches with respect to the maximum input lengths `per 59 | batch`. Additionally, you can provide a list of field names and padding keys which the dataset 60 | will be sorted by before doing this batching, causing inputs with similar length to be batched 61 | together, making computation more efficient (as less time is wasted on padded elements of the 62 | batch). 63 | 64 | Parameters 65 | ---------- 66 | sorting_keys : List[Tuple[str, str]] 67 | To bucket inputs into batches, we want to group the instances by padding length, so that we 68 | minimize the amount of padding necessary per batch. In order to do this, we need to know 69 | which fields need what type of padding, and in what order. 70 | 71 | For example, ``[("sentence1", "num_tokens"), ("sentence2", "num_tokens"), ("sentence1", 72 | "num_token_characters")]`` would sort a dataset first by the "num_tokens" of the 73 | "sentence1" field, then by the "num_tokens" of the "sentence2" field, and finally by the 74 | "num_token_characters" of the "sentence1" field. TODO(mattg): we should have some 75 | documentation somewhere that gives the standard padding keys used by different fields. 76 | padding_noise : float, optional (default=.1) 77 | When sorting by padding length, we add a bit of noise to the lengths, so that the sorting 78 | isn't deterministic. This parameter determines how much noise we add, as a percentage of 79 | the actual padding value for each instance. 80 | biggest_batch_first : bool, optional (default=False) 81 | This is largely for testing, to see how large of a batch you can safely use with your GPU. 82 | This will let you try out the largest batch that you have in the data `first`, so that if 83 | you're going to run out of memory, you know it early, instead of waiting through the whole 84 | epoch to find out at the end that you're going to crash. 85 | 86 | Note that if you specify ``max_instances_in_memory``, the first batch will only be the 87 | biggest from among the first "max instances in memory" instances. 88 | batch_size : int, optional, (default = 32) 89 | The size of each batch of instances yielded when calling the iterator. 90 | instances_per_epoch : int, optional, (default = None) 91 | See :class:`BasicIterator`. 92 | max_instances_in_memory : int, optional, (default = None) 93 | See :class:`BasicIterator`. 94 | maximum_samples_per_batch : ``Tuple[str, int]``, (default = None) 95 | See :class:`BasicIterator`. 96 | skip_smaller_batches : bool, optional, (default = False) 97 | When the number of data samples is not dividable by `batch_size`, 98 | some batches might be smaller than `batch_size`. 99 | If set to `True`, those smaller batches will be discarded. 100 | """ 101 | 102 | def __init__(self, 103 | sorting_keys: List[Tuple[str, str]], 104 | padding_noise: float = 0.1, 105 | biggest_batch_first: bool = False, 106 | batch_size: int = 32, 107 | instances_per_epoch: int = None, 108 | max_instances_in_memory: int = None, 109 | cache_instances: bool = False, 110 | track_epoch: bool = False, 111 | maximum_samples_per_batch: Tuple[str, int] = None, 112 | skip_smaller_batches: bool = False, 113 | shuffle_lang_batches: bool = False) -> None: 114 | self.shuffle_lang_batches = shuffle_lang_batches 115 | if not sorting_keys: 116 | raise ConfigurationError("BucketIterator requires sorting_keys to be specified") 117 | 118 | super().__init__(cache_instances=cache_instances, 119 | track_epoch=track_epoch, 120 | batch_size=batch_size, 121 | instances_per_epoch=instances_per_epoch, 122 | max_instances_in_memory=max_instances_in_memory, 123 | maximum_samples_per_batch=maximum_samples_per_batch) 124 | self._sorting_keys = sorting_keys 125 | self._padding_noise = padding_noise 126 | self._biggest_batch_first = biggest_batch_first 127 | self._skip_smaller_batches = skip_smaller_batches 128 | 129 | ''' 130 | @overrides 131 | def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 132 | all_batches = [] 133 | instances_by_lang = group_by_lang(instances) 134 | for instances in instances_by_lang: 135 | for instance_list in self._memory_sized_lists(instances): 136 | 137 | instance_list = sort_by_padding(instance_list, 138 | self._sorting_keys, 139 | self.vocab, 140 | self._padding_noise) 141 | 142 | batches = [] 143 | excess: Deque[Instance] = deque() 144 | for batch_instances in lazy_groups_of(iter(instance_list), self._batch_size): 145 | for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances, excess): 146 | if self._skip_smaller_batches and len(possibly_smaller_batches) < self._batch_size: 147 | continue 148 | batches.append(Batch(possibly_smaller_batches)) 149 | if excess and (not self._skip_smaller_batches or len(excess) == self._batch_size): 150 | batches.append(Batch(excess)) 151 | 152 | # TODO(brendanr): Add multi-GPU friendly grouping, i.e. group 153 | # num_gpu batches together, shuffle and then expand the groups. 154 | # This guards against imbalanced batches across GPUs. 155 | move_to_front = self._biggest_batch_first and len(batches) > 1 156 | if move_to_front: 157 | # We'll actually pop the last _two_ batches, because the last one might not be full. 158 | last_batch = batches.pop() 159 | penultimate_batch = batches.pop() 160 | if shuffle: 161 | # NOTE: if shuffle is false, the data will still be in a different order 162 | # because of the bucket sorting. 163 | random.shuffle(batches) 164 | if move_to_front: 165 | batches.insert(0, penultimate_batch) 166 | batches.insert(0, last_batch) 167 | 168 | all_batches.extend(batches) 169 | if shuffle: 170 | # NOTE: if shuffle is false, the data will still be in a different order 171 | # because of the bucket sorting. 172 | random.shuffle(all_batches) 173 | yield from all_batches 174 | ''' 175 | 176 | @overrides 177 | def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 178 | all_batches = [] 179 | instances_by_lang = group_by_lang(instances) 180 | if shuffle: 181 | random.shuffle(instances_by_lang) 182 | for instances in instances_by_lang: 183 | for instance_list in self._memory_sized_lists(instances): 184 | 185 | instance_list = sort_by_padding(instance_list, 186 | self._sorting_keys, 187 | self.vocab, 188 | self._padding_noise) 189 | 190 | batches = [] 191 | excess: Deque[Instance] = deque() 192 | for batch_instances in lazy_groups_of(iter(instance_list), self._batch_size): 193 | for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances, excess): 194 | if self._skip_smaller_batches and len(possibly_smaller_batches) < self._batch_size: 195 | continue 196 | batches.append(Batch(possibly_smaller_batches)) 197 | if excess and (not self._skip_smaller_batches or len(excess) == self._batch_size): 198 | batches.append(Batch(excess)) 199 | 200 | # TODO(brendanr): Add multi-GPU friendly grouping, i.e. group 201 | # num_gpu batches together, shuffle and then expand the groups. 202 | # This guards against imbalanced batches across GPUs. 203 | move_to_front = self._biggest_batch_first and len(batches) > 1 204 | if move_to_front: 205 | # We'll actually pop the last _two_ batches, because the last one might not be full. 206 | last_batch = batches.pop() 207 | penultimate_batch = batches.pop() 208 | if shuffle: 209 | # NOTE: if shuffle is false, the data will still be in a different order 210 | # because of the bucket sorting. 211 | random.shuffle(batches) 212 | if move_to_front: 213 | batches.insert(0, penultimate_batch) 214 | batches.insert(0, last_batch) 215 | if not self.shuffle_lang_batches: 216 | yield from batches 217 | else: 218 | all_batches.extend(batches) 219 | if self.shuffle_lang_batches: 220 | random.shuffle(all_batches) 221 | yield from all_batches 222 | -------------------------------------------------------------------------------- /udapter/modules/language_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LanguageEmbeddings(nn.Module): 6 | def __init__(self, config): 7 | super(LanguageEmbeddings, self).__init__() 8 | self.config = config 9 | 10 | self.language_emb = nn.Embedding(num_embeddings=config.num_languages, embedding_dim=config.low_rank_dim) 11 | self.dropout = nn.Dropout(config.language_emb_dropout) 12 | 13 | def forward(self, lang_ids): 14 | 15 | if lang_ids < 1000: 16 | lang_emb = self.language_emb(lang_ids.clone().detach()) 17 | lang_emb = self.dropout(lang_emb) 18 | else: 19 | lang_emb = torch.mean(self.language_emb.weight, dim=0) 20 | 21 | return lang_emb.view(-1) 22 | -------------------------------------------------------------------------------- /udapter/modules/language_mlp.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | 4 | import lang2vec.lang2vec as l2v 5 | 6 | from typing import List 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class LanguageMLP(nn.Module): 14 | def __init__(self, config, in_language_list: List[str], oov_language_list: List[str], letter_codes: str): 15 | super(LanguageMLP, self).__init__() 16 | 17 | self.config = config 18 | self.do_onehot = config.one_hot 19 | self.in_language_list = in_language_list 20 | self.oov_language_list = oov_language_list 21 | self.letter_codes = letter_codes 22 | 23 | l2v.LETTER_CODES_FILE = letter_codes 24 | l2v.LETTER_CODES = json.load(open(letter_codes)) 25 | 26 | self.l2v_cache = dict() 27 | self._cache_language_features() 28 | 29 | nl_project = config.nl_project 30 | in_features = len(self.in_language_list) + 1 + config.num_language_features if self.do_onehot else config.num_language_features 31 | self.nonlinear_project = nn.Linear(in_features, nl_project) 32 | self.down_project = nn.Linear(nl_project, config.low_rank_dim) 33 | self.activation = F.relu 34 | self.dropout = nn.Dropout(config.language_emb_dropout) 35 | 36 | def forward(self, lang_ids): 37 | 38 | lang_vector = self._encode_language_ids(lang_ids, self.do_onehot) 39 | 40 | lang_emb = self.nonlinear_project(torch.tensor(lang_vector).to(lang_ids.device)) 41 | lang_emb = self.activation(lang_emb) 42 | lang_emb = self.down_project(lang_emb) 43 | lang_emb = self.dropout(lang_emb) 44 | return lang_emb 45 | 46 | def _encode_language_ids(self, language_id: int, do_onehot: bool = False) -> List[int]: 47 | 48 | # language one-hot vector 49 | # 0th id for UNK language 50 | # drop language_id 51 | one_hot = [0 for i in range(len(self.in_language_list) + 1)] 52 | if (random.random() < self.config.language_drop_rate) and self.training: 53 | one_hot[0] = 1 54 | elif language_id >= 1000: 55 | one_hot[0] = 1 56 | else: 57 | one_hot[language_id + 1] = 1 58 | 59 | # feature vector from lang2vec cache 60 | features = self.l2v_cache[self.in_language_list[language_id] if language_id < 1000 else self.oov_language_list[language_id-1000]] 61 | 62 | return features if not do_onehot else one_hot + features 63 | 64 | def _cache_language_features(self): 65 | 66 | features = dict() 67 | for lang in self.in_language_list + self.oov_language_list: 68 | features[lang] = l2v.get_features(l2v.LETTER_CODES[lang], self.config.language_features)[l2v.LETTER_CODES[lang]] 69 | self.l2v_cache = features 70 | -------------------------------------------------------------------------------- /udapter/modules/residual_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a "Residual RNN", which adds the input of the RNNs to the output, which tends to work better than conventional 3 | RNN layers. 4 | """ 5 | 6 | from overrides import overrides 7 | import torch 8 | from allennlp.common.checks import ConfigurationError 9 | from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder 10 | from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper 11 | 12 | 13 | @Seq2SeqEncoder.register("udify_residual_rnn") 14 | class ResidualRNN(Seq2SeqEncoder): 15 | """ 16 | Instead of using AllenNLP's default PyTorch wrapper for seq2seq layers, we would like 17 | to apply intermediate logic between each layer, so we create a new wrapper for each layer, 18 | with residual connections between. 19 | """ 20 | 21 | def __init__(self, 22 | input_size: int, 23 | hidden_size: int, 24 | num_layers: int = 1, 25 | dropout: float = 0.0, 26 | residual: bool = True, 27 | rnn_type: str = "lstm") -> None: 28 | super(ResidualRNN, self).__init__() 29 | 30 | self._input_size = input_size 31 | self._hidden_size = hidden_size 32 | self._dropout = torch.nn.Dropout(p=dropout) 33 | self._residual = residual 34 | 35 | rnn_type = rnn_type.lower() 36 | if rnn_type == "lstm": 37 | rnn_cell = torch.nn.LSTM 38 | elif rnn_type == "gru": 39 | rnn_cell = torch.nn.GRU 40 | else: 41 | raise ConfigurationError(f"Unknown RNN cell type {rnn_type}") 42 | 43 | layers = [] 44 | for layer_index in range(num_layers): 45 | # Use hidden size on later layers so that the first layer projects and all other layers are residual 46 | input_ = input_size if layer_index == 0 else hidden_size 47 | rnn = rnn_cell(input_, hidden_size, bidirectional=True, batch_first=True) 48 | layer = PytorchSeq2SeqWrapper(rnn) 49 | layers.append(layer) 50 | self.add_module("rnn_layer_{}".format(layer_index), layer) 51 | self._layers = layers 52 | 53 | @overrides 54 | def get_input_dim(self) -> int: 55 | return self._input_size 56 | 57 | @overrides 58 | def get_output_dim(self) -> int: 59 | return self._hidden_size 60 | 61 | @overrides 62 | def is_bidirectional(self): 63 | return True 64 | 65 | def forward(self, inputs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ 66 | hidden = inputs 67 | for i, layer in enumerate(self._layers): 68 | encoded = layer(hidden, mask) 69 | # Sum the backward and forward states to allow residual connections 70 | encoded = encoded[:, :, :self._hidden_size] + encoded[:, :, self._hidden_size:] 71 | 72 | projecting = i == 0 and self._input_size != self._hidden_size 73 | if self._residual and not projecting: 74 | hidden = hidden + self._dropout(encoded) 75 | else: 76 | hidden = self._dropout(encoded) 77 | 78 | return hidden 79 | -------------------------------------------------------------------------------- /udapter/modules/scalar_mix.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dot-product "Layer Attention" that is applied to the layers of BERT, along with layer dropout to reduce overfitting 3 | """ 4 | 5 | from typing import List 6 | 7 | import torch 8 | from torch.nn import ParameterList, Parameter 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | 12 | 13 | class ScalarMixWithDropout(torch.nn.Module): 14 | """ 15 | Computes a parameterised scalar mixture of N tensors, ``mixture = gamma * sum(s_k * tensor_k)`` 16 | where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters. 17 | 18 | If ``do_layer_norm=True`` then apply layer normalization to each tensor before weighting. 19 | 20 | If ``dropout > 0``, then for each scalar weight, adjust its softmax weight mass to 0 with 21 | the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively 22 | should redistribute dropped probability mass to all other weights. 23 | """ 24 | def __init__(self, 25 | mixture_size: int, 26 | do_layer_norm: bool = False, 27 | initial_scalar_parameters: List[float] = None, 28 | trainable: bool = True, 29 | dropout: float = None, 30 | dropout_value: float = -1e20) -> None: 31 | super(ScalarMixWithDropout, self).__init__() 32 | self.mixture_size = mixture_size 33 | self.do_layer_norm = do_layer_norm 34 | self.dropout = dropout 35 | 36 | if initial_scalar_parameters is None: 37 | initial_scalar_parameters = [0.0] * mixture_size 38 | elif len(initial_scalar_parameters) != mixture_size: 39 | raise ConfigurationError("Length of initial_scalar_parameters {} differs " 40 | "from mixture_size {}".format( 41 | initial_scalar_parameters, mixture_size)) 42 | 43 | self.scalar_parameters = ParameterList( 44 | [Parameter(torch.FloatTensor([initial_scalar_parameters[i]]), 45 | requires_grad=trainable) for i 46 | in range(mixture_size)]) 47 | self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) 48 | 49 | if self.dropout: 50 | dropout_mask = torch.zeros(len(self.scalar_parameters)) 51 | dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value) 52 | self.register_buffer("dropout_mask", dropout_mask) 53 | self.register_buffer("dropout_fill", dropout_fill) 54 | 55 | def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ 56 | mask: torch.Tensor = None) -> torch.Tensor: 57 | """ 58 | Compute a weighted average of the ``tensors``. The input tensors an be any shape 59 | with at least two dimensions, but must all be the same shape. 60 | 61 | When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are 62 | dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned 63 | ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape 64 | ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``. 65 | 66 | When ``do_layer_norm=False`` the ``mask`` is ignored. 67 | """ 68 | if len(tensors) != self.mixture_size: 69 | raise ConfigurationError("{} tensors were passed, but the module was initialized to " 70 | "mix {} tensors.".format(len(tensors), self.mixture_size)) 71 | 72 | def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): 73 | tensor_masked = tensor * broadcast_mask 74 | mean = torch.sum(tensor_masked) / num_elements_not_masked 75 | variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked 76 | return (tensor - mean) / torch.sqrt(variance + 1E-12) 77 | 78 | weights = torch.cat([parameter for parameter in self.scalar_parameters]) 79 | 80 | if self.dropout and self.training: 81 | weights = torch.where(self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill) 82 | 83 | normed_weights = torch.nn.functional.softmax(weights, dim=0) 84 | normed_weights = torch.split(normed_weights, split_size_or_sections=1) 85 | 86 | if not self.do_layer_norm: 87 | pieces = [] 88 | for weight, tensor in zip(normed_weights, tensors): 89 | pieces.append(weight * tensor) 90 | return self.gamma * sum(pieces) 91 | 92 | else: 93 | mask_float = mask.float() 94 | broadcast_mask = mask_float.unsqueeze(-1) 95 | input_dim = tensors[0].size(-1) 96 | num_elements_not_masked = torch.sum(mask_float) * input_dim 97 | 98 | pieces = [] 99 | for weight, tensor in zip(normed_weights, tensors): 100 | pieces.append(weight * _do_layer_norm(tensor, 101 | broadcast_mask, num_elements_not_masked)) 102 | return self.gamma * sum(pieces) 103 | -------------------------------------------------------------------------------- /udapter/modules/text_field_embedder.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modification to AllenNLP's TextFieldEmbedder 3 | """ 4 | 5 | from typing import Dict, List, Optional 6 | import warnings 7 | 8 | import torch 9 | from overrides import overrides 10 | from torch.nn.modules.linear import Linear 11 | 12 | from allennlp.common import Params 13 | from allennlp.common.checks import ConfigurationError 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules.text_field_embedders import TextFieldEmbedder 16 | from allennlp.modules.time_distributed import TimeDistributed 17 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 18 | 19 | 20 | @TextFieldEmbedder.register("udify_embedder") 21 | class UdifyTextFieldEmbedder(TextFieldEmbedder): 22 | """ 23 | This is a ``TextFieldEmbedder`` that, instead of concatenating embeddings together 24 | as in ``BasicTextFieldEmbedder``, sums them together. It also optionally allows 25 | dropout to be applied to its output. See AllenNLP's basic_text_field_embedder.py, 26 | https://github.com/allenai/allennlp/blob/a75cb0a9ddedb23801db74629c3a3017dafb375e/ 27 | allennlp/modules/text_field_embedders/ 28 | basic_text_field_embedder.py 29 | """ 30 | 31 | def __init__(self, 32 | token_embedders: Dict[str, TokenEmbedder], 33 | output_dim: Optional[int] = None, 34 | sum_embeddings: List[str] = None, 35 | embedder_to_indexer_map: Dict[str, List[str]] = None, 36 | allow_unmatched_keys: bool = False, 37 | dropout: float = 0.0) -> None: 38 | super(UdifyTextFieldEmbedder, self).__init__() 39 | self._output_dim = output_dim 40 | self._token_embedders = token_embedders 41 | self._embedder_to_indexer_map = embedder_to_indexer_map 42 | for key, embedder in token_embedders.items(): 43 | name = 'token_embedder_%s' % key 44 | self.add_module(name, embedder) 45 | self._allow_unmatched_keys = allow_unmatched_keys 46 | self._dropout = torch.nn.Dropout(p=dropout) if dropout > 0 else lambda x: x 47 | self._sum_embeddings = sum_embeddings if sum_embeddings is not None else [] 48 | 49 | hidden_dim = 0 50 | for embedder in self._token_embedders.values(): 51 | hidden_dim += embedder.get_output_dim() 52 | 53 | if len(self._sum_embeddings) > 1: 54 | for key in self._sum_embeddings[1:]: 55 | hidden_dim -= self._token_embedders[key].get_output_dim() 56 | 57 | if self._output_dim is None: 58 | self._projection_layer = None 59 | self._output_dim = hidden_dim 60 | else: 61 | self._projection_layer = Linear(hidden_dim, self._output_dim) 62 | 63 | @overrides 64 | def get_output_dim(self) -> int: 65 | return self._output_dim 66 | 67 | def forward(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int = 0) -> torch.Tensor: 68 | embedder_keys = self._token_embedders.keys() 69 | input_keys = text_field_input.keys() 70 | 71 | # Check for unmatched keys 72 | if not self._allow_unmatched_keys: 73 | if embedder_keys < input_keys: 74 | # token embedder keys are a strict subset of text field input keys. 75 | message = (f"Your text field is generating more keys ({list(input_keys)}) " 76 | f"than you have token embedders ({list(embedder_keys)}. " 77 | f"If you are using a token embedder that requires multiple keys " 78 | f"(for example, the OpenAI Transformer embedder or the BERT embedder) " 79 | f"you need to add allow_unmatched_keys = True " 80 | f"(and likely an embedder_to_indexer_map) to your " 81 | f"BasicTextFieldEmbedder configuration. " 82 | f"Otherwise, you should check that there is a 1:1 embedding " 83 | f"between your token indexers and token embedders.") 84 | raise ConfigurationError(message) 85 | 86 | elif self._token_embedders.keys() != text_field_input.keys(): 87 | # some other mismatch 88 | message = "Mismatched token keys: %s and %s" % (str(self._token_embedders.keys()), 89 | str(text_field_input.keys())) 90 | raise ConfigurationError(message) 91 | 92 | def embed(key): 93 | # If we pre-specified a mapping explictly, use that. 94 | if self._embedder_to_indexer_map is not None: 95 | tensors = [text_field_input[indexer_key] for 96 | indexer_key in self._embedder_to_indexer_map[key]] 97 | else: 98 | # otherwise, we assume the mapping between indexers and embedders 99 | # is bijective and just use the key directly. 100 | tensors = [text_field_input[key]] 101 | # Note: need to use getattr here so that the pytorch voodoo 102 | # with submodules works with multiple GPUs. 103 | embedder = getattr(self, 'token_embedder_{}'.format(key)) 104 | for _ in range(num_wrapping_dims): 105 | embedder = TimeDistributed(embedder) 106 | token_vectors = embedder(*tensors) 107 | 108 | return token_vectors 109 | 110 | embedded_representations = [] 111 | keys = sorted(embedder_keys) 112 | 113 | sum_embed = [] 114 | for key in self._sum_embeddings: 115 | token_vectors = embed(key) 116 | sum_embed.append(token_vectors) 117 | keys.remove(key) 118 | 119 | if sum_embed: 120 | embedded_representations.append(sum(sum_embed)) 121 | 122 | for key in keys: 123 | token_vectors = embed(key) 124 | embedded_representations.append(token_vectors) 125 | 126 | combined_embeddings = self._dropout(torch.cat(embedded_representations, dim=-1)) 127 | 128 | if self._projection_layer is not None: 129 | combined_embeddings = self._dropout(self._projection_layer(combined_embeddings)) 130 | 131 | return combined_embeddings 132 | 133 | # This is some unusual logic, it needs a custom from_params. 134 | @classmethod 135 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'UdifyTextFieldEmbedder': # type: ignore 136 | # pylint: disable=arguments-differ,bad-super-call 137 | 138 | # The original `from_params` for this class was designed in a way that didn't agree 139 | # with the constructor. The constructor wants a 'token_embedders' parameter that is a 140 | # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those 141 | # key-value pairs to be top-level in the params object. 142 | # 143 | # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, 144 | # the params need a 'token_embedders' key so that they line up with what the constructor wants. 145 | # For now, the old behavior is still supported, but produces a DeprecationWarning. 146 | 147 | embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None) 148 | if embedder_to_indexer_map is not None: 149 | embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True) 150 | allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False) 151 | 152 | token_embedder_params = params.pop('token_embedders', None) 153 | 154 | dropout = params.pop_float("dropout", 0.0) 155 | 156 | output_dim = params.pop_int("output_dim", None) 157 | sum_embeddings = params.pop("sum_embeddings", None) 158 | 159 | if token_embedder_params is not None: 160 | # New way: explicitly specified, so use it. 161 | token_embedders = { 162 | name: TokenEmbedder.from_params(subparams, vocab=vocab) 163 | for name, subparams in token_embedder_params.items() 164 | } 165 | 166 | else: 167 | # Warn that the original behavior is deprecated 168 | warnings.warn(DeprecationWarning("the token embedders for BasicTextFieldEmbedder should now " 169 | "be specified as a dict under the 'token_embedders' key, " 170 | "not as top-level key-value pairs")) 171 | 172 | token_embedders = {} 173 | keys = list(params.keys()) 174 | for key in keys: 175 | embedder_params = params.pop(key) 176 | token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params) 177 | 178 | params.assert_empty(cls.__name__) 179 | return cls(token_embedders, output_dim, sum_embeddings, embedder_to_indexer_map, allow_unmatched_keys, dropout) 180 | -------------------------------------------------------------------------------- /udapter/modules/token_characters_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modification to AllenNLP's TokenCharactersEncoder 3 | """ 4 | 5 | import torch 6 | 7 | from allennlp.common import Params 8 | from allennlp.data.vocabulary import Vocabulary 9 | from allennlp.modules.token_embedders.embedding import Embedding 10 | from allennlp.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder 11 | from allennlp.modules.time_distributed import TimeDistributed 12 | from allennlp.modules.token_embedders import TokenEmbedder 13 | 14 | 15 | @TokenEmbedder.register("udify_character_encoding") 16 | class UdifyTokenCharactersEncoder(TokenEmbedder): 17 | """ 18 | Like AllenNLP's TokenCharactersEncoder, but applies dropout to the embeddings. 19 | https://github.com/allenai/allennlp/blob/7dbd7d34a2f1390d1ff01f2e9ed6f8bdaaef77eb/ 20 | allennlp/modules/token_embedders/token_characters_encoder.py 21 | """ 22 | def __init__(self, embedding: Embedding, encoder: Seq2VecEncoder, dropout: float = 0.0) -> None: 23 | super(UdifyTokenCharactersEncoder, self).__init__() 24 | self._embedding = TimeDistributed(embedding) 25 | self._encoder = TimeDistributed(encoder) 26 | if dropout > 0: 27 | self._dropout = torch.nn.Dropout(p=dropout) 28 | else: 29 | self._dropout = lambda x: x 30 | 31 | def get_output_dim(self) -> int: 32 | return self._encoder._module.get_output_dim() # pylint: disable=protected-access 33 | 34 | def forward(self, token_characters: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ 35 | mask = (token_characters != 0).long() 36 | return self._encoder(self._dropout(self._embedding(token_characters)), mask) 37 | 38 | # The setdefault requires a custom from_params 39 | @classmethod 40 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'UDTokenCharactersEncoder': # type: ignore 41 | # pylint: disable=arguments-differ 42 | embedding_params: Params = params.pop("embedding") 43 | # Embedding.from_params() uses "tokens" as the default namespace, but we need to change 44 | # that to be "token_characters" by default. 45 | embedding_params.setdefault("vocab_namespace", "token_characters") 46 | embedding = Embedding.from_params(vocab, embedding_params) 47 | encoder_params: Params = params.pop("encoder") 48 | encoder = Seq2VecEncoder.from_params(encoder_params) 49 | dropout = params.pop_float("dropout", 0.0) 50 | params.assert_empty(cls.__name__) 51 | return cls(embedding, encoder, dropout) -------------------------------------------------------------------------------- /udapter/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.optimizers.ulmfit_sqrt import UlmfitSqrtLR 2 | -------------------------------------------------------------------------------- /udapter/optimizers/ulmfit_sqrt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Special LR scheduler for fine-tuning Transformer networks 3 | """ 4 | 5 | from overrides import overrides 6 | import torch 7 | import logging 8 | 9 | from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @LearningRateScheduler.register("ulmfit_sqrt") 15 | class UlmfitSqrtLR(LearningRateScheduler): 16 | """Implements a combination of ULMFiT (slanted triangular) with Noam sqrt learning rate decay""" 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | model_size: int, 20 | warmup_steps: int, 21 | start_step: int = 0, 22 | factor: float = 100, 23 | steepness: float = 0.5, 24 | last_epoch: int = -1, 25 | gradual_unfreezing: bool = False, 26 | discriminative_fine_tuning: bool = False, 27 | decay_factor: float = 0.38) -> None: 28 | self.warmup_steps = warmup_steps + start_step 29 | self.start_step = start_step 30 | self.factor = factor 31 | self.steepness = steepness 32 | self.model_size = model_size 33 | self.gradual_unfreezing = gradual_unfreezing 34 | self.freezing_current = self.gradual_unfreezing 35 | 36 | if self.gradual_unfreezing: 37 | assert not optimizer.param_groups[-1]["params"], \ 38 | "The default group should be empty." 39 | if self.gradual_unfreezing or discriminative_fine_tuning: 40 | assert len(optimizer.param_groups) > 2, \ 41 | "There should be at least 3 param_groups (2 + empty default group)" \ 42 | " for gradual unfreezing / discriminative fine-tuning to make sense." 43 | 44 | super().__init__(optimizer, last_epoch=last_epoch) 45 | 46 | if discriminative_fine_tuning: 47 | # skip the last param_group if it is has no parameters 48 | exponent = 0 49 | for i in range(len(self.base_values) - 1, -1, -1): 50 | param_group = optimizer.param_groups[i] 51 | if param_group['params']: 52 | param_group['lr'] = self.base_values[i] * decay_factor ** exponent 53 | self.base_values[i] = param_group['lr'] 54 | exponent += 1 55 | 56 | @overrides 57 | def step(self, metric: float = None, epoch: int = None) -> None: 58 | if self.gradual_unfreezing: 59 | # the method is called once when initialising before the 60 | # first epoch (epoch -1) and then always at the end of each 61 | # epoch; so the first time, with epoch id -1, we want to set 62 | # up for epoch #1; the second time, with epoch id 0, 63 | # we want to set up for epoch #2, etc. 64 | num_layers_to_unfreeze = epoch + 2 if epoch > -1 else 1 65 | if num_layers_to_unfreeze >= len(self.optimizer.param_groups) - 1: 66 | logger.info('Gradual unfreezing finished. Training all layers.') 67 | self.freezing_current = False 68 | else: 69 | logger.info(f'Gradual unfreezing. Training only the top {num_layers_to_unfreeze} layers.') 70 | for i, param_group in enumerate(reversed(self.optimizer.param_groups)): 71 | for param in param_group["params"]: 72 | # i = 0 is the default group; we care about i > 0 73 | param.requires_grad = bool(i <= num_layers_to_unfreeze) 74 | 75 | def step_batch(self, batch_num_total: int = None) -> None: 76 | if batch_num_total is None: 77 | self.last_epoch += 1 # type: ignore 78 | else: 79 | self.last_epoch = batch_num_total 80 | for param_group, learning_rate in zip(self.optimizer.param_groups, self.get_values()): 81 | param_group['lr'] = learning_rate 82 | 83 | def get_values(self): 84 | if self.freezing_current: 85 | # If parameters are still frozen, keep the base learning rates 86 | return self.base_values 87 | 88 | # This computes the Noam Sqrt LR decay based on the current step 89 | step = max(self.last_epoch - self.start_step, 1) 90 | scale = self.factor * (self.model_size ** (-0.5) * 91 | min(step ** (-self.steepness), step * self.warmup_steps ** (-self.steepness - 1))) 92 | return [scale * lr for lr in self.base_values] 93 | -------------------------------------------------------------------------------- /udapter/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.predictors.udify_predictor import UdifyPredictor 2 | -------------------------------------------------------------------------------- /udapter/predictors/udapter_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main UDify predictor to output conllu files 3 | """ 4 | 5 | from typing import List 6 | from overrides import overrides 7 | 8 | from allennlp.common.util import JsonDict, sanitize 9 | from allennlp.data import DatasetReader, Instance 10 | from allennlp.models import Model 11 | from allennlp.predictors.predictor import Predictor 12 | 13 | 14 | @Predictor.register("udapter_predictor") 15 | class UdifyPredictor(Predictor): 16 | """ 17 | Predictor for a UDify model that takes in a sentence and returns 18 | a single set conllu annotations for it. 19 | """ 20 | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: 21 | super().__init__(model, dataset_reader) 22 | 23 | def predict(self, sentence: str) -> JsonDict: 24 | return self.predict_json({"sentence": sentence}) 25 | 26 | @overrides 27 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 28 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 29 | # Handle cases where the labels are present in the test set but not training set 30 | for instance in instances: 31 | self._predict_unknown(instance) 32 | outputs = self._model.forward_on_instances(instances) 33 | return sanitize(outputs) 34 | 35 | @overrides 36 | def predict_instance(self, instance: Instance) -> JsonDict: 37 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 38 | # Handle cases where the labels are present in the test set but not training set 39 | self._predict_unknown(instance) 40 | outputs = self._model.forward_on_instance(instance) 41 | return sanitize(outputs) 42 | 43 | def _predict_unknown(self, instance: Instance): 44 | """ 45 | Maps each unknown label in each namespace to a default token 46 | :param instance: the instance containing a list of labels for each namespace 47 | """ 48 | def replace_tokens(instance: Instance, namespace: str, token: str): 49 | if namespace not in instance.fields: 50 | return 51 | 52 | instance.fields[namespace].labels = [label 53 | if label in self._model.vocab._token_to_index[namespace] 54 | else token 55 | for label in instance.fields[namespace].labels] 56 | 57 | replace_tokens(instance, "lemmas", "↓0;d¦") 58 | replace_tokens(instance, "feats", "_") 59 | replace_tokens(instance, "xpos", "_") 60 | replace_tokens(instance, "upos", "NOUN") 61 | replace_tokens(instance, "head_tags", "case") 62 | 63 | @overrides 64 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 65 | """ 66 | Expects JSON that looks like ``{"sentence": "..."}``. 67 | Runs the underlying model, and adds the ``"words"`` to the output. 68 | """ 69 | sentence = json_dict["sentence"] 70 | tokens = sentence.split() 71 | return self._dataset_reader.text_to_instance(tokens) 72 | 73 | @overrides 74 | def dump_line(self, outputs: JsonDict) -> str: 75 | word_count = len([word for word in outputs["words"]]) 76 | lines = zip(*[outputs[k] if k in outputs else ["_"] * word_count 77 | for k in ["ids", "words", "lemmas", "upos", "xpos", "feats", 78 | "predicted_heads", "predicted_dependencies"]]) 79 | 80 | multiword_map = None 81 | if outputs["multiword_ids"]: 82 | multiword_ids = [[id] + [int(x) for x in id.split("-")] for id in outputs["multiword_ids"]] 83 | multiword_forms = outputs["multiword_forms"] 84 | multiword_map = {start: (id_, form) for (id_, start, end), form in zip(multiword_ids, multiword_forms)} 85 | 86 | output_lines = [] 87 | for i, line in enumerate(lines): 88 | line = [str(l) for l in line] 89 | 90 | # Handle multiword tokens 91 | if multiword_map and i+1 in multiword_map: 92 | id_, form = multiword_map[i+1] 93 | row = f"{id_}\t{form}" + "".join(["\t_"] * 8) 94 | output_lines.append(row) 95 | 96 | row = "\t".join(line) + "".join(["\t_"] * 2) + "\t" + outputs["langs"][i] 97 | output_lines.append(row) 98 | 99 | return "\n".join(output_lines) + "\n\n" 100 | -------------------------------------------------------------------------------- /udapter/predictors/udify_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main UDify predictor to output conllu files 3 | """ 4 | 5 | from typing import List 6 | from overrides import overrides 7 | 8 | from allennlp.common.util import JsonDict, sanitize 9 | from allennlp.data import DatasetReader, Instance 10 | from allennlp.models import Model 11 | from allennlp.predictors.predictor import Predictor 12 | 13 | 14 | @Predictor.register("udify_predictor") 15 | class UdifyPredictor(Predictor): 16 | """ 17 | Predictor for a UDify model that takes in a sentence and returns 18 | a single set conllu annotations for it. 19 | """ 20 | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: 21 | super().__init__(model, dataset_reader) 22 | 23 | def predict(self, sentence: str) -> JsonDict: 24 | return self.predict_json({"sentence": sentence}) 25 | 26 | @overrides 27 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 28 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 29 | # Handle cases where the labels are present in the test set but not training set 30 | for instance in instances: 31 | self._predict_unknown(instance) 32 | outputs = self._model.forward_on_instances(instances) 33 | return sanitize(outputs) 34 | 35 | @overrides 36 | def predict_instance(self, instance: Instance) -> JsonDict: 37 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 38 | # Handle cases where the labels are present in the test set but not training set 39 | self._predict_unknown(instance) 40 | outputs = self._model.forward_on_instance(instance) 41 | return sanitize(outputs) 42 | 43 | def _predict_unknown(self, instance: Instance): 44 | """ 45 | Maps each unknown label in each namespace to a default token 46 | :param instance: the instance containing a list of labels for each namespace 47 | """ 48 | def replace_tokens(instance: Instance, namespace: str, token: str): 49 | if namespace not in instance.fields: 50 | return 51 | 52 | instance.fields[namespace].labels = [label 53 | if label in self._model.vocab._token_to_index[namespace] 54 | else token 55 | for label in instance.fields[namespace].labels] 56 | 57 | replace_tokens(instance, "lemmas", "↓0;d¦") 58 | replace_tokens(instance, "feats", "_") 59 | replace_tokens(instance, "xpos", "_") 60 | replace_tokens(instance, "upos", "NOUN") 61 | replace_tokens(instance, "head_tags", "case") 62 | 63 | @overrides 64 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 65 | """ 66 | Expects JSON that looks like ``{"sentence": "..."}``. 67 | Runs the underlying model, and adds the ``"words"`` to the output. 68 | """ 69 | sentence = json_dict["sentence"] 70 | tokens = sentence.split() 71 | return self._dataset_reader.text_to_instance(tokens) 72 | 73 | @overrides 74 | def dump_line(self, outputs: JsonDict) -> str: 75 | word_count = len([word for word in outputs["words"]]) 76 | lines = zip(*[outputs[k] if k in outputs else ["_"] * word_count 77 | for k in ["ids", "words", "lemmas", "upos", "xpos", "feats", 78 | "predicted_heads", "predicted_dependencies"]]) 79 | 80 | multiword_map = None 81 | if outputs["multiword_ids"]: 82 | multiword_ids = [[id] + [int(x) for x in id.split("-")] for id in outputs["multiword_ids"]] 83 | multiword_forms = outputs["multiword_forms"] 84 | multiword_map = {start: (id_, form) for (id_, start, end), form in zip(multiword_ids, multiword_forms)} 85 | 86 | output_lines = [] 87 | for i, line in enumerate(lines): 88 | line = [str(l) for l in line] 89 | 90 | # Handle multiword tokens 91 | if multiword_map and i+1 in multiword_map: 92 | id_, form = multiword_map[i+1] 93 | row = f"{id_}\t{form}" + "".join(["\t_"] * 8) 94 | output_lines.append(row) 95 | 96 | row = "\t".join(line) + "".join(["\t_"] * 2) 97 | output_lines.append(row) 98 | 99 | return "\n".join(output_lines) + "\n\n" 100 | -------------------------------------------------------------------------------- /udapter/udapter_models/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.udapter_models.udapter_model import UdapterModel 2 | from udapter.udapter_models.dependency_decoder import DependencyDecoder 3 | from udapter.udapter_models.tag_decoder import TagDecoder -------------------------------------------------------------------------------- /udapter/udapter_models/bilinear_matrix_attention.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | 5 | from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention 6 | from allennlp.nn import Activation 7 | 8 | 9 | @MatrixAttention.register("udapter-bilinear") 10 | class BilinearMatrixAttentionWithPGN(MatrixAttention): 11 | """ 12 | Computes attention between two matrices using a bilinear attention function. This function has 13 | a matrix of weights ``W`` and a bias ``b``, and the similarity between the two matrices ``X`` 14 | and ``Y`` is computed as ``X W Y^T + b``. 15 | 16 | Parameters 17 | ---------- 18 | matrix_1_dim : ``int`` 19 | The dimension of the matrix ``X``, described above. This is ``X.size()[-1]`` - the length 20 | of the vector that will go into the similarity computation. We need this so we can build 21 | the weight matrix correctly. 22 | matrix_2_dim : ``int`` 23 | The dimension of the matrix ``Y``, described above. This is ``Y.size()[-1]`` - the length 24 | of the vector that will go into the similarity computation. We need this so we can build 25 | the weight matrix correctly. 26 | activation : ``Activation``, optional (default=linear (i.e. no activation)) 27 | An activation function applied after the ``X W Y^T + b`` calculation. Default is no 28 | activation. 29 | use_input_biases : ``bool``, optional (default = False) 30 | If True, we add biases to the inputs such that the final computation 31 | is equivalent to the original bilinear matrix multiplication plus a 32 | projection of both inputs. 33 | label_dim : ``int``, optional (default = 1) 34 | The number of output classes. Typically in an attention setting this will be one, 35 | but this parameter allows this class to function as an equivalent to ``torch.nn.Bilinear`` 36 | for matrices, rather than vectors. 37 | """ 38 | def __init__(self, 39 | in_params: int, 40 | matrix_1_dim: int, 41 | matrix_2_dim: int, 42 | activation: Activation = None, 43 | use_input_biases: bool = False, 44 | label_dim: int = 1) -> None: 45 | super().__init__() 46 | 47 | self.in_params = in_params 48 | 49 | if use_input_biases: 50 | matrix_1_dim += 1 51 | matrix_2_dim += 1 52 | 53 | if label_dim == 1: 54 | self._weight_matrix = Parameter(torch.Tensor(in_params, matrix_1_dim, matrix_2_dim)) 55 | else: 56 | self._weight_matrix = Parameter(torch.Tensor(in_params, label_dim, matrix_1_dim, matrix_2_dim)) 57 | 58 | self._bias = Parameter(torch.Tensor(1)) 59 | self._activation = activation or Activation.by_name('linear')() 60 | self._use_input_biases = use_input_biases 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | torch.nn.init.xavier_uniform_(self._weight_matrix) 65 | self._bias.data.fill_(0) 66 | 67 | @overrides 68 | def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor, pgn_vector: torch.Tensor) -> torch.Tensor: 69 | 70 | if self._use_input_biases: 71 | bias1 = matrix_1.new_ones(matrix_1.size()[:-1] + (1,)) 72 | bias2 = matrix_2.new_ones(matrix_2.size()[:-1] + (1,)) 73 | 74 | matrix_1 = torch.cat([matrix_1, bias1], -1) 75 | matrix_2 = torch.cat([matrix_2, bias2], -1) 76 | 77 | ex_shape = self._weight_matrix.shape 78 | weight = torch.matmul(pgn_vector, self._weight_matrix.view(self.in_params, -1)).view(ex_shape[1:]) 79 | 80 | if weight.dim() == 2: 81 | weight = weight.unsqueeze(0) 82 | intermediate = torch.matmul(matrix_1.unsqueeze(1), weight) 83 | final = torch.matmul(intermediate, matrix_2.unsqueeze(1).transpose(2, 3)) 84 | return self._activation(final.squeeze(1) + self._bias) 85 | -------------------------------------------------------------------------------- /udapter/udapter_models/feedforward.py: -------------------------------------------------------------------------------- 1 | """ 2 | A feed-forward neural network. 3 | """ 4 | from typing import List, Union 5 | 6 | import math 7 | import torch 8 | 9 | from allennlp.common import FromParams 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.nn import Activation 12 | 13 | from udapter.udapter_models.linear import LinearWithPGN 14 | 15 | 16 | class FeedForwardWithPGN(torch.nn.Module, FromParams): 17 | """ 18 | This ``Module`` is a feed-forward neural network, just a sequence of ``Linear`` layers with 19 | activation functions in between. 20 | 21 | Parameters 22 | ---------- 23 | input_dim : ``int`` 24 | The dimensionality of the input. We assume the input has shape ``(batch_size, input_dim)``. 25 | num_layers : ``int`` 26 | The number of ``Linear`` layers to apply to the input. 27 | hidden_dims : ``Union[int, List[int]]`` 28 | The output dimension of each of the ``Linear`` layers. If this is a single ``int``, we use 29 | it for all ``Linear`` layers. If it is a ``List[int]``, ``len(hidden_dims)`` must be 30 | ``num_layers``. 31 | activations : ``Union[Callable, List[Callable]]`` 32 | The activation function to use after each ``Linear`` layer. If this is a single function, 33 | we use it after all ``Linear`` layers. If it is a ``List[Callable]``, 34 | ``len(activations)`` must be ``num_layers``. 35 | dropout : ``Union[float, List[float]]``, optional 36 | If given, we will apply this amount of dropout after each layer. Semantics of ``float`` 37 | versus ``List[float]`` is the same as with other parameters. 38 | """ 39 | def __init__(self, 40 | param_dim: int, 41 | input_dim: int, 42 | num_layers: int, 43 | hidden_dims: Union[int, List[int]], 44 | activations: Union[Activation, List[Activation]], 45 | dropout: Union[float, List[float]] = 0.0) -> None: 46 | 47 | super(FeedForwardWithPGN, self).__init__() 48 | if not isinstance(hidden_dims, list): 49 | hidden_dims = [hidden_dims] * num_layers # type: ignore 50 | if not isinstance(activations, list): 51 | activations = [activations] * num_layers # type: ignore 52 | if not isinstance(dropout, list): 53 | dropout = [dropout] * num_layers # type: ignore 54 | if len(hidden_dims) != num_layers: 55 | raise ConfigurationError("len(hidden_dims) (%d) != num_layers (%d)" % 56 | (len(hidden_dims), num_layers)) 57 | if len(activations) != num_layers: 58 | raise ConfigurationError("len(activations) (%d) != num_layers (%d)" % 59 | (len(activations), num_layers)) 60 | if len(dropout) != num_layers: 61 | raise ConfigurationError("len(dropout) (%d) != num_layers (%d)" % 62 | (len(dropout), num_layers)) 63 | self._activations = activations 64 | input_dims = [input_dim] + hidden_dims[:-1] 65 | linear_layers = [] 66 | param_dims = [param_dim] 67 | for layer_param_dim, layer_input_dim, layer_output_dim in zip(param_dims, input_dims, hidden_dims): 68 | linear_layers.append(LinearWithPGN(layer_param_dim, layer_input_dim, layer_output_dim)) 69 | self._linear_layers = torch.nn.ModuleList(linear_layers) 70 | dropout_layers = [torch.nn.Dropout(p=value) for value in dropout] 71 | self._dropout = torch.nn.ModuleList(dropout_layers) 72 | self._output_dim = hidden_dims[-1] 73 | self.input_dim = input_dim 74 | 75 | def get_output_dim(self): 76 | return self._output_dim 77 | 78 | def get_input_dim(self): 79 | return self.input_dim 80 | 81 | def forward(self, inputs: torch.Tensor, png_vector: torch.Tensor) -> torch.Tensor: 82 | # pylint: disable=arguments-differ 83 | output = inputs 84 | for layer, activation, dropout in zip(self._linear_layers, self._activations, self._dropout): 85 | output = dropout(activation(layer(output, png_vector))) 86 | return output 87 | 88 | -------------------------------------------------------------------------------- /udapter/udapter_models/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | 9 | 10 | class LinearWithPGN(torch.nn.Module): 11 | 12 | def __init__(self, in_params, in_features, out_features, bias=True): 13 | super(LinearWithPGN, self).__init__() 14 | self.in_params = in_params 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.weight = nn.Parameter(torch.Tensor(in_params, out_features, in_features)) 18 | if bias: 19 | self.bias = nn.Parameter(torch.Tensor(in_params, out_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 26 | if self.bias is not None: 27 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 28 | bound = 1 / math.sqrt(fan_in) 29 | nn.init.uniform_(self.bias, -bound, bound) 30 | 31 | def forward(self, input, pgn_vector): 32 | project_w = torch.matmul(pgn_vector, self.weight.view(self.in_params, -1)).view(self.out_features, self.in_features) 33 | project_b = None 34 | if self.bias is not None: 35 | project_b = torch.matmul(pgn_vector, self.bias) 36 | 37 | return F.linear(input, project_w, project_b) 38 | 39 | def extra_repr(self): 40 | return 'in_params={}, in_features={}, out_features={}, bias={}'.format( 41 | self.in_params, self.in_features, self.out_features, self.bias is not None 42 | ) 43 | 44 | 45 | class BilinearWithPGN(nn.Module): 46 | __constants__ = ['in1_features', 'in2_features', 'out_features', 'bias'] 47 | 48 | def __init__(self, in_params, in1_features, in2_features, out_features, bias=True): 49 | super(BilinearWithPGN, self).__init__() 50 | self.in_params = in_params 51 | self.in1_features = in1_features 52 | self.in2_features = in2_features 53 | self.out_features = out_features 54 | self.weight = Parameter(torch.Tensor(in_params, out_features, in1_features, in2_features)) 55 | 56 | if bias: 57 | self.bias = Parameter(torch.Tensor(in_params, out_features)) 58 | else: 59 | self.register_parameter('bias', None) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | bound = 1 / math.sqrt(self.weight.size(1)) 64 | init.uniform_(self.weight, -bound, bound) 65 | if self.bias is not None: 66 | init.uniform_(self.bias, -bound, bound) 67 | 68 | def forward(self, input1, input2, pgn_vector): 69 | weight = torch.matmul(pgn_vector, self.weight.view(self.in_params, -1)).view(self.out_features, self.in1_features, self.in2_features) 70 | bias = None 71 | if self.bias is not None: 72 | bias = torch.matmul(pgn_vector, self.bias) 73 | return F.bilinear(input1, input2, weight, bias) 74 | 75 | def extra_repr(self): 76 | return 'in_params={}, in1_features={}, in2_features={}, out_features={}, bias={}'.format( 77 | self.in_params, self.in1_features, self.in2_features, self.out_features, self.bias is not None 78 | ) 79 | -------------------------------------------------------------------------------- /udapter/udapter_models/tag_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decodes sequences of tags, e.g., POS tags, given a list of contextualized word embeddings 3 | """ 4 | 5 | from typing import Optional, Any, Dict, List 6 | from overrides import overrides 7 | 8 | import numpy 9 | import torch 10 | from torch.nn.modules.adaptive import AdaptiveLogSoftmaxWithLoss 11 | import torch.nn.functional as F 12 | 13 | from allennlp.data import Vocabulary 14 | from allennlp.modules import TimeDistributed, Seq2SeqEncoder, Embedding, FeedForward 15 | from allennlp.models.model import Model 16 | from allennlp.nn import InitializerApplicator, RegularizerApplicator, Activation 17 | from allennlp.nn.util import sequence_cross_entropy_with_logits 18 | from allennlp.training.metrics import CategoricalAccuracy 19 | 20 | from udapter.dataset_readers.lemma_edit import apply_lemma_rule 21 | from udapter.udapter_models.linear import LinearWithPGN 22 | from udapter.udapter_models.time_distributed import TimeDistributedPGN 23 | 24 | 25 | def sequence_cross_entropy(log_probs: torch.FloatTensor, 26 | targets: torch.LongTensor, 27 | weights: torch.FloatTensor, 28 | average: str = "batch", 29 | label_smoothing: float = None) -> torch.FloatTensor: 30 | if average not in {None, "token", "batch"}: 31 | raise ValueError("Got average f{average}, expected one of " 32 | "None, 'token', or 'batch'") 33 | # shape : (batch * sequence_length, num_classes) 34 | log_probs_flat = log_probs.view(-1, log_probs.size(2)) 35 | # shape : (batch * max_len, 1) 36 | targets_flat = targets.view(-1, 1).long() 37 | 38 | if label_smoothing is not None and label_smoothing > 0.0: 39 | num_classes = log_probs.size(-1) 40 | smoothing_value = label_smoothing / num_classes 41 | # Fill all the correct indices with 1 - smoothing value. 42 | one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) 43 | smoothed_targets = one_hot_targets + smoothing_value 44 | negative_log_likelihood_flat = - log_probs_flat * smoothed_targets 45 | negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) 46 | else: 47 | # Contribution to the negative log likelihood only comes from the exact indices 48 | # of the targets, as the target distributions are one-hot. Here we use torch.gather 49 | # to extract the indices of the num_classes dimension which contribute to the loss. 50 | # shape : (batch * sequence_length, 1) 51 | negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) 52 | # shape : (batch, sequence_length) 53 | negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) 54 | # shape : (batch, sequence_length) 55 | negative_log_likelihood = negative_log_likelihood * weights.float() 56 | 57 | if average == "batch": 58 | # shape : (batch_size,) 59 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 60 | num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) 61 | return per_batch_loss.sum() / num_non_empty_sequences 62 | elif average == "token": 63 | return negative_log_likelihood.sum() / (weights.sum().float() + 1e-13) 64 | else: 65 | # shape : (batch_size,) 66 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 67 | return per_batch_loss 68 | 69 | 70 | @Model.register("udapter_tag_decoder") 71 | class TagDecoder(Model): 72 | """ 73 | A basic sequence tagger that decodes from inputs of word embeddings 74 | """ 75 | def __init__(self, 76 | vocab: Vocabulary, 77 | task: str, 78 | encoder: Seq2SeqEncoder, 79 | png_params_dim: int, 80 | label_smoothing: float = 0.0, 81 | dropout: float = 0.0, 82 | adaptive: bool = False, 83 | features: List[str] = None, 84 | initializer: InitializerApplicator = InitializerApplicator(), 85 | regularizer: Optional[RegularizerApplicator] = None) -> None: 86 | super(TagDecoder, self).__init__(vocab, regularizer) 87 | 88 | self.dropout = torch.nn.Dropout(p=dropout) 89 | 90 | self.task = task 91 | self.encoder = encoder 92 | self.output_dim = encoder.get_output_dim() 93 | self.label_smoothing = label_smoothing 94 | self.num_classes = self.vocab.get_vocab_size(task) 95 | self.adaptive = adaptive 96 | self.features = [f.replace('[','_').replace(']','_') for f in features] if features else [] 97 | 98 | self.metrics = { 99 | "acc": CategoricalAccuracy(), 100 | # "acc3": CategoricalAccuracy(top_k=3) 101 | } 102 | 103 | if self.adaptive: 104 | # TODO 105 | adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)] 106 | self.task_output = AdaptiveLogSoftmaxWithLoss(self.output_dim, 107 | self.num_classes, 108 | cutoffs=adaptive_cutoffs, 109 | div_value=4.0) 110 | else: 111 | self.task_output = TimeDistributedPGN(LinearWithPGN(png_params_dim, self.output_dim, self.num_classes)) 112 | 113 | self.feature_outputs = torch.nn.ModuleDict() 114 | self.features_metrics = {} 115 | for feature in self.features: 116 | self.feature_outputs[feature] = TimeDistributedPGN(LinearWithPGN(png_params_dim, self.output_dim, 117 | vocab.get_vocab_size(feature))) 118 | self.features_metrics[feature] = { 119 | "acc": CategoricalAccuracy(), 120 | } 121 | 122 | initializer(self) 123 | 124 | @overrides 125 | def forward(self, 126 | encoded_text: torch.FloatTensor, 127 | mask: torch.LongTensor, 128 | gold_tags: Dict[str, torch.LongTensor], 129 | pgn_vector: torch.Tensor, 130 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 131 | 132 | hidden = encoded_text 133 | hidden = self.encoder(hidden, mask) 134 | 135 | batch_size, sequence_length, _ = hidden.size() 136 | output_dim = [batch_size, sequence_length, self.num_classes] 137 | 138 | loss_fn = self._adaptive_loss if self.adaptive else self._loss 139 | 140 | output_dict = loss_fn(hidden, mask, gold_tags[self.task], output_dim, pgn_vector) 141 | self._features_loss(hidden, mask, gold_tags, output_dict, pgn_vector) 142 | 143 | return output_dict 144 | 145 | def _adaptive_loss(self, hidden, mask, gold_tags, output_dim, pgn_vector): 146 | logits = hidden 147 | reshaped_log_probs = logits.view(-1, logits.size(2)) 148 | 149 | class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim) 150 | 151 | output_dict = {"logits": logits, "class_probabilities": class_probabilities} 152 | 153 | if gold_tags is not None: 154 | output_dict["loss"] = sequence_cross_entropy(class_probabilities, 155 | gold_tags, 156 | mask, 157 | label_smoothing=self.label_smoothing) 158 | for metric in self.metrics.values(): 159 | metric(class_probabilities, gold_tags, mask.float()) 160 | 161 | return output_dict 162 | 163 | def _loss(self, hidden, mask, gold_tags, output_dim, pgn_vector): 164 | logits = self.task_output(hidden, pgn_vector) 165 | reshaped_log_probs = logits.view(-1, self.num_classes) 166 | class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim) 167 | 168 | output_dict = {"logits": logits, "class_probabilities": class_probabilities} 169 | 170 | if gold_tags is not None: 171 | output_dict["loss"] = sequence_cross_entropy_with_logits(logits, 172 | gold_tags, 173 | mask, 174 | label_smoothing=self.label_smoothing) 175 | for metric in self.metrics.values(): 176 | metric(logits, gold_tags, mask.float()) 177 | 178 | return output_dict 179 | 180 | def _features_loss(self, hidden, mask, gold_tags, output_dict, pgn_vector): 181 | if gold_tags is None: 182 | return 183 | 184 | for feature in self.features: 185 | logits = self.feature_outputs[feature](hidden, pgn_vector) 186 | loss = sequence_cross_entropy_with_logits(logits, 187 | gold_tags[feature], 188 | mask, 189 | label_smoothing=self.label_smoothing) 190 | loss /= len(self.features) 191 | output_dict["loss"] += loss 192 | 193 | for metric in self.features_metrics[feature].values(): 194 | metric(logits, gold_tags[feature], mask.float()) 195 | 196 | @overrides 197 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 198 | all_words = output_dict["words"] 199 | 200 | all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy() 201 | if all_predictions.ndim == 3: 202 | predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] 203 | else: 204 | predictions_list = [all_predictions] 205 | all_tags = [] 206 | for predictions, words in zip(predictions_list, all_words): 207 | argmax_indices = numpy.argmax(predictions, axis=-1) 208 | tags = [self.vocab.get_token_from_index(x, namespace=self.task) 209 | for x in argmax_indices] 210 | 211 | # TODO: specific task 212 | if self.task == "lemmas": 213 | def decode_lemma(word, rule): 214 | if rule == "_": 215 | return "_" 216 | if rule == "@@UNKNOWN@@": 217 | return word 218 | return apply_lemma_rule(word, rule) 219 | tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)] 220 | 221 | all_tags.append(tags) 222 | output_dict[self.task] = all_tags 223 | 224 | return output_dict 225 | 226 | @overrides 227 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 228 | main_metrics = { 229 | f".run/{self.task}/{metric_name}": metric.get_metric(reset) 230 | for metric_name, metric in self.metrics.items() 231 | } 232 | 233 | features_metrics = { 234 | f"_run/{self.task}/{feature}/{metric_name}": metric.get_metric(reset) 235 | for feature in self.features 236 | for metric_name, metric in self.features_metrics[feature].items() 237 | } 238 | 239 | return {**main_metrics, **features_metrics} 240 | -------------------------------------------------------------------------------- /udapter/udapter_models/time_distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | A wrapper that unrolls the second (time) dimension of a tensor 3 | into the first (batch) dimension, applies some other ``Module``, 4 | and then rolls the time dimension back up. 5 | """ 6 | 7 | from typing import List 8 | 9 | from overrides import overrides 10 | import torch 11 | 12 | 13 | class TimeDistributedPGN(torch.nn.Module): 14 | """ 15 | Given an input shaped like ``(batch_size, time_steps, [rest])`` and a ``Module`` that takes 16 | inputs like ``(batch_size, [rest])``, ``TimeDistributed`` reshapes the input to be 17 | ``(batch_size * time_steps, [rest])``, applies the contained ``Module``, then reshapes it back. 18 | 19 | Note that while the above gives shapes with ``batch_size`` first, this ``Module`` also works if 20 | ``batch_size`` is second - we always just combine the first two dimensions, then split them. 21 | 22 | It also reshapes keyword arguments unless they are not tensors or their name is specified in 23 | the optional ``pass_through`` iterable. 24 | """ 25 | def __init__(self, module): 26 | super().__init__() 27 | self._module = module 28 | 29 | @overrides 30 | def forward(self, *inputs, pass_through: List[str] = None, **kwargs): 31 | pgn_vector = inputs[-1] 32 | inputs = inputs[:-1] 33 | # pylint: disable=arguments-differ 34 | pass_through = pass_through or [] 35 | 36 | reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs] 37 | reshaped_inputs.append(pgn_vector) 38 | 39 | # Need some input to then get the batch_size and time_steps. 40 | some_input = None 41 | if inputs: 42 | some_input = inputs[-1] 43 | 44 | reshaped_kwargs = {} 45 | for key, value in kwargs.items(): 46 | if isinstance(value, torch.Tensor) and key not in pass_through: 47 | if some_input is None: 48 | some_input = value 49 | 50 | value = self._reshape_tensor(value) 51 | 52 | reshaped_kwargs[key] = value 53 | 54 | reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs) 55 | 56 | if some_input is None: 57 | raise RuntimeError("No input tensor to time-distribute") 58 | 59 | # Now get the output back into the right shape. 60 | # (batch_size, time_steps, **output_size) 61 | new_size = some_input.size()[:2] + reshaped_outputs.size()[1:] 62 | outputs = reshaped_outputs.contiguous().view(new_size) 63 | 64 | return outputs 65 | 66 | @staticmethod 67 | def _reshape_tensor(input_tensor): 68 | input_size = input_tensor.size() 69 | if len(input_size) <= 2: 70 | raise RuntimeError(f"No dimension to distribute: {input_size}") 71 | # Squash batch_size and time_steps into a single axis; result has shape 72 | # (batch_size * time_steps, **input_size). 73 | squashed_shape = [-1] + list(input_size[2:]) 74 | return input_tensor.contiguous().view(*squashed_shape) 75 | -------------------------------------------------------------------------------- /udapter/udapter_models/udapter_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base UDify model for training and prediction 3 | """ 4 | 5 | from typing import Optional, Any, Dict, List, Tuple 6 | from overrides import overrides 7 | import logging 8 | 9 | import torch 10 | 11 | from pytorch_transformers import BertTokenizer, XLMTokenizer 12 | 13 | from allennlp.common.checks import check_dimensions_match, ConfigurationError 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder 16 | from allennlp.models.model import Model 17 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 18 | from allennlp.nn.util import get_text_field_mask 19 | 20 | from udapter.modules.language_emb import LanguageEmbeddings 21 | from udapter.modules.language_mlp import LanguageMLP 22 | from udapter.modules.scalar_mix import ScalarMixWithDropout 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @Model.register("udapter_model") 28 | class UdapterModel(Model): 29 | """ 30 | The UDify model base class. Applies a sequence of shared encoders before decoding in a multi-task configuration. 31 | Uses TagDecoder and DependencyDecoder to decode each UD task. 32 | """ 33 | 34 | def __init__(self, 35 | vocab: Vocabulary, 36 | tasks: List[str], 37 | pretrained_model: str, 38 | text_field_embedder: TextFieldEmbedder, 39 | encoder: Seq2SeqEncoder, 40 | decoders: Dict[str, Model], 41 | post_encoder_embedder: TextFieldEmbedder = None, 42 | dropout: float = 0.0, 43 | word_dropout: float = 0.0, 44 | mix_embedding: int = None, 45 | layer_dropout: int = 0.0, 46 | seperate_lang_emb: bool = False, 47 | initializer: InitializerApplicator = InitializerApplicator(), 48 | regularizer: Optional[RegularizerApplicator] = None) -> None: 49 | super(UdapterModel, self).__init__(vocab, regularizer) 50 | 51 | self.tasks = tasks 52 | self.vocab = vocab 53 | self.text_field_embedder = text_field_embedder 54 | self.post_encoder_embedder = post_encoder_embedder 55 | self.shared_encoder = encoder 56 | self.word_dropout = word_dropout 57 | self.dropout = torch.nn.Dropout(p=dropout) 58 | self.decoders = torch.nn.ModuleDict(decoders) 59 | 60 | self.seperate_lang_emb = seperate_lang_emb 61 | 62 | if not self.seperate_lang_emb: 63 | self.language_embedder = self.text_field_embedder.token_embedder_bert.language_embedder 64 | else: 65 | config = self.text_field_embedder.token_embedder_bert.language_embedder.config 66 | language_emb_from_features = self.text_field_embedder.token_embedder_bert.language_emb_from_features 67 | if language_emb_from_features: 68 | in_lang_list = self.text_field_embedder.token_embedder_bert.language_embedder.in_language_list 69 | oov_lang_list = self.text_field_embedder.token_embedder_bert.language_embedder.oov_language_list 70 | letter_codes = self.text_field_embedder.token_embedder_bert.language_embedder.letter_codes 71 | self.language_embedder = LanguageMLP(config, in_lang_list, oov_lang_list, letter_codes) 72 | else: 73 | self.language_embedder = LanguageEmbeddings(config) 74 | 75 | if 'bert' in pretrained_model: 76 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower_case=False) 77 | elif 'xlm' in pretrained_model: 78 | self.tokenizer = XLMTokenizer.from_pretrained(pretrained_model, do_lower_case=False) 79 | else: 80 | raise ConfigurationError(f"No corresponding pretrained model for tokenizer.") 81 | 82 | if mix_embedding: 83 | self.scalar_mix = torch.nn.ModuleDict({ 84 | task: ScalarMixWithDropout(mix_embedding, 85 | do_layer_norm=False, 86 | dropout=layer_dropout) 87 | for task in self.decoders 88 | }) 89 | else: 90 | self.scalar_mix = None 91 | 92 | self.metrics = {} 93 | 94 | for task in self.tasks: 95 | if task not in self.decoders: 96 | raise ConfigurationError(f"Task {task} has no corresponding decoder. Make sure their names match.") 97 | 98 | check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), 99 | "text field embedding dim", "encoder input dim") 100 | 101 | initializer(self) 102 | self._count_params() 103 | 104 | @overrides 105 | def forward(self, 106 | tokens: Dict[str, torch.LongTensor], 107 | metadata: List[Dict[str, Any]] = None, 108 | **kwargs: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: 109 | if "track_epoch" in kwargs: 110 | track_epoch = kwargs.pop("track_epoch") 111 | 112 | gold_tags = kwargs 113 | 114 | if "tokens" in self.tasks: 115 | # Model is predicting tokens, so add them to the gold tags 116 | gold_tags["tokens"] = tokens["tokens"] 117 | 118 | mask = get_text_field_mask(tokens) 119 | self._apply_token_dropout(tokens) 120 | 121 | embedded_text_input = self.text_field_embedder(tokens) 122 | language_emb = self.language_embedder(next(iter(tokens['bert-lang-ids']))).to(embedded_text_input.device) 123 | 124 | if self.post_encoder_embedder: 125 | post_embeddings = self.post_encoder_embedder(tokens) 126 | 127 | encoded_text = self.shared_encoder(embedded_text_input, mask) 128 | 129 | logits = {} 130 | class_probabilities = {} 131 | output_dict = {"logits": logits, 132 | "class_probabilities": class_probabilities} 133 | loss = 0 134 | 135 | # Run through each of the tasks on the shared encoder and save predictions 136 | for task in self.tasks: 137 | if self.scalar_mix: 138 | decoder_input = self.scalar_mix[task](encoded_text, mask) 139 | else: 140 | decoder_input = encoded_text 141 | 142 | if self.post_encoder_embedder: 143 | decoder_input = decoder_input + post_embeddings 144 | 145 | if task == "deps": 146 | tag_logits = logits["upos"] if "upos" in logits else None 147 | pred_output = self.decoders[task](decoder_input, mask, language_emb, tag_logits, 148 | gold_tags, metadata) 149 | for key in ["heads", "head_tags", "arc_loss", "tag_loss", "mask"]: 150 | output_dict[key] = pred_output[key] 151 | else: 152 | pred_output = self.decoders[task](decoder_input, mask, gold_tags, language_emb, metadata) 153 | 154 | logits[task] = pred_output["logits"] 155 | class_probabilities[task] = pred_output["class_probabilities"] 156 | 157 | if task in gold_tags or task == "deps" and "head_tags" in gold_tags: 158 | # Keep track of the loss if we have the gold tags available 159 | loss += pred_output["loss"] 160 | 161 | if gold_tags: 162 | output_dict["loss"] = loss 163 | 164 | if metadata is not None: 165 | output_dict["words"] = [x["words"] for x in metadata] 166 | output_dict["ids"] = [x["ids"] for x in metadata if "ids" in x] 167 | output_dict["multiword_ids"] = [x["multiword_ids"] for x in metadata if "multiword_ids" in x] 168 | output_dict["multiword_forms"] = [x["multiword_forms"] for x in metadata if "multiword_forms" in x] 169 | output_dict["langs"] = [x["langs"] for x in metadata] 170 | 171 | return output_dict 172 | 173 | def _apply_token_dropout(self, tokens): 174 | # Word dropout 175 | if "tokens" in tokens: 176 | oov_token = self.vocab.get_token_index(self.vocab._oov_token) 177 | ignore_tokens = [self.vocab.get_token_index(self.vocab._padding_token)] 178 | tokens["tokens"] = self.token_dropout(tokens["tokens"], 179 | oov_token=oov_token, 180 | padding_tokens=ignore_tokens, 181 | p=self.word_dropout, 182 | training=self.training) 183 | 184 | # BERT token dropout 185 | if "bert" in tokens: 186 | oov_token = self.tokenizer.vocab["[MASK]"] 187 | ignore_tokens = [self.tokenizer.vocab["[PAD]"], self.tokenizer.vocab["[CLS]"], self.tokenizer.vocab["[SEP]"]] 188 | tokens["bert"] = self.token_dropout(tokens["bert"], 189 | oov_token=oov_token, 190 | padding_tokens=ignore_tokens, 191 | p=self.word_dropout, 192 | training=self.training) 193 | 194 | @staticmethod 195 | def token_dropout(tokens: torch.LongTensor, 196 | oov_token: int, 197 | padding_tokens: List[int], 198 | p: float = 0.2, 199 | training: float = True) -> torch.LongTensor: 200 | """ 201 | During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p`` 202 | 203 | :param tokens: The current batch of padded sentences with word ids 204 | :param oov_token: The mask token 205 | :param padding_tokens: The tokens for padding the input batch 206 | :param p: The probability a word gets mapped to the unknown token 207 | :param training: Applies the dropout if set to ``True`` 208 | :return: A copy of the input batch with token dropout applied 209 | """ 210 | if training and p > 0: 211 | # Ensure that the tensors run on the same device 212 | device = tokens.device 213 | 214 | # This creates a mask that only considers unpadded tokens for mapping to oov 215 | padding_mask = torch.ones(tokens.size(), dtype=torch.uint8).to(device) 216 | for pad in padding_tokens: 217 | padding_mask &= tokens != pad 218 | 219 | # Create a uniformly random mask selecting either the original words or OOV tokens 220 | dropout_mask = (torch.empty(tokens.size()).uniform_() < p).to(device) 221 | oov_mask = dropout_mask & padding_mask 222 | 223 | oov_fill = torch.empty(tokens.size(), dtype=torch.long).fill_(oov_token).to(device) 224 | 225 | result = torch.where(oov_mask, oov_fill, tokens) 226 | 227 | return result 228 | else: 229 | return tokens 230 | 231 | @overrides 232 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 233 | for task in self.tasks: 234 | self.decoders[task].decode(output_dict) 235 | 236 | return output_dict 237 | 238 | @overrides 239 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 240 | metrics = {name: task_metric 241 | for task in self.tasks 242 | for name, task_metric in self.decoders[task].get_metrics(reset).items()} 243 | 244 | # The "sum" metric summing all tracked metrics keeps a good measure of patience for early stopping and saving 245 | metrics_to_track = {"upos", "xpos", "feats", "lemmas", "LAS", "UAS"} 246 | metrics[".run/.sum"] = sum(metric 247 | for name, metric in metrics.items() 248 | if not name.startswith("_") and set(name.split("/")).intersection(metrics_to_track)) 249 | 250 | return metrics 251 | 252 | def _count_params(self): 253 | self.total_params = sum(p.numel() for p in self.parameters()) 254 | self.total_train_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 255 | 256 | logger.info(f"Total number of parameters: {self.total_params}") 257 | logger.info(f"Total number of trainable parameters: {self.total_train_params}") 258 | -------------------------------------------------------------------------------- /udapter/udify_models/__init__.py: -------------------------------------------------------------------------------- 1 | from udapter.udify_models.udify_model import UdifyModel 2 | from udapter.udify_models.dependency_decoder import DependencyDecoder 3 | from udapter.udify_models.tag_decoder import TagDecoder 4 | -------------------------------------------------------------------------------- /udapter/udify_models/udify_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base UDify model for training and prediction 3 | """ 4 | 5 | from typing import Optional, Any, Dict, List, Tuple 6 | from overrides import overrides 7 | import logging 8 | 9 | import torch 10 | 11 | from pytorch_transformers import BertTokenizer, XLMTokenizer 12 | 13 | from allennlp.common.checks import check_dimensions_match, ConfigurationError 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder 16 | from allennlp.models.model import Model 17 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 18 | from allennlp.nn.util import get_text_field_mask 19 | 20 | from udapter.modules.scalar_mix import ScalarMixWithDropout 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @Model.register("udify_model") 26 | class UdifyModel(Model): 27 | """ 28 | The UDify model base class. Applies a sequence of shared encoders before decoding in a multi-task configuration. 29 | Uses TagDecoder and DependencyDecoder to decode each UD task. 30 | """ 31 | 32 | def __init__(self, 33 | vocab: Vocabulary, 34 | tasks: List[str], 35 | pretrained_model: str, 36 | text_field_embedder: TextFieldEmbedder, 37 | encoder: Seq2SeqEncoder, 38 | decoders: Dict[str, Model], 39 | post_encoder_embedder: TextFieldEmbedder = None, 40 | dropout: float = 0.0, 41 | word_dropout: float = 0.0, 42 | mix_embedding: int = None, 43 | layer_dropout: int = 0.0, 44 | initializer: InitializerApplicator = InitializerApplicator(), 45 | regularizer: Optional[RegularizerApplicator] = None) -> None: 46 | super(UdifyModel, self).__init__(vocab, regularizer) 47 | 48 | self.tasks = tasks 49 | self.vocab = vocab 50 | self.text_field_embedder = text_field_embedder 51 | self.post_encoder_embedder = post_encoder_embedder 52 | self.shared_encoder = encoder 53 | self.word_dropout = word_dropout 54 | self.dropout = torch.nn.Dropout(p=dropout) 55 | self.decoders = torch.nn.ModuleDict(decoders) 56 | 57 | if 'bert' in pretrained_model: 58 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower_case=False) 59 | elif 'xlm' in pretrained_model: 60 | self.tokenizer = XLMTokenizer.from_pretrained(pretrained_model, do_lower_case=False) 61 | else: 62 | raise ConfigurationError(f"No corresponding pretrained model for tokenizer.") 63 | 64 | if mix_embedding: 65 | self.scalar_mix = torch.nn.ModuleDict({ 66 | task: ScalarMixWithDropout(mix_embedding, 67 | do_layer_norm=False, 68 | dropout=layer_dropout) 69 | for task in self.decoders 70 | }) 71 | else: 72 | self.scalar_mix = None 73 | 74 | self.metrics = {} 75 | 76 | for task in self.tasks: 77 | if task not in self.decoders: 78 | raise ConfigurationError(f"Task {task} has no corresponding decoder. Make sure their names match.") 79 | 80 | check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), 81 | "text field embedding dim", "encoder input dim") 82 | 83 | initializer(self) 84 | self._count_params() 85 | 86 | @overrides 87 | def forward(self, 88 | tokens: Dict[str, torch.LongTensor], 89 | metadata: List[Dict[str, Any]] = None, 90 | **kwargs: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: 91 | if "track_epoch" in kwargs: 92 | track_epoch = kwargs.pop("track_epoch") 93 | 94 | gold_tags = kwargs 95 | 96 | if "tokens" in self.tasks: 97 | # Model is predicting tokens, so add them to the gold tags 98 | gold_tags["tokens"] = tokens["tokens"] 99 | 100 | mask = get_text_field_mask(tokens) 101 | self._apply_token_dropout(tokens) 102 | 103 | embedded_text_input = self.text_field_embedder(tokens) 104 | 105 | if self.post_encoder_embedder: 106 | post_embeddings = self.post_encoder_embedder(tokens) 107 | 108 | encoded_text = self.shared_encoder(embedded_text_input, mask) 109 | 110 | logits = {} 111 | class_probabilities = {} 112 | output_dict = {"logits": logits, 113 | "class_probabilities": class_probabilities} 114 | loss = 0 115 | 116 | # Run through each of the tasks on the shared encoder and save predictions 117 | for task in self.tasks: 118 | if self.scalar_mix: 119 | decoder_input = self.scalar_mix[task](encoded_text, mask) 120 | else: 121 | decoder_input = encoded_text 122 | 123 | if self.post_encoder_embedder: 124 | decoder_input = decoder_input + post_embeddings 125 | 126 | if task == "deps": 127 | tag_logits = logits["upos"] if "upos" in logits else None 128 | pred_output = self.decoders[task](decoder_input, mask, tag_logits, 129 | gold_tags, metadata) 130 | for key in ["heads", "head_tags", "arc_loss", "tag_loss", "mask"]: 131 | output_dict[key] = pred_output[key] 132 | else: 133 | pred_output = self.decoders[task](decoder_input, mask, gold_tags, metadata) 134 | 135 | logits[task] = pred_output["logits"] 136 | class_probabilities[task] = pred_output["class_probabilities"] 137 | 138 | if task in gold_tags or task == "deps" and "head_tags" in gold_tags: 139 | # Keep track of the loss if we have the gold tags available 140 | loss += pred_output["loss"] 141 | 142 | if gold_tags: 143 | output_dict["loss"] = loss 144 | 145 | if metadata is not None: 146 | output_dict["words"] = [x["words"] for x in metadata] 147 | output_dict["ids"] = [x["ids"] for x in metadata if "ids" in x] 148 | output_dict["multiword_ids"] = [x["multiword_ids"] for x in metadata if "multiword_ids" in x] 149 | output_dict["multiword_forms"] = [x["multiword_forms"] for x in metadata if "multiword_forms" in x] 150 | 151 | return output_dict 152 | 153 | def _apply_token_dropout(self, tokens): 154 | # Word dropout 155 | if "tokens" in tokens: 156 | oov_token = self.vocab.get_token_index(self.vocab._oov_token) 157 | ignore_tokens = [self.vocab.get_token_index(self.vocab._padding_token)] 158 | tokens["tokens"] = self.token_dropout(tokens["tokens"], 159 | oov_token=oov_token, 160 | padding_tokens=ignore_tokens, 161 | p=self.word_dropout, 162 | training=self.training) 163 | 164 | # BERT token dropout 165 | if "bert" in tokens: 166 | oov_token = self.tokenizer.vocab["[MASK]"] 167 | ignore_tokens = [self.tokenizer.vocab["[PAD]"], self.tokenizer.vocab["[CLS]"], self.tokenizer.vocab["[SEP]"]] 168 | tokens["bert"] = self.token_dropout(tokens["bert"], 169 | oov_token=oov_token, 170 | padding_tokens=ignore_tokens, 171 | p=self.word_dropout, 172 | training=self.training) 173 | 174 | @staticmethod 175 | def token_dropout(tokens: torch.LongTensor, 176 | oov_token: int, 177 | padding_tokens: List[int], 178 | p: float = 0.2, 179 | training: float = True) -> torch.LongTensor: 180 | """ 181 | During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p`` 182 | 183 | :param tokens: The current batch of padded sentences with word ids 184 | :param oov_token: The mask token 185 | :param padding_tokens: The tokens for padding the input batch 186 | :param p: The probability a word gets mapped to the unknown token 187 | :param training: Applies the dropout if set to ``True`` 188 | :return: A copy of the input batch with token dropout applied 189 | """ 190 | if training and p > 0: 191 | # Ensure that the tensors run on the same device 192 | device = tokens.device 193 | 194 | # This creates a mask that only considers unpadded tokens for mapping to oov 195 | padding_mask = torch.ones(tokens.size(), dtype=torch.uint8).to(device) 196 | for pad in padding_tokens: 197 | padding_mask &= tokens != pad 198 | 199 | # Create a uniformly random mask selecting either the original words or OOV tokens 200 | dropout_mask = (torch.empty(tokens.size()).uniform_() < p).to(device) 201 | oov_mask = dropout_mask & padding_mask 202 | 203 | oov_fill = torch.empty(tokens.size(), dtype=torch.long).fill_(oov_token).to(device) 204 | 205 | result = torch.where(oov_mask, oov_fill, tokens) 206 | 207 | return result 208 | else: 209 | return tokens 210 | 211 | @overrides 212 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 213 | for task in self.tasks: 214 | self.decoders[task].decode(output_dict) 215 | 216 | return output_dict 217 | 218 | @overrides 219 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 220 | metrics = {name: task_metric 221 | for task in self.tasks 222 | for name, task_metric in self.decoders[task].get_metrics(reset).items()} 223 | 224 | # The "sum" metric summing all tracked metrics keeps a good measure of patience for early stopping and saving 225 | metrics_to_track = {"upos", "xpos", "feats", "lemmas", "LAS", "UAS"} 226 | metrics[".run/.sum"] = sum(metric 227 | for name, metric in metrics.items() 228 | if not name.startswith("_") and set(name.split("/")).intersection(metrics_to_track)) 229 | 230 | return metrics 231 | 232 | def _count_params(self): 233 | 234 | self.total_params = sum(p.numel() for p in self.parameters()) 235 | self.total_train_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 236 | 237 | logger.info(f"Total number of parameters: {self.total_params}") 238 | logger.info(f"Total number of trainable parameters: {self.total_train_params}") 239 | --------------------------------------------------------------------------------