├── .gitignore ├── LICENSE.txt ├── README.md ├── bert ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── create_pretraining_data.py ├── download_glue_data.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── multilingual.md ├── optimization.py ├── optimization_test.py ├── run_classifier.py ├── run_pretraining.py ├── run_squad.py ├── sample_text.txt ├── tokenization.py └── tokenization_test.py ├── bluebert ├── __init__.py ├── conlleval.py ├── run_bluebert.py ├── run_bluebert_multi_labels.py ├── run_bluebert_ner.py ├── run_bluebert_sts.py └── tf_metrics.py ├── elmo ├── README.md └── elmoft.py ├── mribert ├── README.md ├── lymph_node_vocab.yml └── sequence_classification.py ├── mt-bluebert ├── .gitignore ├── LICENSE ├── README.md ├── mt_bluebert │ ├── __init__.py │ ├── blue_eval.py │ ├── blue_exp_def.py │ ├── blue_inference.py │ ├── blue_metrics.py │ ├── blue_prepro.py │ ├── blue_prepro_std.py │ ├── blue_strip_model.py │ ├── blue_task_def.yml │ ├── blue_train.py │ ├── blue_utils.py │ ├── conlleval.py │ ├── data_utils │ │ ├── __init__.py │ │ ├── gpt2_bpe.py │ │ ├── log_wrapper.py │ │ ├── metrics.py │ │ ├── task_def.py │ │ ├── utils.py │ │ ├── vocab.py │ │ └── xlnet_utils.py │ ├── experiments │ │ ├── __init__.py │ │ ├── common_utils.py │ │ ├── exp_def.py │ │ └── squad │ │ │ ├── __init__.py │ │ │ ├── squad_prepro.py │ │ │ ├── squad_task_def.yml │ │ │ ├── squad_utils.py │ │ │ └── verify_calc_span.py │ ├── module │ │ ├── __init__.py │ │ ├── bert_optim.py │ │ ├── common.py │ │ ├── dropout_wrapper.py │ │ ├── my_optim.py │ │ ├── san.py │ │ ├── similarity.py │ │ └── sub_layers.py │ ├── mt_dnn │ │ ├── __init__.py │ │ ├── batcher.py │ │ ├── inference.py │ │ ├── matcher.py │ │ └── model.py │ └── pmetrics.py ├── requirements.txt └── scripts │ ├── blue_prepro.sh │ ├── convert_tf_to_pt.py │ ├── run_blue_fine_tune.sh │ └── run_blue_mt_dnn.sh ├── requirements.txt └── tokenizer ├── __init__.py └── run_tokenization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | .idea 3 | .DS_Store 4 | 5 | ### Python template 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | ### JetBrains template 111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 113 | 114 | # User-specific stuff 115 | .idea/**/workspace.xml 116 | .idea/**/tasks.xml 117 | .idea/**/dictionaries 118 | .idea/**/shelf 119 | 120 | # Sensitive or high-churn files 121 | .idea/**/dataSources/ 122 | .idea/**/dataSources.ids 123 | .idea/**/dataSources.local.xml 124 | .idea/**/sqlDataSources.xml 125 | .idea/**/dynamic.xml 126 | .idea/**/uiDesigner.xml 127 | 128 | # Gradle 129 | .idea/**/gradle.xml 130 | .idea/**/libraries 131 | 132 | # CMake 133 | cmake-build-debug/ 134 | cmake-build-release/ 135 | 136 | # Mongo Explorer plugin 137 | .idea/**/mongoSettings.xml 138 | 139 | # File-based project format 140 | *.iws 141 | 142 | # IntelliJ 143 | out/ 144 | 145 | # mpeltonen/sbt-idea plugin 146 | .idea_modules/ 147 | 148 | # JIRA plugin 149 | atlassian-ide-plugin.xml 150 | 151 | # Cursive Clojure plugin 152 | .idea/replstate.xml 153 | 154 | # Crashlytics plugin (for Android Studio and IntelliJ) 155 | com_crashlytics_export_strings.xml 156 | crashlytics.properties 157 | crashlytics-build.properties 158 | fabric.properties 159 | 160 | # Editor-based Rest Client 161 | .idea/httpRequests 162 | 163 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | PUBLIC DOMAIN NOTICE 2 | National Center for Biotechnology Information 3 | 4 | This software/database is a "United States Government Work" under the terms of 5 | the United States Copyright Act. It was written as part of the author's 6 | official duties as a United States Government employee and thus cannot be 7 | copyrighted. This software/database is freely available to the public for use. 8 | The National Library of Medicine and the U.S. Government have not placed any 9 | restriction on its use or reproduction. 10 | 11 | Although all reasonable efforts have been taken to ensure the accuracy and 12 | reliability of the software and data, the NLM and the U.S. Government do not and 13 | cannot warrant the performance or results that may be obtained by using this 14 | software or data. The NLM and the U.S. Government disclaim all warranties, 15 | express or implied, including warranties of performance, merchantability or 16 | fitness for any particular purpose. 17 | 18 | Please cite the author in any work or product based on this material: 19 | 20 | Peng Y, Yan S, Lu Z. Transfer Learning in Biomedical Natural Language 21 | Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets. 22 | In Proceedings of the 2019 Workshop on Biomedical Natural Language Processing 23 | (BioNLP 2019). 2019:58-65. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BlueBERT 2 | 3 | **\*\*\*\*\* New Nov 1st, 2020: BlueBERT can be found at huggingface \*\*\*\*\*** 4 | 5 | **\*\*\*\*\* New Dec 5th, 2019: NCBI_BERT is renamed to BlueBERT \*\*\*\*\*** 6 | 7 | **\*\*\*\*\* New July 11th, 2019: preprocessed PubMed texts \*\*\*\*\*** 8 | 9 | We uploaded the [preprocessed PubMed texts](https://github.com/ncbi-nlp/BlueBERT/blob/master/README.md#pubmed) that were used to pre-train the BlueBERT models. 10 | 11 | ----- 12 | 13 | This repository provides codes and models of BlueBERT, pre-trained on PubMed abstracts and clinical notes ([MIMIC-III](https://mimic.physionet.org/)). Please refer to our paper [Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474) for more details. 14 | 15 | ## Pre-trained models and benchmark datasets 16 | 17 | The pre-trained BlueBERT weights, vocab, and config files can be downloaded from: 18 | 19 | * [BlueBERT-Base, Uncased, PubMed](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_uncased_L-12_H-768_A-12.zip): This model was pretrained on PubMed abstracts. 20 | * [BlueBERT-Base, Uncased, PubMed+MIMIC-III](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12.zip): This model was pretrained on PubMed abstracts and MIMIC-III. 21 | * [BlueBERT-Large, Uncased, PubMed](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_uncased_L-24_H-1024_A-16.zip): This model was pretrained on PubMed abstracts. 22 | * [BlueBERT-Large, Uncased, PubMed+MIMIC-III](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_mimic_uncased_L-24_H-1024_A-16.zip): This model was pretrained on PubMed abstracts and MIMIC-III. 23 | 24 | The pre-trained weights can also be found at Huggingface: 25 | 26 | * https://huggingface.co/bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12 27 | * https://huggingface.co/bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12 28 | * https://huggingface.co/bionlp/bluebert_pubmed_uncased_L-24_H-1024_A-16 29 | * https://huggingface.co/bionlp/bluebert_pubmed_mimic_uncased_L-24_H-1024_A-16 30 | 31 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark) 32 | 33 | ## Fine-tuning BlueBERT 34 | 35 | We assume the BlueBERT model has been downloaded at `$BlueBERT_DIR`, and the dataset has been downloaded at `$DATASET_DIR`. 36 | 37 | Add local directory to `$PYTHONPATH` if needed. 38 | 39 | ```bash 40 | export PYTHONPATH=.;$PYTHONPATH 41 | ``` 42 | 43 | ### Sentence similarity 44 | 45 | ```bash 46 | python bluebert/run_bluebert_sts.py \ 47 | --task_name='sts' \ 48 | --do_train=true \ 49 | --do_eval=false \ 50 | --do_test=true \ 51 | --vocab_file=$BlueBERT_DIR/vocab.txt \ 52 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 53 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 54 | --max_seq_length=128 \ 55 | --num_train_epochs=30.0 \ 56 | --do_lower_case=true \ 57 | --data_dir=$DATASET_DIR \ 58 | --output_dir=$OUTPUT_DIR 59 | ``` 60 | 61 | 62 | ### Named Entity Recognition 63 | 64 | ```bash 65 | python bluebert/run_bluebert_ner.py \ 66 | --do_prepare=true \ 67 | --do_train=true \ 68 | --do_eval=true \ 69 | --do_predict=true \ 70 | --task_name="bc5cdr" \ 71 | --vocab_file=$BlueBERT_DIR/vocab.txt \ 72 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 73 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 74 | --num_train_epochs=30.0 \ 75 | --do_lower_case=true \ 76 | --data_dir=$DATASET_DIR \ 77 | --output_dir=$OUTPUT_DIR 78 | ``` 79 | 80 | The task name can be 81 | 82 | - `bc5cdr`: BC5CDR chemical or disease task 83 | - `clefe`: ShARe/CLEFE task 84 | 85 | ### Relation Extraction 86 | 87 | ```bash 88 | python bluebert/run_bluebert.py \ 89 | --do_train=true \ 90 | --do_eval=false \ 91 | --do_predict=true \ 92 | --task_name="chemprot" \ 93 | --vocab_file=$BlueBERT_DIR/vocab.txt \ 94 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 95 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 96 | --num_train_epochs=10.0 \ 97 | --data_dir=$DATASET_DIR \ 98 | --output_dir=$OUTPUT_DIR \ 99 | --do_lower_case=true 100 | ``` 101 | 102 | The task name can be 103 | 104 | - `chemprot`: BC6 ChemProt task 105 | - `ddi`: DDI 2013 task 106 | - `i2b2_2010`: I2B2 2010 task 107 | 108 | ### Document multilabel classification 109 | 110 | ```bash 111 | python bluebert/run_bluebert_multi_labels.py \ 112 | --task_name="hoc" \ 113 | --do_train=true \ 114 | --do_eval=true \ 115 | --do_predict=true \ 116 | --vocab_file=$BlueBERT_DIR/vocab.txt \ 117 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 118 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 119 | --max_seq_length=128 \ 120 | --train_batch_size=4 \ 121 | --learning_rate=2e-5 \ 122 | --num_train_epochs=3 \ 123 | --num_classes=20 \ 124 | --num_aspects=10 \ 125 | --aspect_value_list="0,1" \ 126 | --data_dir=$DATASET_DIR \ 127 | --output_dir=$OUTPUT_DIR 128 | ``` 129 | 130 | ### Inference task 131 | 132 | ```bash 133 | python bluebert/run_bluebert.py \ 134 | --do_train=true \ 135 | --do_eval=false \ 136 | --do_predict=true \ 137 | --task_name="mednli" \ 138 | --vocab_file=$BlueBERT_DIR/vocab.txt \ 139 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 140 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 141 | --num_train_epochs=10.0 \ 142 | --data_dir=$DATASET_DIR \ 143 | --output_dir=$OUTPUT_DIR \ 144 | --do_lower_case=true 145 | ``` 146 | 147 | ## Preprocessed PubMed texts 148 | 149 | We provide [preprocessed PubMed texts](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/pubmed_uncased_sentence_nltk.txt.tar.gz) that were used to pre-train the BlueBERT models. The corpus contains ~4000M words extracted from the [PubMed ASCII code version](https://www.ncbi.nlm.nih.gov/research/bionlp/APIs/BioC-PubMed/). Other operations include 150 | 151 | * lowercasing the text 152 | * removing speical chars `\x00`-`\x7F` 153 | * tokenizing the text using the [NLTK Treebank tokenizer](https://www.nltk.org/_modules/nltk/tokenize/treebank.html) 154 | 155 | Below is a code snippet for more details. 156 | 157 | ```python 158 | value = value.lower() 159 | value = re.sub(r'[\r\n]+', ' ', value) 160 | value = re.sub(r'[^\x00-\x7F]+', ' ', value) 161 | 162 | tokenized = TreebankWordTokenizer().tokenize(value) 163 | sentence = ' '.join(tokenized) 164 | sentence = re.sub(r"\s's\b", "'s", sentence) 165 | ``` 166 | 167 | ### Pre-training with BERT 168 | 169 | Afterwards, we used the following code to generate pre-training data. Please see https://github.com/google-research/bert for more details. 170 | 171 | ```bash 172 | python bert/create_pretraining_data.py \ 173 | --input_file=pubmed_uncased_sentence_nltk.txt \ 174 | --output_file=pubmed_uncased_sentence_nltk.tfrecord \ 175 | --vocab_file=bert_uncased_L-12_H-768_A-12_vocab.txt \ 176 | --do_lower_case=True \ 177 | --max_seq_length=128 \ 178 | --max_predictions_per_seq=20 \ 179 | --masked_lm_prob=0.15 \ 180 | --random_seed=12345 \ 181 | --dupe_factor=5 182 | ``` 183 | 184 | We used the following code to train the BERT model. Please do not include `init_checkpoint` if you are pre-training from scratch. Please see https://github.com/google-research/bert for more details. 185 | 186 | ```bash 187 | python bert/run_pretraining.py \ 188 | --input_file=pubmed_uncased_sentence_nltk.tfrecord \ 189 | --output_dir=$BlueBERT_DIR \ 190 | --do_train=True \ 191 | --do_eval=True \ 192 | --bert_config_file=$BlueBERT_DIR/bert_config.json \ 193 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \ 194 | --train_batch_size=32 \ 195 | --max_seq_length=128 \ 196 | --max_predictions_per_seq=20 \ 197 | --num_train_steps=20000 \ 198 | --num_warmup_steps=10 \ 199 | --learning_rate=2e-5 200 | ``` 201 | 202 | ## Citing BlueBERT 203 | 204 | * Peng Y, Yan S, Lu Z. [Transfer Learning in Biomedical Natural Language Processing: An 205 | Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474). In *Proceedings of the Workshop on Biomedical Natural Language Processing (BioNLP)*. 2019. 206 | 207 | ``` 208 | @InProceedings{peng2019transfer, 209 | author = {Yifan Peng and Shankai Yan and Zhiyong Lu}, 210 | title = {Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets}, 211 | booktitle = {Proceedings of the 2019 Workshop on Biomedical Natural Language Processing (BioNLP 2019)}, 212 | year = {2019}, 213 | pages = {58--65}, 214 | } 215 | ``` 216 | 217 | ## Acknowledgments 218 | 219 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of 220 | Medicine and Clinical Center. This work was supported by the National Library of Medicine of the National Institutes of Health under award number K99LM013001-01. 221 | 222 | We are also grateful to the authors of BERT and ELMo to make the data and codes publicly available. 223 | 224 | We would like to thank Dr Sun Kim for processing the PubMed texts. 225 | 226 | ## Disclaimer 227 | 228 | This tool shows the results of research conducted in the Computational Biology Branch, NCBI. The information produced 229 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight 230 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information 231 | produced on this website. NIH does not independently verify the validity or utility of the information produced 232 | by this tool. If you have questions about the information produced on this website, please see a health care 233 | professional. More information about NCBI's disclaimer policy is available. 234 | -------------------------------------------------------------------------------- /bert/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bert/download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | 3 | Note: for legal reasons, we are unable to host MRPC. 4 | You can either use the version hosted by the SentEval team, which is already tokenized, 5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 7 | You should then rename and place specific files in a folder (see below for an example). 8 | 9 | mkdir MRPC 10 | cabextract MSRParaphraseCorpus.msi -d MRPC 11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 13 | rm MRPC/_* 14 | rm MSRParaphraseCorpus.msi 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 33 | "QNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLI.zip?alt=media&token=c24cad61-f2df-4f04-9ab6-aa576fa829d0', 34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 37 | 38 | MRPC_TRAIN = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_train.txt' 39 | MRPC_TEST = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_test.txt' 40 | 41 | def download_and_extract(task, data_dir): 42 | print("Downloading and extracting %s..." % task) 43 | data_file = "%s.zip" % task 44 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 45 | with zipfile.ZipFile(data_file) as zip_ref: 46 | zip_ref.extractall(data_dir) 47 | os.remove(data_file) 48 | print("\tCompleted!") 49 | 50 | def format_mrpc(data_dir, path_to_data): 51 | print("Processing MRPC...") 52 | mrpc_dir = os.path.join(data_dir, "MRPC") 53 | if not os.path.isdir(mrpc_dir): 54 | os.mkdir(mrpc_dir) 55 | if path_to_data: 56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 58 | else: 59 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 61 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 62 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 63 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 64 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 65 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 66 | 67 | dev_ids = [] 68 | with open(os.path.join(mrpc_dir, "dev_ids.tsv")) as ids_fh: 69 | for row in ids_fh: 70 | dev_ids.append(row.strip().split('\t')) 71 | 72 | with open(mrpc_train_file, encoding='utf8') as data_fh, \ 73 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf8') as train_fh, \ 74 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf8') as dev_fh: 75 | header = data_fh.readline() 76 | train_fh.write(header) 77 | dev_fh.write(header) 78 | for row in data_fh: 79 | label, id1, id2, s1, s2 = row.strip().split('\t') 80 | if [id1, id2] in dev_ids: 81 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 82 | else: 83 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 84 | 85 | with open(mrpc_test_file, encoding='utf8') as data_fh, \ 86 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf8') as test_fh: 87 | header = data_fh.readline() 88 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 89 | for idx, row in enumerate(data_fh): 90 | label, id1, id2, s1, s2 = row.strip().split('\t') 91 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 92 | print("\tCompleted!") 93 | 94 | def download_diagnostic(data_dir): 95 | print("Downloading and extracting diagnostic...") 96 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 97 | os.mkdir(os.path.join(data_dir, "diagnostic")) 98 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 99 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 100 | print("\tCompleted!") 101 | return 102 | 103 | def get_tasks(task_names): 104 | task_names = task_names.split(',') 105 | if "all" in task_names: 106 | tasks = TASKS 107 | else: 108 | tasks = [] 109 | for task_name in task_names: 110 | assert task_name in TASKS, "Task %s not found!" % task_name 111 | tasks.append(task_name) 112 | return tasks 113 | 114 | def main(arguments): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 117 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 118 | type=str, default='all') 119 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 120 | type=str, default='') 121 | args = parser.parse_args(arguments) 122 | 123 | if not os.path.isdir(args.data_dir): 124 | os.mkdir(args.data_dir) 125 | tasks = get_tasks(args.tasks) 126 | 127 | for task in tasks: 128 | if task == 'MRPC': 129 | format_mrpc(args.data_dir, args.path_to_mrpc) 130 | elif task == 'diagnostic': 131 | download_diagnostic(args.data_dir) 132 | else: 133 | download_and_extract(task, args.data_dir) 134 | 135 | 136 | if __name__ == '__main__': 137 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import six 25 | import tensorflow as tf 26 | 27 | from bert import modeling 28 | 29 | 30 | class BertModelTest(tf.test.TestCase): 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import six 22 | import tensorflow as tf 23 | import tokenization 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /bluebert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/bluebert/__init__.py -------------------------------------------------------------------------------- /bluebert/conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for i, line in enumerate(iterable): 77 | line = line.rstrip('\r\n') 78 | # print(line) 79 | 80 | if options.delimiter == ANY_SPACE: 81 | features = line.split() 82 | else: 83 | features = line.split(options.delimiter) 84 | 85 | if num_features is None: 86 | num_features = len(features) 87 | elif num_features != len(features) and len(features) != 0: 88 | raise FormatError('unexpected number of features: %d (%d) at line %d\n%s' % 89 | (len(features), num_features, i, line)) 90 | 91 | if len(features) == 0 or features[0] == options.boundary: 92 | features = [options.boundary, 'O', 'O'] 93 | if len(features) < 3: 94 | raise FormatError('unexpected number of features in line %s' % line) 95 | 96 | guessed, guessed_type = parse_tag(features.pop()) 97 | correct, correct_type = parse_tag(features.pop()) 98 | first_item = features.pop(0) 99 | 100 | if first_item == options.boundary: 101 | guessed = 'O' 102 | 103 | end_correct = end_of_chunk(last_correct, correct, 104 | last_correct_type, correct_type) 105 | end_guessed = end_of_chunk(last_guessed, guessed, 106 | last_guessed_type, guessed_type) 107 | start_correct = start_of_chunk(last_correct, correct, 108 | last_correct_type, correct_type) 109 | start_guessed = start_of_chunk(last_guessed, guessed, 110 | last_guessed_type, guessed_type) 111 | 112 | if in_correct: 113 | if (end_correct and end_guessed and 114 | last_guessed_type == last_correct_type): 115 | in_correct = False 116 | counts.correct_chunk += 1 117 | counts.t_correct_chunk[last_correct_type] += 1 118 | elif (end_correct != end_guessed or guessed_type != correct_type): 119 | in_correct = False 120 | 121 | if start_correct and start_guessed and guessed_type == correct_type: 122 | in_correct = True 123 | 124 | if start_correct: 125 | counts.found_correct += 1 126 | counts.t_found_correct[correct_type] += 1 127 | if start_guessed: 128 | counts.found_guessed += 1 129 | counts.t_found_guessed[guessed_type] += 1 130 | if first_item != options.boundary: 131 | if correct == guessed and guessed_type == correct_type: 132 | counts.correct_tags += 1 133 | counts.token_counter += 1 134 | 135 | last_guessed = guessed 136 | last_correct = correct 137 | last_guessed_type = guessed_type 138 | last_correct_type = correct_type 139 | 140 | if in_correct: 141 | counts.correct_chunk += 1 142 | counts.t_correct_chunk[last_correct_type] += 1 143 | 144 | return counts 145 | 146 | 147 | 148 | def uniq(iterable): 149 | seen = set() 150 | return [i for i in iterable if not (i in seen or seen.add(i))] 151 | 152 | 153 | def calculate_metrics(correct, guessed, total): 154 | tp, fp, fn = correct, guessed-correct, total-correct 155 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 156 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 157 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 158 | return Metrics(tp, fp, fn, p, r, f) 159 | 160 | 161 | def metrics(counts): 162 | c = counts 163 | overall = calculate_metrics( 164 | c.correct_chunk, c.found_guessed, c.found_correct 165 | ) 166 | by_type = {} 167 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 168 | by_type[t] = calculate_metrics( 169 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 170 | ) 171 | return overall, by_type 172 | 173 | 174 | def report(counts, out=None): 175 | if out is None: 176 | out = sys.stdout 177 | 178 | overall, by_type = metrics(counts) 179 | 180 | c = counts 181 | out.write('processed %d tokens with %d phrases; ' % 182 | (c.token_counter, c.found_correct)) 183 | out.write('found: %d phrases; correct: %d.\n' % 184 | (c.found_guessed, c.correct_chunk)) 185 | 186 | if c.token_counter > 0: 187 | out.write('accuracy: %6.2f%%; ' % 188 | (100.*c.correct_tags/c.token_counter)) 189 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 190 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 191 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 192 | 193 | for i, m in sorted(by_type.items()): 194 | out.write('%17s: ' % i) 195 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 196 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 197 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 198 | 199 | 200 | def report_notprint(counts, out=None): 201 | if out is None: 202 | out = sys.stdout 203 | 204 | overall, by_type = metrics(counts) 205 | 206 | c = counts 207 | final_report = [] 208 | line = [] 209 | line.append('processed %d tokens with %d phrases; ' % 210 | (c.token_counter, c.found_correct)) 211 | line.append('found: %d phrases; correct: %d.\n' % 212 | (c.found_guessed, c.correct_chunk)) 213 | final_report.append("".join(line)) 214 | 215 | if c.token_counter > 0: 216 | line = [] 217 | line.append('accuracy: %6.2f%%; ' % 218 | (100.*c.correct_tags/c.token_counter)) 219 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 220 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 221 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 222 | final_report.append("".join(line)) 223 | 224 | for i, m in sorted(by_type.items()): 225 | line = [] 226 | line.append('%17s: ' % i) 227 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 228 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 229 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 230 | final_report.append("".join(line)) 231 | return final_report 232 | 233 | 234 | def end_of_chunk(prev_tag, tag, prev_type, type_): 235 | # check if a chunk ended between the previous and current word 236 | # arguments: previous and current chunk tags, previous and current types 237 | chunk_end = False 238 | 239 | if prev_tag == 'E': chunk_end = True 240 | if prev_tag == 'S': chunk_end = True 241 | 242 | if prev_tag == 'B' and tag == 'B': chunk_end = True 243 | if prev_tag == 'B' and tag == 'S': chunk_end = True 244 | if prev_tag == 'B' and tag == 'O': chunk_end = True 245 | if prev_tag == 'I' and tag == 'B': chunk_end = True 246 | if prev_tag == 'I' and tag == 'S': chunk_end = True 247 | if prev_tag == 'I' and tag == 'O': chunk_end = True 248 | 249 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 250 | chunk_end = True 251 | 252 | # these chunks are assumed to have length 1 253 | if prev_tag == ']': chunk_end = True 254 | if prev_tag == '[': chunk_end = True 255 | 256 | return chunk_end 257 | 258 | 259 | def start_of_chunk(prev_tag, tag, prev_type, type_): 260 | # check if a chunk started between the previous and current word 261 | # arguments: previous and current chunk tags, previous and current types 262 | chunk_start = False 263 | 264 | if tag == 'B': chunk_start = True 265 | if tag == 'S': chunk_start = True 266 | 267 | if prev_tag == 'E' and tag == 'E': chunk_start = True 268 | if prev_tag == 'E' and tag == 'I': chunk_start = True 269 | if prev_tag == 'S' and tag == 'E': chunk_start = True 270 | if prev_tag == 'S' and tag == 'I': chunk_start = True 271 | if prev_tag == 'O' and tag == 'E': chunk_start = True 272 | if prev_tag == 'O' and tag == 'I': chunk_start = True 273 | 274 | if tag != 'O' and tag != '.' and prev_type != type_: 275 | chunk_start = True 276 | 277 | # these chunks are assumed to have length 1 278 | if tag == '[': chunk_start = True 279 | if tag == ']': chunk_start = True 280 | 281 | return chunk_start 282 | 283 | 284 | def main(argv): 285 | args = parse_args(argv[1:]) 286 | 287 | if args.file is None: 288 | counts = evaluate(sys.stdin, args) 289 | else: 290 | with open(args.file) as f: 291 | counts = evaluate(f, args) 292 | report(counts) 293 | 294 | 295 | def return_report(input_file): 296 | with open(input_file, "r") as f: 297 | counts = evaluate(f) 298 | return report_notprint(counts) 299 | 300 | if __name__ == '__main__': 301 | # sys.exit(main(sys.argv)) 302 | return_report('/home/pengy6/data/sentence_similarity/data/cdr/test1/wanli_result2/label_test.txt') -------------------------------------------------------------------------------- /bluebert/tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() 216 | -------------------------------------------------------------------------------- /elmo/README.md: -------------------------------------------------------------------------------- 1 | # ELMo Fine-tuning 2 | The python script `elmoft.py` provides utility functions for fine-tuning the [ELMo model](https://allennlp.org/elmo). We used the model pre-trained on PubMed in our paper [Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474). 3 | 4 | ## Pre-trained models and benchmark datasets 5 | Please prepare the pre-trained ELMo model files `options.json` as well as `weights.hdf5` and indicate their locations in the parameters `--options_path` as well as`--weights_path` when running the script. The model pre-trained on PubMed can be downloaded from the [ELMo website](https://allennlp.org/elmo) 6 | 7 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark) 8 | 9 | ## Fine-tuning ELMo 10 | We assume the ELMo model has been downloaded at `$ELMO_DIR`, and the dataset has been downloaded at `$DATASET_DIR`. 11 | 12 | ### Sentence similarity 13 | 14 | ```bash 15 | python elmoft.py \ 16 | --task 'clnclsts' \ 17 | --seq2vec 'boe' \ 18 | --options_path $ELMO_DIR/options.json \ 19 | --weights_path $ELMO_DIR/weights.hdf5 \ 20 | --maxlen 256 \ 21 | --fchdim 500 \ 22 | --lr 0.001 \ 23 | --pdrop 0.5 \ 24 | --do_norm --norm_type batch \ 25 | --do_lastdrop \ 26 | --initln \ 27 | --earlystop \ 28 | --epochs 20 \ 29 | --bsize 64 \ 30 | --data_dir=$DATASET_DIR 31 | ``` 32 | 33 | The task can be 34 | 35 | - `clnclsts`: Mayo Clinics clinical sentence similarity task 36 | - `biosses`: Biomedical Summarization Track sentence similarity task 37 | 38 | 39 | ### Named Entity Recognition 40 | 41 | ```bash 42 | python elmoft.py \ 43 | --task 'bc5cdr-chem' \ 44 | --seq2vec 'boe' \ 45 | --options_path $ELMO_DIR/options.json \ 46 | --weights_path $ELMO_DIR/weights.hdf5 \ 47 | --maxlen 128\ 48 | --fchdim 500 \ 49 | --lr 0.001 \ 50 | --pdrop 0.5 \ 51 | --do_norm --norm_type batch \ 52 | --do_lastdrop \ 53 | --initln \ 54 | --earlystop \ 55 | --epochs 20 \ 56 | --bsize 64 \ 57 | --data_dir=$DATASET_DIR 58 | ``` 59 | 60 | The task can be 61 | 62 | - `bc5cdr-chem`: BC5CDR chemical or disease task 63 | - `bc5cdr-dz`: BC5CDR disease task 64 | - `shareclefe`: ShARe/CLEFE task 65 | 66 | ### Relation Extraction 67 | 68 | ```bash 69 | python elmoft.py \ 70 | --task 'ddi' \ 71 | --seq2vec 'boe' \ 72 | --options_path $ELMO_DIR/options.json \ 73 | --weights_path $ELMO_DIR/weights.hdf5 \ 74 | --maxlen 128 \ 75 | --fchdim 500 \ 76 | --lr 0.001 \ 77 | --pdrop 0.5 \ 78 | --do_norm --norm_type batch \ 79 | --initln \ 80 | --earlystop \ 81 | --epochs 20 \ 82 | --bsize 64 \ 83 | --data_dir=$DATASET_DIR 84 | ``` 85 | 86 | The task name can be 87 | 88 | - `ddi`: DDI 2013 task 89 | - `chemprot`: BC6 ChemProt task 90 | - `i2b2`: I2B2 relation extraction task 91 | 92 | ### Document multilabel classification 93 | 94 | ```bash 95 | python elmoft.py \ 96 | --task 'hoc' \ 97 | --seq2vec 'boe' \ 98 | --options_path $ELMO_DIR/options.json \ 99 | --weights_path $ELMO_DIR/weights.hdf5 \ 100 | --maxlen 128 \ 101 | --fchdim 500 \ 102 | --lr 0.001 \ 103 | --pdrop 0.5 \ 104 | --do_norm --norm_type batch \ 105 | --initln \ 106 | --earlystop \ 107 | --epochs 50 \ 108 | --bsize 64 \ 109 | --data_dir=$DATASET_DIR 110 | ``` 111 | 112 | ### Inference task 113 | 114 | ```bash 115 | python elmoft.py \ 116 | --task 'mednli' \ 117 | --seq2vec 'boe' \ 118 | --options_path $ELMO_DIR/options.json \ 119 | --weights_path $ELMO_DIR/weights.hdf5 \ 120 | --maxlen 128 \ 121 | --fchdim 500 \ 122 | --lr 0.0005 \ 123 | --pdrop 0.5 \ 124 | --do_norm --norm_type batch \ 125 | --initln \ 126 | --earlystop \ 127 | --epochs 20 \ 128 | --bsize 64 \ 129 | --data_dir=$DATASET_DIR 130 | ``` 131 | 132 | ## GPU Acceleration 133 | If there is no GPU devices please indicate it using the parameter `-g 0` or `--gpunum 0`. Otherwise, please indicate the index (starting from 0) of the GPU you want to use by setting the parameter `-q 0` or `--gpuq 0`. 134 | -------------------------------------------------------------------------------- /mribert/README.md: -------------------------------------------------------------------------------- 1 | ## Automatic recognition of abdominal lymph nodes from clinical text 2 | 3 | This repository provides codes and models of the BERT model for lymph node detection from MRI reports. 4 | 5 | ## Pre-trained models 6 | 7 | The pre-trained model weights, vocab, and config files can be downloaded from: 8 | 9 | * [mribert](https://github.com/ncbi-nlp/bluebert/releases/tag/lymphnode) 10 | 11 | ## Fine-tuning BERT 12 | 13 | We assume the MriBERT model has been downloaded at `$MriBERT_DIR`. 14 | 15 | ```bash 16 | tstr=$(date +"%FT%H%M%S%N") 17 | text_col="text" 18 | label_col="label" 19 | batch_size=32 20 | train_dataset="train,dev" 21 | val_dataset="dev" 22 | test_dataset="test" 23 | epochs=10 24 | 25 | bert_dir=$MriBERT_DIR 26 | dataset="$MriBERT_DIR/total_data.csv" 27 | model_dir="MriBERT_DIR/mri_${tstr}" 28 | test_predictions="predictions_mribert.csv" 29 | 30 | # predict new 31 | pred_dataset="$MriBERT_DIR/new_data.csv" 32 | pred_predictions="new_data_predictions.csv" 33 | 34 | export PYTHONPATH=.;$PYTHONPATH 35 | python sequence_classification.py \ 36 | --do_train \ 37 | --do_test \ 38 | --dataset "${dataset}" \ 39 | --output_dir "${model_dir}" \ 40 | --vocab_file $bert_dir/vocab.txt \ 41 | --bert_config_file $bert_dir/bert_config.json \ 42 | --init_checkpoint $bert_dir/mribert_model.ckpt \ 43 | --text_col "${text_col}" \ 44 | --label_col "${label_col}" \ 45 | --batch_size "${batch_size}" \ 46 | --train_dataset "${train_dataset}" \ 47 | --val_dataset "${val_dataset}" \ 48 | --test_dataset "${test_dataset}" \ 49 | --pred_dataset "${pred_dataset}" \ 50 | --epochs ${epochs} \ 51 | --test_predictions ${test_predictions} \ 52 | --pred_predictions ${pred_predictions} 53 | ``` 54 | 55 | 56 | ## Citing MriBert 57 | 58 | Peng Y, Lee S, Elton D, Shen T, Tang YX, Chen Q, Wang S, Zhu Y, Summers RM, Lu Z. 59 | Automatic recognition of abdominal lymph nodes from clinical text. 60 | In Proceedings of the ClinicalNLP Workshop. 2020. 61 | 62 | ## Acknowledgments 63 | 64 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of 65 | Medicine and Clinical Center. 66 | This work was supported by the National Library of Medicine of the National Institutes of Health under award number 4R00LM013001. 67 | 68 | ## Disclaimer 69 | 70 | This tool shows the results of research conducted in the Computational Biology Branch, NLM/NCBI. The information produced 71 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight 72 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information 73 | produced on this website. NIH does not independently verify the validity or utility of the information produced 74 | by this tool. If you have questions about the information produced on this website, please see a health care 75 | professional. More information about NLM/NCBI's disclaimer policy is available. 76 | -------------------------------------------------------------------------------- /mribert/lymph_node_vocab.yml: -------------------------------------------------------------------------------- 1 | Mediastinal lymph node: 2 | include: 3 | - mediastinum 4 | - mediastinal 5 | Subcarinal lymph node: 6 | include: 7 | - subcarinal 8 | - tracheobronchial 9 | Cardiophrenic lymph node: 10 | include: 11 | - cardiophrenic 12 | - cardiophrenic angle 13 | - pericardiac 14 | Paraesophageal lymph node: 15 | include: 16 | - esophageal 17 | - esophagus 18 | - paraesophageal 19 | - 8 thoracic lymph node 20 | #------------------------------------------------------------------------------- 21 | Retroperitoneal lymph node: 22 | include: 23 | - retroperitoneum 24 | - retroperitoneal 25 | - retro peritoneum 26 | - right paraspinal 27 | - adrenal gland 28 | - adrenal 29 | - nephrectomy bed 30 | - nephrectomy resection bed 31 | - nephrectomy surgical bed 32 | - peripelvic 33 | Retrocrural lymph node: 34 | include: 35 | - retrocrural 36 | Para-aortic lymph node: 37 | include: 38 | - periaortic 39 | - peri-aortic 40 | - paraaortic 41 | - para-aortic 42 | - lateral aortic 43 | - infrarenal abdominal aorta 44 | - left to the aorta 45 | - retroaortic 46 | - retro-aortic 47 | Interaortocaval lymph node: 48 | include: 49 | - aortocaval 50 | - interaortocaval 51 | Retrocaval lymph node: 52 | include: 53 | - retrocaval 54 | - postcaval 55 | - posterior to the intrahepatic IVC 56 | Preaortic lymph node: 57 | include: 58 | - preaotic 59 | - preaortic 60 | - anterior to the abdominal aorta 61 | - anterior to the aorta 62 | Paracaval lymph node: 63 | include: 64 | - paracaval 65 | - pericaval 66 | - caval 67 | - vena cava 68 | - cava 69 | - paracardial 70 | - aJCC level 16 71 | - paracardial gastric 72 | - lateral to the IVC 73 | - margin of the intrahepatic IVC 74 | Precaval lymph node: # anterior to the vena cava 75 | include: 76 | - precaval 77 | Paraspinal lymph node: 78 | include: 79 | - paraspinal 80 | #------------------------------------------------------------------------------- 81 | Peritoneal lymph node: 82 | include: 83 | - peritoneal 84 | - peritoneum 85 | Subdiaphragmatic lymph node: 86 | include: 87 | - diaphragmatic 88 | - peri-diaphragmatic 89 | - subdiaphragmatic 90 | Perihepatic lymph node: 91 | include: 92 | - perihepatic 93 | Paraduodenal lymph node: 94 | include: 95 | - paraduodenal 96 | - portion of the duodenum 97 | Hepatic artery lymph node: 98 | include: 99 | - hepatic artery 100 | - hepatic arterial 101 | Periportal lymph node: 102 | include: 103 | - liver hilum 104 | - hepatoportal 105 | - porta hepatis 106 | - periportal 107 | - portal vein 108 | Peripancreatic lymph node: 109 | include: 110 | - pancreatic 111 | - pancreas 112 | - peripancreatic 113 | Portocaval lymph node: 114 | include: 115 | - interportocaval 116 | - portocaval 117 | - portacaval 118 | Perigastric lymph node: 119 | include: 120 | - perigastric 121 | Splenic lymph node: 122 | include: 123 | - spleen 124 | - splenule 125 | - splenic 126 | Celiac lymph node: 127 | include: 128 | - celiac 129 | - paraceliac 130 | Superior mesenteric lymph node: 131 | include: 132 | - SMA 133 | - at the level of the SMA 134 | - superior mesenteric 135 | Mesenteric lymph node: 136 | include: 137 | - mesentery 138 | - mesenteric 139 | - pelvis mesentery 140 | Perigastric lymph node along lesser curvature: 141 | include: 142 | - lesser curve 143 | - lesser curvature 144 | Perigastric lymph node along greater curvature: 145 | include: 146 | - greater curve 147 | - greater curvature 148 | Gastrosplenic lymph node: 149 | include: 150 | - Gastrosplenic 151 | Gastrohepatic ligament lymph node: 152 | include: 153 | - gastrohepatic 154 | - gastrohepatic ligament 155 | - gastric 156 | Hepatoduodenal ligament lymph node: 157 | include: 158 | - hepatoduodenal 159 | - hepatoduodenal ligament 160 | Paracolic lymph node: 161 | include: 162 | - paracolic 163 | Pericecal lymph node: 164 | include: 165 | - pericecal 166 | - adjacent to the cecum 167 | - ileocolic 168 | Periportal/peripancreatic lymph node: 169 | include: 170 | - periportal/peripancreatic 171 | #------------------------------------------------------------------------------- 172 | Pelvic lymph node: 173 | include: 174 | - pelvic 175 | - pelvis 176 | Common iliac lymph node: 177 | include: 178 | - common iliac 179 | - iliac 180 | - ilium 181 | External iliac lymph node: 182 | include: 183 | - external iliac 184 | Psoas lymph node: 185 | include: 186 | - psoas 187 | - psoas muscle 188 | Presacral lymph node: 189 | include: 190 | - presacral 191 | Perivesicular lymph node: 192 | include: 193 | - Perivesicular 194 | #------------------------------------------------------------------------------- 195 | Inguinal lymph node: 196 | include: 197 | - groin 198 | - inguinal 199 | - thoracic 200 | -------------------------------------------------------------------------------- /mt-bluebert/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | mtdnn_env 7 | venv_windows 8 | mtdnn_env_apex 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | .DS_Store 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # IDE pycharm 112 | .idea/ 113 | 114 | log/ 115 | model/ 116 | submission/ 117 | save/ 118 | book_corpus_test/ 119 | book_corpus_train/ 120 | checkpoints/ 121 | .pt_description_history 122 | .git-credentials 123 | pt_bert/philly 124 | .vs 125 | *.pyproj 126 | pt_bert/checkpoint 127 | */aml_experiments 128 | screenlog.* 129 | data 130 | pt_bert/scripts 131 | pt_bert/model_data 132 | screen* 133 | checkpoint 134 | *.sln 135 | dt_mtl 136 | philly 137 | bert_models 138 | run_baseline* 139 | mt_dnn_models 140 | *pyc 141 | run_test/ 142 | experiments/superglue 143 | -------------------------------------------------------------------------------- /mt-bluebert/LICENSE: -------------------------------------------------------------------------------- 1 | PUBLIC DOMAIN NOTICE 2 | National Center for Biotechnology Information 3 | 4 | This software/database is a "United States Government Work" under the terms of 5 | the United States Copyright Act. It was written as part of the author's 6 | official duties as a United States Government employee and thus cannot be 7 | copyrighted. This software/database is freely available to the public for use. 8 | The National Library of Medicine and the U.S. Government have not placed any 9 | restriction on its use or reproduction. 10 | 11 | Although all reasonable efforts have been taken to ensure the accuracy and 12 | reliability of the software and data, the NLM and the U.S. Government do not and 13 | cannot warrant the performance or results that may be obtained by using this 14 | software or data. The NLM and the U.S. Government disclaim all warranties, 15 | express or implied, including warranties of performance, merchantability or 16 | fitness for any particular purpose. 17 | 18 | Please cite the author in any work or product based on this material: 19 | 20 | Peng Y, Chen Q, Lu Z. An Empirical Study of Multi-Task Learning on BERT 21 | for Biomedical Text Mining. In Proceedings of the 2020 Workshop on Biomedical 22 | Natural Language Processing (BioNLP 2020). 2020. -------------------------------------------------------------------------------- /mt-bluebert/README.md: -------------------------------------------------------------------------------- 1 | # Multi-Task Learning on BERT for Biomedical Text Mining 2 | 3 | This repository provides codes and models of the Multi-Task Learning on BERT for Biomedical Text Mining. 4 | The package is based on [`mt-dnn`](https://github.com/namisan/mt-dnn). 5 | 6 | ## Pre-trained models 7 | 8 | The pre-trained MT-BlueBERT weights, vocab, and config files can be downloaded from: 9 | 10 | * [mt-bluebert-biomedical](https://github.com/yfpeng/mt-bluebert/releases/download/0.1/mt-bluebert-biomedical.pt) 11 | * [mt-bluebert-clinical](https://github.com/yfpeng/mt-bluebert/releases/download/0.1/mt-bluebert-clinical.pt) 12 | 13 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark) 14 | 15 | ## Quick start 16 | 17 | ### Setup Environment 18 | 1. python3.6 19 | 2. install requirements 20 | ```bash 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Download data 25 | Please refer to download BLUE_Benchmark: https://github.com/ncbi-nlp/BLUE_Benchmark 26 | 27 | 28 | ### Preprocess data 29 | ```bash 30 | bash ncbi_scripts/blue_prepro.sh 31 | ``` 32 | 33 | ### Train a MT-DNN model 34 | ```bash 35 | bash ncbi_scripts/run_blue_mt_dnn.sh 36 | ``` 37 | 38 | ### Fine-tune a model 39 | ```bash 40 | bash ncbi_scripts/run_blue_fine_tune.sh 41 | ``` 42 | 43 | ### Convert Tensorflow BERT model to the MT-DNN format 44 | ```bash 45 | python ncbi_scripts/convert_tf_to_pt.py --tf_checkpoint_root $SRC_ROOT --pytorch_checkpoint_path $DEST --encoder_type 1``` 46 | ``` 47 | 48 | ## Citing MT-BLUE 49 | 50 | Peng Y, Chen Q, Lu Z. An Empirical Study of Multi-Task Learning on BERT 51 | for Biomedical Text Mining. In Proceedings of the 2020 Workshop on Biomedical 52 | Natural Language Processing (BioNLP 2020). 2020. 53 | 54 | ``` 55 | @InProceedings{peng2019transfer, 56 | author = {Yifan Peng and Qingyu Chen and Zhiyong Lu}, 57 | title = {An Empirical Study of Multi-Task Learning on BERT for Biomedical Text Mining}, 58 | booktitle = {Proceedings of the 2020 Workshop on Biomedical Natural Language Processing (BioNLP 2020)}, 59 | year = {2020}, 60 | } 61 | ``` 62 | 63 | ## Acknowledgments 64 | 65 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of 66 | Medicine. This work was supported by the National Library of Medicine of the National Institutes of Health under award number K99LM013001-01. 67 | 68 | We are also grateful to the authors of BERT and mt-dnn to make the data and codes publicly available. 69 | 70 | ## Disclaimer 71 | 72 | This tool shows the results of research conducted in the Computational Biology Branch, NLM/NCBI. The information produced 73 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight 74 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information 75 | produced on this website. NIH does not independently verify the validity or utility of the information produced 76 | by this tool. If you have questions about the information produced on this website, please see a health care 77 | professional. More information about NLM/NCBI's disclaimer policy is available. 78 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/__init__.py -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | blue_eval [options] --task_def= --data_dir= --range= --model_dir= 4 | 5 | Options: 6 | --task_def= 7 | --data_dir= 8 | --test_datasets= 9 | --range= 10 | --output= 11 | """ 12 | import collections 13 | import functools 14 | import json 15 | from pathlib import Path 16 | from typing import Dict 17 | 18 | import docopt as docopt 19 | import numpy as np 20 | import pandas as pd 21 | import yaml 22 | 23 | from mt_bluebert.blue_metrics import compute_micro_f1, compute_micro_f1_subindex, compute_seq_f1, compute_pearson 24 | 25 | 26 | class TaskMetric: 27 | def __init__(self, task): 28 | self.task = task 29 | self.epochs = [] 30 | self.scores = [] 31 | 32 | def print_max(self): 33 | index = int(np.argmax(self.scores)) 34 | print('%s: Max at epoch %d: %.3f' % (self.task, self.epochs[index], self.scores[index])) 35 | 36 | 37 | def get_score1(pred_file, task, n_class, rows, golds, metric_func): 38 | with open(pred_file, 'r', encoding='utf8') as fp: 39 | obj = json.load(fp) 40 | 41 | preds = [] 42 | for i, uid in enumerate(obj['uids']): 43 | if uid != rows[i]['uid']: 44 | raise ValueError('{}: {} vs {}'.format(task, uid, rows[i]['uid'])) 45 | if n_class == 1: 46 | pred = obj['scores'][i] 47 | elif n_class > 1: 48 | pred = obj['predictions'][i] 49 | else: 50 | raise KeyError(task) 51 | preds.append(pred) 52 | 53 | score = metric_func(preds, golds) 54 | return score 55 | 56 | 57 | def get_score2(pred_file, task, n_class, rows, golds, metric_func): 58 | with open(pred_file, 'r', encoding='utf8') as fp: 59 | objs = [] 60 | for line in fp: 61 | objs.append(json.loads(line)) 62 | 63 | preds = [] 64 | for i, obj in enumerate(objs): 65 | if obj['uid'] != rows[i]['uid']: 66 | raise ValueError('{}: {} vs {}'.format(task, obj['uid'], rows[i]['uid'])) 67 | if n_class == 1: 68 | pred = obj['score'] 69 | elif n_class > 1: 70 | pred = obj['prediction'] 71 | else: 72 | raise KeyError(task) 73 | preds.append(pred) 74 | 75 | score = metric_func(preds, golds) 76 | return score 77 | 78 | 79 | def eval_blue(test_datasets, task_def_path, data_dir, model_dir, epochs): 80 | with open(task_def_path) as fp: 81 | task_def = yaml.safe_load(fp) 82 | 83 | METRIC_FUNC = { 84 | 'biosses': compute_pearson, 85 | 'clinicalsts': compute_pearson, 86 | 'mednli': compute_micro_f1, 87 | 'i2b2-2010-re': functools.partial( 88 | compute_micro_f1_subindex, 89 | subindex=[i for i in range(len(task_def['i2b2-2010-re']['labels']) - 1)]), 90 | 'chemprot': functools.partial( 91 | compute_micro_f1_subindex, 92 | subindex=[i for i in range(len(task_def['chemprot']['labels']) - 1)]), 93 | 'ddi2013-type': functools.partial( 94 | compute_micro_f1_subindex, 95 | subindex=[i for i in range(len(task_def['ddi2013-type']['labels']) - 1)]), 96 | 'shareclefe': functools.partial( 97 | compute_seq_f1, 98 | label_mapper={i: v for i, v in enumerate(task_def['shareclefe']['labels'])}), 99 | 'bc5cdr-disease': functools.partial( 100 | compute_seq_f1, 101 | label_mapper={i: v for i, v in enumerate(task_def['bc5cdr-disease']['labels'])}), 102 | 'bc5cdr-chemical': functools.partial( 103 | compute_seq_f1, 104 | label_mapper={i: v for i, v in enumerate(task_def['bc5cdr-chemical']['labels'])}), 105 | } 106 | 107 | total_scores = collections.OrderedDict() # type: Dict[str, TaskMetric] 108 | for task in test_datasets: 109 | n_class = task_def[task]['n_class'] 110 | 111 | file = data_dir / f'{task}_test.json' 112 | with open(file, 'r', encoding='utf-8') as fp: 113 | rows = [json.loads(line) for line in fp] 114 | # print('Loaded {} samples'.format(len(rows))) 115 | golds = [row['label'] for row in rows] 116 | 117 | task_metric = TaskMetric(task) 118 | for epoch in epochs: 119 | # pred_file = model_dir / f'{task}_test_scores_{epoch}.json' 120 | # score = get_score1(pred_file, task, n_class, rows, golds, METRIC_FUNC[task]) 121 | # scores.append(score) 122 | 123 | # if task in ('clinicalsts', 'i2b2-2010-re', 'mednli', 'shareclefe'): 124 | pred_file = model_dir / f'{task}_test_scores_{epoch}_2.json' 125 | try: 126 | score = get_score2(pred_file, task, n_class, rows, golds, METRIC_FUNC[task]) 127 | except FileNotFoundError: 128 | pred_file = model_dir / f'{task}_test_scores_{epoch}.json' 129 | score = get_score1(pred_file, task, n_class, rows, golds, METRIC_FUNC[task]) 130 | 131 | task_metric.epochs.append(epoch) 132 | task_metric.scores.append(score) 133 | 134 | total_scores[task] = task_metric 135 | task_metric.print_max() 136 | 137 | # if len(scores2) != 0: 138 | # index = np.argmax(scores2) 139 | # print('%s: Max at epoch %d: %.3f' % (task, epochs[index], scores2[index])) 140 | # index = np.argmin(scores) 141 | # print('%s: Min at epoch %d: %.3f' % (task, epochs[index], scores[index])) 142 | return total_scores 143 | 144 | 145 | def pretty_print(total_scores, epochs, dest=None): 146 | # average 147 | for task_metric in total_scores.values(): 148 | assert len(task_metric.scores) == len(epochs) 149 | avg_scores = [np.average([total_scores[t].scores[i] for t in total_scores.keys()]) 150 | for i, _ in enumerate(epochs)] 151 | index = int(np.argmax(avg_scores)) 152 | print('On average, max at epoch %d: %.3f' % (epochs[index], avg_scores[index])) 153 | for t in total_scores: 154 | print(' %s: At epoch %d: %.3f' % (t, epochs[index], total_scores[t].scores[index])) 155 | 156 | if dest is not None: 157 | table = {'epoch': epochs, 'average': avg_scores} 158 | for t, v in total_scores.items(): 159 | table[t] = v.scores 160 | df = pd.DataFrame.from_dict(table) 161 | df.to_csv(dest, index=None) 162 | 163 | 164 | def main(): 165 | args = docopt.docopt(__doc__) 166 | print(args) 167 | 168 | task_def_path = Path(args['--task_def']) 169 | model_dir = Path(args['--model_dir']) 170 | data_dir = Path(args['--data_dir']) 171 | 172 | toks = args['--range'].split(',') 173 | epochs = list(range(int(toks[0]), int(toks[1]))) 174 | 175 | test_datasets = args['--test_datasets'] 176 | if test_datasets is None: 177 | test_datasets = ['clinicalsts', 'i2b2-2010-re', 'mednli', 'chemprot', 'ddi2013-type', 178 | 'bc5cdr-chemical', 'bc5cdr-disease', 'shareclefe'] 179 | else: 180 | test_datasets = test_datasets.split(',') 181 | 182 | total_scores = eval_blue(test_datasets, task_def_path, data_dir, model_dir, epochs) 183 | pretty_print(total_scores, epochs, args['--output']) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_exp_def.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Set 2 | 3 | import yaml 4 | 5 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat, EncoderModelType 6 | from mt_bluebert.data_utils.vocab import Vocabulary 7 | from mt_bluebert.blue_metrics import BlueMetric 8 | 9 | 10 | class BlueTaskDefs: 11 | def __init__(self, task_def_path): 12 | with open(task_def_path) as fp: 13 | self.task_def_dic = yaml.load(fp, yaml.FullLoader) 14 | 15 | self.label_mapper_map = {} # type: Dict[str, Vocabulary] 16 | self.n_class_map = {} # type: Dict[str, int] 17 | self.data_format_map = {} 18 | self.task_type_map = {} 19 | self.metric_meta_map = {} 20 | self.enable_san_map = {} 21 | self.dropout_p_map = {} 22 | self.split_names_map = {} 23 | self.encoder_type = None 24 | for task, task_def in self.task_def_dic.items(): 25 | assert "_" not in task, "task name should not contain '_', current task name: %s" % task 26 | self.n_class_map[task] = task_def["n_class"] 27 | self.data_format_map[task] = DataFormat[task_def["data_format"]] 28 | self.task_type_map[task] = TaskType[task_def["task_type"]] 29 | self.metric_meta_map[task] = tuple(BlueMetric[metric_name] for metric_name in task_def["metric_meta"]) 30 | self.enable_san_map[task] = task_def["enable_san"] 31 | if self.encoder_type is None: 32 | self.encoder_type = EncoderModelType[task_def["encoder_type"]] 33 | else: 34 | if self.encoder_type != EncoderModelType[task_def["encoder_type"]]: 35 | raise ValueError('The shared encoder has to be the same.') 36 | 37 | if "labels" in task_def: 38 | label_mapper = Vocabulary(True) 39 | for label in task_def["labels"]: 40 | label_mapper.add(label) 41 | self.label_mapper_map[task] = label_mapper 42 | else: 43 | self.label_mapper_map[task] = None 44 | 45 | if "dropout_p" in task_def: 46 | self.dropout_p_map[task] = task_def["dropout_p"] 47 | 48 | if 'split_names' in task_def: 49 | self.split_names_map[task] = task_def['split_names'] 50 | else: 51 | self.split_names_map[task] = ["train", "dev", "test"] 52 | 53 | @property 54 | def tasks(self) -> Set[str]: 55 | return self.task_def_dic.keys() 56 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from data_utils.vocab import Vocabulary 4 | from mt_bluebert.blue_metrics import calc_metrics, BlueMetric 5 | 6 | 7 | def eval_model(model, data, 8 | metric_meta: Tuple[BlueMetric], 9 | use_cuda: bool=True, 10 | with_label: bool=True, 11 | label_mapper: Vocabulary=None): 12 | data.reset() 13 | if use_cuda: 14 | model.cuda() 15 | predictions = [] 16 | golds = [] 17 | scores = [] 18 | ids = [] 19 | metrics = {} 20 | for batch_meta, batch_data in data: 21 | score, pred, gold = model.predict(batch_meta, batch_data) 22 | predictions.extend(pred) 23 | golds.extend(gold) 24 | scores.extend(score) 25 | ids.extend(batch_meta['uids']) 26 | if with_label: 27 | metrics = calc_metrics(metric_meta, golds, predictions, scores, label_mapper) 28 | return metrics, predictions, scores, golds, ids 29 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | # modified by: Yifan Peng 3 | import logging 4 | from enum import Enum 5 | 6 | from scipy.stats import pearsonr, spearmanr 7 | from sklearn.metrics import accuracy_score, f1_score 8 | from sklearn.metrics import matthews_corrcoef 9 | from sklearn.metrics import roc_auc_score 10 | 11 | from mt_bluebert.data_utils.vocab import Vocabulary 12 | from mt_bluebert.pmetrics import blue_classification_report, ner_report_conlleval 13 | 14 | 15 | def compute_acc(predicts, labels): 16 | return accuracy_score(labels, predicts) 17 | 18 | 19 | def compute_f1(predicts, labels): 20 | return f1_score(labels, predicts) 21 | 22 | 23 | def compute_mcc(predicts, labels): 24 | return matthews_corrcoef(labels, predicts) 25 | 26 | 27 | def compute_pearson(predicts, labels): 28 | pcof = pearsonr(labels, predicts)[0] 29 | return pcof 30 | 31 | 32 | def compute_spearman(predicts, labels): 33 | scof = spearmanr(labels, predicts)[0] 34 | return scof 35 | 36 | 37 | def compute_auc(predicts, labels): 38 | auc = roc_auc_score(labels, predicts) 39 | return auc 40 | 41 | 42 | def compute_micro_f1(predicts, labels): 43 | report = blue_classification_report(labels, predicts) 44 | return report.micro_row.f1.item() 45 | 46 | 47 | def compute_micro_f1_subindex(predicts, labels, subindex): 48 | report = blue_classification_report(labels, predicts) 49 | try: 50 | sub_report = report.sub_report(subindex) 51 | return sub_report.micro_row.f1.item() 52 | except Exception as e: 53 | logging.error('%s\n%s', e, report.report()) 54 | return 0 55 | 56 | def compute_macro_f1_subindex(predicts, labels, subindex): 57 | report = blue_classification_report(labels, predicts) 58 | try: 59 | sub_report = report.sub_report(subindex) 60 | return sub_report.macro_row.f1.item() 61 | except Exception as e: 62 | logging.error('%s\n%s', e, report.report()) 63 | return 0 64 | 65 | 66 | def compute_seq_f1(predicts, labels, label_mapper): 67 | y_true, y_pred = [], [] 68 | 69 | def trim(predict, label): 70 | temp_1 = [] 71 | temp_2 = [] 72 | 73 | # label_index = 1 74 | # pred_index = 1 75 | # while pred_index < len(predict) and label_index < len(label): 76 | # if label_mapper[label[label_index]] == 'X' and label_mapper[predict[pred_index]] == 'X': 77 | # label_index += 1 78 | # pred_index += 1 79 | # elif label_mapper[predict[pred_index]] == 'X': 80 | # pred_index += 1 81 | # elif label_mapper[label[label_index]] == 'X': 82 | # label_index += 1 83 | # pred_index += 1 84 | # else: 85 | # temp_1.append(label_mapper[label[label_index]]) 86 | # temp_2.append(label_mapper[predict[pred_index]]) 87 | # label_index += 1 88 | # pred_index += 1 89 | 90 | for j, m in enumerate(predict): 91 | if j == 0: 92 | continue 93 | # if j >= len(label): 94 | # print(predict, label) 95 | # exit(1) 96 | 97 | if label_mapper[label[j]] != 'X': 98 | temp_1.append(label_mapper[label[j]]) 99 | temp_2.append(label_mapper[m]) 100 | temp_1.pop() 101 | temp_2.pop() 102 | y_true.append(temp_1) 103 | y_pred.append(temp_2) 104 | 105 | for i, (predict, label) in enumerate(zip(predicts, labels)): 106 | # try: 107 | trim(predict, label) 108 | # if i == 100: 109 | # break 110 | # except Exception as e: 111 | # print('index ', i) 112 | # exit(1) 113 | report = ner_report_conlleval(y_true, y_pred) 114 | return report.micro_row.f1.item() 115 | 116 | 117 | class BlueMetric(Enum): 118 | ACC = 0 119 | F1 = 1 120 | MCC = 2 121 | Pearson = 3 122 | Spearman = 4 123 | AUC = 5 124 | SeqEval = 7 125 | MicroF1 = 8 126 | MicroF1WithoutLastOne = 9 127 | MacroF1WithoutLastOne = 10 128 | 129 | 130 | METRIC_FUNC = { 131 | BlueMetric.ACC: compute_acc, 132 | BlueMetric.F1: compute_f1, 133 | BlueMetric.MCC: compute_mcc, 134 | BlueMetric.Pearson: compute_pearson, 135 | BlueMetric.Spearman: compute_spearman, 136 | BlueMetric.AUC: compute_auc, 137 | BlueMetric.SeqEval: compute_seq_f1, 138 | BlueMetric.MicroF1: compute_micro_f1, 139 | BlueMetric.MicroF1WithoutLastOne: compute_micro_f1_subindex, 140 | BlueMetric.MacroF1WithoutLastOne: compute_macro_f1_subindex 141 | } 142 | 143 | 144 | def calc_metrics(metric_meta, golds, predictions, scores, label_mapper: Vocabulary = None): 145 | metrics = {} 146 | for mm in metric_meta: 147 | metric_name = mm.name 148 | metric_func = METRIC_FUNC[mm] 149 | if mm in (BlueMetric.ACC, BlueMetric.F1, BlueMetric.MCC, BlueMetric.MicroF1): 150 | metric = metric_func(predictions, golds) 151 | elif mm == BlueMetric.SeqEval: 152 | metric = metric_func(predictions, golds, label_mapper) 153 | elif mm == BlueMetric.MicroF1WithoutLastOne: 154 | metric = metric_func(predictions, golds, subindex=list(range(len(label_mapper) - 1))) 155 | elif mm == BlueMetric.MacroF1WithoutLastOne: 156 | metric = metric_func(predictions, golds, subindex=list(range(len(label_mapper) - 1))) 157 | else: 158 | if mm == BlueMetric.AUC: 159 | assert len(scores) == 2 * len(golds), "AUC is only valid for binary classification problem" 160 | scores = scores[1::2] 161 | metric = metric_func(scores, golds) 162 | metrics[metric_name] = metric 163 | return metrics 164 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_prepro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing BLUE dataset. 3 | 4 | Usage: 5 | blue_prepro [options] --root_dir= --task_def= --datasets= 6 | 7 | Options: 8 | --overwrite 9 | """ 10 | import os 11 | 12 | import docopt 13 | 14 | from mt_bluebert.data_utils.log_wrapper import create_logger 15 | from mt_bluebert.blue_exp_def import BlueTaskDefs 16 | from mt_bluebert.blue_utils import load_sts, load_mednli, \ 17 | load_relation, load_ner, dump_rows 18 | 19 | 20 | def main(args): 21 | root = args['--root_dir'] 22 | assert os.path.exists(root) 23 | 24 | log_file = os.path.join(root, 'blue_prepro.log') 25 | logger = create_logger(__name__, to_disk=True, log_file=log_file) 26 | 27 | task_defs = BlueTaskDefs(args['--task_def']) 28 | 29 | canonical_data_suffix = "canonical_data" 30 | canonical_data_root = os.path.join(root, canonical_data_suffix) 31 | if not os.path.isdir(canonical_data_root): 32 | os.mkdir(canonical_data_root) 33 | 34 | if args['--datasets'] == 'all': 35 | tasks = task_defs.tasks 36 | else: 37 | tasks = args['--datasets'].split(',') 38 | for task in tasks: 39 | logger.info("Task %s" % task) 40 | if task not in task_defs.task_def_dic: 41 | raise KeyError('%s: Cannot process this task' % task) 42 | 43 | if task in ['clinicalsts', 'biosses']: 44 | load = load_sts 45 | elif task == 'mednli': 46 | load = load_mednli 47 | elif task in ('chemprot', 'i2b2-2010-re', 'ddi2013-type'): 48 | load = load_relation 49 | elif task in ('bc5cdr-disease', 'bc5cdr-chemical', 'shareclefe'): 50 | load = load_ner 51 | else: 52 | raise KeyError('%s: Cannot process this task' % task) 53 | 54 | data_format = task_defs.data_format_map[task] 55 | split_names = task_defs.split_names_map[task] 56 | for split_name in split_names: 57 | fin = os.path.join(root, f'{task}/{split_name}.tsv') 58 | fout = os.path.join(canonical_data_root, f'{task}_{split_name}.tsv') 59 | if os.path.exists(fout) and not args['--overwrite']: 60 | logger.warning('%s: Not overwrite %s: %s', task, split_name, fout) 61 | continue 62 | data = load(fin) 63 | logger.info('%s: Loaded %s %s samples', task, len(data), split_name) 64 | dump_rows(data, fout, data_format) 65 | 66 | logger.info('%s: Done', task) 67 | 68 | 69 | if __name__ == '__main__': 70 | args = docopt.docopt(__doc__) 71 | main(args) 72 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_strip_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | blue_strip_model 4 | """ 5 | import logging 6 | import os 7 | 8 | import docopt 9 | import torch 10 | 11 | 12 | def main(): 13 | args = docopt.docopt(__doc__) 14 | print(args) 15 | 16 | if not os.path.exists(args['']): 17 | logging.error('%s: Cannot find the model', args['']) 18 | 19 | map_location = 'cpu' if not torch.cuda.is_available() else None 20 | state_dict = torch.load(args[''], map_location=map_location) 21 | config = state_dict['config'] 22 | if config['ema_opt'] > 0: 23 | state = state_dict['ema'] 24 | else: 25 | state = state_dict['state'] 26 | 27 | my_state = {k: v for k, v in state.items() if not k.startswith('scoring_list.')} 28 | my_config = {k: config[k] for k in ('vocab_size', 'hidden_size', 'num_hidden_layers', 'num_attention_heads', 29 | 'hidden_act', 'intermediate_size', 'hidden_dropout_prob', 30 | 'attention_probs_dropout_prob', 'max_position_embeddings', 'type_vocab_size', 31 | 'initializer_range')} 32 | 33 | torch.save({'state': my_state, 'config': my_config}, args['']) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_task_def.yml: -------------------------------------------------------------------------------- 1 | clinicalsts: 2 | data_format: PremiseAndOneHypothesis 3 | encoder_type: BERT 4 | enable_san: false 5 | metric_meta: 6 | - Pearson 7 | n_class: 1 8 | task_type: Regression 9 | biosses: 10 | data_format: PremiseAndOneHypothesis 11 | encoder_type: BERT 12 | enable_san: false 13 | metric_meta: 14 | - Pearson 15 | n_class: 1 16 | task_type: Regression 17 | mednli: 18 | data_format: PremiseAndOneHypothesis 19 | encoder_type: BERT 20 | enable_san: true 21 | labels: 22 | - contradiction 23 | - neutral 24 | - entailment 25 | metric_meta: 26 | - MicroF1 27 | n_class: 3 28 | task_type: Classification 29 | chemprot: 30 | data_format: PremiseOnly 31 | encoder_type: BERT 32 | enable_san: false 33 | labels: 34 | - CPR:3 35 | - CPR:4 36 | - CPR:5 37 | - CPR:6 38 | - CPR:9 39 | - 'false' 40 | metric_meta: 41 | - MicroF1WithoutLastOne 42 | n_class: 6 43 | task_type: Classification 44 | ddi2013-type: 45 | data_format: PremiseOnly 46 | encoder_type: BERT 47 | enable_san: false 48 | labels: 49 | - DDI-advise 50 | - DDI-effect 51 | - DDI-int 52 | - DDI-mechanism 53 | - DDI-false 54 | metric_meta: 55 | - MacroF1WithoutLastOne 56 | - MicroF1WithoutLastOne 57 | n_class: 5 58 | task_type: Classification 59 | i2b2-2010-re: 60 | data_format: PremiseOnly 61 | encoder_type: BERT 62 | enable_san: false 63 | labels: 64 | - PIP 65 | - TeCP 66 | - TeRP 67 | - TrAP 68 | - TrCP 69 | - TrIP 70 | - TrNAP 71 | - TrWP 72 | - 'false' 73 | metric_meta: 74 | - MicroF1WithoutLastOne 75 | n_class: 9 76 | task_type: Classification 77 | bc5cdr-disease: 78 | data_format: Sequence 79 | encoder_type: BERT 80 | enable_san: False 81 | labels: 82 | - O 83 | - B-Disease 84 | - I-Disease 85 | - X 86 | - CLS 87 | - SEP 88 | metric_meta: 89 | - SeqEval 90 | n_class: 6 91 | task_type: SequenceLabeling 92 | bc5cdr-chemical: 93 | data_format: Sequence 94 | encoder_type: BERT 95 | enable_san: False 96 | labels: 97 | - O 98 | - B-Chemical 99 | - I-Chemical 100 | - X 101 | - CLS 102 | - SEP 103 | metric_meta: 104 | - SeqEval 105 | n_class: 6 106 | task_type: SequenceLabeling 107 | shareclefe: 108 | data_format: Sequence 109 | encoder_type: BERT 110 | enable_san: False 111 | labels: 112 | - O 113 | - B-Disease 114 | - I-Disease 115 | - X 116 | - CLS 117 | - SEP 118 | metric_meta: 119 | - SeqEval 120 | n_class: 6 121 | task_type: SequenceLabeling -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/blue_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | 5 | from mt_bluebert.data_utils import DataFormat 6 | 7 | 8 | def load_relation(file): 9 | rows = [] 10 | with open(file, encoding="utf8") as f: 11 | reader = csv.reader(f, delimiter='\t') 12 | next(reader) 13 | for i, blocks in enumerate(reader): 14 | assert len(blocks) == 3, '%s:%s: number of blocks: %s' % (file, i, len(blocks)) 15 | lab = blocks[-1] 16 | sample = {'uid': blocks[0], 'premise': blocks[1], 'label': lab} 17 | rows.append(sample) 18 | return rows 19 | 20 | 21 | def load_mednli(file): 22 | """MEDNLI for classification""" 23 | rows = [] 24 | with open(file, encoding="utf8") as f: 25 | reader = csv.reader(f, delimiter='\t') 26 | next(reader) 27 | for i, blocks in enumerate(reader): 28 | assert len(blocks) == 4, '%s:%s: number of blocks: %s' % (file, i, len(blocks)) 29 | lab = blocks[0] 30 | assert lab is not None, '%s:%s: label is None' % (file, i) 31 | sample = {'uid': blocks[1], 'premise': blocks[2], 'hypothesis': blocks[3], 'label': lab} 32 | rows.append(sample) 33 | return rows 34 | 35 | 36 | def load_sts(file): 37 | rows = [] 38 | cnt = 0 39 | with open(file, encoding="utf8") as f: 40 | reader = csv.reader(f, delimiter='\t') 41 | next(reader) 42 | for i, blocks in enumerate(reader): 43 | assert len(blocks) > 8, '%s:%s: number of blocks: %s' % (file, i, len(blocks)) 44 | score = blocks[-1] 45 | sample = {'uid': cnt, 'premise': blocks[-3],'hypothesis': blocks[-2], 'label': score} 46 | rows.append(sample) 47 | cnt += 1 48 | return rows 49 | 50 | 51 | def load_ner(file, sep='\t'): 52 | rows = [] 53 | sentence = [] 54 | label = [] 55 | offset = [] 56 | uid = None 57 | with open(file, encoding="utf8") as f: 58 | for line in f: 59 | line = line.strip() 60 | if len(line) == 0 or line[0] == "\n": 61 | if len(sentence) > 0: 62 | assert uid is not None 63 | sample = {'uid': uid, 'premise': sentence, 'label': label, 'offset': offset} 64 | rows.append(sample) 65 | sentence = [] 66 | label = [] 67 | offset = [] 68 | uid = None 69 | continue 70 | splits = line.split(sep) 71 | assert len(splits) == 4 72 | sentence.append(splits[0]) 73 | offset.append('{};{}'.format(int(splits[2]), int(splits[2]) + len(splits[0]))) 74 | label.append(splits[3]) 75 | if splits[1] != '-': 76 | uid = splits[1] + '.' + splits[2] 77 | if len(sentence) > 0: 78 | assert uid is not None 79 | sample = {'uid': uid, 'premise': sentence, 'label': label, 'offset': offset} 80 | rows.append(sample) 81 | return rows 82 | 83 | 84 | def dump_PremiseOnly(rows, out_path): 85 | logger = logging.getLogger(__name__) 86 | with open(out_path, "w", encoding="utf-8") as out_f: 87 | for i, row in enumerate(rows): 88 | row_str = [] 89 | for col in ["uid", "label", "premise"]: 90 | if "\t" in str(row[col]): 91 | row[col] = row[col].replace('\t', ' ') 92 | logger.warning('%s:%s: %s has tab' % (out_path, i, col)) 93 | row_str.append(str(row[col])) 94 | out_f.write('\t'.join(row_str) + '\n') 95 | 96 | 97 | def dump_PremiseAndOneHypothesis(rows, out_path): 98 | logger = logging.getLogger(__name__) 99 | with open(out_path, "w", encoding="utf-8") as out_f: 100 | for i, row in enumerate(rows): 101 | row_str = [] 102 | for col in ["uid", "label", "premise", "hypothesis"]: 103 | if "\t" in str(row[col]): 104 | row[col] = row[col].replace('\t', ' ') 105 | logger.warning('%s:%s: %s has tab' % (out_path, i, col)) 106 | row_str.append(str(row[col])) 107 | out_f.write('\t'.join(row_str) + '\n') 108 | 109 | 110 | def dump_Sequence(rows, out_path): 111 | logger = logging.getLogger(__name__) 112 | with open(out_path, "w", encoding="utf-8") as out_f: 113 | for i, row in enumerate(rows): 114 | row_str = [] 115 | if "\t" in str(row['uid']): 116 | row['uid'] = row['uid'].replace('\t', ' ') 117 | logger.warning('%s:%s: %s has tab' % (out_path, i, 'uid')) 118 | row_str.append(str(row['uid'])) 119 | for col in ["label", "premise", "offset"]: 120 | for j, token in enumerate(row[col]): 121 | if "\t" in str(token): 122 | row[col][j] = token.replace('\t', ' ') 123 | logger.warning('%s:%s: %s has tab' % (out_path, i, col)) 124 | row_str.append(json.dumps(row[col])) 125 | out_f.write('\t'.join(row_str) + '\n') 126 | 127 | 128 | def dump_PremiseAndMultiHypothesis(rows, out_path): 129 | logger = logging.getLogger(__name__) 130 | with open(out_path, "w", encoding="utf-8") as out_f: 131 | for i, row in enumerate(rows): 132 | row_str = [] 133 | for col in ["uid", "label", "premise"]: 134 | if "\t" in str(row[col]): 135 | row[col] = row[col].replace('\t', ' ') 136 | logger.warning('%s:%s: %s has tab' % (out_path, i, col)) 137 | row_str.append(str(row[col])) 138 | hypothesis = row["hypothesis"] 139 | for j, one_hypo in enumerate(hypothesis): 140 | if "\t" in str(one_hypo): 141 | hypothesis[j] = one_hypo.replace('\t', ' ') 142 | logger.warning('%s:%s: hypothesis has tab' % (out_path, i)) 143 | row_str.append("\t".join(hypothesis)) 144 | out_f.write('\t'.join(row_str) + '\n') 145 | 146 | def dump_rows(rows, out_path, data_format: DataFormat): 147 | """ 148 | output files should have following format 149 | """ 150 | if data_format == DataFormat.PremiseOnly: 151 | dump_PremiseOnly(rows, out_path) 152 | elif data_format == DataFormat.PremiseAndOneHypothesis: 153 | dump_PremiseAndOneHypothesis(rows, out_path) 154 | elif data_format == DataFormat.PremiseAndMultiHypothesis: 155 | dump_PremiseAndMultiHypothesis(rows, out_path) 156 | elif data_format == DataFormat.Sequence: 157 | dump_Sequence(rows, out_path) 158 | else: 159 | raise ValueError(data_format) 160 | 161 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for i, line in enumerate(iterable): 77 | line = line.rstrip('\r\n') 78 | # print(line) 79 | 80 | if options.delimiter == ANY_SPACE: 81 | features = line.split() 82 | else: 83 | features = line.split(options.delimiter) 84 | 85 | if num_features is None: 86 | num_features = len(features) 87 | elif num_features != len(features) and len(features) != 0: 88 | raise FormatError('unexpected number of features: %d (%d) at line %d\n%s' % 89 | (len(features), num_features, i, line)) 90 | 91 | if len(features) == 0 or features[0] == options.boundary: 92 | features = [options.boundary, 'O', 'O'] 93 | if len(features) < 3: 94 | raise FormatError('unexpected number of features in line %s' % line) 95 | 96 | guessed, guessed_type = parse_tag(features.pop()) 97 | correct, correct_type = parse_tag(features.pop()) 98 | first_item = features.pop(0) 99 | 100 | if first_item == options.boundary: 101 | guessed = 'O' 102 | 103 | end_correct = end_of_chunk(last_correct, correct, 104 | last_correct_type, correct_type) 105 | end_guessed = end_of_chunk(last_guessed, guessed, 106 | last_guessed_type, guessed_type) 107 | start_correct = start_of_chunk(last_correct, correct, 108 | last_correct_type, correct_type) 109 | start_guessed = start_of_chunk(last_guessed, guessed, 110 | last_guessed_type, guessed_type) 111 | 112 | if in_correct: 113 | if (end_correct and end_guessed and 114 | last_guessed_type == last_correct_type): 115 | in_correct = False 116 | counts.correct_chunk += 1 117 | counts.t_correct_chunk[last_correct_type] += 1 118 | elif (end_correct != end_guessed or guessed_type != correct_type): 119 | in_correct = False 120 | 121 | if start_correct and start_guessed and guessed_type == correct_type: 122 | in_correct = True 123 | 124 | if start_correct: 125 | counts.found_correct += 1 126 | counts.t_found_correct[correct_type] += 1 127 | if start_guessed: 128 | counts.found_guessed += 1 129 | counts.t_found_guessed[guessed_type] += 1 130 | if first_item != options.boundary: 131 | if correct == guessed and guessed_type == correct_type: 132 | counts.correct_tags += 1 133 | counts.token_counter += 1 134 | 135 | last_guessed = guessed 136 | last_correct = correct 137 | last_guessed_type = guessed_type 138 | last_correct_type = correct_type 139 | 140 | if in_correct: 141 | counts.correct_chunk += 1 142 | counts.t_correct_chunk[last_correct_type] += 1 143 | 144 | return counts 145 | 146 | 147 | 148 | def uniq(iterable): 149 | seen = set() 150 | return [i for i in iterable if not (i in seen or seen.add(i))] 151 | 152 | 153 | def calculate_metrics(correct, guessed, total): 154 | tp, fp, fn = correct, guessed-correct, total-correct 155 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 156 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 157 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 158 | return Metrics(tp, fp, fn, p, r, f) 159 | 160 | 161 | def metrics(counts): 162 | c = counts 163 | overall = calculate_metrics( 164 | c.correct_chunk, c.found_guessed, c.found_correct 165 | ) 166 | by_type = {} 167 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 168 | by_type[t] = calculate_metrics( 169 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 170 | ) 171 | return overall, by_type 172 | 173 | 174 | def report(counts, out=None): 175 | if out is None: 176 | out = sys.stdout 177 | 178 | overall, by_type = metrics(counts) 179 | 180 | c = counts 181 | out.write('processed %d tokens with %d phrases; ' % 182 | (c.token_counter, c.found_correct)) 183 | out.write('found: %d phrases; correct: %d.\n' % 184 | (c.found_guessed, c.correct_chunk)) 185 | 186 | if c.token_counter > 0: 187 | out.write('accuracy: %6.2f%%; ' % 188 | (100.*c.correct_tags/c.token_counter)) 189 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 190 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 191 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 192 | 193 | for i, m in sorted(by_type.items()): 194 | out.write('%17s: ' % i) 195 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 196 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 197 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 198 | 199 | 200 | def report_notprint(counts, out=None): 201 | if out is None: 202 | out = sys.stdout 203 | 204 | overall, by_type = metrics(counts) 205 | 206 | c = counts 207 | final_report = [] 208 | line = [] 209 | line.append('processed %d tokens with %d phrases; ' % 210 | (c.token_counter, c.found_correct)) 211 | line.append('found: %d phrases; correct: %d.\n' % 212 | (c.found_guessed, c.correct_chunk)) 213 | final_report.append("".join(line)) 214 | 215 | if c.token_counter > 0: 216 | line = [] 217 | line.append('accuracy: %6.2f%%; ' % 218 | (100.*c.correct_tags/c.token_counter)) 219 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 220 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 221 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 222 | final_report.append("".join(line)) 223 | 224 | for i, m in sorted(by_type.items()): 225 | line = [] 226 | line.append('%17s: ' % i) 227 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 228 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 229 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 230 | final_report.append("".join(line)) 231 | return final_report 232 | 233 | 234 | def end_of_chunk(prev_tag, tag, prev_type, type_): 235 | # check if a chunk ended between the previous and current word 236 | # arguments: previous and current chunk tags, previous and current types 237 | chunk_end = False 238 | 239 | if prev_tag == 'E': chunk_end = True 240 | if prev_tag == 'S': chunk_end = True 241 | 242 | if prev_tag == 'B' and tag == 'B': chunk_end = True 243 | if prev_tag == 'B' and tag == 'S': chunk_end = True 244 | if prev_tag == 'B' and tag == 'O': chunk_end = True 245 | if prev_tag == 'I' and tag == 'B': chunk_end = True 246 | if prev_tag == 'I' and tag == 'S': chunk_end = True 247 | if prev_tag == 'I' and tag == 'O': chunk_end = True 248 | 249 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 250 | chunk_end = True 251 | 252 | # these chunks are assumed to have length 1 253 | if prev_tag == ']': chunk_end = True 254 | if prev_tag == '[': chunk_end = True 255 | 256 | return chunk_end 257 | 258 | 259 | def start_of_chunk(prev_tag, tag, prev_type, type_): 260 | # check if a chunk started between the previous and current word 261 | # arguments: previous and current chunk tags, previous and current types 262 | chunk_start = False 263 | 264 | if tag == 'B': chunk_start = True 265 | if tag == 'S': chunk_start = True 266 | 267 | if prev_tag == 'E' and tag == 'E': chunk_start = True 268 | if prev_tag == 'E' and tag == 'I': chunk_start = True 269 | if prev_tag == 'S' and tag == 'E': chunk_start = True 270 | if prev_tag == 'S' and tag == 'I': chunk_start = True 271 | if prev_tag == 'O' and tag == 'E': chunk_start = True 272 | if prev_tag == 'O' and tag == 'I': chunk_start = True 273 | 274 | if tag != 'O' and tag != '.' and prev_type != type_: 275 | chunk_start = True 276 | 277 | # these chunks are assumed to have length 1 278 | if tag == '[': chunk_start = True 279 | if tag == ']': chunk_start = True 280 | 281 | return chunk_start 282 | 283 | 284 | def main(argv): 285 | args = parse_args(argv[1:]) 286 | 287 | if args.file is None: 288 | counts = evaluate(sys.stdin, args) 289 | else: 290 | with open(args.file) as f: 291 | counts = evaluate(f, args) 292 | report(counts) 293 | 294 | 295 | def return_report(input_file): 296 | with open(input_file, "r") as f: 297 | counts = evaluate(f) 298 | return report_notprint(counts) 299 | 300 | if __name__ == '__main__': 301 | # sys.exit(main(sys.argv)) 302 | return_report('/home/pengy6/data/sentence_similarity/data/cdr/test1/wanli_result2/label_test.txt') -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat 6 | 7 | 8 | def load_data(file_path, data_format, task_type, label_dict=None): 9 | """ 10 | :param file_path: 11 | :param data_format: 12 | :param task_type: 13 | :param label_dict: map string label to numbers. 14 | only valid for Classification task or ranking task. 15 | For ranking task, better label should have large number 16 | :return: 17 | """ 18 | if task_type == TaskType.Ranking: 19 | assert data_format == DataFormat.PremiseAndMultiHypothesis 20 | 21 | rows = [] 22 | for line in open(file_path, encoding="utf-8"): 23 | fields = line.strip("\n").split("\t") 24 | if data_format == DataFormat.PremiseOnly: 25 | assert len(fields) == 3 26 | row = {"uid": fields[0], "label": fields[1], "premise": fields[2]} 27 | elif data_format == DataFormat.PremiseAndOneHypothesis: 28 | assert len(fields) == 4 29 | row = {"uid": fields[0], "label": fields[1], "premise": fields[2], "hypothesis": fields[3]} 30 | elif data_format == DataFormat.PremiseAndMultiHypothesis: 31 | assert len(fields) > 5 32 | row = {"uid": fields[0], "ruid": fields[1].split(","), "label": fields[2], "premise": fields[3], 33 | "hypothesis": fields[4:]} 34 | else: 35 | raise ValueError(data_format) 36 | 37 | if task_type == TaskType.Classification: 38 | if label_dict is not None: 39 | row["label"] = label_dict[row["label"]] 40 | else: 41 | row["label"] = int(row["label"]) 42 | elif task_type == TaskType.Regression: 43 | row["label"] = float(row["label"]) 44 | elif task_type == TaskType.Ranking: 45 | labels = row["label"].split(",") 46 | if label_dict is not None: 47 | labels = [label_dict[label] for label in labels] 48 | else: 49 | labels = [float(label) for label in labels] 50 | row["label"] = int(np.argmax(labels)) 51 | row["olabel"] = labels 52 | 53 | rows.append(row) 54 | return rows 55 | 56 | 57 | def load_score_file(score_path, n_class): 58 | sample_id_2_pred_score_seg_dic = {} 59 | score_obj = json.loads(open(score_path, encoding="utf-8").read()) 60 | assert (len(score_obj["scores"]) % len(score_obj["uids"]) == 0) and \ 61 | (len(score_obj["scores"]) / len(score_obj["uids"]) == n_class), \ 62 | "scores column size should equal to sample count or multiple of sample count (for classification problem)" 63 | 64 | scores = score_obj["scores"] 65 | score_segs = [scores[i * n_class: (i+1) * n_class] for i in range(len(score_obj["uids"]))] 66 | for sample_id, pred, score_seg in zip(score_obj["uids"], score_obj["predictions"], score_segs): 67 | sample_id_2_pred_score_seg_dic[sample_id] = (pred, score_seg) 68 | return sample_id_2_pred_score_seg_dic -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Byte pair encoding utilities from GPT-2. 3 | 4 | Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py 5 | Original license: MIT 6 | """ 7 | 8 | from functools import lru_cache 9 | import json 10 | 11 | 12 | @lru_cache() 13 | def bytes_to_unicode(): 14 | """ 15 | Returns list of utf-8 byte and a corresponding list of unicode strings. 16 | The reversible bpe codes work on unicode strings. 17 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 18 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 19 | This is a signficant percentage of your normal, say, 32K bpe vocab. 20 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 21 | And avoids mapping to whitespace/control characters the bpe code barfs on. 22 | """ 23 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 24 | cs = bs[:] 25 | n = 0 26 | for b in range(2**8): 27 | if b not in bs: 28 | bs.append(b) 29 | cs.append(2**8+n) 30 | n += 1 31 | cs = [chr(n) for n in cs] 32 | return dict(zip(bs, cs)) 33 | 34 | def get_pairs(word): 35 | """Return set of symbol pairs in a word. 36 | Word is represented as tuple of symbols (symbols being variable-length strings). 37 | """ 38 | pairs = set() 39 | prev_char = word[0] 40 | for char in word[1:]: 41 | pairs.add((prev_char, char)) 42 | prev_char = char 43 | return pairs 44 | 45 | class Encoder: 46 | 47 | def __init__(self, encoder, bpe_merges, errors='replace'): 48 | self.encoder = encoder 49 | self.decoder = {v:k for k,v in self.encoder.items()} 50 | self.errors = errors # how to handle errors in decoding 51 | self.byte_encoder = bytes_to_unicode() 52 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 53 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 54 | self.cache = {} 55 | 56 | try: 57 | import regex as re 58 | self.re = re 59 | except ImportError: 60 | raise ImportError('Please install regex with: pip install regex') 61 | 62 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 63 | self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 64 | 65 | def bpe(self, token): 66 | if token in self.cache: 67 | return self.cache[token] 68 | word = tuple(token) 69 | pairs = get_pairs(word) 70 | 71 | if not pairs: 72 | return token 73 | 74 | while True: 75 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 76 | if bigram not in self.bpe_ranks: 77 | break 78 | first, second = bigram 79 | new_word = [] 80 | i = 0 81 | while i < len(word): 82 | try: 83 | j = word.index(first, i) 84 | new_word.extend(word[i:j]) 85 | i = j 86 | except: 87 | new_word.extend(word[i:]) 88 | break 89 | 90 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 91 | new_word.append(first+second) 92 | i += 2 93 | else: 94 | new_word.append(word[i]) 95 | i += 1 96 | new_word = tuple(new_word) 97 | word = new_word 98 | if len(word) == 1: 99 | break 100 | else: 101 | pairs = get_pairs(word) 102 | word = ' '.join(word) 103 | self.cache[token] = word 104 | return word 105 | 106 | def encode(self, text): 107 | bpe_tokens = [] 108 | for token in self.re.findall(self.pat, text): 109 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 110 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 111 | return bpe_tokens 112 | 113 | def decode(self, tokens): 114 | text = ''.join([self.decoder[token] for token in tokens]) 115 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 116 | return text 117 | 118 | def get_encoder(encoder_json_path, vocab_bpe_path): 119 | with open(encoder_json_path, 'r') as f: 120 | encoder = json.load(f) 121 | with open(vocab_bpe_path, 'r', encoding="utf-8") as f: 122 | bpe_data = f.read() 123 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 124 | return Encoder( 125 | encoder=encoder, 126 | bpe_merges=bpe_merges, 127 | ) 128 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/log_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import logging 3 | from time import gmtime, strftime 4 | import sys 5 | 6 | def create_logger(name, silent=False, to_disk=False, log_file=None): 7 | """Logger wrapper 8 | """ 9 | # setup logger 10 | log = logging.getLogger(name) 11 | log.setLevel(logging.DEBUG) 12 | log.propagate = False 13 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 14 | if not silent: 15 | ch = logging.StreamHandler(sys.stdout) 16 | ch.setLevel(logging.INFO) 17 | ch.setFormatter(formatter) 18 | log.addHandler(ch) 19 | if to_disk: 20 | log_file = log_file if log_file is not None else strftime("%Y-%m-%d-%H-%M-%S.log", gmtime()) 21 | fh = logging.FileHandler(log_file) 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | log.addHandler(fh) 25 | return log 26 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | from enum import Enum 3 | 4 | from sklearn.metrics import matthews_corrcoef 5 | from sklearn.metrics import accuracy_score, f1_score 6 | from sklearn.metrics import roc_auc_score 7 | from scipy.stats import pearsonr, spearmanr 8 | from seqeval.metrics import classification_report 9 | 10 | def compute_acc(predicts, labels): 11 | return 100.0 * accuracy_score(labels, predicts) 12 | 13 | def compute_f1(predicts, labels): 14 | return 100.0 * f1_score(labels, predicts) 15 | 16 | def compute_mcc(predicts, labels): 17 | return 100.0 * matthews_corrcoef(labels, predicts) 18 | 19 | def compute_pearson(predicts, labels): 20 | pcof = pearsonr(labels, predicts)[0] 21 | return 100.0 * pcof 22 | 23 | def compute_spearman(predicts, labels): 24 | scof = spearmanr(labels, predicts)[0] 25 | return 100.0 * scof 26 | 27 | def compute_auc(predicts, labels): 28 | auc = roc_auc_score(labels, predicts) 29 | return 100.0 * auc 30 | 31 | def compute_seqacc(predicts, labels, label_mapper): 32 | y_true, y_pred = [], [] 33 | def trim(predict, label): 34 | temp_1 = [] 35 | temp_2 = [] 36 | for j, m in enumerate(predict): 37 | if j == 0: 38 | continue 39 | if label_mapper[label[j]] != 'X': 40 | temp_1.append(label_mapper[label[j]]) 41 | temp_2.append(label_mapper[m]) 42 | temp_1.pop() 43 | temp_2.pop() 44 | y_true.append(temp_1) 45 | y_pred.append(temp_2) 46 | for predict, label in zip(predicts, labels): 47 | trim(predict, label) 48 | report = classification_report(y_true, y_pred,digits=4) 49 | return report 50 | 51 | class Metric(Enum): 52 | ACC = 0 53 | F1 = 1 54 | MCC = 2 55 | Pearson = 3 56 | Spearman = 4 57 | AUC = 5 58 | SeqEval = 7 59 | 60 | 61 | 62 | METRIC_FUNC = { 63 | Metric.ACC: compute_acc, 64 | Metric.F1: compute_f1, 65 | Metric.MCC: compute_mcc, 66 | Metric.Pearson: compute_pearson, 67 | Metric.Spearman: compute_spearman, 68 | Metric.AUC: compute_auc, 69 | Metric.SeqEval: compute_seqacc 70 | } 71 | 72 | 73 | def calc_metrics(metric_meta, golds, predictions, scores, label_mapper=None): 74 | """Label Mapper is used for NER/POS etc. 75 | TODO: a better refactor, by xiaodl 76 | """ 77 | metrics = {} 78 | for mm in metric_meta: 79 | metric_name = mm.name 80 | metric_func = METRIC_FUNC[mm] 81 | if mm in (Metric.ACC, Metric.F1, Metric.MCC): 82 | metric = metric_func(predictions, golds) 83 | elif mm == Metric.SeqEval: 84 | metric = metric_func(predictions, golds, label_mapper) 85 | else: 86 | if mm == Metric.AUC: 87 | assert len(scores) == 2 * len(golds), "AUC is only valid for binary classification problem" 88 | scores = scores[1::2] 89 | metric = metric_func(scores, golds) 90 | metrics[metric_name] = metric 91 | return metrics 92 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/task_def.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | 3 | from enum import IntEnum 4 | 5 | 6 | class TaskType(IntEnum): 7 | Classification = 1 8 | Regression = 2 9 | Ranking = 3 10 | Span = 4 11 | SequenceLabeling = 5 12 | 13 | 14 | class DataFormat(IntEnum): 15 | PremiseOnly = 1 16 | PremiseAndOneHypothesis = 2 17 | PremiseAndMultiHypothesis = 3 18 | Sequence = 4 19 | 20 | 21 | class EncoderModelType(IntEnum): 22 | BERT = 1 23 | ROBERTA = 2 24 | XLNET = 3 25 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import random 3 | import torch 4 | import numpy 5 | import subprocess 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value.""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | def set_environment(seed, set_cuda=False): 25 | random.seed(seed) 26 | numpy.random.seed(seed) 27 | torch.manual_seed(seed) 28 | if torch.cuda.is_available() and set_cuda: 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | def patch_var(v, cuda=True): 32 | if cuda: 33 | v = v.cuda(non_blocking=True) 34 | return v 35 | 36 | def get_gpu_memory_map(): 37 | result = subprocess.check_output( 38 | [ 39 | 'nvidia-smi', '--query-gpu=memory.used', 40 | '--format=csv,nounits,noheader' 41 | ], encoding='utf-8') 42 | gpu_memory = [int(x) for x in result.strip().split('\n')] 43 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 44 | return gpu_memory_map 45 | 46 | def get_pip_env(): 47 | result = subprocess.call(["pip", "freeze"]) 48 | return result 49 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/vocab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import tqdm 3 | import unicodedata 4 | 5 | PAD = 'PADPAD' 6 | UNK = 'UNKUNK' 7 | STA= 'BOSBOS' 8 | END = 'EOSEOS' 9 | 10 | PAD_ID = 0 11 | UNK_ID = 1 12 | STA_ID = 2 13 | END_ID = 3 14 | 15 | class Vocabulary(object): 16 | INIT_LEN = 4 17 | def __init__(self, neat=False): 18 | self.neat = neat 19 | if not neat: 20 | self.tok2ind = {PAD: PAD_ID, UNK: UNK_ID, STA: STA_ID, END: END_ID} 21 | self.ind2tok = {PAD_ID: PAD, UNK_ID: UNK, STA_ID: STA, END_ID:END} 22 | else: 23 | self.tok2ind = {} 24 | self.ind2tok = {} 25 | 26 | def __len__(self): 27 | return len(self.tok2ind) 28 | 29 | def __iter__(self): 30 | return iter(self.tok2ind) 31 | 32 | def __contains__(self, key): 33 | if type(key) == int: 34 | return key in self.ind2tok 35 | elif type(key) == str: 36 | return key in self.tok2ind 37 | 38 | def __getitem__(self, key): 39 | if type(key) == int: 40 | return self.ind2tok.get(key, -1) if self.neat else self.ind2tok.get(key, UNK) 41 | if type(key) == str: 42 | return self.tok2ind.get(key, None) if self.neat else self.tok2ind.get(key,self.tok2ind.get(UNK)) 43 | 44 | def __setitem__(self, key, item): 45 | if type(key) == int and type(item) == str: 46 | self.ind2tok[key] = item 47 | elif type(key) == str and type(item) == int: 48 | self.tok2ind[key] = item 49 | else: 50 | raise RuntimeError('Invalid (key, item) types.') 51 | 52 | def add(self, token): 53 | if token not in self.tok2ind: 54 | index = len(self.tok2ind) 55 | self.tok2ind[token] = index 56 | self.ind2tok[index] = token 57 | 58 | def get_vocab_list(self, with_order=True): 59 | if with_order: 60 | words = [self[k] for k in range(0, len(self))] 61 | else: 62 | words = [k for k in self.tok2ind.keys() 63 | if k not in {PAD, UNK, STA, END}] 64 | return words 65 | 66 | def toidx(self, tokens): 67 | return [self[tok] for tok in tokens] 68 | 69 | def copy(self): 70 | """Deep copy 71 | """ 72 | new_vocab = Vocabulary(self.neat) 73 | for w in self: 74 | new_vocab.add(w) 75 | return new_vocab 76 | 77 | def build(words, neat=False): 78 | vocab = Vocabulary(neat) 79 | for w in words: vocab.add(w) 80 | return vocab 81 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/data_utils/xlnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Texar Authors. All Rights Reserved. 2 | # Copyright (c) Microsoft. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import unicodedata 21 | import six 22 | from functools import partial 23 | 24 | 25 | SPIECE_UNDERLINE = '▁' 26 | 27 | special_symbols = { 28 | "" : 0, 29 | "" : 1, 30 | "" : 2, 31 | "" : 3, 32 | "" : 4, 33 | "" : 5, 34 | "" : 6, 35 | "" : 7, 36 | "" : 8, 37 | } 38 | 39 | VOCAB_SIZE = 32000 40 | UNK_ID = special_symbols[""] 41 | CLS_ID = special_symbols[""] 42 | SEP_ID = special_symbols[""] 43 | MASK_ID = special_symbols[""] 44 | EOD_ID = special_symbols[""] 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def print_(*args): 71 | new_args = [] 72 | for arg in args: 73 | if isinstance(arg, list): 74 | s = [printable_text(i) for i in arg] 75 | s = ' '.join(s) 76 | new_args.append(s) 77 | else: 78 | new_args.append(printable_text(arg)) 79 | print(*new_args) 80 | 81 | 82 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False): 83 | if remove_space: 84 | outputs = ' '.join(inputs.strip().split()) 85 | else: 86 | outputs = inputs 87 | outputs = outputs.replace("``", '"').replace("''", '"') 88 | 89 | if six.PY2 and isinstance(outputs, str): 90 | outputs = outputs.decode('utf-8') 91 | 92 | if not keep_accents: 93 | outputs = unicodedata.normalize('NFKD', outputs) 94 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 95 | if lower: 96 | outputs = outputs.lower() 97 | 98 | return outputs 99 | 100 | 101 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 102 | # return_unicode is used only for py2 103 | 104 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 105 | if six.PY2 and isinstance(text, unicode): 106 | text = text.encode('utf-8') 107 | 108 | if not sample: 109 | pieces = sp_model.EncodeAsPieces(text) 110 | else: 111 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 112 | new_pieces = [] 113 | for piece in pieces: 114 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 115 | cur_pieces = sp_model.EncodeAsPieces( 116 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 117 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 118 | if len(cur_pieces[0]) == 1: 119 | cur_pieces = cur_pieces[1:] 120 | else: 121 | cur_pieces[0] = cur_pieces[0][1:] 122 | cur_pieces.append(piece[-1]) 123 | new_pieces.extend(cur_pieces) 124 | else: 125 | new_pieces.append(piece) 126 | 127 | # note(zhiliny): convert back to unicode for py2 128 | if six.PY2 and return_unicode: 129 | ret_pieces = [] 130 | for piece in new_pieces: 131 | if isinstance(piece, str): 132 | piece = piece.decode('utf-8') 133 | ret_pieces.append(piece) 134 | new_pieces = ret_pieces 135 | 136 | return new_pieces 137 | 138 | 139 | def encode_ids(sp_model, text, sample=False): 140 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 141 | ids = [sp_model.PieceToId(piece) for piece in pieces] 142 | return ids 143 | 144 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/experiments/__init__.py -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/common_utils.py: -------------------------------------------------------------------------------- 1 | from data_utils import DataFormat 2 | 3 | 4 | def dump_rows(rows, out_path, data_format): 5 | """ 6 | output files should have following format 7 | :param rows: 8 | :param out_path: 9 | :return: 10 | """ 11 | with open(out_path, "w", encoding="utf-8") as out_f: 12 | row0 = rows[0] 13 | #data_format = detect_format(row0) 14 | for row in rows: 15 | #assert data_format == detect_format(row), row 16 | if data_format == DataFormat.PremiseOnly: 17 | for col in ["uid", "label", "premise"]: 18 | if "\t" in str(row[col]): 19 | import pdb; pdb.set_trace() 20 | out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"])) 21 | elif data_format == DataFormat.PremiseAndOneHypothesis: 22 | for col in ["uid", "label", "premise", "hypothesis"]: 23 | if "\t" in str(row[col]): 24 | import pdb; pdb.set_trace() 25 | out_f.write("%s\t%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"], row["hypothesis"])) 26 | elif data_format == DataFormat.PremiseAndMultiHypothesis: 27 | for col in ["uid", "label", "premise"]: 28 | if "\t" in str(row[col]): 29 | import pdb; pdb.set_trace() 30 | hypothesis = row["hypothesis"] 31 | for one_hypo in hypothesis: 32 | if "\t" in str(one_hypo): 33 | import pdb; pdb.set_trace() 34 | hypothesis = "\t".join(hypothesis) 35 | out_f.write("%s\t%s\t%s\t%s\t%s\n" % (row["uid"], row["ruid"], row["label"], row["premise"], hypothesis)) 36 | elif data_format == DataFormat.Sequence: 37 | for col in ["uid", "label", "premise"]: 38 | if "\t" in str(row[col]): 39 | import pdb; pdb.set_trace() 40 | out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"])) 41 | else: 42 | raise ValueError(data_format) -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/exp_def.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from mt_bluebert.data_utils.vocab import Vocabulary 3 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat, EncoderModelType 4 | from mt_bluebert.data_utils.metrics import Metric 5 | 6 | class TaskDefs: 7 | def __init__(self, task_def_path): 8 | self._task_def_dic = yaml.safe_load(open(task_def_path)) 9 | global_map = {} 10 | n_class_map = {} 11 | data_type_map = {} 12 | task_type_map = {} 13 | metric_meta_map = {} 14 | enable_san_map = {} 15 | dropout_p_map = {} 16 | encoderType_map = {} 17 | uniq_encoderType = set() 18 | for task, task_def in self._task_def_dic.items(): 19 | assert "_" not in task, "task name should not contain '_', current task name: %s" % task 20 | n_class_map[task] = task_def["n_class"] 21 | data_format = DataFormat[task_def["data_format"]] 22 | data_type_map[task] = data_format 23 | task_type_map[task] = TaskType[task_def["task_type"]] 24 | metric_meta_map[task] = tuple(Metric[metric_name] for metric_name in task_def["metric_meta"]) 25 | enable_san_map[task] = task_def["enable_san"] 26 | uniq_encoderType.add(EncoderModelType[task_def["encoder_type"]]) 27 | if "labels" in task_def: 28 | labels = task_def["labels"] 29 | label_mapper = Vocabulary(True) 30 | for label in labels: 31 | label_mapper.add(label) 32 | global_map[task] = label_mapper 33 | if "dropout_p" in task_def: 34 | dropout_p_map[task] = task_def["dropout_p"] 35 | 36 | assert len(uniq_encoderType) == 1, 'The shared encoder has to be the same.' 37 | self.global_map = global_map 38 | self.n_class_map = n_class_map 39 | self.data_type_map = data_type_map 40 | self.task_type_map = task_type_map 41 | self.metric_meta_map = metric_meta_map 42 | self.enable_san_map = enable_san_map 43 | self.dropout_p_map = dropout_p_map 44 | self.encoderType = uniq_encoderType.pop() 45 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/experiments/squad/__init__.py -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/squad/squad_prepro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from sys import path 4 | import json 5 | path.append(os.getcwd()) 6 | from data_utils.log_wrapper import create_logger 7 | from experiments.common_utils import dump_rows 8 | 9 | logger = create_logger(__name__, to_disk=True, log_file='squad_prepro.log') 10 | 11 | def normalize_qa_field(s: str, replacement_list): 12 | for replacement in replacement_list: 13 | s = s.replace(replacement, " " * len(replacement)) # ensure answer_start and answer_end still valid 14 | return s 15 | 16 | END = 'EOSEOS' 17 | def load_data(path, is_train=True, v2_on=False): 18 | rows = [] 19 | with open(path, encoding="utf8") as f: 20 | data = json.load(f)['data'] 21 | for article in data: 22 | for paragraph in article['paragraphs']: 23 | context = paragraph['context'] 24 | if v2_on: 25 | context = '{} {}'.format(context, END) 26 | for qa in paragraph['qas']: 27 | uid, question = qa['id'], qa['question'] 28 | answers = qa.get('answers', []) 29 | # used for v2.0 30 | is_impossible = qa.get('is_impossible', False) 31 | label = 1 if is_impossible else 0 32 | if (v2_on and label < 1 and len(answers) < 1) or ((not v2_on) and len(answers) < 1): 33 | # detect inconsistent data 34 | # * for v2, the row is possible but has no answer 35 | # * for v1, all questions should have answer 36 | continue 37 | if len(answers) > 0: 38 | answer = answers[0]['text'] 39 | answer_start = answers[0]['answer_start'] 40 | answer_end = answer_start + len(answer) 41 | else: 42 | # for questions without answers, give a fake answer 43 | answer = END 44 | answer_start = len(context) - len(END) 45 | answer_end = len(context) 46 | answer = normalize_qa_field(answer, ["\n", "\t", ":::"]) 47 | context = normalize_qa_field(context, ["\n", "\t"]) 48 | question = normalize_qa_field(question, ["\n", "\t"]) 49 | sample = {'uid': uid, 'premise': context, 'hypothesis': question, 50 | 'label': "%s:::%s:::%s:::%s" % (answer_start, answer_end, label, answer)} 51 | rows.append(sample) 52 | return rows 53 | 54 | def parse_args(): 55 | parser = argparse.ArgumentParser(description='Preprocessing SQUAD data.') 56 | parser.add_argument('--root_dir', type=str, default='data') 57 | args = parser.parse_args() 58 | return args 59 | 60 | def main(args): 61 | root = args.root_dir 62 | assert os.path.exists(root) 63 | 64 | squad_train_path = os.path.join(root, 'squad/train.json') 65 | squad_dev_path = os.path.join(root, 'squad/dev.json') 66 | squad_v2_train_path = os.path.join(root, 'squad_v2/train.json') 67 | squad_v2_dev_path = os.path.join(root, 'squad_v2/dev.json') 68 | 69 | squad_train_data = load_data(squad_train_path) 70 | squad_dev_data = load_data(squad_dev_path, is_train=False) 71 | logger.info('Loaded {} squad train samples'.format(len(squad_train_data))) 72 | logger.info('Loaded {} squad dev samples'.format(len(squad_dev_data))) 73 | 74 | squad_v2_train_data = load_data(squad_v2_train_path, v2_on=True) 75 | squad_v2_dev_data = load_data(squad_v2_dev_path, is_train=False, v2_on=True) 76 | logger.info('Loaded {} squad_v2 train samples'.format(len(squad_v2_train_data))) 77 | logger.info('Loaded {} squad_v2 dev samples'.format(len(squad_v2_dev_data))) 78 | 79 | canonical_data_suffix = "canonical_data" 80 | canonical_data_root = os.path.join(root, canonical_data_suffix) 81 | if not os.path.isdir(canonical_data_root): 82 | os.mkdir(canonical_data_root) 83 | 84 | squad_train_fout = os.path.join(canonical_data_root, 'squad_train.tsv') 85 | squad_dev_fout = os.path.join(canonical_data_root, 'squad_dev.tsv') 86 | dump_rows(squad_train_data, squad_train_fout) 87 | dump_rows(squad_dev_data, squad_dev_fout) 88 | logger.info('done with squad') 89 | 90 | squad_v2_train_fout = os.path.join(canonical_data_root, 'squad-v2_train.tsv') 91 | squad_v2_dev_fout = os.path.join(canonical_data_root, 'squad-v2_dev.tsv') 92 | dump_rows(squad_v2_train_data, squad_v2_train_fout) 93 | dump_rows(squad_v2_dev_data, squad_v2_dev_fout) 94 | logger.info('done with squad_v2') 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | args = parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/squad/squad_task_def.yml: -------------------------------------------------------------------------------- 1 | squad-v2: 2 | data_format: PremiseAndOneHypothesis 3 | encoder_type: BERT 4 | dropout_p: 0.05 5 | enable_san: false 6 | metric_meta: 7 | - ACC 8 | - MCC 9 | n_class: 1 10 | task_type: Span 11 | split_names: 12 | - train 13 | - dev 14 | squad: 15 | data_format: PremiseAndOneHypothesis 16 | encoder_type: BERT 17 | dropout_p: 0.05 18 | enable_san: false 19 | metric_meta: 20 | - ACC 21 | - MCC 22 | n_class: 1 23 | task_type: Span 24 | split_names: 25 | - train 26 | - dev 27 | 28 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/squad/squad_utils.py: -------------------------------------------------------------------------------- 1 | from data_utils.task_def import EncoderModelType 2 | 3 | 4 | def calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, encoderModelType, 5 | verbose=False): 6 | """ 7 | 8 | :param context: 9 | :param question: 10 | :param answer: 11 | :param answer_start: 12 | :param answer_end: 13 | :param tokenizer: 14 | :param encoderModelType: 15 | :param verbose: 16 | :return: span_start, span_end 17 | """ 18 | assert encoderModelType == EncoderModelType.BERT 19 | prefix = context[:answer_start] 20 | prefix_tokens = tokenizer.tokenize(prefix) 21 | full = context[:answer_end] 22 | full_tokens = tokenizer.tokenize(full) 23 | span_start = len(prefix_tokens) 24 | span_end = len(full_tokens) 25 | span_tokens = full_tokens[span_start: span_end] 26 | recovered_answer = " ".join(span_tokens).replace(" ##", "") 27 | cleaned_answer = " ".join(tokenizer.basic_tokenizer.tokenize(answer)) 28 | if verbose: 29 | try: 30 | assert recovered_answer == cleaned_answer, "answer: %s, recovered_answer: %s, question: %s, select:%s ext_select:%s context: %s" % ( 31 | cleaned_answer, recovered_answer, question, context[answer_start:answer_end], 32 | context[answer_start - 5:answer_end + 5], context) 33 | except Exception as e: 34 | pass 35 | print(e) 36 | return span_start, span_end 37 | 38 | 39 | def parse_squad_label(label): 40 | """ 41 | 42 | :param label: 43 | :return: answer_start, answer_end, answer, is_impossible 44 | """ 45 | answer_start, answer_end, is_impossible, answer = label.split(":::") 46 | answer_start = int(answer_start) 47 | answer_end = int(answer_end) 48 | is_impossible = int(is_impossible) 49 | return answer_start, answer_end, answer, is_impossible -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/experiments/squad/verify_calc_span.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert import BertTokenizer 2 | from data_utils.task_def import EncoderModelType 3 | from experiments.squad.squad_utils import calc_tokenized_span_range, parse_squad_label 4 | 5 | model = "bert-base-uncased" 6 | do_lower_case = True 7 | tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=do_lower_case) 8 | 9 | for no, line in enumerate(open(r"data\canonical_data\squad_v2_train.tsv", encoding="utf-8")): 10 | if no % 1000 == 0: 11 | print(no) 12 | uid, label, context, question = line.strip().split("\t") 13 | answer_start, answer_end, answer, is_impossible = parse_squad_label(label) 14 | calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, EncoderModelType.BERT, 15 | verbose=True) 16 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/module/__init__.py -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import torch 3 | import math 4 | from torch.nn.functional import tanh, relu, prelu, leaky_relu, sigmoid, elu, selu 5 | from torch.nn.init import uniform, normal, eye, xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, orthogonal 6 | 7 | def linear(x): 8 | return x 9 | 10 | def swish(x): 11 | return x * sigmoid(x) 12 | 13 | def bertgelu(x): 14 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 15 | 16 | def gptgelu(x): 17 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 18 | 19 | # default gelue 20 | gelu = bertgelu 21 | 22 | def activation(func_a): 23 | """Activation function wrapper 24 | """ 25 | try: 26 | f = eval(func_a) 27 | except: 28 | f = linear 29 | return f 30 | 31 | def init_wrapper(init='xavier_uniform'): 32 | return eval(init) 33 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/dropout_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class DropoutWrapper(nn.Module): 7 | """ 8 | This is a dropout wrapper which supports the fix mask dropout 9 | """ 10 | def __init__(self, dropout_p=0, enable_vbp=True): 11 | super(DropoutWrapper, self).__init__() 12 | """variational dropout means fix dropout mask 13 | ref: https://discuss.pytorch.org/t/dropout-for-rnns/633/11 14 | """ 15 | self.enable_variational_dropout = enable_vbp 16 | self.dropout_p = dropout_p 17 | 18 | def forward(self, x): 19 | """ 20 | :param x: batch * len * input_size 21 | """ 22 | if self.training == False or self.dropout_p == 0: 23 | return x 24 | 25 | if len(x.size()) == 3: 26 | mask = 1.0 / (1-self.dropout_p) * torch.bernoulli((1-self.dropout_p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1)) 27 | mask.requires_grad = False 28 | return mask.unsqueeze(1).expand_as(x) * x 29 | else: 30 | return F.dropout(x, p=self.dropout_p, training=self.training) 31 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/my_optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | from copy import deepcopy 3 | import torch 4 | from torch.nn import Parameter 5 | from functools import wraps 6 | 7 | class EMA: 8 | def __init__(self, gamma, model): 9 | super(EMA, self).__init__() 10 | self.gamma = gamma 11 | self.shadow = {} 12 | self.model = model 13 | self.setup() 14 | 15 | def setup(self): 16 | for name, para in self.model.named_parameters(): 17 | if para.requires_grad: 18 | self.shadow[name] = para.clone() 19 | def cuda(self): 20 | for k, v in self.shadow.items(): 21 | self.shadow[k] = v.cuda() 22 | 23 | def update(self): 24 | for name,para in self.model.named_parameters(): 25 | if para.requires_grad: 26 | self.shadow[name] = (1.0 - self.gamma) * para + self.gamma * self.shadow[name] 27 | 28 | def swap_parameters(self): 29 | for name, para in self.model.named_parameters(): 30 | if para.requires_grad: 31 | temp_data = para.data 32 | para.data = self.shadow[name].data 33 | self.shadow[name].data = temp_data 34 | 35 | def state_dict(self): 36 | return self.shadow 37 | 38 | 39 | # Adapted from 40 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 41 | # and https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py 42 | def _norm(p, dim): 43 | """Computes the norm over all dimensions except dim""" 44 | if dim is None: 45 | return p.norm() 46 | elif dim == 0: 47 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 48 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 49 | elif dim == p.dim() - 1: 50 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 51 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 52 | else: 53 | return _norm(p.transpose(0, dim), 0).transpose(0, dim) 54 | 55 | 56 | def _dummy(*args, **kwargs): 57 | # We need to replace flatten_parameters with a nothing function 58 | return 59 | 60 | 61 | class WeightNorm(torch.nn.Module): 62 | 63 | def __init__(self, weights, dim): 64 | super(WeightNorm, self).__init__() 65 | self.weights = weights 66 | self.dim = dim 67 | 68 | def compute_weight(self, module, name): 69 | g = getattr(module, name + '_g') 70 | v = getattr(module, name + '_v') 71 | return v * (g / _norm(v, self.dim)) 72 | 73 | @staticmethod 74 | def apply(module, weights, dim): 75 | # Terrible temporary solution to an issue regarding compacting weights 76 | # re: CUDNN RNN 77 | if issubclass(type(module), torch.nn.RNNBase): 78 | module.flatten_parameters = _dummy 79 | if weights is None: # do for all weight params 80 | weights = [w for w in module._parameters.keys() if 'weight' in w] 81 | fn = WeightNorm(weights, dim) 82 | for name in weights: 83 | if hasattr(module, name): 84 | print('Applying weight norm to {} - {}'.format(str(module), name)) 85 | weight = getattr(module, name) 86 | del module._parameters[name] 87 | module.register_parameter( 88 | name + '_g', Parameter(_norm(weight, dim).data)) 89 | module.register_parameter(name + '_v', Parameter(weight.data)) 90 | setattr(module, name, fn.compute_weight(module, name)) 91 | 92 | module.register_forward_pre_hook(fn) 93 | 94 | return fn 95 | 96 | def remove(self, module): 97 | for name in self.weights: 98 | weight = self.compute_weight(module) 99 | delattr(module, name) 100 | del module._parameters[name + '_g'] 101 | del module._parameters[name + '_v'] 102 | module.register_parameter(name, Parameter(weight.data)) 103 | 104 | def __call__(self, module, inputs): 105 | for name in self.weights: 106 | setattr(module, name, self.compute_weight(module, name)) 107 | 108 | 109 | def weight_norm(module, weights=None, dim=0): 110 | WeightNorm.apply(module, weights, dim) 111 | return module 112 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/san.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | from torch.nn.utils import weight_norm 6 | from torch.nn.parameter import Parameter 7 | import torch.nn.functional as F 8 | from mt_bluebert.module.dropout_wrapper import DropoutWrapper 9 | from mt_bluebert.module.similarity import FlatSimilarityWrapper, SelfAttnWrapper 10 | from mt_bluebert.module.my_optim import weight_norm as WN 11 | 12 | SMALL_POS_NUM=1.0e-30 13 | 14 | def generate_mask(new_data, dropout_p=0.0, is_training=False): 15 | if not is_training: dropout_p = 0.0 16 | new_data = (1-dropout_p) * (new_data.zero_() + 1) 17 | for i in range(new_data.size(0)): 18 | one = random.randint(0, new_data.size(1)-1) 19 | new_data[i][one] = 1 20 | mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data) 21 | mask.requires_grad = False 22 | return mask 23 | 24 | 25 | class Classifier(nn.Module): 26 | def __init__(self, x_size, y_size, opt, prefix='decoder', dropout=None): 27 | super(Classifier, self).__init__() 28 | self.opt = opt 29 | if dropout is None: 30 | self.dropout = DropoutWrapper(opt.get('{}_dropout_p'.format(prefix), 0)) 31 | else: 32 | self.dropout = dropout 33 | self.merge_opt = opt.get('{}_merge_opt'.format(prefix), 0) 34 | self.weight_norm_on = opt.get('{}_weight_norm_on'.format(prefix), False) 35 | 36 | if self.merge_opt == 1: 37 | self.proj = nn.Linear(x_size * 4, y_size) 38 | else: 39 | self.proj = nn.Linear(x_size * 2, y_size) 40 | 41 | if self.weight_norm_on: 42 | self.proj = weight_norm(self.proj) 43 | 44 | def forward(self, x1, x2, mask=None): 45 | if self.merge_opt == 1: 46 | x = torch.cat([x1, x2, (x1 - x2).abs(), x1 * x2], 1) 47 | else: 48 | x = torch.cat([x1, x2], 1) 49 | x = self.dropout(x) 50 | scores = self.proj(x) 51 | return scores 52 | 53 | class SANClassifier(nn.Module): 54 | """Implementation of Stochastic Answer Networks for Natural Language Inference, Xiaodong Liu, Kevin Duh and Jianfeng Gao 55 | https://arxiv.org/abs/1804.07888 56 | """ 57 | def __init__(self, x_size, h_size, label_size, opt={}, prefix='decoder', dropout=None): 58 | super(SANClassifier, self).__init__() 59 | if dropout is None: 60 | self.dropout = DropoutWrapper(opt.get('{}_dropout_p'.format(self.prefix), 0)) 61 | else: 62 | self.dropout = dropout 63 | self.prefix = prefix 64 | self.query_wsum = SelfAttnWrapper(x_size, prefix='mem_cum', opt=opt, dropout=self.dropout) 65 | self.attn = FlatSimilarityWrapper(x_size, h_size, prefix, opt, self.dropout) 66 | self.rnn_type = '{}{}'.format(opt.get('{}_rnn_type'.format(prefix), 'gru').upper(), 'Cell') 67 | self.rnn =getattr(nn, self.rnn_type)(x_size, h_size) 68 | self.num_turn = opt.get('{}_num_turn'.format(prefix), 5) 69 | self.opt = opt 70 | self.mem_random_drop = opt.get('{}_mem_drop_p'.format(prefix), 0) 71 | self.mem_type = opt.get('{}_mem_type'.format(prefix), 0) 72 | self.weight_norm_on = opt.get('{}_weight_norm_on'.format(prefix), False) 73 | self.label_size = label_size 74 | self.dump_state = opt.get('dump_state_on', False) 75 | self.alpha = Parameter(torch.zeros(1, 1), requires_grad=False) 76 | if self.weight_norm_on: 77 | self.rnn = WN(self.rnn) 78 | 79 | self.classifier = Classifier(x_size, self.label_size, opt, prefix=prefix, dropout=self.dropout) 80 | 81 | def forward(self, x, h0, x_mask=None, h_mask=None): 82 | h0 = self.query_wsum(h0, h_mask) 83 | if type(self.rnn) is nn.LSTMCell: 84 | c0 = h0.new(h0.size()).zero_() 85 | scores_list = [] 86 | for turn in range(self.num_turn): 87 | att_scores = self.attn(x, h0, x_mask) 88 | x_sum = torch.bmm(F.softmax(att_scores, 1).unsqueeze(1), x).squeeze(1) 89 | scores = self.classifier(x_sum, h0) 90 | scores_list.append(scores) 91 | # next turn 92 | if self.rnn is not None: 93 | h0 = self.dropout(h0) 94 | if type(self.rnn) is nn.LSTMCell: 95 | h0, c0 = self.rnn(x_sum, (h0, c0)) 96 | else: 97 | h0 = self.rnn(x_sum, h0) 98 | if self.mem_type == 1: 99 | mask = generate_mask(self.alpha.data.new(x.size(0), self.num_turn), self.mem_random_drop, self.training) 100 | mask = [m.contiguous() for m in torch.unbind(mask, 1)] 101 | tmp_scores_list = [mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1) for idx, inp in enumerate(scores_list)] 102 | scores = torch.stack(tmp_scores_list, 2) 103 | scores = torch.mean(scores, 2) 104 | scores = torch.log(scores) 105 | else: 106 | scores = scores_list[-1] 107 | if self.dump_state: 108 | return scores, scores_list 109 | else: 110 | return scores 111 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/module/sub_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | class LayerNorm(nn.Module): 8 | #ref: https://github.com/pytorch/pytorch/issues/1959 9 | # :https://arxiv.org/pdf/1607.06450.pdf 10 | def __init__(self, hidden_size, eps=1e-4): 11 | super(LayerNorm, self).__init__() 12 | self.alpha = Parameter(torch.ones(1,1,hidden_size)) # gain g 13 | self.beta = Parameter(torch.zeros(1,1,hidden_size)) # bias b 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | """ 18 | Args: 19 | :param x: batch * len * input_size 20 | 21 | Returns: 22 | normalized x 23 | """ 24 | mu = torch.mean(x, 2, keepdim=True).expand_as(x) 25 | sigma = torch.std(x, 2, keepdim=True).expand_as(x) 26 | return (x - mu) / (sigma + self.eps) * self.alpha.expand_as(x) + self.beta.expand_as(x) 27 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/mt_dnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/mt_dnn/__init__.py -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/mt_dnn/batcher.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Microsoft. All rights reserved. 3 | import sys 4 | import json 5 | import torch 6 | import random 7 | from shutil import copyfile 8 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat 9 | from mt_bluebert.data_utils.task_def import EncoderModelType 10 | 11 | UNK_ID=100 12 | BOS_ID=101 13 | 14 | class BatchGen: 15 | def __init__(self, data, batch_size=32, gpu=True, is_train=True, 16 | maxlen=128, dropout_w=0.005, 17 | do_batch=True, weighted_on=False, 18 | task_id=0, 19 | task=None, 20 | task_type=TaskType.Classification, 21 | data_type=DataFormat.PremiseOnly, 22 | soft_label=False, 23 | encoder_type=EncoderModelType.BERT): 24 | self.batch_size = batch_size 25 | self.maxlen = maxlen 26 | self.is_train = is_train 27 | self.gpu = gpu 28 | self.weighted_on = weighted_on 29 | self.data = data 30 | self.task_id = task_id 31 | self.pairwise_size = 1 32 | self.data_type = data_type 33 | self.task_type=task_type 34 | self.encoder_type = encoder_type 35 | # soft label used for knowledge distillation 36 | self.soft_label_on = soft_label 37 | if do_batch: 38 | if is_train: 39 | indices = list(range(len(self.data))) 40 | random.shuffle(indices) 41 | data = [self.data[i] for i in indices] 42 | self.data = BatchGen.make_baches(data, batch_size) 43 | self.offset = 0 44 | self.dropout_w = dropout_w 45 | 46 | @staticmethod 47 | def make_baches(data, batch_size=32): 48 | return [data[i:i + batch_size] for i in range(0, len(data), batch_size)] 49 | 50 | @staticmethod 51 | def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None): 52 | assert task_type is not None 53 | with open(path, 'r', encoding='utf-8') as reader: 54 | data = [] 55 | cnt = 0 56 | for line in reader: 57 | sample = json.loads(line) 58 | sample['factor'] = factor 59 | cnt += 1 60 | if is_train: 61 | if (task_type == TaskType.Ranking) and (len(sample['token_id'][0]) > maxlen or len(sample['token_id'][1]) > maxlen): 62 | continue 63 | if (task_type != TaskType.Ranking) and (len(sample['token_id']) > maxlen): 64 | continue 65 | data.append(sample) 66 | print('Loaded {} samples out of {}'.format(len(data), cnt)) 67 | return data 68 | 69 | def reset(self): 70 | if self.is_train: 71 | indices = list(range(len(self.data))) 72 | random.shuffle(indices) 73 | self.data = [self.data[i] for i in indices] 74 | self.offset = 0 75 | 76 | def __random_select__(self, arr): 77 | if self.dropout_w > 0: 78 | return [UNK_ID if random.uniform(0, 1) < self.dropout_w else e for e in arr] 79 | else: return arr 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | 84 | def patch(self, v): 85 | v = v.cuda(non_blocking=True) 86 | return v 87 | 88 | @staticmethod 89 | def todevice(v, device): 90 | v = v.to(device) 91 | return v 92 | 93 | def rebacth(self, batch): 94 | newbatch = [] 95 | for sample in batch: 96 | size = len(sample['token_id']) 97 | self.pairwise_size = size 98 | assert size == len(sample['type_id']) 99 | for idx in range(0, size): 100 | token_id = sample['token_id'][idx] 101 | type_id = sample['type_id'][idx] 102 | uid = sample['ruid'][idx] 103 | olab = sample['olabel'][idx] 104 | newbatch.append({'uid': uid, 'token_id': token_id, 'type_id': type_id, 'label':sample['label'], 'true_label': olab}) 105 | return newbatch 106 | 107 | def __if_pair__(self, data_type): 108 | return data_type in [DataFormat.PremiseAndOneHypothesis, DataFormat.PremiseAndMultiHypothesis] 109 | 110 | def __iter__(self): 111 | while self.offset < len(self): 112 | batch = self.data[self.offset] 113 | if self.task_type == TaskType.Ranking: 114 | batch = self.rebacth(batch) 115 | 116 | # prepare model input 117 | batch_data, batch_info = self._prepare_model_input(batch) 118 | batch_info['task_id'] = self.task_id # used for select correct decoding head 119 | batch_info['input_len'] = len(batch_data) # used to select model inputs 120 | # select different loss function and other difference in training and testing 121 | batch_info['task_type'] = self.task_type 122 | batch_info['pairwise_size'] = self.pairwise_size # need for ranking task 123 | if self.gpu: 124 | for i, item in enumerate(batch_data): 125 | batch_data[i] = self.patch(item.pin_memory()) 126 | 127 | # add label 128 | labels = [sample['label'] for sample in batch] 129 | if self.is_train: 130 | # in training model, label is used by Pytorch, so would be tensor 131 | if self.task_type == TaskType.Regression: 132 | batch_data.append(torch.FloatTensor(labels)) 133 | batch_info['label'] = len(batch_data) - 1 134 | elif self.task_type in (TaskType.Classification, TaskType.Ranking): 135 | batch_data.append(torch.LongTensor(labels)) 136 | batch_info['label'] = len(batch_data) - 1 137 | elif self.task_type == TaskType.Span: 138 | start = [sample['token_start'] for sample in batch] 139 | end = [sample['token_end'] for sample in batch] 140 | batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)]) 141 | batch_info['start'] = len(batch_data) - 2 142 | batch_info['end'] = len(batch_data) - 1 143 | elif self.task_type == TaskType.SequenceLabeling: 144 | batch_size = self._get_batch_size(batch) 145 | tok_len = self._get_max_len(batch, key='token_id') 146 | tlab = torch.LongTensor(batch_size, tok_len).fill_(-1) 147 | for i, label in enumerate(labels): 148 | ll = len(label) 149 | tlab[i, : ll] = torch.LongTensor(label) 150 | batch_data.append(tlab) 151 | batch_info['label'] = len(batch_data) - 1 152 | 153 | # soft label generated by ensemble models for knowledge distillation 154 | if self.soft_label_on and (batch[0].get('softlabel', None) is not None): 155 | assert self.task_type != TaskType.Span # Span task doesn't support soft label yet. 156 | sortlabels = [sample['softlabel'] for sample in batch] 157 | sortlabels = torch.FloatTensor(sortlabels) 158 | batch_info['soft_label'] = self.patch(sortlabels.pin_memory()) if self.gpu else sortlabels 159 | else: 160 | # in test model, label would be used for evaluation 161 | batch_info['label'] = labels 162 | if self.task_type == TaskType.Ranking: 163 | batch_info['true_label'] = [sample['true_label'] for sample in batch] 164 | 165 | batch_info['uids'] = [sample['uid'] for sample in batch] # used in scoring 166 | self.offset += 1 167 | yield batch_info, batch_data 168 | 169 | def _get_max_len(self, batch, key='token_id'): 170 | tok_len = max(len(x[key]) for x in batch) 171 | return tok_len 172 | 173 | def _get_batch_size(self, batch): 174 | return len(batch) 175 | 176 | def _prepare_model_input(self, batch): 177 | batch_size = self._get_batch_size(batch) 178 | tok_len = self._get_max_len(batch, key='token_id') 179 | #tok_len = max(len(x['token_id']) for x in batch) 180 | hypothesis_len = max(len(x['type_id']) - sum(x['type_id']) for x in batch) 181 | if self.encoder_type == EncoderModelType.ROBERTA: 182 | token_ids = torch.LongTensor(batch_size, tok_len).fill_(1) 183 | type_ids = torch.LongTensor(batch_size, tok_len).fill_(0) 184 | masks = torch.LongTensor(batch_size, tok_len).fill_(0) 185 | else: 186 | token_ids = torch.LongTensor(batch_size, tok_len).fill_(0) 187 | type_ids = torch.LongTensor(batch_size, tok_len).fill_(0) 188 | masks = torch.LongTensor(batch_size, tok_len).fill_(0) 189 | if self.__if_pair__(self.data_type): 190 | premise_masks = torch.ByteTensor(batch_size, tok_len).fill_(1) 191 | hypothesis_masks = torch.ByteTensor(batch_size, hypothesis_len).fill_(1) 192 | for i, sample in enumerate(batch): 193 | select_len = min(len(sample['token_id']), tok_len) 194 | tok = sample['token_id'] 195 | if self.is_train: 196 | tok = self.__random_select__(tok) 197 | token_ids[i, :select_len] = torch.LongTensor(tok[:select_len]) 198 | type_ids[i, :select_len] = torch.LongTensor(sample['type_id'][:select_len]) 199 | masks[i, :select_len] = torch.LongTensor([1] * select_len) 200 | if self.__if_pair__(self.data_type): 201 | hlen = len(sample['type_id']) - sum(sample['type_id']) 202 | hypothesis_masks[i, :hlen] = torch.LongTensor([0] * hlen) 203 | for j in range(hlen, select_len): 204 | premise_masks[i, j] = 0 205 | if self.__if_pair__(self.data_type): 206 | batch_info = { 207 | 'token_id': 0, 208 | 'segment_id': 1, 209 | 'mask': 2, 210 | 'premise_mask': 3, 211 | 'hypothesis_mask': 4 212 | } 213 | batch_data = [token_ids, type_ids, masks, premise_masks, hypothesis_masks] 214 | else: 215 | batch_info = { 216 | 'token_id': 0, 217 | 'segment_id': 1, 218 | 'mask': 2 219 | } 220 | batch_data = [token_ids, type_ids, masks] 221 | return batch_data, batch_info 222 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/mt_dnn/inference.py: -------------------------------------------------------------------------------- 1 | from mt_bluebert.data_utils.metrics import calc_metrics 2 | 3 | def eval_model(model, data, metric_meta, use_cuda=True, with_label=True, label_mapper=None): 4 | data.reset() 5 | if use_cuda: 6 | model.cuda() 7 | predictions = [] 8 | golds = [] 9 | scores = [] 10 | ids = [] 11 | metrics = {} 12 | for batch_meta, batch_data in data: 13 | score, pred, gold = model.predict(batch_meta, batch_data) 14 | predictions.extend(pred) 15 | golds.extend(gold) 16 | scores.extend(score) 17 | ids.extend(batch_meta['uids']) 18 | if with_label: 19 | metrics = calc_metrics(metric_meta, golds, predictions, scores, label_mapper) 20 | return metrics, predictions, scores, golds, ids 21 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/mt_dnn/matcher.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Microsoft. All rights reserved. 3 | import torch.nn as nn 4 | from pytorch_pretrained_bert.modeling import BertConfig, BertLayerNorm, BertModel 5 | 6 | from mt_bluebert.module.dropout_wrapper import DropoutWrapper 7 | from mt_bluebert.module.san import SANClassifier 8 | from mt_bluebert.data_utils.task_def import EncoderModelType, TaskType 9 | 10 | 11 | class LinearPooler(nn.Module): 12 | def __init__(self, hidden_size): 13 | super(LinearPooler, self).__init__() 14 | self.dense = nn.Linear(hidden_size, hidden_size) 15 | self.activation = nn.Tanh() 16 | 17 | def forward(self, hidden_states): 18 | first_token_tensor = hidden_states[:, 0] 19 | pooled_output = self.dense(first_token_tensor) 20 | pooled_output = self.activation(pooled_output) 21 | return pooled_output 22 | 23 | class SANBertNetwork(nn.Module): 24 | def __init__(self, opt, bert_config=None): 25 | super(SANBertNetwork, self).__init__() 26 | self.dropout_list = nn.ModuleList() 27 | self.encoder_type = opt['encoder_type'] 28 | if opt['encoder_type'] == EncoderModelType.ROBERTA: 29 | from fairseq.models.roberta import RobertaModel 30 | self.bert = RobertaModel.from_pretrained(opt['init_checkpoint']) 31 | hidden_size = self.bert.args.encoder_embed_dim 32 | self.pooler = LinearPooler(hidden_size) 33 | else: 34 | self.bert_config = BertConfig.from_dict(opt) 35 | self.bert = BertModel(self.bert_config) 36 | hidden_size = self.bert_config.hidden_size 37 | 38 | if opt.get('dump_feature', False): 39 | self.opt = opt 40 | return 41 | if opt['update_bert_opt'] > 0: 42 | for p in self.bert.parameters(): 43 | p.requires_grad = False 44 | self.decoder_opt = opt['answer_opt'] 45 | self.task_types = opt["task_types"] 46 | self.scoring_list = nn.ModuleList() 47 | labels = [int(ls) for ls in opt['label_size'].split(',')] 48 | task_dropout_p = opt['tasks_dropout_p'] 49 | 50 | for task, lab in enumerate(labels): 51 | decoder_opt = self.decoder_opt[task] 52 | task_type = self.task_types[task] 53 | dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout']) 54 | self.dropout_list.append(dropout) 55 | if task_type == TaskType.Span: 56 | assert decoder_opt != 1 57 | out_proj = nn.Linear(hidden_size, 2) 58 | elif task_type == TaskType.SequenceLabeling: 59 | out_proj = nn.Linear(hidden_size, lab) 60 | else: 61 | if decoder_opt == 1: 62 | out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout) 63 | else: 64 | out_proj = nn.Linear(hidden_size, lab) 65 | self.scoring_list.append(out_proj) 66 | 67 | self.opt = opt 68 | self._my_init() 69 | 70 | def _my_init(self): 71 | def init_weights(module): 72 | if isinstance(module, (nn.Linear, nn.Embedding)): 73 | # Slightly different from the TF version which uses truncated_normal for initialization 74 | # cf https://github.com/pytorch/pytorch/pull/5617 75 | module.weight.data.normal_(mean=0.0, std=0.02 * self.opt['init_ratio']) 76 | elif isinstance(module, BertLayerNorm): 77 | # Slightly different from the BERT pytorch version, which should be a bug. 78 | # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@. 79 | # Layer normalization (https://arxiv.org/abs/1607.06450) 80 | # support both old/latest version 81 | if 'beta' in dir(module) and 'gamma' in dir(module): 82 | module.beta.data.zero_() 83 | module.gamma.data.fill_(1.0) 84 | else: 85 | module.bias.data.zero_() 86 | module.weight.data.fill_(1.0) 87 | if isinstance(module, nn.Linear): 88 | module.bias.data.zero_() 89 | 90 | self.apply(init_weights) 91 | 92 | def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0): 93 | if self.encoder_type == EncoderModelType.ROBERTA: 94 | sequence_output = self.bert.extract_features(input_ids) 95 | pooled_output = self.pooler(sequence_output) 96 | else: 97 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 98 | sequence_output = all_encoder_layers[-1] 99 | 100 | decoder_opt = self.decoder_opt[task_id] 101 | task_type = self.task_types[task_id] 102 | if task_type == TaskType.Span: 103 | assert decoder_opt != 1 104 | sequence_output = self.dropout_list[task_id](sequence_output) 105 | logits = self.scoring_list[task_id](sequence_output) 106 | start_scores, end_scores = logits.split(1, dim=-1) 107 | start_scores = start_scores.squeeze(-1) 108 | end_scores = end_scores.squeeze(-1) 109 | return start_scores, end_scores 110 | elif task_type == TaskType.SequenceLabeling: 111 | pooled_output = all_encoder_layers[-1] 112 | pooled_output = self.dropout_list[task_id](pooled_output) 113 | pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2)) 114 | logits = self.scoring_list[task_id](pooled_output) 115 | return logits 116 | else: 117 | if decoder_opt == 1: 118 | max_query = hyp_mask.size(1) 119 | assert max_query > 0 120 | assert premise_mask is not None 121 | assert hyp_mask is not None 122 | hyp_mem = sequence_output[:, :max_query, :] 123 | logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask) 124 | else: 125 | pooled_output = self.dropout_list[task_id](pooled_output) 126 | logits = self.scoring_list[task_id](pooled_output) 127 | return logits 128 | -------------------------------------------------------------------------------- /mt-bluebert/mt_bluebert/pmetrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018, Yifan Peng 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation and/or 13 | other materials provided with the distribution. 14 | 15 | * Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | from typing import List 31 | 32 | import numpy as np 33 | from sklearn import metrics 34 | from tabulate import tabulate 35 | from mt_bluebert import conlleval 36 | 37 | 38 | def _divide(x, y): 39 | try: 40 | return np.true_divide(x, y, out=np.zeros_like(x, dtype=np.float), where=y != 0) 41 | except: 42 | return np.nan 43 | 44 | 45 | def tp_tn_fp_fn(confusion_matrix): 46 | fp = np.sum(confusion_matrix, axis=0) - np.diag(confusion_matrix) 47 | fn = np.sum(confusion_matrix, axis=1) - np.diag(confusion_matrix) 48 | tp = np.diag(confusion_matrix) 49 | tn = np.sum(confusion_matrix) - (fp + fn + tp) 50 | return tp, tn, fp, fn 51 | 52 | 53 | class PReportRow: 54 | """https://en.wikipedia.org/wiki/Precision_and_recall""" 55 | 56 | def __init__(self, category, **kwargs): 57 | self.category = category 58 | self.tp = kwargs.pop('tp', np.nan) 59 | self.tn = kwargs.pop('tn', np.nan) 60 | self.fp = kwargs.pop('fp', np.nan) 61 | self.fn = kwargs.pop('fn', np.nan) 62 | self.precision = kwargs.pop('precision', _divide(self.tp, self.tp + self.fp)) 63 | self.recall = kwargs.pop('recall', _divide(self.tp, self.tp + self.fn)) 64 | self.f1 = kwargs.pop('f1', _divide(2 * self.precision * self.recall, self.precision + self.recall)) 65 | self.specificity = kwargs.pop('specificity', _divide(self.tn, self.tn + self.fp)) 66 | self.support = kwargs.pop('support', self.tp + self.fn) 67 | self.accuracy = kwargs.pop('accuracy', _divide(self.tp + self.tn, self.tp + self.tn + self.fp + self.fn)) 68 | self.balanced_accuracy = kwargs.pop('balanced_accuracy', _divide(self.recall + self.specificity, 2)) 69 | # different names 70 | self.sensitivity = self.recall 71 | self.positive_predictive_value = self.precision 72 | self.true_positive_rate = self.recall 73 | self.true_negative_rate = self.specificity 74 | 75 | 76 | class PReport: 77 | def __init__(self, rows: List[PReportRow]): 78 | self.rows = rows 79 | self.micro_row = self.compute_micro() 80 | self.macro_row = self.compute_macro() 81 | 82 | def compute_micro(self): 83 | tps = [row.tp for row in self.rows] 84 | tns = [row.tn for row in self.rows] 85 | fps = [row.fp for row in self.rows] 86 | fns = [row.fn for row in self.rows] 87 | return PReportRow('micro', tp=np.sum(tps), tn=np.sum(tns), fp=np.sum(fps), fn=np.sum(fns)) 88 | 89 | def compute_macro(self): 90 | ps = [row.precision for row in self.rows] 91 | rs = [row.recall for row in self.rows] 92 | fs = [row.f1 for row in self.rows] 93 | return PReportRow('macro', precision=np.average(ps), recall=np.average(rs), f1=np.average(fs)) 94 | 95 | def report(self, digits=3, micro=False, macro=False): 96 | headers = ['Class', 'TP', 'FP', 'FN', 97 | 'Precision', 'Recall', 'F-score', 98 | 'Support'] 99 | float_formatter = ['g'] * 4 + ['.{}f'.format(digits)] * 3 + ['g'] 100 | 101 | rows = self.rows 102 | if micro: 103 | rows.append(self.micro_row) 104 | if macro: 105 | rows.append(self.macro_row) 106 | 107 | table = [[r.category, r.tp, r.fp, r.fn, r.precision, r.recall, r.f1, r.support] for r in rows] 108 | return tabulate(table, showindex=False, headers=headers, 109 | tablefmt="plain", floatfmt=float_formatter) 110 | 111 | def sub_report(self, subindex) -> 'PReport': 112 | rows = [self.rows[i] for i in subindex] 113 | return PReport(rows) 114 | 115 | 116 | def ner_report_conlleval(y_true: List[List[str]], y_pred: List[List[str]]) -> PReport: 117 | """Build a text report showing the main classification metrics. 118 | 119 | Args: 120 | y_true : 2d array. Ground truth (correct) target values. 121 | y_pred : 2d array. Estimated targets as returned by a classifier. 122 | """ 123 | lines = [] 124 | assert len(y_true) == len(y_pred) 125 | for t_sen, p_sen in zip(y_true, y_pred): 126 | assert len(t_sen) == len(p_sen) 127 | for t_word, p_word in zip(t_sen, p_sen): 128 | lines.append(f'XXX\t{t_word}\t{p_word}\n') 129 | lines.append('\n') 130 | 131 | counts = conlleval.evaluate(lines) 132 | overall, by_type = conlleval.metrics(counts) 133 | 134 | rows = [] 135 | for i, m in sorted(by_type.items()): 136 | rows.append(PReportRow(i, tp=m.tp, fp=m.fp, fn=m.fn)) 137 | return PReport(rows) 138 | 139 | 140 | def blue_classification_report(y_true, y_pred, *_, **kwargs) -> PReport: 141 | """ 142 | Args: 143 | y_true: (n_sample, ) 144 | y_pred: (n_sample, ) 145 | """ 146 | classes_ = kwargs.get('classes_', None) 147 | confusion_matrix = metrics.confusion_matrix(y_true, y_pred) 148 | 149 | tp, tn, fp, fn = tp_tn_fp_fn(confusion_matrix) 150 | 151 | if classes_ is None: 152 | classes_ = [i for i in range(confusion_matrix.shape[0])] 153 | 154 | rows = [] 155 | for i, c in enumerate(classes_): 156 | rows.append(PReportRow(c, tp=tp[i], tn=tn[i], fp=fp[i], fn=fn[i])) 157 | 158 | report = PReport(rows) 159 | return report 160 | 161 | -------------------------------------------------------------------------------- /mt-bluebert/requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | 3 | numpy 4 | torch==1.1.0 5 | tqdm 6 | colorlog 7 | boto3 8 | pytorch-pretrained-bert==v0.6.0 9 | regex 10 | scikit-learn 11 | pyyaml 12 | pytest 13 | sentencepiece 14 | tensorboardX 15 | tensorboard 16 | future 17 | fairseq==0.8.0 18 | seqeval==0.0.12 19 | docopt 20 | pandas 21 | tabulate -------------------------------------------------------------------------------- /mt-bluebert/scripts/blue_prepro.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | ROOT="blue_data" 4 | BERT_PATH="${ROOT}/bluebert_models/bert_uncased_lower" 5 | datasets="clinicalsts,biosses,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe" 6 | 7 | python experiments/blue/blue_prepro.py \ 8 | --root_dir $ROOT \ 9 | --task_def experiments/blue/blue_task_def.yml \ 10 | --datasets $datasets \ 11 | --overwrite 12 | 13 | python experiments/blue/blue_prepro_std.py \ 14 | --vocab $BERT_PATH/vocab.txt \ 15 | --root_dir $ROOT/canonical_data \ 16 | --task_def experiments/blue/blue_task_def.yml \ 17 | --do_lower_case \ 18 | --max_seq_len 128 \ 19 | --datasets $datasets \ 20 | --overwrite 21 | -------------------------------------------------------------------------------- /mt-bluebert/scripts/convert_tf_to_pt.py: -------------------------------------------------------------------------------- 1 | # This scripts is to convert Google's TF BERT to the pytorch version which is used by mt-dnn. 2 | # It is a supplementary script. 3 | # Note that it relies on tensorflow==1.12.0 which does not support by our released docker. 4 | # If you want to use this, please install tensorflow==1.12.0 by: pip install tensorflow==1.12.0 5 | # Some codes are adapted from https://github.com/huggingface/pytorch-pretrained-BERT 6 | # by: xiaodl 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | import re 10 | import os 11 | import argparse 12 | import tensorflow as tf 13 | import torch 14 | import numpy as np 15 | from pytorch_pretrained_bert.modeling import BertConfig 16 | from sys import path 17 | path.append(os.getcwd()) 18 | from mt_bluebert.mt_dnn.matcher import SANBertNetwork 19 | from mt_bluebert.data_utils.log_wrapper import create_logger 20 | 21 | logger = create_logger(__name__, to_disk=False) 22 | def model_config(parser): 23 | parser.add_argument('--update_bert_opt', default=0, type=int) 24 | parser.add_argument('--multi_gpu_on', action='store_true') 25 | parser.add_argument('--mem_cum_type', type=str, default='simple', 26 | help='bilinear/simple/defualt') 27 | parser.add_argument('--answer_num_turn', type=int, default=5) 28 | parser.add_argument('--answer_mem_drop_p', type=float, default=0.1) 29 | parser.add_argument('--answer_att_hidden_size', type=int, default=128) 30 | parser.add_argument('--answer_att_type', type=str, default='bilinear', 31 | help='bilinear/simple/defualt') 32 | parser.add_argument('--answer_rnn_type', type=str, default='gru', 33 | help='rnn/gru/lstm') 34 | parser.add_argument('--answer_sum_att_type', type=str, default='bilinear', 35 | help='bilinear/simple/defualt') 36 | parser.add_argument('--answer_merge_opt', type=int, default=1) 37 | parser.add_argument('--answer_mem_type', type=int, default=1) 38 | parser.add_argument('--answer_dropout_p', type=float, default=0.1) 39 | parser.add_argument('--answer_weight_norm_on', action='store_true') 40 | parser.add_argument('--dump_state_on', action='store_true') 41 | parser.add_argument('--answer_opt', type=int, default=0, help='0,1') 42 | parser.add_argument('--label_size', type=str, default='3') 43 | parser.add_argument('--mtl_opt', type=int, default=0) 44 | parser.add_argument('--ratio', type=float, default=0) 45 | parser.add_argument('--mix_opt', type=int, default=0) 46 | parser.add_argument('--max_seq_len', type=int, default=512) 47 | parser.add_argument('--init_ratio', type=float, default=1) 48 | parser.add_argument('--encoder_type', type=int, default=1) 49 | return parser 50 | 51 | def train_config(parser): 52 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 53 | help='whether to use GPU acceleration.') 54 | parser.add_argument('--log_per_updates', type=int, default=500) 55 | parser.add_argument('--epochs', type=int, default=5) 56 | parser.add_argument('--batch_size', type=int, default=8) 57 | parser.add_argument('--batch_size_eval', type=int, default=8) 58 | parser.add_argument('--optimizer', default='adamax', 59 | help='supported optimizer: adamax, sgd, adadelta, adam') 60 | parser.add_argument('--grad_clipping', type=float, default=0) 61 | parser.add_argument('--global_grad_clipping', type=float, default=1.0) 62 | parser.add_argument('--weight_decay', type=float, default=0) 63 | parser.add_argument('--learning_rate', type=float, default=5e-5) 64 | parser.add_argument('--momentum', type=float, default=0) 65 | parser.add_argument('--warmup', type=float, default=0.1) 66 | parser.add_argument('--warmup_schedule', type=str, default='warmup_linear') 67 | 68 | parser.add_argument('--vb_dropout', action='store_false') 69 | parser.add_argument('--dropout_p', type=float, default=0.1) 70 | parser.add_argument('--tasks_dropout_p', type=float, default=0.1) 71 | parser.add_argument('--dropout_w', type=float, default=0.000) 72 | parser.add_argument('--bert_dropout_p', type=float, default=0.1) 73 | parser.add_argument('--dump_feature', action='store_false') 74 | 75 | # EMA 76 | parser.add_argument('--ema_opt', type=int, default=0) 77 | parser.add_argument('--ema_gamma', type=float, default=0.995) 78 | 79 | # scheduler 80 | parser.add_argument('--have_lr_scheduler', dest='have_lr_scheduler', action='store_false') 81 | parser.add_argument('--multi_step_lr', type=str, default='10,20,30') 82 | parser.add_argument('--freeze_layers', type=int, default=-1) 83 | parser.add_argument('--embedding_opt', type=int, default=0) 84 | parser.add_argument('--lr_gamma', type=float, default=0.5) 85 | parser.add_argument('--bert_l2norm', type=float, default=0.0) 86 | parser.add_argument('--scheduler_type', type=str, default='ms', help='ms/rop/exp') 87 | parser.add_argument('--output_dir', default='checkpoint') 88 | parser.add_argument('--seed', type=int, default=2018, 89 | help='random seed for data shuffling, embedding init, etc.') 90 | return parser 91 | 92 | 93 | def convert(args): 94 | tf_checkpoint_path = args.tf_checkpoint_root 95 | bert_config_file = os.path.join(tf_checkpoint_path, 'bert_config.json') 96 | pytorch_dump_path = args.pytorch_checkpoint_path 97 | config = BertConfig.from_json_file(bert_config_file) 98 | opt = vars(args) 99 | opt.update(config.to_dict()) 100 | model = SANBertNetwork(opt) 101 | path = os.path.join(tf_checkpoint_path, 'bert_model.ckpt') 102 | logger.info('Converting TensorFlow checkpoint from {}'.format(path)) 103 | init_vars = tf.train.list_variables(path) 104 | names = [] 105 | arrays = [] 106 | 107 | for name, shape in init_vars: 108 | logger.info('Loading {} with shape {}'.format(name, shape)) 109 | array = tf.train.load_variable(path, name) 110 | logger.info('Numpy array shape {}'.format(array.shape)) 111 | 112 | # new layer norm var name 113 | # make sure you use the latest huggingface's new layernorm implementation 114 | # if you still use beta/gamma, remove line: 48-52 115 | if name.endswith('LayerNorm/beta'): 116 | name = name[:-14] + 'LayerNorm/bias' 117 | if name.endswith('LayerNorm/gamma'): 118 | name = name[:-15] + 'LayerNorm/weight' 119 | 120 | if name.endswith('bad_steps'): 121 | print('bad_steps') 122 | continue 123 | if name.endswith('steps'): 124 | print('step') 125 | continue 126 | if name.endswith('step'): 127 | print('step') 128 | continue 129 | if name.endswith('adam_m'): 130 | print('adam_m') 131 | continue 132 | if name.endswith('adam_v'): 133 | print('adam_v') 134 | continue 135 | if name.endswith('loss_scale'): 136 | print('loss_scale') 137 | continue 138 | names.append(name) 139 | arrays.append(array) 140 | 141 | for name, array in zip(names, arrays): 142 | flag = False 143 | if name == 'cls/squad/output_bias': 144 | name = 'out_proj/bias' 145 | flag = True 146 | if name == 'cls/squad/output_weights': 147 | name = 'out_proj/weight' 148 | flag = True 149 | 150 | logger.info('Loading {}'.format(name)) 151 | name = name.split('/') 152 | if name[0] in ['redictions', 'eq_relationship', 'cls', 'output']: 153 | logger.info('Skipping') 154 | continue 155 | pointer = model 156 | for m_name in name: 157 | if flag: continue 158 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 159 | l = re.split(r'_(\d+)', m_name) 160 | else: 161 | l = [m_name] 162 | if l[0] == 'kernel': 163 | pointer = getattr(pointer, 'weight') 164 | else: 165 | pointer = getattr(pointer, l[0]) 166 | if len(l) >= 2: 167 | num = int(l[1]) 168 | pointer = pointer[num] 169 | if m_name[-11:] == '_embeddings': 170 | pointer = getattr(pointer, 'weight') 171 | elif m_name == 'kernel': 172 | array = np.transpose(array) 173 | elif flag: 174 | continue 175 | pointer = getattr(getattr(pointer, name[0]), name[1]) 176 | try: 177 | assert tuple(pointer.shape) == array.shape 178 | except AssertionError as e: 179 | e.args += (pointer.shape, array.shape) 180 | raise 181 | pointer.data = torch.from_numpy(array) 182 | 183 | nstate_dict = model.state_dict() 184 | params = {'state':nstate_dict, 'config': config.to_dict()} 185 | torch.save(params, pytorch_dump_path) 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--tf_checkpoint_root', type=str, required=True) 190 | parser.add_argument('--pytorch_checkpoint_path', type=str, required=True) 191 | parser = model_config(parser) 192 | parser = train_config(parser) 193 | args = parser.parse_args() 194 | logger.info(args) 195 | convert(args) 196 | -------------------------------------------------------------------------------- /mt-bluebert/scripts/run_blue_fine_tune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ $# -ne 3 ]; then 4 | echo "fine_tune.sh " 5 | exit 1 6 | fi 7 | 8 | BATCH_SIZE=$1 9 | gpu=$2 10 | task=$3 11 | 12 | echo "export CUDA_VISIBLE_DEVICES=${gpu}" 13 | export CUDA_VISIBLE_DEVICES=${gpu} 14 | tstr=$(date +"%FT%H%M") 15 | 16 | MODEL_NAME="bluebert-mt-dnn4-biomedical-pubmed_adam_answer_opt0_gc1_ggc1_2020-02-16T1213_97_stripped" 17 | ROOT="bionlp2020" 18 | BERT_PATH="$ROOT/bluebert_models/mt_dnn_bluebert_base_cased/$MODEL_NAME.pt" 19 | DATA_DIR="$ROOT/blue_data/canonical_data/bert_uncased_lower" 20 | 21 | if [ "$task" = "clinicalsts" ]; then 22 | answer_opt=0 23 | optim="adam" 24 | grad_clipping=0 25 | global_grad_clipping=1 26 | epochs=30 27 | lr="5e-5" 28 | elif [ "$task" = "biosses" ]; then 29 | answer_opt=0 30 | optim="adam" 31 | grad_clipping=0 32 | global_grad_clipping=1 33 | epochs=20 34 | lr="5e-6" 35 | elif [ "$task" = "mednli" ]; then 36 | answer_opt=0 37 | optim="adam" 38 | grad_clipping=0 39 | global_grad_clipping=1 40 | epochs=10 41 | lr="1e-5" 42 | elif [ "$task" = "mednli" ]; then 43 | answer_opt=0 44 | optim="adam" 45 | grad_clipping=0 46 | global_grad_clipping=1 47 | epochs=10 48 | lr="5e-5" 49 | elif [ "$task" = "i2b2-2010-re" ]; then 50 | answer_opt=0 51 | optim="adam" 52 | grad_clipping=1 53 | global_grad_clipping=1 54 | epochs=10 55 | lr="2e-5" 56 | elif [ "$task" = "chemprot" ]; then 57 | answer_opt=0 58 | optim="adam" 59 | grad_clipping=0 60 | global_grad_clipping=1 61 | epochs=20 62 | lr="2e-5" 63 | elif [ "$task" = "ddi2013-type" ]; then 64 | answer_opt=0 65 | optim="adam" 66 | grad_clipping=0 67 | global_grad_clipping=1 68 | epochs=10 69 | lr="5e-5" 70 | elif [ "$task" = "bc5cdr-chemical" ]; then 71 | answer_opt=0 72 | optim="adam" 73 | grad_clipping=0 74 | global_grad_clipping=1 75 | epochs=20 76 | lr="5e-5" 77 | elif [ "$task" = "bc5cdr-disease" ]; then 78 | answer_opt=0 79 | optim="adam" 80 | grad_clipping=0 81 | global_grad_clipping=1 82 | epochs=20 83 | lr="6e-5" 84 | elif [ "$task" = "shareclefe" ]; then 85 | answer_opt=0 86 | optim="adam" 87 | grad_clipping=0 88 | global_grad_clipping=1 89 | epochs=20 90 | lr="5e-5" 91 | else 92 | echo "Cannot recognize $task" 93 | exit 94 | fi 95 | 96 | train_datasets=$task 97 | test_datasets=$task 98 | 99 | model_dir="$ROOT/checkpoints_blue/${MODEL_NAME}_${task}_${optim}_answer_opt${answer_opt}_gc${grad_clipping}_ggc${global_grad_clipping}_${tstr}" 100 | log_file="${model_dir}/log.log" 101 | python experiments/blue/blue_train.py \ 102 | --data_dir ${DATA_DIR} \ 103 | --init_checkpoint ${BERT_PATH} \ 104 | --task_def experiments/blue/blue_task_def.yml \ 105 | --batch_size "${BATCH_SIZE}" \ 106 | --epochs ${epochs} \ 107 | --output_dir "${model_dir}" \ 108 | --log_file "${log_file}" \ 109 | --answer_opt ${answer_opt} \ 110 | --optimizer ${optim} \ 111 | --train_datasets "${train_datasets}" \ 112 | --test_datasets "${test_datasets}" \ 113 | --grad_clipping ${grad_clipping} \ 114 | --global_grad_clipping ${global_grad_clipping} \ 115 | --learning_rate ${lr} \ 116 | --max_seq_len 128 \ 117 | --not_save 118 | -------------------------------------------------------------------------------- /mt-bluebert/scripts/run_blue_mt_dnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ $# -ne 2 ]; then 4 | echo "train.sh " 5 | exit 1 6 | fi 7 | 8 | prefix="blue-mt-dnn" 9 | BATCH_SIZE=$1 10 | gpu=$2 11 | echo "export CUDA_VISIBLE_DEVICES=${gpu}" 12 | export CUDA_VISIBLE_DEVICES=${gpu} 13 | tstr=$(date +"%FT%H%M") 14 | 15 | train_datasets="clinicalsts,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe" 16 | test_datasets="clinicalsts,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe" 17 | 18 | ROOT="bionlp2020" 19 | BERT_PATH="$ROOT/bluebert_models/bluebert_base/bluebert_base.pt" 20 | DATA_DIR="$ROOT/blue_data/canonical_data/bert_uncased_lower" 21 | 22 | answer_opt=0 23 | optim="adam" 24 | grad_clipping=1 25 | global_grad_clipping=1 26 | lr="5e-5" 27 | epochs=100 28 | 29 | model_dir="$ROOT/checkpoints_blue/${prefix}_${optim}_answer_opt${answer_opt}_gc${grad_clipping}_ggc${global_grad_clipping}_${tstr}" 30 | log_file="${model_dir}/log.log" 31 | python experiments/blue/blue_train.py \ 32 | --data_dir ${DATA_DIR} \ 33 | --init_checkpoint ${BERT_PATH} \ 34 | --task_def experiments/blue/blue_task_def.yml \ 35 | --batch_size "${BATCH_SIZE}" \ 36 | --output_dir "${model_dir}" \ 37 | --log_file "${log_file}" \ 38 | --answer_opt ${answer_opt} \ 39 | --optimizer ${optim} \ 40 | --train_datasets ${train_datasets} \ 41 | --test_datasets ${test_datasets} \ 42 | --grad_clipping ${grad_clipping} \ 43 | --global_grad_clipping ${global_grad_clipping} \ 44 | --learning_rate ${lr} \ 45 | --multi_gpu_on \ 46 | --epochs ${epochs} \ 47 | --max_seq_len 128 48 | # --model_ckpt ${MODEL_CKPT} \ 49 | # --resume 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | 3 | tensorflow==2.5.3 4 | # tensorflow-gpu==1.15 5 | google-api-python-client 6 | oauth2client 7 | tqdm 8 | numpy 9 | pandas==0.25.3 10 | torch===1.4.0 11 | allennlp 12 | -------------------------------------------------------------------------------- /tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/tokenizer/__init__.py --------------------------------------------------------------------------------