├── .github └── workflows │ └── codeql-analysis.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── data ├── BC5CDR │ └── raw │ │ ├── BC5CDR_Evaluation-0.0.3 │ │ ├── BioC.dtd │ │ ├── DISCLAIMER.txt │ │ ├── bc5cdr_eval.jar │ │ ├── data │ │ │ ├── gold │ │ │ │ ├── CDR_sample.gold.BioC.xml │ │ │ │ └── CDR_sample.gold.PubTator │ │ │ └── test │ │ │ │ ├── CDR_sample.old.test.CID.BioC.xml │ │ │ │ ├── CDR_sample.old.test.DNER.BioC.xml │ │ │ │ ├── CDR_sample.test.CID.BioC.xml │ │ │ │ ├── CDR_sample.test.CID.PubTator │ │ │ │ ├── CDR_sample.test.DNER.BioC.xml │ │ │ │ ├── CDR_sample.test.DNER.PubTator │ │ │ │ ├── my.txt │ │ │ │ └── rment.py │ │ ├── eval_id.sh │ │ ├── eval_mention.sh │ │ ├── eval_relation.sh │ │ ├── my.txt │ │ └── readme.txt │ │ ├── CDR_Data │ │ ├── BC5CDR.corpus.pdf │ │ ├── BC5CDR.overview.pdf │ │ ├── BC5CDR.presentation.pdf │ │ ├── CDR.Corpus.v010516 │ │ │ ├── CDR_DevelopmentSet.BioC.xml │ │ │ ├── CDR_DevelopmentSet.PubTator.txt │ │ │ ├── CDR_TestSet.BioC.xml │ │ │ ├── CDR_TestSet.PubTator.txt │ │ │ ├── CDR_TrainingSet.BioC.xml │ │ │ └── CDR_TrainingSet.PubTator.txt │ │ ├── DNorm.TestSet │ │ │ ├── TestSet.DNorm.BioC.xml │ │ │ └── TestSet.DNorm.PubTator.txt │ │ ├── README.txt │ │ └── tmChem.TestSet │ │ │ ├── TestSet.tmChem.BioC.xml │ │ │ └── TestSet.tmChem.PubTator.txt │ │ ├── test.entities.json │ │ ├── test.json │ │ ├── train.entities.json │ │ ├── train.json │ │ ├── valid.entities.json │ │ └── valid.json ├── BioGPT-Large │ ├── bpecodes │ └── dict.txt ├── BioGPT │ ├── bpecodes │ └── dict.txt ├── DDI │ └── raw │ │ ├── test.json │ │ ├── train.json │ │ └── valid.json ├── HoC │ └── raw │ │ ├── test.tsv │ │ ├── train.tsv │ │ └── valid.tsv ├── KD-DTI │ └── raw │ │ ├── test.json │ │ ├── train.json │ │ └── valid.json ├── biogpt_large_bpecodes ├── biogpt_large_dict.txt ├── bpecodes └── dict.txt ├── examples ├── DC-HoC │ ├── README.md │ ├── hard_match_evaluation.py │ ├── infer.sh │ ├── postprocess.py │ ├── preprocess.sh │ ├── rebuild_data.py │ └── train.sh ├── QA-PubMedQA │ ├── README.md │ ├── hard_match_evaluation.py │ ├── infer.sh │ ├── infer_large.sh │ ├── postprocess.py │ ├── preprocess.sh │ ├── preprocess_large.sh │ └── rebuild_data.py ├── RE-BC5CDR │ ├── README.md │ ├── infer.sh │ ├── postprocess.py │ ├── preprocess.sh │ ├── rebuild_data.py │ └── train.sh ├── RE-DDI │ ├── README.md │ ├── hard_match_evaluation.py │ ├── infer.sh │ ├── postprocess.py │ ├── preprocess.sh │ ├── rebuild_data.py │ └── train.sh ├── RE-DTI │ ├── README.md │ ├── hard_match_evaluation.py │ ├── infer.sh │ ├── postprocess.py │ ├── preprocess.sh │ ├── rebuild_data.py │ └── train.sh └── text-generation │ ├── README.md │ └── interactive.py ├── inference.py ├── requirements.txt ├── scripts └── average_checkpoints.py └── src ├── __init__.py ├── constrained_generator.py ├── language_model_prompt_dataset.py ├── language_modeling_prompt.py └── transformer_lm_prompt.py /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '39 5 * * 1' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains PyCharm IDE 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # macOS dir files 13 | .DS_Store 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # VSCODE 107 | .vscode/ftp-sync.json 108 | .vscode/settings.json 109 | 110 | # Experimental Folder 111 | experimental/* 112 | 113 | # Weights and Biases logs 114 | wandb/ 115 | 116 | # data 117 | data/*/*-bin 118 | data/*/raw/*.x 119 | data/*/raw/*.y 120 | data/*/raw/*.pmid 121 | data/*/raw/*bpecodes 122 | data/*/raw/*dict.txt 123 | 124 | # Checkpoints 125 | checkpoints/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BioGPT 2 | This repository contains the implementation of [BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbac409/6713511?guestAccessKey=a66d9b5d-4f83-4017-bb52-405815c907b9), by Renqian Luo, Liai Sun, Yingce Xia, Tao Qin, Sheng Zhang, Hoifung Poon and Tie-Yan Liu. 3 | 4 | 5 | # Requirements and Installation 6 | 7 | * [PyTorch](http://pytorch.org/) version == 1.12.0 8 | * Python version == 3.10 9 | * fairseq version == 0.12.0: 10 | 11 | ``` bash 12 | git clone https://github.com/pytorch/fairseq 13 | cd fairseq 14 | git checkout v0.12.0 15 | pip install . 16 | python setup.py build_ext --inplace 17 | cd .. 18 | ``` 19 | * Moses 20 | ``` bash 21 | git clone https://github.com/moses-smt/mosesdecoder.git 22 | export MOSES=${PWD}/mosesdecoder 23 | ``` 24 | * fastBPE 25 | ``` bash 26 | git clone https://github.com/glample/fastBPE.git 27 | export FASTBPE=${PWD}/fastBPE 28 | cd fastBPE 29 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 30 | ``` 31 | * sacremoses 32 | ``` bash 33 | pip install sacremoses 34 | ``` 35 | * sklearn 36 | ``` bash 37 | pip install scikit-learn 38 | ``` 39 | 40 | Remember to set the environment variables `MOSES` and `FASTBPE` to the path of Moses and fastBPE respetively, as they will be required later. 41 | 42 | # Getting Started 43 | ## Pre-trained models 44 | We provide our pre-trained BioGPT model checkpoints along with fine-tuned checkpoints for downstream tasks, available both through URL download as well as through the Hugging Face 🤗 Hub. 45 | 46 | |Model|Description|URL|🤗 Hub| 47 | |----|----|---|---| 48 | |BioGPT|Pre-trained BioGPT model checkpoint|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT.tgz?sp=r&st=2023-11-13T15:37:35Z&se=2099-12-30T23:37:35Z&spr=https&sv=2022-11-02&sr=b&sig=3CcG1TOhqJPBhkVutvVn3PtUq0vPyLBgwggUfojypfY%3D)|[link](https://huggingface.co/microsoft/biogpt)| 49 | |BioGPT-Large|Pre-trained BioGPT-Large model checkpoint|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT-Large.tgz?sp=r&st=2023-11-13T15:38:13Z&se=2099-12-30T23:38:13Z&spr=https&sv=2022-11-02&sr=b&sig=ib1SZut9wAwrsxGWtFtIZDhrnRg92dwPJmoY2lr3MTg%3D)|[link](https://huggingface.co/microsoft/biogpt-large)| 50 | |BioGPT-QA-PubMedQA-BioGPT|Fine-tuned BioGPT for question answering task on PubMedQA|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/QA-PubMedQA-BioGPT.tgz?sp=r&st=2023-11-13T15:38:43Z&se=2099-12-30T23:38:43Z&spr=https&sv=2022-11-02&sr=b&sig=A5SQae6ifsXmrsgpj4E2flhyXm4iHc%2FqO5b8HGOMyjc%3D)| | 51 | |BioGPT-QA-PubMedQA-BioGPT-Large|Fine-tuned BioGPT-Large for question answering task on PubMedQA|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/QA-PubMedQA-BioGPT-Large.tgz?sp=r&st=2023-11-13T15:39:40Z&se=2099-12-30T23:39:40Z&spr=https&sv=2022-11-02&sr=b&sig=t%2B%2FD%2BxVoIxiuyDsD0VXv%2FjSGoS0VcrdVXycYhWZoxUc%3D)|| 52 | |BioGPT-RE-BC5CDR|Fine-tuned BioGPT for relation extraction task on BC5CDR|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-BC5CDR-BioGPT.tgz?sp=r&st=2023-11-13T15:35:14Z&se=2099-12-30T23:35:14Z&spr=https&sv=2022-11-02&sr=b&sig=uXlLIHlVeKIbS%2BVmdzAmlNCeKdoKO2lxsSmwSi%2FH8nE%3D)| | 53 | |BioGPT-RE-DDI|Fine-tuned BioGPT for relation extraction task on DDI|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-DDI-BioGPT.tgz?sp=r&st=2023-11-13T15:35:58Z&se=2099-12-30T23:35:58Z&spr=https&sv=2022-11-02&sr=b&sig=DkaQMuM%2FXAsM2p8%2BUs45ecuqhlSRF1DUYRBJNcxD6Pk%3D)| | 54 | |BioGPT-RE-DTI|Fine-tuned BioGPT for relation extraction task on KD-DTI|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-DTI-BioGPT.tgz?sp=r&st=2023-11-13T15:36:23Z&se=2099-12-30T23:36:23Z&spr=https&sv=2022-11-02&sr=b&sig=bRgUZyqGuwYdM%2FVFzIv6Xa0GThkXq6bVzszmTe9c%2BKM%3D)| | 55 | |BioGPT-DC-HoC|Fine-tuned BioGPT for document classification task on HoC|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/DC-HoC-BioGPT.tgz?sp=r&st=2023-11-13T15:37:17Z&se=2099-12-30T23:37:17Z&spr=https&sv=2022-11-02&sr=b&sig=1DxroWPt%2FBppCTy7QHs842lLy8SQRcUeUwSfMzDFvl0%3D)| | 56 | 57 | Download them and extract them to the `checkpoints` folder of this project. 58 | 59 | For example: 60 | ``` bash 61 | mkdir checkpoints 62 | cd checkpoints 63 | wget https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT.tgz?sp=r&st=2023-11-13T15:37:35Z&se=2099-12-30T23:37:35Z&spr=https&sv=2022-11-02&sr=b&sig=3CcG1TOhqJPBhkVutvVn3PtUq0vPyLBgwggUfojypfY%3D 64 | tar -zxvf Pre-trained-BioGPT.tgz 65 | ``` 66 | 67 | ## Example Usage 68 | Use pre-trained BioGPT model in your code: 69 | ```python 70 | import torch 71 | from fairseq.models.transformer_lm import TransformerLanguageModel 72 | m = TransformerLanguageModel.from_pretrained( 73 | "checkpoints/Pre-trained-BioGPT", 74 | "checkpoint.pt", 75 | "data", 76 | tokenizer='moses', 77 | bpe='fastbpe', 78 | bpe_codes="data/bpecodes", 79 | min_len=100, 80 | max_len_b=1024) 81 | m.cuda() 82 | src_tokens = m.encode("COVID-19 is") 83 | generate = m.generate([src_tokens], beam=5)[0] 84 | output = m.decode(generate[0]["tokens"]) 85 | print(output) 86 | ``` 87 | 88 | Use fine-tuned BioGPT model on KD-DTI for drug-target-interaction in your code: 89 | ```python 90 | import torch 91 | from src.transformer_lm_prompt import TransformerLanguageModelPrompt 92 | m = TransformerLanguageModelPrompt.from_pretrained( 93 | "checkpoints/RE-DTI-BioGPT", 94 | "checkpoint_avg.pt", 95 | "data/KD-DTI/relis-bin", 96 | tokenizer='moses', 97 | bpe='fastbpe', 98 | bpe_codes="data/bpecodes", 99 | max_len_b=1024, 100 | beam=1) 101 | m.cuda() 102 | src_text="" # input text, e.g., a PubMed abstract 103 | src_tokens = m.encode(src_text) 104 | generate = m.generate([src_tokens], beam=args.beam)[0] 105 | output = m.decode(generate[0]["tokens"]) 106 | print(output) 107 | ``` 108 | 109 | For more downstream tasks, please see below. 110 | 111 | ## Downstream tasks 112 | See corresponding folder in [examples](examples): 113 | ### [Relation Extraction on BC5CDR](examples/RE-BC5CDR) 114 | ### [Relation Extraction on KD-DTI](examples/RE-DTI/) 115 | ### [Relation Extraction on DDI](examples/RE-DDI) 116 | ### [Document Classification on HoC](examples/DC-HoC/) 117 | ### [Question Answering on PubMedQA](examples/QA-PubMedQA/) 118 | ### [Text Generation](examples/text-generation/) 119 | 120 | ## Hugging Face 🤗 Usage 121 | 122 | BioGPT has also been integrated into the Hugging Face `transformers` library, and model checkpoints are available on the Hugging Face Hub. 123 | 124 | You can use this model directly with a pipeline for text generation. Since the generation relies on some randomness, we set a seed for reproducibility: 125 | 126 | ```python 127 | from transformers import pipeline, set_seed 128 | from transformers import BioGptTokenizer, BioGptForCausalLM 129 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt") 130 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt") 131 | generator = pipeline('text-generation', model=model, tokenizer=tokenizer) 132 | set_seed(42) 133 | generator("COVID-19 is", max_length=20, num_return_sequences=5, do_sample=True) 134 | ``` 135 | 136 | Here is how to use this model to get the features of a given text in PyTorch: 137 | 138 | ```python 139 | from transformers import BioGptTokenizer, BioGptForCausalLM 140 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt") 141 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt") 142 | text = "Replace me by any text you'd like." 143 | encoded_input = tokenizer(text, return_tensors='pt') 144 | output = model(**encoded_input) 145 | ``` 146 | 147 | Beam-search decoding: 148 | 149 | ```python 150 | import torch 151 | from transformers import BioGptTokenizer, BioGptForCausalLM, set_seed 152 | 153 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt") 154 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt") 155 | 156 | sentence = "COVID-19 is" 157 | inputs = tokenizer(sentence, return_tensors="pt") 158 | 159 | set_seed(42) 160 | 161 | with torch.no_grad(): 162 | beam_output = model.generate(**inputs, 163 | min_length=100, 164 | max_length=1024, 165 | num_beams=5, 166 | early_stopping=True 167 | ) 168 | tokenizer.decode(beam_output[0], skip_special_tokens=True) 169 | ``` 170 | 171 | For more information, please see the [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/biogpt) on the Hugging Face website. 172 | 173 | ## Demos 174 | 175 | Check out these demos on Hugging Face Spaces: 176 | * [Text Generation with BioGPT-Large](https://huggingface.co/spaces/katielink/biogpt-large-demo) 177 | * [Question Answering with BioGPT-Large-PubMedQA](https://huggingface.co/spaces/katielink/biogpt-qa-demo) 178 | 179 | # License 180 | 181 | BioGPT is MIT-licensed. 182 | The license applies to the pre-trained models as well. 183 | 184 | # Contributing 185 | 186 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 187 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 188 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 189 | 190 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 191 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 192 | provided by the bot. You will only need to do this once across all repos using our CLA. 193 | 194 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 195 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 196 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 197 | 198 | # Trademarks 199 | 200 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 201 | trademarks or logos is subject to and must follow 202 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 203 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 204 | Any use of third-party trademarks or logos are subject to those third-party's policies. 205 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | <!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK --> 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | <!-- END MICROSOFT SECURITY.MD BLOCK --> 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/BioC.dtd: -------------------------------------------------------------------------------- 1 | <!-- BioC.dtd --> 2 | 3 | <!-- 4 | 5 | BioC is designed to allow programs that process text and 6 | annotations on that text to easily share data and work 7 | together. This DTD describes how that data is represented in XML 8 | files. 9 | 10 | Some believe XML is easily read by humans and that should be 11 | supported by clearly formatting the elements. In the long run, 12 | this is distracting. While the only meaningful spaces are in text 13 | elements and the other spaces can be ignored, current tools add no 14 | additional space. Formatters and editors may be used to make the 15 | XML file appear more readable. 16 | 17 | The possible variety of annotations that one might want to produce 18 | or use is nearly countless. There is no guarantee that these are 19 | organized in the nice nested structure required for XML 20 | elements. Even if they were, it would be nice to easily ignore 21 | unwanted annotations. So annotations are recorded in a stand off 22 | manner, external to the annotated text. The exceptions are 23 | passages and sentences because of their fundamental place in text. 24 | 25 | The text is expected to be encoded in Unicode, specifically 26 | UTF-8. This is one of the encodings required to be implemented by 27 | XML tools, is portable between big-endian and little-endian 28 | machines and is a superset of 7-bit ASCII. Code points beyond 127 29 | may be expressed directly in UTF-8 or indirectly using numeric 30 | entities. Since many tools today still only directly process 31 | ASCII characters, conversion should be available and 32 | standardized. Offsets should be in 8 bit code units (bytes) for 33 | easier processing by naive programs. 34 | 35 | collection: Group of documents, usually from a larger corpus. If 36 | a group of documents is from several corpora, use several 37 | collections. 38 | 39 | source: Name of the source corpus from which the documents were selected 40 | 41 | date: Date documents extracted from original source. Can be as 42 | simple as yyyymmdd or an ISO timestamp. 43 | 44 | key: Separate file describing the infons used and any other useful 45 | information about the data in the file. For example, if a file 46 | includes part-of-speech tags, this file should describe the set of 47 | part-of-speech tags used. 48 | 49 | infon: key-value pairs. Can record essentially arbitrary 50 | information. "type" will be a particular common key in the major 51 | sub elements below. For PubMed references, passage "type" might 52 | signal "title" or "abstract". For annotations, it might indicate 53 | "noun phrase", "gene", or "disease". In the programming language 54 | data structures, infons are typically represented as a map from a 55 | string to a string. This means keys should be unique within each 56 | parent element. 57 | 58 | document: A document in the collection. A single, complete 59 | stand-alone document as described by its parent source. 60 | 61 | id: Typically, the id of the document in the parent 62 | source. Should at least be unique in the collection. 63 | 64 | passage: One portion of the document. In the sample collection of 65 | PubMed documents, each document has a title and frequently an 66 | abstract. Structured abstracts could have additional passages. For 67 | a full text document, passages could be sections such as 68 | Introduction, Materials and Methods, or Conclusion. Another option 69 | would be paragraphs. Passages impose a linear structure on the 70 | document. Further structure in the document can be described by 71 | infon values. 72 | 73 | offset: Where the passage occurs in the parent document. Depending 74 | on the source corpus, this might be a very relevant number. They 75 | should be sequential and identify a passage's position in the 76 | document. Since the sample PubMed collection is extracted from an 77 | XML file, literal offsets have little value. The title is given an 78 | offset of zero, while the abstract is assumed to begin after the 79 | title and one space. 80 | 81 | text: The original text of the passage. 82 | 83 | sentence: One sentence of the passage. 84 | 85 | offset: A document offset to where the sentence begins in the 86 | passage. This value is the sum of the passage offset and the local 87 | offset within the passage. 88 | 89 | text: The original text of the sentence. 90 | 91 | annotation: Stand-off annotation 92 | 93 | id: Used to refer to this annotation in relations. Should be 94 | unique at whatever level relations at appear. If relations appear 95 | at the sentence level, annotation ids need to be unique within 96 | each sentence. Similarly, if relations appear at the passage 97 | level, annotation ids need to be unique within each passage. 98 | 99 | location: Location of the annotated text. Multiple locations 100 | indicate a multi-span annotation. 101 | 102 | offset: Document offset to where the annotated text begins in 103 | the passage or sentence. The value is the sum of the passage or 104 | sentence offset and the local offset within the passage or 105 | sentence. 106 | 107 | length: Length of the annotated text. While unlikely, this could 108 | be zero to describe an annotation that belongs between two 109 | characters. 110 | 111 | text: Typically the annotated text. 112 | 113 | relation: Relation between multiple annotations and / or other 114 | relations. Relations are allowed to appear at several levels 115 | (document, passage, and sentence). Typically they will all appear 116 | at one level, the level at which they are determined. 117 | Significantly different types of relations might appear at 118 | different levels. 119 | 120 | id: Used to refer to this relation in other relations. This id 121 | needs to be unique at whatever level relations appear. (See 122 | discussion of annotation ids.) 123 | 124 | refid: Id of an annotation or an other relation. 125 | 126 | role: Describes how the referenced annotattion or other relation 127 | participates in the current relation. Has a default value so it 128 | can be left out if there is no meaningful value. 129 | 130 | --> 131 | 132 | <!ELEMENT collection ( source, date, key, infon*, document+ ) > 133 | <!ELEMENT source (#PCDATA)> 134 | <!ELEMENT date (#PCDATA)> 135 | <!ELEMENT key (#PCDATA)> 136 | <!ELEMENT infon (#PCDATA)> 137 | <!ATTLIST infon key CDATA #REQUIRED > 138 | 139 | <!ELEMENT document ( id, infon*, passage+, relation* ) > 140 | <!ELEMENT id (#PCDATA)> 141 | 142 | <!ELEMENT passage ( infon*, offset, ( ( text?, annotation* ) | sentence* ), relation* ) > 143 | <!ELEMENT offset (#PCDATA)> 144 | <!ELEMENT text (#PCDATA)> 145 | 146 | <!ELEMENT sentence ( infon*, offset, text?, annotation*, relation* ) > 147 | 148 | <!ELEMENT annotation ( infon*, location*, text ) > 149 | <!ATTLIST annotation id CDATA #IMPLIED > 150 | <!ELEMENT location EMPTY> 151 | <!ATTLIST location offset CDATA #REQUIRED > 152 | <!ATTLIST location length CDATA #REQUIRED > 153 | 154 | <!ELEMENT relation ( infon*, node* ) > 155 | <!ATTLIST relation id CDATA #IMPLIED > 156 | <!ELEMENT node EMPTY> 157 | <!ATTLIST node refid CDATA #REQUIRED > 158 | <!ATTLIST node role CDATA "" > 159 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/DISCLAIMER.txt: -------------------------------------------------------------------------------- 1 | PUBLIC DOMAIN NOTICE 2 | National Center for Biotechnology Information 3 | 4 | This software/database is a "United States Government Work" under the terms of the United States Copyright Act. It was written as part of the authors' official duties as a United States Government employee and thus cannot be copyrighted. This software/database is freely available to the public for use. The National Library of Medicine and the U.S. Government have not placed any restriction on its use or reproduction. 5 | 6 | Although all reasonable efforts have been taken to ensure the accuracy and reliability of the software and data, the NLM and the U.S. Government do not and cannot warrant the performance or results that may be obtained by using this software or data. The NLM and the U.S. Government disclaim all warranties, express or implied, including warranties of performance, merchantability or fitness for any particular purpose. 7 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/bc5cdr_eval.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/bc5cdr_eval.jar -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/data/test/my.txt: -------------------------------------------------------------------------------- 1 | 26094 CID CHEBI:32875 D006973 1.0 2 | 26094 CID CHEBI:32875 D003866 1. 3 | 354896 CID D008012 D009135 1.0 4 | 354896 CID D008012 D001919 1.0 5 | 354896 CID D008012 D003866 1.0 6 | 354896 CID D008012 D006331 1. 7 | 1720453 CID D003404 D007674 1.0 8 | 1720453 CID D002945 D007674 1.0 9 | 1720453 CID D007069 D007674 1.0 10 | 1720453 CID D010710 D007674 1.0 11 | 1720453 CID D003609 D007674 1.0 12 | 1720453 CID D014750 D007674 1.0 13 | 1720453 CID D003404 D000608 1.0 14 | 1720453 CID D002945 D000608 1.0 15 | 1720453 CID D007069 D000608 1.0 16 | 1720453 CID D010710 D000608 1.0 17 | 1720453 CID D003609 D000608 1.0 18 | 1720453 CID D014750 D000608 1.0 19 | 1720453 CID D003404 D011507 1.0 20 | 1720453 CID D002945 D011507 1.0 21 | 1720453 CID D007069 D011507 1.0 22 | 1720453 CID D010710 D011507 1.0 23 | 1720453 CID D003609 D011507 1.0 24 | 1720453 CID D014750 D011507 1.0 25 | 1720453 CID D003404 D009369 1.0 26 | 1720453 CID D002945 D009369 1.0 27 | 1720453 CID D007069 D009369 1.0 28 | 1720453 CID D010710 D009369 1.0 29 | 1720453 CID D003609 D009369 1.0 30 | 1720453 CID D014750 D009369 1.0 31 | 1720453 CID D003404 D005199 1.0 32 | 1720453 CID D002945 D005199 1.0 33 | 1720453 CID D007069 D005199 1.0 34 | 1720453 CID D010710 D005199 1.0 35 | 1720453 CID D003609 D005199 1.0 36 | 1720453 CID D014750 D005199 1.0 37 | 1720453 CID D003404 D014786 1.0 38 | 1720453 CID D002945 D014786 1.0 39 | 1720453 CID D007069 D014786 1.0 40 | 1720453 CID D010710 D014786 1.0 41 | 1720453 CID D003609 D014786 1.0 42 | 1720453 CID D014750 D014786 1. 43 | 2224762 CID D004317 D000505 1.0 44 | 2224762 CID C010012 D000505 1.0 45 | 2224762 CID C010013 D000505 1.0 46 | 2224762 CID C027260 D000505 1.0 47 | 2224762 CID C027260 D000505 1.0 48 | 2224762 CID D004317 D009325 1.0 49 | 2224762 CID C010012 D009325 1.0 50 | 2224762 CID C010013 D009325 1.0 51 | 2224762 CID C027260 D009325 1.0 52 | 2224762 CID C027260 D009325 1.0 53 | 2224762 CID D004317 D000740 1.0 54 | 2224762 CID C010012 D000740 1.0 55 | 2224762 CID C010013 D000740 1.0 56 | 2224762 CID C027260 D000740 1.0 57 | 2224762 CID C027260 D000740 1.0 58 | 2224762 CID D004317 D000380 1.0 59 | 2224762 CID C010012 D000380 1.0 60 | 2224762 CID C010013 D000380 1.0 61 | 2224762 CID C027260 D000380 1.0 62 | 2224762 CID C027260 D000380 1.0 63 | 2224762 CID D004317 D009369 1.0 64 | 2224762 CID C010012 D009369 1.0 65 | 2224762 CID C010013 D009369 1.0 66 | 2224762 CID C027260 D009369 1.0 67 | 2224762 CID C027260 D009369 1.0 68 | 2224762 CID D004317 D014786 1.0 69 | 2224762 CID C010012 D014786 1.0 70 | 2224762 CID C010013 D014786 1.0 71 | 2224762 CID C027260 D014786 1.0 72 | 2224762 CID C027260 D014786 1.0 73 | 2224762 CID D004317 D008107 1.0 74 | 2224762 CID C010012 D008107 1.0 75 | 2224762 CID C010013 D008107 1.0 76 | 2224762 CID C027260 D008107 1.0 77 | 2224762 CID C027260 D008107 1.0 78 | 2224762 CID D004317 D010689 1.0 79 | 2224762 CID C010012 D010689 1.0 80 | 2224762 CID C010013 D010689 1.0 81 | 2224762 CID C027260 D010689 1.0 82 | 2224762 CID C027260 D010689 1.0 83 | 2224762 CID D004317 D013921 1.0 84 | 2224762 CID C010012 D013921 1.0 85 | 2224762 CID C010013 D013921 1.0 86 | 2224762 CID C027260 D013921 1.0 87 | 2224762 CID C027260 D013921 1.0 88 | 2224762 CID D004317 D002280 1.0 89 | 2224762 CID C010012 D002280 1.0 90 | 2224762 CID C010013 D002280 1.0 91 | 2224762 CID C027260 D002280 1.0 92 | 2224762 CID C027260 D002280 1.0 93 | 2224762 CID D004317 D052016 1.0 94 | 2224762 CID C010012 D052016 1.0 95 | 2224762 CID C010013 D052016 1.0 96 | 2224762 CID C027260 D052016 1.0 97 | 2224762 CID C027260 D052016 1.0 98 | 2224762 CID D004317 D007890 1.0 99 | 2224762 CID C010012 D007890 1.0 100 | 2224762 CID C010013 D007890 1.0 101 | 2224762 CID C027260 D007890 1.0 102 | 2224762 CID C027260 D007890 1. 103 | 2348231 CID D004176 D007383 1.0 104 | 2348231 CID C008514 D007383 1.0 105 | 2348231 CID D013806 D007383 1.0 106 | 2348231 CID C008514 D007383 1.0 107 | 2348231 CID D010431 D007383 1.0 108 | 2348231 CID D010431 D007383 1.0 109 | 2348231 CID D004176 D014652 1.0 110 | 2348231 CID C008514 D014652 1.0 111 | 2348231 CID D013806 D014652 1.0 112 | 2348231 CID C008514 D014652 1.0 113 | 2348231 CID D010431 D014652 1.0 114 | 2348231 CID D010431 D014652 1.0 115 | 2348231 CID D004176 D006940 1.0 116 | 2348231 CID C008514 D006940 1.0 117 | 2348231 CID D013806 D006940 1.0 118 | 2348231 CID C008514 D006940 1.0 119 | 2348231 CID D010431 D006940 1.0 120 | 2348231 CID D010431 D006940 1.0 121 | 2348231 CID D004176 D003327 1.0 122 | 2348231 CID C008514 D003327 1.0 123 | 2348231 CID D013806 D003327 1.0 124 | 2348231 CID C008514 D003327 1.0 125 | 2348231 CID D010431 D003327 1.0 126 | 2348231 CID D010431 D003327 1. 127 | 2385256 CID D000109 D009468 1.0 128 | 2385256 CID D008274 D009468 1.0 129 | 2385256 CID D000109 D009157 1.0 130 | 2385256 CID D008274 D009157 1.0 131 | 2385256 CID D000109 D011225 1.0 132 | 2385256 CID D008274 D011225 1.0 133 | 2385256 CID D000109 D018908 1.0 134 | 2385256 CID D008274 D018908 1.0 135 | 2385256 CID D000109 D010243 1.0 136 | 2385256 CID D008274 D010243 1.0 137 | 2385256 CID D000109 D020511 1.0 138 | 2385256 CID D008274 D020511 1. 139 | 2887062 CID D001971 D000236 1.0 140 | 2887062 CID D004967 D000236 1.0 141 | 2887062 CID D001971 D015175 1.0 142 | 2887062 CID D004967 D015175 1. 143 | 2894766 CID D012460 D003093 1.0 144 | 2894766 CID D000305 D003093 1.0 145 | 2894766 CID D012460 D002305 1.0 146 | 2894766 CID D000305 D002305 1.0 147 | 2894766 CID D012460 C536397 1.0 148 | 2894766 CID D000305 C536397 1.0 149 | 2894766 CID D012460 D012700 1.0 150 | 2894766 CID D000305 D012700 1.0 151 | 2894766 CID D012460 D004832 1.0 152 | 2894766 CID D000305 D004832 1.0 153 | 2894766 CID D012460 D010996 1.0 154 | 2894766 CID D000305 D010996 1.0 155 | 2894766 CID D012460 D008180 1.0 156 | 2894766 CID D000305 D008180 1.0 157 | 2894766 CID D012460 D011014 1.0 158 | 2894766 CID D000305 D011014 1.0 159 | 2894766 CID D012460 D015212 1.0 160 | 2894766 CID D000305 D015212 1. 161 | 3107448 CID D013453 D007239 1.0 162 | 3107448 CID D005905 D007239 1.0 163 | 3107448 CID D013453 D056486 1.0 164 | 3107448 CID D005905 D056486 1.0 165 | 3107448 CID D013453 D003920 1.0 166 | 3107448 CID D005905 D003920 1. 167 | 3173180 CID D008795 D000743 1. 168 | 3403780 CID D000082 D017093 1.0 169 | 3403780 CID D000082 D058186 1.0 170 | 3403780 CID D000082 D000138 1.0 171 | 3403780 CID D000082 D003128 1.0 172 | 3403780 CID D000082 D051437 1. 173 | 3615541 CID D002116 D006948 1.0 174 | 3615541 CID D000661 D006948 1. 175 | 3827439 CID D000666 D058186 1.0 176 | 3827439 CID D000666 D013174 1.0 177 | 3827439 CID D000666 OMIM:215600 1.0 178 | 3827439 CID D000666 D007674 1.0 179 | 3827439 CID D000666 D051437 1. 180 | 4027862 CID D003891 D003693 1. 181 | 6287825 CID D012256 D017695 1.0 182 | 6287825 CID D013831 D017695 1.0 183 | 6287825 CID D007538 D017695 1.0 184 | 6287825 CID D012256 D020275 1.0 185 | 6287825 CID D013831 D020275 1.0 186 | 6287825 CID D007538 D020275 1.0 187 | 6287825 CID D012256 D003389 1.0 188 | 6287825 CID D013831 D003389 1.0 189 | 6287825 CID D007538 D003389 1.0 190 | 6287825 CID D012256 D009422 1.0 191 | 6287825 CID D013831 D009422 1.0 192 | 6287825 CID D007538 D009422 1.0 193 | 6287825 CID D012256 D010523 1.0 194 | 6287825 CID D013831 D010523 1.0 195 | 6287825 CID D007538 D010523 1.0 196 | 6287825 CID D012256 D008881 1.0 197 | 6287825 CID D013831 D008881 1.0 198 | 6287825 CID D007538 D008881 1.0 199 | 6287825 CID D012256 D015417 1.0 200 | 6287825 CID D013831 D015417 1.0 201 | 6287825 CID D007538 D015417 1.0 202 | 6287825 CID D012256 D016609 1.0 203 | 6287825 CID D013831 D016609 1.0 204 | 6287825 CID D007538 D016609 1.0 205 | 6287825 CID D012256 D003920 1.0 206 | 6287825 CID D013831 D003920 1.0 207 | 6287825 CID D007538 D003920 1. 208 | 7352670 CID D006221 D007022 1.0 209 | 7352670 CID D012504 D007022 1.0 210 | 7352670 CID D009599 D007022 1.0 211 | 7352670 CID D009599 D007022 1.0 212 | 7352670 CID D006221 D006973 1.0 213 | 7352670 CID D012504 D006973 1.0 214 | 7352670 CID D009599 D006973 1.0 215 | 7352670 CID D009599 D006973 1. 216 | 7420681 CID D000617 D006311 1.0 217 | 7420681 CID D005839 D006311 1.0 218 | 7420681 CID D014031 D006311 1.0 219 | 7420681 CID D000617 D051437 1.0 220 | 7420681 CID D005839 D051437 1.0 221 | 7420681 CID D014031 D051437 1. 222 | 7468724 CID D013726 D007752 1.0 223 | 7468724 CID D013726 D002318 1. 224 | 8643971 CID D002945 D009369 1.0 225 | 8643971 CID D017239 D009369 1.0 226 | 8643971 CID D002945 D014786 1.0 227 | 8643971 CID D017239 D014786 1.0 228 | 8643971 CID D002945 D010051 1.0 229 | 8643971 CID D017239 D010051 1.0 230 | 8643971 CID D002945 D006258 1.0 231 | 8643971 CID D017239 D006258 1. 232 | 8649546 CID D011433 D010300 1.0 233 | 8649546 CID D011433 D004421 1.0 234 | 8649546 CID D011433 D009069 1.0 235 | 8649546 CID D011433 D004409 1. 236 | 9088814 CID D020155 D007022 1.0 237 | 9088814 CID D000527 D007022 1.0 238 | 9088814 CID D001663 D007022 1.0 239 | 9088814 CID C016635 D007022 1.0 240 | 9088814 CID CHEBI:17087 D007022 1.0 241 | 9088814 CID D020155 D008107 1.0 242 | 9088814 CID D000527 D008107 1.0 243 | 9088814 CID D001663 D008107 1.0 244 | 9088814 CID C016635 D008107 1.0 245 | 9088814 CID CHEBI:17087 D008107 1. 246 | 9249847 CID D014700 D014927 1.0 247 | 9249847 CID D014700 C536277 1.0 248 | 9249847 CID D014700 D013611 1.0 249 | 9249847 CID D014700 D058606 1. 250 | 9660111 CID D019257 D006948 1.0 251 | 9660111 CID D012701 D006948 1.0 252 | 9660111 CID D001058 D006948 1.0 253 | 9660111 CID D007099 D006948 1.0 254 | 9660111 CID D004298 D006948 1.0 255 | 9660111 CID D006916 D006948 1.0 256 | 9660111 CID D003913 D006948 1.0 257 | 9660111 CID D014299 D006948 1.0 258 | 9660111 CID D010656 D006948 1.0 259 | 9660111 CID D014299 D006948 1.0 260 | 9660111 CID D009638 D006948 1.0 261 | 9660111 CID D003000 D006948 1.0 262 | 9660111 CID D019257 D010554 1.0 263 | 9660111 CID D012701 D010554 1.0 264 | 9660111 CID D001058 D010554 1.0 265 | 9660111 CID D007099 D010554 1.0 266 | 9660111 CID D004298 D010554 1.0 267 | 9660111 CID D006916 D010554 1.0 268 | 9660111 CID D003913 D010554 1.0 269 | 9660111 CID D014299 D010554 1.0 270 | 9660111 CID D010656 D010554 1.0 271 | 9660111 CID D014299 D010554 1.0 272 | 9660111 CID D009638 D010554 1.0 273 | 9660111 CID D003000 D010554 1.0 274 | 9660111 CID D019257 D007035 1.0 275 | 9660111 CID D012701 D007035 1.0 276 | 9660111 CID D001058 D007035 1.0 277 | 9660111 CID D007099 D007035 1.0 278 | 9660111 CID D004298 D007035 1.0 279 | 9660111 CID D006916 D007035 1.0 280 | 9660111 CID D003913 D007035 1.0 281 | 9660111 CID D014299 D007035 1.0 282 | 9660111 CID D010656 D007035 1.0 283 | 9660111 CID D014299 D007035 1.0 284 | 9660111 CID D009638 D007035 1.0 285 | 9660111 CID D003000 D007035 1. 286 | 9698967 CID D008874 D014474 1.0 287 | 9698967 CID D004837 D014474 1.0 288 | 9698967 CID D008790 D014474 1.0 289 | 9698967 CID D007741 D014474 1.0 290 | 9698967 CID D008619 D014474 1.0 291 | 9698967 CID D008874 D004387 1.0 292 | 9698967 CID D004837 D004387 1.0 293 | 9698967 CID D008790 D004387 1.0 294 | 9698967 CID D007741 D004387 1.0 295 | 9698967 CID D008619 D004387 1.0 296 | 9698967 CID D008874 D001281 1.0 297 | 9698967 CID D004837 D001281 1.0 298 | 9698967 CID D008790 D001281 1.0 299 | 9698967 CID D007741 D001281 1.0 300 | 9698967 CID D008619 D001281 1. 301 | 9855119 CID D003404 D005923 1.0 302 | 9855119 CID D016559 D005923 1.0 303 | 9855119 CID D003404 D007674 1.0 304 | 9855119 CID D016559 D007674 1.0 305 | 9855119 CID D003404 D057770 1.0 306 | 9855119 CID D016559 D057770 1.0 307 | 9855119 CID D003404 D005355 1.0 308 | 9855119 CID D016559 D005355 1. 309 | 9869257 CID D012601 D003072 1.0 310 | 9869257 CID C098725 D003072 1.0 311 | 9869257 CID D000109 D003072 1.0 312 | 9869257 CID D012601 D000647 1.0 313 | 9869257 CID C098725 D000647 1.0 314 | 9869257 CID D000109 D000647 1. 315 | 10457883 CID D001279 D006323 1.0 316 | 10457883 CID D001279 D009468 1.0 317 | 10457883 CID D001279 D001919 1. 318 | 10721819 CID D001971 D013610 1.0 319 | 10721819 CID D007545 D013610 1.0 320 | 10721819 CID D004298 D013610 1.0 321 | 10721819 CID D004294 D013610 1.0 322 | 10721819 CID D001971 D001919 1.0 323 | 10721819 CID D007545 D001919 1.0 324 | 10721819 CID D004298 D001919 1.0 325 | 10721819 CID D004294 D001919 1.0 326 | 10721819 CID D001971 D007022 1.0 327 | 10721819 CID D007545 D007022 1.0 328 | 10721819 CID D004298 D007022 1.0 329 | 10721819 CID D004294 D007022 1.0 330 | 10721819 CID D001971 D006332 1.0 331 | 10721819 CID D007545 D006332 1.0 332 | 10721819 CID D004298 D006332 1.0 333 | 10721819 CID D004294 D006332 1. 334 | 11147747 CID D012701 D014549 1.0 335 | 11147747 CID D016651 D014549 1.0 336 | 11147747 CID D020280 D014549 1.0 337 | 11147747 CID D017374 D014549 1.0 338 | 11147747 CID C047426 D014549 1. 339 | 11587867 CID D014750 D009370 1.0 340 | 11587867 CID D014750 D009370 1.0 341 | 11587867 CID D014750 D010243 1.0 342 | 11587867 CID D014750 D010243 1.0 343 | 11587867 CID D014750 D015417 1.0 344 | 11587867 CID D014750 D015417 1.0 345 | 11587867 CID D014750 D054198 1.0 346 | 11587867 CID D014750 D054198 1. 347 | 11897407 CID C067171 D006331 1.0 348 | 11897407 CID D007545 D006331 1.0 349 | 11897407 CID D005937 D006331 1.0 350 | 11897407 CID C067171 D007238 1.0 351 | 11897407 CID D007545 D007238 1.0 352 | 11897407 CID D005937 D007238 1.0 353 | 11897407 CID C067171 D009203 1.0 354 | 11897407 CID D007545 D009203 1.0 355 | 11897407 CID D005937 D009203 1. 356 | 12584269 CID D020123 D005355 1.0 357 | 12584269 CID D002857 D005355 1.0 358 | 12584269 CID D004492 D005355 1.0 359 | 12584269 CID D016572 D005355 1.0 360 | 12584269 CID D020123 D005355 1.0 361 | 12584269 CID D016559 D005355 1.0 362 | 12584269 CID D016572 D005355 1. 363 | 12905102 CID D003024 D003693 1. 364 | 14513889 CID D007654 D020335 1.0 365 | 14513889 CID D007654 D009422 1.0 366 | 14513889 CID D007654 D014202 1.0 367 | 14513889 CID D007654 D004401 1.0 368 | 14513889 CID D007654 D010243 1.0 369 | 14513889 CID D007654 D004362 1. 370 | 15188772 CID D002395 D009202 1.0 371 | 15188772 CID D002395 D009202 1.0 372 | 15188772 CID D004837 D009202 1.0 373 | 15188772 CID D002395 D015537 1.0 374 | 15188772 CID D002395 D015537 1.0 375 | 15188772 CID D004837 D015537 1.0 376 | 15188772 CID D002395 D017682 1.0 377 | 15188772 CID D002395 D017682 1.0 378 | 15188772 CID D004837 D017682 1.0 379 | 15188772 CID D002395 D008107 1.0 380 | 15188772 CID D002395 D008107 1.0 381 | 15188772 CID D004837 D008107 1. 382 | 15602202 CID D017963 D009395 1.0 383 | 15602202 CID D017963 D003072 1.0 384 | 15602202 CID D017963 D007674 1. 385 | 16298782 CID D005839 D006316 1.0 386 | 16298782 CID D019331 D006316 1.0 387 | 16298782 CID D000617 D006316 1.0 388 | 16298782 CID D009569 D006316 1.0 389 | 16298782 CID D005839 D034381 1.0 390 | 16298782 CID D019331 D034381 1.0 391 | 16298782 CID D000617 D034381 1.0 392 | 16298782 CID D009569 D034381 1.0 393 | 16298782 CID D005839 D006311 1.0 394 | 16298782 CID D019331 D006311 1.0 395 | 16298782 CID D000617 D006311 1.0 396 | 16298782 CID D009569 D006311 1.0 397 | 16298782 CID D005839 D006319 1.0 398 | 16298782 CID D019331 D006319 1.0 399 | 16298782 CID D000617 D006319 1.0 400 | 16298782 CID D009569 D006319 1. 401 | 16330766 CID C040029 D017695 1.0 402 | 16330766 CID D002211 D017695 1.0 403 | 16330766 CID C040029 D002493 1.0 404 | 16330766 CID D002211 D002493 1.0 405 | 16330766 CID C040029 D020886 1.0 406 | 16330766 CID D002211 D020886 1.0 407 | 16330766 CID C040029 D010146 1.0 408 | 16330766 CID D002211 D010146 1.0 409 | 16330766 CID C040029 D006930 1.0 410 | 16330766 CID D002211 D006930 1.0 411 | 16330766 CID C040029 D009437 1.0 412 | 16330766 CID D002211 D009437 1. 413 | 17175308 CID D003404 D009395 1.0 414 | 17175308 CID D020123 D009395 1.0 415 | 17175308 CID D020123 D009395 1.0 416 | 17175308 CID D003404 OMIM:274230 1.0 417 | 17175308 CID D020123 OMIM:274230 1.0 418 | 17175308 CID D020123 OMIM:274230 1.0 419 | 17175308 CID D003404 D011507 1.0 420 | 17175308 CID D020123 D011507 1.0 421 | 17175308 CID D020123 D011507 1.0 422 | 17175308 CID D003404 D009369 1.0 423 | 17175308 CID D020123 D009369 1.0 424 | 17175308 CID D020123 D009369 1.0 425 | 17175308 CID D003404 D012514 1.0 426 | 17175308 CID D020123 D012514 1.0 427 | 17175308 CID D020123 D012514 1.0 428 | 17175308 CID D003404 D009404 1.0 429 | 17175308 CID D020123 D009404 1.0 430 | 17175308 CID D020123 D009404 1.0 431 | 17175308 CID D003404 D007674 1.0 432 | 17175308 CID D020123 D007674 1.0 433 | 17175308 CID D020123 D007674 1. 434 | 17242861 CID D010862 D013180 1.0 435 | 17242861 CID D010862 D012640 1.0 436 | 17242861 CID D010862 D004833 1. 437 | 17466854 CID D008012 D019106 1.0 438 | 17466854 CID D005839 D019106 1.0 439 | 17466854 CID D008775 D019106 1.0 440 | 17466854 CID D008012 D009325 1.0 441 | 17466854 CID D005839 D009325 1.0 442 | 17466854 CID D008775 D009325 1.0 443 | 17466854 CID D008012 D006261 1.0 444 | 17466854 CID D005839 D006261 1.0 445 | 17466854 CID D008775 D006261 1.0 446 | 17466854 CID D008012 D014839 1.0 447 | 17466854 CID D005839 D014839 1.0 448 | 17466854 CID D008775 D014839 1. 449 | 17615423 CID D003404 D058186 1.0 450 | 17615423 CID D003401 D058186 1.0 451 | 17615423 CID D008148 D058186 1.0 452 | 17615423 CID C065179 D058186 1.0 453 | 17615423 CID D017035 D058186 1.0 454 | 17615423 CID C422923 D058186 1.0 455 | 17615423 CID C413408 D058186 1.0 456 | 17615423 CID D019821 D058186 1.0 457 | 17615423 CID D001224 D058186 1.0 458 | 17615423 CID D000638 D058186 1.0 459 | 17615423 CID C065180 D058186 1.0 460 | 17615423 CID D019821 D058186 1.0 461 | 17615423 CID D000409 D058186 1.0 462 | 17615423 CID D003404 D012206 1.0 463 | 17615423 CID D003401 D012206 1.0 464 | 17615423 CID D008148 D012206 1.0 465 | 17615423 CID C065179 D012206 1.0 466 | 17615423 CID D017035 D012206 1.0 467 | 17615423 CID C422923 D012206 1.0 468 | 17615423 CID C413408 D012206 1.0 469 | 17615423 CID D019821 D012206 1.0 470 | 17615423 CID D001224 D012206 1.0 471 | 17615423 CID D000638 D012206 1.0 472 | 17615423 CID C065180 D012206 1.0 473 | 17615423 CID D019821 D012206 1.0 474 | 17615423 CID D000409 D012206 1.0 475 | 17615423 CID D003404 D006949 1.0 476 | 17615423 CID D003401 D006949 1.0 477 | 17615423 CID D008148 D006949 1.0 478 | 17615423 CID C065179 D006949 1.0 479 | 17615423 CID D017035 D006949 1.0 480 | 17615423 CID C422923 D006949 1.0 481 | 17615423 CID C413408 D006949 1.0 482 | 17615423 CID D019821 D006949 1.0 483 | 17615423 CID D001224 D006949 1.0 484 | 17615423 CID D000638 D006949 1.0 485 | 17615423 CID C065180 D006949 1.0 486 | 17615423 CID D019821 D006949 1.0 487 | 17615423 CID D000409 D006949 1.0 488 | 17615423 CID D003404 D001281 1.0 489 | 17615423 CID D003401 D001281 1.0 490 | 17615423 CID D008148 D001281 1.0 491 | 17615423 CID C065179 D001281 1.0 492 | 17615423 CID D017035 D001281 1.0 493 | 17615423 CID C422923 D001281 1.0 494 | 17615423 CID C413408 D001281 1.0 495 | 17615423 CID D019821 D001281 1.0 496 | 17615423 CID D001224 D001281 1.0 497 | 17615423 CID D000638 D001281 1.0 498 | 17615423 CID C065180 D001281 1.0 499 | 17615423 CID D019821 D001281 1.0 500 | 17615423 CID D000409 D001281 1.0 501 | 17615423 CID D003404 D015658 1.0 502 | 17615423 CID D003401 D015658 1.0 503 | 17615423 CID D008148 D015658 1.0 504 | 17615423 CID C065179 D015658 1.0 505 | 17615423 CID D017035 D015658 1.0 506 | 17615423 CID C422923 D015658 1.0 507 | 17615423 CID C413408 D015658 1.0 508 | 17615423 CID D019821 D015658 1.0 509 | 17615423 CID D001224 D015658 1.0 510 | 17615423 CID D000638 D015658 1.0 511 | 17615423 CID C065180 D015658 1.0 512 | 17615423 CID D019821 D015658 1.0 513 | 17615423 CID D000409 D015658 1.0 514 | 17615423 CID D003404 D003324 1.0 515 | 17615423 CID D003401 D003324 1.0 516 | 17615423 CID D008148 D003324 1.0 517 | 17615423 CID C065179 D003324 1.0 518 | 17615423 CID D017035 D003324 1.0 519 | 17615423 CID C422923 D003324 1.0 520 | 17615423 CID C413408 D003324 1.0 521 | 17615423 CID D019821 D003324 1.0 522 | 17615423 CID D001224 D003324 1.0 523 | 17615423 CID D000638 D003324 1.0 524 | 17615423 CID C065180 D003324 1.0 525 | 17615423 CID D019821 D003324 1.0 526 | 17615423 CID D000409 D003324 1.0 527 | 17615423 CID D003404 D010146 1.0 528 | 17615423 CID D003401 D010146 1.0 529 | 17615423 CID D008148 D010146 1.0 530 | 17615423 CID C065179 D010146 1.0 531 | 17615423 CID D017035 D010146 1.0 532 | 17615423 CID C422923 D010146 1.0 533 | 17615423 CID C413408 D010146 1.0 534 | 17615423 CID D019821 D010146 1.0 535 | 17615423 CID D001224 D010146 1.0 536 | 17615423 CID D000638 D010146 1.0 537 | 17615423 CID C065180 D010146 1.0 538 | 17615423 CID D019821 D010146 1.0 539 | 17615423 CID D000409 D010146 1. 540 | 18023325 CID D004221 D020275 1.0 541 | 18023325 CID D004221 D011782 1.0 542 | 18023325 CID D004221 D015537 1.0 543 | 18023325 CID D004221 D001480 1.0 544 | 18023325 CID D004221 D010523 1.0 545 | 18023325 CID D004221 D001259 1.0 546 | 18023325 CID D004221 D008107 1.0 547 | 18023325 CID D004221 D010146 1.0 548 | 18023325 CID D004221 D014826 1. 549 | 18186898 CID CHEBI:8764 D005198 1.0 550 | 18186898 CID D019259 D005198 1.0 551 | 18186898 CID D013256 D005198 1.0 552 | 18186898 CID D020123 D005198 1.0 553 | 18186898 CID D016559 D005198 1.0 554 | 18186898 CID CHEBI:8764 D028361 1.0 555 | 18186898 CID D019259 D028361 1.0 556 | 18186898 CID D013256 D028361 1.0 557 | 18186898 CID D020123 D028361 1.0 558 | 18186898 CID D016559 D028361 1.0 559 | 18186898 CID CHEBI:8764 D000138 1.0 560 | 18186898 CID D019259 D000138 1.0 561 | 18186898 CID D013256 D000138 1.0 562 | 18186898 CID D020123 D000138 1.0 563 | 18186898 CID D016559 D000138 1.0 564 | 18186898 CID CHEBI:8764 D018908 1.0 565 | 18186898 CID D019259 D018908 1.0 566 | 18186898 CID D013256 D018908 1.0 567 | 18186898 CID D020123 D018908 1.0 568 | 18186898 CID D016559 D018908 1.0 569 | 18186898 CID CHEBI:8764 D006029 1.0 570 | 18186898 CID D019259 D006029 1.0 571 | 18186898 CID D013256 D006029 1.0 572 | 18186898 CID D020123 D006029 1.0 573 | 18186898 CID D016559 D006029 1.0 574 | 18186898 CID CHEBI:8764 D056486 1.0 575 | 18186898 CID D019259 D056486 1.0 576 | 18186898 CID D013256 D056486 1.0 577 | 18186898 CID D020123 D056486 1.0 578 | 18186898 CID D016559 D056486 1.0 579 | 18186898 CID CHEBI:8764 D007625 1.0 580 | 18186898 CID D019259 D007625 1.0 581 | 18186898 CID D013256 D007625 1.0 582 | 18186898 CID D020123 D007625 1.0 583 | 18186898 CID D016559 D007625 1.0 584 | 18186898 CID CHEBI:8764 D008107 1.0 585 | 18186898 CID D019259 D008107 1.0 586 | 18186898 CID D013256 D008107 1.0 587 | 18186898 CID D020123 D008107 1.0 588 | 18186898 CID D016559 D008107 1.0 589 | 18186898 CID CHEBI:8764 D017674 1.0 590 | 18186898 CID D019259 D017674 1.0 591 | 18186898 CID D013256 D017674 1.0 592 | 18186898 CID D020123 D017674 1.0 593 | 18186898 CID D016559 D017674 1.0 594 | 18186898 CID CHEBI:8764 D009135 1.0 595 | 18186898 CID D019259 D009135 1.0 596 | 18186898 CID D013256 D009135 1.0 597 | 18186898 CID D020123 D009135 1.0 598 | 18186898 CID D016559 D009135 1.0 599 | 18186898 CID CHEBI:8764 D007674 1.0 600 | 18186898 CID D019259 D007674 1.0 601 | 18186898 CID D013256 D007674 1.0 602 | 18186898 CID D020123 D007674 1.0 603 | 18186898 CID D016559 D007674 1.0 604 | 18186898 CID CHEBI:8764 D000608 1.0 605 | 18186898 CID D019259 D000608 1.0 606 | 18186898 CID D013256 D000608 1.0 607 | 18186898 CID D020123 D000608 1.0 608 | 18186898 CID D016559 D000608 1.0 609 | 18186898 CID CHEBI:8764 D006527 1.0 610 | 18186898 CID D019259 D006527 1.0 611 | 18186898 CID D013256 D006527 1.0 612 | 18186898 CID D020123 D006527 1.0 613 | 18186898 CID D016559 D006527 1. 614 | 18439803 CID D018698 D012640 1.0 615 | 18439803 CID D005998 D012640 1.0 616 | 18439803 CID C108761 D012640 1.0 617 | 18439803 CID D014635 D012640 1.0 618 | 18439803 CID D010862 D012640 1.0 619 | 18439803 CID D005680 D012640 1.0 620 | 18439803 CID D000596 D012640 1.0 621 | 18439803 CID D014635 D012640 1.0 622 | 18439803 CID C108761 D012640 1.0 623 | 18439803 CID D001224 D012640 1. 624 | 18619688 CID D004317 D028361 1.0 625 | 18619688 CID C025946 D028361 1.0 626 | 18619688 CID C025946 D028361 1.0 627 | 18619688 CID D004317 D003643 1.0 628 | 18619688 CID C025946 D003643 1.0 629 | 18619688 CID C025946 D003643 1.0 630 | 18619688 CID D004317 D006333 1.0 631 | 18619688 CID C025946 D006333 1.0 632 | 18619688 CID C025946 D006333 1. 633 | 18809400 CID D002945 D004194 1.0 634 | 18809400 CID D008063 D004194 1.0 635 | 18809400 CID D017239 D004194 1.0 636 | 18809400 CID D002945 D028361 1.0 637 | 18809400 CID D008063 D028361 1.0 638 | 18809400 CID D017239 D028361 1.0 639 | 18809400 CID D002945 D009422 1.0 640 | 18809400 CID D008063 D009422 1.0 641 | 18809400 CID D017239 D009422 1.0 642 | 18809400 CID D002945 D001480 1.0 643 | 18809400 CID D008063 D001480 1.0 644 | 18809400 CID D017239 D001480 1.0 645 | 18809400 CID D002945 D010523 1.0 646 | 18809400 CID D008063 D010523 1.0 647 | 18809400 CID D017239 D010523 1.0 648 | 18809400 CID D002945 D014786 1.0 649 | 18809400 CID D008063 D014786 1.0 650 | 18809400 CID D017239 D014786 1.0 651 | 18809400 CID D002945 D045888 1.0 652 | 18809400 CID D008063 D045888 1.0 653 | 18809400 CID D017239 D045888 1.0 654 | 18809400 CID D002945 D020258 1.0 655 | 18809400 CID D008063 D020258 1.0 656 | 18809400 CID D017239 D020258 1. 657 | 19657887 CID D000079 D013610 1.0 658 | 19657887 CID D004221 D013610 1.0 659 | 19657887 CID D000431 D013610 1.0 660 | 19657887 CID D003484 D013610 1.0 661 | 19657887 CID D000079 D013610 1.0 662 | 19657887 CID D000079 D007022 1.0 663 | 19657887 CID D004221 D007022 1.0 664 | 19657887 CID D000431 D007022 1.0 665 | 19657887 CID D003484 D007022 1.0 666 | 19657887 CID D000079 D007022 1.0 667 | 19657887 CID D000079 D004417 1.0 668 | 19657887 CID D004221 D004417 1.0 669 | 19657887 CID D000431 D004417 1.0 670 | 19657887 CID D003484 D004417 1.0 671 | 19657887 CID D000079 D004417 1.0 672 | 19657887 CID D000079 C536855 1.0 673 | 19657887 CID D004221 C536855 1.0 674 | 19657887 CID D000431 C536855 1.0 675 | 19657887 CID D003484 C536855 1.0 676 | 19657887 CID D000079 C536855 1. 677 | 19803309 CID D013390 D013035 1.0 678 | 19803309 CID D013390 D009222 1. 679 | 21029050 CID D013390 D001049 1.0 680 | 21029050 CID D013390 D016609 1.0 681 | 21029050 CID D013390 D005316 1. 682 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/data/test/rment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from itertools import groupby 4 | from turtle import title 5 | 6 | 7 | inp_f = sys.argv[1] 8 | out_f = sys.argv[2] 9 | 10 | 11 | def read_pubtator(file): 12 | file = open(file, "r") 13 | lines = (line.strip() for line in file) 14 | for k, g in groupby(lines, key=bool): 15 | g = list(g) 16 | if g[0]: 17 | yield g 18 | file.close() 19 | 20 | def extract_pubtator(lines): 21 | res = [] 22 | fixed_lines = [ 23 | str_with_null.replace('\x00', '') 24 | for str_with_null in lines[2:] 25 | ] 26 | for line in fixed_lines: 27 | sline = line.split('\t') 28 | if sline[1] == 'CID': 29 | res.append(line+'\t1.0') 30 | return res 31 | 32 | data = read_pubtator(inp_f) 33 | with open(out_f, 'w') as f: 34 | for sample in data: 35 | lines = extract_pubtator(sample) 36 | lines = '\n'.join(lines) 37 | print(lines[:-1], file=f) 38 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_id.sh: -------------------------------------------------------------------------------- 1 | FORMAT=$1 2 | GOLD_FILE=$2 3 | PREDICTION_FILE=$3 4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate id Disease $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO 5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate id Disease $FORMAT $GOLD_FILE $PREDICTION_FILE 6 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_mention.sh: -------------------------------------------------------------------------------- 1 | FORMAT=$1 2 | GOLD_FILE=$2 3 | PREDICTION_FILE=$3 4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate mention Disease $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO 5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate mention Disease $FORMAT $GOLD_FILE $PREDICTION_FILE 6 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_relation.sh: -------------------------------------------------------------------------------- 1 | FORMAT=$1 2 | GOLD_FILE=$2 3 | PREDICTION_FILE=$3 4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate relation CID $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO 5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate relation CID $FORMAT $GOLD_FILE $PREDICTION_FILE 6 | 7 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/readme.txt: -------------------------------------------------------------------------------- 1 | [Directory] 2 | A. Introduction 3 | B. Installation 4 | C. Instruction 5 | D. Output control 6 | E. Troubleshooting 7 | 8 | #======================================================================================# 9 | 10 | A. [Introduction] 11 | 12 | This is a set of scripts for executing a Java program for evaluating the performance of BioCreative V CDR task. Please follow the instructions to evaluate your system. We used DNorm, tmChem and coocurrence relations to develop a baseline results in the data/test folder. 13 | 14 | The three scripts each support both the PubTator (http://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/PubTator/) and BioC (http://bioc.sourceforge.net/) formats. 15 | 16 | B. [Installation] 17 | 18 | Users need to install Java in their environment. Scripts are provided for the UNIX command line. Batch files for the Windows environment should be straightforward but are not provided. 19 | 20 | C. [Instruction] 21 | 22 | This program can evaluate the performance of disease mention recognition (mention), normalization (id) and chemical-induces-disease relation (relation) on both PubTator and BioC formats. 23 | 24 | Instruction: 25 | 26 | ./eval_mention.sh [BioC|PubTator] [gold standard] [result] 27 | ./eval_id.sh [BioC|PubTator] [gold standard] [result] 28 | ./eval_relation.sh [BioC|PubTator] [gold standard] [result] 29 | 30 | Example mention evaluation: 31 | ./eval_mention.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.DNER.PubTator 32 | OR 33 | ./eval_mention.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.DNER.BioC.xml 34 | 35 | Results: 36 | TP: 303 37 | FP: 105 38 | FN: 121 39 | Precision: 0.7426470588235294 40 | Recall: 0.714622641509434 41 | F-score: 0.7283653846153848 42 | 43 | Example ID evaluation: 44 | ./eval_id.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.DNER.PubTator 45 | OR 46 | ./eval_id.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.DNER.BioC.xml 47 | 48 | Results: 49 | TP: 150 50 | FP: 56 51 | FN: 64 52 | Precision: 0.7281553398058253 53 | Recall: 0.7009345794392523 54 | F-score: 0.7142857142857142 55 | 56 | Example relation evaluation: 57 | ./eval_relation.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.CID.PubTator 58 | OR 59 | ./eval_relation.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.CID.BioC.xml 60 | 61 | Results: 62 | TP: 90 63 | FP: 533 64 | FN: 33 65 | Precision: 0.14446227929373998 66 | Recall: 0.7317073170731707 67 | F-score: 0.24128686327077747 68 | 69 | D. [Output control] 70 | 71 | The evaluation code provides some output useful for debugging but this is filtered by the scripts for simplicity. The information removed consists of a tab-delimited string containing the values necessary for determining if the the predicted values match the target values. This information is specific to each evaluation type. 72 | 73 | Mention evaluation: 74 | INFO {TP|FP|FN} mention documentId startOffset endOffset mentionType 75 | For example: 76 | INFO TP mention 1720453 34 48 Disease 77 | 78 | ID evaluation: 79 | INFO {TP|FP|FN} id documentId mentionType conceptId 80 | For example: 81 | INFO TP id 7352670 Disease MESH:D007022 82 | 83 | Relation evaluation 84 | INFO {TP|FP|FN} relation documentId relationType conceptId1 conceptId2 85 | For example: 86 | INFO TP relation 2894766 CID MESH:D012460 MESH:D011014 87 | 88 | To obtain the debug output for a specific evaluation, open the appropriate script, and change the Java command that is commented out (with the "#" symbol) to be the line that does not contain "grep -v INFO". 89 | 90 | E. [Troubleshooting] 91 | 92 | The scripts are written to ignore some unexpected data. Writing data in the incorrect format may therefore result in it being ignored. Results with zero TP and zero FP are therefore probably indicating a formatting error. 93 | -------------------------------------------------------------------------------- /data/BC5CDR/raw/CDR_Data/BC5CDR.corpus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.corpus.pdf -------------------------------------------------------------------------------- /data/BC5CDR/raw/CDR_Data/BC5CDR.overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.overview.pdf -------------------------------------------------------------------------------- /data/BC5CDR/raw/CDR_Data/BC5CDR.presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.presentation.pdf -------------------------------------------------------------------------------- /data/BC5CDR/raw/CDR_Data/README.txt: -------------------------------------------------------------------------------- 1 | =========================================================================== 2 | * 3 | * PUBLIC DOMAIN NOTICE 4 | * National Center for Biotechnology Information 5 | * 6 | * This software/database is a "United States Government Work" under the 7 | * terms of the United States Copyright Act. It was written as part of 8 | * the author's official duties as a United States Government employee and 9 | * thus cannot be copyrighted. This software/database is freely available 10 | * to the public for use. The National Library of Medicine and the U.S. 11 | * Government have not placed any restriction on its use or reproduction. 12 | * 13 | * Although all reasonable efforts have been taken to ensure the accuracy 14 | * and reliability of the software and data, the NLM and the U.S. 15 | * Government do not and cannot warrant the performance or results that 16 | * may be obtained by using this software or data. The NLM and the U.S. 17 | * Government disclaim all warranties, express or implied, including 18 | * warranties of performance, merchantability or fitness for any 19 | * particular purpose. 20 | * 21 | * Please cite the authors in any work or product based on this material: 22 | * 23 | * 1. Wei CH, Peng Y, Leaman R, Davis AP, Mattingly CJ, Li J, Wiegers TC, 24 | * Lu Z. Overview of the BioCreative V Chemical Disease Relation (CDR) 25 | * Task, Proceedings of the Fifth BioCreative Challenge Evaluation Workshop, 26 | * p154-166, 2015 27 | * 28 | * 2. Li J, Sun Y, Johnson RJ, Sciaky D, Wei CH, Leaman R, Davis AP, Mattingly CJ, 29 | * Wiegers TC, Lu Z. Anotating chemicals, diseases and their interactions in 30 | * biomedical literature, Proceedings of the Fifth BioCreative Challenge 31 | * Evaluation Workshop, p173-182, 2015 32 | * 33 | * 3. Leaman R, Dogan RI, Lu Z. DNorm: disease name normalization with pairwise 34 | * learning to rank, Bioinformatics 29(22):2909-17, 2013 35 | * 36 | * 4. Leaman R, Wei CH, Lu Z. tmChem: a high performance approach for chemical 37 | * named entity recognition and normalization. J Cheminform, 7:S3, 2015 38 | * 39 | * 40 | ========================================================================== 41 | 42 | This directory contains the annotated corpus created and used in the BioCreative 43 | V Chemical Disease Relation (CDR) Challenge Task [1]. In addition, it contains 44 | text-mined results of two computational tools for disease & chemcial NER. All data 45 | are made available in both BioC XML and PubTator text formats. 46 | 47 | ./CDR.Corpus: The annotated CDR corpus of 1500 PubMed articles of chemicals, 48 | diseases, and chemical-induced disease relationships [2]. 49 | 50 | ./DNorm.TestSet: The text-mined results of diseases on the test set using DNorm [3]. 51 | The normalization performance is 0.81 (P), 0.80 (R), and 0.81 (F). 52 | 53 | ./tmChem.TestSet: The text-mined results of chemicals on the test set using tmChem [4]. 54 | The normalization performance is 0.92 (P), 0.90 (R), and 0.91 (F). 55 | 56 | -------------------------------------------------------------------------------- /examples/DC-HoC/README.md: -------------------------------------------------------------------------------- 1 | # Document Classification on HoC 2 | 3 | ## Data 4 | You can process the data by: 5 | ``` bash 6 | bash preprocess.sh 7 | ``` 8 | 9 | ## Training 10 | You can fine-tune the pre-trained BioGPT on the task by: 11 | ``` bash 12 | bash train.sh 13 | ``` 14 | 15 | ## Model Checkpoint 16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models) 17 | 18 | ## Inference and Evaluating 19 | You can inference and evalaute the model on the test set by: 20 | ``` bash 21 | bash infer.sh 22 | ``` -------------------------------------------------------------------------------- /examples/DC-HoC/hard_match_evaluation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | 5 | from ast import Global 6 | import os 7 | import sys 8 | from sklearn.metrics import f1_score 9 | from sklearn.preprocessing import MultiLabelBinarizer 10 | 11 | pred_file = sys.argv[1] 12 | gold_file = sys.argv[2] 13 | 14 | 15 | def convert_hoc_labels(lines): 16 | labels = [] 17 | classes = ['tumor promoting inflammation', 'inducing angiogenesis', 'evading growth suppressors', 'resisting cell death', 'cellular energetics', 'empty', 'genomic instability and mutation', 'sustaining proliferative signaling', 'avoiding immune destruction', 'activating invasion and metastasis', 'enabling replicative immortality'] 18 | for line in lines: 19 | labels.append([w.strip() for w in line.strip().split('|')]) 20 | return MultiLabelBinarizer(classes=classes).fit_transform(labels) 21 | 22 | def do_eval(preds, golden): 23 | preds = convert_hoc_labels(preds) 24 | golden = convert_hoc_labels(golden) 25 | score = f1_score(golden, preds, average='micro') 26 | print(score) 27 | return 28 | 29 | 30 | def main(): 31 | preds = [] 32 | with open(pred_file) as reader: 33 | for line in reader: 34 | preds.append(line.strip()) 35 | 36 | golden = [] 37 | with open(gold_file) as reader: 38 | for line in reader: 39 | line = line.strip() 40 | if line != '' and len(line) > 0: 41 | golden.append(line.strip().split('\t')[-1]) 42 | 43 | assert len(preds) == len(golden), f"{len(preds)} {len(golden)}" 44 | 45 | print("\n====File: ", os.path.basename(pred_file)) 46 | do_eval(preds, golden) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /examples/DC-HoC/infer.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/DC-HoC-BioGPT 5 | MODEL=checkpoint_last.pt 6 | DATA_DIR=${PWD}/../../data/HoC/ansis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/test.tsv 16 | 17 | # inference 18 | if [ ! -f "$OUTPUT_FILE" ]; then 19 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 20 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 21 | fi 22 | 23 | # debpe 24 | sed -i "s/@@ //g" ${OUTPUT_FILE} 25 | # detok 26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 27 | # postprocess 28 | python postprocess.py ${OUTPUT_FILE}.detok 29 | # eval 30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE} 31 | -------------------------------------------------------------------------------- /examples/DC-HoC/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import sys 5 | import re 6 | 7 | 8 | out_file = sys.argv[1] 9 | 10 | prefix = [ 11 | '(learned[0-9]+ )+', 12 | 'we can conclude that', 13 | 'we have that', 14 | 'in conclusion,', 15 | ] 16 | 17 | 18 | def strip_prefix(line): 19 | for p in prefix: 20 | res = re.search(p, line) 21 | if res is not None: 22 | line = re.split(p, line)[-1].strip() 23 | break 24 | return line 25 | 26 | 27 | def convert_ansis_sentence(sentence): 28 | ans = None 29 | segs = re.search(r"the type of this document is(.*)", sentence) 30 | if segs is not None: 31 | segs = segs.groups() 32 | ans = segs[0].strip() 33 | return ans 34 | 35 | 36 | all_lines = [] 37 | with open(out_file, "r", encoding="utf8") as fr: 38 | for line in fr: 39 | e = line.strip() 40 | if len(e) > 0 and e[-1] == ".": 41 | all_lines.append(e[:-1]) 42 | else: 43 | all_lines.append(e) 44 | 45 | 46 | hypothesis = [] 47 | cnt = 0 48 | fail_cnt = 0 49 | 50 | 51 | for i, line in enumerate(all_lines): 52 | cnt += 1 53 | strip_line = strip_prefix(line) 54 | ans = convert_ansis_sentence(strip_line) 55 | if ans is not None: 56 | hypothesis.append(ans) 57 | else: 58 | hypothesis.append("failed") 59 | fail_cnt += 1 60 | print("Failed:id:{}, line:{}".format(i+1, line)) 61 | 62 | 63 | with open(f"{out_file}.extracted.txt", "w", encoding="utf8") as fw: 64 | for eg in hypothesis: 65 | print(eg, file=fw) 66 | 67 | 68 | print(f"failed = {fail_cnt}, total = {cnt}") 69 | -------------------------------------------------------------------------------- /examples/DC-HoC/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/HoC 5 | prefix=ansis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} 14 | 15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/dict.txt -------------------------------------------------------------------------------- /examples/DC-HoC/rebuild_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | 7 | data_dir=sys.argv[1] 8 | 9 | 10 | def build_target_seq(tgt): 11 | tgt = 'the type of this document is ' + tgt + '.' 12 | return tgt 13 | 14 | 15 | def loader(fname, fn): 16 | ret = [] 17 | cnt = 0 18 | file = open(fname) 19 | 20 | for line in file: 21 | 22 | if line == '\n': 23 | continue 24 | cnt += 1 25 | sent = line.split('\t') 26 | source, target = sent[0].replace('\n', '').strip(), sent[1].replace('\n', '').strip() 27 | if source[-1] == '.': 28 | ret.append([source, fn(target)]) 29 | else: 30 | ret.append([source +'.', fn(target)]) 31 | 32 | print(f"{cnt} samples in {fname} has been processed") 33 | return ret 34 | 35 | 36 | def dumper(content_list, prefix): 37 | fw_source = open(prefix + ".x", "w") 38 | fw_target = open(prefix + ".y", "w") 39 | 40 | for ele in content_list: 41 | print(ele[0], file=fw_source) 42 | print(ele[1], file=fw_target) 43 | 44 | fw_source.close() 45 | fw_target.close() 46 | 47 | 48 | def worker(fname, prefix, fn): 49 | ret = loader(fname, fn) 50 | dumper(ret, prefix) 51 | 52 | 53 | for split in ['train', 'valid', 'test']: 54 | worker(os.path.join(f"{data_dir}", f"{split}.tsv"), os.path.join(f"{data_dir}", f"ansis_{split}"), build_target_seq) -------------------------------------------------------------------------------- /examples/DC-HoC/train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | SAVE_DIR=../../checkpoints/DC-HoC-BioGPT 5 | mkdir -p ${SAVE_DIR} 6 | 7 | fairseq-train \ 8 | ../../data/HoC/ansis-bin --save-dir ${SAVE_DIR} \ 9 | --user-dir ../../src \ 10 | --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \ 11 | --task language_modeling_prompt \ 12 | --arch transformer_lm_prompt_biogpt \ 13 | --share-decoder-input-output-embed --decoder-learned-pos \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 15 | --weight-decay 0.01 --clip-norm 0.0 \ 16 | --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 1000 --warmup-init-lr 1e-07 \ 17 | --tokens-per-sample 1024 --max-source-positions 900 --max-target-positions 1024 \ 18 | --max-tokens 1024 --update-freq 32 \ 19 | --skip-invalid-size-inputs-valid-test \ 20 | --max-update 20000 --save-interval-updates 1000 --no-epoch-checkpoints \ 21 | --learned-prompt 1 -------------------------------------------------------------------------------- /examples/QA-PubMedQA/README.md: -------------------------------------------------------------------------------- 1 | # Question Answering on PubMedQA in Reasoning Required Setting 2 | 3 | ## Data 4 | Download data from [PubMedQA](https://github.com/pubmedqa/pubmedqa) and following the steps of splitting dataset. 5 | 6 | Copy the files `pqal_fold0/train_set.json`, `pqal_fold0/dev_set.json`, `test_set.json` and `test_ground_truth.json` to `../../data/PubMedQA/raw` 7 | 8 | Then, you can process the data by: 9 | ``` bash 10 | bash preprocess.sh # for BioGPT 11 | ``` 12 | or 13 | ``` bash 14 | bash preprocess_large.sh # for BioGPT-Large 15 | ``` 16 | 17 | 18 | ## Model Checkpoint 19 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models) 20 | 21 | ## Inference and Evaluating 22 | You can inference and evaluate the model on the test set by: 23 | ``` bash 24 | bash infer.sh # for BioGPT 25 | ``` 26 | or 27 | ``` bash 28 | bash infer_large.sh # for BioGPT-Large 29 | ``` -------------------------------------------------------------------------------- /examples/QA-PubMedQA/hard_match_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import json 7 | from sklearn.metrics import accuracy_score 8 | 9 | pred_file = sys.argv[1] 10 | gold_file = sys.argv[2] 11 | 12 | 13 | def do_eval(preds, golden): 14 | print(accuracy_score(golden, preds)) 15 | return 16 | 17 | 18 | def main(): 19 | preds = [] 20 | with open(pred_file) as reader: 21 | for line in reader: 22 | preds.append(line.strip()) 23 | 24 | golden = [] 25 | if gold_file.endswith('.tsv'): 26 | with open(gold_file) as reader: 27 | for line in reader: 28 | line = line.strip() 29 | if line != '' and len(line) > 0: 30 | golden.append(line.strip().split('\t')[-1]) 31 | elif gold_file.endswith('.json'): 32 | with open(gold_file) as reader: 33 | data = json.load(reader) 34 | golden = [label for pmid, label in data.items()] 35 | assert len(preds) == len(golden), f"{len(preds)} {len(golden)}" 36 | 37 | print("\n====File: ", os.path.basename(pred_file)) 38 | do_eval(preds, golden) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /examples/QA-PubMedQA/infer.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/QA-PubMedQA-BioGPT 5 | MODEL=checkpoint.pt 6 | DATA_DIR=${PWD}/../../data/PubMedQA/pqal_qcl_ansis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/test_ground_truth.json 16 | 17 | # inference 18 | if [ ! -f "$OUTPUT_FILE" ]; then 19 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 20 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 21 | fi 22 | 23 | # debpe 24 | sed -i "s/@@ //g" ${OUTPUT_FILE} 25 | # detok 26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 27 | # postprocess 28 | python postprocess.py ${OUTPUT_FILE}.detok 29 | # eval 30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE} 31 | -------------------------------------------------------------------------------- /examples/QA-PubMedQA/infer_large.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/QA-PubMedQA-BioGPT-Large 5 | MODEL=checkpoint.pt 6 | DATA_DIR=${PWD}/../../data/PubMedQA/biogpt-large-pqal_qcl_ansis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/test_ground_truth.json 16 | 17 | # inference 18 | if [ ! -f "$OUTPUT_FILE" ]; then 19 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 20 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 21 | fi 22 | 23 | # debpe 24 | sed -i "s/@@ //g" ${OUTPUT_FILE} 25 | # detok 26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 27 | # postprocess 28 | python postprocess.py ${OUTPUT_FILE}.detok 29 | # eval 30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE} 31 | -------------------------------------------------------------------------------- /examples/QA-PubMedQA/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import sys 5 | import re 6 | 7 | 8 | out_file = sys.argv[1] 9 | 10 | prefix = [ 11 | '(learned[0-9]+ )+', 12 | 'we can conclude that', 13 | 'we have that', 14 | 'in conclusion,', 15 | ] 16 | 17 | 18 | def strip_prefix(line): 19 | for p in prefix: 20 | res = re.search(p, line) 21 | if res is not None: 22 | line = re.split(p, line)[-1].strip() 23 | break 24 | return line 25 | 26 | 27 | def convert_relis_sentence(sentence): 28 | ans = None 29 | segs = re.search(r"the answer to the question given the context is(.*)", sentence) 30 | if segs is not None: 31 | segs = segs.groups() 32 | ans = segs[0].strip() 33 | return ans 34 | 35 | 36 | all_lines = [] 37 | with open(out_file, "r", encoding="utf8") as fr: 38 | for line in fr: 39 | e = line.strip() 40 | if len(e) > 0 and e[-1] == ".": 41 | all_lines.append(e[:-1]) 42 | else: 43 | all_lines.append(e) 44 | 45 | 46 | hypothesis = [] 47 | cnt = 0 48 | fail_cnt = 0 49 | 50 | 51 | for i, line in enumerate(all_lines): 52 | cnt += 1 53 | strip_line = strip_prefix(line) 54 | ans = convert_relis_sentence(strip_line) 55 | if ans is not None: 56 | hypothesis.append(ans) 57 | else: 58 | hypothesis.append("failed") 59 | fail_cnt += 1 60 | print("Failed:id:{}, line:{}".format(i+1, line)) 61 | 62 | 63 | with open(f"{out_file}.extracted.txt", "w", encoding="utf8") as fw: 64 | for eg in hypothesis: 65 | print(eg, file=fw) 66 | 67 | 68 | print(f"failed = {fail_cnt}, total = {cnt}") 69 | -------------------------------------------------------------------------------- /examples/QA-PubMedQA/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/PubMedQA 5 | prefix=pqal_qcl_ansis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} ${prefix} 14 | 15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/dict.txt -------------------------------------------------------------------------------- /examples/QA-PubMedQA/preprocess_large.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/PubMedQA 5 | prefix=biogpt-large-pqal_qcl_ansis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} ${prefix} 14 | 15 | cp ${DATA_DIR}/../biogpt_large_dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../biogpt_large_bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/biogpt_large_bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/biogpt_large_bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/biogpt_large_dict.txt -------------------------------------------------------------------------------- /examples/QA-PubMedQA/rebuild_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import re 7 | import json 8 | 9 | data_dir=sys.argv[1] 10 | prefix=sys.argv[2] 11 | 12 | 13 | def build_source_seq(question, context, long_answer=None): 14 | if long_answer: 15 | src = "question: {} context: {} answer: {}".format(question.strip(), context.strip(), long_answer.strip()) 16 | else: 17 | src = "question: {} context: {} ".format(question.strip(), context.strip()) 18 | return src 19 | 20 | 21 | def build_target_seq(tgt): 22 | tgt = 'the answer to the question given the context is ' + tgt + '.' 23 | return tgt 24 | 25 | 26 | def loader(fname, fn, required_long_answer=False): 27 | ret = [] 28 | cnt = 0 29 | 30 | with open(fname, 'r') as file: 31 | data = json.load(file) 32 | 33 | for pmid, content in data.items(): 34 | cnt += 1 35 | question = content['QUESTION'] 36 | context = ' '.join(sen.strip() for sen in content['CONTEXTS']) 37 | context = re.sub(r'\n', ' ', context) 38 | # remove duplicate spaces 39 | context = re.sub(r'\s+', ' ', context) 40 | long_answer = content['LONG_ANSWER'] 41 | if required_long_answer: 42 | source = build_source_seq(question, context, long_answer) 43 | else: 44 | source = build_source_seq(question, context) 45 | 46 | if 'final_decision' in content: 47 | label = content['final_decision'] 48 | target = fn(label) 49 | else: 50 | target = '' 51 | if isinstance(target, list): 52 | for i in range(len(target)): 53 | data_pair = [source, target[i]] 54 | ret.append(data_pair) 55 | else: 56 | data_pair = [source, target] 57 | ret.append(data_pair) 58 | 59 | print(f"{cnt} samples in {fname} has been processed") 60 | return ret 61 | 62 | 63 | def dumper(content_list, prefix): 64 | fw_source = open(prefix + ".x", "w") 65 | fw_target = open(prefix + ".y", "w") 66 | 67 | for ele in content_list: 68 | print(ele[0], file=fw_source) 69 | print(ele[1], file=fw_target) 70 | 71 | fw_source.close() 72 | fw_target.close() 73 | 74 | 75 | def worker(fname, prefix, fn): 76 | ret = loader(fname, fn) 77 | dumper(ret, prefix) 78 | 79 | 80 | worker(os.path.join(f"{data_dir}", "train_set.json"), os.path.join(f"{data_dir}", f"{prefix}_train"), build_target_seq) 81 | worker(os.path.join(f"{data_dir}", "dev_set.json"), os.path.join(f"{data_dir}", f"{prefix}_valid"), build_target_seq) 82 | worker(os.path.join(f"{data_dir}", "test_set.json"), os.path.join(f"{data_dir}", f"{prefix}_test"), build_target_seq) -------------------------------------------------------------------------------- /examples/RE-BC5CDR/README.md: -------------------------------------------------------------------------------- 1 | # Relation Extraction on BC5CDR 2 | 3 | ## Data 4 | You can process the data by: 5 | ``` bash 6 | bash preprocess.sh 7 | ``` 8 | 9 | ## Training 10 | You can fine-tune the pre-trained BioGPT on the task by: 11 | ``` bash 12 | bash train.sh 13 | ``` 14 | 15 | ## Model Checkpoint 16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models) 17 | 18 | ## Inference and Evaluating 19 | You can inference and evalaute the model on the test set by: 20 | ``` bash 21 | bash infer.sh 22 | ``` -------------------------------------------------------------------------------- /examples/RE-BC5CDR/infer.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/RE-BC5CDR-BioGPT 5 | MODEL=checkpoint_avg.pt 6 | DATA_DIR=${PWD}/../../data/BC5CDR/relis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/CDR_Data/CDR.Corpus.v010516/CDR_TestSet.PubTator.txt 16 | ENTITY_FILE=${RAW_DATA_DIR}/test.entities.json 17 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid 18 | 19 | # average checkpoints 20 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then 21 | python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5 22 | fi 23 | 24 | 25 | # inference 26 | if [ ! -f "$OUTPUT_FILE" ]; then 27 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 28 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 29 | fi 30 | 31 | # debpe 32 | sed -i "s/@@ //g" ${OUTPUT_FILE} 33 | # detok 34 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 35 | # postprocess 36 | python postprocess.py ${OUTPUT_FILE}.detok ${ENTITY_FILE} ${PMID_FILE} 37 | # eval 38 | cd ${RAW_DATA_DIR}/BC5CDR_Evaluation-0.0.3 39 | bash eval_relation.sh PubTator ${OLDPWD}/${GOLD_FILE} ${OLDPWD}/${OUTPUT_FILE}.detok.extracted.PubTator 40 | cd ${OLDPWD} -------------------------------------------------------------------------------- /examples/RE-BC5CDR/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import re 7 | import json 8 | 9 | 10 | out_file = sys.argv[1] 11 | entity_file=sys.argv[2] 12 | pmids_file = sys.argv[3] 13 | 14 | prefix = [ 15 | '(learned[0-9]+ )+', 16 | 'in conclusion ,', 17 | 'we can conclude that', 18 | 'we have that', 19 | ] 20 | 21 | 22 | def strip_prefix(line): 23 | for p in prefix: 24 | res = re.search(p, line) 25 | if res is not None: 26 | line = re.split(p, line)[-1].strip() 27 | break 28 | return line 29 | 30 | 31 | def split_sentence(line): 32 | sentences = re.split(r";", line) 33 | return sentences 34 | 35 | 36 | def convert_relis_sentence(sentence): 37 | ans = None 38 | segs = re.match(r"the relation between (.*) and (.*) exists", sentence.strip()) 39 | if segs is not None: 40 | segs = segs.groups() 41 | chemical = segs[0].strip() 42 | disease = segs[1].strip() 43 | ans = (chemical, disease) 44 | return ans 45 | 46 | 47 | all_lines = [] 48 | with open(out_file, "r", encoding="utf8") as fr: 49 | for line in fr: 50 | e = line.strip() 51 | if e[-1] == ".": 52 | all_lines.append(e[:-1]) 53 | else: 54 | all_lines.append(e) 55 | with open(entity_file, "r", encoding="utf8") as fr: 56 | ent2id = json.load(fr) 57 | with open(pmids_file, "r") as reader: 58 | if '.json' in pmids_file: 59 | pmids = json.load(reader) 60 | else: 61 | pmids = [] 62 | for line in reader: 63 | pmids.append(line.strip()) 64 | 65 | 66 | hypothesis = [] 67 | cnt = 0 68 | fail_cnt = 0 69 | for i, line in enumerate(all_lines): 70 | cnt += 1 71 | strip_line = strip_prefix(line) 72 | ret = [] 73 | sentences = split_sentence(strip_line) 74 | for sen in sentences: 75 | ans = convert_relis_sentence(sen) 76 | if ans is not None: 77 | chemical, disease = ans 78 | chemicalID = ent2id['chemical2id'].get(chemical.strip(), "-1") 79 | diseaseID = ent2id['disease2id'].get(disease.strip(), "-1") 80 | ret.append(f"{pmids[i]}\tCID\t{chemicalID}\t{diseaseID}\t1.0") 81 | if len(ret) > 0: 82 | hypothesis.extend(ret) 83 | else: 84 | fail_cnt += 1 85 | print("Failed:id:{}, line:{}".format(i+1, line)) 86 | 87 | 88 | with open(f"{out_file}.extracted.PubTator", "w", encoding="utf8") as fw: 89 | for line in hypothesis: 90 | print(line, file=fw) 91 | 92 | 93 | print(f"failed = {fail_cnt}, total = {cnt}") -------------------------------------------------------------------------------- /examples/RE-BC5CDR/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/BC5CDR 5 | prefix=relis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} 14 | 15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/dict.txt -------------------------------------------------------------------------------- /examples/RE-BC5CDR/rebuild_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import json 7 | import re 8 | 9 | data_dir=sys.argv[1] 10 | 11 | 12 | def unify_ent2id(ent2id, method='max'): 13 | id2ent = {} 14 | for k, v in ent2id.items(): 15 | if v in id2ent: 16 | if method == 'min': 17 | id2ent[v] = k if len(k) < len(id2ent[v]) else id2ent[v] 18 | else: 19 | id2ent[v] = k if len(k) > len(id2ent[v]) else id2ent[v] 20 | else: 21 | id2ent[v] = k 22 | ent2id = {v:k for k, v in id2ent.items()} 23 | return ent2id, id2ent 24 | 25 | 26 | def sort_triples(triples, text): 27 | sorted_triples = sorted(triples, key=lambda x:text.find(x['chemical'])) 28 | return sorted_triples 29 | 30 | 31 | def build_target_seq_svo(relations, id2chem, id2disease): 32 | answer = "" 33 | for z in relations: 34 | chemical = id2chem[z["chemical"]] 35 | disease = id2disease[z["disease"]] 36 | answer += f"{chemical} correlates with {disease}; " 37 | return answer[:-2] + "." 38 | 39 | 40 | def build_target_seq_relis(relations, id2chem, id2disease): 41 | answer = "" 42 | for z in relations: 43 | chemical = id2chem[z["chemical"]] 44 | disease = id2disease[z["disease"]] 45 | answer += f"the relation between {chemical} and {disease} exists; " 46 | return answer[:-2] + "." 47 | 48 | 49 | def loader(fname, fn): 50 | ret = [] 51 | null_cnt = 0 52 | suc_cnt = 0 53 | null_flag = False 54 | with open(fname, "r", encoding="utf8") as fr: 55 | data = json.load(fr) 56 | for pmid, v in data.items(): 57 | if re.search(r"\Wquot;, v["title"]): 58 | content = v["title"] + " " + v["abstract"] 59 | else: 60 | content = v["title"] + ". " + v["abstract"] 61 | 62 | content = content.lower() 63 | 64 | if v["relations"] is None or len(v["relations"]) == 0: 65 | if not null_flag: 66 | print(f"Following PMID in {fname} has no extracted relations:") 67 | null_flag = True 68 | print(f"{pmid} ", end="") 69 | null_cnt += 1 70 | else: 71 | chemical2id = v["chemical2id"] 72 | disease2id = v["disease2id"] 73 | unified_chemical2id, id2chemical = unify_ent2id(chemical2id, method='max') 74 | unified_disease2id, id2disease = unify_ent2id(disease2id, method='max') 75 | answer = fn(v["relations"], id2chemical, id2disease) 76 | ret.append((pmid, content, answer)) 77 | suc_cnt += 1 78 | if null_flag: 79 | print("") 80 | print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no relations extracted.") 81 | return ret 82 | 83 | def dumper(content_list, prefix): 84 | fw_pmid = open(prefix + ".pmid", "w") 85 | fw_content = open(prefix + ".x", "w") 86 | fw_label = open(prefix + ".y", "w") 87 | 88 | for ele in content_list: 89 | print(ele[0], file=fw_pmid) 90 | print(ele[1], file=fw_content) 91 | print(ele[2], file=fw_label) 92 | 93 | fw_pmid.close() 94 | fw_content.close() 95 | fw_label.close() 96 | 97 | def worker(fname, prefix, fn): 98 | ret = loader(fname, fn) 99 | dumper(ret, prefix) 100 | 101 | 102 | for split in ['train', 'valid', 'test']: 103 | worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis) -------------------------------------------------------------------------------- /examples/RE-BC5CDR/train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | SAVE_DIR=../../checkpoints/RE-BC5CDR-BioGPT 5 | mkdir -p ${SAVE_DIR} 6 | 7 | fairseq-train \ 8 | ../../data/BC5CDR/relis-bin --save-dir ${SAVE_DIR} \ 9 | --user-dir ../../src \ 10 | --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \ 11 | --task language_modeling_prompt \ 12 | --arch transformer_lm_prompt_biogpt \ 13 | --share-decoder-input-output-embed --decoder-learned-pos \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 15 | --weight-decay 0.01 --clip-norm 0.0 \ 16 | --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 100 --warmup-init-lr 1e-07 \ 17 | --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \ 18 | --max-tokens 1024 --update-freq 32 \ 19 | --skip-invalid-size-inputs-valid-test \ 20 | --max-epoch 100 --keep-last-epochs 5 \ 21 | --learned-prompt 9 -------------------------------------------------------------------------------- /examples/RE-DDI/README.md: -------------------------------------------------------------------------------- 1 | # Relation Extraction on DDI 2 | 3 | ## Data 4 | You can process the data by: 5 | ``` bash 6 | bash preprocess.sh 7 | ``` 8 | 9 | ## Training 10 | You can fine-tune the pre-trained BioGPT on the task by: 11 | ``` bash 12 | bash train.sh 13 | ``` 14 | 15 | ## Model Checkpoint 16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models) 17 | 18 | ## Inference and Evaluating 19 | You can inference and evalaute the model on the test set by: 20 | ``` bash 21 | bash infer.sh 22 | ``` -------------------------------------------------------------------------------- /examples/RE-DDI/hard_match_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import re 5 | import json 6 | import sys 7 | import os 8 | 9 | pred_file = sys.argv[1] 10 | gold_file = sys.argv[2] 11 | pmids_file = sys.argv[3] 12 | 13 | def normalize_name(s: str): 14 | s = s.strip() 15 | 16 | # normalize roman type id at end of string 17 | num2roman = {"0": "0", "1": "I", "2": "II", "3": "III", "4": "IV", "5": "V", "6": "VI", "7": "VII", "8": "VIII", "9": "IX"} 18 | if len(s) > 2 and s[-1].isnumeric() and not s[-2].isnumeric() and s[-1] in num2roman: 19 | tmps = list(s) 20 | s = ''.join(tmps[:-1]) + num2roman[tmps[-1]] 21 | 22 | # remove useless end string 23 | s = s.replace("-type", '') 24 | 25 | return re.sub('[^a-zA-Z0-9]+', '', s) 26 | 27 | 28 | def rm_abbr(tgt_set): 29 | """ remove abbreviation in the brackets of entity, eg: aaa (bb) -> aaa """ 30 | def rm(s): 31 | s = s.strip() 32 | if "(" in s and s[-1] == ')': # entity end with a bracketed short cut 33 | return normalize_name(s[:s.rfind("(")].strip()) 34 | else: 35 | return normalize_name(s) 36 | 37 | tgt_set = list(tgt_set) 38 | if tgt_set and type(tgt_set[0]) in [tuple, list]: # process triples 39 | return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set]) 40 | else: # process entities 41 | return set([rm(e) for e in tgt_set]) 42 | 43 | 44 | def get_abbr(tgt_set): 45 | """ extract abbreviation in the brackets of entity, eg: aaa (bb) -> bb """ 46 | def rm(s): 47 | s = s.strip() 48 | if "(" in s and s[-1] == ')': 49 | return normalize_name(s[s.rfind("(")+1:-1].strip()) 50 | else: 51 | return normalize_name(s) 52 | 53 | tgt_set = list(tgt_set) 54 | if tgt_set and type(tgt_set[0]) in [tuple, list]: # process triples 55 | return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set]) 56 | else: # process entities 57 | return set([rm(e) for e in tgt_set]) 58 | 59 | 60 | def acc(pred_set, gold_set): 61 | """ Multi-label style acc """ 62 | tp_num = len(pred_set & gold_set) 63 | return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(pred_set | gold_set) 64 | 65 | 66 | def precision(pred_set, gold_set): 67 | """ Multi-label style precision """ 68 | tp_num = len(pred_set & gold_set) 69 | return int(pred_set == gold_set) if len(pred_set) == 0 else 1.0 * tp_num / len(pred_set) 70 | 71 | 72 | def recall(pred_set, gold_set): 73 | """ Multi-label style recall """ 74 | tp_num = len(pred_set & gold_set) 75 | return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(gold_set) 76 | 77 | 78 | def normed_eval(pred_set, gold_set, metric): 79 | """ Both body and abbreviation match are considered correct """ 80 | abbr_pred_set, abbr_gold_set = get_abbr(pred_set), get_abbr(gold_set) 81 | rm_pred_set, rm_gold_set = rm_abbr(pred_set), rm_abbr(gold_set) 82 | return max(metric(abbr_pred_set, abbr_gold_set), metric(rm_pred_set, rm_gold_set)) 83 | 84 | 85 | def get_f1(p, r): 86 | return 0 if (p + r) == 0 else (2.0 * p * r / (p + r)) 87 | 88 | 89 | def ave(scores): 90 | return 1.0 * sum(scores) / len(scores) 91 | 92 | 93 | def do_eval(preds, pmids, golden): 94 | ret = [] 95 | num_pred, num_gold, num_missing = 0, 0, 0 96 | all_f1, p, r, d_acc, t_acc, i_acc = [], [], [], [], [], [] 97 | all_pred_triple, all_pred_d, all_pred_t, all_pred_i, all_gold_triple, all_gold_d, all_gold_t, all_gold_i = [], [], [], [], [], [], [], [], 98 | 99 | for pred, idx in zip(preds, pmids): 100 | gold_d_set, gold_t_set, gold_i_set, gold_set = set(), set(), set(), set() 101 | pred_d_set, pred_t_set, pred_i_set, pred_set = set(), set(), set(), set() 102 | 103 | if pred["triple_list_pred"] and pred["triple_list_pred"][0]["subject"] != 'failed': 104 | for tp in pred["triple_list_pred"]: 105 | d = tp["subject"].strip().lower().replace(' ', '') 106 | t = tp["object"].strip().lower().replace(' ', '') 107 | i = tp["relation"].strip().lower().replace(' ', '') 108 | 109 | pred_d_set.add(d) 110 | pred_t_set.add(t) 111 | pred_i_set.add(i) 112 | pred_set.add((d, t, i)) 113 | if idx not in golden: 114 | num_missing += 1 115 | # print("----Missing:", idx) 116 | continue 117 | if golden[idx]["triples"]: 118 | for tp in golden[idx]["triples"]: 119 | d = tp["drug"].strip().lower().replace(' ', '') 120 | t = tp["target"].strip().lower().replace(' ', '') 121 | i = tp["interaction"].strip().lower().replace(' ', '') 122 | gold_d_set.add(d) 123 | gold_t_set.add(t) 124 | gold_i_set.add(i) 125 | gold_set.add((d, t, i)) 126 | 127 | # sample level eval 128 | p.append(normed_eval(pred_set, gold_set, metric=precision)) 129 | r.append(normed_eval(pred_set, gold_set, metric=recall)) 130 | all_f1.append(get_f1(p[-1], r[-1])) 131 | d_acc.append(normed_eval(pred_d_set, gold_d_set, metric=acc)) 132 | t_acc.append(normed_eval(pred_t_set, gold_t_set, metric=acc)) 133 | i_acc.append(normed_eval(pred_i_set, gold_i_set, metric=acc)) 134 | 135 | # onto level eval 136 | all_pred_d.extend(pred_d_set) 137 | all_pred_t.extend(pred_t_set) 138 | all_pred_i.extend(pred_i_set) 139 | all_pred_triple.extend(pred_set) 140 | all_gold_d.extend(gold_d_set) 141 | all_gold_t.extend(gold_t_set) 142 | all_gold_i.extend(gold_i_set) 143 | all_gold_triple.extend(gold_set) 144 | 145 | # if len(gold_set) < len(golden[idx]["triples"]): 146 | # print("Duplicate extists, ori", golden[idx]["triples"], gold_set) 147 | 148 | num_pred += len(pred_set) 149 | num_gold += len(gold_set) 150 | 151 | ret.append({ 152 | "pmid": idx, 153 | "title": golden[idx]["title"] if "title" in golden[idx] else None, 154 | "abstract": golden[idx]["abstract"], 155 | "d_pred_gold": [d_acc[-1], list(pred_d_set), list(gold_d_set)], 156 | "t_pred_gold": [t_acc[-1], list(pred_t_set), list(gold_t_set)], 157 | "i_pred_gold": [i_acc[-1], list(pred_i_set), list(gold_i_set)], 158 | "all_pred_gold": [all_f1[-1], list(pred_set), list(gold_set)], 159 | }) 160 | 161 | 162 | print("num sample", len(all_f1), "missing", len(preds) - len(all_f1), "num_gold tp", num_gold, "num_pred", num_pred) 163 | 164 | # Note: we adopt multi-label metrics following: http://129.211.169.156/publication/tkde13rev.pdf 165 | print("Sample: acc d: {:.4f}\tt:{:.4f}\ti: {:.4f}\ntp p: {:.4f}\ttp r: {:.4f}\ttp micro f1: {:.4f}\ttp macro f1: {:.4f} ".format( 166 | ave(d_acc), ave(t_acc), ave(i_acc), ave(p), ave(r), ave(all_f1), get_f1(ave(p), ave(r)))) 167 | 168 | # Ontology evaluation_scripts 169 | all_p, all_r = normed_eval(set(all_pred_triple), set(all_gold_triple), metric=precision), normed_eval(set(all_pred_triple), set(all_gold_triple), metric=recall) 170 | d_p, d_r = normed_eval(set(all_pred_d), set(all_gold_d), metric=precision), normed_eval(set(all_pred_d), set(all_gold_d), metric=recall) 171 | t_p, t_r = normed_eval(set(all_pred_t), set(all_gold_t), metric=precision), normed_eval(set(all_pred_t), set(all_gold_t), metric=recall) 172 | i_p, i_r = normed_eval(set(all_pred_i), set(all_gold_i), metric=precision), normed_eval(set(all_pred_i), set(all_gold_i), metric=recall) 173 | 174 | print("Ontology: f1 d: {:.4f}\tt:{:.4f}\ti: {:.4f}\t \nall p: {:.4f}\tall r: {:.4f}\tonto f1: {:.4f}".format( 175 | get_f1(d_p, d_r), get_f1(t_p, t_r), get_f1(i_p, i_r), all_p, all_r, get_f1(all_p, all_r) 176 | )) 177 | return ret 178 | 179 | 180 | def main(): 181 | preds = [] 182 | with open(pred_file) as reader: 183 | for line in reader: 184 | preds.append(json.loads(line)) 185 | 186 | with open(gold_file) as reader: 187 | golden = json.load(reader) 188 | 189 | with open(pmids_file) as reader: 190 | if '.json' in pmids_file: 191 | pmids = json.load(reader) 192 | else: 193 | pmids = [] 194 | for line in reader: 195 | pmids.append(line.strip()) 196 | 197 | print("\n====File: ", os.path.basename(pred_file)) 198 | result = do_eval(preds, pmids, golden) 199 | 200 | last_pos = pred_file.rfind('.json') 201 | res_file_name = pred_file[:last_pos] + '.eval_res.json' 202 | with open(res_file_name, 'w') as writer: 203 | json.dump(result, writer, indent=2) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /examples/RE-DDI/infer.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/RE-DDI-BioGPT 5 | MODEL=checkpoint_avg.pt 6 | DATA_DIR=${PWD}/../../data/DDI/relis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/test.json 16 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid 17 | 18 | # average checkpoints 19 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then 20 | python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5 21 | fi 22 | 23 | # inference 24 | if [ ! -f "$OUTPUT_FILE" ]; then 25 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 26 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 27 | fi 28 | 29 | # debpe 30 | sed -i "s/@@ //g" ${OUTPUT_FILE} 31 | # detok 32 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 33 | # postprocess 34 | python postprocess.py ${OUTPUT_FILE}.detok 35 | # eval 36 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.json ${GOLD_FILE} ${PMID_FILE} 37 | -------------------------------------------------------------------------------- /examples/RE-DDI/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import re 7 | import json 8 | 9 | 10 | out_file = sys.argv[1] 11 | 12 | prefix = [ 13 | '(learned[0-9]+ )+', 14 | 'we can conclude that', 15 | 'we have that', 16 | 'in conclusion,', 17 | ] 18 | 19 | 20 | def strip_prefix(line): 21 | for p in prefix: 22 | res = re.search(p, line) 23 | if res is not None: 24 | line = re.split(p, line)[-1].strip() 25 | break 26 | return line 27 | 28 | 29 | def split_sentence(line): 30 | sentences = re.split(r"; ", line) 31 | return sentences 32 | 33 | 34 | def convert_relis_sentence(sentence): 35 | ans = None 36 | segs = re.match(r"the interaction between (.*) and (.*) is (.*)", sentence) 37 | if segs is not None: 38 | segs = segs.groups() 39 | ans = (segs[0].strip(), segs[2].strip(), segs[1].strip()) 40 | return ans 41 | 42 | 43 | def converter(sample, h_idx=0, r_idx=1, t_idx=2): 44 | ret = {"triple_list_gold": [], "triple_list_pred": [], "new": [], "lack": [], "id": [0]} 45 | for s in sample: 46 | ret["triple_list_pred"].append({"subject": s[h_idx], "relation": s[r_idx], "object": s[t_idx]}) 47 | return ret 48 | 49 | 50 | all_lines = [] 51 | with open(out_file, "r", encoding="utf8") as fr: 52 | for line in fr: 53 | e = line.strip() 54 | if len(e) > 0 and e[-1] == ".": 55 | all_lines.append(e[:-1]) 56 | else: 57 | all_lines.append(e) 58 | 59 | 60 | hypothesis = [] 61 | cnt = 0 62 | fail_cnt = 0 63 | 64 | 65 | for i, line in enumerate(all_lines): 66 | cnt += 1 67 | ret = [] 68 | strip_line = strip_prefix(line) 69 | sentences = split_sentence(strip_line) 70 | for sen in sentences: 71 | ans = convert_relis_sentence(sen) 72 | if ans is not None: 73 | ret.append(ans) 74 | if len(ret) > 0: 75 | hypothesis.append(ret) 76 | else: 77 | hypothesis.append([("failed", "failed", "failed")]) 78 | fail_cnt += 1 79 | print("Failed:id:{}, line:{}".format(i+1, line)) 80 | 81 | 82 | ret_formatted = [] 83 | for i in range(len(hypothesis)): 84 | ret_formatted.append(converter(hypothesis[i])) 85 | 86 | 87 | with open(f"{out_file}.extracted.json", "w", encoding="utf8") as fw: 88 | for eg in ret_formatted: 89 | print(json.dumps(eg), file=fw) 90 | 91 | 92 | print(f"failed = {fail_cnt}, total = {cnt}") 93 | -------------------------------------------------------------------------------- /examples/RE-DDI/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/DDI 5 | prefix=relis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} 14 | 15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/dict.txt -------------------------------------------------------------------------------- /examples/RE-DDI/rebuild_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import json 7 | 8 | data_dir=sys.argv[1] 9 | 10 | 11 | def sort_triples(triples, text): 12 | sorted_triples = sorted(triples, key=lambda x:text.find(x['drug'])) 13 | return sorted_triples 14 | 15 | 16 | def build_target_seq_relis(triples): 17 | answer = "" 18 | for z in triples: 19 | drug = z["drug"].lower() 20 | target = z["target"].lower() 21 | rel = z["interaction"].lower() 22 | answer += f"the interaction between {drug} and {target} is {rel}; " 23 | 24 | return answer[:-2] + "." 25 | 26 | 27 | def build_target_seq_2type(triples): 28 | answer = "" 29 | for z in triples: 30 | drug = z["drug"].lower() 31 | target = z["target"].lower() 32 | rel = z["interaction"].lower() 33 | answer += f"{drug} and {target} are {rel}; " 34 | 35 | return answer[:-2] + "." 36 | 37 | 38 | def loader(fname, fn): 39 | ret = [] 40 | null_cnt = 0 41 | suc_cnt = 0 42 | null_flag = False 43 | with open(fname, "r", encoding="utf8") as fr: 44 | data = json.load(fr) 45 | for pmid, v in data.items(): 46 | content = v["abstract"].strip().replace('\n',' ') 47 | 48 | content = content.lower() 49 | if v["triples"] is None or len(v["triples"]) == 0: 50 | if not null_flag: 51 | print(f"Following PMID in {fname} has no extracted triples:") 52 | null_flag = True 53 | print(f"{pmid} ", end="") 54 | null_cnt += 1 55 | else: 56 | triples = v['triples'] 57 | triples = sort_triples(triples, content) 58 | answer = fn(triples) 59 | ret.append((pmid, content, answer)) 60 | suc_cnt += 1 61 | if null_flag: 62 | print("") 63 | print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no triples extracted.") 64 | return ret 65 | 66 | 67 | def dumper(content_list, prefix): 68 | fw_pmid = open(prefix + ".pmid", "w") 69 | fw_content = open(prefix + ".x", "w") 70 | fw_label = open(prefix + ".y", "w") 71 | 72 | for pmid, x, y in content_list: 73 | print(pmid, file=fw_pmid) 74 | print(x, file=fw_content) 75 | print(y, file=fw_label) 76 | 77 | fw_pmid.close() 78 | fw_content.close() 79 | fw_label.close() 80 | 81 | 82 | def worker(fname, prefix, fn): 83 | ret = loader(fname, fn) 84 | dumper(ret, prefix) 85 | 86 | 87 | for split in ['train', 'valid', 'test']: 88 | worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis) 89 | -------------------------------------------------------------------------------- /examples/RE-DDI/train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | SAVE_DIR=../../checkpoints/RE-DDI-BioGPT 5 | mkdir -p ${SAVE_DIR} 6 | 7 | fairseq-train \ 8 | ../../data/DDI/relis-bin --save-dir ${SAVE_DIR} \ 9 | --user-dir ../../src \ 10 | --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \ 11 | --task language_modeling_prompt \ 12 | --arch transformer_lm_prompt_biogpt \ 13 | --share-decoder-input-output-embed --decoder-learned-pos \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 15 | --weight-decay 0.01 --clip-norm 0.0 \ 16 | --lr 1e-4 --lr-scheduler inverse_sqrt --warmup-updates 500 --warmup-init-lr 1e-07 \ 17 | --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \ 18 | --max-tokens 1024 --update-freq 32 \ 19 | --skip-invalid-size-inputs-valid-test \ 20 | --max-epoch 100 --keep-last-epochs 5 \ 21 | --learned-prompt 9 -------------------------------------------------------------------------------- /examples/RE-DTI/README.md: -------------------------------------------------------------------------------- 1 | # Relation Extraction on KD-DTI 2 | 3 | ## Data 4 | According to the original [KD-DTI dataset](https://github.com/bert-nmt/BERT-DTI), before processing the data, you should first register a DrugBank account, download the xml dataset and replace the entity id with the entity name in the drugbank. 5 | 6 | Then, you can process the data by: 7 | ``` bash 8 | bash preprocess.sh 9 | ``` 10 | 11 | For more details, please see [here](https://github.com/bert-nmt/BERT-DTI). 12 | 13 | ## Training 14 | You can fine-tune the pre-trained BioGPT on the task by: 15 | ``` bash 16 | bash train.sh 17 | ``` 18 | 19 | ## Model Checkpoint 20 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models) 21 | 22 | ## Inference and Evaluating 23 | You can inference and evalaute the model on the test set by: 24 | ``` bash 25 | bash infer.sh 26 | ``` -------------------------------------------------------------------------------- /examples/RE-DTI/hard_match_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import re 5 | import json 6 | import sys 7 | import os 8 | 9 | pred_file = sys.argv[1] 10 | gold_file = sys.argv[2] 11 | pmids_file = sys.argv[3] 12 | 13 | def normalize_name(s: str): 14 | s = s.strip() 15 | 16 | # normalize roman type id at end of string 17 | num2roman = {"0": "0", "1": "I", "2": "II", "3": "III", "4": "IV", "5": "V", "6": "VI", "7": "VII", "8": "VIII", "9": "IX"} 18 | if len(s) > 2 and s[-1].isnumeric() and not s[-2].isnumeric() and s[-1] in num2roman: 19 | tmps = list(s) 20 | s = ''.join(tmps[:-1]) + num2roman[tmps[-1]] 21 | 22 | # remove useless end string 23 | s = s.replace("-type", '') 24 | 25 | return re.sub('[^a-zA-Z0-9]+', '', s) 26 | 27 | 28 | def rm_abbr(tgt_set): 29 | """ remove abbreviation in the brackets of entity, eg: aaa (bb) -> aaa """ 30 | def rm(s): 31 | s = s.strip() 32 | if "(" in s and s[-1] == ')': # entity end with a bracketed short cut 33 | return normalize_name(s[:s.rfind("(")].strip()) 34 | else: 35 | return normalize_name(s) 36 | 37 | tgt_set = list(tgt_set) 38 | if tgt_set and type(tgt_set[0]) in [tuple, list]: # process triples 39 | return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set]) 40 | else: # process entities 41 | return set([rm(e) for e in tgt_set]) 42 | 43 | 44 | def get_abbr(tgt_set): 45 | """ extract abbreviation in the brackets of entity, eg: aaa (bb) -> bb """ 46 | def rm(s): 47 | s = s.strip() 48 | if "(" in s and s[-1] == ')': 49 | return normalize_name(s[s.rfind("(")+1:-1].strip()) 50 | else: 51 | return normalize_name(s) 52 | 53 | tgt_set = list(tgt_set) 54 | if tgt_set and type(tgt_set[0]) in [tuple, list]: # process triples 55 | return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set]) 56 | else: # process entities 57 | return set([rm(e) for e in tgt_set]) 58 | 59 | 60 | def acc(pred_set, gold_set): 61 | """ Multi-label style acc """ 62 | tp_num = len(pred_set & gold_set) 63 | return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(pred_set | gold_set) 64 | 65 | 66 | def precision(pred_set, gold_set): 67 | """ Multi-label style precision """ 68 | tp_num = len(pred_set & gold_set) 69 | return int(pred_set == gold_set) if len(pred_set) == 0 else 1.0 * tp_num / len(pred_set) 70 | 71 | 72 | def recall(pred_set, gold_set): 73 | """ Multi-label style recall """ 74 | tp_num = len(pred_set & gold_set) 75 | return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(gold_set) 76 | 77 | 78 | def normed_eval(pred_set, gold_set, metric): 79 | """ Both body and abbreviation match are considered correct """ 80 | abbr_pred_set, abbr_gold_set = get_abbr(pred_set), get_abbr(gold_set) 81 | rm_pred_set, rm_gold_set = rm_abbr(pred_set), rm_abbr(gold_set) 82 | return max(metric(abbr_pred_set, abbr_gold_set), metric(rm_pred_set, rm_gold_set)) 83 | 84 | 85 | def get_f1(p, r): 86 | return 0 if (p + r) == 0 else (2.0 * p * r / (p + r)) 87 | 88 | 89 | def ave(scores): 90 | return 1.0 * sum(scores) / len(scores) 91 | 92 | 93 | def do_eval(preds, pmids, golden): 94 | ret = [] 95 | num_pred, num_gold, num_missing = 0, 0, 0 96 | all_f1, p, r, d_acc, t_acc, i_acc = [], [], [], [], [], [] 97 | all_pred_triple, all_pred_d, all_pred_t, all_pred_i, all_gold_triple, all_gold_d, all_gold_t, all_gold_i = [], [], [], [], [], [], [], [], 98 | 99 | for pred, idx in zip(preds, pmids): 100 | gold_d_set, gold_t_set, gold_i_set, gold_set = set(), set(), set(), set() 101 | pred_d_set, pred_t_set, pred_i_set, pred_set = set(), set(), set(), set() 102 | 103 | if pred["triple_list_pred"] and pred["triple_list_pred"][0]["subject"] != 'failed': 104 | for tp in pred["triple_list_pred"]: 105 | d = tp["subject"].strip().lower().replace(' ', '') 106 | t = tp["object"].strip().lower().replace(' ', '') 107 | i = tp["relation"].strip().lower().replace(' ', '') 108 | 109 | pred_d_set.add(d) 110 | pred_t_set.add(t) 111 | pred_i_set.add(i) 112 | pred_set.add((d, t, i)) 113 | if idx not in golden: 114 | num_missing += 1 115 | # print("----Missing:", idx) 116 | continue 117 | if golden[idx]["triples"]: 118 | for tp in golden[idx]["triples"]: 119 | d = tp["drug"].strip().lower().replace(' ', '') 120 | t = tp["target"].strip().lower().replace(' ', '') 121 | i = tp["interaction"].strip().lower().replace(' ', '') 122 | gold_d_set.add(d) 123 | gold_t_set.add(t) 124 | gold_i_set.add(i) 125 | gold_set.add((d, t, i)) 126 | 127 | # sample level eval 128 | p.append(normed_eval(pred_set, gold_set, metric=precision)) 129 | r.append(normed_eval(pred_set, gold_set, metric=recall)) 130 | all_f1.append(get_f1(p[-1], r[-1])) 131 | d_acc.append(normed_eval(pred_d_set, gold_d_set, metric=acc)) 132 | t_acc.append(normed_eval(pred_t_set, gold_t_set, metric=acc)) 133 | i_acc.append(normed_eval(pred_i_set, gold_i_set, metric=acc)) 134 | 135 | # onto level eval 136 | all_pred_d.extend(pred_d_set) 137 | all_pred_t.extend(pred_t_set) 138 | all_pred_i.extend(pred_i_set) 139 | all_pred_triple.extend(pred_set) 140 | all_gold_d.extend(gold_d_set) 141 | all_gold_t.extend(gold_t_set) 142 | all_gold_i.extend(gold_i_set) 143 | all_gold_triple.extend(gold_set) 144 | 145 | # if len(gold_set) < len(golden[idx]["triples"]): 146 | # print("Duplicate extists, ori", golden[idx]["triples"], gold_set) 147 | 148 | num_pred += len(pred_set) 149 | num_gold += len(gold_set) 150 | 151 | ret.append({ 152 | "pmid": idx, 153 | "title": golden[idx]["title"] if "title" in golden[idx] else None, 154 | "abstract": golden[idx]["abstract"], 155 | "d_pred_gold": [d_acc[-1], list(pred_d_set), list(gold_d_set)], 156 | "t_pred_gold": [t_acc[-1], list(pred_t_set), list(gold_t_set)], 157 | "i_pred_gold": [i_acc[-1], list(pred_i_set), list(gold_i_set)], 158 | "all_pred_gold": [all_f1[-1], list(pred_set), list(gold_set)], 159 | }) 160 | 161 | 162 | print("num sample", len(all_f1), "missing", len(preds) - len(all_f1), "num_gold tp", num_gold, "num_pred", num_pred) 163 | 164 | # Note: we adopt multi-label metrics following: http://129.211.169.156/publication/tkde13rev.pdf 165 | print("Sample: acc d: {:.4f}\tt:{:.4f}\ti: {:.4f}\ntp p: {:.4f}\ttp r: {:.4f}\ttp micro f1: {:.4f}\ttp macro f1: {:.4f} ".format( 166 | ave(d_acc), ave(t_acc), ave(i_acc), ave(p), ave(r), ave(all_f1), get_f1(ave(p), ave(r)))) 167 | 168 | # Ontology evaluation_scripts 169 | all_p, all_r = normed_eval(set(all_pred_triple), set(all_gold_triple), metric=precision), normed_eval(set(all_pred_triple), set(all_gold_triple), metric=recall) 170 | d_p, d_r = normed_eval(set(all_pred_d), set(all_gold_d), metric=precision), normed_eval(set(all_pred_d), set(all_gold_d), metric=recall) 171 | t_p, t_r = normed_eval(set(all_pred_t), set(all_gold_t), metric=precision), normed_eval(set(all_pred_t), set(all_gold_t), metric=recall) 172 | i_p, i_r = normed_eval(set(all_pred_i), set(all_gold_i), metric=precision), normed_eval(set(all_pred_i), set(all_gold_i), metric=recall) 173 | 174 | print("Ontology: f1 d: {:.4f}\tt:{:.4f}\ti: {:.4f}\t \nall p: {:.4f}\tall r: {:.4f}\tonto f1: {:.4f}".format( 175 | get_f1(d_p, d_r), get_f1(t_p, t_r), get_f1(i_p, i_r), all_p, all_r, get_f1(all_p, all_r) 176 | )) 177 | return ret 178 | 179 | 180 | def main(): 181 | preds = [] 182 | with open(pred_file) as reader: 183 | for line in reader: 184 | preds.append(json.loads(line)) 185 | 186 | with open(gold_file) as reader: 187 | golden = json.load(reader) 188 | 189 | with open(pmids_file) as reader: 190 | if '.json' in pmids_file: 191 | pmids = json.load(reader) 192 | else: 193 | pmids = [] 194 | for line in reader: 195 | pmids.append(line.strip()) 196 | 197 | print("\n====File: ", os.path.basename(pred_file)) 198 | result = do_eval(preds, pmids, golden) 199 | 200 | last_pos = pred_file.rfind('.json') 201 | res_file_name = pred_file[:last_pos] + '.eval_res.json' 202 | with open(res_file_name, 'w') as writer: 203 | json.dump(result, writer, indent=2) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /examples/RE-DTI/infer.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | MODEL_DIR=../../checkpoints/RE-DTI-BioGPT 5 | MODEL=checkpoint_avg.pt 6 | DATA_DIR=${PWD}/../../data/KD-DTI/relis-bin 7 | BASE_DATA_DIR=${DATA_DIR%/*} 8 | BIN_DATA_DIR=${DATA_DIR##*/} 9 | DATA_PREFIX=${BIN_DATA_DIR%-*} 10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw 11 | OUTPUT_FILE=generate_${MODEL} 12 | 13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x 14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE} 15 | GOLD_FILE=${RAW_DATA_DIR}/test.json 16 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid 17 | 18 | # average checkpoints 19 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then 20 | python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5 21 | fi 22 | 23 | # inference 24 | if [ ! -f "$OUTPUT_FILE" ]; then 25 | echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}" 26 | python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE} 27 | fi 28 | 29 | # debpe 30 | sed -i "s/@@ //g" ${OUTPUT_FILE} 31 | # detok 32 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok 33 | # postprocess 34 | python postprocess.py ${OUTPUT_FILE}.detok 35 | # eval 36 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.json ${GOLD_FILE} ${PMID_FILE} 37 | -------------------------------------------------------------------------------- /examples/RE-DTI/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import re 7 | import json 8 | 9 | 10 | out_file = sys.argv[1] 11 | 12 | prefix = [ 13 | '(learned[0-9]+ )+', 14 | 'we can conclude that', 15 | 'we have that', 16 | 'in conclusion,', 17 | ] 18 | 19 | 20 | def strip_prefix(line): 21 | for p in prefix: 22 | res = re.search(p, line) 23 | if res is not None: 24 | line = re.split(p, line)[-1].strip() 25 | break 26 | return line 27 | 28 | 29 | def split_sentence(line): 30 | sentences = re.split(r"; ", line) 31 | return sentences 32 | 33 | 34 | def convert_relis_sentence(sentence): 35 | ans = None 36 | segs = re.match(r"the interaction between (.*) and (.*) is (.*)", sentence) 37 | if segs is not None: 38 | segs = segs.groups() 39 | ans = (segs[0].strip(), segs[2].strip(), segs[1].strip()) 40 | return ans 41 | 42 | 43 | def converter(sample, h_idx=0, r_idx=1, t_idx=2): 44 | ret = {"triple_list_gold": [], "triple_list_pred": [], "new": [], "lack": [], "id": [0]} 45 | for s in sample: 46 | ret["triple_list_pred"].append({"subject": s[h_idx], "relation": s[r_idx], "object": s[t_idx]}) 47 | return ret 48 | 49 | 50 | all_lines = [] 51 | with open(out_file, "r", encoding="utf8") as fr: 52 | for line in fr: 53 | e = line.strip() 54 | if len(e) > 0 and e[-1] == ".": 55 | all_lines.append(e[:-1]) 56 | else: 57 | all_lines.append(e) 58 | 59 | 60 | hypothesis = [] 61 | cnt = 0 62 | fail_cnt = 0 63 | 64 | 65 | for i, line in enumerate(all_lines): 66 | cnt += 1 67 | ret = [] 68 | strip_line = strip_prefix(line) 69 | sentences = split_sentence(strip_line) 70 | for sen in sentences: 71 | ans = convert_relis_sentence(sen) 72 | if ans is not None: 73 | ret.append(ans) 74 | if len(ret) > 0: 75 | hypothesis.append(ret) 76 | else: 77 | hypothesis.append([("failed", "failed", "failed")]) 78 | fail_cnt += 1 79 | print("Failed:id:{}, line:{}".format(i+1, line)) 80 | 81 | 82 | ret_formatted = [] 83 | for i in range(len(hypothesis)): 84 | ret_formatted.append(converter(hypothesis[i])) 85 | 86 | 87 | with open(f"{out_file}.extracted.json", "w", encoding="utf8") as fw: 88 | for eg in ret_formatted: 89 | print(json.dumps(eg), file=fw) 90 | 91 | 92 | print(f"failed = {fail_cnt}, total = {cnt}") 93 | -------------------------------------------------------------------------------- /examples/RE-DTI/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | DATA_DIR=../../data/KD-DTI 5 | prefix=relis 6 | RAW_DATA_DIR=${DATA_DIR}/raw 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin 8 | 9 | if [ -d "${OUTPUT_DIR}" ]; then 10 | rm -rf ${OUTPUT_DIR} 11 | fi 12 | 13 | python rebuild_data.py ${RAW_DATA_DIR} 14 | 15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/ 16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/ 17 | 18 | SPLIT=(train valid test) 19 | 20 | for ff in ${SPLIT[@]}; do 21 | if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then 22 | echo "Preprocessing ${ff}" 23 | 24 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x 25 | perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 26 | 27 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes 28 | ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes 29 | 30 | rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y 31 | fi 32 | done 33 | 34 | # do binarize 35 | fairseq-preprocess \ 36 | -s x -t y --workers 8 \ 37 | --joined-dictionary \ 38 | --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \ 39 | --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \ 40 | --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \ 41 | --destdir ${OUTPUT_DIR} \ 42 | --srcdict ${RAW_DATA_DIR}/dict.txt -------------------------------------------------------------------------------- /examples/RE-DTI/rebuild_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import json 7 | import re 8 | 9 | data_dir=sys.argv[1] 10 | 11 | def map_relation_to_verb(relation): 12 | special_mapping = { 13 | "product of": "is the product of", 14 | "negative modulator": "negatively modulates", 15 | "other/unknown": "randomly works with", 16 | "other": "randomly works with", 17 | "incorporation into and destabilization": "incorporates into and destabilizates", 18 | "cross-linking/alkylation": "cross lines / alkylates", 19 | "antibody": "is the antibody of", 20 | "downregulator": "downregulates", 21 | "desensitize the target": "desensitizes", 22 | "protector": "protects", 23 | "inhibitor": "inhibits", 24 | "weak inhibitor": "weakly inhibits", 25 | "blocker": "blocks" 26 | } 27 | if relation in special_mapping: 28 | return special_mapping[relation] 29 | 30 | if relation.endswith("agonist") or relation.endswith("antagonist"): 31 | return relation + "s" 32 | 33 | if relation.endswith("or") or relation.endswith("er"): 34 | return relation[:-2] + "es" 35 | 36 | if relation.endswith("tion"): 37 | return relation[:-3] + "es" 38 | 39 | if relation.endswith("ing"): 40 | return relation[:-3] + "s" 41 | 42 | return relation + "s" 43 | 44 | 45 | def sort_triples(triples, text): 46 | sorted_triples = sorted(triples, key=lambda x:text.find(x['drug'])) 47 | return sorted_triples 48 | 49 | 50 | def build_target_seq_svo(triples): 51 | answer = "" 52 | for z in triples: 53 | drug = z["drug"].lower() 54 | target = z["target"].lower() 55 | rel = map_relation_to_verb(z["interaction"].lower()) 56 | answer += f"{drug} {rel} {target}; " 57 | 58 | return answer[:-2] + "." 59 | 60 | 61 | def build_target_seq_isof(triples): 62 | answer = "" 63 | for z in triples: 64 | drug = z["drug"].lower() 65 | target = z["target"].lower() 66 | rel = z["interaction"].lower() 67 | answer += f"{drug} is the {rel} of {target}; " 68 | 69 | return answer[:-2] + "." 70 | 71 | 72 | def build_target_seq_htr(triples): 73 | answer = "" 74 | for z in triples: 75 | drug = z["drug"].lower() 76 | target = z["target"].lower() 77 | rel = z["interaction"].lower() 78 | answer += f"<h> {drug} <t> {target} <r> {rel} " 79 | 80 | return answer[:-1] + "." 81 | 82 | 83 | def build_target_seq_relis(triples): 84 | answer = "" 85 | for z in triples: 86 | drug = z["drug"].lower() 87 | target = z["target"].lower() 88 | rel = z["interaction"].lower() 89 | answer += f"the interaction between {drug} and {target} is {rel}; " 90 | 91 | return answer[:-2] + "." 92 | 93 | 94 | def loader(fname, fn): 95 | ret = [] 96 | null_cnt = 0 97 | suc_cnt = 0 98 | null_flag = False 99 | with open(fname, "r", encoding="utf8") as fr: 100 | data = json.load(fr) 101 | for pmid, v in data.items(): 102 | if re.search(r"\Wquot;, v["title"]): 103 | content = v["title"] + " " + v["abstract"] 104 | else: 105 | content = v["title"] + ". " + v["abstract"] 106 | 107 | content = content.lower() 108 | if v["triples"] is None or len(v["triples"]) == 0: 109 | if not null_flag: 110 | print(f"Following PMID in {fname} has no extracted triples:") 111 | null_flag = True 112 | print(f"{pmid} ", end="") 113 | null_cnt += 1 114 | else: 115 | triples = v['triples'] 116 | triples = sort_triples(triples, content) 117 | answer = fn(triples) 118 | ret.append((pmid, content, answer)) 119 | suc_cnt += 1 120 | if null_flag: 121 | print("") 122 | print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no triples extracted.") 123 | return ret 124 | 125 | 126 | def dumper(content_list, prefix): 127 | fw_pmid = open(prefix + ".pmid", "w") 128 | fw_content = open(prefix + ".x", "w") 129 | fw_label = open(prefix + ".y", "w") 130 | 131 | for ele in content_list: 132 | print(ele[0], file=fw_pmid) 133 | print(ele[1], file=fw_content) 134 | print(ele[2], file=fw_label) 135 | 136 | fw_pmid.close() 137 | fw_content.close() 138 | fw_label.close() 139 | 140 | 141 | def worker(fname, prefix, fn): 142 | ret = loader(fname, fn) 143 | dumper(ret, prefix) 144 | 145 | 146 | for split in ['train', 'valid', 'test']: 147 | worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis) 148 | #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"isof_{split}"), build_target_seq_isof) 149 | #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"svo_{split}"), build_target_seq_svo) 150 | #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"htr_{split}"), build_target_seq_htr) -------------------------------------------------------------------------------- /examples/RE-DTI/train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | SAVE_DIR=../../checkpoints/RE-DTI-BioGPT 5 | mkdir -p ${SAVE_DIR} 6 | 7 | fairseq-train \ 8 | ../../data/KD-DTI/relis-bin --save-dir ${SAVE_DIR} \ 9 | --user-dir ../../src \ 10 | --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \ 11 | --task language_modeling_prompt \ 12 | --arch transformer_lm_prompt_biogpt \ 13 | --share-decoder-input-output-embed --decoder-learned-pos \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' \ 15 | --weight-decay 0.01 --clip-norm 0.0 \ 16 | --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 1000 --warmup-init-lr 1e-07 \ 17 | --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \ 18 | --max-tokens 1024 --update-freq 32 \ 19 | --skip-invalid-size-inputs-valid-test \ 20 | --max-epoch 30 --keep-last-epochs 5 \ 21 | --learned-prompt 9 -------------------------------------------------------------------------------- /examples/text-generation/README.md: -------------------------------------------------------------------------------- 1 | # Text Generation 2 | You can use the pre-trained BioGPT model for free text generation, just as how you use GPT models. 3 | ## Model Checkpoint 4 | We provide our pre-trained BioGPT model. See [here](../../README.md#pre-trained-models) 5 | 6 | ## Generation 7 | We here provide an interactive way for generation: 8 | ``` bash 9 | python interactive.py 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/text-generation/interactive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | from fairseq.models.transformer_lm import TransformerLanguageModel 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--data_dir", type=str, default="../../data/PubMed/data-bin") 9 | parser.add_argument("--model_dir", type=str, default="../../checkpoints/Pre-trained-BioGPT") 10 | parser.add_argument("--model_file", type=str, default="checkpoint.pt") 11 | parser.add_argument("--bpecodes", type=str, default="../../data/bpecodes") 12 | parser.add_argument("--beam", type=int, default=5) 13 | parser.add_argument("--lenpen", type=float, default=1.0) 14 | parser.add_argument("--min_len", type=int, default=100) 15 | parser.add_argument("--lower", default=False, action="store_true") 16 | args, _ = parser.parse_known_args() 17 | 18 | 19 | def main(args): 20 | 21 | m = TransformerLanguageModel.from_pretrained( 22 | args.model_dir, 23 | args.model_file, 24 | args.data_dir, 25 | tokenizer='moses', 26 | bpe='fastbpe', 27 | bpe_codes=args.bpecodes, 28 | min_len=args.min_len, 29 | max_len_b=1024, 30 | beam=args.beam, 31 | lenpen=args.lenpen, 32 | max_tokens=12000) 33 | 34 | print(m.cfg) 35 | if m.cfg.common.fp16: 36 | print('Converting to float 16') 37 | m.half() 38 | m.cuda() 39 | 40 | while True: 41 | print("Please input and press enter:") 42 | _src = input().strip() 43 | src_tokens = m.encode(_src) 44 | generate = m.generate([src_tokens], beam=args.beam)[0] 45 | output = m.decode(generate[0]["tokens"]) 46 | print(output) 47 | if __name__ == "__main__": 48 | main(args) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | from src.transformer_lm_prompt import TransformerLanguageModelPrompt 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data_dir", type=str, default='') 10 | parser.add_argument("--model_dir", type=str, default=None) 11 | parser.add_argument("--model_file", type=str, default="checkpoint_last.pt") 12 | parser.add_argument("--src_file", type=str, default=None) 13 | parser.add_argument("--output_file", type=str, default=None) 14 | parser.add_argument("--beam", type=int, default=1) 15 | parser.add_argument("--decoding_length", type=int, default=1024) 16 | args, _ = parser.parse_known_args() 17 | 18 | 19 | def main(args): 20 | src_inputs = [] 21 | with open(args.src_file) as reader: 22 | for line in reader: 23 | src_inputs.append(line.strip()) 24 | 25 | m = TransformerLanguageModelPrompt.from_pretrained( 26 | args.model_dir, 27 | args.model_file, 28 | args.data_dir, 29 | max_len_b=args.decoding_length, 30 | max_tokens=12000,) 31 | 32 | print(m.cfg) 33 | 34 | if m.cfg.common.fp16: 35 | print('Converting to float 16') 36 | m.half() 37 | m.cuda() 38 | 39 | outputs = m.sample(src_inputs, beam=args.beam) 40 | 41 | with open(f"{args.output_file}", "w", encoding='utf8') as fw: 42 | for i in range(len(outputs)): 43 | fw.write(outputs[i] + '\n') 44 | 45 | 46 | if __name__ == "__main__": 47 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | antlr4-python3-runtime==4.8 2 | bitarray==2.6.2 3 | cffi==1.15.1 4 | click==8.1.3 5 | colorama==0.4.6 6 | Cython==0.29.33 7 | fairseq==0.12.2 8 | hydra-core==1.0.7 9 | joblib==1.2.0 10 | lxml==4.9.2 11 | numpy==1.24.1 12 | omegaconf==2.0.6 13 | portalocker==2.7.0 14 | protobuf==3.20.1 15 | pycparser==2.21 16 | PyYAML==6.0 17 | regex==2022.10.31 18 | sacrebleu==2.3.1 19 | sacremoses==0.0.53 20 | scikit-learn==1.2.1 21 | scipy==1.10.0 22 | six==1.16.0 23 | tabulate==0.9.0 24 | tensorboardX==2.5.1 25 | threadpoolctl==3.1.0 26 | torch==1.12.0 27 | torchaudio==0.12.0 28 | tqdm==4.64.1 29 | typing-extensions==4.4.0 30 | -------------------------------------------------------------------------------- /scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | #!/usr/bin/env python3 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # 7 | # This source code is licensed under the MIT license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import argparse 11 | import collections 12 | import os 13 | import re 14 | 15 | import torch 16 | from fairseq.file_io import PathManager 17 | 18 | 19 | def average_checkpoints(inputs): 20 | """Loads checkpoints from inputs and returns a model with averaged weights. 21 | 22 | Args: 23 | inputs: An iterable of string paths of checkpoints to load from. 24 | 25 | Returns: 26 | A dict of string keys mapping to various values. The 'model' key 27 | from the returned dict should correspond to an OrderedDict mapping 28 | string parameter names to torch Tensors. 29 | """ 30 | params_dict = collections.OrderedDict() 31 | params_keys = None 32 | new_state = None 33 | num_models = len(inputs) 34 | 35 | for fpath in inputs: 36 | with PathManager.open(fpath, "rb") as f: 37 | state = torch.load( 38 | f, 39 | map_location=( 40 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 41 | ), 42 | ) 43 | # Copies over the settings from the first checkpoint 44 | if new_state is None: 45 | new_state = state 46 | 47 | model_params = state["model"] 48 | 49 | model_params_keys = list(model_params.keys()) 50 | if params_keys is None: 51 | params_keys = model_params_keys 52 | elif params_keys != model_params_keys: 53 | raise KeyError( 54 | "For checkpoint {}, expected list of params: {}, " 55 | "but found: {}".format(f, params_keys, model_params_keys) 56 | ) 57 | 58 | for k in params_keys: 59 | p = model_params[k] 60 | if isinstance(p, torch.HalfTensor): 61 | p = p.float() 62 | if k not in params_dict: 63 | params_dict[k] = p.clone() 64 | # NOTE: clone() is needed in case of p is a shared parameter 65 | else: 66 | params_dict[k] += p 67 | 68 | averaged_params = collections.OrderedDict() 69 | for k, v in params_dict.items(): 70 | averaged_params[k] = v 71 | if averaged_params[k].is_floating_point(): 72 | averaged_params[k].div_(num_models) 73 | else: 74 | averaged_params[k] //= num_models 75 | new_state["model"] = averaged_params 76 | return new_state 77 | 78 | 79 | def last_n_checkpoints(paths, n, update_based, upper_bound=None): 80 | assert len(paths) == 1 81 | path = paths[0] 82 | if update_based: 83 | pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") 84 | else: 85 | pt_regexp = re.compile(r"checkpoint(\d+)\.pt") 86 | files = PathManager.ls(path) 87 | 88 | entries = [] 89 | for f in files: 90 | m = pt_regexp.fullmatch(f) 91 | if m is not None: 92 | sort_key = int(m.group(1)) 93 | if upper_bound is None or sort_key <= upper_bound: 94 | entries.append((sort_key, m.group(0))) 95 | if len(entries) < n: 96 | raise Exception( 97 | "Found {} checkpoint files but need at least {}", len(entries), n 98 | ) 99 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 100 | 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser( 104 | description="Tool to average the params of input checkpoints to " 105 | "produce a new checkpoint", 106 | ) 107 | # fmt: off 108 | parser.add_argument('--inputs', required=True, nargs='+', 109 | help='Input checkpoint file paths.') 110 | parser.add_argument('--output', required=True, metavar='FILE', 111 | help='Write the new checkpoint containing the averaged weights to this path.') 112 | num_group = parser.add_mutually_exclusive_group() 113 | num_group.add_argument('--num-epoch-checkpoints', type=int, 114 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 115 | 'and average last this many of them.') 116 | num_group.add_argument('--num-update-checkpoints', type=int, 117 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 118 | 'and average last this many of them.') 119 | parser.add_argument('--checkpoint-upper-bound', type=int, 120 | help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 121 | 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 122 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' 123 | 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' 124 | ) 125 | # fmt: on 126 | args = parser.parse_args() 127 | print(args) 128 | 129 | num = None 130 | is_update_based = False 131 | if args.num_update_checkpoints is not None: 132 | num = args.num_update_checkpoints 133 | is_update_based = True 134 | elif args.num_epoch_checkpoints is not None: 135 | num = args.num_epoch_checkpoints 136 | 137 | assert args.checkpoint_upper_bound is None or ( 138 | args.num_epoch_checkpoints is not None 139 | or args.num_update_checkpoints is not None 140 | ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" 141 | assert ( 142 | args.num_epoch_checkpoints is None or args.num_update_checkpoints is None 143 | ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" 144 | 145 | if num is not None: 146 | args.inputs = last_n_checkpoints( 147 | args.inputs, 148 | num, 149 | is_update_based, 150 | upper_bound=args.checkpoint_upper_bound, 151 | ) 152 | print("averaging checkpoints: ", args.inputs) 153 | 154 | new_state = average_checkpoints(args.inputs) 155 | with PathManager.open(args.output, "wb") as f: 156 | torch.save(new_state, f) 157 | print("Finished writing averaged checkpoint to {}".format(args.output)) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .language_modeling_prompt import * 5 | from .transformer_lm_prompt import * 6 | from .language_model_prompt_dataset import * 7 | from .constrained_generator import * -------------------------------------------------------------------------------- /src/constrained_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | from typing import Dict, List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | from fairseq import search, utils 10 | from fairseq.data import data_utils 11 | from fairseq.models import FairseqIncrementalDecoder 12 | from torch import Tensor 13 | from fairseq.ngram_repeat_block import NGramRepeatBlock 14 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel 15 | 16 | 17 | class ConstrainedGenerator(nn.Module): 18 | def __init__( 19 | self, 20 | models, 21 | tgt_dict, 22 | beam_size=1, 23 | max_len_a=0, 24 | max_len_b=200, 25 | max_len=0, 26 | min_len=1, 27 | normalize_scores=True, 28 | len_penalty=1.0, 29 | unk_penalty=0.0, 30 | temperature=1.0, 31 | match_source_len=False, 32 | no_repeat_ngram_size=0, 33 | search_strategy=None, 34 | eos=None, 35 | symbols_to_strip_from_output=None, 36 | lm_model=None, 37 | lm_weight=1.0, 38 | ): 39 | """Generates translations of a given source sentence. 40 | 41 | Args: 42 | models (List[~fairseq.models.FairseqModel]): ensemble of models, 43 | currently support fairseq.models.TransformerModel for scripting 44 | beam_size (int, optional): beam width (default: 1) 45 | max_len_a/b (int, optional): generate sequences of maximum length 46 | ax + b, where x is the source length 47 | max_len (int, optional): the maximum length of the generated output 48 | (not including end-of-sentence) 49 | min_len (int, optional): the minimum length of the generated output 50 | (not including end-of-sentence) 51 | normalize_scores (bool, optional): normalize scores by the length 52 | of the output (default: True) 53 | len_penalty (float, optional): length penalty, where <1.0 favors 54 | shorter, >1.0 favors longer sentences (default: 1.0) 55 | unk_penalty (float, optional): unknown word penalty, where <0 56 | produces more unks, >0 produces fewer (default: 0.0) 57 | temperature (float, optional): temperature, where values 58 | >1.0 produce more uniform samples and values <1.0 produce 59 | sharper samples (default: 1.0) 60 | match_source_len (bool, optional): outputs should match the source 61 | length (default: False) 62 | """ 63 | super().__init__() 64 | if isinstance(models, EnsembleModel): 65 | self.model = models 66 | else: 67 | self.model = EnsembleModel(models) 68 | self.tgt_dict = tgt_dict 69 | self.pad = tgt_dict.pad() 70 | self.unk = tgt_dict.unk() 71 | self.eos = tgt_dict.eos() if eos is None else eos 72 | self.symbols_to_strip_from_output = ( 73 | symbols_to_strip_from_output.union({self.eos}) 74 | if symbols_to_strip_from_output is not None 75 | else {self.eos} 76 | ) 77 | self.vocab_size = len(tgt_dict) 78 | self.beam_size = beam_size 79 | # the max beam size is the dictionary size - 1, since we never select pad 80 | self.beam_size = min(beam_size, self.vocab_size - 1) 81 | self.max_len_a = max_len_a 82 | self.max_len_b = max_len_b 83 | self.min_len = min_len 84 | self.max_len = max_len or self.model.max_decoder_positions() 85 | 86 | self.normalize_scores = normalize_scores 87 | self.len_penalty = len_penalty 88 | self.unk_penalty = unk_penalty 89 | self.temperature = temperature 90 | self.match_source_len = match_source_len 91 | 92 | if no_repeat_ngram_size > 0: 93 | self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) 94 | else: 95 | self.repeat_ngram_blocker = None 96 | 97 | assert temperature > 0, "--temperature must be greater than 0" 98 | 99 | self.search = ( 100 | search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy 101 | ) 102 | # We only need to set src_lengths in LengthConstrainedBeamSearch. 103 | # As a module attribute, setting it would break in multithread 104 | # settings when the model is shared. 105 | self.should_set_src_lengths = ( 106 | hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths 107 | ) 108 | 109 | self.model.eval() 110 | 111 | self.lm_model = lm_model 112 | self.lm_weight = lm_weight 113 | if self.lm_model is not None: 114 | self.lm_model.eval() 115 | 116 | def cuda(self): 117 | self.model.cuda() 118 | return self 119 | 120 | @torch.no_grad() 121 | def forward( 122 | self, 123 | sample: Dict[str, Dict[str, Tensor]], 124 | prefix_tokens: Optional[Tensor] = None, 125 | bos_token: Optional[int] = None, 126 | ): 127 | """Generate a batch of translations. 128 | 129 | Args: 130 | sample (dict): batch 131 | prefix_tokens (torch.LongTensor, optional): force decoder to begin 132 | with these tokens 133 | bos_token (int, optional): beginning of sentence token 134 | (default: self.eos) 135 | """ 136 | return self._generate(sample, prefix_tokens, bos_token=bos_token) 137 | 138 | # TODO(myleott): unused, deprecate after pytorch-translate migration 139 | def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): 140 | """Iterate over a batched dataset and yield individual translations. 141 | Args: 142 | cuda (bool, optional): use GPU for generation 143 | timer (StopwatchMeter, optional): time generations 144 | """ 145 | for sample in data_itr: 146 | s = utils.move_to_cuda(sample) if cuda else sample 147 | if "net_input" not in s: 148 | continue 149 | input = s["net_input"] 150 | # model.forward normally channels prev_output_tokens into the decoder 151 | # separately, but SequenceGenerator directly calls model.encoder 152 | encoder_input = { 153 | k: v for k, v in input.items() if k != "prev_output_tokens" 154 | } 155 | if timer is not None: 156 | timer.start() 157 | with torch.no_grad(): 158 | hypos = self.generate(encoder_input) 159 | if timer is not None: 160 | timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) 161 | for i, id in enumerate(s["id"].data): 162 | # remove padding 163 | src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) 164 | ref = ( 165 | utils.strip_pad(s["target"].data[i, :], self.pad) 166 | if s["target"] is not None 167 | else None 168 | ) 169 | yield id, src, ref, hypos[i] 170 | 171 | @torch.no_grad() 172 | def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: 173 | """Generate translations. Match the api of other fairseq generators. 174 | 175 | Args: 176 | models (List[~fairseq.models.FairseqModel]): ensemble of models 177 | sample (dict): batch 178 | prefix_tokens (torch.LongTensor, optional): force decoder to begin 179 | with these tokens 180 | constraints (torch.LongTensor, optional): force decoder to include 181 | the list of constraints 182 | bos_token (int, optional): beginning of sentence token 183 | (default: self.eos) 184 | """ 185 | return self._generate(sample, **kwargs) 186 | 187 | def _generate( 188 | self, 189 | sample: Dict[str, Dict[str, Tensor]], 190 | prefix_tokens: Optional[Tensor] = None, 191 | constraints: Optional[Tensor] = None, 192 | bos_token: Optional[int] = None, 193 | allowed_text: Optional[List[Tensor]] = None, 194 | ): 195 | incremental_states = torch.jit.annotate( 196 | List[Dict[str, Dict[str, Optional[Tensor]]]], 197 | [ 198 | torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) 199 | for i in range(self.model.models_size) 200 | ], 201 | ) 202 | net_input = sample["net_input"] 203 | 204 | if "src_tokens" in net_input: 205 | src_tokens = net_input["src_tokens"] 206 | # length of the source text being the character length except EndOfSentence and pad 207 | src_lengths = ( 208 | (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) 209 | ) 210 | elif "source" in net_input: 211 | src_tokens = net_input["source"] 212 | src_lengths = ( 213 | net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) 214 | if net_input["padding_mask"] is not None 215 | else torch.tensor(src_tokens.size(-1)).to(src_tokens) 216 | ) 217 | else: 218 | raise Exception("expected src_tokens or source in net input") 219 | 220 | # bsz: total number of sentences in beam 221 | # Note that src_tokens may have more than 2 dimensions (i.e. audio features) 222 | bsz, src_len = src_tokens.size()[:2] 223 | beam_size = self.beam_size 224 | 225 | if constraints is not None and not self.search.supports_constraints: 226 | raise NotImplementedError( 227 | "Target-side constraints were provided, but search method doesn't support them" 228 | ) 229 | 230 | # Initialize constraints, when active 231 | self.search.init_constraints(constraints, beam_size) 232 | 233 | # Allowed text 234 | if allowed_text is not None: 235 | if allowed_text.dim() == 1: 236 | allowed_text = allowed_text.unsqueeze(dim=0).repeat_interleave(bsz, dim=0) 237 | allowed_text = torch.cat([prefix_tokens, allowed_text], dim=1) 238 | 239 | max_len: int = -1 240 | if self.match_source_len: 241 | max_len = src_lengths.max().item() 242 | else: 243 | max_len = min( 244 | int(self.max_len_a * src_len + self.max_len_b), 245 | self.max_len - 1, 246 | ) 247 | assert ( 248 | self.min_len <= max_len 249 | ), "min_len cannot be larger than max_len, please adjust these!" 250 | # compute the encoder output for each beam 251 | encoder_outs = self.model.forward_encoder(net_input) 252 | 253 | # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores 254 | new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) 255 | new_order = new_order.to(src_tokens.device).long() 256 | encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) 257 | # ensure encoder_outs is a List. 258 | assert encoder_outs is not None 259 | 260 | # initialize buffers 261 | scores = ( 262 | torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() 263 | ) # +1 for eos; pad is never chosen for scoring 264 | tokens = ( 265 | torch.zeros(bsz * beam_size, max_len + 2) 266 | .to(src_tokens) 267 | .long() 268 | .fill_(self.pad) 269 | ) # +2 for eos and pad 270 | tokens[:, 0] = self.eos if bos_token is None else bos_token 271 | attn: Optional[Tensor] = None 272 | 273 | # A list that indicates candidates that should be ignored. 274 | # For example, suppose we're sampling and have already finalized 2/5 275 | # samples. Then cands_to_ignore would mark 2 positions as being ignored, 276 | # so that we only finalize the remaining 3 samples. 277 | cands_to_ignore = ( 278 | torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) 279 | ) # forward and backward-compatible False mask 280 | 281 | # list of completed sentences 282 | finalized = torch.jit.annotate( 283 | List[List[Dict[str, Tensor]]], 284 | [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], 285 | ) # contains lists of dictionaries of information about the hypothesis being finalized at each step 286 | 287 | # a boolean array indicating if the sentence at the index is finished or not 288 | finished = [False for i in range(bsz)] 289 | num_remaining_sent = bsz # number of sentences remaining 290 | 291 | # number of candidate hypos per step 292 | cand_size = 2 * beam_size # 2 x beam size in case half are EOS 293 | 294 | # offset arrays for converting between different indexing schemes 295 | bbsz_offsets = ( 296 | (torch.arange(0, bsz) * beam_size) 297 | .unsqueeze(1) 298 | .type_as(tokens) 299 | .to(src_tokens.device) 300 | ) 301 | cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) 302 | 303 | reorder_state: Optional[Tensor] = None 304 | batch_idxs: Optional[Tensor] = None 305 | 306 | original_batch_idxs: Optional[Tensor] = None 307 | if "id" in sample and isinstance(sample["id"], Tensor): 308 | original_batch_idxs = sample["id"] 309 | else: 310 | original_batch_idxs = torch.arange(0, bsz).type_as(tokens) 311 | 312 | for step in range(max_len + 1): # one extra step for EOS marker 313 | # reorder decoder internal states based on the prev choice of beams 314 | if reorder_state is not None: 315 | if batch_idxs is not None: 316 | # update beam indices to take into account removed sentences 317 | corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( 318 | batch_idxs 319 | ) 320 | reorder_state.view(-1, beam_size).add_( 321 | corr.unsqueeze(-1) * beam_size 322 | ) 323 | original_batch_idxs = original_batch_idxs[batch_idxs] 324 | self.model.reorder_incremental_state(incremental_states, reorder_state) 325 | encoder_outs = self.model.reorder_encoder_out( 326 | encoder_outs, reorder_state 327 | ) 328 | 329 | lprobs, avg_attn_scores = self.model.forward_decoder( 330 | tokens[:, : step + 1], 331 | encoder_outs, 332 | incremental_states, 333 | self.temperature, 334 | ) 335 | 336 | if self.lm_model is not None: 337 | lm_out = self.lm_model(tokens[:, : step + 1]) 338 | probs = self.lm_model.get_normalized_probs( 339 | lm_out, log_probs=True, sample=None 340 | ) 341 | probs = probs[:, -1, :] * self.lm_weight 342 | lprobs += probs 343 | 344 | lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) 345 | 346 | # hack the lprobs here to constrain the output to be in the allowed text 347 | if allowed_text is not None: 348 | if batch_idxs is not None: 349 | allowed_text = allowed_text[batch_idxs] 350 | bi = torch.arange(lprobs.size(0)).unsqueeze(dim=1).repeat_interleave(allowed_text.size(-1), dim=-1) 351 | mask = torch.ones(lprobs.size()).to(lprobs) 352 | mask[bi.view(-1), allowed_text.view(-1)] = 0 353 | mask[mask==1] = -math.inf 354 | lprobs += mask 355 | 356 | lprobs[:, self.pad] = -math.inf # never select pad 357 | lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty 358 | 359 | # handle max length constraint 360 | if step >= max_len: 361 | lprobs[:, : self.eos] = -math.inf 362 | lprobs[:, self.eos + 1 :] = -math.inf 363 | 364 | # handle prefix tokens (possibly with different lengths) 365 | if ( 366 | prefix_tokens is not None 367 | and step < prefix_tokens.size(1) 368 | and step < max_len 369 | ): 370 | lprobs, tokens, scores = self._prefix_tokens( 371 | step, lprobs, scores, tokens, prefix_tokens, beam_size 372 | ) 373 | elif step < self.min_len: 374 | # minimum length constraint (does not apply if using prefix_tokens) 375 | lprobs[:, self.eos] = -math.inf 376 | 377 | # Record attention scores, only support avg_attn_scores is a Tensor 378 | if avg_attn_scores is not None: 379 | if attn is None: 380 | attn = torch.empty( 381 | bsz * beam_size, avg_attn_scores.size(1), max_len + 2 382 | ).to(scores) 383 | attn[:, :, step + 1].copy_(avg_attn_scores) 384 | 385 | scores = scores.type_as(lprobs) 386 | eos_bbsz_idx = torch.empty(0).to( 387 | tokens 388 | ) # indices of hypothesis ending with eos (finished sentences) 389 | eos_scores = torch.empty(0).to( 390 | scores 391 | ) # scores of hypothesis ending with eos (finished sentences) 392 | 393 | if self.should_set_src_lengths: 394 | self.search.set_src_lengths(src_lengths) 395 | 396 | if self.repeat_ngram_blocker is not None: 397 | lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) 398 | 399 | # Shape: (batch, cand_size) 400 | cand_scores, cand_indices, cand_beams = self.search.step( 401 | step, 402 | lprobs.view(bsz, -1, self.vocab_size), 403 | scores.view(bsz, beam_size, -1)[:, :, :step], 404 | tokens[:, : step + 1], 405 | original_batch_idxs, 406 | ) 407 | 408 | # cand_bbsz_idx contains beam indices for the top candidate 409 | # hypotheses, with a range of values: [0, bsz*beam_size), 410 | # and dimensions: [bsz, cand_size] 411 | cand_bbsz_idx = cand_beams.add(bbsz_offsets) 412 | 413 | # finalize hypotheses that end in eos 414 | # Shape of eos_mask: (batch size, beam size) 415 | eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) 416 | eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) 417 | 418 | # only consider eos when it's among the top beam_size indices 419 | # Now we know what beam item(s) to finish 420 | # Shape: 1d list of absolute-numbered 421 | eos_bbsz_idx = torch.masked_select( 422 | cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] 423 | ) 424 | 425 | finalized_sents: List[int] = [] 426 | if eos_bbsz_idx.numel() > 0: 427 | eos_scores = torch.masked_select( 428 | cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] 429 | ) 430 | 431 | finalized_sents = self.finalize_hypos( 432 | step, 433 | eos_bbsz_idx, 434 | eos_scores, 435 | tokens, 436 | scores, 437 | finalized, 438 | finished, 439 | beam_size, 440 | attn, 441 | src_lengths, 442 | max_len, 443 | ) 444 | num_remaining_sent -= len(finalized_sents) 445 | 446 | assert num_remaining_sent >= 0 447 | if num_remaining_sent == 0: 448 | break 449 | if self.search.stop_on_max_len and step >= max_len: 450 | break 451 | assert step < max_len, f"{step} < {max_len}" 452 | 453 | # Remove finalized sentences (ones for which {beam_size} 454 | # finished hypotheses have been generated) from the batch. 455 | if len(finalized_sents) > 0: 456 | new_bsz = bsz - len(finalized_sents) 457 | 458 | # construct batch_idxs which holds indices of batches to keep for the next pass 459 | batch_mask = torch.ones( 460 | bsz, dtype=torch.bool, device=cand_indices.device 461 | ) 462 | batch_mask[finalized_sents] = False 463 | # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it 464 | batch_idxs = torch.arange( 465 | bsz, device=cand_indices.device 466 | ).masked_select(batch_mask) 467 | 468 | # Choose the subset of the hypothesized constraints that will continue 469 | self.search.prune_sentences(batch_idxs) 470 | 471 | eos_mask = eos_mask[batch_idxs] 472 | cand_beams = cand_beams[batch_idxs] 473 | bbsz_offsets.resize_(new_bsz, 1) 474 | cand_bbsz_idx = cand_beams.add(bbsz_offsets) 475 | cand_scores = cand_scores[batch_idxs] 476 | cand_indices = cand_indices[batch_idxs] 477 | 478 | if prefix_tokens is not None: 479 | prefix_tokens = prefix_tokens[batch_idxs] 480 | src_lengths = src_lengths[batch_idxs] 481 | cands_to_ignore = cands_to_ignore[batch_idxs] 482 | 483 | scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) 484 | tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) 485 | if attn is not None: 486 | attn = attn.view(bsz, -1)[batch_idxs].view( 487 | new_bsz * beam_size, attn.size(1), -1 488 | ) 489 | bsz = new_bsz 490 | else: 491 | batch_idxs = None 492 | 493 | # Set active_mask so that values > cand_size indicate eos hypos 494 | # and values < cand_size indicate candidate active hypos. 495 | # After, the min values per row are the top candidate active hypos 496 | 497 | # Rewrite the operator since the element wise or is not supported in torchscript. 498 | 499 | eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) 500 | active_mask = torch.add( 501 | eos_mask.type_as(cand_offsets) * cand_size, 502 | cand_offsets[: eos_mask.size(1)], 503 | ) 504 | 505 | # get the top beam_size active hypotheses, which are just 506 | # the hypos with the smallest values in active_mask. 507 | # {active_hypos} indicates which {beam_size} hypotheses 508 | # from the list of {2 * beam_size} candidates were 509 | # selected. Shapes: (batch size, beam size) 510 | new_cands_to_ignore, active_hypos = torch.topk( 511 | active_mask, k=beam_size, dim=1, largest=False 512 | ) 513 | 514 | # update cands_to_ignore to ignore any finalized hypos. 515 | cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] 516 | # Make sure there is at least one active item for each sentence in the batch. 517 | assert (~cands_to_ignore).any(dim=1).all() 518 | 519 | # update cands_to_ignore to ignore any finalized hypos 520 | 521 | # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam 522 | # can be selected more than once). 523 | active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) 524 | active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) 525 | 526 | active_bbsz_idx = active_bbsz_idx.view(-1) 527 | active_scores = active_scores.view(-1) 528 | 529 | # copy tokens and scores for active hypotheses 530 | 531 | # Set the tokens for each beam (can select the same row more than once) 532 | tokens[:, : step + 1] = torch.index_select( 533 | tokens[:, : step + 1], dim=0, index=active_bbsz_idx 534 | ) 535 | # Select the next token for each of them 536 | tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( 537 | cand_indices, dim=1, index=active_hypos 538 | ) 539 | if step > 0: 540 | scores[:, :step] = torch.index_select( 541 | scores[:, :step], dim=0, index=active_bbsz_idx 542 | ) 543 | scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( 544 | cand_scores, dim=1, index=active_hypos 545 | ) 546 | 547 | # Update constraints based on which candidates were selected for the next beam 548 | self.search.update_constraints(active_hypos) 549 | 550 | # copy attention for active hypotheses 551 | if attn is not None: 552 | attn[:, :, : step + 2] = torch.index_select( 553 | attn[:, :, : step + 2], dim=0, index=active_bbsz_idx 554 | ) 555 | 556 | # reorder incremental state in decoder 557 | reorder_state = active_bbsz_idx 558 | 559 | # sort by score descending 560 | for sent in range(len(finalized)): 561 | scores = torch.tensor( 562 | [float(elem["score"].item()) for elem in finalized[sent]] 563 | ) 564 | _, sorted_scores_indices = torch.sort(scores, descending=True) 565 | finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] 566 | finalized[sent] = torch.jit.annotate( 567 | List[Dict[str, Tensor]], finalized[sent] 568 | ) 569 | return finalized 570 | 571 | def _prefix_tokens( 572 | self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int 573 | ): 574 | """Handle prefix tokens""" 575 | prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) 576 | prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) 577 | prefix_mask = prefix_toks.ne(self.pad) 578 | lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) 579 | lprobs[prefix_mask] = lprobs[prefix_mask].scatter( 580 | -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] 581 | ) 582 | # if prefix includes eos, then we should make sure tokens and 583 | # scores are the same across all beams 584 | eos_mask = prefix_toks.eq(self.eos) 585 | if eos_mask.any(): 586 | # validate that the first beam matches the prefix 587 | first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ 588 | :, 0, 1 : step + 1 589 | ] 590 | eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] 591 | target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] 592 | assert (first_beam == target_prefix).all() 593 | 594 | # copy tokens, scores and lprobs from the first beam to all beams 595 | tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) 596 | scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) 597 | lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) 598 | return lprobs, tokens, scores 599 | 600 | def replicate_first_beam(self, tensor, mask, beam_size: int): 601 | tensor = tensor.view(-1, beam_size, tensor.size(-1)) 602 | tensor[mask] = tensor[mask][:, :1, :] 603 | return tensor.view(-1, tensor.size(-1)) 604 | 605 | def finalize_hypos( 606 | self, 607 | step: int, 608 | bbsz_idx, 609 | eos_scores, 610 | tokens, 611 | scores, 612 | finalized: List[List[Dict[str, Tensor]]], 613 | finished: List[bool], 614 | beam_size: int, 615 | attn: Optional[Tensor], 616 | src_lengths, 617 | max_len: int, 618 | ): 619 | """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. 620 | A sentence is finalized when {beam_size} finished items have been collected for it. 621 | 622 | Returns number of sentences (not beam items) being finalized. 623 | These will be removed from the batch and not processed further. 624 | Args: 625 | bbsz_idx (Tensor): 626 | """ 627 | assert bbsz_idx.numel() == eos_scores.numel() 628 | 629 | # clone relevant token and attention tensors. 630 | # tokens is (batch * beam, max_len). So the index_select 631 | # gets the newly EOS rows, then selects cols 1..{step + 2} 632 | tokens_clone = tokens.index_select(0, bbsz_idx)[ 633 | :, 1 : step + 2 634 | ] # skip the first index, which is EOS 635 | 636 | tokens_clone[:, step] = self.eos 637 | attn_clone = ( 638 | attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] 639 | if attn is not None 640 | else None 641 | ) 642 | 643 | # compute scores per token position 644 | pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] 645 | pos_scores[:, step] = eos_scores 646 | # convert from cumulative to per-position scores 647 | pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] 648 | 649 | # normalize sentence-level scores 650 | if self.normalize_scores: 651 | eos_scores /= (step + 1) ** self.len_penalty 652 | 653 | # cum_unfin records which sentences in the batch are finished. 654 | # It helps match indexing between (a) the original sentences 655 | # in the batch and (b) the current, possibly-reduced set of 656 | # sentences. 657 | cum_unfin: List[int] = [] 658 | prev = 0 659 | for f in finished: 660 | if f: 661 | prev += 1 662 | else: 663 | cum_unfin.append(prev) 664 | 665 | # The keys here are of the form "{sent}_{unfin_idx}", where 666 | # "unfin_idx" is the index in the current (possibly reduced) 667 | # list of sentences, and "sent" is the index in the original, 668 | # unreduced batch 669 | # set() is not supported in script export 670 | sents_seen: Dict[str, Optional[Tensor]] = {} 671 | 672 | # For every finished beam item 673 | for i in range(bbsz_idx.size()[0]): 674 | idx = bbsz_idx[i] 675 | score = eos_scores[i] 676 | # sentence index in the current (possibly reduced) batch 677 | unfin_idx = idx // beam_size 678 | # sentence index in the original (unreduced) batch 679 | sent = unfin_idx + cum_unfin[unfin_idx] 680 | # Cannot create dict for key type '(int, int)' in torchscript. 681 | # The workaround is to cast int to string 682 | seen = str(sent.item()) + "_" + str(unfin_idx.item()) 683 | if seen not in sents_seen: 684 | sents_seen[seen] = None 685 | 686 | if self.match_source_len and step > src_lengths[unfin_idx]: 687 | score = torch.tensor(-math.inf).to(score) 688 | 689 | # An input sentence (among those in a batch) is finished when 690 | # beam_size hypotheses have been collected for it 691 | if len(finalized[sent]) < beam_size: 692 | if attn_clone is not None: 693 | # remove padding tokens from attn scores 694 | hypo_attn = attn_clone[i] 695 | else: 696 | hypo_attn = torch.empty(0) 697 | 698 | finalized[sent].append( 699 | { 700 | "tokens": tokens_clone[i], 701 | "score": score, 702 | "attention": hypo_attn, # src_len x tgt_len 703 | "alignment": torch.empty(0), 704 | "positional_scores": pos_scores[i], 705 | } 706 | ) 707 | 708 | newly_finished: List[int] = [] 709 | 710 | for seen in sents_seen.keys(): 711 | # check termination conditions for this sentence 712 | sent: int = int(float(seen.split("_")[0])) 713 | unfin_idx: int = int(float(seen.split("_")[1])) 714 | 715 | if not finished[sent] and self.is_finished( 716 | step, unfin_idx, max_len, len(finalized[sent]), beam_size 717 | ): 718 | finished[sent] = True 719 | newly_finished.append(unfin_idx) 720 | 721 | return newly_finished 722 | 723 | def is_finished( 724 | self, 725 | step: int, 726 | unfin_idx: int, 727 | max_len: int, 728 | finalized_sent_len: int, 729 | beam_size: int, 730 | ): 731 | """ 732 | Check whether decoding for a sentence is finished, which 733 | occurs when the list of finalized sentences has reached the 734 | beam size, or when we reach the maximum length. 735 | """ 736 | assert finalized_sent_len <= beam_size 737 | if finalized_sent_len == beam_size or step == max_len: 738 | return True 739 | return False 740 | -------------------------------------------------------------------------------- /src/language_model_prompt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | from fairseq.data import FairseqDataset, data_utils 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def collate(samples, pad_idx, eos_idx, prefix=False, sep_idx=None, prompt=None): 15 | if len(samples) == 0: 16 | return {} 17 | 18 | def make_sentence(prompt, source, target): 19 | if source[-1] == eos_idx: 20 | source = source[:-1] 21 | if prompt is None: 22 | return torch.cat([source, target], dim=0) 23 | if prefix: 24 | sep = torch.LongTensor([sep_idx]) 25 | return torch.cat([prompt, source, sep, target], dim=0) 26 | return torch.cat([source, prompt, target], dim=0) 27 | 28 | 29 | def merge(tokens, move_eos_to_beginning=False): 30 | return data_utils.collate_tokens( 31 | tokens, 32 | pad_idx, 33 | eos_idx, 34 | move_eos_to_beginning=move_eos_to_beginning, 35 | ) 36 | 37 | id = torch.LongTensor([s["id"] for s in samples]) 38 | #src_tokens = merge([s["source"] for s in samples]) 39 | #src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) 40 | 41 | target_tokens = [] 42 | target_lengths = [] 43 | for s in samples: 44 | target_tokens.append(make_sentence(prompt, s["source"], s["target"])) 45 | 46 | target_lengths = [t.ne(pad_idx).long().sum() for t in target_tokens] 47 | target = merge(target_tokens) 48 | target_lengths = torch.LongTensor(target_lengths) 49 | prev_output_tokens = merge(target_tokens, move_eos_to_beginning=True) 50 | ntokens = target_lengths.sum().item() 51 | batch = { 52 | "id": id, 53 | "nsentences": len(samples), 54 | "ntokens": ntokens, 55 | "net_input": { 56 | "src_tokens": prev_output_tokens, #src_tokens, 57 | "src_lengths": target_lengths, #src_lengths, 58 | #"prev_output_tokens": prev_output_tokens, 59 | #"target_lengths": target_lengths, 60 | }, 61 | "target": target, 62 | } 63 | return batch 64 | 65 | 66 | class LanguageModelPromptDataset(FairseqDataset): 67 | """ 68 | A pair of torch.utils.data.Datasets. 69 | 70 | Args: 71 | src (torch.utils.data.Dataset): source dataset to wrap 72 | src_sizes (List[int]): source sentence lengths 73 | dictionary (~fairseq.data.Dictionary): vocabulary 74 | tgt (torch.utils.data.Dataset, optional): target dataset to wrap 75 | tgt_sizes (List[int], optional): target sentence lengths 76 | prefix (bool, optional): prefix 77 | prompt (str, optional): prompt to use 78 | shuffle (bool, optional): shuffle dataset elements before batching 79 | (default: True). 80 | max_source_length (int): max source text length 81 | max_length (int): max text length 82 | prompt_length (int): length of the prompt text 83 | 84 | """ 85 | 86 | def __init__( 87 | self, 88 | src, 89 | src_sizes, 90 | dictionary, 91 | tgt, 92 | tgt_sizes, 93 | prefix=False, 94 | prompt=None, 95 | shuffle=True, 96 | eos=None, 97 | max_source_length=None, 98 | max_length=None, 99 | prompt_length=None, 100 | ): 101 | self.src = src 102 | self.tgt = tgt 103 | self.prefix = prefix 104 | self.seq_sep = None 105 | self.prompt = prompt 106 | self.dict = dictionary 107 | self.shuffle = shuffle 108 | self.eos = eos if eos is not None else dictionary.eos() 109 | self.max_source_length = max_source_length 110 | self.max_target_length = max_length - max_source_length - prompt_length 111 | if self.prefix: 112 | self.max_target_length -= 1 113 | self.src_sizes = [min(s-1, self.max_source_length) for s in src_sizes] 114 | self.tgt_sizes = [min(t, self.max_target_length) for t in tgt_sizes] 115 | self.sizes = np.array([s+t for s,t in zip(self.src_sizes, self.tgt_sizes)]) 116 | self.buckets = None 117 | 118 | def get_batch_shapes(self): 119 | return self.buckets 120 | 121 | def __getitem__(self, index): 122 | src_item = self.src[index] 123 | if src_item.size(0) - 1 > self.max_source_length: 124 | src_item = src_item[:self.max_source_length + 1] 125 | src_item[-2] = self.dict.index('...') 126 | src_item[-1] = self.eos 127 | 128 | tgt_item = self.tgt[index] 129 | if tgt_item.size(0) > self.max_target_length: 130 | tgt_item = tgt_item[:self.max_target_length] 131 | tgt_item[-2] = self.dict.index('...') 132 | tgt_item[-1] = self.eos 133 | example = { 134 | "id": index, 135 | "source": src_item, 136 | "target": tgt_item, 137 | } 138 | return example 139 | 140 | def __len__(self): 141 | return len(self.src) 142 | 143 | def collater(self, samples): 144 | """Merge a list of samples to form a mini-batch. 145 | 146 | Args: 147 | samples (List[dict]): samples to collate 148 | Returns: 149 | dict: a mini-batch with the following keys: 150 | 151 | - `id` (LongTensor): example IDs in the original input order 152 | - `ntokens` (int): total number of tokens in the batch 153 | - `net_input` (dict): the input to the Model, containing keys: 154 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 155 | the source sentence of shape `(bsz, src_len)`. 156 | - `src_lengths` (LongTensor): 1D Tensor of the unpadded 157 | lengths of each source sentence of shape `(bsz)` 158 | - `prev_output_tokens` (LongTensor): a padded 2D Tensor of 159 | tokens in the target sentence, shifted right by one 160 | position for teacher forcing, of shape `(bsz, tgt_len)`. 161 | - `lengths` (LongTensor): 1D Tensor of the unpadded 162 | lengths of each target sentence of shape `(bsz)` 163 | """ 164 | res = collate( 165 | samples, 166 | pad_idx=self.dict.pad(), 167 | eos_idx=self.dict.eos(), 168 | prefix=self.prefix, 169 | sep_idx=self.dict.sep_index, 170 | prompt=self.prompt, 171 | ) 172 | return res 173 | 174 | def num_tokens(self, index): 175 | """Return the number of tokens in a sample. This value is used to 176 | enforce ``--max-tokens`` during batching.""" 177 | return self.sizes[index] 178 | 179 | def num_tokens_vec(self, indices): 180 | """Return the number of tokens for a set of positions defined by indices. 181 | This value is used to enforce ``--max-tokens`` during batching.""" 182 | sizes = self.sizes[indices] 183 | return sizes 184 | 185 | def size(self, index): 186 | """Return an example's size as a float or tuple. This value is used when 187 | filtering a dataset with ``--max-positions``.""" 188 | return self.sizes[index] 189 | 190 | def ordered_indices(self): 191 | """Return an ordered list of indices. Batches will be constructed based 192 | on this order.""" 193 | if self.shuffle: 194 | indices = np.random.permutation(len(self)).astype(np.int64) 195 | else: 196 | indices = np.arange(len(self), dtype=np.int64) 197 | return indices[np.argsort(self.sizes[indices], kind="mergesort")] 198 | 199 | @property 200 | def supports_prefetch(self): 201 | return getattr(self.src, "supports_prefetch", False) and ( 202 | getattr(self.tgt, "supports_prefetch", False) or self.tgt is None 203 | ) 204 | 205 | def prefetch(self, indices): 206 | self.src.prefetch(indices) 207 | self.tgt.prefetch(indices) 208 | -------------------------------------------------------------------------------- /src/language_modeling_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import logging 5 | import os 6 | from dataclasses import dataclass, field 7 | from typing import Optional 8 | 9 | import torch 10 | from fairseq import search, utils 11 | from fairseq.data import ( 12 | Dictionary, 13 | data_utils, 14 | indexed_dataset, 15 | ) 16 | 17 | from fairseq.tasks import register_task 18 | from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask 19 | from .language_model_prompt_dataset import LanguageModelPromptDataset 20 | from omegaconf import II 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class LanguageModelingPromptConfig(LanguageModelingConfig): 28 | source_lang: Optional[str] = field( 29 | default=None, metadata={"help": "source language", "argparse_alias": "-s",} 30 | ) 31 | target_lang: Optional[str] = field( 32 | default=None, metadata={"help": "target language","argparse_alias": "-t",} 33 | ) 34 | max_source_positions: Optional[int] = field( 35 | default=384, metadata={"help": "max number of tokens in the source sequence, exclude eos."} 36 | ) 37 | manual_prompt: Optional[str] = field( 38 | default=None, metadata={"help": "manual prompt to use",} 39 | ) 40 | learned_prompt: Optional[int] = field( 41 | default=None, metadata={"help": "number of virtual tokens to use",} 42 | ) 43 | learned_prompt_pattern: Optional[str] = field( 44 | default='learned', metadata={"help": "pattern of virtual tokens, default is learned",} 45 | ) 46 | prefix: Optional[bool] = field( 47 | default=False, metadata={"help": "whether put prompt as prefix."} 48 | ) 49 | sep_token: Optional[str] = field( 50 | default="<seqsep>", metadata={"help": "token to seperate prompt source and target."} 51 | ) 52 | 53 | 54 | @register_task("language_modeling_prompt", dataclass=LanguageModelingPromptConfig) 55 | class LanguageModelingPromptTask(LanguageModelingTask): 56 | """ 57 | Train a language model. 58 | 59 | Args: 60 | dictionary (~fairseq.data.Dictionary): the dictionary for the input of 61 | the language model 62 | output_dictionary (~fairseq.data.Dictionary): the dictionary for the 63 | output of the language model. In most cases it will be the same as 64 | *dictionary*, but could possibly be a more limited version of the 65 | dictionary (if ``--output-dictionary-size`` is used). 66 | targets (List[str]): list of the target types that the language model 67 | should predict. Can be one of "self", "future", and "past". 68 | Defaults to "future". 69 | 70 | .. note:: 71 | 72 | The language modeling task is compatible with :mod:`fairseq-train`, 73 | :mod:`fairseq-generate`, :mod:`fairseq-interactive` and 74 | :mod:`fairseq-eval-lm`. 75 | 76 | The language modeling task provides the following additional command-line 77 | arguments: 78 | 79 | .. argparse:: 80 | :ref: fairseq.tasks.language_modeling_parser 81 | :prog: 82 | """ 83 | def __init__(self, args, dictionary, output_dictionary=None, prompt=None, targets=None): 84 | super().__init__(args, dictionary, output_dictionary, targets) 85 | self.prompt = prompt 86 | self.prompt_length = self.prompt.size(0) if self.prompt is not None else 0 87 | self.prefix = args.prefix 88 | 89 | @classmethod 90 | def setup_prompt(cls, args, dictionary): 91 | if args.prefix: 92 | dictionary.sep_index = dictionary.add_symbol(args.sep_token) 93 | else: 94 | dictionary.sep_index = None 95 | assert not (args.manual_prompt and args.learned_prompt), "manual prompt and learned prompt can not be set " 96 | if args.manual_prompt and len(args.manual_prompt) != 0: 97 | prompt = dictionary.encode_line(args.manual_prompt, append_eos=False).long() 98 | elif args.learned_prompt: 99 | prompt = '' 100 | for idx in range(args.learned_prompt): 101 | prompt += args.learned_prompt_pattern + str(idx+1) + ' ' 102 | prompt = dictionary.encode_line(prompt, append_eos=False).long() 103 | else: 104 | prompt = None 105 | return prompt 106 | 107 | @classmethod 108 | def setup_dictionary(cls, args, **kwargs): 109 | dictionary = None 110 | output_dictionary = None 111 | if args.data: 112 | paths = utils.split_paths(args.data) 113 | assert len(paths) > 0 114 | dictionary = Dictionary.load(os.path.join(paths[0], "dict.{}.txt".format(args.source_lang))) 115 | logger.info("dictionary: {} types".format(len(dictionary))) 116 | #output_dictionary = Dictionary.load(os.path.join(paths[0], "dict.{}.txt".format(args.target_lang))) 117 | output_dictionary = dictionary 118 | return (dictionary, output_dictionary) 119 | 120 | @classmethod 121 | def setup_task(cls, args, **kwargs): 122 | """Setup the task (e.g., load dictionaries). 123 | 124 | Args: 125 | args (argparse.Namespace): parsed command-line arguments 126 | """ 127 | paths = utils.split_paths(args.data) 128 | assert len(paths) > 0 129 | # find language pair automatically 130 | if args.source_lang is None or args.target_lang is None: 131 | args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0]) 132 | if args.source_lang is None or args.target_lang is None: 133 | raise Exception( 134 | "Could not infer language pair, please provide it explicitly" 135 | ) 136 | 137 | dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) 138 | prompt = cls.setup_prompt(args, dictionary) 139 | 140 | # upgrade old checkpoints 141 | if getattr(args, "exclude_self_target", False): 142 | args.self_target = False 143 | 144 | targets = [] 145 | if getattr(args, "self_target", False): 146 | targets.append("self") 147 | if getattr(args, "future_target", False): 148 | targets.append("future") 149 | if getattr(args, "past_target", False): 150 | targets.append("past") 151 | if len(targets) == 0: 152 | # standard language modeling 153 | targets = ["future"] 154 | 155 | return cls(args, dictionary, output_dictionary, prompt, targets=targets) 156 | 157 | def load_dataset( 158 | self, split: str, epoch=1, combine=False, **kwargs 159 | ) -> LanguageModelPromptDataset: 160 | """Load a given dataset split. 161 | 162 | Args: 163 | split (str): name of the split (e.g., train, valid, test) 164 | """ 165 | def split_exists(split, src, tgt, lang, data_path): 166 | filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) 167 | return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) 168 | 169 | paths = utils.split_paths(self.args.data) 170 | assert len(paths) > 0 171 | data_path = paths[(epoch - 1) % len(paths)] 172 | 173 | # source 174 | if split_exists(split, self.args.source_lang, self.args.target_lang, self.args.source_lang, data_path): 175 | prefix = os.path.join(data_path, "{}.{}-{}.".format(split, self.args.source_lang, self.args.target_lang)) 176 | else: 177 | raise FileNotFoundError( 178 | "Dataset not found: {} ({})".format(split, data_path) 179 | ) 180 | src_dataset = data_utils.load_indexed_dataset( 181 | prefix + self.args.source_lang, self.dictionary, self.args.dataset_impl 182 | ) 183 | 184 | tgt_dataset = data_utils.load_indexed_dataset( 185 | prefix + self.args.target_lang, self.output_dictionary, self.args.dataset_impl 186 | ) 187 | 188 | src_sizes = src_dataset.sizes 189 | tgt_sizes = tgt_dataset.sizes 190 | 191 | dataset = LanguageModelPromptDataset( 192 | src_dataset, 193 | src_sizes, 194 | self.dictionary, 195 | tgt_dataset, 196 | tgt_sizes, 197 | prefix = self.prefix, 198 | prompt=self.prompt, 199 | max_source_length=self.args.max_source_positions, 200 | max_length=self.args.max_target_positions, 201 | prompt_length=self.prompt_length 202 | ) 203 | 204 | self.datasets[split] = dataset 205 | 206 | def build_dataset_for_inference(self, src_tokens, src_lengths, tgt_tokens=None, tgt_lengths=None): 207 | """ 208 | Generate batches for inference. We prepend an eos token to src_tokens 209 | (or bos if `--add-bos-token` is set) and we append a <pad> to target. 210 | This is convenient both for generation with a prefix and LM scoring. 211 | """ 212 | bs = len(src_tokens) 213 | if tgt_tokens is None: 214 | tgt_tokens = [torch.LongTensor([self.dictionary.eos()]) for _ in range(bs)] 215 | tgt_lengths = torch.LongTensor([t.numel() for t in tgt_tokens]) 216 | 217 | dataset = LanguageModelPromptDataset( 218 | src_tokens, 219 | src_lengths, 220 | self.dictionary, 221 | tgt_tokens, 222 | tgt_lengths, 223 | prefix = self.prefix, 224 | prompt=self.prompt, 225 | max_source_length=self.args.max_source_positions, 226 | max_length=self.args.max_target_positions, 227 | prompt_length=self.prompt_length 228 | ) 229 | 230 | return dataset 231 | 232 | def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None, allowed_text=None): 233 | with torch.no_grad(): 234 | # Generation will always be conditioned on bos_token 235 | if getattr(self.args, "add_bos_token", False): 236 | bos_token = self.source_dictionary.bos() 237 | else: 238 | bos_token = self.source_dictionary.eos() 239 | 240 | if constraints is not None: 241 | raise NotImplementedError( 242 | "Constrained decoding with the language_modeling task is not supported" 243 | ) 244 | 245 | if allowed_text is not None: 246 | allowed_text = self.target_dictionary.encode_line(allowed_text, add_if_not_exist=False).to(sample['net_input']['src_tokens']) 247 | # SequenceGenerator doesn't use src_tokens directly, we need to 248 | # pass the `prefix_tokens` argument instead 249 | 250 | if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): 251 | prefix_tokens = sample["net_input"]["src_tokens"] 252 | if prefix_tokens[:, 0].eq(bos_token).all(): 253 | prefix_tokens = prefix_tokens[:, 1:] 254 | 255 | return generator.generate( 256 | models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token, allowed_text=allowed_text 257 | ) 258 | 259 | def build_generator( 260 | self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None 261 | ): 262 | from .constrained_generator import ConstrainedGenerator 263 | 264 | # Choose search strategy. Defaults to Beam Search. 265 | sampling = getattr(args, "sampling", False) 266 | sampling_topk = getattr(args, "sampling_topk", -1) 267 | sampling_topp = getattr(args, "sampling_topp", -1.0) 268 | diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) 269 | diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) 270 | match_source_len = getattr(args, "match_source_len", False) 271 | diversity_rate = getattr(args, "diversity_rate", -1) 272 | constrained = getattr(args, "constraints", False) 273 | if prefix_allowed_tokens_fn is None: 274 | prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) 275 | if ( 276 | sum( 277 | int(cond) 278 | for cond in [ 279 | sampling, 280 | diverse_beam_groups > 0, 281 | match_source_len, 282 | diversity_rate > 0, 283 | ] 284 | ) 285 | > 1 286 | ): 287 | raise ValueError("Provided Search parameters are mutually exclusive.") 288 | assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" 289 | assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" 290 | 291 | if sampling: 292 | search_strategy = search.Sampling( 293 | self.target_dictionary, sampling_topk, sampling_topp 294 | ) 295 | elif diverse_beam_groups > 0: 296 | search_strategy = search.DiverseBeamSearch( 297 | self.target_dictionary, diverse_beam_groups, diverse_beam_strength 298 | ) 299 | elif match_source_len: 300 | # this is useful for tagging applications where the output 301 | # length should match the input length, so we hardcode the 302 | # length constraints for simplicity 303 | search_strategy = search.LengthConstrainedBeamSearch( 304 | self.target_dictionary, 305 | min_len_a=1, 306 | min_len_b=0, 307 | max_len_a=1, 308 | max_len_b=0, 309 | ) 310 | elif diversity_rate > -1: 311 | search_strategy = search.DiverseSiblingsSearch( 312 | self.target_dictionary, diversity_rate 313 | ) 314 | elif constrained: 315 | search_strategy = search.LexicallyConstrainedBeamSearch( 316 | self.target_dictionary, args.constraints 317 | ) 318 | elif prefix_allowed_tokens_fn: 319 | search_strategy = search.PrefixConstrainedBeamSearch( 320 | self.target_dictionary, prefix_allowed_tokens_fn 321 | ) 322 | else: 323 | search_strategy = search.BeamSearch(self.target_dictionary) 324 | 325 | extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} 326 | 327 | seq_gen_cls = ConstrainedGenerator 328 | 329 | return seq_gen_cls( 330 | models, 331 | self.target_dictionary, 332 | beam_size=getattr(args, "beam", 5), 333 | max_len_a=getattr(args, "max_len_a", 0), 334 | max_len_b=getattr(args, "max_len_b", 200), 335 | min_len=getattr(args, "min_len", 1), 336 | normalize_scores=(not getattr(args, "unnormalized", False)), 337 | len_penalty=getattr(args, "lenpen", 1), 338 | unk_penalty=getattr(args, "unkpen", 0), 339 | temperature=getattr(args, "temperature", 1.0), 340 | match_source_len=getattr(args, "match_source_len", False), 341 | no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), 342 | search_strategy=search_strategy, 343 | **extra_gen_cls_kwargs, 344 | ) 345 | -------------------------------------------------------------------------------- /src/transformer_lm_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import logging 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Dict, List, Tuple 7 | from argparse import Namespace 8 | import torch 9 | from torch import Tensor 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from fairseq import options, utils 14 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 15 | from fairseq.models import ( 16 | FairseqLanguageModel, 17 | register_model, 18 | register_model_architecture, 19 | ) 20 | from fairseq.models.transformer import ( 21 | DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder 22 | ) 23 | from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder 24 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 25 | from fairseq.models.transformer_lm import ( 26 | TransformerLanguageModelConfig, 27 | TransformerLanguageModel, 28 | transformer_lm_gpt2_small, 29 | transformer_lm_gpt2_big, 30 | ) 31 | from omegaconf import II, DictConfig 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | @register_model("transformer_lm_prompt", dataclass=TransformerLanguageModelConfig) 38 | class TransformerLanguageModelPrompt(TransformerLanguageModel): 39 | 40 | def load_state_dict( 41 | self, 42 | state_dict, 43 | strict=True, 44 | model_cfg: Optional[DictConfig] = None, 45 | args: Optional[Namespace] = None, 46 | ): 47 | """Copies parameters and buffers from *state_dict* into this module and 48 | its descendants. 49 | 50 | Overrides the method in :class:`nn.Module`. Compared with that method 51 | this additionally "upgrades" *state_dicts* from old checkpoints. 52 | """ 53 | 54 | if model_cfg is None and args is not None: 55 | logger.warn("using 'args' is deprecated, please update your code to use dataclass config") 56 | model_cfg = convert_namespace_to_omegaconf(args).model 57 | 58 | self.upgrade_state_dict(state_dict) 59 | 60 | device = state_dict["decoder.embed_tokens.weight"].device 61 | if self.decoder.embed_tokens.weight.shape[0] > state_dict["decoder.embed_tokens.weight"].shape[0]: 62 | shape = state_dict["decoder.embed_tokens.weight"].shape 63 | state_dict["decoder.embed_tokens.weight"] = torch.cat( 64 | [state_dict["decoder.embed_tokens.weight"], 65 | self.decoder.embed_tokens.weight[shape[0]:].to(device)] 66 | ) 67 | if self.decoder.output_projection.weight.shape[0] > state_dict["decoder.output_projection.weight"].shape[0]: 68 | shape = state_dict["decoder.output_projection.weight"].shape 69 | device = state_dict["decoder.output_projection.weight"].device 70 | state_dict["decoder.output_projection.weight"] = torch.cat( 71 | [state_dict["decoder.output_projection.weight"], 72 | self.decoder.output_projection.weight[shape[0]:].to(device)] 73 | ) 74 | 75 | from fairseq.checkpoint_utils import prune_state_dict 76 | 77 | new_state_dict = prune_state_dict(state_dict, model_cfg) 78 | return super().load_state_dict(new_state_dict, strict) 79 | 80 | 81 | @register_model_architecture("transformer_lm_prompt", "transformer_lm_prompt_biogpt") 82 | def transformer_lm_prompt_biogpt(args): 83 | transformer_lm_gpt2_small(args) 84 | 85 | @register_model_architecture("transformer_lm_prompt", "transformer_lm_prompt_biogpt_large") 86 | def transformer_lm_prompt_gpt2_big(args): 87 | transformer_lm_gpt2_big(args) --------------------------------------------------------------------------------