├── .gitignore
├── LICENSE.txt
├── README.md
├── bert
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── create_pretraining_data.py
├── download_glue_data.py
├── extract_features.py
├── modeling.py
├── modeling_test.py
├── multilingual.md
├── optimization.py
├── optimization_test.py
├── run_classifier.py
├── run_pretraining.py
├── run_squad.py
├── sample_text.txt
├── tokenization.py
└── tokenization_test.py
├── bluebert
├── __init__.py
├── conlleval.py
├── run_bluebert.py
├── run_bluebert_multi_labels.py
├── run_bluebert_ner.py
├── run_bluebert_sts.py
└── tf_metrics.py
├── elmo
├── README.md
└── elmoft.py
├── mribert
├── README.md
├── lymph_node_vocab.yml
└── sequence_classification.py
├── mt-bluebert
├── .gitignore
├── LICENSE
├── README.md
├── mt_bluebert
│ ├── __init__.py
│ ├── blue_eval.py
│ ├── blue_exp_def.py
│ ├── blue_inference.py
│ ├── blue_metrics.py
│ ├── blue_prepro.py
│ ├── blue_prepro_std.py
│ ├── blue_strip_model.py
│ ├── blue_task_def.yml
│ ├── blue_train.py
│ ├── blue_utils.py
│ ├── conlleval.py
│ ├── data_utils
│ │ ├── __init__.py
│ │ ├── gpt2_bpe.py
│ │ ├── log_wrapper.py
│ │ ├── metrics.py
│ │ ├── task_def.py
│ │ ├── utils.py
│ │ ├── vocab.py
│ │ └── xlnet_utils.py
│ ├── experiments
│ │ ├── __init__.py
│ │ ├── common_utils.py
│ │ ├── exp_def.py
│ │ └── squad
│ │ │ ├── __init__.py
│ │ │ ├── squad_prepro.py
│ │ │ ├── squad_task_def.yml
│ │ │ ├── squad_utils.py
│ │ │ └── verify_calc_span.py
│ ├── module
│ │ ├── __init__.py
│ │ ├── bert_optim.py
│ │ ├── common.py
│ │ ├── dropout_wrapper.py
│ │ ├── my_optim.py
│ │ ├── san.py
│ │ ├── similarity.py
│ │ └── sub_layers.py
│ ├── mt_dnn
│ │ ├── __init__.py
│ │ ├── batcher.py
│ │ ├── inference.py
│ │ ├── matcher.py
│ │ └── model.py
│ └── pmetrics.py
├── requirements.txt
└── scripts
│ ├── blue_prepro.sh
│ ├── convert_tf_to_pt.py
│ ├── run_blue_fine_tune.sh
│ └── run_blue_mt_dnn.sh
├── requirements.txt
└── tokenizer
├── __init__.py
└── run_tokenization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | .idea
3 | .DS_Store
4 |
5 | ### Python template
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 | ### JetBrains template
111 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
112 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
113 |
114 | # User-specific stuff
115 | .idea/**/workspace.xml
116 | .idea/**/tasks.xml
117 | .idea/**/dictionaries
118 | .idea/**/shelf
119 |
120 | # Sensitive or high-churn files
121 | .idea/**/dataSources/
122 | .idea/**/dataSources.ids
123 | .idea/**/dataSources.local.xml
124 | .idea/**/sqlDataSources.xml
125 | .idea/**/dynamic.xml
126 | .idea/**/uiDesigner.xml
127 |
128 | # Gradle
129 | .idea/**/gradle.xml
130 | .idea/**/libraries
131 |
132 | # CMake
133 | cmake-build-debug/
134 | cmake-build-release/
135 |
136 | # Mongo Explorer plugin
137 | .idea/**/mongoSettings.xml
138 |
139 | # File-based project format
140 | *.iws
141 |
142 | # IntelliJ
143 | out/
144 |
145 | # mpeltonen/sbt-idea plugin
146 | .idea_modules/
147 |
148 | # JIRA plugin
149 | atlassian-ide-plugin.xml
150 |
151 | # Cursive Clojure plugin
152 | .idea/replstate.xml
153 |
154 | # Crashlytics plugin (for Android Studio and IntelliJ)
155 | com_crashlytics_export_strings.xml
156 | crashlytics.properties
157 | crashlytics-build.properties
158 | fabric.properties
159 |
160 | # Editor-based Rest Client
161 | .idea/httpRequests
162 |
163 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | PUBLIC DOMAIN NOTICE
2 | National Center for Biotechnology Information
3 |
4 | This software/database is a "United States Government Work" under the terms of
5 | the United States Copyright Act. It was written as part of the author's
6 | official duties as a United States Government employee and thus cannot be
7 | copyrighted. This software/database is freely available to the public for use.
8 | The National Library of Medicine and the U.S. Government have not placed any
9 | restriction on its use or reproduction.
10 |
11 | Although all reasonable efforts have been taken to ensure the accuracy and
12 | reliability of the software and data, the NLM and the U.S. Government do not and
13 | cannot warrant the performance or results that may be obtained by using this
14 | software or data. The NLM and the U.S. Government disclaim all warranties,
15 | express or implied, including warranties of performance, merchantability or
16 | fitness for any particular purpose.
17 |
18 | Please cite the author in any work or product based on this material:
19 |
20 | Peng Y, Yan S, Lu Z. Transfer Learning in Biomedical Natural Language
21 | Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets.
22 | In Proceedings of the 2019 Workshop on Biomedical Natural Language Processing
23 | (BioNLP 2019). 2019:58-65.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BlueBERT
2 |
3 | **\*\*\*\*\* New Nov 1st, 2020: BlueBERT can be found at huggingface \*\*\*\*\***
4 |
5 | **\*\*\*\*\* New Dec 5th, 2019: NCBI_BERT is renamed to BlueBERT \*\*\*\*\***
6 |
7 | **\*\*\*\*\* New July 11th, 2019: preprocessed PubMed texts \*\*\*\*\***
8 |
9 | We uploaded the [preprocessed PubMed texts](https://github.com/ncbi-nlp/BlueBERT/blob/master/README.md#pubmed) that were used to pre-train the BlueBERT models.
10 |
11 | -----
12 |
13 | This repository provides codes and models of BlueBERT, pre-trained on PubMed abstracts and clinical notes ([MIMIC-III](https://mimic.physionet.org/)). Please refer to our paper [Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474) for more details.
14 |
15 | ## Pre-trained models and benchmark datasets
16 |
17 | The pre-trained BlueBERT weights, vocab, and config files can be downloaded from:
18 |
19 | * [BlueBERT-Base, Uncased, PubMed](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_uncased_L-12_H-768_A-12.zip): This model was pretrained on PubMed abstracts.
20 | * [BlueBERT-Base, Uncased, PubMed+MIMIC-III](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12.zip): This model was pretrained on PubMed abstracts and MIMIC-III.
21 | * [BlueBERT-Large, Uncased, PubMed](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_uncased_L-24_H-1024_A-16.zip): This model was pretrained on PubMed abstracts.
22 | * [BlueBERT-Large, Uncased, PubMed+MIMIC-III](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/NCBI_BERT_pubmed_mimic_uncased_L-24_H-1024_A-16.zip): This model was pretrained on PubMed abstracts and MIMIC-III.
23 |
24 | The pre-trained weights can also be found at Huggingface:
25 |
26 | * https://huggingface.co/bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12
27 | * https://huggingface.co/bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12
28 | * https://huggingface.co/bionlp/bluebert_pubmed_uncased_L-24_H-1024_A-16
29 | * https://huggingface.co/bionlp/bluebert_pubmed_mimic_uncased_L-24_H-1024_A-16
30 |
31 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark)
32 |
33 | ## Fine-tuning BlueBERT
34 |
35 | We assume the BlueBERT model has been downloaded at `$BlueBERT_DIR`, and the dataset has been downloaded at `$DATASET_DIR`.
36 |
37 | Add local directory to `$PYTHONPATH` if needed.
38 |
39 | ```bash
40 | export PYTHONPATH=.;$PYTHONPATH
41 | ```
42 |
43 | ### Sentence similarity
44 |
45 | ```bash
46 | python bluebert/run_bluebert_sts.py \
47 | --task_name='sts' \
48 | --do_train=true \
49 | --do_eval=false \
50 | --do_test=true \
51 | --vocab_file=$BlueBERT_DIR/vocab.txt \
52 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
53 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
54 | --max_seq_length=128 \
55 | --num_train_epochs=30.0 \
56 | --do_lower_case=true \
57 | --data_dir=$DATASET_DIR \
58 | --output_dir=$OUTPUT_DIR
59 | ```
60 |
61 |
62 | ### Named Entity Recognition
63 |
64 | ```bash
65 | python bluebert/run_bluebert_ner.py \
66 | --do_prepare=true \
67 | --do_train=true \
68 | --do_eval=true \
69 | --do_predict=true \
70 | --task_name="bc5cdr" \
71 | --vocab_file=$BlueBERT_DIR/vocab.txt \
72 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
73 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
74 | --num_train_epochs=30.0 \
75 | --do_lower_case=true \
76 | --data_dir=$DATASET_DIR \
77 | --output_dir=$OUTPUT_DIR
78 | ```
79 |
80 | The task name can be
81 |
82 | - `bc5cdr`: BC5CDR chemical or disease task
83 | - `clefe`: ShARe/CLEFE task
84 |
85 | ### Relation Extraction
86 |
87 | ```bash
88 | python bluebert/run_bluebert.py \
89 | --do_train=true \
90 | --do_eval=false \
91 | --do_predict=true \
92 | --task_name="chemprot" \
93 | --vocab_file=$BlueBERT_DIR/vocab.txt \
94 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
95 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
96 | --num_train_epochs=10.0 \
97 | --data_dir=$DATASET_DIR \
98 | --output_dir=$OUTPUT_DIR \
99 | --do_lower_case=true
100 | ```
101 |
102 | The task name can be
103 |
104 | - `chemprot`: BC6 ChemProt task
105 | - `ddi`: DDI 2013 task
106 | - `i2b2_2010`: I2B2 2010 task
107 |
108 | ### Document multilabel classification
109 |
110 | ```bash
111 | python bluebert/run_bluebert_multi_labels.py \
112 | --task_name="hoc" \
113 | --do_train=true \
114 | --do_eval=true \
115 | --do_predict=true \
116 | --vocab_file=$BlueBERT_DIR/vocab.txt \
117 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
118 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
119 | --max_seq_length=128 \
120 | --train_batch_size=4 \
121 | --learning_rate=2e-5 \
122 | --num_train_epochs=3 \
123 | --num_classes=20 \
124 | --num_aspects=10 \
125 | --aspect_value_list="0,1" \
126 | --data_dir=$DATASET_DIR \
127 | --output_dir=$OUTPUT_DIR
128 | ```
129 |
130 | ### Inference task
131 |
132 | ```bash
133 | python bluebert/run_bluebert.py \
134 | --do_train=true \
135 | --do_eval=false \
136 | --do_predict=true \
137 | --task_name="mednli" \
138 | --vocab_file=$BlueBERT_DIR/vocab.txt \
139 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
140 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
141 | --num_train_epochs=10.0 \
142 | --data_dir=$DATASET_DIR \
143 | --output_dir=$OUTPUT_DIR \
144 | --do_lower_case=true
145 | ```
146 |
147 | ## Preprocessed PubMed texts
148 |
149 | We provide [preprocessed PubMed texts](https://ftp.ncbi.nlm.nih.gov/pub/lu/Suppl/NCBI-BERT/pubmed_uncased_sentence_nltk.txt.tar.gz) that were used to pre-train the BlueBERT models. The corpus contains ~4000M words extracted from the [PubMed ASCII code version](https://www.ncbi.nlm.nih.gov/research/bionlp/APIs/BioC-PubMed/). Other operations include
150 |
151 | * lowercasing the text
152 | * removing speical chars `\x00`-`\x7F`
153 | * tokenizing the text using the [NLTK Treebank tokenizer](https://www.nltk.org/_modules/nltk/tokenize/treebank.html)
154 |
155 | Below is a code snippet for more details.
156 |
157 | ```python
158 | value = value.lower()
159 | value = re.sub(r'[\r\n]+', ' ', value)
160 | value = re.sub(r'[^\x00-\x7F]+', ' ', value)
161 |
162 | tokenized = TreebankWordTokenizer().tokenize(value)
163 | sentence = ' '.join(tokenized)
164 | sentence = re.sub(r"\s's\b", "'s", sentence)
165 | ```
166 |
167 | ### Pre-training with BERT
168 |
169 | Afterwards, we used the following code to generate pre-training data. Please see https://github.com/google-research/bert for more details.
170 |
171 | ```bash
172 | python bert/create_pretraining_data.py \
173 | --input_file=pubmed_uncased_sentence_nltk.txt \
174 | --output_file=pubmed_uncased_sentence_nltk.tfrecord \
175 | --vocab_file=bert_uncased_L-12_H-768_A-12_vocab.txt \
176 | --do_lower_case=True \
177 | --max_seq_length=128 \
178 | --max_predictions_per_seq=20 \
179 | --masked_lm_prob=0.15 \
180 | --random_seed=12345 \
181 | --dupe_factor=5
182 | ```
183 |
184 | We used the following code to train the BERT model. Please do not include `init_checkpoint` if you are pre-training from scratch. Please see https://github.com/google-research/bert for more details.
185 |
186 | ```bash
187 | python bert/run_pretraining.py \
188 | --input_file=pubmed_uncased_sentence_nltk.tfrecord \
189 | --output_dir=$BlueBERT_DIR \
190 | --do_train=True \
191 | --do_eval=True \
192 | --bert_config_file=$BlueBERT_DIR/bert_config.json \
193 | --init_checkpoint=$BlueBERT_DIR/bert_model.ckpt \
194 | --train_batch_size=32 \
195 | --max_seq_length=128 \
196 | --max_predictions_per_seq=20 \
197 | --num_train_steps=20000 \
198 | --num_warmup_steps=10 \
199 | --learning_rate=2e-5
200 | ```
201 |
202 | ## Citing BlueBERT
203 |
204 | * Peng Y, Yan S, Lu Z. [Transfer Learning in Biomedical Natural Language Processing: An
205 | Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474). In *Proceedings of the Workshop on Biomedical Natural Language Processing (BioNLP)*. 2019.
206 |
207 | ```
208 | @InProceedings{peng2019transfer,
209 | author = {Yifan Peng and Shankai Yan and Zhiyong Lu},
210 | title = {Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets},
211 | booktitle = {Proceedings of the 2019 Workshop on Biomedical Natural Language Processing (BioNLP 2019)},
212 | year = {2019},
213 | pages = {58--65},
214 | }
215 | ```
216 |
217 | ## Acknowledgments
218 |
219 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of
220 | Medicine and Clinical Center. This work was supported by the National Library of Medicine of the National Institutes of Health under award number K99LM013001-01.
221 |
222 | We are also grateful to the authors of BERT and ELMo to make the data and codes publicly available.
223 |
224 | We would like to thank Dr Sun Kim for processing the PubMed texts.
225 |
226 | ## Disclaimer
227 |
228 | This tool shows the results of research conducted in the Computational Biology Branch, NCBI. The information produced
229 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight
230 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information
231 | produced on this website. NIH does not independently verify the validity or utility of the information produced
232 | by this tool. If you have questions about the information produced on this website, please see a health care
233 | professional. More information about NCBI's disclaimer policy is available.
234 |
--------------------------------------------------------------------------------
/bert/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | BERT needs to maintain permanent compatibility with the pre-trained model files,
4 | so we do not plan to make any major changes to this library (other than what was
5 | promised in the README). However, we can accept small patches related to
6 | re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/bert/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/bert/download_glue_data.py:
--------------------------------------------------------------------------------
1 | ''' Script for downloading all GLUE data.
2 |
3 | Note: for legal reasons, we are unable to host MRPC.
4 | You can either use the version hosted by the SentEval team, which is already tokenized,
5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
7 | You should then rename and place specific files in a folder (see below for an example).
8 |
9 | mkdir MRPC
10 | cabextract MSRParaphraseCorpus.msi -d MRPC
11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
13 | rm MRPC/_*
14 | rm MSRParaphraseCorpus.msi
15 | '''
16 |
17 | import os
18 | import sys
19 | import shutil
20 | import argparse
21 | import tempfile
22 | import urllib.request
23 | import zipfile
24 |
25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
33 | "QNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLI.zip?alt=media&token=c24cad61-f2df-4f04-9ab6-aa576fa829d0',
34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
37 |
38 | MRPC_TRAIN = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_train.txt'
39 | MRPC_TEST = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_test.txt'
40 |
41 | def download_and_extract(task, data_dir):
42 | print("Downloading and extracting %s..." % task)
43 | data_file = "%s.zip" % task
44 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
45 | with zipfile.ZipFile(data_file) as zip_ref:
46 | zip_ref.extractall(data_dir)
47 | os.remove(data_file)
48 | print("\tCompleted!")
49 |
50 | def format_mrpc(data_dir, path_to_data):
51 | print("Processing MRPC...")
52 | mrpc_dir = os.path.join(data_dir, "MRPC")
53 | if not os.path.isdir(mrpc_dir):
54 | os.mkdir(mrpc_dir)
55 | if path_to_data:
56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
58 | else:
59 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
60 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
61 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
62 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
63 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
64 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
65 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
66 |
67 | dev_ids = []
68 | with open(os.path.join(mrpc_dir, "dev_ids.tsv")) as ids_fh:
69 | for row in ids_fh:
70 | dev_ids.append(row.strip().split('\t'))
71 |
72 | with open(mrpc_train_file, encoding='utf8') as data_fh, \
73 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf8') as train_fh, \
74 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf8') as dev_fh:
75 | header = data_fh.readline()
76 | train_fh.write(header)
77 | dev_fh.write(header)
78 | for row in data_fh:
79 | label, id1, id2, s1, s2 = row.strip().split('\t')
80 | if [id1, id2] in dev_ids:
81 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
82 | else:
83 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
84 |
85 | with open(mrpc_test_file, encoding='utf8') as data_fh, \
86 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf8') as test_fh:
87 | header = data_fh.readline()
88 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
89 | for idx, row in enumerate(data_fh):
90 | label, id1, id2, s1, s2 = row.strip().split('\t')
91 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
92 | print("\tCompleted!")
93 |
94 | def download_diagnostic(data_dir):
95 | print("Downloading and extracting diagnostic...")
96 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
97 | os.mkdir(os.path.join(data_dir, "diagnostic"))
98 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
99 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
100 | print("\tCompleted!")
101 | return
102 |
103 | def get_tasks(task_names):
104 | task_names = task_names.split(',')
105 | if "all" in task_names:
106 | tasks = TASKS
107 | else:
108 | tasks = []
109 | for task_name in task_names:
110 | assert task_name in TASKS, "Task %s not found!" % task_name
111 | tasks.append(task_name)
112 | return tasks
113 |
114 | def main(arguments):
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
117 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
118 | type=str, default='all')
119 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
120 | type=str, default='')
121 | args = parser.parse_args(arguments)
122 |
123 | if not os.path.isdir(args.data_dir):
124 | os.mkdir(args.data_dir)
125 | tasks = get_tasks(args.tasks)
126 |
127 | for task in tasks:
128 | if task == 'MRPC':
129 | format_mrpc(args.data_dir, args.path_to_mrpc)
130 | elif task == 'diagnostic':
131 | download_diagnostic(args.data_dir)
132 | else:
133 | download_and_extract(task, args.data_dir)
134 |
135 |
136 | if __name__ == '__main__':
137 | sys.exit(main(sys.argv[1:]))
--------------------------------------------------------------------------------
/bert/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import json
21 | import random
22 | import re
23 |
24 | import six
25 | import tensorflow as tf
26 |
27 | from bert import modeling
28 |
29 |
30 | class BertModelTest(tf.test.TestCase):
31 | class BertModelTester(object):
32 |
33 | def __init__(self,
34 | parent,
35 | batch_size=13,
36 | seq_length=7,
37 | is_training=True,
38 | use_input_mask=True,
39 | use_token_type_ids=True,
40 | vocab_size=99,
41 | hidden_size=32,
42 | num_hidden_layers=5,
43 | num_attention_heads=4,
44 | intermediate_size=37,
45 | hidden_act="gelu",
46 | hidden_dropout_prob=0.1,
47 | attention_probs_dropout_prob=0.1,
48 | max_position_embeddings=512,
49 | type_vocab_size=16,
50 | initializer_range=0.02,
51 | scope=None):
52 | self.parent = parent
53 | self.batch_size = batch_size
54 | self.seq_length = seq_length
55 | self.is_training = is_training
56 | self.use_input_mask = use_input_mask
57 | self.use_token_type_ids = use_token_type_ids
58 | self.vocab_size = vocab_size
59 | self.hidden_size = hidden_size
60 | self.num_hidden_layers = num_hidden_layers
61 | self.num_attention_heads = num_attention_heads
62 | self.intermediate_size = intermediate_size
63 | self.hidden_act = hidden_act
64 | self.hidden_dropout_prob = hidden_dropout_prob
65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
66 | self.max_position_embeddings = max_position_embeddings
67 | self.type_vocab_size = type_vocab_size
68 | self.initializer_range = initializer_range
69 | self.scope = scope
70 |
71 | def create_model(self):
72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
73 | self.vocab_size)
74 |
75 | input_mask = None
76 | if self.use_input_mask:
77 | input_mask = BertModelTest.ids_tensor(
78 | [self.batch_size, self.seq_length], vocab_size=2)
79 |
80 | token_type_ids = None
81 | if self.use_token_type_ids:
82 | token_type_ids = BertModelTest.ids_tensor(
83 | [self.batch_size, self.seq_length], self.type_vocab_size)
84 |
85 | config = modeling.BertConfig(
86 | vocab_size=self.vocab_size,
87 | hidden_size=self.hidden_size,
88 | num_hidden_layers=self.num_hidden_layers,
89 | num_attention_heads=self.num_attention_heads,
90 | intermediate_size=self.intermediate_size,
91 | hidden_act=self.hidden_act,
92 | hidden_dropout_prob=self.hidden_dropout_prob,
93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
94 | max_position_embeddings=self.max_position_embeddings,
95 | type_vocab_size=self.type_vocab_size,
96 | initializer_range=self.initializer_range)
97 |
98 | model = modeling.BertModel(
99 | config=config,
100 | is_training=self.is_training,
101 | input_ids=input_ids,
102 | input_mask=input_mask,
103 | token_type_ids=token_type_ids,
104 | scope=self.scope)
105 |
106 | outputs = {
107 | "embedding_output": model.get_embedding_output(),
108 | "sequence_output": model.get_sequence_output(),
109 | "pooled_output": model.get_pooled_output(),
110 | "all_encoder_layers": model.get_all_encoder_layers(),
111 | }
112 | return outputs
113 |
114 | def check_output(self, result):
115 | self.parent.assertAllEqual(
116 | result["embedding_output"].shape,
117 | [self.batch_size, self.seq_length, self.hidden_size])
118 |
119 | self.parent.assertAllEqual(
120 | result["sequence_output"].shape,
121 | [self.batch_size, self.seq_length, self.hidden_size])
122 |
123 | self.parent.assertAllEqual(result["pooled_output"].shape,
124 | [self.batch_size, self.hidden_size])
125 |
126 | def test_default(self):
127 | self.run_tester(BertModelTest.BertModelTester(self))
128 |
129 | def test_config_to_json_string(self):
130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37)
131 | obj = json.loads(config.to_json_string())
132 | self.assertEqual(obj["vocab_size"], 99)
133 | self.assertEqual(obj["hidden_size"], 37)
134 |
135 | def run_tester(self, tester):
136 | with self.test_session() as sess:
137 | ops = tester.create_model()
138 | init_op = tf.group(tf.global_variables_initializer(),
139 | tf.local_variables_initializer())
140 | sess.run(init_op)
141 | output_result = sess.run(ops)
142 | tester.check_output(output_result)
143 |
144 | self.assert_all_tensors_reachable(sess, [init_op, ops])
145 |
146 | @classmethod
147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
148 | """Creates a random int32 tensor of the shape within the vocab size."""
149 | if rng is None:
150 | rng = random.Random()
151 |
152 | total_dims = 1
153 | for dim in shape:
154 | total_dims *= dim
155 |
156 | values = []
157 | for _ in range(total_dims):
158 | values.append(rng.randint(0, vocab_size - 1))
159 |
160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
161 |
162 | def assert_all_tensors_reachable(self, sess, outputs):
163 | """Checks that all the tensors in the graph are reachable from outputs."""
164 | graph = sess.graph
165 |
166 | ignore_strings = [
167 | "^.*/assert_less_equal/.*$",
168 | "^.*/dilation_rate$",
169 | "^.*/Tensordot/concat$",
170 | "^.*/Tensordot/concat/axis$",
171 | "^testing/.*$",
172 | ]
173 |
174 | ignore_regexes = [re.compile(x) for x in ignore_strings]
175 |
176 | unreachable = self.get_unreachable_ops(graph, outputs)
177 | filtered_unreachable = []
178 | for x in unreachable:
179 | do_ignore = False
180 | for r in ignore_regexes:
181 | m = r.match(x.name)
182 | if m is not None:
183 | do_ignore = True
184 | if do_ignore:
185 | continue
186 | filtered_unreachable.append(x)
187 | unreachable = filtered_unreachable
188 |
189 | self.assertEqual(
190 | len(unreachable), 0, "The following ops are unreachable: %s" %
191 | (" ".join([x.name for x in unreachable])))
192 |
193 | @classmethod
194 | def get_unreachable_ops(cls, graph, outputs):
195 | """Finds all of the tensors in graph that are unreachable from outputs."""
196 | outputs = cls.flatten_recursive(outputs)
197 | output_to_op = collections.defaultdict(list)
198 | op_to_all = collections.defaultdict(list)
199 | assign_out_to_in = collections.defaultdict(list)
200 |
201 | for op in graph.get_operations():
202 | for x in op.inputs:
203 | op_to_all[op.name].append(x.name)
204 | for y in op.outputs:
205 | output_to_op[y.name].append(op.name)
206 | op_to_all[op.name].append(y.name)
207 | if str(op.type) == "Assign":
208 | for y in op.outputs:
209 | for x in op.inputs:
210 | assign_out_to_in[y.name].append(x.name)
211 |
212 | assign_groups = collections.defaultdict(list)
213 | for out_name in assign_out_to_in.keys():
214 | name_group = assign_out_to_in[out_name]
215 | for n1 in name_group:
216 | assign_groups[n1].append(out_name)
217 | for n2 in name_group:
218 | if n1 != n2:
219 | assign_groups[n1].append(n2)
220 |
221 | seen_tensors = {}
222 | stack = [x.name for x in outputs]
223 | while stack:
224 | name = stack.pop()
225 | if name in seen_tensors:
226 | continue
227 | seen_tensors[name] = True
228 |
229 | if name in output_to_op:
230 | for op_name in output_to_op[name]:
231 | if op_name in op_to_all:
232 | for input_name in op_to_all[op_name]:
233 | if input_name not in stack:
234 | stack.append(input_name)
235 |
236 | expanded_names = []
237 | if name in assign_groups:
238 | for assign_name in assign_groups[name]:
239 | expanded_names.append(assign_name)
240 |
241 | for expanded_name in expanded_names:
242 | if expanded_name not in stack:
243 | stack.append(expanded_name)
244 |
245 | unreachable_ops = []
246 | for op in graph.get_operations():
247 | is_unreachable = False
248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
249 | for name in all_names:
250 | if name not in seen_tensors:
251 | is_unreachable = True
252 | if is_unreachable:
253 | unreachable_ops.append(op)
254 | return unreachable_ops
255 |
256 | @classmethod
257 | def flatten_recursive(cls, item):
258 | """Flattens (potentially nested) a tuple/dictionary/list to a list."""
259 | output = []
260 | if isinstance(item, list):
261 | output.extend(item)
262 | elif isinstance(item, tuple):
263 | output.extend(list(item))
264 | elif isinstance(item, dict):
265 | for (_, v) in six.iteritems(item):
266 | output.append(v)
267 | else:
268 | return [item]
269 |
270 | flat_output = []
271 | for x in output:
272 | flat_output.extend(cls.flatten_recursive(x))
273 | return flat_output
274 |
275 |
276 | if __name__ == "__main__":
277 | tf.test.main()
278 |
--------------------------------------------------------------------------------
/bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/bert/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import optimization
20 | import tensorflow as tf
21 |
22 |
23 | class OptimizationTest(tf.test.TestCase):
24 |
25 | def test_adam(self):
26 | with self.test_session() as sess:
27 | w = tf.get_variable(
28 | "w",
29 | shape=[3],
30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
31 | x = tf.constant([0.4, 0.2, -0.5])
32 | loss = tf.reduce_mean(tf.square(x - w))
33 | tvars = tf.trainable_variables()
34 | grads = tf.gradients(loss, tvars)
35 | global_step = tf.train.get_or_create_global_step()
36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
38 | init_op = tf.group(tf.global_variables_initializer(),
39 | tf.local_variables_initializer())
40 | sess.run(init_op)
41 | for _ in range(100):
42 | sess.run(train_op)
43 | w_np = sess.run(w)
44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/bert/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/bert/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import tempfile
21 | import six
22 | import tensorflow as tf
23 | import tokenization
24 |
25 |
26 | class TokenizationTest(tf.test.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
34 | if six.PY2:
35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
36 | else:
37 | vocab_writer.write("".join(
38 | [x + "\n" for x in vocab_tokens]).encode("utf-8"))
39 |
40 | vocab_file = vocab_writer.name
41 |
42 | tokenizer = tokenization.FullTokenizer(vocab_file)
43 | os.unlink(vocab_file)
44 |
45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
47 |
48 | self.assertAllEqual(
49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
50 |
51 | def test_chinese(self):
52 | tokenizer = tokenization.BasicTokenizer()
53 |
54 | self.assertAllEqual(
55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
56 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
57 |
58 | def test_basic_tokenizer_lower(self):
59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
60 |
61 | self.assertAllEqual(
62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
63 | ["hello", "!", "how", "are", "you", "?"])
64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
65 |
66 | def test_basic_tokenizer_no_lower(self):
67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
68 |
69 | self.assertAllEqual(
70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
71 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
72 |
73 | def test_wordpiece_tokenizer(self):
74 | vocab_tokens = [
75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
76 | "##ing"
77 | ]
78 |
79 | vocab = {}
80 | for (i, token) in enumerate(vocab_tokens):
81 | vocab[token] = i
82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
83 |
84 | self.assertAllEqual(tokenizer.tokenize(""), [])
85 |
86 | self.assertAllEqual(
87 | tokenizer.tokenize("unwanted running"),
88 | ["un", "##want", "##ed", "runn", "##ing"])
89 |
90 | self.assertAllEqual(
91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
92 |
93 | def test_convert_tokens_to_ids(self):
94 | vocab_tokens = [
95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
96 | "##ing"
97 | ]
98 |
99 | vocab = {}
100 | for (i, token) in enumerate(vocab_tokens):
101 | vocab[token] = i
102 |
103 | self.assertAllEqual(
104 | tokenization.convert_tokens_to_ids(
105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
106 |
107 | def test_is_whitespace(self):
108 | self.assertTrue(tokenization._is_whitespace(u" "))
109 | self.assertTrue(tokenization._is_whitespace(u"\t"))
110 | self.assertTrue(tokenization._is_whitespace(u"\r"))
111 | self.assertTrue(tokenization._is_whitespace(u"\n"))
112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
113 |
114 | self.assertFalse(tokenization._is_whitespace(u"A"))
115 | self.assertFalse(tokenization._is_whitespace(u"-"))
116 |
117 | def test_is_control(self):
118 | self.assertTrue(tokenization._is_control(u"\u0005"))
119 |
120 | self.assertFalse(tokenization._is_control(u"A"))
121 | self.assertFalse(tokenization._is_control(u" "))
122 | self.assertFalse(tokenization._is_control(u"\t"))
123 | self.assertFalse(tokenization._is_control(u"\r"))
124 |
125 | def test_is_punctuation(self):
126 | self.assertTrue(tokenization._is_punctuation(u"-"))
127 | self.assertTrue(tokenization._is_punctuation(u"$"))
128 | self.assertTrue(tokenization._is_punctuation(u"`"))
129 | self.assertTrue(tokenization._is_punctuation(u"."))
130 |
131 | self.assertFalse(tokenization._is_punctuation(u"A"))
132 | self.assertFalse(tokenization._is_punctuation(u" "))
133 |
134 |
135 | if __name__ == "__main__":
136 | tf.test.main()
137 |
--------------------------------------------------------------------------------
/bluebert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/bluebert/__init__.py
--------------------------------------------------------------------------------
/bluebert/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | # add function :evaluate(predicted_label, ori_label): which will not read from file
13 |
14 | import sys
15 | import re
16 | import codecs
17 | from collections import defaultdict, namedtuple
18 |
19 | ANY_SPACE = ''
20 |
21 |
22 | class FormatError(Exception):
23 | pass
24 |
25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
26 |
27 |
28 | class EvalCounts(object):
29 | def __init__(self):
30 | self.correct_chunk = 0 # number of correctly identified chunks
31 | self.correct_tags = 0 # number of correct chunk tags
32 | self.found_correct = 0 # number of chunks in corpus
33 | self.found_guessed = 0 # number of identified chunks
34 | self.token_counter = 0 # token counter (ignores sentence breaks)
35 |
36 | # counts by type
37 | self.t_correct_chunk = defaultdict(int)
38 | self.t_found_correct = defaultdict(int)
39 | self.t_found_guessed = defaultdict(int)
40 |
41 |
42 | def parse_args(argv):
43 | import argparse
44 | parser = argparse.ArgumentParser(
45 | description='evaluate tagging results using CoNLL criteria',
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
47 | )
48 | arg = parser.add_argument
49 | arg('-b', '--boundary', metavar='STR', default='-X-',
50 | help='sentence boundary')
51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
52 | help='character delimiting items in input')
53 | arg('-o', '--otag', metavar='CHAR', default='O',
54 | help='alternative outside tag')
55 | arg('file', nargs='?', default=None)
56 | return parser.parse_args(argv)
57 |
58 |
59 | def parse_tag(t):
60 | m = re.match(r'^([^-]*)-(.*)$', t)
61 | return m.groups() if m else (t, '')
62 |
63 |
64 | def evaluate(iterable, options=None):
65 | if options is None:
66 | options = parse_args([]) # use defaults
67 |
68 | counts = EvalCounts()
69 | num_features = None # number of features per line
70 | in_correct = False # currently processed chunks is correct until now
71 | last_correct = 'O' # previous chunk tag in corpus
72 | last_correct_type = '' # type of previously identified chunk tag
73 | last_guessed = 'O' # previously identified chunk tag
74 | last_guessed_type = '' # type of previous chunk tag in corpus
75 |
76 | for i, line in enumerate(iterable):
77 | line = line.rstrip('\r\n')
78 | # print(line)
79 |
80 | if options.delimiter == ANY_SPACE:
81 | features = line.split()
82 | else:
83 | features = line.split(options.delimiter)
84 |
85 | if num_features is None:
86 | num_features = len(features)
87 | elif num_features != len(features) and len(features) != 0:
88 | raise FormatError('unexpected number of features: %d (%d) at line %d\n%s' %
89 | (len(features), num_features, i, line))
90 |
91 | if len(features) == 0 or features[0] == options.boundary:
92 | features = [options.boundary, 'O', 'O']
93 | if len(features) < 3:
94 | raise FormatError('unexpected number of features in line %s' % line)
95 |
96 | guessed, guessed_type = parse_tag(features.pop())
97 | correct, correct_type = parse_tag(features.pop())
98 | first_item = features.pop(0)
99 |
100 | if first_item == options.boundary:
101 | guessed = 'O'
102 |
103 | end_correct = end_of_chunk(last_correct, correct,
104 | last_correct_type, correct_type)
105 | end_guessed = end_of_chunk(last_guessed, guessed,
106 | last_guessed_type, guessed_type)
107 | start_correct = start_of_chunk(last_correct, correct,
108 | last_correct_type, correct_type)
109 | start_guessed = start_of_chunk(last_guessed, guessed,
110 | last_guessed_type, guessed_type)
111 |
112 | if in_correct:
113 | if (end_correct and end_guessed and
114 | last_guessed_type == last_correct_type):
115 | in_correct = False
116 | counts.correct_chunk += 1
117 | counts.t_correct_chunk[last_correct_type] += 1
118 | elif (end_correct != end_guessed or guessed_type != correct_type):
119 | in_correct = False
120 |
121 | if start_correct and start_guessed and guessed_type == correct_type:
122 | in_correct = True
123 |
124 | if start_correct:
125 | counts.found_correct += 1
126 | counts.t_found_correct[correct_type] += 1
127 | if start_guessed:
128 | counts.found_guessed += 1
129 | counts.t_found_guessed[guessed_type] += 1
130 | if first_item != options.boundary:
131 | if correct == guessed and guessed_type == correct_type:
132 | counts.correct_tags += 1
133 | counts.token_counter += 1
134 |
135 | last_guessed = guessed
136 | last_correct = correct
137 | last_guessed_type = guessed_type
138 | last_correct_type = correct_type
139 |
140 | if in_correct:
141 | counts.correct_chunk += 1
142 | counts.t_correct_chunk[last_correct_type] += 1
143 |
144 | return counts
145 |
146 |
147 |
148 | def uniq(iterable):
149 | seen = set()
150 | return [i for i in iterable if not (i in seen or seen.add(i))]
151 |
152 |
153 | def calculate_metrics(correct, guessed, total):
154 | tp, fp, fn = correct, guessed-correct, total-correct
155 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
156 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
157 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
158 | return Metrics(tp, fp, fn, p, r, f)
159 |
160 |
161 | def metrics(counts):
162 | c = counts
163 | overall = calculate_metrics(
164 | c.correct_chunk, c.found_guessed, c.found_correct
165 | )
166 | by_type = {}
167 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
168 | by_type[t] = calculate_metrics(
169 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
170 | )
171 | return overall, by_type
172 |
173 |
174 | def report(counts, out=None):
175 | if out is None:
176 | out = sys.stdout
177 |
178 | overall, by_type = metrics(counts)
179 |
180 | c = counts
181 | out.write('processed %d tokens with %d phrases; ' %
182 | (c.token_counter, c.found_correct))
183 | out.write('found: %d phrases; correct: %d.\n' %
184 | (c.found_guessed, c.correct_chunk))
185 |
186 | if c.token_counter > 0:
187 | out.write('accuracy: %6.2f%%; ' %
188 | (100.*c.correct_tags/c.token_counter))
189 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
190 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
191 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
192 |
193 | for i, m in sorted(by_type.items()):
194 | out.write('%17s: ' % i)
195 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
196 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
197 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
198 |
199 |
200 | def report_notprint(counts, out=None):
201 | if out is None:
202 | out = sys.stdout
203 |
204 | overall, by_type = metrics(counts)
205 |
206 | c = counts
207 | final_report = []
208 | line = []
209 | line.append('processed %d tokens with %d phrases; ' %
210 | (c.token_counter, c.found_correct))
211 | line.append('found: %d phrases; correct: %d.\n' %
212 | (c.found_guessed, c.correct_chunk))
213 | final_report.append("".join(line))
214 |
215 | if c.token_counter > 0:
216 | line = []
217 | line.append('accuracy: %6.2f%%; ' %
218 | (100.*c.correct_tags/c.token_counter))
219 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
220 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
221 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
222 | final_report.append("".join(line))
223 |
224 | for i, m in sorted(by_type.items()):
225 | line = []
226 | line.append('%17s: ' % i)
227 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
228 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
229 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
230 | final_report.append("".join(line))
231 | return final_report
232 |
233 |
234 | def end_of_chunk(prev_tag, tag, prev_type, type_):
235 | # check if a chunk ended between the previous and current word
236 | # arguments: previous and current chunk tags, previous and current types
237 | chunk_end = False
238 |
239 | if prev_tag == 'E': chunk_end = True
240 | if prev_tag == 'S': chunk_end = True
241 |
242 | if prev_tag == 'B' and tag == 'B': chunk_end = True
243 | if prev_tag == 'B' and tag == 'S': chunk_end = True
244 | if prev_tag == 'B' and tag == 'O': chunk_end = True
245 | if prev_tag == 'I' and tag == 'B': chunk_end = True
246 | if prev_tag == 'I' and tag == 'S': chunk_end = True
247 | if prev_tag == 'I' and tag == 'O': chunk_end = True
248 |
249 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
250 | chunk_end = True
251 |
252 | # these chunks are assumed to have length 1
253 | if prev_tag == ']': chunk_end = True
254 | if prev_tag == '[': chunk_end = True
255 |
256 | return chunk_end
257 |
258 |
259 | def start_of_chunk(prev_tag, tag, prev_type, type_):
260 | # check if a chunk started between the previous and current word
261 | # arguments: previous and current chunk tags, previous and current types
262 | chunk_start = False
263 |
264 | if tag == 'B': chunk_start = True
265 | if tag == 'S': chunk_start = True
266 |
267 | if prev_tag == 'E' and tag == 'E': chunk_start = True
268 | if prev_tag == 'E' and tag == 'I': chunk_start = True
269 | if prev_tag == 'S' and tag == 'E': chunk_start = True
270 | if prev_tag == 'S' and tag == 'I': chunk_start = True
271 | if prev_tag == 'O' and tag == 'E': chunk_start = True
272 | if prev_tag == 'O' and tag == 'I': chunk_start = True
273 |
274 | if tag != 'O' and tag != '.' and prev_type != type_:
275 | chunk_start = True
276 |
277 | # these chunks are assumed to have length 1
278 | if tag == '[': chunk_start = True
279 | if tag == ']': chunk_start = True
280 |
281 | return chunk_start
282 |
283 |
284 | def main(argv):
285 | args = parse_args(argv[1:])
286 |
287 | if args.file is None:
288 | counts = evaluate(sys.stdin, args)
289 | else:
290 | with open(args.file) as f:
291 | counts = evaluate(f, args)
292 | report(counts)
293 |
294 |
295 | def return_report(input_file):
296 | with open(input_file, "r") as f:
297 | counts = evaluate(f)
298 | return report_notprint(counts)
299 |
300 | if __name__ == '__main__':
301 | # sys.exit(main(sys.argv))
302 | return_report('/home/pengy6/data/sentence_similarity/data/cdr/test1/wanli_result2/label_test.txt')
--------------------------------------------------------------------------------
/bluebert/tf_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Multiclass
3 | from:
4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py
5 |
6 | """
7 |
8 | __author__ = "Guillaume Genthial"
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix
13 |
14 |
15 | def precision(labels, predictions, num_classes, pos_indices=None,
16 | weights=None, average='micro'):
17 | """Multi-class precision metric for Tensorflow
18 | Parameters
19 | ----------
20 | labels : Tensor of tf.int32 or tf.int64
21 | The true labels
22 | predictions : Tensor of tf.int32 or tf.int64
23 | The predictions, same shape as labels
24 | num_classes : int
25 | The number of classes
26 | pos_indices : list of int, optional
27 | The indices of the positive classes, default is all
28 | weights : Tensor of tf.int32, optional
29 | Mask, must be of compatible shape with labels
30 | average : str, optional
31 | 'micro': counts the total number of true positives, false
32 | positives, and false negatives for the classes in
33 | `pos_indices` and infer the metric from it.
34 | 'macro': will compute the metric separately for each class in
35 | `pos_indices` and average. Will not account for class
36 | imbalance.
37 | 'weighted': will compute the metric separately for each class in
38 | `pos_indices` and perform a weighted average by the total
39 | number of true labels for each class.
40 | Returns
41 | -------
42 | tuple of (scalar float Tensor, update_op)
43 | """
44 | cm, op = _streaming_confusion_matrix(
45 | labels, predictions, num_classes, weights)
46 | pr, _, _ = metrics_from_confusion_matrix(
47 | cm, pos_indices, average=average)
48 | op, _, _ = metrics_from_confusion_matrix(
49 | op, pos_indices, average=average)
50 | return (pr, op)
51 |
52 |
53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None,
54 | average='micro'):
55 | """Multi-class recall metric for Tensorflow
56 | Parameters
57 | ----------
58 | labels : Tensor of tf.int32 or tf.int64
59 | The true labels
60 | predictions : Tensor of tf.int32 or tf.int64
61 | The predictions, same shape as labels
62 | num_classes : int
63 | The number of classes
64 | pos_indices : list of int, optional
65 | The indices of the positive classes, default is all
66 | weights : Tensor of tf.int32, optional
67 | Mask, must be of compatible shape with labels
68 | average : str, optional
69 | 'micro': counts the total number of true positives, false
70 | positives, and false negatives for the classes in
71 | `pos_indices` and infer the metric from it.
72 | 'macro': will compute the metric separately for each class in
73 | `pos_indices` and average. Will not account for class
74 | imbalance.
75 | 'weighted': will compute the metric separately for each class in
76 | `pos_indices` and perform a weighted average by the total
77 | number of true labels for each class.
78 | Returns
79 | -------
80 | tuple of (scalar float Tensor, update_op)
81 | """
82 | cm, op = _streaming_confusion_matrix(
83 | labels, predictions, num_classes, weights)
84 | _, re, _ = metrics_from_confusion_matrix(
85 | cm, pos_indices, average=average)
86 | _, op, _ = metrics_from_confusion_matrix(
87 | op, pos_indices, average=average)
88 | return (re, op)
89 |
90 |
91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None,
92 | average='micro'):
93 | return fbeta(labels, predictions, num_classes, pos_indices, weights,
94 | average)
95 |
96 |
97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None,
98 | average='micro', beta=1):
99 | """Multi-class fbeta metric for Tensorflow
100 | Parameters
101 | ----------
102 | labels : Tensor of tf.int32 or tf.int64
103 | The true labels
104 | predictions : Tensor of tf.int32 or tf.int64
105 | The predictions, same shape as labels
106 | num_classes : int
107 | The number of classes
108 | pos_indices : list of int, optional
109 | The indices of the positive classes, default is all
110 | weights : Tensor of tf.int32, optional
111 | Mask, must be of compatible shape with labels
112 | average : str, optional
113 | 'micro': counts the total number of true positives, false
114 | positives, and false negatives for the classes in
115 | `pos_indices` and infer the metric from it.
116 | 'macro': will compute the metric separately for each class in
117 | `pos_indices` and average. Will not account for class
118 | imbalance.
119 | 'weighted': will compute the metric separately for each class in
120 | `pos_indices` and perform a weighted average by the total
121 | number of true labels for each class.
122 | beta : int, optional
123 | Weight of precision in harmonic mean
124 | Returns
125 | -------
126 | tuple of (scalar float Tensor, update_op)
127 | """
128 | cm, op = _streaming_confusion_matrix(
129 | labels, predictions, num_classes, weights)
130 | _, _, fbeta = metrics_from_confusion_matrix(
131 | cm, pos_indices, average=average, beta=beta)
132 | _, _, op = metrics_from_confusion_matrix(
133 | op, pos_indices, average=average, beta=beta)
134 | return (fbeta, op)
135 |
136 |
137 | def safe_div(numerator, denominator):
138 | """Safe division, return 0 if denominator is 0"""
139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator)
140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype)
141 | denominator_is_zero = tf.equal(denominator, zeros)
142 | return tf.where(denominator_is_zero, zeros, numerator / denominator)
143 |
144 |
145 | def pr_re_fbeta(cm, pos_indices, beta=1):
146 | """Uses a confusion matrix to compute precision, recall and fbeta"""
147 | num_classes = cm.shape[0]
148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices]
149 | cm_mask = np.ones([num_classes, num_classes])
150 | cm_mask[neg_indices, neg_indices] = 0
151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask))
152 |
153 | cm_mask = np.ones([num_classes, num_classes])
154 | cm_mask[:, neg_indices] = 0
155 | tot_pred = tf.reduce_sum(cm * cm_mask)
156 |
157 | cm_mask = np.ones([num_classes, num_classes])
158 | cm_mask[neg_indices, :] = 0
159 | tot_gold = tf.reduce_sum(cm * cm_mask)
160 |
161 | pr = safe_div(diag_sum, tot_pred)
162 | re = safe_div(diag_sum, tot_gold)
163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re)
164 |
165 | return pr, re, fbeta
166 |
167 |
168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro',
169 | beta=1):
170 | """Precision, Recall and F1 from the confusion matrix
171 | Parameters
172 | ----------
173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes)
174 | The streaming confusion matrix.
175 | pos_indices : list of int, optional
176 | The indices of the positive classes
177 | beta : int, optional
178 | Weight of precision in harmonic mean
179 | average : str, optional
180 | 'micro', 'macro' or 'weighted'
181 | """
182 | num_classes = cm.shape[0]
183 | if pos_indices is None:
184 | pos_indices = [i for i in range(num_classes)]
185 |
186 | if average == 'micro':
187 | return pr_re_fbeta(cm, pos_indices, beta)
188 | elif average in {'macro', 'weighted'}:
189 | precisions, recalls, fbetas, n_golds = [], [], [], []
190 | for idx in pos_indices:
191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta)
192 | precisions.append(pr)
193 | recalls.append(re)
194 | fbetas.append(fbeta)
195 | cm_mask = np.zeros([num_classes, num_classes])
196 | cm_mask[idx, :] = 1
197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))
198 |
199 | if average == 'macro':
200 | pr = tf.reduce_mean(precisions)
201 | re = tf.reduce_mean(recalls)
202 | fbeta = tf.reduce_mean(fbetas)
203 | return pr, re, fbeta
204 | if average == 'weighted':
205 | n_gold = tf.reduce_sum(n_golds)
206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds))
207 | pr = safe_div(pr_sum, n_gold)
208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds))
209 | re = safe_div(re_sum, n_gold)
210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds))
211 | fbeta = safe_div(fbeta_sum, n_gold)
212 | return pr, re, fbeta
213 |
214 | else:
215 | raise NotImplementedError()
216 |
--------------------------------------------------------------------------------
/elmo/README.md:
--------------------------------------------------------------------------------
1 | # ELMo Fine-tuning
2 | The python script `elmoft.py` provides utility functions for fine-tuning the [ELMo model](https://allennlp.org/elmo). We used the model pre-trained on PubMed in our paper [Transfer Learning in Biomedical Natural Language Processing: An Evaluation of BERT and ELMo on Ten Benchmarking Datasets](https://arxiv.org/abs/1906.05474).
3 |
4 | ## Pre-trained models and benchmark datasets
5 | Please prepare the pre-trained ELMo model files `options.json` as well as `weights.hdf5` and indicate their locations in the parameters `--options_path` as well as`--weights_path` when running the script. The model pre-trained on PubMed can be downloaded from the [ELMo website](https://allennlp.org/elmo)
6 |
7 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark)
8 |
9 | ## Fine-tuning ELMo
10 | We assume the ELMo model has been downloaded at `$ELMO_DIR`, and the dataset has been downloaded at `$DATASET_DIR`.
11 |
12 | ### Sentence similarity
13 |
14 | ```bash
15 | python elmoft.py \
16 | --task 'clnclsts' \
17 | --seq2vec 'boe' \
18 | --options_path $ELMO_DIR/options.json \
19 | --weights_path $ELMO_DIR/weights.hdf5 \
20 | --maxlen 256 \
21 | --fchdim 500 \
22 | --lr 0.001 \
23 | --pdrop 0.5 \
24 | --do_norm --norm_type batch \
25 | --do_lastdrop \
26 | --initln \
27 | --earlystop \
28 | --epochs 20 \
29 | --bsize 64 \
30 | --data_dir=$DATASET_DIR
31 | ```
32 |
33 | The task can be
34 |
35 | - `clnclsts`: Mayo Clinics clinical sentence similarity task
36 | - `biosses`: Biomedical Summarization Track sentence similarity task
37 |
38 |
39 | ### Named Entity Recognition
40 |
41 | ```bash
42 | python elmoft.py \
43 | --task 'bc5cdr-chem' \
44 | --seq2vec 'boe' \
45 | --options_path $ELMO_DIR/options.json \
46 | --weights_path $ELMO_DIR/weights.hdf5 \
47 | --maxlen 128\
48 | --fchdim 500 \
49 | --lr 0.001 \
50 | --pdrop 0.5 \
51 | --do_norm --norm_type batch \
52 | --do_lastdrop \
53 | --initln \
54 | --earlystop \
55 | --epochs 20 \
56 | --bsize 64 \
57 | --data_dir=$DATASET_DIR
58 | ```
59 |
60 | The task can be
61 |
62 | - `bc5cdr-chem`: BC5CDR chemical or disease task
63 | - `bc5cdr-dz`: BC5CDR disease task
64 | - `shareclefe`: ShARe/CLEFE task
65 |
66 | ### Relation Extraction
67 |
68 | ```bash
69 | python elmoft.py \
70 | --task 'ddi' \
71 | --seq2vec 'boe' \
72 | --options_path $ELMO_DIR/options.json \
73 | --weights_path $ELMO_DIR/weights.hdf5 \
74 | --maxlen 128 \
75 | --fchdim 500 \
76 | --lr 0.001 \
77 | --pdrop 0.5 \
78 | --do_norm --norm_type batch \
79 | --initln \
80 | --earlystop \
81 | --epochs 20 \
82 | --bsize 64 \
83 | --data_dir=$DATASET_DIR
84 | ```
85 |
86 | The task name can be
87 |
88 | - `ddi`: DDI 2013 task
89 | - `chemprot`: BC6 ChemProt task
90 | - `i2b2`: I2B2 relation extraction task
91 |
92 | ### Document multilabel classification
93 |
94 | ```bash
95 | python elmoft.py \
96 | --task 'hoc' \
97 | --seq2vec 'boe' \
98 | --options_path $ELMO_DIR/options.json \
99 | --weights_path $ELMO_DIR/weights.hdf5 \
100 | --maxlen 128 \
101 | --fchdim 500 \
102 | --lr 0.001 \
103 | --pdrop 0.5 \
104 | --do_norm --norm_type batch \
105 | --initln \
106 | --earlystop \
107 | --epochs 50 \
108 | --bsize 64 \
109 | --data_dir=$DATASET_DIR
110 | ```
111 |
112 | ### Inference task
113 |
114 | ```bash
115 | python elmoft.py \
116 | --task 'mednli' \
117 | --seq2vec 'boe' \
118 | --options_path $ELMO_DIR/options.json \
119 | --weights_path $ELMO_DIR/weights.hdf5 \
120 | --maxlen 128 \
121 | --fchdim 500 \
122 | --lr 0.0005 \
123 | --pdrop 0.5 \
124 | --do_norm --norm_type batch \
125 | --initln \
126 | --earlystop \
127 | --epochs 20 \
128 | --bsize 64 \
129 | --data_dir=$DATASET_DIR
130 | ```
131 |
132 | ## GPU Acceleration
133 | If there is no GPU devices please indicate it using the parameter `-g 0` or `--gpunum 0`. Otherwise, please indicate the index (starting from 0) of the GPU you want to use by setting the parameter `-q 0` or `--gpuq 0`.
134 |
--------------------------------------------------------------------------------
/mribert/README.md:
--------------------------------------------------------------------------------
1 | ## Automatic recognition of abdominal lymph nodes from clinical text
2 |
3 | This repository provides codes and models of the BERT model for lymph node detection from MRI reports.
4 |
5 | ## Pre-trained models
6 |
7 | The pre-trained model weights, vocab, and config files can be downloaded from:
8 |
9 | * [mribert](https://github.com/ncbi-nlp/bluebert/releases/tag/lymphnode)
10 |
11 | ## Fine-tuning BERT
12 |
13 | We assume the MriBERT model has been downloaded at `$MriBERT_DIR`.
14 |
15 | ```bash
16 | tstr=$(date +"%FT%H%M%S%N")
17 | text_col="text"
18 | label_col="label"
19 | batch_size=32
20 | train_dataset="train,dev"
21 | val_dataset="dev"
22 | test_dataset="test"
23 | epochs=10
24 |
25 | bert_dir=$MriBERT_DIR
26 | dataset="$MriBERT_DIR/total_data.csv"
27 | model_dir="MriBERT_DIR/mri_${tstr}"
28 | test_predictions="predictions_mribert.csv"
29 |
30 | # predict new
31 | pred_dataset="$MriBERT_DIR/new_data.csv"
32 | pred_predictions="new_data_predictions.csv"
33 |
34 | export PYTHONPATH=.;$PYTHONPATH
35 | python sequence_classification.py \
36 | --do_train \
37 | --do_test \
38 | --dataset "${dataset}" \
39 | --output_dir "${model_dir}" \
40 | --vocab_file $bert_dir/vocab.txt \
41 | --bert_config_file $bert_dir/bert_config.json \
42 | --init_checkpoint $bert_dir/mribert_model.ckpt \
43 | --text_col "${text_col}" \
44 | --label_col "${label_col}" \
45 | --batch_size "${batch_size}" \
46 | --train_dataset "${train_dataset}" \
47 | --val_dataset "${val_dataset}" \
48 | --test_dataset "${test_dataset}" \
49 | --pred_dataset "${pred_dataset}" \
50 | --epochs ${epochs} \
51 | --test_predictions ${test_predictions} \
52 | --pred_predictions ${pred_predictions}
53 | ```
54 |
55 |
56 | ## Citing MriBert
57 |
58 | Peng Y, Lee S, Elton D, Shen T, Tang YX, Chen Q, Wang S, Zhu Y, Summers RM, Lu Z.
59 | Automatic recognition of abdominal lymph nodes from clinical text.
60 | In Proceedings of the ClinicalNLP Workshop. 2020.
61 |
62 | ## Acknowledgments
63 |
64 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of
65 | Medicine and Clinical Center.
66 | This work was supported by the National Library of Medicine of the National Institutes of Health under award number 4R00LM013001.
67 |
68 | ## Disclaimer
69 |
70 | This tool shows the results of research conducted in the Computational Biology Branch, NLM/NCBI. The information produced
71 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight
72 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information
73 | produced on this website. NIH does not independently verify the validity or utility of the information produced
74 | by this tool. If you have questions about the information produced on this website, please see a health care
75 | professional. More information about NLM/NCBI's disclaimer policy is available.
76 |
--------------------------------------------------------------------------------
/mribert/lymph_node_vocab.yml:
--------------------------------------------------------------------------------
1 | Mediastinal lymph node:
2 | include:
3 | - mediastinum
4 | - mediastinal
5 | Subcarinal lymph node:
6 | include:
7 | - subcarinal
8 | - tracheobronchial
9 | Cardiophrenic lymph node:
10 | include:
11 | - cardiophrenic
12 | - cardiophrenic angle
13 | - pericardiac
14 | Paraesophageal lymph node:
15 | include:
16 | - esophageal
17 | - esophagus
18 | - paraesophageal
19 | - 8 thoracic lymph node
20 | #-------------------------------------------------------------------------------
21 | Retroperitoneal lymph node:
22 | include:
23 | - retroperitoneum
24 | - retroperitoneal
25 | - retro peritoneum
26 | - right paraspinal
27 | - adrenal gland
28 | - adrenal
29 | - nephrectomy bed
30 | - nephrectomy resection bed
31 | - nephrectomy surgical bed
32 | - peripelvic
33 | Retrocrural lymph node:
34 | include:
35 | - retrocrural
36 | Para-aortic lymph node:
37 | include:
38 | - periaortic
39 | - peri-aortic
40 | - paraaortic
41 | - para-aortic
42 | - lateral aortic
43 | - infrarenal abdominal aorta
44 | - left to the aorta
45 | - retroaortic
46 | - retro-aortic
47 | Interaortocaval lymph node:
48 | include:
49 | - aortocaval
50 | - interaortocaval
51 | Retrocaval lymph node:
52 | include:
53 | - retrocaval
54 | - postcaval
55 | - posterior to the intrahepatic IVC
56 | Preaortic lymph node:
57 | include:
58 | - preaotic
59 | - preaortic
60 | - anterior to the abdominal aorta
61 | - anterior to the aorta
62 | Paracaval lymph node:
63 | include:
64 | - paracaval
65 | - pericaval
66 | - caval
67 | - vena cava
68 | - cava
69 | - paracardial
70 | - aJCC level 16
71 | - paracardial gastric
72 | - lateral to the IVC
73 | - margin of the intrahepatic IVC
74 | Precaval lymph node: # anterior to the vena cava
75 | include:
76 | - precaval
77 | Paraspinal lymph node:
78 | include:
79 | - paraspinal
80 | #-------------------------------------------------------------------------------
81 | Peritoneal lymph node:
82 | include:
83 | - peritoneal
84 | - peritoneum
85 | Subdiaphragmatic lymph node:
86 | include:
87 | - diaphragmatic
88 | - peri-diaphragmatic
89 | - subdiaphragmatic
90 | Perihepatic lymph node:
91 | include:
92 | - perihepatic
93 | Paraduodenal lymph node:
94 | include:
95 | - paraduodenal
96 | - portion of the duodenum
97 | Hepatic artery lymph node:
98 | include:
99 | - hepatic artery
100 | - hepatic arterial
101 | Periportal lymph node:
102 | include:
103 | - liver hilum
104 | - hepatoportal
105 | - porta hepatis
106 | - periportal
107 | - portal vein
108 | Peripancreatic lymph node:
109 | include:
110 | - pancreatic
111 | - pancreas
112 | - peripancreatic
113 | Portocaval lymph node:
114 | include:
115 | - interportocaval
116 | - portocaval
117 | - portacaval
118 | Perigastric lymph node:
119 | include:
120 | - perigastric
121 | Splenic lymph node:
122 | include:
123 | - spleen
124 | - splenule
125 | - splenic
126 | Celiac lymph node:
127 | include:
128 | - celiac
129 | - paraceliac
130 | Superior mesenteric lymph node:
131 | include:
132 | - SMA
133 | - at the level of the SMA
134 | - superior mesenteric
135 | Mesenteric lymph node:
136 | include:
137 | - mesentery
138 | - mesenteric
139 | - pelvis mesentery
140 | Perigastric lymph node along lesser curvature:
141 | include:
142 | - lesser curve
143 | - lesser curvature
144 | Perigastric lymph node along greater curvature:
145 | include:
146 | - greater curve
147 | - greater curvature
148 | Gastrosplenic lymph node:
149 | include:
150 | - Gastrosplenic
151 | Gastrohepatic ligament lymph node:
152 | include:
153 | - gastrohepatic
154 | - gastrohepatic ligament
155 | - gastric
156 | Hepatoduodenal ligament lymph node:
157 | include:
158 | - hepatoduodenal
159 | - hepatoduodenal ligament
160 | Paracolic lymph node:
161 | include:
162 | - paracolic
163 | Pericecal lymph node:
164 | include:
165 | - pericecal
166 | - adjacent to the cecum
167 | - ileocolic
168 | Periportal/peripancreatic lymph node:
169 | include:
170 | - periportal/peripancreatic
171 | #-------------------------------------------------------------------------------
172 | Pelvic lymph node:
173 | include:
174 | - pelvic
175 | - pelvis
176 | Common iliac lymph node:
177 | include:
178 | - common iliac
179 | - iliac
180 | - ilium
181 | External iliac lymph node:
182 | include:
183 | - external iliac
184 | Psoas lymph node:
185 | include:
186 | - psoas
187 | - psoas muscle
188 | Presacral lymph node:
189 | include:
190 | - presacral
191 | Perivesicular lymph node:
192 | include:
193 | - Perivesicular
194 | #-------------------------------------------------------------------------------
195 | Inguinal lymph node:
196 | include:
197 | - groin
198 | - inguinal
199 | - thoracic
200 |
--------------------------------------------------------------------------------
/mt-bluebert/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 | mtdnn_env
7 | venv_windows
8 | mtdnn_env_apex
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 | .DS_Store
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | # IDE pycharm
112 | .idea/
113 |
114 | log/
115 | model/
116 | submission/
117 | save/
118 | book_corpus_test/
119 | book_corpus_train/
120 | checkpoints/
121 | .pt_description_history
122 | .git-credentials
123 | pt_bert/philly
124 | .vs
125 | *.pyproj
126 | pt_bert/checkpoint
127 | */aml_experiments
128 | screenlog.*
129 | data
130 | pt_bert/scripts
131 | pt_bert/model_data
132 | screen*
133 | checkpoint
134 | *.sln
135 | dt_mtl
136 | philly
137 | bert_models
138 | run_baseline*
139 | mt_dnn_models
140 | *pyc
141 | run_test/
142 | experiments/superglue
143 |
--------------------------------------------------------------------------------
/mt-bluebert/LICENSE:
--------------------------------------------------------------------------------
1 | PUBLIC DOMAIN NOTICE
2 | National Center for Biotechnology Information
3 |
4 | This software/database is a "United States Government Work" under the terms of
5 | the United States Copyright Act. It was written as part of the author's
6 | official duties as a United States Government employee and thus cannot be
7 | copyrighted. This software/database is freely available to the public for use.
8 | The National Library of Medicine and the U.S. Government have not placed any
9 | restriction on its use or reproduction.
10 |
11 | Although all reasonable efforts have been taken to ensure the accuracy and
12 | reliability of the software and data, the NLM and the U.S. Government do not and
13 | cannot warrant the performance or results that may be obtained by using this
14 | software or data. The NLM and the U.S. Government disclaim all warranties,
15 | express or implied, including warranties of performance, merchantability or
16 | fitness for any particular purpose.
17 |
18 | Please cite the author in any work or product based on this material:
19 |
20 | Peng Y, Chen Q, Lu Z. An Empirical Study of Multi-Task Learning on BERT
21 | for Biomedical Text Mining. In Proceedings of the 2020 Workshop on Biomedical
22 | Natural Language Processing (BioNLP 2020). 2020.
--------------------------------------------------------------------------------
/mt-bluebert/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Task Learning on BERT for Biomedical Text Mining
2 |
3 | This repository provides codes and models of the Multi-Task Learning on BERT for Biomedical Text Mining.
4 | The package is based on [`mt-dnn`](https://github.com/namisan/mt-dnn).
5 |
6 | ## Pre-trained models
7 |
8 | The pre-trained MT-BlueBERT weights, vocab, and config files can be downloaded from:
9 |
10 | * [mt-bluebert-biomedical](https://github.com/yfpeng/mt-bluebert/releases/download/0.1/mt-bluebert-biomedical.pt)
11 | * [mt-bluebert-clinical](https://github.com/yfpeng/mt-bluebert/releases/download/0.1/mt-bluebert-clinical.pt)
12 |
13 | The benchmark datasets can be downloaded from [https://github.com/ncbi-nlp/BLUE_Benchmark](https://github.com/ncbi-nlp/BLUE_Benchmark)
14 |
15 | ## Quick start
16 |
17 | ### Setup Environment
18 | 1. python3.6
19 | 2. install requirements
20 | ```bash
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ### Download data
25 | Please refer to download BLUE_Benchmark: https://github.com/ncbi-nlp/BLUE_Benchmark
26 |
27 |
28 | ### Preprocess data
29 | ```bash
30 | bash ncbi_scripts/blue_prepro.sh
31 | ```
32 |
33 | ### Train a MT-DNN model
34 | ```bash
35 | bash ncbi_scripts/run_blue_mt_dnn.sh
36 | ```
37 |
38 | ### Fine-tune a model
39 | ```bash
40 | bash ncbi_scripts/run_blue_fine_tune.sh
41 | ```
42 |
43 | ### Convert Tensorflow BERT model to the MT-DNN format
44 | ```bash
45 | python ncbi_scripts/convert_tf_to_pt.py --tf_checkpoint_root $SRC_ROOT --pytorch_checkpoint_path $DEST --encoder_type 1```
46 | ```
47 |
48 | ## Citing MT-BLUE
49 |
50 | Peng Y, Chen Q, Lu Z. An Empirical Study of Multi-Task Learning on BERT
51 | for Biomedical Text Mining. In Proceedings of the 2020 Workshop on Biomedical
52 | Natural Language Processing (BioNLP 2020). 2020.
53 |
54 | ```
55 | @InProceedings{peng2019transfer,
56 | author = {Yifan Peng and Qingyu Chen and Zhiyong Lu},
57 | title = {An Empirical Study of Multi-Task Learning on BERT for Biomedical Text Mining},
58 | booktitle = {Proceedings of the 2020 Workshop on Biomedical Natural Language Processing (BioNLP 2020)},
59 | year = {2020},
60 | }
61 | ```
62 |
63 | ## Acknowledgments
64 |
65 | This work was supported by the Intramural Research Programs of the National Institutes of Health, National Library of
66 | Medicine. This work was supported by the National Library of Medicine of the National Institutes of Health under award number K99LM013001-01.
67 |
68 | We are also grateful to the authors of BERT and mt-dnn to make the data and codes publicly available.
69 |
70 | ## Disclaimer
71 |
72 | This tool shows the results of research conducted in the Computational Biology Branch, NLM/NCBI. The information produced
73 | on this website is not intended for direct diagnostic use or medical decision-making without review and oversight
74 | by a clinical professional. Individuals should not change their health behavior solely on the basis of information
75 | produced on this website. NIH does not independently verify the validity or utility of the information produced
76 | by this tool. If you have questions about the information produced on this website, please see a health care
77 | professional. More information about NLM/NCBI's disclaimer policy is available.
78 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/__init__.py
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | blue_eval [options] --task_def= --data_dir= --range= --model_dir=
4 |
5 | Options:
6 | --task_def=
7 | --data_dir=
8 | --test_datasets=
9 | --range=
10 | --output=
11 | """
12 | import collections
13 | import functools
14 | import json
15 | from pathlib import Path
16 | from typing import Dict
17 |
18 | import docopt as docopt
19 | import numpy as np
20 | import pandas as pd
21 | import yaml
22 |
23 | from mt_bluebert.blue_metrics import compute_micro_f1, compute_micro_f1_subindex, compute_seq_f1, compute_pearson
24 |
25 |
26 | class TaskMetric:
27 | def __init__(self, task):
28 | self.task = task
29 | self.epochs = []
30 | self.scores = []
31 |
32 | def print_max(self):
33 | index = int(np.argmax(self.scores))
34 | print('%s: Max at epoch %d: %.3f' % (self.task, self.epochs[index], self.scores[index]))
35 |
36 |
37 | def get_score1(pred_file, task, n_class, rows, golds, metric_func):
38 | with open(pred_file, 'r', encoding='utf8') as fp:
39 | obj = json.load(fp)
40 |
41 | preds = []
42 | for i, uid in enumerate(obj['uids']):
43 | if uid != rows[i]['uid']:
44 | raise ValueError('{}: {} vs {}'.format(task, uid, rows[i]['uid']))
45 | if n_class == 1:
46 | pred = obj['scores'][i]
47 | elif n_class > 1:
48 | pred = obj['predictions'][i]
49 | else:
50 | raise KeyError(task)
51 | preds.append(pred)
52 |
53 | score = metric_func(preds, golds)
54 | return score
55 |
56 |
57 | def get_score2(pred_file, task, n_class, rows, golds, metric_func):
58 | with open(pred_file, 'r', encoding='utf8') as fp:
59 | objs = []
60 | for line in fp:
61 | objs.append(json.loads(line))
62 |
63 | preds = []
64 | for i, obj in enumerate(objs):
65 | if obj['uid'] != rows[i]['uid']:
66 | raise ValueError('{}: {} vs {}'.format(task, obj['uid'], rows[i]['uid']))
67 | if n_class == 1:
68 | pred = obj['score']
69 | elif n_class > 1:
70 | pred = obj['prediction']
71 | else:
72 | raise KeyError(task)
73 | preds.append(pred)
74 |
75 | score = metric_func(preds, golds)
76 | return score
77 |
78 |
79 | def eval_blue(test_datasets, task_def_path, data_dir, model_dir, epochs):
80 | with open(task_def_path) as fp:
81 | task_def = yaml.safe_load(fp)
82 |
83 | METRIC_FUNC = {
84 | 'biosses': compute_pearson,
85 | 'clinicalsts': compute_pearson,
86 | 'mednli': compute_micro_f1,
87 | 'i2b2-2010-re': functools.partial(
88 | compute_micro_f1_subindex,
89 | subindex=[i for i in range(len(task_def['i2b2-2010-re']['labels']) - 1)]),
90 | 'chemprot': functools.partial(
91 | compute_micro_f1_subindex,
92 | subindex=[i for i in range(len(task_def['chemprot']['labels']) - 1)]),
93 | 'ddi2013-type': functools.partial(
94 | compute_micro_f1_subindex,
95 | subindex=[i for i in range(len(task_def['ddi2013-type']['labels']) - 1)]),
96 | 'shareclefe': functools.partial(
97 | compute_seq_f1,
98 | label_mapper={i: v for i, v in enumerate(task_def['shareclefe']['labels'])}),
99 | 'bc5cdr-disease': functools.partial(
100 | compute_seq_f1,
101 | label_mapper={i: v for i, v in enumerate(task_def['bc5cdr-disease']['labels'])}),
102 | 'bc5cdr-chemical': functools.partial(
103 | compute_seq_f1,
104 | label_mapper={i: v for i, v in enumerate(task_def['bc5cdr-chemical']['labels'])}),
105 | }
106 |
107 | total_scores = collections.OrderedDict() # type: Dict[str, TaskMetric]
108 | for task in test_datasets:
109 | n_class = task_def[task]['n_class']
110 |
111 | file = data_dir / f'{task}_test.json'
112 | with open(file, 'r', encoding='utf-8') as fp:
113 | rows = [json.loads(line) for line in fp]
114 | # print('Loaded {} samples'.format(len(rows)))
115 | golds = [row['label'] for row in rows]
116 |
117 | task_metric = TaskMetric(task)
118 | for epoch in epochs:
119 | # pred_file = model_dir / f'{task}_test_scores_{epoch}.json'
120 | # score = get_score1(pred_file, task, n_class, rows, golds, METRIC_FUNC[task])
121 | # scores.append(score)
122 |
123 | # if task in ('clinicalsts', 'i2b2-2010-re', 'mednli', 'shareclefe'):
124 | pred_file = model_dir / f'{task}_test_scores_{epoch}_2.json'
125 | try:
126 | score = get_score2(pred_file, task, n_class, rows, golds, METRIC_FUNC[task])
127 | except FileNotFoundError:
128 | pred_file = model_dir / f'{task}_test_scores_{epoch}.json'
129 | score = get_score1(pred_file, task, n_class, rows, golds, METRIC_FUNC[task])
130 |
131 | task_metric.epochs.append(epoch)
132 | task_metric.scores.append(score)
133 |
134 | total_scores[task] = task_metric
135 | task_metric.print_max()
136 |
137 | # if len(scores2) != 0:
138 | # index = np.argmax(scores2)
139 | # print('%s: Max at epoch %d: %.3f' % (task, epochs[index], scores2[index]))
140 | # index = np.argmin(scores)
141 | # print('%s: Min at epoch %d: %.3f' % (task, epochs[index], scores[index]))
142 | return total_scores
143 |
144 |
145 | def pretty_print(total_scores, epochs, dest=None):
146 | # average
147 | for task_metric in total_scores.values():
148 | assert len(task_metric.scores) == len(epochs)
149 | avg_scores = [np.average([total_scores[t].scores[i] for t in total_scores.keys()])
150 | for i, _ in enumerate(epochs)]
151 | index = int(np.argmax(avg_scores))
152 | print('On average, max at epoch %d: %.3f' % (epochs[index], avg_scores[index]))
153 | for t in total_scores:
154 | print(' %s: At epoch %d: %.3f' % (t, epochs[index], total_scores[t].scores[index]))
155 |
156 | if dest is not None:
157 | table = {'epoch': epochs, 'average': avg_scores}
158 | for t, v in total_scores.items():
159 | table[t] = v.scores
160 | df = pd.DataFrame.from_dict(table)
161 | df.to_csv(dest, index=None)
162 |
163 |
164 | def main():
165 | args = docopt.docopt(__doc__)
166 | print(args)
167 |
168 | task_def_path = Path(args['--task_def'])
169 | model_dir = Path(args['--model_dir'])
170 | data_dir = Path(args['--data_dir'])
171 |
172 | toks = args['--range'].split(',')
173 | epochs = list(range(int(toks[0]), int(toks[1])))
174 |
175 | test_datasets = args['--test_datasets']
176 | if test_datasets is None:
177 | test_datasets = ['clinicalsts', 'i2b2-2010-re', 'mednli', 'chemprot', 'ddi2013-type',
178 | 'bc5cdr-chemical', 'bc5cdr-disease', 'shareclefe']
179 | else:
180 | test_datasets = test_datasets.split(',')
181 |
182 | total_scores = eval_blue(test_datasets, task_def_path, data_dir, model_dir, epochs)
183 | pretty_print(total_scores, epochs, args['--output'])
184 |
185 |
186 | if __name__ == '__main__':
187 | main()
188 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_exp_def.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Set
2 |
3 | import yaml
4 |
5 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat, EncoderModelType
6 | from mt_bluebert.data_utils.vocab import Vocabulary
7 | from mt_bluebert.blue_metrics import BlueMetric
8 |
9 |
10 | class BlueTaskDefs:
11 | def __init__(self, task_def_path):
12 | with open(task_def_path) as fp:
13 | self.task_def_dic = yaml.load(fp, yaml.FullLoader)
14 |
15 | self.label_mapper_map = {} # type: Dict[str, Vocabulary]
16 | self.n_class_map = {} # type: Dict[str, int]
17 | self.data_format_map = {}
18 | self.task_type_map = {}
19 | self.metric_meta_map = {}
20 | self.enable_san_map = {}
21 | self.dropout_p_map = {}
22 | self.split_names_map = {}
23 | self.encoder_type = None
24 | for task, task_def in self.task_def_dic.items():
25 | assert "_" not in task, "task name should not contain '_', current task name: %s" % task
26 | self.n_class_map[task] = task_def["n_class"]
27 | self.data_format_map[task] = DataFormat[task_def["data_format"]]
28 | self.task_type_map[task] = TaskType[task_def["task_type"]]
29 | self.metric_meta_map[task] = tuple(BlueMetric[metric_name] for metric_name in task_def["metric_meta"])
30 | self.enable_san_map[task] = task_def["enable_san"]
31 | if self.encoder_type is None:
32 | self.encoder_type = EncoderModelType[task_def["encoder_type"]]
33 | else:
34 | if self.encoder_type != EncoderModelType[task_def["encoder_type"]]:
35 | raise ValueError('The shared encoder has to be the same.')
36 |
37 | if "labels" in task_def:
38 | label_mapper = Vocabulary(True)
39 | for label in task_def["labels"]:
40 | label_mapper.add(label)
41 | self.label_mapper_map[task] = label_mapper
42 | else:
43 | self.label_mapper_map[task] = None
44 |
45 | if "dropout_p" in task_def:
46 | self.dropout_p_map[task] = task_def["dropout_p"]
47 |
48 | if 'split_names' in task_def:
49 | self.split_names_map[task] = task_def['split_names']
50 | else:
51 | self.split_names_map[task] = ["train", "dev", "test"]
52 |
53 | @property
54 | def tasks(self) -> Set[str]:
55 | return self.task_def_dic.keys()
56 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_inference.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from data_utils.vocab import Vocabulary
4 | from mt_bluebert.blue_metrics import calc_metrics, BlueMetric
5 |
6 |
7 | def eval_model(model, data,
8 | metric_meta: Tuple[BlueMetric],
9 | use_cuda: bool=True,
10 | with_label: bool=True,
11 | label_mapper: Vocabulary=None):
12 | data.reset()
13 | if use_cuda:
14 | model.cuda()
15 | predictions = []
16 | golds = []
17 | scores = []
18 | ids = []
19 | metrics = {}
20 | for batch_meta, batch_data in data:
21 | score, pred, gold = model.predict(batch_meta, batch_data)
22 | predictions.extend(pred)
23 | golds.extend(gold)
24 | scores.extend(score)
25 | ids.extend(batch_meta['uids'])
26 | if with_label:
27 | metrics = calc_metrics(metric_meta, golds, predictions, scores, label_mapper)
28 | return metrics, predictions, scores, golds, ids
29 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | # modified by: Yifan Peng
3 | import logging
4 | from enum import Enum
5 |
6 | from scipy.stats import pearsonr, spearmanr
7 | from sklearn.metrics import accuracy_score, f1_score
8 | from sklearn.metrics import matthews_corrcoef
9 | from sklearn.metrics import roc_auc_score
10 |
11 | from mt_bluebert.data_utils.vocab import Vocabulary
12 | from mt_bluebert.pmetrics import blue_classification_report, ner_report_conlleval
13 |
14 |
15 | def compute_acc(predicts, labels):
16 | return accuracy_score(labels, predicts)
17 |
18 |
19 | def compute_f1(predicts, labels):
20 | return f1_score(labels, predicts)
21 |
22 |
23 | def compute_mcc(predicts, labels):
24 | return matthews_corrcoef(labels, predicts)
25 |
26 |
27 | def compute_pearson(predicts, labels):
28 | pcof = pearsonr(labels, predicts)[0]
29 | return pcof
30 |
31 |
32 | def compute_spearman(predicts, labels):
33 | scof = spearmanr(labels, predicts)[0]
34 | return scof
35 |
36 |
37 | def compute_auc(predicts, labels):
38 | auc = roc_auc_score(labels, predicts)
39 | return auc
40 |
41 |
42 | def compute_micro_f1(predicts, labels):
43 | report = blue_classification_report(labels, predicts)
44 | return report.micro_row.f1.item()
45 |
46 |
47 | def compute_micro_f1_subindex(predicts, labels, subindex):
48 | report = blue_classification_report(labels, predicts)
49 | try:
50 | sub_report = report.sub_report(subindex)
51 | return sub_report.micro_row.f1.item()
52 | except Exception as e:
53 | logging.error('%s\n%s', e, report.report())
54 | return 0
55 |
56 | def compute_macro_f1_subindex(predicts, labels, subindex):
57 | report = blue_classification_report(labels, predicts)
58 | try:
59 | sub_report = report.sub_report(subindex)
60 | return sub_report.macro_row.f1.item()
61 | except Exception as e:
62 | logging.error('%s\n%s', e, report.report())
63 | return 0
64 |
65 |
66 | def compute_seq_f1(predicts, labels, label_mapper):
67 | y_true, y_pred = [], []
68 |
69 | def trim(predict, label):
70 | temp_1 = []
71 | temp_2 = []
72 |
73 | # label_index = 1
74 | # pred_index = 1
75 | # while pred_index < len(predict) and label_index < len(label):
76 | # if label_mapper[label[label_index]] == 'X' and label_mapper[predict[pred_index]] == 'X':
77 | # label_index += 1
78 | # pred_index += 1
79 | # elif label_mapper[predict[pred_index]] == 'X':
80 | # pred_index += 1
81 | # elif label_mapper[label[label_index]] == 'X':
82 | # label_index += 1
83 | # pred_index += 1
84 | # else:
85 | # temp_1.append(label_mapper[label[label_index]])
86 | # temp_2.append(label_mapper[predict[pred_index]])
87 | # label_index += 1
88 | # pred_index += 1
89 |
90 | for j, m in enumerate(predict):
91 | if j == 0:
92 | continue
93 | # if j >= len(label):
94 | # print(predict, label)
95 | # exit(1)
96 |
97 | if label_mapper[label[j]] != 'X':
98 | temp_1.append(label_mapper[label[j]])
99 | temp_2.append(label_mapper[m])
100 | temp_1.pop()
101 | temp_2.pop()
102 | y_true.append(temp_1)
103 | y_pred.append(temp_2)
104 |
105 | for i, (predict, label) in enumerate(zip(predicts, labels)):
106 | # try:
107 | trim(predict, label)
108 | # if i == 100:
109 | # break
110 | # except Exception as e:
111 | # print('index ', i)
112 | # exit(1)
113 | report = ner_report_conlleval(y_true, y_pred)
114 | return report.micro_row.f1.item()
115 |
116 |
117 | class BlueMetric(Enum):
118 | ACC = 0
119 | F1 = 1
120 | MCC = 2
121 | Pearson = 3
122 | Spearman = 4
123 | AUC = 5
124 | SeqEval = 7
125 | MicroF1 = 8
126 | MicroF1WithoutLastOne = 9
127 | MacroF1WithoutLastOne = 10
128 |
129 |
130 | METRIC_FUNC = {
131 | BlueMetric.ACC: compute_acc,
132 | BlueMetric.F1: compute_f1,
133 | BlueMetric.MCC: compute_mcc,
134 | BlueMetric.Pearson: compute_pearson,
135 | BlueMetric.Spearman: compute_spearman,
136 | BlueMetric.AUC: compute_auc,
137 | BlueMetric.SeqEval: compute_seq_f1,
138 | BlueMetric.MicroF1: compute_micro_f1,
139 | BlueMetric.MicroF1WithoutLastOne: compute_micro_f1_subindex,
140 | BlueMetric.MacroF1WithoutLastOne: compute_macro_f1_subindex
141 | }
142 |
143 |
144 | def calc_metrics(metric_meta, golds, predictions, scores, label_mapper: Vocabulary = None):
145 | metrics = {}
146 | for mm in metric_meta:
147 | metric_name = mm.name
148 | metric_func = METRIC_FUNC[mm]
149 | if mm in (BlueMetric.ACC, BlueMetric.F1, BlueMetric.MCC, BlueMetric.MicroF1):
150 | metric = metric_func(predictions, golds)
151 | elif mm == BlueMetric.SeqEval:
152 | metric = metric_func(predictions, golds, label_mapper)
153 | elif mm == BlueMetric.MicroF1WithoutLastOne:
154 | metric = metric_func(predictions, golds, subindex=list(range(len(label_mapper) - 1)))
155 | elif mm == BlueMetric.MacroF1WithoutLastOne:
156 | metric = metric_func(predictions, golds, subindex=list(range(len(label_mapper) - 1)))
157 | else:
158 | if mm == BlueMetric.AUC:
159 | assert len(scores) == 2 * len(golds), "AUC is only valid for binary classification problem"
160 | scores = scores[1::2]
161 | metric = metric_func(scores, golds)
162 | metrics[metric_name] = metric
163 | return metrics
164 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_prepro.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocessing BLUE dataset.
3 |
4 | Usage:
5 | blue_prepro [options] --root_dir= --task_def= --datasets=
6 |
7 | Options:
8 | --overwrite
9 | """
10 | import os
11 |
12 | import docopt
13 |
14 | from mt_bluebert.data_utils.log_wrapper import create_logger
15 | from mt_bluebert.blue_exp_def import BlueTaskDefs
16 | from mt_bluebert.blue_utils import load_sts, load_mednli, \
17 | load_relation, load_ner, dump_rows
18 |
19 |
20 | def main(args):
21 | root = args['--root_dir']
22 | assert os.path.exists(root)
23 |
24 | log_file = os.path.join(root, 'blue_prepro.log')
25 | logger = create_logger(__name__, to_disk=True, log_file=log_file)
26 |
27 | task_defs = BlueTaskDefs(args['--task_def'])
28 |
29 | canonical_data_suffix = "canonical_data"
30 | canonical_data_root = os.path.join(root, canonical_data_suffix)
31 | if not os.path.isdir(canonical_data_root):
32 | os.mkdir(canonical_data_root)
33 |
34 | if args['--datasets'] == 'all':
35 | tasks = task_defs.tasks
36 | else:
37 | tasks = args['--datasets'].split(',')
38 | for task in tasks:
39 | logger.info("Task %s" % task)
40 | if task not in task_defs.task_def_dic:
41 | raise KeyError('%s: Cannot process this task' % task)
42 |
43 | if task in ['clinicalsts', 'biosses']:
44 | load = load_sts
45 | elif task == 'mednli':
46 | load = load_mednli
47 | elif task in ('chemprot', 'i2b2-2010-re', 'ddi2013-type'):
48 | load = load_relation
49 | elif task in ('bc5cdr-disease', 'bc5cdr-chemical', 'shareclefe'):
50 | load = load_ner
51 | else:
52 | raise KeyError('%s: Cannot process this task' % task)
53 |
54 | data_format = task_defs.data_format_map[task]
55 | split_names = task_defs.split_names_map[task]
56 | for split_name in split_names:
57 | fin = os.path.join(root, f'{task}/{split_name}.tsv')
58 | fout = os.path.join(canonical_data_root, f'{task}_{split_name}.tsv')
59 | if os.path.exists(fout) and not args['--overwrite']:
60 | logger.warning('%s: Not overwrite %s: %s', task, split_name, fout)
61 | continue
62 | data = load(fin)
63 | logger.info('%s: Loaded %s %s samples', task, len(data), split_name)
64 | dump_rows(data, fout, data_format)
65 |
66 | logger.info('%s: Done', task)
67 |
68 |
69 | if __name__ == '__main__':
70 | args = docopt.docopt(__doc__)
71 | main(args)
72 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_strip_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | blue_strip_model
4 | """
5 | import logging
6 | import os
7 |
8 | import docopt
9 | import torch
10 |
11 |
12 | def main():
13 | args = docopt.docopt(__doc__)
14 | print(args)
15 |
16 | if not os.path.exists(args['']):
17 | logging.error('%s: Cannot find the model', args[''])
18 |
19 | map_location = 'cpu' if not torch.cuda.is_available() else None
20 | state_dict = torch.load(args[''], map_location=map_location)
21 | config = state_dict['config']
22 | if config['ema_opt'] > 0:
23 | state = state_dict['ema']
24 | else:
25 | state = state_dict['state']
26 |
27 | my_state = {k: v for k, v in state.items() if not k.startswith('scoring_list.')}
28 | my_config = {k: config[k] for k in ('vocab_size', 'hidden_size', 'num_hidden_layers', 'num_attention_heads',
29 | 'hidden_act', 'intermediate_size', 'hidden_dropout_prob',
30 | 'attention_probs_dropout_prob', 'max_position_embeddings', 'type_vocab_size',
31 | 'initializer_range')}
32 |
33 | torch.save({'state': my_state, 'config': my_config}, args[''])
34 |
35 |
36 | if __name__ == '__main__':
37 | main()
38 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_task_def.yml:
--------------------------------------------------------------------------------
1 | clinicalsts:
2 | data_format: PremiseAndOneHypothesis
3 | encoder_type: BERT
4 | enable_san: false
5 | metric_meta:
6 | - Pearson
7 | n_class: 1
8 | task_type: Regression
9 | biosses:
10 | data_format: PremiseAndOneHypothesis
11 | encoder_type: BERT
12 | enable_san: false
13 | metric_meta:
14 | - Pearson
15 | n_class: 1
16 | task_type: Regression
17 | mednli:
18 | data_format: PremiseAndOneHypothesis
19 | encoder_type: BERT
20 | enable_san: true
21 | labels:
22 | - contradiction
23 | - neutral
24 | - entailment
25 | metric_meta:
26 | - MicroF1
27 | n_class: 3
28 | task_type: Classification
29 | chemprot:
30 | data_format: PremiseOnly
31 | encoder_type: BERT
32 | enable_san: false
33 | labels:
34 | - CPR:3
35 | - CPR:4
36 | - CPR:5
37 | - CPR:6
38 | - CPR:9
39 | - 'false'
40 | metric_meta:
41 | - MicroF1WithoutLastOne
42 | n_class: 6
43 | task_type: Classification
44 | ddi2013-type:
45 | data_format: PremiseOnly
46 | encoder_type: BERT
47 | enable_san: false
48 | labels:
49 | - DDI-advise
50 | - DDI-effect
51 | - DDI-int
52 | - DDI-mechanism
53 | - DDI-false
54 | metric_meta:
55 | - MacroF1WithoutLastOne
56 | - MicroF1WithoutLastOne
57 | n_class: 5
58 | task_type: Classification
59 | i2b2-2010-re:
60 | data_format: PremiseOnly
61 | encoder_type: BERT
62 | enable_san: false
63 | labels:
64 | - PIP
65 | - TeCP
66 | - TeRP
67 | - TrAP
68 | - TrCP
69 | - TrIP
70 | - TrNAP
71 | - TrWP
72 | - 'false'
73 | metric_meta:
74 | - MicroF1WithoutLastOne
75 | n_class: 9
76 | task_type: Classification
77 | bc5cdr-disease:
78 | data_format: Sequence
79 | encoder_type: BERT
80 | enable_san: False
81 | labels:
82 | - O
83 | - B-Disease
84 | - I-Disease
85 | - X
86 | - CLS
87 | - SEP
88 | metric_meta:
89 | - SeqEval
90 | n_class: 6
91 | task_type: SequenceLabeling
92 | bc5cdr-chemical:
93 | data_format: Sequence
94 | encoder_type: BERT
95 | enable_san: False
96 | labels:
97 | - O
98 | - B-Chemical
99 | - I-Chemical
100 | - X
101 | - CLS
102 | - SEP
103 | metric_meta:
104 | - SeqEval
105 | n_class: 6
106 | task_type: SequenceLabeling
107 | shareclefe:
108 | data_format: Sequence
109 | encoder_type: BERT
110 | enable_san: False
111 | labels:
112 | - O
113 | - B-Disease
114 | - I-Disease
115 | - X
116 | - CLS
117 | - SEP
118 | metric_meta:
119 | - SeqEval
120 | n_class: 6
121 | task_type: SequenceLabeling
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/blue_utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 |
5 | from mt_bluebert.data_utils import DataFormat
6 |
7 |
8 | def load_relation(file):
9 | rows = []
10 | with open(file, encoding="utf8") as f:
11 | reader = csv.reader(f, delimiter='\t')
12 | next(reader)
13 | for i, blocks in enumerate(reader):
14 | assert len(blocks) == 3, '%s:%s: number of blocks: %s' % (file, i, len(blocks))
15 | lab = blocks[-1]
16 | sample = {'uid': blocks[0], 'premise': blocks[1], 'label': lab}
17 | rows.append(sample)
18 | return rows
19 |
20 |
21 | def load_mednli(file):
22 | """MEDNLI for classification"""
23 | rows = []
24 | with open(file, encoding="utf8") as f:
25 | reader = csv.reader(f, delimiter='\t')
26 | next(reader)
27 | for i, blocks in enumerate(reader):
28 | assert len(blocks) == 4, '%s:%s: number of blocks: %s' % (file, i, len(blocks))
29 | lab = blocks[0]
30 | assert lab is not None, '%s:%s: label is None' % (file, i)
31 | sample = {'uid': blocks[1], 'premise': blocks[2], 'hypothesis': blocks[3], 'label': lab}
32 | rows.append(sample)
33 | return rows
34 |
35 |
36 | def load_sts(file):
37 | rows = []
38 | cnt = 0
39 | with open(file, encoding="utf8") as f:
40 | reader = csv.reader(f, delimiter='\t')
41 | next(reader)
42 | for i, blocks in enumerate(reader):
43 | assert len(blocks) > 8, '%s:%s: number of blocks: %s' % (file, i, len(blocks))
44 | score = blocks[-1]
45 | sample = {'uid': cnt, 'premise': blocks[-3],'hypothesis': blocks[-2], 'label': score}
46 | rows.append(sample)
47 | cnt += 1
48 | return rows
49 |
50 |
51 | def load_ner(file, sep='\t'):
52 | rows = []
53 | sentence = []
54 | label = []
55 | offset = []
56 | uid = None
57 | with open(file, encoding="utf8") as f:
58 | for line in f:
59 | line = line.strip()
60 | if len(line) == 0 or line[0] == "\n":
61 | if len(sentence) > 0:
62 | assert uid is not None
63 | sample = {'uid': uid, 'premise': sentence, 'label': label, 'offset': offset}
64 | rows.append(sample)
65 | sentence = []
66 | label = []
67 | offset = []
68 | uid = None
69 | continue
70 | splits = line.split(sep)
71 | assert len(splits) == 4
72 | sentence.append(splits[0])
73 | offset.append('{};{}'.format(int(splits[2]), int(splits[2]) + len(splits[0])))
74 | label.append(splits[3])
75 | if splits[1] != '-':
76 | uid = splits[1] + '.' + splits[2]
77 | if len(sentence) > 0:
78 | assert uid is not None
79 | sample = {'uid': uid, 'premise': sentence, 'label': label, 'offset': offset}
80 | rows.append(sample)
81 | return rows
82 |
83 |
84 | def dump_PremiseOnly(rows, out_path):
85 | logger = logging.getLogger(__name__)
86 | with open(out_path, "w", encoding="utf-8") as out_f:
87 | for i, row in enumerate(rows):
88 | row_str = []
89 | for col in ["uid", "label", "premise"]:
90 | if "\t" in str(row[col]):
91 | row[col] = row[col].replace('\t', ' ')
92 | logger.warning('%s:%s: %s has tab' % (out_path, i, col))
93 | row_str.append(str(row[col]))
94 | out_f.write('\t'.join(row_str) + '\n')
95 |
96 |
97 | def dump_PremiseAndOneHypothesis(rows, out_path):
98 | logger = logging.getLogger(__name__)
99 | with open(out_path, "w", encoding="utf-8") as out_f:
100 | for i, row in enumerate(rows):
101 | row_str = []
102 | for col in ["uid", "label", "premise", "hypothesis"]:
103 | if "\t" in str(row[col]):
104 | row[col] = row[col].replace('\t', ' ')
105 | logger.warning('%s:%s: %s has tab' % (out_path, i, col))
106 | row_str.append(str(row[col]))
107 | out_f.write('\t'.join(row_str) + '\n')
108 |
109 |
110 | def dump_Sequence(rows, out_path):
111 | logger = logging.getLogger(__name__)
112 | with open(out_path, "w", encoding="utf-8") as out_f:
113 | for i, row in enumerate(rows):
114 | row_str = []
115 | if "\t" in str(row['uid']):
116 | row['uid'] = row['uid'].replace('\t', ' ')
117 | logger.warning('%s:%s: %s has tab' % (out_path, i, 'uid'))
118 | row_str.append(str(row['uid']))
119 | for col in ["label", "premise", "offset"]:
120 | for j, token in enumerate(row[col]):
121 | if "\t" in str(token):
122 | row[col][j] = token.replace('\t', ' ')
123 | logger.warning('%s:%s: %s has tab' % (out_path, i, col))
124 | row_str.append(json.dumps(row[col]))
125 | out_f.write('\t'.join(row_str) + '\n')
126 |
127 |
128 | def dump_PremiseAndMultiHypothesis(rows, out_path):
129 | logger = logging.getLogger(__name__)
130 | with open(out_path, "w", encoding="utf-8") as out_f:
131 | for i, row in enumerate(rows):
132 | row_str = []
133 | for col in ["uid", "label", "premise"]:
134 | if "\t" in str(row[col]):
135 | row[col] = row[col].replace('\t', ' ')
136 | logger.warning('%s:%s: %s has tab' % (out_path, i, col))
137 | row_str.append(str(row[col]))
138 | hypothesis = row["hypothesis"]
139 | for j, one_hypo in enumerate(hypothesis):
140 | if "\t" in str(one_hypo):
141 | hypothesis[j] = one_hypo.replace('\t', ' ')
142 | logger.warning('%s:%s: hypothesis has tab' % (out_path, i))
143 | row_str.append("\t".join(hypothesis))
144 | out_f.write('\t'.join(row_str) + '\n')
145 |
146 | def dump_rows(rows, out_path, data_format: DataFormat):
147 | """
148 | output files should have following format
149 | """
150 | if data_format == DataFormat.PremiseOnly:
151 | dump_PremiseOnly(rows, out_path)
152 | elif data_format == DataFormat.PremiseAndOneHypothesis:
153 | dump_PremiseAndOneHypothesis(rows, out_path)
154 | elif data_format == DataFormat.PremiseAndMultiHypothesis:
155 | dump_PremiseAndMultiHypothesis(rows, out_path)
156 | elif data_format == DataFormat.Sequence:
157 | dump_Sequence(rows, out_path)
158 | else:
159 | raise ValueError(data_format)
160 |
161 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/conlleval.py:
--------------------------------------------------------------------------------
1 | # Python version of the evaluation script from CoNLL'00-
2 | # Originates from: https://github.com/spyysalo/conlleval.py
3 |
4 |
5 | # Intentional differences:
6 | # - accept any space as delimiter by default
7 | # - optional file argument (default STDIN)
8 | # - option to set boundary (-b argument)
9 | # - LaTeX output (-l argument) not supported
10 | # - raw tags (-r argument) not supported
11 |
12 | # add function :evaluate(predicted_label, ori_label): which will not read from file
13 |
14 | import sys
15 | import re
16 | import codecs
17 | from collections import defaultdict, namedtuple
18 |
19 | ANY_SPACE = ''
20 |
21 |
22 | class FormatError(Exception):
23 | pass
24 |
25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
26 |
27 |
28 | class EvalCounts(object):
29 | def __init__(self):
30 | self.correct_chunk = 0 # number of correctly identified chunks
31 | self.correct_tags = 0 # number of correct chunk tags
32 | self.found_correct = 0 # number of chunks in corpus
33 | self.found_guessed = 0 # number of identified chunks
34 | self.token_counter = 0 # token counter (ignores sentence breaks)
35 |
36 | # counts by type
37 | self.t_correct_chunk = defaultdict(int)
38 | self.t_found_correct = defaultdict(int)
39 | self.t_found_guessed = defaultdict(int)
40 |
41 |
42 | def parse_args(argv):
43 | import argparse
44 | parser = argparse.ArgumentParser(
45 | description='evaluate tagging results using CoNLL criteria',
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
47 | )
48 | arg = parser.add_argument
49 | arg('-b', '--boundary', metavar='STR', default='-X-',
50 | help='sentence boundary')
51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
52 | help='character delimiting items in input')
53 | arg('-o', '--otag', metavar='CHAR', default='O',
54 | help='alternative outside tag')
55 | arg('file', nargs='?', default=None)
56 | return parser.parse_args(argv)
57 |
58 |
59 | def parse_tag(t):
60 | m = re.match(r'^([^-]*)-(.*)$', t)
61 | return m.groups() if m else (t, '')
62 |
63 |
64 | def evaluate(iterable, options=None):
65 | if options is None:
66 | options = parse_args([]) # use defaults
67 |
68 | counts = EvalCounts()
69 | num_features = None # number of features per line
70 | in_correct = False # currently processed chunks is correct until now
71 | last_correct = 'O' # previous chunk tag in corpus
72 | last_correct_type = '' # type of previously identified chunk tag
73 | last_guessed = 'O' # previously identified chunk tag
74 | last_guessed_type = '' # type of previous chunk tag in corpus
75 |
76 | for i, line in enumerate(iterable):
77 | line = line.rstrip('\r\n')
78 | # print(line)
79 |
80 | if options.delimiter == ANY_SPACE:
81 | features = line.split()
82 | else:
83 | features = line.split(options.delimiter)
84 |
85 | if num_features is None:
86 | num_features = len(features)
87 | elif num_features != len(features) and len(features) != 0:
88 | raise FormatError('unexpected number of features: %d (%d) at line %d\n%s' %
89 | (len(features), num_features, i, line))
90 |
91 | if len(features) == 0 or features[0] == options.boundary:
92 | features = [options.boundary, 'O', 'O']
93 | if len(features) < 3:
94 | raise FormatError('unexpected number of features in line %s' % line)
95 |
96 | guessed, guessed_type = parse_tag(features.pop())
97 | correct, correct_type = parse_tag(features.pop())
98 | first_item = features.pop(0)
99 |
100 | if first_item == options.boundary:
101 | guessed = 'O'
102 |
103 | end_correct = end_of_chunk(last_correct, correct,
104 | last_correct_type, correct_type)
105 | end_guessed = end_of_chunk(last_guessed, guessed,
106 | last_guessed_type, guessed_type)
107 | start_correct = start_of_chunk(last_correct, correct,
108 | last_correct_type, correct_type)
109 | start_guessed = start_of_chunk(last_guessed, guessed,
110 | last_guessed_type, guessed_type)
111 |
112 | if in_correct:
113 | if (end_correct and end_guessed and
114 | last_guessed_type == last_correct_type):
115 | in_correct = False
116 | counts.correct_chunk += 1
117 | counts.t_correct_chunk[last_correct_type] += 1
118 | elif (end_correct != end_guessed or guessed_type != correct_type):
119 | in_correct = False
120 |
121 | if start_correct and start_guessed and guessed_type == correct_type:
122 | in_correct = True
123 |
124 | if start_correct:
125 | counts.found_correct += 1
126 | counts.t_found_correct[correct_type] += 1
127 | if start_guessed:
128 | counts.found_guessed += 1
129 | counts.t_found_guessed[guessed_type] += 1
130 | if first_item != options.boundary:
131 | if correct == guessed and guessed_type == correct_type:
132 | counts.correct_tags += 1
133 | counts.token_counter += 1
134 |
135 | last_guessed = guessed
136 | last_correct = correct
137 | last_guessed_type = guessed_type
138 | last_correct_type = correct_type
139 |
140 | if in_correct:
141 | counts.correct_chunk += 1
142 | counts.t_correct_chunk[last_correct_type] += 1
143 |
144 | return counts
145 |
146 |
147 |
148 | def uniq(iterable):
149 | seen = set()
150 | return [i for i in iterable if not (i in seen or seen.add(i))]
151 |
152 |
153 | def calculate_metrics(correct, guessed, total):
154 | tp, fp, fn = correct, guessed-correct, total-correct
155 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
156 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
157 | f = 0 if p + r == 0 else 2 * p * r / (p + r)
158 | return Metrics(tp, fp, fn, p, r, f)
159 |
160 |
161 | def metrics(counts):
162 | c = counts
163 | overall = calculate_metrics(
164 | c.correct_chunk, c.found_guessed, c.found_correct
165 | )
166 | by_type = {}
167 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
168 | by_type[t] = calculate_metrics(
169 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
170 | )
171 | return overall, by_type
172 |
173 |
174 | def report(counts, out=None):
175 | if out is None:
176 | out = sys.stdout
177 |
178 | overall, by_type = metrics(counts)
179 |
180 | c = counts
181 | out.write('processed %d tokens with %d phrases; ' %
182 | (c.token_counter, c.found_correct))
183 | out.write('found: %d phrases; correct: %d.\n' %
184 | (c.found_guessed, c.correct_chunk))
185 |
186 | if c.token_counter > 0:
187 | out.write('accuracy: %6.2f%%; ' %
188 | (100.*c.correct_tags/c.token_counter))
189 | out.write('precision: %6.2f%%; ' % (100.*overall.prec))
190 | out.write('recall: %6.2f%%; ' % (100.*overall.rec))
191 | out.write('FB1: %6.2f\n' % (100.*overall.fscore))
192 |
193 | for i, m in sorted(by_type.items()):
194 | out.write('%17s: ' % i)
195 | out.write('precision: %6.2f%%; ' % (100.*m.prec))
196 | out.write('recall: %6.2f%%; ' % (100.*m.rec))
197 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
198 |
199 |
200 | def report_notprint(counts, out=None):
201 | if out is None:
202 | out = sys.stdout
203 |
204 | overall, by_type = metrics(counts)
205 |
206 | c = counts
207 | final_report = []
208 | line = []
209 | line.append('processed %d tokens with %d phrases; ' %
210 | (c.token_counter, c.found_correct))
211 | line.append('found: %d phrases; correct: %d.\n' %
212 | (c.found_guessed, c.correct_chunk))
213 | final_report.append("".join(line))
214 |
215 | if c.token_counter > 0:
216 | line = []
217 | line.append('accuracy: %6.2f%%; ' %
218 | (100.*c.correct_tags/c.token_counter))
219 | line.append('precision: %6.2f%%; ' % (100.*overall.prec))
220 | line.append('recall: %6.2f%%; ' % (100.*overall.rec))
221 | line.append('FB1: %6.2f\n' % (100.*overall.fscore))
222 | final_report.append("".join(line))
223 |
224 | for i, m in sorted(by_type.items()):
225 | line = []
226 | line.append('%17s: ' % i)
227 | line.append('precision: %6.2f%%; ' % (100.*m.prec))
228 | line.append('recall: %6.2f%%; ' % (100.*m.rec))
229 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
230 | final_report.append("".join(line))
231 | return final_report
232 |
233 |
234 | def end_of_chunk(prev_tag, tag, prev_type, type_):
235 | # check if a chunk ended between the previous and current word
236 | # arguments: previous and current chunk tags, previous and current types
237 | chunk_end = False
238 |
239 | if prev_tag == 'E': chunk_end = True
240 | if prev_tag == 'S': chunk_end = True
241 |
242 | if prev_tag == 'B' and tag == 'B': chunk_end = True
243 | if prev_tag == 'B' and tag == 'S': chunk_end = True
244 | if prev_tag == 'B' and tag == 'O': chunk_end = True
245 | if prev_tag == 'I' and tag == 'B': chunk_end = True
246 | if prev_tag == 'I' and tag == 'S': chunk_end = True
247 | if prev_tag == 'I' and tag == 'O': chunk_end = True
248 |
249 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
250 | chunk_end = True
251 |
252 | # these chunks are assumed to have length 1
253 | if prev_tag == ']': chunk_end = True
254 | if prev_tag == '[': chunk_end = True
255 |
256 | return chunk_end
257 |
258 |
259 | def start_of_chunk(prev_tag, tag, prev_type, type_):
260 | # check if a chunk started between the previous and current word
261 | # arguments: previous and current chunk tags, previous and current types
262 | chunk_start = False
263 |
264 | if tag == 'B': chunk_start = True
265 | if tag == 'S': chunk_start = True
266 |
267 | if prev_tag == 'E' and tag == 'E': chunk_start = True
268 | if prev_tag == 'E' and tag == 'I': chunk_start = True
269 | if prev_tag == 'S' and tag == 'E': chunk_start = True
270 | if prev_tag == 'S' and tag == 'I': chunk_start = True
271 | if prev_tag == 'O' and tag == 'E': chunk_start = True
272 | if prev_tag == 'O' and tag == 'I': chunk_start = True
273 |
274 | if tag != 'O' and tag != '.' and prev_type != type_:
275 | chunk_start = True
276 |
277 | # these chunks are assumed to have length 1
278 | if tag == '[': chunk_start = True
279 | if tag == ']': chunk_start = True
280 |
281 | return chunk_start
282 |
283 |
284 | def main(argv):
285 | args = parse_args(argv[1:])
286 |
287 | if args.file is None:
288 | counts = evaluate(sys.stdin, args)
289 | else:
290 | with open(args.file) as f:
291 | counts = evaluate(f, args)
292 | report(counts)
293 |
294 |
295 | def return_report(input_file):
296 | with open(input_file, "r") as f:
297 | counts = evaluate(f)
298 | return report_notprint(counts)
299 |
300 | if __name__ == '__main__':
301 | # sys.exit(main(sys.argv))
302 | return_report('/home/pengy6/data/sentence_similarity/data/cdr/test1/wanli_result2/label_test.txt')
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 |
5 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat
6 |
7 |
8 | def load_data(file_path, data_format, task_type, label_dict=None):
9 | """
10 | :param file_path:
11 | :param data_format:
12 | :param task_type:
13 | :param label_dict: map string label to numbers.
14 | only valid for Classification task or ranking task.
15 | For ranking task, better label should have large number
16 | :return:
17 | """
18 | if task_type == TaskType.Ranking:
19 | assert data_format == DataFormat.PremiseAndMultiHypothesis
20 |
21 | rows = []
22 | for line in open(file_path, encoding="utf-8"):
23 | fields = line.strip("\n").split("\t")
24 | if data_format == DataFormat.PremiseOnly:
25 | assert len(fields) == 3
26 | row = {"uid": fields[0], "label": fields[1], "premise": fields[2]}
27 | elif data_format == DataFormat.PremiseAndOneHypothesis:
28 | assert len(fields) == 4
29 | row = {"uid": fields[0], "label": fields[1], "premise": fields[2], "hypothesis": fields[3]}
30 | elif data_format == DataFormat.PremiseAndMultiHypothesis:
31 | assert len(fields) > 5
32 | row = {"uid": fields[0], "ruid": fields[1].split(","), "label": fields[2], "premise": fields[3],
33 | "hypothesis": fields[4:]}
34 | else:
35 | raise ValueError(data_format)
36 |
37 | if task_type == TaskType.Classification:
38 | if label_dict is not None:
39 | row["label"] = label_dict[row["label"]]
40 | else:
41 | row["label"] = int(row["label"])
42 | elif task_type == TaskType.Regression:
43 | row["label"] = float(row["label"])
44 | elif task_type == TaskType.Ranking:
45 | labels = row["label"].split(",")
46 | if label_dict is not None:
47 | labels = [label_dict[label] for label in labels]
48 | else:
49 | labels = [float(label) for label in labels]
50 | row["label"] = int(np.argmax(labels))
51 | row["olabel"] = labels
52 |
53 | rows.append(row)
54 | return rows
55 |
56 |
57 | def load_score_file(score_path, n_class):
58 | sample_id_2_pred_score_seg_dic = {}
59 | score_obj = json.loads(open(score_path, encoding="utf-8").read())
60 | assert (len(score_obj["scores"]) % len(score_obj["uids"]) == 0) and \
61 | (len(score_obj["scores"]) / len(score_obj["uids"]) == n_class), \
62 | "scores column size should equal to sample count or multiple of sample count (for classification problem)"
63 |
64 | scores = score_obj["scores"]
65 | score_segs = [scores[i * n_class: (i+1) * n_class] for i in range(len(score_obj["uids"]))]
66 | for sample_id, pred, score_seg in zip(score_obj["uids"], score_obj["predictions"], score_segs):
67 | sample_id_2_pred_score_seg_dic[sample_id] = (pred, score_seg)
68 | return sample_id_2_pred_score_seg_dic
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/gpt2_bpe.py:
--------------------------------------------------------------------------------
1 | """
2 | Byte pair encoding utilities from GPT-2.
3 |
4 | Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
5 | Original license: MIT
6 | """
7 |
8 | from functools import lru_cache
9 | import json
10 |
11 |
12 | @lru_cache()
13 | def bytes_to_unicode():
14 | """
15 | Returns list of utf-8 byte and a corresponding list of unicode strings.
16 | The reversible bpe codes work on unicode strings.
17 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
18 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
19 | This is a signficant percentage of your normal, say, 32K bpe vocab.
20 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
21 | And avoids mapping to whitespace/control characters the bpe code barfs on.
22 | """
23 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
24 | cs = bs[:]
25 | n = 0
26 | for b in range(2**8):
27 | if b not in bs:
28 | bs.append(b)
29 | cs.append(2**8+n)
30 | n += 1
31 | cs = [chr(n) for n in cs]
32 | return dict(zip(bs, cs))
33 |
34 | def get_pairs(word):
35 | """Return set of symbol pairs in a word.
36 | Word is represented as tuple of symbols (symbols being variable-length strings).
37 | """
38 | pairs = set()
39 | prev_char = word[0]
40 | for char in word[1:]:
41 | pairs.add((prev_char, char))
42 | prev_char = char
43 | return pairs
44 |
45 | class Encoder:
46 |
47 | def __init__(self, encoder, bpe_merges, errors='replace'):
48 | self.encoder = encoder
49 | self.decoder = {v:k for k,v in self.encoder.items()}
50 | self.errors = errors # how to handle errors in decoding
51 | self.byte_encoder = bytes_to_unicode()
52 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
53 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
54 | self.cache = {}
55 |
56 | try:
57 | import regex as re
58 | self.re = re
59 | except ImportError:
60 | raise ImportError('Please install regex with: pip install regex')
61 |
62 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
63 | self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
64 |
65 | def bpe(self, token):
66 | if token in self.cache:
67 | return self.cache[token]
68 | word = tuple(token)
69 | pairs = get_pairs(word)
70 |
71 | if not pairs:
72 | return token
73 |
74 | while True:
75 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
76 | if bigram not in self.bpe_ranks:
77 | break
78 | first, second = bigram
79 | new_word = []
80 | i = 0
81 | while i < len(word):
82 | try:
83 | j = word.index(first, i)
84 | new_word.extend(word[i:j])
85 | i = j
86 | except:
87 | new_word.extend(word[i:])
88 | break
89 |
90 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
91 | new_word.append(first+second)
92 | i += 2
93 | else:
94 | new_word.append(word[i])
95 | i += 1
96 | new_word = tuple(new_word)
97 | word = new_word
98 | if len(word) == 1:
99 | break
100 | else:
101 | pairs = get_pairs(word)
102 | word = ' '.join(word)
103 | self.cache[token] = word
104 | return word
105 |
106 | def encode(self, text):
107 | bpe_tokens = []
108 | for token in self.re.findall(self.pat, text):
109 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
110 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
111 | return bpe_tokens
112 |
113 | def decode(self, tokens):
114 | text = ''.join([self.decoder[token] for token in tokens])
115 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
116 | return text
117 |
118 | def get_encoder(encoder_json_path, vocab_bpe_path):
119 | with open(encoder_json_path, 'r') as f:
120 | encoder = json.load(f)
121 | with open(vocab_bpe_path, 'r', encoding="utf-8") as f:
122 | bpe_data = f.read()
123 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
124 | return Encoder(
125 | encoder=encoder,
126 | bpe_merges=bpe_merges,
127 | )
128 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/log_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import logging
3 | from time import gmtime, strftime
4 | import sys
5 |
6 | def create_logger(name, silent=False, to_disk=False, log_file=None):
7 | """Logger wrapper
8 | """
9 | # setup logger
10 | log = logging.getLogger(name)
11 | log.setLevel(logging.DEBUG)
12 | log.propagate = False
13 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
14 | if not silent:
15 | ch = logging.StreamHandler(sys.stdout)
16 | ch.setLevel(logging.INFO)
17 | ch.setFormatter(formatter)
18 | log.addHandler(ch)
19 | if to_disk:
20 | log_file = log_file if log_file is not None else strftime("%Y-%m-%d-%H-%M-%S.log", gmtime())
21 | fh = logging.FileHandler(log_file)
22 | fh.setLevel(logging.DEBUG)
23 | fh.setFormatter(formatter)
24 | log.addHandler(fh)
25 | return log
26 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | from enum import Enum
3 |
4 | from sklearn.metrics import matthews_corrcoef
5 | from sklearn.metrics import accuracy_score, f1_score
6 | from sklearn.metrics import roc_auc_score
7 | from scipy.stats import pearsonr, spearmanr
8 | from seqeval.metrics import classification_report
9 |
10 | def compute_acc(predicts, labels):
11 | return 100.0 * accuracy_score(labels, predicts)
12 |
13 | def compute_f1(predicts, labels):
14 | return 100.0 * f1_score(labels, predicts)
15 |
16 | def compute_mcc(predicts, labels):
17 | return 100.0 * matthews_corrcoef(labels, predicts)
18 |
19 | def compute_pearson(predicts, labels):
20 | pcof = pearsonr(labels, predicts)[0]
21 | return 100.0 * pcof
22 |
23 | def compute_spearman(predicts, labels):
24 | scof = spearmanr(labels, predicts)[0]
25 | return 100.0 * scof
26 |
27 | def compute_auc(predicts, labels):
28 | auc = roc_auc_score(labels, predicts)
29 | return 100.0 * auc
30 |
31 | def compute_seqacc(predicts, labels, label_mapper):
32 | y_true, y_pred = [], []
33 | def trim(predict, label):
34 | temp_1 = []
35 | temp_2 = []
36 | for j, m in enumerate(predict):
37 | if j == 0:
38 | continue
39 | if label_mapper[label[j]] != 'X':
40 | temp_1.append(label_mapper[label[j]])
41 | temp_2.append(label_mapper[m])
42 | temp_1.pop()
43 | temp_2.pop()
44 | y_true.append(temp_1)
45 | y_pred.append(temp_2)
46 | for predict, label in zip(predicts, labels):
47 | trim(predict, label)
48 | report = classification_report(y_true, y_pred,digits=4)
49 | return report
50 |
51 | class Metric(Enum):
52 | ACC = 0
53 | F1 = 1
54 | MCC = 2
55 | Pearson = 3
56 | Spearman = 4
57 | AUC = 5
58 | SeqEval = 7
59 |
60 |
61 |
62 | METRIC_FUNC = {
63 | Metric.ACC: compute_acc,
64 | Metric.F1: compute_f1,
65 | Metric.MCC: compute_mcc,
66 | Metric.Pearson: compute_pearson,
67 | Metric.Spearman: compute_spearman,
68 | Metric.AUC: compute_auc,
69 | Metric.SeqEval: compute_seqacc
70 | }
71 |
72 |
73 | def calc_metrics(metric_meta, golds, predictions, scores, label_mapper=None):
74 | """Label Mapper is used for NER/POS etc.
75 | TODO: a better refactor, by xiaodl
76 | """
77 | metrics = {}
78 | for mm in metric_meta:
79 | metric_name = mm.name
80 | metric_func = METRIC_FUNC[mm]
81 | if mm in (Metric.ACC, Metric.F1, Metric.MCC):
82 | metric = metric_func(predictions, golds)
83 | elif mm == Metric.SeqEval:
84 | metric = metric_func(predictions, golds, label_mapper)
85 | else:
86 | if mm == Metric.AUC:
87 | assert len(scores) == 2 * len(golds), "AUC is only valid for binary classification problem"
88 | scores = scores[1::2]
89 | metric = metric_func(scores, golds)
90 | metrics[metric_name] = metric
91 | return metrics
92 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/task_def.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 |
3 | from enum import IntEnum
4 |
5 |
6 | class TaskType(IntEnum):
7 | Classification = 1
8 | Regression = 2
9 | Ranking = 3
10 | Span = 4
11 | SequenceLabeling = 5
12 |
13 |
14 | class DataFormat(IntEnum):
15 | PremiseOnly = 1
16 | PremiseAndOneHypothesis = 2
17 | PremiseAndMultiHypothesis = 3
18 | Sequence = 4
19 |
20 |
21 | class EncoderModelType(IntEnum):
22 | BERT = 1
23 | ROBERTA = 2
24 | XLNET = 3
25 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import random
3 | import torch
4 | import numpy
5 | import subprocess
6 |
7 | class AverageMeter(object):
8 | """Computes and stores the average and current value."""
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 | def set_environment(seed, set_cuda=False):
25 | random.seed(seed)
26 | numpy.random.seed(seed)
27 | torch.manual_seed(seed)
28 | if torch.cuda.is_available() and set_cuda:
29 | torch.cuda.manual_seed_all(seed)
30 |
31 | def patch_var(v, cuda=True):
32 | if cuda:
33 | v = v.cuda(non_blocking=True)
34 | return v
35 |
36 | def get_gpu_memory_map():
37 | result = subprocess.check_output(
38 | [
39 | 'nvidia-smi', '--query-gpu=memory.used',
40 | '--format=csv,nounits,noheader'
41 | ], encoding='utf-8')
42 | gpu_memory = [int(x) for x in result.strip().split('\n')]
43 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
44 | return gpu_memory_map
45 |
46 | def get_pip_env():
47 | result = subprocess.call(["pip", "freeze"])
48 | return result
49 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/vocab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import tqdm
3 | import unicodedata
4 |
5 | PAD = 'PADPAD'
6 | UNK = 'UNKUNK'
7 | STA= 'BOSBOS'
8 | END = 'EOSEOS'
9 |
10 | PAD_ID = 0
11 | UNK_ID = 1
12 | STA_ID = 2
13 | END_ID = 3
14 |
15 | class Vocabulary(object):
16 | INIT_LEN = 4
17 | def __init__(self, neat=False):
18 | self.neat = neat
19 | if not neat:
20 | self.tok2ind = {PAD: PAD_ID, UNK: UNK_ID, STA: STA_ID, END: END_ID}
21 | self.ind2tok = {PAD_ID: PAD, UNK_ID: UNK, STA_ID: STA, END_ID:END}
22 | else:
23 | self.tok2ind = {}
24 | self.ind2tok = {}
25 |
26 | def __len__(self):
27 | return len(self.tok2ind)
28 |
29 | def __iter__(self):
30 | return iter(self.tok2ind)
31 |
32 | def __contains__(self, key):
33 | if type(key) == int:
34 | return key in self.ind2tok
35 | elif type(key) == str:
36 | return key in self.tok2ind
37 |
38 | def __getitem__(self, key):
39 | if type(key) == int:
40 | return self.ind2tok.get(key, -1) if self.neat else self.ind2tok.get(key, UNK)
41 | if type(key) == str:
42 | return self.tok2ind.get(key, None) if self.neat else self.tok2ind.get(key,self.tok2ind.get(UNK))
43 |
44 | def __setitem__(self, key, item):
45 | if type(key) == int and type(item) == str:
46 | self.ind2tok[key] = item
47 | elif type(key) == str and type(item) == int:
48 | self.tok2ind[key] = item
49 | else:
50 | raise RuntimeError('Invalid (key, item) types.')
51 |
52 | def add(self, token):
53 | if token not in self.tok2ind:
54 | index = len(self.tok2ind)
55 | self.tok2ind[token] = index
56 | self.ind2tok[index] = token
57 |
58 | def get_vocab_list(self, with_order=True):
59 | if with_order:
60 | words = [self[k] for k in range(0, len(self))]
61 | else:
62 | words = [k for k in self.tok2ind.keys()
63 | if k not in {PAD, UNK, STA, END}]
64 | return words
65 |
66 | def toidx(self, tokens):
67 | return [self[tok] for tok in tokens]
68 |
69 | def copy(self):
70 | """Deep copy
71 | """
72 | new_vocab = Vocabulary(self.neat)
73 | for w in self:
74 | new_vocab.add(w)
75 | return new_vocab
76 |
77 | def build(words, neat=False):
78 | vocab = Vocabulary(neat)
79 | for w in words: vocab.add(w)
80 | return vocab
81 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/data_utils/xlnet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Texar Authors. All Rights Reserved.
2 | # Copyright (c) Microsoft.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import unicodedata
21 | import six
22 | from functools import partial
23 |
24 |
25 | SPIECE_UNDERLINE = '▁'
26 |
27 | special_symbols = {
28 | "" : 0,
29 | "" : 1,
30 | "" : 2,
31 | "" : 3,
32 | "" : 4,
33 | "" : 5,
34 | "" : 6,
35 | "" : 7,
36 | "" : 8,
37 | }
38 |
39 | VOCAB_SIZE = 32000
40 | UNK_ID = special_symbols[""]
41 | CLS_ID = special_symbols[""]
42 | SEP_ID = special_symbols[""]
43 | MASK_ID = special_symbols[""]
44 | EOD_ID = special_symbols[""]
45 |
46 |
47 | def printable_text(text):
48 | """Returns text encoded in a way suitable for print or `tf.logging`."""
49 |
50 | # These functions want `str` for both Python2 and Python3, but in one case
51 | # it's a Unicode string and in the other it's a byte string.
52 | if six.PY3:
53 | if isinstance(text, str):
54 | return text
55 | elif isinstance(text, bytes):
56 | return text.decode("utf-8", "ignore")
57 | else:
58 | raise ValueError("Unsupported string type: %s" % (type(text)))
59 | elif six.PY2:
60 | if isinstance(text, str):
61 | return text
62 | elif isinstance(text, unicode):
63 | return text.encode("utf-8")
64 | else:
65 | raise ValueError("Unsupported string type: %s" % (type(text)))
66 | else:
67 | raise ValueError("Not running on Python2 or Python 3?")
68 |
69 |
70 | def print_(*args):
71 | new_args = []
72 | for arg in args:
73 | if isinstance(arg, list):
74 | s = [printable_text(i) for i in arg]
75 | s = ' '.join(s)
76 | new_args.append(s)
77 | else:
78 | new_args.append(printable_text(arg))
79 | print(*new_args)
80 |
81 |
82 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
83 | if remove_space:
84 | outputs = ' '.join(inputs.strip().split())
85 | else:
86 | outputs = inputs
87 | outputs = outputs.replace("``", '"').replace("''", '"')
88 |
89 | if six.PY2 and isinstance(outputs, str):
90 | outputs = outputs.decode('utf-8')
91 |
92 | if not keep_accents:
93 | outputs = unicodedata.normalize('NFKD', outputs)
94 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
95 | if lower:
96 | outputs = outputs.lower()
97 |
98 | return outputs
99 |
100 |
101 | def encode_pieces(sp_model, text, return_unicode=True, sample=False):
102 | # return_unicode is used only for py2
103 |
104 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2
105 | if six.PY2 and isinstance(text, unicode):
106 | text = text.encode('utf-8')
107 |
108 | if not sample:
109 | pieces = sp_model.EncodeAsPieces(text)
110 | else:
111 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
112 | new_pieces = []
113 | for piece in pieces:
114 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
115 | cur_pieces = sp_model.EncodeAsPieces(
116 | piece[:-1].replace(SPIECE_UNDERLINE, ''))
117 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
118 | if len(cur_pieces[0]) == 1:
119 | cur_pieces = cur_pieces[1:]
120 | else:
121 | cur_pieces[0] = cur_pieces[0][1:]
122 | cur_pieces.append(piece[-1])
123 | new_pieces.extend(cur_pieces)
124 | else:
125 | new_pieces.append(piece)
126 |
127 | # note(zhiliny): convert back to unicode for py2
128 | if six.PY2 and return_unicode:
129 | ret_pieces = []
130 | for piece in new_pieces:
131 | if isinstance(piece, str):
132 | piece = piece.decode('utf-8')
133 | ret_pieces.append(piece)
134 | new_pieces = ret_pieces
135 |
136 | return new_pieces
137 |
138 |
139 | def encode_ids(sp_model, text, sample=False):
140 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
141 | ids = [sp_model.PieceToId(piece) for piece in pieces]
142 | return ids
143 |
144 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/experiments/__init__.py
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/common_utils.py:
--------------------------------------------------------------------------------
1 | from data_utils import DataFormat
2 |
3 |
4 | def dump_rows(rows, out_path, data_format):
5 | """
6 | output files should have following format
7 | :param rows:
8 | :param out_path:
9 | :return:
10 | """
11 | with open(out_path, "w", encoding="utf-8") as out_f:
12 | row0 = rows[0]
13 | #data_format = detect_format(row0)
14 | for row in rows:
15 | #assert data_format == detect_format(row), row
16 | if data_format == DataFormat.PremiseOnly:
17 | for col in ["uid", "label", "premise"]:
18 | if "\t" in str(row[col]):
19 | import pdb; pdb.set_trace()
20 | out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"]))
21 | elif data_format == DataFormat.PremiseAndOneHypothesis:
22 | for col in ["uid", "label", "premise", "hypothesis"]:
23 | if "\t" in str(row[col]):
24 | import pdb; pdb.set_trace()
25 | out_f.write("%s\t%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"], row["hypothesis"]))
26 | elif data_format == DataFormat.PremiseAndMultiHypothesis:
27 | for col in ["uid", "label", "premise"]:
28 | if "\t" in str(row[col]):
29 | import pdb; pdb.set_trace()
30 | hypothesis = row["hypothesis"]
31 | for one_hypo in hypothesis:
32 | if "\t" in str(one_hypo):
33 | import pdb; pdb.set_trace()
34 | hypothesis = "\t".join(hypothesis)
35 | out_f.write("%s\t%s\t%s\t%s\t%s\n" % (row["uid"], row["ruid"], row["label"], row["premise"], hypothesis))
36 | elif data_format == DataFormat.Sequence:
37 | for col in ["uid", "label", "premise"]:
38 | if "\t" in str(row[col]):
39 | import pdb; pdb.set_trace()
40 | out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"]))
41 | else:
42 | raise ValueError(data_format)
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/exp_def.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from mt_bluebert.data_utils.vocab import Vocabulary
3 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat, EncoderModelType
4 | from mt_bluebert.data_utils.metrics import Metric
5 |
6 | class TaskDefs:
7 | def __init__(self, task_def_path):
8 | self._task_def_dic = yaml.safe_load(open(task_def_path))
9 | global_map = {}
10 | n_class_map = {}
11 | data_type_map = {}
12 | task_type_map = {}
13 | metric_meta_map = {}
14 | enable_san_map = {}
15 | dropout_p_map = {}
16 | encoderType_map = {}
17 | uniq_encoderType = set()
18 | for task, task_def in self._task_def_dic.items():
19 | assert "_" not in task, "task name should not contain '_', current task name: %s" % task
20 | n_class_map[task] = task_def["n_class"]
21 | data_format = DataFormat[task_def["data_format"]]
22 | data_type_map[task] = data_format
23 | task_type_map[task] = TaskType[task_def["task_type"]]
24 | metric_meta_map[task] = tuple(Metric[metric_name] for metric_name in task_def["metric_meta"])
25 | enable_san_map[task] = task_def["enable_san"]
26 | uniq_encoderType.add(EncoderModelType[task_def["encoder_type"]])
27 | if "labels" in task_def:
28 | labels = task_def["labels"]
29 | label_mapper = Vocabulary(True)
30 | for label in labels:
31 | label_mapper.add(label)
32 | global_map[task] = label_mapper
33 | if "dropout_p" in task_def:
34 | dropout_p_map[task] = task_def["dropout_p"]
35 |
36 | assert len(uniq_encoderType) == 1, 'The shared encoder has to be the same.'
37 | self.global_map = global_map
38 | self.n_class_map = n_class_map
39 | self.data_type_map = data_type_map
40 | self.task_type_map = task_type_map
41 | self.metric_meta_map = metric_meta_map
42 | self.enable_san_map = enable_san_map
43 | self.dropout_p_map = dropout_p_map
44 | self.encoderType = uniq_encoderType.pop()
45 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/squad/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/experiments/squad/__init__.py
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/squad/squad_prepro.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from sys import path
4 | import json
5 | path.append(os.getcwd())
6 | from data_utils.log_wrapper import create_logger
7 | from experiments.common_utils import dump_rows
8 |
9 | logger = create_logger(__name__, to_disk=True, log_file='squad_prepro.log')
10 |
11 | def normalize_qa_field(s: str, replacement_list):
12 | for replacement in replacement_list:
13 | s = s.replace(replacement, " " * len(replacement)) # ensure answer_start and answer_end still valid
14 | return s
15 |
16 | END = 'EOSEOS'
17 | def load_data(path, is_train=True, v2_on=False):
18 | rows = []
19 | with open(path, encoding="utf8") as f:
20 | data = json.load(f)['data']
21 | for article in data:
22 | for paragraph in article['paragraphs']:
23 | context = paragraph['context']
24 | if v2_on:
25 | context = '{} {}'.format(context, END)
26 | for qa in paragraph['qas']:
27 | uid, question = qa['id'], qa['question']
28 | answers = qa.get('answers', [])
29 | # used for v2.0
30 | is_impossible = qa.get('is_impossible', False)
31 | label = 1 if is_impossible else 0
32 | if (v2_on and label < 1 and len(answers) < 1) or ((not v2_on) and len(answers) < 1):
33 | # detect inconsistent data
34 | # * for v2, the row is possible but has no answer
35 | # * for v1, all questions should have answer
36 | continue
37 | if len(answers) > 0:
38 | answer = answers[0]['text']
39 | answer_start = answers[0]['answer_start']
40 | answer_end = answer_start + len(answer)
41 | else:
42 | # for questions without answers, give a fake answer
43 | answer = END
44 | answer_start = len(context) - len(END)
45 | answer_end = len(context)
46 | answer = normalize_qa_field(answer, ["\n", "\t", ":::"])
47 | context = normalize_qa_field(context, ["\n", "\t"])
48 | question = normalize_qa_field(question, ["\n", "\t"])
49 | sample = {'uid': uid, 'premise': context, 'hypothesis': question,
50 | 'label': "%s:::%s:::%s:::%s" % (answer_start, answer_end, label, answer)}
51 | rows.append(sample)
52 | return rows
53 |
54 | def parse_args():
55 | parser = argparse.ArgumentParser(description='Preprocessing SQUAD data.')
56 | parser.add_argument('--root_dir', type=str, default='data')
57 | args = parser.parse_args()
58 | return args
59 |
60 | def main(args):
61 | root = args.root_dir
62 | assert os.path.exists(root)
63 |
64 | squad_train_path = os.path.join(root, 'squad/train.json')
65 | squad_dev_path = os.path.join(root, 'squad/dev.json')
66 | squad_v2_train_path = os.path.join(root, 'squad_v2/train.json')
67 | squad_v2_dev_path = os.path.join(root, 'squad_v2/dev.json')
68 |
69 | squad_train_data = load_data(squad_train_path)
70 | squad_dev_data = load_data(squad_dev_path, is_train=False)
71 | logger.info('Loaded {} squad train samples'.format(len(squad_train_data)))
72 | logger.info('Loaded {} squad dev samples'.format(len(squad_dev_data)))
73 |
74 | squad_v2_train_data = load_data(squad_v2_train_path, v2_on=True)
75 | squad_v2_dev_data = load_data(squad_v2_dev_path, is_train=False, v2_on=True)
76 | logger.info('Loaded {} squad_v2 train samples'.format(len(squad_v2_train_data)))
77 | logger.info('Loaded {} squad_v2 dev samples'.format(len(squad_v2_dev_data)))
78 |
79 | canonical_data_suffix = "canonical_data"
80 | canonical_data_root = os.path.join(root, canonical_data_suffix)
81 | if not os.path.isdir(canonical_data_root):
82 | os.mkdir(canonical_data_root)
83 |
84 | squad_train_fout = os.path.join(canonical_data_root, 'squad_train.tsv')
85 | squad_dev_fout = os.path.join(canonical_data_root, 'squad_dev.tsv')
86 | dump_rows(squad_train_data, squad_train_fout)
87 | dump_rows(squad_dev_data, squad_dev_fout)
88 | logger.info('done with squad')
89 |
90 | squad_v2_train_fout = os.path.join(canonical_data_root, 'squad-v2_train.tsv')
91 | squad_v2_dev_fout = os.path.join(canonical_data_root, 'squad-v2_dev.tsv')
92 | dump_rows(squad_v2_train_data, squad_v2_train_fout)
93 | dump_rows(squad_v2_dev_data, squad_v2_dev_fout)
94 | logger.info('done with squad_v2')
95 |
96 |
97 |
98 | if __name__ == '__main__':
99 | args = parse_args()
100 | main(args)
101 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/squad/squad_task_def.yml:
--------------------------------------------------------------------------------
1 | squad-v2:
2 | data_format: PremiseAndOneHypothesis
3 | encoder_type: BERT
4 | dropout_p: 0.05
5 | enable_san: false
6 | metric_meta:
7 | - ACC
8 | - MCC
9 | n_class: 1
10 | task_type: Span
11 | split_names:
12 | - train
13 | - dev
14 | squad:
15 | data_format: PremiseAndOneHypothesis
16 | encoder_type: BERT
17 | dropout_p: 0.05
18 | enable_san: false
19 | metric_meta:
20 | - ACC
21 | - MCC
22 | n_class: 1
23 | task_type: Span
24 | split_names:
25 | - train
26 | - dev
27 |
28 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/squad/squad_utils.py:
--------------------------------------------------------------------------------
1 | from data_utils.task_def import EncoderModelType
2 |
3 |
4 | def calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, encoderModelType,
5 | verbose=False):
6 | """
7 |
8 | :param context:
9 | :param question:
10 | :param answer:
11 | :param answer_start:
12 | :param answer_end:
13 | :param tokenizer:
14 | :param encoderModelType:
15 | :param verbose:
16 | :return: span_start, span_end
17 | """
18 | assert encoderModelType == EncoderModelType.BERT
19 | prefix = context[:answer_start]
20 | prefix_tokens = tokenizer.tokenize(prefix)
21 | full = context[:answer_end]
22 | full_tokens = tokenizer.tokenize(full)
23 | span_start = len(prefix_tokens)
24 | span_end = len(full_tokens)
25 | span_tokens = full_tokens[span_start: span_end]
26 | recovered_answer = " ".join(span_tokens).replace(" ##", "")
27 | cleaned_answer = " ".join(tokenizer.basic_tokenizer.tokenize(answer))
28 | if verbose:
29 | try:
30 | assert recovered_answer == cleaned_answer, "answer: %s, recovered_answer: %s, question: %s, select:%s ext_select:%s context: %s" % (
31 | cleaned_answer, recovered_answer, question, context[answer_start:answer_end],
32 | context[answer_start - 5:answer_end + 5], context)
33 | except Exception as e:
34 | pass
35 | print(e)
36 | return span_start, span_end
37 |
38 |
39 | def parse_squad_label(label):
40 | """
41 |
42 | :param label:
43 | :return: answer_start, answer_end, answer, is_impossible
44 | """
45 | answer_start, answer_end, is_impossible, answer = label.split(":::")
46 | answer_start = int(answer_start)
47 | answer_end = int(answer_end)
48 | is_impossible = int(is_impossible)
49 | return answer_start, answer_end, answer, is_impossible
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/experiments/squad/verify_calc_span.py:
--------------------------------------------------------------------------------
1 | from pytorch_pretrained_bert import BertTokenizer
2 | from data_utils.task_def import EncoderModelType
3 | from experiments.squad.squad_utils import calc_tokenized_span_range, parse_squad_label
4 |
5 | model = "bert-base-uncased"
6 | do_lower_case = True
7 | tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=do_lower_case)
8 |
9 | for no, line in enumerate(open(r"data\canonical_data\squad_v2_train.tsv", encoding="utf-8")):
10 | if no % 1000 == 0:
11 | print(no)
12 | uid, label, context, question = line.strip().split("\t")
13 | answer_start, answer_end, answer, is_impossible = parse_squad_label(label)
14 | calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, EncoderModelType.BERT,
15 | verbose=True)
16 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/module/__init__.py
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import torch
3 | import math
4 | from torch.nn.functional import tanh, relu, prelu, leaky_relu, sigmoid, elu, selu
5 | from torch.nn.init import uniform, normal, eye, xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, orthogonal
6 |
7 | def linear(x):
8 | return x
9 |
10 | def swish(x):
11 | return x * sigmoid(x)
12 |
13 | def bertgelu(x):
14 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
15 |
16 | def gptgelu(x):
17 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
18 |
19 | # default gelue
20 | gelu = bertgelu
21 |
22 | def activation(func_a):
23 | """Activation function wrapper
24 | """
25 | try:
26 | f = eval(func_a)
27 | except:
28 | f = linear
29 | return f
30 |
31 | def init_wrapper(init='xavier_uniform'):
32 | return eval(init)
33 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/dropout_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class DropoutWrapper(nn.Module):
7 | """
8 | This is a dropout wrapper which supports the fix mask dropout
9 | """
10 | def __init__(self, dropout_p=0, enable_vbp=True):
11 | super(DropoutWrapper, self).__init__()
12 | """variational dropout means fix dropout mask
13 | ref: https://discuss.pytorch.org/t/dropout-for-rnns/633/11
14 | """
15 | self.enable_variational_dropout = enable_vbp
16 | self.dropout_p = dropout_p
17 |
18 | def forward(self, x):
19 | """
20 | :param x: batch * len * input_size
21 | """
22 | if self.training == False or self.dropout_p == 0:
23 | return x
24 |
25 | if len(x.size()) == 3:
26 | mask = 1.0 / (1-self.dropout_p) * torch.bernoulli((1-self.dropout_p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1))
27 | mask.requires_grad = False
28 | return mask.unsqueeze(1).expand_as(x) * x
29 | else:
30 | return F.dropout(x, p=self.dropout_p, training=self.training)
31 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/my_optim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | from copy import deepcopy
3 | import torch
4 | from torch.nn import Parameter
5 | from functools import wraps
6 |
7 | class EMA:
8 | def __init__(self, gamma, model):
9 | super(EMA, self).__init__()
10 | self.gamma = gamma
11 | self.shadow = {}
12 | self.model = model
13 | self.setup()
14 |
15 | def setup(self):
16 | for name, para in self.model.named_parameters():
17 | if para.requires_grad:
18 | self.shadow[name] = para.clone()
19 | def cuda(self):
20 | for k, v in self.shadow.items():
21 | self.shadow[k] = v.cuda()
22 |
23 | def update(self):
24 | for name,para in self.model.named_parameters():
25 | if para.requires_grad:
26 | self.shadow[name] = (1.0 - self.gamma) * para + self.gamma * self.shadow[name]
27 |
28 | def swap_parameters(self):
29 | for name, para in self.model.named_parameters():
30 | if para.requires_grad:
31 | temp_data = para.data
32 | para.data = self.shadow[name].data
33 | self.shadow[name].data = temp_data
34 |
35 | def state_dict(self):
36 | return self.shadow
37 |
38 |
39 | # Adapted from
40 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py
41 | # and https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py
42 | def _norm(p, dim):
43 | """Computes the norm over all dimensions except dim"""
44 | if dim is None:
45 | return p.norm()
46 | elif dim == 0:
47 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
48 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
49 | elif dim == p.dim() - 1:
50 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
51 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
52 | else:
53 | return _norm(p.transpose(0, dim), 0).transpose(0, dim)
54 |
55 |
56 | def _dummy(*args, **kwargs):
57 | # We need to replace flatten_parameters with a nothing function
58 | return
59 |
60 |
61 | class WeightNorm(torch.nn.Module):
62 |
63 | def __init__(self, weights, dim):
64 | super(WeightNorm, self).__init__()
65 | self.weights = weights
66 | self.dim = dim
67 |
68 | def compute_weight(self, module, name):
69 | g = getattr(module, name + '_g')
70 | v = getattr(module, name + '_v')
71 | return v * (g / _norm(v, self.dim))
72 |
73 | @staticmethod
74 | def apply(module, weights, dim):
75 | # Terrible temporary solution to an issue regarding compacting weights
76 | # re: CUDNN RNN
77 | if issubclass(type(module), torch.nn.RNNBase):
78 | module.flatten_parameters = _dummy
79 | if weights is None: # do for all weight params
80 | weights = [w for w in module._parameters.keys() if 'weight' in w]
81 | fn = WeightNorm(weights, dim)
82 | for name in weights:
83 | if hasattr(module, name):
84 | print('Applying weight norm to {} - {}'.format(str(module), name))
85 | weight = getattr(module, name)
86 | del module._parameters[name]
87 | module.register_parameter(
88 | name + '_g', Parameter(_norm(weight, dim).data))
89 | module.register_parameter(name + '_v', Parameter(weight.data))
90 | setattr(module, name, fn.compute_weight(module, name))
91 |
92 | module.register_forward_pre_hook(fn)
93 |
94 | return fn
95 |
96 | def remove(self, module):
97 | for name in self.weights:
98 | weight = self.compute_weight(module)
99 | delattr(module, name)
100 | del module._parameters[name + '_g']
101 | del module._parameters[name + '_v']
102 | module.register_parameter(name, Parameter(weight.data))
103 |
104 | def __call__(self, module, inputs):
105 | for name in self.weights:
106 | setattr(module, name, self.compute_weight(module, name))
107 |
108 |
109 | def weight_norm(module, weights=None, dim=0):
110 | WeightNorm.apply(module, weights, dim)
111 | return module
112 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/san.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import torch
3 | import random
4 | import torch.nn as nn
5 | from torch.nn.utils import weight_norm
6 | from torch.nn.parameter import Parameter
7 | import torch.nn.functional as F
8 | from mt_bluebert.module.dropout_wrapper import DropoutWrapper
9 | from mt_bluebert.module.similarity import FlatSimilarityWrapper, SelfAttnWrapper
10 | from mt_bluebert.module.my_optim import weight_norm as WN
11 |
12 | SMALL_POS_NUM=1.0e-30
13 |
14 | def generate_mask(new_data, dropout_p=0.0, is_training=False):
15 | if not is_training: dropout_p = 0.0
16 | new_data = (1-dropout_p) * (new_data.zero_() + 1)
17 | for i in range(new_data.size(0)):
18 | one = random.randint(0, new_data.size(1)-1)
19 | new_data[i][one] = 1
20 | mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data)
21 | mask.requires_grad = False
22 | return mask
23 |
24 |
25 | class Classifier(nn.Module):
26 | def __init__(self, x_size, y_size, opt, prefix='decoder', dropout=None):
27 | super(Classifier, self).__init__()
28 | self.opt = opt
29 | if dropout is None:
30 | self.dropout = DropoutWrapper(opt.get('{}_dropout_p'.format(prefix), 0))
31 | else:
32 | self.dropout = dropout
33 | self.merge_opt = opt.get('{}_merge_opt'.format(prefix), 0)
34 | self.weight_norm_on = opt.get('{}_weight_norm_on'.format(prefix), False)
35 |
36 | if self.merge_opt == 1:
37 | self.proj = nn.Linear(x_size * 4, y_size)
38 | else:
39 | self.proj = nn.Linear(x_size * 2, y_size)
40 |
41 | if self.weight_norm_on:
42 | self.proj = weight_norm(self.proj)
43 |
44 | def forward(self, x1, x2, mask=None):
45 | if self.merge_opt == 1:
46 | x = torch.cat([x1, x2, (x1 - x2).abs(), x1 * x2], 1)
47 | else:
48 | x = torch.cat([x1, x2], 1)
49 | x = self.dropout(x)
50 | scores = self.proj(x)
51 | return scores
52 |
53 | class SANClassifier(nn.Module):
54 | """Implementation of Stochastic Answer Networks for Natural Language Inference, Xiaodong Liu, Kevin Duh and Jianfeng Gao
55 | https://arxiv.org/abs/1804.07888
56 | """
57 | def __init__(self, x_size, h_size, label_size, opt={}, prefix='decoder', dropout=None):
58 | super(SANClassifier, self).__init__()
59 | if dropout is None:
60 | self.dropout = DropoutWrapper(opt.get('{}_dropout_p'.format(self.prefix), 0))
61 | else:
62 | self.dropout = dropout
63 | self.prefix = prefix
64 | self.query_wsum = SelfAttnWrapper(x_size, prefix='mem_cum', opt=opt, dropout=self.dropout)
65 | self.attn = FlatSimilarityWrapper(x_size, h_size, prefix, opt, self.dropout)
66 | self.rnn_type = '{}{}'.format(opt.get('{}_rnn_type'.format(prefix), 'gru').upper(), 'Cell')
67 | self.rnn =getattr(nn, self.rnn_type)(x_size, h_size)
68 | self.num_turn = opt.get('{}_num_turn'.format(prefix), 5)
69 | self.opt = opt
70 | self.mem_random_drop = opt.get('{}_mem_drop_p'.format(prefix), 0)
71 | self.mem_type = opt.get('{}_mem_type'.format(prefix), 0)
72 | self.weight_norm_on = opt.get('{}_weight_norm_on'.format(prefix), False)
73 | self.label_size = label_size
74 | self.dump_state = opt.get('dump_state_on', False)
75 | self.alpha = Parameter(torch.zeros(1, 1), requires_grad=False)
76 | if self.weight_norm_on:
77 | self.rnn = WN(self.rnn)
78 |
79 | self.classifier = Classifier(x_size, self.label_size, opt, prefix=prefix, dropout=self.dropout)
80 |
81 | def forward(self, x, h0, x_mask=None, h_mask=None):
82 | h0 = self.query_wsum(h0, h_mask)
83 | if type(self.rnn) is nn.LSTMCell:
84 | c0 = h0.new(h0.size()).zero_()
85 | scores_list = []
86 | for turn in range(self.num_turn):
87 | att_scores = self.attn(x, h0, x_mask)
88 | x_sum = torch.bmm(F.softmax(att_scores, 1).unsqueeze(1), x).squeeze(1)
89 | scores = self.classifier(x_sum, h0)
90 | scores_list.append(scores)
91 | # next turn
92 | if self.rnn is not None:
93 | h0 = self.dropout(h0)
94 | if type(self.rnn) is nn.LSTMCell:
95 | h0, c0 = self.rnn(x_sum, (h0, c0))
96 | else:
97 | h0 = self.rnn(x_sum, h0)
98 | if self.mem_type == 1:
99 | mask = generate_mask(self.alpha.data.new(x.size(0), self.num_turn), self.mem_random_drop, self.training)
100 | mask = [m.contiguous() for m in torch.unbind(mask, 1)]
101 | tmp_scores_list = [mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1) for idx, inp in enumerate(scores_list)]
102 | scores = torch.stack(tmp_scores_list, 2)
103 | scores = torch.mean(scores, 2)
104 | scores = torch.log(scores)
105 | else:
106 | scores = scores_list[-1]
107 | if self.dump_state:
108 | return scores, scores_list
109 | else:
110 | return scores
111 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/module/sub_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.parameter import Parameter
6 |
7 | class LayerNorm(nn.Module):
8 | #ref: https://github.com/pytorch/pytorch/issues/1959
9 | # :https://arxiv.org/pdf/1607.06450.pdf
10 | def __init__(self, hidden_size, eps=1e-4):
11 | super(LayerNorm, self).__init__()
12 | self.alpha = Parameter(torch.ones(1,1,hidden_size)) # gain g
13 | self.beta = Parameter(torch.zeros(1,1,hidden_size)) # bias b
14 | self.eps = eps
15 |
16 | def forward(self, x):
17 | """
18 | Args:
19 | :param x: batch * len * input_size
20 |
21 | Returns:
22 | normalized x
23 | """
24 | mu = torch.mean(x, 2, keepdim=True).expand_as(x)
25 | sigma = torch.std(x, 2, keepdim=True).expand_as(x)
26 | return (x - mu) / (sigma + self.eps) * self.alpha.expand_as(x) + self.beta.expand_as(x)
27 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/mt_dnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/mt-bluebert/mt_bluebert/mt_dnn/__init__.py
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/mt_dnn/batcher.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Microsoft. All rights reserved.
3 | import sys
4 | import json
5 | import torch
6 | import random
7 | from shutil import copyfile
8 | from mt_bluebert.data_utils.task_def import TaskType, DataFormat
9 | from mt_bluebert.data_utils.task_def import EncoderModelType
10 |
11 | UNK_ID=100
12 | BOS_ID=101
13 |
14 | class BatchGen:
15 | def __init__(self, data, batch_size=32, gpu=True, is_train=True,
16 | maxlen=128, dropout_w=0.005,
17 | do_batch=True, weighted_on=False,
18 | task_id=0,
19 | task=None,
20 | task_type=TaskType.Classification,
21 | data_type=DataFormat.PremiseOnly,
22 | soft_label=False,
23 | encoder_type=EncoderModelType.BERT):
24 | self.batch_size = batch_size
25 | self.maxlen = maxlen
26 | self.is_train = is_train
27 | self.gpu = gpu
28 | self.weighted_on = weighted_on
29 | self.data = data
30 | self.task_id = task_id
31 | self.pairwise_size = 1
32 | self.data_type = data_type
33 | self.task_type=task_type
34 | self.encoder_type = encoder_type
35 | # soft label used for knowledge distillation
36 | self.soft_label_on = soft_label
37 | if do_batch:
38 | if is_train:
39 | indices = list(range(len(self.data)))
40 | random.shuffle(indices)
41 | data = [self.data[i] for i in indices]
42 | self.data = BatchGen.make_baches(data, batch_size)
43 | self.offset = 0
44 | self.dropout_w = dropout_w
45 |
46 | @staticmethod
47 | def make_baches(data, batch_size=32):
48 | return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
49 |
50 | @staticmethod
51 | def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None):
52 | assert task_type is not None
53 | with open(path, 'r', encoding='utf-8') as reader:
54 | data = []
55 | cnt = 0
56 | for line in reader:
57 | sample = json.loads(line)
58 | sample['factor'] = factor
59 | cnt += 1
60 | if is_train:
61 | if (task_type == TaskType.Ranking) and (len(sample['token_id'][0]) > maxlen or len(sample['token_id'][1]) > maxlen):
62 | continue
63 | if (task_type != TaskType.Ranking) and (len(sample['token_id']) > maxlen):
64 | continue
65 | data.append(sample)
66 | print('Loaded {} samples out of {}'.format(len(data), cnt))
67 | return data
68 |
69 | def reset(self):
70 | if self.is_train:
71 | indices = list(range(len(self.data)))
72 | random.shuffle(indices)
73 | self.data = [self.data[i] for i in indices]
74 | self.offset = 0
75 |
76 | def __random_select__(self, arr):
77 | if self.dropout_w > 0:
78 | return [UNK_ID if random.uniform(0, 1) < self.dropout_w else e for e in arr]
79 | else: return arr
80 |
81 | def __len__(self):
82 | return len(self.data)
83 |
84 | def patch(self, v):
85 | v = v.cuda(non_blocking=True)
86 | return v
87 |
88 | @staticmethod
89 | def todevice(v, device):
90 | v = v.to(device)
91 | return v
92 |
93 | def rebacth(self, batch):
94 | newbatch = []
95 | for sample in batch:
96 | size = len(sample['token_id'])
97 | self.pairwise_size = size
98 | assert size == len(sample['type_id'])
99 | for idx in range(0, size):
100 | token_id = sample['token_id'][idx]
101 | type_id = sample['type_id'][idx]
102 | uid = sample['ruid'][idx]
103 | olab = sample['olabel'][idx]
104 | newbatch.append({'uid': uid, 'token_id': token_id, 'type_id': type_id, 'label':sample['label'], 'true_label': olab})
105 | return newbatch
106 |
107 | def __if_pair__(self, data_type):
108 | return data_type in [DataFormat.PremiseAndOneHypothesis, DataFormat.PremiseAndMultiHypothesis]
109 |
110 | def __iter__(self):
111 | while self.offset < len(self):
112 | batch = self.data[self.offset]
113 | if self.task_type == TaskType.Ranking:
114 | batch = self.rebacth(batch)
115 |
116 | # prepare model input
117 | batch_data, batch_info = self._prepare_model_input(batch)
118 | batch_info['task_id'] = self.task_id # used for select correct decoding head
119 | batch_info['input_len'] = len(batch_data) # used to select model inputs
120 | # select different loss function and other difference in training and testing
121 | batch_info['task_type'] = self.task_type
122 | batch_info['pairwise_size'] = self.pairwise_size # need for ranking task
123 | if self.gpu:
124 | for i, item in enumerate(batch_data):
125 | batch_data[i] = self.patch(item.pin_memory())
126 |
127 | # add label
128 | labels = [sample['label'] for sample in batch]
129 | if self.is_train:
130 | # in training model, label is used by Pytorch, so would be tensor
131 | if self.task_type == TaskType.Regression:
132 | batch_data.append(torch.FloatTensor(labels))
133 | batch_info['label'] = len(batch_data) - 1
134 | elif self.task_type in (TaskType.Classification, TaskType.Ranking):
135 | batch_data.append(torch.LongTensor(labels))
136 | batch_info['label'] = len(batch_data) - 1
137 | elif self.task_type == TaskType.Span:
138 | start = [sample['token_start'] for sample in batch]
139 | end = [sample['token_end'] for sample in batch]
140 | batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)])
141 | batch_info['start'] = len(batch_data) - 2
142 | batch_info['end'] = len(batch_data) - 1
143 | elif self.task_type == TaskType.SequenceLabeling:
144 | batch_size = self._get_batch_size(batch)
145 | tok_len = self._get_max_len(batch, key='token_id')
146 | tlab = torch.LongTensor(batch_size, tok_len).fill_(-1)
147 | for i, label in enumerate(labels):
148 | ll = len(label)
149 | tlab[i, : ll] = torch.LongTensor(label)
150 | batch_data.append(tlab)
151 | batch_info['label'] = len(batch_data) - 1
152 |
153 | # soft label generated by ensemble models for knowledge distillation
154 | if self.soft_label_on and (batch[0].get('softlabel', None) is not None):
155 | assert self.task_type != TaskType.Span # Span task doesn't support soft label yet.
156 | sortlabels = [sample['softlabel'] for sample in batch]
157 | sortlabels = torch.FloatTensor(sortlabels)
158 | batch_info['soft_label'] = self.patch(sortlabels.pin_memory()) if self.gpu else sortlabels
159 | else:
160 | # in test model, label would be used for evaluation
161 | batch_info['label'] = labels
162 | if self.task_type == TaskType.Ranking:
163 | batch_info['true_label'] = [sample['true_label'] for sample in batch]
164 |
165 | batch_info['uids'] = [sample['uid'] for sample in batch] # used in scoring
166 | self.offset += 1
167 | yield batch_info, batch_data
168 |
169 | def _get_max_len(self, batch, key='token_id'):
170 | tok_len = max(len(x[key]) for x in batch)
171 | return tok_len
172 |
173 | def _get_batch_size(self, batch):
174 | return len(batch)
175 |
176 | def _prepare_model_input(self, batch):
177 | batch_size = self._get_batch_size(batch)
178 | tok_len = self._get_max_len(batch, key='token_id')
179 | #tok_len = max(len(x['token_id']) for x in batch)
180 | hypothesis_len = max(len(x['type_id']) - sum(x['type_id']) for x in batch)
181 | if self.encoder_type == EncoderModelType.ROBERTA:
182 | token_ids = torch.LongTensor(batch_size, tok_len).fill_(1)
183 | type_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
184 | masks = torch.LongTensor(batch_size, tok_len).fill_(0)
185 | else:
186 | token_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
187 | type_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
188 | masks = torch.LongTensor(batch_size, tok_len).fill_(0)
189 | if self.__if_pair__(self.data_type):
190 | premise_masks = torch.ByteTensor(batch_size, tok_len).fill_(1)
191 | hypothesis_masks = torch.ByteTensor(batch_size, hypothesis_len).fill_(1)
192 | for i, sample in enumerate(batch):
193 | select_len = min(len(sample['token_id']), tok_len)
194 | tok = sample['token_id']
195 | if self.is_train:
196 | tok = self.__random_select__(tok)
197 | token_ids[i, :select_len] = torch.LongTensor(tok[:select_len])
198 | type_ids[i, :select_len] = torch.LongTensor(sample['type_id'][:select_len])
199 | masks[i, :select_len] = torch.LongTensor([1] * select_len)
200 | if self.__if_pair__(self.data_type):
201 | hlen = len(sample['type_id']) - sum(sample['type_id'])
202 | hypothesis_masks[i, :hlen] = torch.LongTensor([0] * hlen)
203 | for j in range(hlen, select_len):
204 | premise_masks[i, j] = 0
205 | if self.__if_pair__(self.data_type):
206 | batch_info = {
207 | 'token_id': 0,
208 | 'segment_id': 1,
209 | 'mask': 2,
210 | 'premise_mask': 3,
211 | 'hypothesis_mask': 4
212 | }
213 | batch_data = [token_ids, type_ids, masks, premise_masks, hypothesis_masks]
214 | else:
215 | batch_info = {
216 | 'token_id': 0,
217 | 'segment_id': 1,
218 | 'mask': 2
219 | }
220 | batch_data = [token_ids, type_ids, masks]
221 | return batch_data, batch_info
222 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/mt_dnn/inference.py:
--------------------------------------------------------------------------------
1 | from mt_bluebert.data_utils.metrics import calc_metrics
2 |
3 | def eval_model(model, data, metric_meta, use_cuda=True, with_label=True, label_mapper=None):
4 | data.reset()
5 | if use_cuda:
6 | model.cuda()
7 | predictions = []
8 | golds = []
9 | scores = []
10 | ids = []
11 | metrics = {}
12 | for batch_meta, batch_data in data:
13 | score, pred, gold = model.predict(batch_meta, batch_data)
14 | predictions.extend(pred)
15 | golds.extend(gold)
16 | scores.extend(score)
17 | ids.extend(batch_meta['uids'])
18 | if with_label:
19 | metrics = calc_metrics(metric_meta, golds, predictions, scores, label_mapper)
20 | return metrics, predictions, scores, golds, ids
21 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/mt_dnn/matcher.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Microsoft. All rights reserved.
3 | import torch.nn as nn
4 | from pytorch_pretrained_bert.modeling import BertConfig, BertLayerNorm, BertModel
5 |
6 | from mt_bluebert.module.dropout_wrapper import DropoutWrapper
7 | from mt_bluebert.module.san import SANClassifier
8 | from mt_bluebert.data_utils.task_def import EncoderModelType, TaskType
9 |
10 |
11 | class LinearPooler(nn.Module):
12 | def __init__(self, hidden_size):
13 | super(LinearPooler, self).__init__()
14 | self.dense = nn.Linear(hidden_size, hidden_size)
15 | self.activation = nn.Tanh()
16 |
17 | def forward(self, hidden_states):
18 | first_token_tensor = hidden_states[:, 0]
19 | pooled_output = self.dense(first_token_tensor)
20 | pooled_output = self.activation(pooled_output)
21 | return pooled_output
22 |
23 | class SANBertNetwork(nn.Module):
24 | def __init__(self, opt, bert_config=None):
25 | super(SANBertNetwork, self).__init__()
26 | self.dropout_list = nn.ModuleList()
27 | self.encoder_type = opt['encoder_type']
28 | if opt['encoder_type'] == EncoderModelType.ROBERTA:
29 | from fairseq.models.roberta import RobertaModel
30 | self.bert = RobertaModel.from_pretrained(opt['init_checkpoint'])
31 | hidden_size = self.bert.args.encoder_embed_dim
32 | self.pooler = LinearPooler(hidden_size)
33 | else:
34 | self.bert_config = BertConfig.from_dict(opt)
35 | self.bert = BertModel(self.bert_config)
36 | hidden_size = self.bert_config.hidden_size
37 |
38 | if opt.get('dump_feature', False):
39 | self.opt = opt
40 | return
41 | if opt['update_bert_opt'] > 0:
42 | for p in self.bert.parameters():
43 | p.requires_grad = False
44 | self.decoder_opt = opt['answer_opt']
45 | self.task_types = opt["task_types"]
46 | self.scoring_list = nn.ModuleList()
47 | labels = [int(ls) for ls in opt['label_size'].split(',')]
48 | task_dropout_p = opt['tasks_dropout_p']
49 |
50 | for task, lab in enumerate(labels):
51 | decoder_opt = self.decoder_opt[task]
52 | task_type = self.task_types[task]
53 | dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout'])
54 | self.dropout_list.append(dropout)
55 | if task_type == TaskType.Span:
56 | assert decoder_opt != 1
57 | out_proj = nn.Linear(hidden_size, 2)
58 | elif task_type == TaskType.SequenceLabeling:
59 | out_proj = nn.Linear(hidden_size, lab)
60 | else:
61 | if decoder_opt == 1:
62 | out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout)
63 | else:
64 | out_proj = nn.Linear(hidden_size, lab)
65 | self.scoring_list.append(out_proj)
66 |
67 | self.opt = opt
68 | self._my_init()
69 |
70 | def _my_init(self):
71 | def init_weights(module):
72 | if isinstance(module, (nn.Linear, nn.Embedding)):
73 | # Slightly different from the TF version which uses truncated_normal for initialization
74 | # cf https://github.com/pytorch/pytorch/pull/5617
75 | module.weight.data.normal_(mean=0.0, std=0.02 * self.opt['init_ratio'])
76 | elif isinstance(module, BertLayerNorm):
77 | # Slightly different from the BERT pytorch version, which should be a bug.
78 | # Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@.
79 | # Layer normalization (https://arxiv.org/abs/1607.06450)
80 | # support both old/latest version
81 | if 'beta' in dir(module) and 'gamma' in dir(module):
82 | module.beta.data.zero_()
83 | module.gamma.data.fill_(1.0)
84 | else:
85 | module.bias.data.zero_()
86 | module.weight.data.fill_(1.0)
87 | if isinstance(module, nn.Linear):
88 | module.bias.data.zero_()
89 |
90 | self.apply(init_weights)
91 |
92 | def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0):
93 | if self.encoder_type == EncoderModelType.ROBERTA:
94 | sequence_output = self.bert.extract_features(input_ids)
95 | pooled_output = self.pooler(sequence_output)
96 | else:
97 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
98 | sequence_output = all_encoder_layers[-1]
99 |
100 | decoder_opt = self.decoder_opt[task_id]
101 | task_type = self.task_types[task_id]
102 | if task_type == TaskType.Span:
103 | assert decoder_opt != 1
104 | sequence_output = self.dropout_list[task_id](sequence_output)
105 | logits = self.scoring_list[task_id](sequence_output)
106 | start_scores, end_scores = logits.split(1, dim=-1)
107 | start_scores = start_scores.squeeze(-1)
108 | end_scores = end_scores.squeeze(-1)
109 | return start_scores, end_scores
110 | elif task_type == TaskType.SequenceLabeling:
111 | pooled_output = all_encoder_layers[-1]
112 | pooled_output = self.dropout_list[task_id](pooled_output)
113 | pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2))
114 | logits = self.scoring_list[task_id](pooled_output)
115 | return logits
116 | else:
117 | if decoder_opt == 1:
118 | max_query = hyp_mask.size(1)
119 | assert max_query > 0
120 | assert premise_mask is not None
121 | assert hyp_mask is not None
122 | hyp_mem = sequence_output[:, :max_query, :]
123 | logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask)
124 | else:
125 | pooled_output = self.dropout_list[task_id](pooled_output)
126 | logits = self.scoring_list[task_id](pooled_output)
127 | return logits
128 |
--------------------------------------------------------------------------------
/mt-bluebert/mt_bluebert/pmetrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2018, Yifan Peng
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification,
6 | are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice, this
12 | list of conditions and the following disclaimer in the documentation and/or
13 | other materials provided with the distribution.
14 |
15 | * Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | """
30 | from typing import List
31 |
32 | import numpy as np
33 | from sklearn import metrics
34 | from tabulate import tabulate
35 | from mt_bluebert import conlleval
36 |
37 |
38 | def _divide(x, y):
39 | try:
40 | return np.true_divide(x, y, out=np.zeros_like(x, dtype=np.float), where=y != 0)
41 | except:
42 | return np.nan
43 |
44 |
45 | def tp_tn_fp_fn(confusion_matrix):
46 | fp = np.sum(confusion_matrix, axis=0) - np.diag(confusion_matrix)
47 | fn = np.sum(confusion_matrix, axis=1) - np.diag(confusion_matrix)
48 | tp = np.diag(confusion_matrix)
49 | tn = np.sum(confusion_matrix) - (fp + fn + tp)
50 | return tp, tn, fp, fn
51 |
52 |
53 | class PReportRow:
54 | """https://en.wikipedia.org/wiki/Precision_and_recall"""
55 |
56 | def __init__(self, category, **kwargs):
57 | self.category = category
58 | self.tp = kwargs.pop('tp', np.nan)
59 | self.tn = kwargs.pop('tn', np.nan)
60 | self.fp = kwargs.pop('fp', np.nan)
61 | self.fn = kwargs.pop('fn', np.nan)
62 | self.precision = kwargs.pop('precision', _divide(self.tp, self.tp + self.fp))
63 | self.recall = kwargs.pop('recall', _divide(self.tp, self.tp + self.fn))
64 | self.f1 = kwargs.pop('f1', _divide(2 * self.precision * self.recall, self.precision + self.recall))
65 | self.specificity = kwargs.pop('specificity', _divide(self.tn, self.tn + self.fp))
66 | self.support = kwargs.pop('support', self.tp + self.fn)
67 | self.accuracy = kwargs.pop('accuracy', _divide(self.tp + self.tn, self.tp + self.tn + self.fp + self.fn))
68 | self.balanced_accuracy = kwargs.pop('balanced_accuracy', _divide(self.recall + self.specificity, 2))
69 | # different names
70 | self.sensitivity = self.recall
71 | self.positive_predictive_value = self.precision
72 | self.true_positive_rate = self.recall
73 | self.true_negative_rate = self.specificity
74 |
75 |
76 | class PReport:
77 | def __init__(self, rows: List[PReportRow]):
78 | self.rows = rows
79 | self.micro_row = self.compute_micro()
80 | self.macro_row = self.compute_macro()
81 |
82 | def compute_micro(self):
83 | tps = [row.tp for row in self.rows]
84 | tns = [row.tn for row in self.rows]
85 | fps = [row.fp for row in self.rows]
86 | fns = [row.fn for row in self.rows]
87 | return PReportRow('micro', tp=np.sum(tps), tn=np.sum(tns), fp=np.sum(fps), fn=np.sum(fns))
88 |
89 | def compute_macro(self):
90 | ps = [row.precision for row in self.rows]
91 | rs = [row.recall for row in self.rows]
92 | fs = [row.f1 for row in self.rows]
93 | return PReportRow('macro', precision=np.average(ps), recall=np.average(rs), f1=np.average(fs))
94 |
95 | def report(self, digits=3, micro=False, macro=False):
96 | headers = ['Class', 'TP', 'FP', 'FN',
97 | 'Precision', 'Recall', 'F-score',
98 | 'Support']
99 | float_formatter = ['g'] * 4 + ['.{}f'.format(digits)] * 3 + ['g']
100 |
101 | rows = self.rows
102 | if micro:
103 | rows.append(self.micro_row)
104 | if macro:
105 | rows.append(self.macro_row)
106 |
107 | table = [[r.category, r.tp, r.fp, r.fn, r.precision, r.recall, r.f1, r.support] for r in rows]
108 | return tabulate(table, showindex=False, headers=headers,
109 | tablefmt="plain", floatfmt=float_formatter)
110 |
111 | def sub_report(self, subindex) -> 'PReport':
112 | rows = [self.rows[i] for i in subindex]
113 | return PReport(rows)
114 |
115 |
116 | def ner_report_conlleval(y_true: List[List[str]], y_pred: List[List[str]]) -> PReport:
117 | """Build a text report showing the main classification metrics.
118 |
119 | Args:
120 | y_true : 2d array. Ground truth (correct) target values.
121 | y_pred : 2d array. Estimated targets as returned by a classifier.
122 | """
123 | lines = []
124 | assert len(y_true) == len(y_pred)
125 | for t_sen, p_sen in zip(y_true, y_pred):
126 | assert len(t_sen) == len(p_sen)
127 | for t_word, p_word in zip(t_sen, p_sen):
128 | lines.append(f'XXX\t{t_word}\t{p_word}\n')
129 | lines.append('\n')
130 |
131 | counts = conlleval.evaluate(lines)
132 | overall, by_type = conlleval.metrics(counts)
133 |
134 | rows = []
135 | for i, m in sorted(by_type.items()):
136 | rows.append(PReportRow(i, tp=m.tp, fp=m.fp, fn=m.fn))
137 | return PReport(rows)
138 |
139 |
140 | def blue_classification_report(y_true, y_pred, *_, **kwargs) -> PReport:
141 | """
142 | Args:
143 | y_true: (n_sample, )
144 | y_pred: (n_sample, )
145 | """
146 | classes_ = kwargs.get('classes_', None)
147 | confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
148 |
149 | tp, tn, fp, fn = tp_tn_fp_fn(confusion_matrix)
150 |
151 | if classes_ is None:
152 | classes_ = [i for i in range(confusion_matrix.shape[0])]
153 |
154 | rows = []
155 | for i, c in enumerate(classes_):
156 | rows.append(PReportRow(c, tp=tp[i], tn=tn[i], fp=fp[i], fn=fn[i]))
157 |
158 | report = PReport(rows)
159 | return report
160 |
161 |
--------------------------------------------------------------------------------
/mt-bluebert/requirements.txt:
--------------------------------------------------------------------------------
1 | -f https://download.pytorch.org/whl/torch_stable.html
2 |
3 | numpy
4 | torch==1.1.0
5 | tqdm
6 | colorlog
7 | boto3
8 | pytorch-pretrained-bert==v0.6.0
9 | regex
10 | scikit-learn
11 | pyyaml
12 | pytest
13 | sentencepiece
14 | tensorboardX
15 | tensorboard
16 | future
17 | fairseq==0.8.0
18 | seqeval==0.0.12
19 | docopt
20 | pandas
21 | tabulate
--------------------------------------------------------------------------------
/mt-bluebert/scripts/blue_prepro.sh:
--------------------------------------------------------------------------------
1 | #! /bin/sh
2 |
3 | ROOT="blue_data"
4 | BERT_PATH="${ROOT}/bluebert_models/bert_uncased_lower"
5 | datasets="clinicalsts,biosses,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe"
6 |
7 | python experiments/blue/blue_prepro.py \
8 | --root_dir $ROOT \
9 | --task_def experiments/blue/blue_task_def.yml \
10 | --datasets $datasets \
11 | --overwrite
12 |
13 | python experiments/blue/blue_prepro_std.py \
14 | --vocab $BERT_PATH/vocab.txt \
15 | --root_dir $ROOT/canonical_data \
16 | --task_def experiments/blue/blue_task_def.yml \
17 | --do_lower_case \
18 | --max_seq_len 128 \
19 | --datasets $datasets \
20 | --overwrite
21 |
--------------------------------------------------------------------------------
/mt-bluebert/scripts/convert_tf_to_pt.py:
--------------------------------------------------------------------------------
1 | # This scripts is to convert Google's TF BERT to the pytorch version which is used by mt-dnn.
2 | # It is a supplementary script.
3 | # Note that it relies on tensorflow==1.12.0 which does not support by our released docker.
4 | # If you want to use this, please install tensorflow==1.12.0 by: pip install tensorflow==1.12.0
5 | # Some codes are adapted from https://github.com/huggingface/pytorch-pretrained-BERT
6 | # by: xiaodl
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | import re
10 | import os
11 | import argparse
12 | import tensorflow as tf
13 | import torch
14 | import numpy as np
15 | from pytorch_pretrained_bert.modeling import BertConfig
16 | from sys import path
17 | path.append(os.getcwd())
18 | from mt_bluebert.mt_dnn.matcher import SANBertNetwork
19 | from mt_bluebert.data_utils.log_wrapper import create_logger
20 |
21 | logger = create_logger(__name__, to_disk=False)
22 | def model_config(parser):
23 | parser.add_argument('--update_bert_opt', default=0, type=int)
24 | parser.add_argument('--multi_gpu_on', action='store_true')
25 | parser.add_argument('--mem_cum_type', type=str, default='simple',
26 | help='bilinear/simple/defualt')
27 | parser.add_argument('--answer_num_turn', type=int, default=5)
28 | parser.add_argument('--answer_mem_drop_p', type=float, default=0.1)
29 | parser.add_argument('--answer_att_hidden_size', type=int, default=128)
30 | parser.add_argument('--answer_att_type', type=str, default='bilinear',
31 | help='bilinear/simple/defualt')
32 | parser.add_argument('--answer_rnn_type', type=str, default='gru',
33 | help='rnn/gru/lstm')
34 | parser.add_argument('--answer_sum_att_type', type=str, default='bilinear',
35 | help='bilinear/simple/defualt')
36 | parser.add_argument('--answer_merge_opt', type=int, default=1)
37 | parser.add_argument('--answer_mem_type', type=int, default=1)
38 | parser.add_argument('--answer_dropout_p', type=float, default=0.1)
39 | parser.add_argument('--answer_weight_norm_on', action='store_true')
40 | parser.add_argument('--dump_state_on', action='store_true')
41 | parser.add_argument('--answer_opt', type=int, default=0, help='0,1')
42 | parser.add_argument('--label_size', type=str, default='3')
43 | parser.add_argument('--mtl_opt', type=int, default=0)
44 | parser.add_argument('--ratio', type=float, default=0)
45 | parser.add_argument('--mix_opt', type=int, default=0)
46 | parser.add_argument('--max_seq_len', type=int, default=512)
47 | parser.add_argument('--init_ratio', type=float, default=1)
48 | parser.add_argument('--encoder_type', type=int, default=1)
49 | return parser
50 |
51 | def train_config(parser):
52 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(),
53 | help='whether to use GPU acceleration.')
54 | parser.add_argument('--log_per_updates', type=int, default=500)
55 | parser.add_argument('--epochs', type=int, default=5)
56 | parser.add_argument('--batch_size', type=int, default=8)
57 | parser.add_argument('--batch_size_eval', type=int, default=8)
58 | parser.add_argument('--optimizer', default='adamax',
59 | help='supported optimizer: adamax, sgd, adadelta, adam')
60 | parser.add_argument('--grad_clipping', type=float, default=0)
61 | parser.add_argument('--global_grad_clipping', type=float, default=1.0)
62 | parser.add_argument('--weight_decay', type=float, default=0)
63 | parser.add_argument('--learning_rate', type=float, default=5e-5)
64 | parser.add_argument('--momentum', type=float, default=0)
65 | parser.add_argument('--warmup', type=float, default=0.1)
66 | parser.add_argument('--warmup_schedule', type=str, default='warmup_linear')
67 |
68 | parser.add_argument('--vb_dropout', action='store_false')
69 | parser.add_argument('--dropout_p', type=float, default=0.1)
70 | parser.add_argument('--tasks_dropout_p', type=float, default=0.1)
71 | parser.add_argument('--dropout_w', type=float, default=0.000)
72 | parser.add_argument('--bert_dropout_p', type=float, default=0.1)
73 | parser.add_argument('--dump_feature', action='store_false')
74 |
75 | # EMA
76 | parser.add_argument('--ema_opt', type=int, default=0)
77 | parser.add_argument('--ema_gamma', type=float, default=0.995)
78 |
79 | # scheduler
80 | parser.add_argument('--have_lr_scheduler', dest='have_lr_scheduler', action='store_false')
81 | parser.add_argument('--multi_step_lr', type=str, default='10,20,30')
82 | parser.add_argument('--freeze_layers', type=int, default=-1)
83 | parser.add_argument('--embedding_opt', type=int, default=0)
84 | parser.add_argument('--lr_gamma', type=float, default=0.5)
85 | parser.add_argument('--bert_l2norm', type=float, default=0.0)
86 | parser.add_argument('--scheduler_type', type=str, default='ms', help='ms/rop/exp')
87 | parser.add_argument('--output_dir', default='checkpoint')
88 | parser.add_argument('--seed', type=int, default=2018,
89 | help='random seed for data shuffling, embedding init, etc.')
90 | return parser
91 |
92 |
93 | def convert(args):
94 | tf_checkpoint_path = args.tf_checkpoint_root
95 | bert_config_file = os.path.join(tf_checkpoint_path, 'bert_config.json')
96 | pytorch_dump_path = args.pytorch_checkpoint_path
97 | config = BertConfig.from_json_file(bert_config_file)
98 | opt = vars(args)
99 | opt.update(config.to_dict())
100 | model = SANBertNetwork(opt)
101 | path = os.path.join(tf_checkpoint_path, 'bert_model.ckpt')
102 | logger.info('Converting TensorFlow checkpoint from {}'.format(path))
103 | init_vars = tf.train.list_variables(path)
104 | names = []
105 | arrays = []
106 |
107 | for name, shape in init_vars:
108 | logger.info('Loading {} with shape {}'.format(name, shape))
109 | array = tf.train.load_variable(path, name)
110 | logger.info('Numpy array shape {}'.format(array.shape))
111 |
112 | # new layer norm var name
113 | # make sure you use the latest huggingface's new layernorm implementation
114 | # if you still use beta/gamma, remove line: 48-52
115 | if name.endswith('LayerNorm/beta'):
116 | name = name[:-14] + 'LayerNorm/bias'
117 | if name.endswith('LayerNorm/gamma'):
118 | name = name[:-15] + 'LayerNorm/weight'
119 |
120 | if name.endswith('bad_steps'):
121 | print('bad_steps')
122 | continue
123 | if name.endswith('steps'):
124 | print('step')
125 | continue
126 | if name.endswith('step'):
127 | print('step')
128 | continue
129 | if name.endswith('adam_m'):
130 | print('adam_m')
131 | continue
132 | if name.endswith('adam_v'):
133 | print('adam_v')
134 | continue
135 | if name.endswith('loss_scale'):
136 | print('loss_scale')
137 | continue
138 | names.append(name)
139 | arrays.append(array)
140 |
141 | for name, array in zip(names, arrays):
142 | flag = False
143 | if name == 'cls/squad/output_bias':
144 | name = 'out_proj/bias'
145 | flag = True
146 | if name == 'cls/squad/output_weights':
147 | name = 'out_proj/weight'
148 | flag = True
149 |
150 | logger.info('Loading {}'.format(name))
151 | name = name.split('/')
152 | if name[0] in ['redictions', 'eq_relationship', 'cls', 'output']:
153 | logger.info('Skipping')
154 | continue
155 | pointer = model
156 | for m_name in name:
157 | if flag: continue
158 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
159 | l = re.split(r'_(\d+)', m_name)
160 | else:
161 | l = [m_name]
162 | if l[0] == 'kernel':
163 | pointer = getattr(pointer, 'weight')
164 | else:
165 | pointer = getattr(pointer, l[0])
166 | if len(l) >= 2:
167 | num = int(l[1])
168 | pointer = pointer[num]
169 | if m_name[-11:] == '_embeddings':
170 | pointer = getattr(pointer, 'weight')
171 | elif m_name == 'kernel':
172 | array = np.transpose(array)
173 | elif flag:
174 | continue
175 | pointer = getattr(getattr(pointer, name[0]), name[1])
176 | try:
177 | assert tuple(pointer.shape) == array.shape
178 | except AssertionError as e:
179 | e.args += (pointer.shape, array.shape)
180 | raise
181 | pointer.data = torch.from_numpy(array)
182 |
183 | nstate_dict = model.state_dict()
184 | params = {'state':nstate_dict, 'config': config.to_dict()}
185 | torch.save(params, pytorch_dump_path)
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument('--tf_checkpoint_root', type=str, required=True)
190 | parser.add_argument('--pytorch_checkpoint_path', type=str, required=True)
191 | parser = model_config(parser)
192 | parser = train_config(parser)
193 | args = parser.parse_args()
194 | logger.info(args)
195 | convert(args)
196 |
--------------------------------------------------------------------------------
/mt-bluebert/scripts/run_blue_fine_tune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ $# -ne 3 ]; then
4 | echo "fine_tune.sh "
5 | exit 1
6 | fi
7 |
8 | BATCH_SIZE=$1
9 | gpu=$2
10 | task=$3
11 |
12 | echo "export CUDA_VISIBLE_DEVICES=${gpu}"
13 | export CUDA_VISIBLE_DEVICES=${gpu}
14 | tstr=$(date +"%FT%H%M")
15 |
16 | MODEL_NAME="bluebert-mt-dnn4-biomedical-pubmed_adam_answer_opt0_gc1_ggc1_2020-02-16T1213_97_stripped"
17 | ROOT="bionlp2020"
18 | BERT_PATH="$ROOT/bluebert_models/mt_dnn_bluebert_base_cased/$MODEL_NAME.pt"
19 | DATA_DIR="$ROOT/blue_data/canonical_data/bert_uncased_lower"
20 |
21 | if [ "$task" = "clinicalsts" ]; then
22 | answer_opt=0
23 | optim="adam"
24 | grad_clipping=0
25 | global_grad_clipping=1
26 | epochs=30
27 | lr="5e-5"
28 | elif [ "$task" = "biosses" ]; then
29 | answer_opt=0
30 | optim="adam"
31 | grad_clipping=0
32 | global_grad_clipping=1
33 | epochs=20
34 | lr="5e-6"
35 | elif [ "$task" = "mednli" ]; then
36 | answer_opt=0
37 | optim="adam"
38 | grad_clipping=0
39 | global_grad_clipping=1
40 | epochs=10
41 | lr="1e-5"
42 | elif [ "$task" = "mednli" ]; then
43 | answer_opt=0
44 | optim="adam"
45 | grad_clipping=0
46 | global_grad_clipping=1
47 | epochs=10
48 | lr="5e-5"
49 | elif [ "$task" = "i2b2-2010-re" ]; then
50 | answer_opt=0
51 | optim="adam"
52 | grad_clipping=1
53 | global_grad_clipping=1
54 | epochs=10
55 | lr="2e-5"
56 | elif [ "$task" = "chemprot" ]; then
57 | answer_opt=0
58 | optim="adam"
59 | grad_clipping=0
60 | global_grad_clipping=1
61 | epochs=20
62 | lr="2e-5"
63 | elif [ "$task" = "ddi2013-type" ]; then
64 | answer_opt=0
65 | optim="adam"
66 | grad_clipping=0
67 | global_grad_clipping=1
68 | epochs=10
69 | lr="5e-5"
70 | elif [ "$task" = "bc5cdr-chemical" ]; then
71 | answer_opt=0
72 | optim="adam"
73 | grad_clipping=0
74 | global_grad_clipping=1
75 | epochs=20
76 | lr="5e-5"
77 | elif [ "$task" = "bc5cdr-disease" ]; then
78 | answer_opt=0
79 | optim="adam"
80 | grad_clipping=0
81 | global_grad_clipping=1
82 | epochs=20
83 | lr="6e-5"
84 | elif [ "$task" = "shareclefe" ]; then
85 | answer_opt=0
86 | optim="adam"
87 | grad_clipping=0
88 | global_grad_clipping=1
89 | epochs=20
90 | lr="5e-5"
91 | else
92 | echo "Cannot recognize $task"
93 | exit
94 | fi
95 |
96 | train_datasets=$task
97 | test_datasets=$task
98 |
99 | model_dir="$ROOT/checkpoints_blue/${MODEL_NAME}_${task}_${optim}_answer_opt${answer_opt}_gc${grad_clipping}_ggc${global_grad_clipping}_${tstr}"
100 | log_file="${model_dir}/log.log"
101 | python experiments/blue/blue_train.py \
102 | --data_dir ${DATA_DIR} \
103 | --init_checkpoint ${BERT_PATH} \
104 | --task_def experiments/blue/blue_task_def.yml \
105 | --batch_size "${BATCH_SIZE}" \
106 | --epochs ${epochs} \
107 | --output_dir "${model_dir}" \
108 | --log_file "${log_file}" \
109 | --answer_opt ${answer_opt} \
110 | --optimizer ${optim} \
111 | --train_datasets "${train_datasets}" \
112 | --test_datasets "${test_datasets}" \
113 | --grad_clipping ${grad_clipping} \
114 | --global_grad_clipping ${global_grad_clipping} \
115 | --learning_rate ${lr} \
116 | --max_seq_len 128 \
117 | --not_save
118 |
--------------------------------------------------------------------------------
/mt-bluebert/scripts/run_blue_mt_dnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ $# -ne 2 ]; then
4 | echo "train.sh "
5 | exit 1
6 | fi
7 |
8 | prefix="blue-mt-dnn"
9 | BATCH_SIZE=$1
10 | gpu=$2
11 | echo "export CUDA_VISIBLE_DEVICES=${gpu}"
12 | export CUDA_VISIBLE_DEVICES=${gpu}
13 | tstr=$(date +"%FT%H%M")
14 |
15 | train_datasets="clinicalsts,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe"
16 | test_datasets="clinicalsts,mednli,i2b2-2010-re,chemprot,ddi2013-type,bc5cdr-chemical,bc5cdr-disease,shareclefe"
17 |
18 | ROOT="bionlp2020"
19 | BERT_PATH="$ROOT/bluebert_models/bluebert_base/bluebert_base.pt"
20 | DATA_DIR="$ROOT/blue_data/canonical_data/bert_uncased_lower"
21 |
22 | answer_opt=0
23 | optim="adam"
24 | grad_clipping=1
25 | global_grad_clipping=1
26 | lr="5e-5"
27 | epochs=100
28 |
29 | model_dir="$ROOT/checkpoints_blue/${prefix}_${optim}_answer_opt${answer_opt}_gc${grad_clipping}_ggc${global_grad_clipping}_${tstr}"
30 | log_file="${model_dir}/log.log"
31 | python experiments/blue/blue_train.py \
32 | --data_dir ${DATA_DIR} \
33 | --init_checkpoint ${BERT_PATH} \
34 | --task_def experiments/blue/blue_task_def.yml \
35 | --batch_size "${BATCH_SIZE}" \
36 | --output_dir "${model_dir}" \
37 | --log_file "${log_file}" \
38 | --answer_opt ${answer_opt} \
39 | --optimizer ${optim} \
40 | --train_datasets ${train_datasets} \
41 | --test_datasets ${test_datasets} \
42 | --grad_clipping ${grad_clipping} \
43 | --global_grad_clipping ${global_grad_clipping} \
44 | --learning_rate ${lr} \
45 | --multi_gpu_on \
46 | --epochs ${epochs} \
47 | --max_seq_len 128
48 | # --model_ckpt ${MODEL_CKPT} \
49 | # --resume
50 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 |
3 | tensorflow==2.5.3
4 | # tensorflow-gpu==1.15
5 | google-api-python-client
6 | oauth2client
7 | tqdm
8 | numpy
9 | pandas==0.25.3
10 | torch===1.4.0
11 | allennlp
12 |
--------------------------------------------------------------------------------
/tokenizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ncbi-nlp/bluebert/f4b8af9db9f8c4503d62d0c205de7256f38c5890/tokenizer/__init__.py
--------------------------------------------------------------------------------