The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    └── workflows
    │   └── codeql-analysis.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── data
    ├── BC5CDR
    │   └── raw
    │   │   ├── BC5CDR_Evaluation-0.0.3
    │   │       ├── BioC.dtd
    │   │       ├── DISCLAIMER.txt
    │   │       ├── bc5cdr_eval.jar
    │   │       ├── data
    │   │       │   ├── gold
    │   │       │   │   ├── CDR_sample.gold.BioC.xml
    │   │       │   │   └── CDR_sample.gold.PubTator
    │   │       │   └── test
    │   │       │   │   ├── CDR_sample.old.test.CID.BioC.xml
    │   │       │   │   ├── CDR_sample.old.test.DNER.BioC.xml
    │   │       │   │   ├── CDR_sample.test.CID.BioC.xml
    │   │       │   │   ├── CDR_sample.test.CID.PubTator
    │   │       │   │   ├── CDR_sample.test.DNER.BioC.xml
    │   │       │   │   ├── CDR_sample.test.DNER.PubTator
    │   │       │   │   ├── my.txt
    │   │       │   │   └── rment.py
    │   │       ├── eval_id.sh
    │   │       ├── eval_mention.sh
    │   │       ├── eval_relation.sh
    │   │       ├── my.txt
    │   │       └── readme.txt
    │   │   ├── CDR_Data
    │   │       ├── BC5CDR.corpus.pdf
    │   │       ├── BC5CDR.overview.pdf
    │   │       ├── BC5CDR.presentation.pdf
    │   │       ├── CDR.Corpus.v010516
    │   │       │   ├── CDR_DevelopmentSet.BioC.xml
    │   │       │   ├── CDR_DevelopmentSet.PubTator.txt
    │   │       │   ├── CDR_TestSet.BioC.xml
    │   │       │   ├── CDR_TestSet.PubTator.txt
    │   │       │   ├── CDR_TrainingSet.BioC.xml
    │   │       │   └── CDR_TrainingSet.PubTator.txt
    │   │       ├── DNorm.TestSet
    │   │       │   ├── TestSet.DNorm.BioC.xml
    │   │       │   └── TestSet.DNorm.PubTator.txt
    │   │       ├── README.txt
    │   │       └── tmChem.TestSet
    │   │       │   ├── TestSet.tmChem.BioC.xml
    │   │       │   └── TestSet.tmChem.PubTator.txt
    │   │   ├── test.entities.json
    │   │   ├── test.json
    │   │   ├── train.entities.json
    │   │   ├── train.json
    │   │   ├── valid.entities.json
    │   │   └── valid.json
    ├── BioGPT-Large
    │   ├── bpecodes
    │   └── dict.txt
    ├── BioGPT
    │   ├── bpecodes
    │   └── dict.txt
    ├── DDI
    │   └── raw
    │   │   ├── test.json
    │   │   ├── train.json
    │   │   └── valid.json
    ├── HoC
    │   └── raw
    │   │   ├── test.tsv
    │   │   ├── train.tsv
    │   │   └── valid.tsv
    ├── KD-DTI
    │   └── raw
    │   │   ├── test.json
    │   │   ├── train.json
    │   │   └── valid.json
    ├── biogpt_large_bpecodes
    ├── biogpt_large_dict.txt
    ├── bpecodes
    └── dict.txt
├── examples
    ├── DC-HoC
    │   ├── README.md
    │   ├── hard_match_evaluation.py
    │   ├── infer.sh
    │   ├── postprocess.py
    │   ├── preprocess.sh
    │   ├── rebuild_data.py
    │   └── train.sh
    ├── QA-PubMedQA
    │   ├── README.md
    │   ├── hard_match_evaluation.py
    │   ├── infer.sh
    │   ├── infer_large.sh
    │   ├── postprocess.py
    │   ├── preprocess.sh
    │   ├── preprocess_large.sh
    │   └── rebuild_data.py
    ├── RE-BC5CDR
    │   ├── README.md
    │   ├── infer.sh
    │   ├── postprocess.py
    │   ├── preprocess.sh
    │   ├── rebuild_data.py
    │   └── train.sh
    ├── RE-DDI
    │   ├── README.md
    │   ├── hard_match_evaluation.py
    │   ├── infer.sh
    │   ├── postprocess.py
    │   ├── preprocess.sh
    │   ├── rebuild_data.py
    │   └── train.sh
    ├── RE-DTI
    │   ├── README.md
    │   ├── hard_match_evaluation.py
    │   ├── infer.sh
    │   ├── postprocess.py
    │   ├── preprocess.sh
    │   ├── rebuild_data.py
    │   └── train.sh
    └── text-generation
    │   ├── README.md
    │   └── interactive.py
├── inference.py
├── requirements.txt
├── scripts
    └── average_checkpoints.py
└── src
    ├── __init__.py
    ├── constrained_generator.py
    ├── language_model_prompt_dataset.py
    ├── language_modeling_prompt.py
    └── transformer_lm_prompt.py


/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
 1 | # For most projects, this workflow file will not need changing; you simply need
 2 | # to commit it to your repository.
 3 | #
 4 | # You may wish to alter this file to override the set of languages analyzed,
 5 | # or to provide custom queries or build logic.
 6 | #
 7 | # ******** NOTE ********
 8 | # We have attempted to detect the languages in your repository. Please check
 9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 | 
14 | on:
15 |   push:
16 |     branches: [ "main" ]
17 |   pull_request:
18 |     # The branches below must be a subset of the branches above
19 |     branches: [ "main" ]
20 |   schedule:
21 |     - cron: '39 5 * * 1'
22 | 
23 | jobs:
24 |   analyze:
25 |     name: Analyze
26 |     runs-on: ubuntu-latest
27 |     permissions:
28 |       actions: read
29 |       contents: read
30 |       security-events: write
31 | 
32 |     strategy:
33 |       fail-fast: false
34 |       matrix:
35 |         language: [ 'python' ]
36 |         # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 |         # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
38 | 
39 |     steps:
40 |     - name: Checkout repository
41 |       uses: actions/checkout@v3
42 | 
43 |     # Initializes the CodeQL tools for scanning.
44 |     - name: Initialize CodeQL
45 |       uses: github/codeql-action/init@v2
46 |       with:
47 |         languages: ${{ matrix.language }}
48 |         # If you wish to specify custom queries, you can do so here or in a config file.
49 |         # By default, queries listed here will override any specified in a config file.
50 |         # Prefix the list here with "+" to use these queries and those in the config file.
51 |         
52 |         # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
53 |         # queries: security-extended,security-and-quality
54 | 
55 |         
56 |     # Autobuild attempts to build any compiled languages  (C/C++, C#, or Java).
57 |     # If this step fails, then you should remove it and run the build manually (see below)
58 |     - name: Autobuild
59 |       uses: github/codeql-action/autobuild@v2
60 | 
61 |     # ℹ️ Command-line programs to run using the OS shell.
62 |     # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
63 | 
64 |     #   If the Autobuild fails above, remove it and uncomment the following three lines. 
65 |     #   modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
66 | 
67 |     # - run: |
68 |     #   echo "Run, Build Application using script"
69 |     #   ./location_of_script_within_repo/buildscript.sh
70 | 
71 |     - name: Perform CodeQL Analysis
72 |       uses: github/codeql-action/analyze@v2
73 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
  1 | # JetBrains PyCharm IDE
  2 | .idea/
  3 | 
  4 | # Byte-compiled / optimized / DLL files
  5 | __pycache__/
  6 | *.py[cod]
  7 | *$py.class
  8 | 
  9 | # C extensions
 10 | *.so
 11 | 
 12 | # macOS dir files
 13 | .DS_Store
 14 | 
 15 | # Distribution / packaging
 16 | .Python
 17 | env/
 18 | build/
 19 | develop-eggs/
 20 | dist/
 21 | downloads/
 22 | eggs/
 23 | .eggs/
 24 | lib/
 25 | lib64/
 26 | parts/
 27 | sdist/
 28 | var/
 29 | wheels/
 30 | *.egg-info/
 31 | .installed.cfg
 32 | *.egg
 33 | 
 34 | # PyInstaller
 35 | #  Usually these files are written by a python script from a template
 36 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 37 | *.manifest
 38 | *.spec
 39 | 
 40 | # Installer logs
 41 | pip-log.txt
 42 | pip-delete-this-directory.txt
 43 | 
 44 | # Unit test / coverage reports
 45 | htmlcov/
 46 | .tox/
 47 | .coverage
 48 | .coverage.*
 49 | .cache
 50 | nosetests.xml
 51 | coverage.xml
 52 | *.cover
 53 | .hypothesis/
 54 | 
 55 | # Translations
 56 | *.mo
 57 | *.pot
 58 | 
 59 | # Django stuff:
 60 | *.log
 61 | local_settings.py
 62 | 
 63 | # Flask stuff:
 64 | instance/
 65 | .webassets-cache
 66 | 
 67 | # Scrapy stuff:
 68 | .scrapy
 69 | 
 70 | # Sphinx documentation
 71 | docs/_build/
 72 | 
 73 | # PyBuilder
 74 | target/
 75 | 
 76 | # Jupyter Notebook
 77 | .ipynb_checkpoints
 78 | 
 79 | # pyenv
 80 | .python-version
 81 | 
 82 | # celery beat schedule file
 83 | celerybeat-schedule
 84 | 
 85 | # SageMath parsed files
 86 | *.sage.py
 87 | 
 88 | # dotenv
 89 | .env
 90 | 
 91 | # virtualenv
 92 | .venv
 93 | venv/
 94 | ENV/
 95 | 
 96 | # Spyder project settings
 97 | .spyderproject
 98 | .spyproject
 99 | 
100 | # Rope project settings
101 | .ropeproject
102 | 
103 | # mypy
104 | .mypy_cache/
105 | 
106 | # VSCODE
107 | .vscode/ftp-sync.json
108 | .vscode/settings.json
109 | 
110 | # Experimental Folder
111 | experimental/*
112 | 
113 | # Weights and Biases logs
114 | wandb/
115 | 
116 | # data
117 | data/*/*-bin
118 | data/*/raw/*.x
119 | data/*/raw/*.y
120 | data/*/raw/*.pmid
121 | data/*/raw/*bpecodes
122 | data/*/raw/*dict.txt
123 | 
124 | # Checkpoints
125 | checkpoints/


--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
 1 | # Microsoft Open Source Code of Conduct
 2 | 
 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
 4 | 
 5 | Resources:
 6 | 
 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 |     MIT License
 2 | 
 3 |     Copyright (c) Microsoft Corporation.
 4 | 
 5 |     Permission is hereby granted, free of charge, to any person obtaining a copy
 6 |     of this software and associated documentation files (the "Software"), to deal
 7 |     in the Software without restriction, including without limitation the rights
 8 |     to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 |     copies of the Software, and to permit persons to whom the Software is
10 |     furnished to do so, subject to the following conditions:
11 | 
12 |     The above copyright notice and this permission notice shall be included in all
13 |     copies or substantial portions of the Software.
14 | 
15 |     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 |     IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 |     FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 |     AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 |     LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 |     OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 |     SOFTWARE


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # BioGPT
  2 | This repository contains the implementation of [BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbac409/6713511?guestAccessKey=a66d9b5d-4f83-4017-bb52-405815c907b9), by Renqian Luo, Liai Sun, Yingce Xia, Tao Qin, Sheng Zhang, Hoifung Poon and Tie-Yan Liu.
  3 | 
  4 | 
  5 | # Requirements and Installation
  6 | 
  7 | * [PyTorch](http://pytorch.org/) version == 1.12.0
  8 | * Python version == 3.10
  9 | * fairseq version == 0.12.0:
 10 | 
 11 | ``` bash
 12 | git clone https://github.com/pytorch/fairseq
 13 | cd fairseq
 14 | git checkout v0.12.0
 15 | pip install .
 16 | python setup.py build_ext --inplace
 17 | cd ..
 18 | ```
 19 | * Moses
 20 | ``` bash
 21 | git clone https://github.com/moses-smt/mosesdecoder.git
 22 | export MOSES=${PWD}/mosesdecoder
 23 | ```
 24 | * fastBPE
 25 | ``` bash
 26 | git clone https://github.com/glample/fastBPE.git
 27 | export FASTBPE=${PWD}/fastBPE
 28 | cd fastBPE
 29 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
 30 | ```
 31 | * sacremoses
 32 | ``` bash
 33 | pip install sacremoses
 34 | ```
 35 | * sklearn
 36 | ``` bash
 37 | pip install scikit-learn
 38 | ```
 39 | 
 40 | Remember to set the environment variables `MOSES` and `FASTBPE` to the path of Moses and fastBPE respetively, as they will be required later.
 41 | 
 42 | # Getting Started
 43 | ## Pre-trained models
 44 | We provide our pre-trained BioGPT model checkpoints along with fine-tuned checkpoints for downstream tasks, available both through URL download as well as through the Hugging Face 🤗 Hub. 
 45 | 
 46 | |Model|Description|URL|🤗 Hub|
 47 | |----|----|---|---|
 48 | |BioGPT|Pre-trained BioGPT model checkpoint|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT.tgz?sp=r&st=2023-11-13T15:37:35Z&se=2099-12-30T23:37:35Z&spr=https&sv=2022-11-02&sr=b&sig=3CcG1TOhqJPBhkVutvVn3PtUq0vPyLBgwggUfojypfY%3D)|[link](https://huggingface.co/microsoft/biogpt)|
 49 | |BioGPT-Large|Pre-trained BioGPT-Large model checkpoint|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT-Large.tgz?sp=r&st=2023-11-13T15:38:13Z&se=2099-12-30T23:38:13Z&spr=https&sv=2022-11-02&sr=b&sig=ib1SZut9wAwrsxGWtFtIZDhrnRg92dwPJmoY2lr3MTg%3D)|[link](https://huggingface.co/microsoft/biogpt-large)|
 50 | |BioGPT-QA-PubMedQA-BioGPT|Fine-tuned BioGPT for question answering task on PubMedQA|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/QA-PubMedQA-BioGPT.tgz?sp=r&st=2023-11-13T15:38:43Z&se=2099-12-30T23:38:43Z&spr=https&sv=2022-11-02&sr=b&sig=A5SQae6ifsXmrsgpj4E2flhyXm4iHc%2FqO5b8HGOMyjc%3D)| |
 51 | |BioGPT-QA-PubMedQA-BioGPT-Large|Fine-tuned BioGPT-Large for question answering task on PubMedQA|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/QA-PubMedQA-BioGPT-Large.tgz?sp=r&st=2023-11-13T15:39:40Z&se=2099-12-30T23:39:40Z&spr=https&sv=2022-11-02&sr=b&sig=t%2B%2FD%2BxVoIxiuyDsD0VXv%2FjSGoS0VcrdVXycYhWZoxUc%3D)||
 52 | |BioGPT-RE-BC5CDR|Fine-tuned BioGPT for relation extraction task on BC5CDR|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-BC5CDR-BioGPT.tgz?sp=r&st=2023-11-13T15:35:14Z&se=2099-12-30T23:35:14Z&spr=https&sv=2022-11-02&sr=b&sig=uXlLIHlVeKIbS%2BVmdzAmlNCeKdoKO2lxsSmwSi%2FH8nE%3D)| |
 53 | |BioGPT-RE-DDI|Fine-tuned BioGPT for relation extraction task on DDI|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-DDI-BioGPT.tgz?sp=r&st=2023-11-13T15:35:58Z&se=2099-12-30T23:35:58Z&spr=https&sv=2022-11-02&sr=b&sig=DkaQMuM%2FXAsM2p8%2BUs45ecuqhlSRF1DUYRBJNcxD6Pk%3D)| |
 54 | |BioGPT-RE-DTI|Fine-tuned BioGPT for relation extraction task on KD-DTI|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/RE-DTI-BioGPT.tgz?sp=r&st=2023-11-13T15:36:23Z&se=2099-12-30T23:36:23Z&spr=https&sv=2022-11-02&sr=b&sig=bRgUZyqGuwYdM%2FVFzIv6Xa0GThkXq6bVzszmTe9c%2BKM%3D)| |
 55 | |BioGPT-DC-HoC|Fine-tuned BioGPT for document classification task on HoC|[link](https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/DC-HoC-BioGPT.tgz?sp=r&st=2023-11-13T15:37:17Z&se=2099-12-30T23:37:17Z&spr=https&sv=2022-11-02&sr=b&sig=1DxroWPt%2FBppCTy7QHs842lLy8SQRcUeUwSfMzDFvl0%3D)| |
 56 | 
 57 | Download them and extract them to the `checkpoints` folder of this project.
 58 | 
 59 | For example:
 60 | ``` bash
 61 | mkdir checkpoints
 62 | cd checkpoints
 63 | wget https://msralaphilly2.blob.core.windows.net/release/BioGPT/checkpoints/Pre-trained-BioGPT.tgz?sp=r&st=2023-11-13T15:37:35Z&se=2099-12-30T23:37:35Z&spr=https&sv=2022-11-02&sr=b&sig=3CcG1TOhqJPBhkVutvVn3PtUq0vPyLBgwggUfojypfY%3D
 64 | tar -zxvf Pre-trained-BioGPT.tgz
 65 | ```
 66 | 
 67 | ## Example Usage
 68 | Use pre-trained BioGPT model in your code:
 69 | ```python
 70 | import torch
 71 | from fairseq.models.transformer_lm import TransformerLanguageModel
 72 | m = TransformerLanguageModel.from_pretrained(
 73 |         "checkpoints/Pre-trained-BioGPT", 
 74 |         "checkpoint.pt", 
 75 |         "data",
 76 |         tokenizer='moses', 
 77 |         bpe='fastbpe', 
 78 |         bpe_codes="data/bpecodes",
 79 |         min_len=100,
 80 |         max_len_b=1024)
 81 | m.cuda()
 82 | src_tokens = m.encode("COVID-19 is")
 83 | generate = m.generate([src_tokens], beam=5)[0]
 84 | output = m.decode(generate[0]["tokens"])
 85 | print(output)
 86 | ```
 87 | 
 88 | Use fine-tuned BioGPT model on KD-DTI for drug-target-interaction in your code:
 89 | ```python
 90 | import torch
 91 | from src.transformer_lm_prompt import TransformerLanguageModelPrompt
 92 | m = TransformerLanguageModelPrompt.from_pretrained(
 93 |         "checkpoints/RE-DTI-BioGPT", 
 94 |         "checkpoint_avg.pt", 
 95 |         "data/KD-DTI/relis-bin",
 96 |         tokenizer='moses', 
 97 |         bpe='fastbpe', 
 98 |         bpe_codes="data/bpecodes",
 99 |         max_len_b=1024,
100 |         beam=1)
101 | m.cuda()
102 | src_text="" # input text, e.g., a PubMed abstract
103 | src_tokens = m.encode(src_text)
104 | generate = m.generate([src_tokens], beam=args.beam)[0]
105 | output = m.decode(generate[0]["tokens"])
106 | print(output)
107 | ```
108 | 
109 | For more downstream tasks, please see below.
110 | 
111 | ## Downstream tasks
112 | See corresponding folder in [examples](examples):
113 | ### [Relation Extraction on BC5CDR](examples/RE-BC5CDR)
114 | ### [Relation Extraction on KD-DTI](examples/RE-DTI/)
115 | ### [Relation Extraction on DDI](examples/RE-DDI)
116 | ### [Document Classification on HoC](examples/DC-HoC/)
117 | ### [Question Answering on PubMedQA](examples/QA-PubMedQA/)
118 | ### [Text Generation](examples/text-generation/)
119 | 
120 | ## Hugging Face 🤗 Usage
121 | 
122 | BioGPT has also been integrated into the Hugging Face `transformers` library, and model checkpoints are available on the Hugging Face Hub.
123 | 
124 | You can use this model directly with a pipeline for text generation. Since the generation relies on some randomness, we set a seed for reproducibility:
125 | 
126 | ```python
127 | from transformers import pipeline, set_seed
128 | from transformers import BioGptTokenizer, BioGptForCausalLM
129 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
130 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
131 | generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
132 | set_seed(42)
133 | generator("COVID-19 is", max_length=20, num_return_sequences=5, do_sample=True)
134 | ```
135 | 
136 | Here is how to use this model to get the features of a given text in PyTorch:
137 | 
138 | ```python
139 | from transformers import BioGptTokenizer, BioGptForCausalLM
140 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
141 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
142 | text = "Replace me by any text you'd like."
143 | encoded_input = tokenizer(text, return_tensors='pt')
144 | output = model(**encoded_input)
145 | ```
146 | 
147 | Beam-search decoding:
148 | 
149 | ```python
150 | import torch
151 | from transformers import BioGptTokenizer, BioGptForCausalLM, set_seed
152 | 
153 | tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
154 | model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
155 | 
156 | sentence = "COVID-19 is"
157 | inputs = tokenizer(sentence, return_tensors="pt")
158 | 
159 | set_seed(42)
160 | 
161 | with torch.no_grad():
162 |     beam_output = model.generate(**inputs,
163 |                                  min_length=100,
164 |                                  max_length=1024,
165 |                                  num_beams=5,
166 |                                  early_stopping=True
167 |                                 )
168 | tokenizer.decode(beam_output[0], skip_special_tokens=True)
169 | ```
170 | 
171 | For more information, please see the [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/biogpt) on the Hugging Face website.
172 | 
173 | ## Demos
174 | 
175 | Check out these demos on Hugging Face Spaces:
176 | * [Text Generation with BioGPT-Large](https://huggingface.co/spaces/katielink/biogpt-large-demo)
177 | * [Question Answering with BioGPT-Large-PubMedQA](https://huggingface.co/spaces/katielink/biogpt-qa-demo)
178 | 
179 | # License
180 | 
181 | BioGPT is MIT-licensed.
182 | The license applies to the pre-trained models as well.
183 | 
184 | # Contributing
185 | 
186 | This project welcomes contributions and suggestions.  Most contributions require you to agree to a
187 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
188 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
189 | 
190 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
191 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
192 | provided by the bot. You will only need to do this once across all repos using our CLA.
193 | 
194 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
195 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
196 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
197 | 
198 | # Trademarks
199 | 
200 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 
201 | trademarks or logos is subject to and must follow 
202 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
203 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
204 | Any use of third-party trademarks or logos are subject to those third-party's policies.
205 | 


--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
 1 | <!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
 2 | 
 3 | ## Security
 4 | 
 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
 6 | 
 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
 8 | 
 9 | ## Reporting Security Issues
10 | 
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 | 
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 | 
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 | 
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 
18 | 
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 | 
21 |   * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 |   * Full paths of source file(s) related to the manifestation of the issue
23 |   * The location of the affected source code (tag/branch/commit or direct URL)
24 |   * Any special configuration required to reproduce the issue
25 |   * Step-by-step instructions to reproduce the issue
26 |   * Proof-of-concept or exploit code (if possible)
27 |   * Impact of the issue, including how an attacker might exploit the issue
28 | 
29 | This information will help us triage your report more quickly.
30 | 
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 | 
33 | ## Preferred Languages
34 | 
35 | We prefer all communications to be in English.
36 | 
37 | ## Policy
38 | 
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 | 
41 | <!-- END MICROSOFT SECURITY.MD BLOCK -->
42 | 


--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
 1 | # TODO: The maintainer of this repo has not yet edited this file
 2 | 
 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
 4 | 
 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
 8 | 
 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 | 
11 | # Support
12 | 
13 | ## How to file issues and get help  
14 | 
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 
16 | issues before filing new issues to avoid duplicates.  For new issues, file your bug or 
17 | feature request as a new Issue.
18 | 
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 | 
23 | ## Microsoft Support Policy  
24 | 
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/BioC.dtd:
--------------------------------------------------------------------------------
  1 | <!-- BioC.dtd -->
  2 | 
  3 | <!--
  4 | 
  5 |     BioC is designed to allow programs that process text and
  6 |     annotations on that text to easily share data and work
  7 |     together. This DTD describes how that data is represented in XML
  8 |     files.
  9 | 
 10 |     Some believe XML is easily read by humans and that should be
 11 |     supported by clearly formatting the elements. In the long run,
 12 |     this is distracting. While the only meaningful spaces are in text
 13 |     elements and the other spaces can be ignored, current tools add no
 14 |     additional space.  Formatters and editors may be used to make the
 15 |     XML file appear more readable.
 16 | 
 17 |     The possible variety of annotations that one might want to produce
 18 |     or use is nearly countless. There is no guarantee that these are
 19 |     organized in the nice nested structure required for XML
 20 |     elements. Even if they were, it would be nice to easily ignore
 21 |     unwanted annotations.  So annotations are recorded in a stand off
 22 |     manner, external to the annotated text. The exceptions are
 23 |     passages and sentences because of their fundamental place in text.
 24 | 
 25 |     The text is expected to be encoded in Unicode, specifically
 26 |     UTF-8. This is one of the encodings required to be implemented by
 27 |     XML tools, is portable between big-endian and little-endian
 28 |     machines and is a superset of 7-bit ASCII. Code points beyond 127
 29 |     may be expressed directly in UTF-8 or indirectly using numeric
 30 |     entities.  Since many tools today still only directly process
 31 |     ASCII characters, conversion should be available and
 32 |     standardized.  Offsets should be in 8 bit code units (bytes) for
 33 |     easier processing by naive programs.
 34 | 
 35 |     collection:  Group of documents, usually from a larger corpus. If
 36 |     a group of documents is from several corpora, use several
 37 |     collections.
 38 | 
 39 |     source:  Name of the source corpus from which the documents were selected
 40 | 
 41 |     date:  Date documents extracted from original source. Can be as
 42 |     simple as yyyymmdd or an ISO timestamp.
 43 | 
 44 |     key: Separate file describing the infons used and any other useful
 45 |     information about the data in the file. For example, if a file
 46 |     includes part-of-speech tags, this file should describe the set of
 47 |     part-of-speech tags used.
 48 | 
 49 |     infon: key-value pairs. Can record essentially arbitrary
 50 |     information. "type" will be a particular common key in the major
 51 |     sub elements below. For PubMed references, passage "type" might
 52 |     signal "title" or "abstract". For annotations, it might indicate
 53 |     "noun phrase", "gene", or "disease". In the programming language
 54 |     data structures, infons are typically represented as a map from a
 55 |     string to a string.  This means keys should be unique within each
 56 |     parent element.
 57 | 
 58 |     document: A document in the collection. A single, complete
 59 |     stand-alone document as described by its parent source.
 60 | 
 61 |     id:  Typically, the id of the document in the parent
 62 |     source. Should at least be unique in the collection.
 63 | 
 64 |     passage: One portion of the document.  In the sample collection of
 65 |     PubMed documents, each document has a title and frequently an
 66 |     abstract. Structured abstracts could have additional passages. For
 67 |     a full text document, passages could be sections such as
 68 |     Introduction, Materials and Methods, or Conclusion. Another option
 69 |     would be paragraphs. Passages impose a linear structure on the
 70 |     document. Further structure in the document can be described by
 71 |     infon values.
 72 | 
 73 |     offset: Where the passage occurs in the parent document. Depending
 74 |     on the source corpus, this might be a very relevant number.  They
 75 |     should be sequential and identify a passage's position in the
 76 |     document.  Since the sample PubMed collection is extracted from an
 77 |     XML file, literal offsets have little value. The title is given an
 78 |     offset of zero, while the abstract is assumed to begin after the
 79 |     title and one space.
 80 | 
 81 |     text: The original text of the passage.
 82 | 
 83 |     sentence:  One sentence of the passage.
 84 | 
 85 |     offset: A document offset to where the sentence begins in the
 86 |     passage. This value is the sum of the passage offset and the local
 87 |     offset within the passage.
 88 | 
 89 |     text: The original text of the sentence.
 90 | 
 91 |     annotation:  Stand-off annotation
 92 | 
 93 |     id: Used to refer to this annotation in relations. Should be
 94 |     unique at whatever level relations at appear. If relations appear
 95 |     at the sentence level, annotation ids need to be unique within
 96 |     each sentence. Similarly, if relations appear at the passage
 97 |     level, annotation ids need to be unique within each passage.
 98 | 
 99 |     location: Location of the annotated text. Multiple locations
100 |     indicate a multi-span annotation.
101 | 
102 |     offset: Document offset to where the annotated text begins in
103 |     the passage or sentence. The value is the sum of the passage or
104 |     sentence offset and the local offset within the passage or
105 |     sentence.
106 | 
107 |     length: Length of the annotated text. While unlikely, this could
108 |     be zero to describe an annotation that belongs between two
109 |     characters.
110 | 
111 |     text:  Typically the annotated text.
112 | 
113 |     relation: Relation between multiple annotations and / or other
114 |     relations. Relations are allowed to appear at several levels
115 |     (document, passage, and sentence). Typically they will all appear
116 |     at one level, the level at which they are determined.
117 |     Significantly different types of relations might appear at
118 |     different levels.
119 | 
120 |     id: Used to refer to this relation in other relations. This id
121 |     needs to be unique at whatever level relations appear. (See
122 |     discussion of annotation ids.)
123 | 
124 |     refid: Id of an annotation or an other relation.
125 | 
126 |     role: Describes how the referenced annotattion or other relation
127 |     participates in the current relation. Has a default value so it
128 |     can be left out if there is no meaningful value.
129 | 
130 | -->
131 | 
132 | <!ELEMENT collection ( source, date, key, infon*, document+ ) >
133 | <!ELEMENT source (#PCDATA)>
134 | <!ELEMENT date (#PCDATA)>
135 | <!ELEMENT key (#PCDATA)>
136 | <!ELEMENT infon (#PCDATA)>
137 | <!ATTLIST infon key CDATA #REQUIRED >
138 | 
139 | <!ELEMENT document ( id, infon*, passage+, relation* ) >
140 | <!ELEMENT id (#PCDATA)>
141 | 
142 | <!ELEMENT passage ( infon*, offset, ( ( text?, annotation* ) | sentence* ), relation* ) >
143 | <!ELEMENT offset (#PCDATA)>
144 | <!ELEMENT text (#PCDATA)>
145 | 
146 | <!ELEMENT sentence ( infon*, offset, text?, annotation*, relation* ) >
147 | 
148 | <!ELEMENT annotation ( infon*, location*, text ) >
149 | <!ATTLIST annotation id CDATA #IMPLIED >
150 | <!ELEMENT location EMPTY>
151 | <!ATTLIST location offset CDATA #REQUIRED >
152 | <!ATTLIST location length CDATA #REQUIRED >
153 | 
154 | <!ELEMENT relation ( infon*, node* ) >
155 | <!ATTLIST relation id CDATA #IMPLIED >
156 | <!ELEMENT node EMPTY>
157 | <!ATTLIST node refid CDATA #REQUIRED >
158 | <!ATTLIST node role CDATA "" >
159 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/DISCLAIMER.txt:
--------------------------------------------------------------------------------
1 | PUBLIC DOMAIN NOTICE
2 | National Center for Biotechnology Information
3 |  
4 | This software/database is a "United States Government Work" under the terms of the United States Copyright Act. It was written as part of the authors' official duties as a United States Government employee and thus cannot be copyrighted. This software/database is freely available to the public for use. The National Library of Medicine and the U.S. Government have not placed any restriction on its use or reproduction.
5 |  
6 | Although all reasonable efforts have been taken to ensure the accuracy and reliability of the software and data, the NLM and the U.S. Government do not and cannot warrant the performance or results that may be obtained by using this software or data. The NLM and the U.S. Government disclaim all warranties, express or implied, including warranties of performance, merchantability or fitness for any particular purpose. 
7 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/bc5cdr_eval.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/bc5cdr_eval.jar


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/data/test/my.txt:
--------------------------------------------------------------------------------
  1 | 26094	CID	CHEBI:32875	D006973	1.0
  2 | 26094	CID	CHEBI:32875	D003866	1.
  3 | 354896	CID	D008012	D009135	1.0
  4 | 354896	CID	D008012	D001919	1.0
  5 | 354896	CID	D008012	D003866	1.0
  6 | 354896	CID	D008012	D006331	1.
  7 | 1720453	CID	D003404	D007674	1.0
  8 | 1720453	CID	D002945	D007674	1.0
  9 | 1720453	CID	D007069	D007674	1.0
 10 | 1720453	CID	D010710	D007674	1.0
 11 | 1720453	CID	D003609	D007674	1.0
 12 | 1720453	CID	D014750	D007674	1.0
 13 | 1720453	CID	D003404	D000608	1.0
 14 | 1720453	CID	D002945	D000608	1.0
 15 | 1720453	CID	D007069	D000608	1.0
 16 | 1720453	CID	D010710	D000608	1.0
 17 | 1720453	CID	D003609	D000608	1.0
 18 | 1720453	CID	D014750	D000608	1.0
 19 | 1720453	CID	D003404	D011507	1.0
 20 | 1720453	CID	D002945	D011507	1.0
 21 | 1720453	CID	D007069	D011507	1.0
 22 | 1720453	CID	D010710	D011507	1.0
 23 | 1720453	CID	D003609	D011507	1.0
 24 | 1720453	CID	D014750	D011507	1.0
 25 | 1720453	CID	D003404	D009369	1.0
 26 | 1720453	CID	D002945	D009369	1.0
 27 | 1720453	CID	D007069	D009369	1.0
 28 | 1720453	CID	D010710	D009369	1.0
 29 | 1720453	CID	D003609	D009369	1.0
 30 | 1720453	CID	D014750	D009369	1.0
 31 | 1720453	CID	D003404	D005199	1.0
 32 | 1720453	CID	D002945	D005199	1.0
 33 | 1720453	CID	D007069	D005199	1.0
 34 | 1720453	CID	D010710	D005199	1.0
 35 | 1720453	CID	D003609	D005199	1.0
 36 | 1720453	CID	D014750	D005199	1.0
 37 | 1720453	CID	D003404	D014786	1.0
 38 | 1720453	CID	D002945	D014786	1.0
 39 | 1720453	CID	D007069	D014786	1.0
 40 | 1720453	CID	D010710	D014786	1.0
 41 | 1720453	CID	D003609	D014786	1.0
 42 | 1720453	CID	D014750	D014786	1.
 43 | 2224762	CID	D004317	D000505	1.0
 44 | 2224762	CID	C010012	D000505	1.0
 45 | 2224762	CID	C010013	D000505	1.0
 46 | 2224762	CID	C027260	D000505	1.0
 47 | 2224762	CID	C027260	D000505	1.0
 48 | 2224762	CID	D004317	D009325	1.0
 49 | 2224762	CID	C010012	D009325	1.0
 50 | 2224762	CID	C010013	D009325	1.0
 51 | 2224762	CID	C027260	D009325	1.0
 52 | 2224762	CID	C027260	D009325	1.0
 53 | 2224762	CID	D004317	D000740	1.0
 54 | 2224762	CID	C010012	D000740	1.0
 55 | 2224762	CID	C010013	D000740	1.0
 56 | 2224762	CID	C027260	D000740	1.0
 57 | 2224762	CID	C027260	D000740	1.0
 58 | 2224762	CID	D004317	D000380	1.0
 59 | 2224762	CID	C010012	D000380	1.0
 60 | 2224762	CID	C010013	D000380	1.0
 61 | 2224762	CID	C027260	D000380	1.0
 62 | 2224762	CID	C027260	D000380	1.0
 63 | 2224762	CID	D004317	D009369	1.0
 64 | 2224762	CID	C010012	D009369	1.0
 65 | 2224762	CID	C010013	D009369	1.0
 66 | 2224762	CID	C027260	D009369	1.0
 67 | 2224762	CID	C027260	D009369	1.0
 68 | 2224762	CID	D004317	D014786	1.0
 69 | 2224762	CID	C010012	D014786	1.0
 70 | 2224762	CID	C010013	D014786	1.0
 71 | 2224762	CID	C027260	D014786	1.0
 72 | 2224762	CID	C027260	D014786	1.0
 73 | 2224762	CID	D004317	D008107	1.0
 74 | 2224762	CID	C010012	D008107	1.0
 75 | 2224762	CID	C010013	D008107	1.0
 76 | 2224762	CID	C027260	D008107	1.0
 77 | 2224762	CID	C027260	D008107	1.0
 78 | 2224762	CID	D004317	D010689	1.0
 79 | 2224762	CID	C010012	D010689	1.0
 80 | 2224762	CID	C010013	D010689	1.0
 81 | 2224762	CID	C027260	D010689	1.0
 82 | 2224762	CID	C027260	D010689	1.0
 83 | 2224762	CID	D004317	D013921	1.0
 84 | 2224762	CID	C010012	D013921	1.0
 85 | 2224762	CID	C010013	D013921	1.0
 86 | 2224762	CID	C027260	D013921	1.0
 87 | 2224762	CID	C027260	D013921	1.0
 88 | 2224762	CID	D004317	D002280	1.0
 89 | 2224762	CID	C010012	D002280	1.0
 90 | 2224762	CID	C010013	D002280	1.0
 91 | 2224762	CID	C027260	D002280	1.0
 92 | 2224762	CID	C027260	D002280	1.0
 93 | 2224762	CID	D004317	D052016	1.0
 94 | 2224762	CID	C010012	D052016	1.0
 95 | 2224762	CID	C010013	D052016	1.0
 96 | 2224762	CID	C027260	D052016	1.0
 97 | 2224762	CID	C027260	D052016	1.0
 98 | 2224762	CID	D004317	D007890	1.0
 99 | 2224762	CID	C010012	D007890	1.0
100 | 2224762	CID	C010013	D007890	1.0
101 | 2224762	CID	C027260	D007890	1.0
102 | 2224762	CID	C027260	D007890	1.
103 | 2348231	CID	D004176	D007383	1.0
104 | 2348231	CID	C008514	D007383	1.0
105 | 2348231	CID	D013806	D007383	1.0
106 | 2348231	CID	C008514	D007383	1.0
107 | 2348231	CID	D010431	D007383	1.0
108 | 2348231	CID	D010431	D007383	1.0
109 | 2348231	CID	D004176	D014652	1.0
110 | 2348231	CID	C008514	D014652	1.0
111 | 2348231	CID	D013806	D014652	1.0
112 | 2348231	CID	C008514	D014652	1.0
113 | 2348231	CID	D010431	D014652	1.0
114 | 2348231	CID	D010431	D014652	1.0
115 | 2348231	CID	D004176	D006940	1.0
116 | 2348231	CID	C008514	D006940	1.0
117 | 2348231	CID	D013806	D006940	1.0
118 | 2348231	CID	C008514	D006940	1.0
119 | 2348231	CID	D010431	D006940	1.0
120 | 2348231	CID	D010431	D006940	1.0
121 | 2348231	CID	D004176	D003327	1.0
122 | 2348231	CID	C008514	D003327	1.0
123 | 2348231	CID	D013806	D003327	1.0
124 | 2348231	CID	C008514	D003327	1.0
125 | 2348231	CID	D010431	D003327	1.0
126 | 2348231	CID	D010431	D003327	1.
127 | 2385256	CID	D000109	D009468	1.0
128 | 2385256	CID	D008274	D009468	1.0
129 | 2385256	CID	D000109	D009157	1.0
130 | 2385256	CID	D008274	D009157	1.0
131 | 2385256	CID	D000109	D011225	1.0
132 | 2385256	CID	D008274	D011225	1.0
133 | 2385256	CID	D000109	D018908	1.0
134 | 2385256	CID	D008274	D018908	1.0
135 | 2385256	CID	D000109	D010243	1.0
136 | 2385256	CID	D008274	D010243	1.0
137 | 2385256	CID	D000109	D020511	1.0
138 | 2385256	CID	D008274	D020511	1.
139 | 2887062	CID	D001971	D000236	1.0
140 | 2887062	CID	D004967	D000236	1.0
141 | 2887062	CID	D001971	D015175	1.0
142 | 2887062	CID	D004967	D015175	1.
143 | 2894766	CID	D012460	D003093	1.0
144 | 2894766	CID	D000305	D003093	1.0
145 | 2894766	CID	D012460	D002305	1.0
146 | 2894766	CID	D000305	D002305	1.0
147 | 2894766	CID	D012460	C536397	1.0
148 | 2894766	CID	D000305	C536397	1.0
149 | 2894766	CID	D012460	D012700	1.0
150 | 2894766	CID	D000305	D012700	1.0
151 | 2894766	CID	D012460	D004832	1.0
152 | 2894766	CID	D000305	D004832	1.0
153 | 2894766	CID	D012460	D010996	1.0
154 | 2894766	CID	D000305	D010996	1.0
155 | 2894766	CID	D012460	D008180	1.0
156 | 2894766	CID	D000305	D008180	1.0
157 | 2894766	CID	D012460	D011014	1.0
158 | 2894766	CID	D000305	D011014	1.0
159 | 2894766	CID	D012460	D015212	1.0
160 | 2894766	CID	D000305	D015212	1.
161 | 3107448	CID	D013453	D007239	1.0
162 | 3107448	CID	D005905	D007239	1.0
163 | 3107448	CID	D013453	D056486	1.0
164 | 3107448	CID	D005905	D056486	1.0
165 | 3107448	CID	D013453	D003920	1.0
166 | 3107448	CID	D005905	D003920	1.
167 | 3173180	CID	D008795	D000743	1.
168 | 3403780	CID	D000082	D017093	1.0
169 | 3403780	CID	D000082	D058186	1.0
170 | 3403780	CID	D000082	D000138	1.0
171 | 3403780	CID	D000082	D003128	1.0
172 | 3403780	CID	D000082	D051437	1.
173 | 3615541	CID	D002116	D006948	1.0
174 | 3615541	CID	D000661	D006948	1.
175 | 3827439	CID	D000666	D058186	1.0
176 | 3827439	CID	D000666	D013174	1.0
177 | 3827439	CID	D000666	OMIM:215600	1.0
178 | 3827439	CID	D000666	D007674	1.0
179 | 3827439	CID	D000666	D051437	1.
180 | 4027862	CID	D003891	D003693	1.
181 | 6287825	CID	D012256	D017695	1.0
182 | 6287825	CID	D013831	D017695	1.0
183 | 6287825	CID	D007538	D017695	1.0
184 | 6287825	CID	D012256	D020275	1.0
185 | 6287825	CID	D013831	D020275	1.0
186 | 6287825	CID	D007538	D020275	1.0
187 | 6287825	CID	D012256	D003389	1.0
188 | 6287825	CID	D013831	D003389	1.0
189 | 6287825	CID	D007538	D003389	1.0
190 | 6287825	CID	D012256	D009422	1.0
191 | 6287825	CID	D013831	D009422	1.0
192 | 6287825	CID	D007538	D009422	1.0
193 | 6287825	CID	D012256	D010523	1.0
194 | 6287825	CID	D013831	D010523	1.0
195 | 6287825	CID	D007538	D010523	1.0
196 | 6287825	CID	D012256	D008881	1.0
197 | 6287825	CID	D013831	D008881	1.0
198 | 6287825	CID	D007538	D008881	1.0
199 | 6287825	CID	D012256	D015417	1.0
200 | 6287825	CID	D013831	D015417	1.0
201 | 6287825	CID	D007538	D015417	1.0
202 | 6287825	CID	D012256	D016609	1.0
203 | 6287825	CID	D013831	D016609	1.0
204 | 6287825	CID	D007538	D016609	1.0
205 | 6287825	CID	D012256	D003920	1.0
206 | 6287825	CID	D013831	D003920	1.0
207 | 6287825	CID	D007538	D003920	1.
208 | 7352670	CID	D006221	D007022	1.0
209 | 7352670	CID	D012504	D007022	1.0
210 | 7352670	CID	D009599	D007022	1.0
211 | 7352670	CID	D009599	D007022	1.0
212 | 7352670	CID	D006221	D006973	1.0
213 | 7352670	CID	D012504	D006973	1.0
214 | 7352670	CID	D009599	D006973	1.0
215 | 7352670	CID	D009599	D006973	1.
216 | 7420681	CID	D000617	D006311	1.0
217 | 7420681	CID	D005839	D006311	1.0
218 | 7420681	CID	D014031	D006311	1.0
219 | 7420681	CID	D000617	D051437	1.0
220 | 7420681	CID	D005839	D051437	1.0
221 | 7420681	CID	D014031	D051437	1.
222 | 7468724	CID	D013726	D007752	1.0
223 | 7468724	CID	D013726	D002318	1.
224 | 8643971	CID	D002945	D009369	1.0
225 | 8643971	CID	D017239	D009369	1.0
226 | 8643971	CID	D002945	D014786	1.0
227 | 8643971	CID	D017239	D014786	1.0
228 | 8643971	CID	D002945	D010051	1.0
229 | 8643971	CID	D017239	D010051	1.0
230 | 8643971	CID	D002945	D006258	1.0
231 | 8643971	CID	D017239	D006258	1.
232 | 8649546	CID	D011433	D010300	1.0
233 | 8649546	CID	D011433	D004421	1.0
234 | 8649546	CID	D011433	D009069	1.0
235 | 8649546	CID	D011433	D004409	1.
236 | 9088814	CID	D020155	D007022	1.0
237 | 9088814	CID	D000527	D007022	1.0
238 | 9088814	CID	D001663	D007022	1.0
239 | 9088814	CID	C016635	D007022	1.0
240 | 9088814	CID	CHEBI:17087	D007022	1.0
241 | 9088814	CID	D020155	D008107	1.0
242 | 9088814	CID	D000527	D008107	1.0
243 | 9088814	CID	D001663	D008107	1.0
244 | 9088814	CID	C016635	D008107	1.0
245 | 9088814	CID	CHEBI:17087	D008107	1.
246 | 9249847	CID	D014700	D014927	1.0
247 | 9249847	CID	D014700	C536277	1.0
248 | 9249847	CID	D014700	D013611	1.0
249 | 9249847	CID	D014700	D058606	1.
250 | 9660111	CID	D019257	D006948	1.0
251 | 9660111	CID	D012701	D006948	1.0
252 | 9660111	CID	D001058	D006948	1.0
253 | 9660111	CID	D007099	D006948	1.0
254 | 9660111	CID	D004298	D006948	1.0
255 | 9660111	CID	D006916	D006948	1.0
256 | 9660111	CID	D003913	D006948	1.0
257 | 9660111	CID	D014299	D006948	1.0
258 | 9660111	CID	D010656	D006948	1.0
259 | 9660111	CID	D014299	D006948	1.0
260 | 9660111	CID	D009638	D006948	1.0
261 | 9660111	CID	D003000	D006948	1.0
262 | 9660111	CID	D019257	D010554	1.0
263 | 9660111	CID	D012701	D010554	1.0
264 | 9660111	CID	D001058	D010554	1.0
265 | 9660111	CID	D007099	D010554	1.0
266 | 9660111	CID	D004298	D010554	1.0
267 | 9660111	CID	D006916	D010554	1.0
268 | 9660111	CID	D003913	D010554	1.0
269 | 9660111	CID	D014299	D010554	1.0
270 | 9660111	CID	D010656	D010554	1.0
271 | 9660111	CID	D014299	D010554	1.0
272 | 9660111	CID	D009638	D010554	1.0
273 | 9660111	CID	D003000	D010554	1.0
274 | 9660111	CID	D019257	D007035	1.0
275 | 9660111	CID	D012701	D007035	1.0
276 | 9660111	CID	D001058	D007035	1.0
277 | 9660111	CID	D007099	D007035	1.0
278 | 9660111	CID	D004298	D007035	1.0
279 | 9660111	CID	D006916	D007035	1.0
280 | 9660111	CID	D003913	D007035	1.0
281 | 9660111	CID	D014299	D007035	1.0
282 | 9660111	CID	D010656	D007035	1.0
283 | 9660111	CID	D014299	D007035	1.0
284 | 9660111	CID	D009638	D007035	1.0
285 | 9660111	CID	D003000	D007035	1.
286 | 9698967	CID	D008874	D014474	1.0
287 | 9698967	CID	D004837	D014474	1.0
288 | 9698967	CID	D008790	D014474	1.0
289 | 9698967	CID	D007741	D014474	1.0
290 | 9698967	CID	D008619	D014474	1.0
291 | 9698967	CID	D008874	D004387	1.0
292 | 9698967	CID	D004837	D004387	1.0
293 | 9698967	CID	D008790	D004387	1.0
294 | 9698967	CID	D007741	D004387	1.0
295 | 9698967	CID	D008619	D004387	1.0
296 | 9698967	CID	D008874	D001281	1.0
297 | 9698967	CID	D004837	D001281	1.0
298 | 9698967	CID	D008790	D001281	1.0
299 | 9698967	CID	D007741	D001281	1.0
300 | 9698967	CID	D008619	D001281	1.
301 | 9855119	CID	D003404	D005923	1.0
302 | 9855119	CID	D016559	D005923	1.0
303 | 9855119	CID	D003404	D007674	1.0
304 | 9855119	CID	D016559	D007674	1.0
305 | 9855119	CID	D003404	D057770	1.0
306 | 9855119	CID	D016559	D057770	1.0
307 | 9855119	CID	D003404	D005355	1.0
308 | 9855119	CID	D016559	D005355	1.
309 | 9869257	CID	D012601	D003072	1.0
310 | 9869257	CID	C098725	D003072	1.0
311 | 9869257	CID	D000109	D003072	1.0
312 | 9869257	CID	D012601	D000647	1.0
313 | 9869257	CID	C098725	D000647	1.0
314 | 9869257	CID	D000109	D000647	1.
315 | 10457883	CID	D001279	D006323	1.0
316 | 10457883	CID	D001279	D009468	1.0
317 | 10457883	CID	D001279	D001919	1.
318 | 10721819	CID	D001971	D013610	1.0
319 | 10721819	CID	D007545	D013610	1.0
320 | 10721819	CID	D004298	D013610	1.0
321 | 10721819	CID	D004294	D013610	1.0
322 | 10721819	CID	D001971	D001919	1.0
323 | 10721819	CID	D007545	D001919	1.0
324 | 10721819	CID	D004298	D001919	1.0
325 | 10721819	CID	D004294	D001919	1.0
326 | 10721819	CID	D001971	D007022	1.0
327 | 10721819	CID	D007545	D007022	1.0
328 | 10721819	CID	D004298	D007022	1.0
329 | 10721819	CID	D004294	D007022	1.0
330 | 10721819	CID	D001971	D006332	1.0
331 | 10721819	CID	D007545	D006332	1.0
332 | 10721819	CID	D004298	D006332	1.0
333 | 10721819	CID	D004294	D006332	1.
334 | 11147747	CID	D012701	D014549	1.0
335 | 11147747	CID	D016651	D014549	1.0
336 | 11147747	CID	D020280	D014549	1.0
337 | 11147747	CID	D017374	D014549	1.0
338 | 11147747	CID	C047426	D014549	1.
339 | 11587867	CID	D014750	D009370	1.0
340 | 11587867	CID	D014750	D009370	1.0
341 | 11587867	CID	D014750	D010243	1.0
342 | 11587867	CID	D014750	D010243	1.0
343 | 11587867	CID	D014750	D015417	1.0
344 | 11587867	CID	D014750	D015417	1.0
345 | 11587867	CID	D014750	D054198	1.0
346 | 11587867	CID	D014750	D054198	1.
347 | 11897407	CID	C067171	D006331	1.0
348 | 11897407	CID	D007545	D006331	1.0
349 | 11897407	CID	D005937	D006331	1.0
350 | 11897407	CID	C067171	D007238	1.0
351 | 11897407	CID	D007545	D007238	1.0
352 | 11897407	CID	D005937	D007238	1.0
353 | 11897407	CID	C067171	D009203	1.0
354 | 11897407	CID	D007545	D009203	1.0
355 | 11897407	CID	D005937	D009203	1.
356 | 12584269	CID	D020123	D005355	1.0
357 | 12584269	CID	D002857	D005355	1.0
358 | 12584269	CID	D004492	D005355	1.0
359 | 12584269	CID	D016572	D005355	1.0
360 | 12584269	CID	D020123	D005355	1.0
361 | 12584269	CID	D016559	D005355	1.0
362 | 12584269	CID	D016572	D005355	1.
363 | 12905102	CID	D003024	D003693	1.
364 | 14513889	CID	D007654	D020335	1.0
365 | 14513889	CID	D007654	D009422	1.0
366 | 14513889	CID	D007654	D014202	1.0
367 | 14513889	CID	D007654	D004401	1.0
368 | 14513889	CID	D007654	D010243	1.0
369 | 14513889	CID	D007654	D004362	1.
370 | 15188772	CID	D002395	D009202	1.0
371 | 15188772	CID	D002395	D009202	1.0
372 | 15188772	CID	D004837	D009202	1.0
373 | 15188772	CID	D002395	D015537	1.0
374 | 15188772	CID	D002395	D015537	1.0
375 | 15188772	CID	D004837	D015537	1.0
376 | 15188772	CID	D002395	D017682	1.0
377 | 15188772	CID	D002395	D017682	1.0
378 | 15188772	CID	D004837	D017682	1.0
379 | 15188772	CID	D002395	D008107	1.0
380 | 15188772	CID	D002395	D008107	1.0
381 | 15188772	CID	D004837	D008107	1.
382 | 15602202	CID	D017963	D009395	1.0
383 | 15602202	CID	D017963	D003072	1.0
384 | 15602202	CID	D017963	D007674	1.
385 | 16298782	CID	D005839	D006316	1.0
386 | 16298782	CID	D019331	D006316	1.0
387 | 16298782	CID	D000617	D006316	1.0
388 | 16298782	CID	D009569	D006316	1.0
389 | 16298782	CID	D005839	D034381	1.0
390 | 16298782	CID	D019331	D034381	1.0
391 | 16298782	CID	D000617	D034381	1.0
392 | 16298782	CID	D009569	D034381	1.0
393 | 16298782	CID	D005839	D006311	1.0
394 | 16298782	CID	D019331	D006311	1.0
395 | 16298782	CID	D000617	D006311	1.0
396 | 16298782	CID	D009569	D006311	1.0
397 | 16298782	CID	D005839	D006319	1.0
398 | 16298782	CID	D019331	D006319	1.0
399 | 16298782	CID	D000617	D006319	1.0
400 | 16298782	CID	D009569	D006319	1.
401 | 16330766	CID	C040029	D017695	1.0
402 | 16330766	CID	D002211	D017695	1.0
403 | 16330766	CID	C040029	D002493	1.0
404 | 16330766	CID	D002211	D002493	1.0
405 | 16330766	CID	C040029	D020886	1.0
406 | 16330766	CID	D002211	D020886	1.0
407 | 16330766	CID	C040029	D010146	1.0
408 | 16330766	CID	D002211	D010146	1.0
409 | 16330766	CID	C040029	D006930	1.0
410 | 16330766	CID	D002211	D006930	1.0
411 | 16330766	CID	C040029	D009437	1.0
412 | 16330766	CID	D002211	D009437	1.
413 | 17175308	CID	D003404	D009395	1.0
414 | 17175308	CID	D020123	D009395	1.0
415 | 17175308	CID	D020123	D009395	1.0
416 | 17175308	CID	D003404	OMIM:274230	1.0
417 | 17175308	CID	D020123	OMIM:274230	1.0
418 | 17175308	CID	D020123	OMIM:274230	1.0
419 | 17175308	CID	D003404	D011507	1.0
420 | 17175308	CID	D020123	D011507	1.0
421 | 17175308	CID	D020123	D011507	1.0
422 | 17175308	CID	D003404	D009369	1.0
423 | 17175308	CID	D020123	D009369	1.0
424 | 17175308	CID	D020123	D009369	1.0
425 | 17175308	CID	D003404	D012514	1.0
426 | 17175308	CID	D020123	D012514	1.0
427 | 17175308	CID	D020123	D012514	1.0
428 | 17175308	CID	D003404	D009404	1.0
429 | 17175308	CID	D020123	D009404	1.0
430 | 17175308	CID	D020123	D009404	1.0
431 | 17175308	CID	D003404	D007674	1.0
432 | 17175308	CID	D020123	D007674	1.0
433 | 17175308	CID	D020123	D007674	1.
434 | 17242861	CID	D010862	D013180	1.0
435 | 17242861	CID	D010862	D012640	1.0
436 | 17242861	CID	D010862	D004833	1.
437 | 17466854	CID	D008012	D019106	1.0
438 | 17466854	CID	D005839	D019106	1.0
439 | 17466854	CID	D008775	D019106	1.0
440 | 17466854	CID	D008012	D009325	1.0
441 | 17466854	CID	D005839	D009325	1.0
442 | 17466854	CID	D008775	D009325	1.0
443 | 17466854	CID	D008012	D006261	1.0
444 | 17466854	CID	D005839	D006261	1.0
445 | 17466854	CID	D008775	D006261	1.0
446 | 17466854	CID	D008012	D014839	1.0
447 | 17466854	CID	D005839	D014839	1.0
448 | 17466854	CID	D008775	D014839	1.
449 | 17615423	CID	D003404	D058186	1.0
450 | 17615423	CID	D003401	D058186	1.0
451 | 17615423	CID	D008148	D058186	1.0
452 | 17615423	CID	C065179	D058186	1.0
453 | 17615423	CID	D017035	D058186	1.0
454 | 17615423	CID	C422923	D058186	1.0
455 | 17615423	CID	C413408	D058186	1.0
456 | 17615423	CID	D019821	D058186	1.0
457 | 17615423	CID	D001224	D058186	1.0
458 | 17615423	CID	D000638	D058186	1.0
459 | 17615423	CID	C065180	D058186	1.0
460 | 17615423	CID	D019821	D058186	1.0
461 | 17615423	CID	D000409	D058186	1.0
462 | 17615423	CID	D003404	D012206	1.0
463 | 17615423	CID	D003401	D012206	1.0
464 | 17615423	CID	D008148	D012206	1.0
465 | 17615423	CID	C065179	D012206	1.0
466 | 17615423	CID	D017035	D012206	1.0
467 | 17615423	CID	C422923	D012206	1.0
468 | 17615423	CID	C413408	D012206	1.0
469 | 17615423	CID	D019821	D012206	1.0
470 | 17615423	CID	D001224	D012206	1.0
471 | 17615423	CID	D000638	D012206	1.0
472 | 17615423	CID	C065180	D012206	1.0
473 | 17615423	CID	D019821	D012206	1.0
474 | 17615423	CID	D000409	D012206	1.0
475 | 17615423	CID	D003404	D006949	1.0
476 | 17615423	CID	D003401	D006949	1.0
477 | 17615423	CID	D008148	D006949	1.0
478 | 17615423	CID	C065179	D006949	1.0
479 | 17615423	CID	D017035	D006949	1.0
480 | 17615423	CID	C422923	D006949	1.0
481 | 17615423	CID	C413408	D006949	1.0
482 | 17615423	CID	D019821	D006949	1.0
483 | 17615423	CID	D001224	D006949	1.0
484 | 17615423	CID	D000638	D006949	1.0
485 | 17615423	CID	C065180	D006949	1.0
486 | 17615423	CID	D019821	D006949	1.0
487 | 17615423	CID	D000409	D006949	1.0
488 | 17615423	CID	D003404	D001281	1.0
489 | 17615423	CID	D003401	D001281	1.0
490 | 17615423	CID	D008148	D001281	1.0
491 | 17615423	CID	C065179	D001281	1.0
492 | 17615423	CID	D017035	D001281	1.0
493 | 17615423	CID	C422923	D001281	1.0
494 | 17615423	CID	C413408	D001281	1.0
495 | 17615423	CID	D019821	D001281	1.0
496 | 17615423	CID	D001224	D001281	1.0
497 | 17615423	CID	D000638	D001281	1.0
498 | 17615423	CID	C065180	D001281	1.0
499 | 17615423	CID	D019821	D001281	1.0
500 | 17615423	CID	D000409	D001281	1.0
501 | 17615423	CID	D003404	D015658	1.0
502 | 17615423	CID	D003401	D015658	1.0
503 | 17615423	CID	D008148	D015658	1.0
504 | 17615423	CID	C065179	D015658	1.0
505 | 17615423	CID	D017035	D015658	1.0
506 | 17615423	CID	C422923	D015658	1.0
507 | 17615423	CID	C413408	D015658	1.0
508 | 17615423	CID	D019821	D015658	1.0
509 | 17615423	CID	D001224	D015658	1.0
510 | 17615423	CID	D000638	D015658	1.0
511 | 17615423	CID	C065180	D015658	1.0
512 | 17615423	CID	D019821	D015658	1.0
513 | 17615423	CID	D000409	D015658	1.0
514 | 17615423	CID	D003404	D003324	1.0
515 | 17615423	CID	D003401	D003324	1.0
516 | 17615423	CID	D008148	D003324	1.0
517 | 17615423	CID	C065179	D003324	1.0
518 | 17615423	CID	D017035	D003324	1.0
519 | 17615423	CID	C422923	D003324	1.0
520 | 17615423	CID	C413408	D003324	1.0
521 | 17615423	CID	D019821	D003324	1.0
522 | 17615423	CID	D001224	D003324	1.0
523 | 17615423	CID	D000638	D003324	1.0
524 | 17615423	CID	C065180	D003324	1.0
525 | 17615423	CID	D019821	D003324	1.0
526 | 17615423	CID	D000409	D003324	1.0
527 | 17615423	CID	D003404	D010146	1.0
528 | 17615423	CID	D003401	D010146	1.0
529 | 17615423	CID	D008148	D010146	1.0
530 | 17615423	CID	C065179	D010146	1.0
531 | 17615423	CID	D017035	D010146	1.0
532 | 17615423	CID	C422923	D010146	1.0
533 | 17615423	CID	C413408	D010146	1.0
534 | 17615423	CID	D019821	D010146	1.0
535 | 17615423	CID	D001224	D010146	1.0
536 | 17615423	CID	D000638	D010146	1.0
537 | 17615423	CID	C065180	D010146	1.0
538 | 17615423	CID	D019821	D010146	1.0
539 | 17615423	CID	D000409	D010146	1.
540 | 18023325	CID	D004221	D020275	1.0
541 | 18023325	CID	D004221	D011782	1.0
542 | 18023325	CID	D004221	D015537	1.0
543 | 18023325	CID	D004221	D001480	1.0
544 | 18023325	CID	D004221	D010523	1.0
545 | 18023325	CID	D004221	D001259	1.0
546 | 18023325	CID	D004221	D008107	1.0
547 | 18023325	CID	D004221	D010146	1.0
548 | 18023325	CID	D004221	D014826	1.
549 | 18186898	CID	CHEBI:8764	D005198	1.0
550 | 18186898	CID	D019259	D005198	1.0
551 | 18186898	CID	D013256	D005198	1.0
552 | 18186898	CID	D020123	D005198	1.0
553 | 18186898	CID	D016559	D005198	1.0
554 | 18186898	CID	CHEBI:8764	D028361	1.0
555 | 18186898	CID	D019259	D028361	1.0
556 | 18186898	CID	D013256	D028361	1.0
557 | 18186898	CID	D020123	D028361	1.0
558 | 18186898	CID	D016559	D028361	1.0
559 | 18186898	CID	CHEBI:8764	D000138	1.0
560 | 18186898	CID	D019259	D000138	1.0
561 | 18186898	CID	D013256	D000138	1.0
562 | 18186898	CID	D020123	D000138	1.0
563 | 18186898	CID	D016559	D000138	1.0
564 | 18186898	CID	CHEBI:8764	D018908	1.0
565 | 18186898	CID	D019259	D018908	1.0
566 | 18186898	CID	D013256	D018908	1.0
567 | 18186898	CID	D020123	D018908	1.0
568 | 18186898	CID	D016559	D018908	1.0
569 | 18186898	CID	CHEBI:8764	D006029	1.0
570 | 18186898	CID	D019259	D006029	1.0
571 | 18186898	CID	D013256	D006029	1.0
572 | 18186898	CID	D020123	D006029	1.0
573 | 18186898	CID	D016559	D006029	1.0
574 | 18186898	CID	CHEBI:8764	D056486	1.0
575 | 18186898	CID	D019259	D056486	1.0
576 | 18186898	CID	D013256	D056486	1.0
577 | 18186898	CID	D020123	D056486	1.0
578 | 18186898	CID	D016559	D056486	1.0
579 | 18186898	CID	CHEBI:8764	D007625	1.0
580 | 18186898	CID	D019259	D007625	1.0
581 | 18186898	CID	D013256	D007625	1.0
582 | 18186898	CID	D020123	D007625	1.0
583 | 18186898	CID	D016559	D007625	1.0
584 | 18186898	CID	CHEBI:8764	D008107	1.0
585 | 18186898	CID	D019259	D008107	1.0
586 | 18186898	CID	D013256	D008107	1.0
587 | 18186898	CID	D020123	D008107	1.0
588 | 18186898	CID	D016559	D008107	1.0
589 | 18186898	CID	CHEBI:8764	D017674	1.0
590 | 18186898	CID	D019259	D017674	1.0
591 | 18186898	CID	D013256	D017674	1.0
592 | 18186898	CID	D020123	D017674	1.0
593 | 18186898	CID	D016559	D017674	1.0
594 | 18186898	CID	CHEBI:8764	D009135	1.0
595 | 18186898	CID	D019259	D009135	1.0
596 | 18186898	CID	D013256	D009135	1.0
597 | 18186898	CID	D020123	D009135	1.0
598 | 18186898	CID	D016559	D009135	1.0
599 | 18186898	CID	CHEBI:8764	D007674	1.0
600 | 18186898	CID	D019259	D007674	1.0
601 | 18186898	CID	D013256	D007674	1.0
602 | 18186898	CID	D020123	D007674	1.0
603 | 18186898	CID	D016559	D007674	1.0
604 | 18186898	CID	CHEBI:8764	D000608	1.0
605 | 18186898	CID	D019259	D000608	1.0
606 | 18186898	CID	D013256	D000608	1.0
607 | 18186898	CID	D020123	D000608	1.0
608 | 18186898	CID	D016559	D000608	1.0
609 | 18186898	CID	CHEBI:8764	D006527	1.0
610 | 18186898	CID	D019259	D006527	1.0
611 | 18186898	CID	D013256	D006527	1.0
612 | 18186898	CID	D020123	D006527	1.0
613 | 18186898	CID	D016559	D006527	1.
614 | 18439803	CID	D018698	D012640	1.0
615 | 18439803	CID	D005998	D012640	1.0
616 | 18439803	CID	C108761	D012640	1.0
617 | 18439803	CID	D014635	D012640	1.0
618 | 18439803	CID	D010862	D012640	1.0
619 | 18439803	CID	D005680	D012640	1.0
620 | 18439803	CID	D000596	D012640	1.0
621 | 18439803	CID	D014635	D012640	1.0
622 | 18439803	CID	C108761	D012640	1.0
623 | 18439803	CID	D001224	D012640	1.
624 | 18619688	CID	D004317	D028361	1.0
625 | 18619688	CID	C025946	D028361	1.0
626 | 18619688	CID	C025946	D028361	1.0
627 | 18619688	CID	D004317	D003643	1.0
628 | 18619688	CID	C025946	D003643	1.0
629 | 18619688	CID	C025946	D003643	1.0
630 | 18619688	CID	D004317	D006333	1.0
631 | 18619688	CID	C025946	D006333	1.0
632 | 18619688	CID	C025946	D006333	1.
633 | 18809400	CID	D002945	D004194	1.0
634 | 18809400	CID	D008063	D004194	1.0
635 | 18809400	CID	D017239	D004194	1.0
636 | 18809400	CID	D002945	D028361	1.0
637 | 18809400	CID	D008063	D028361	1.0
638 | 18809400	CID	D017239	D028361	1.0
639 | 18809400	CID	D002945	D009422	1.0
640 | 18809400	CID	D008063	D009422	1.0
641 | 18809400	CID	D017239	D009422	1.0
642 | 18809400	CID	D002945	D001480	1.0
643 | 18809400	CID	D008063	D001480	1.0
644 | 18809400	CID	D017239	D001480	1.0
645 | 18809400	CID	D002945	D010523	1.0
646 | 18809400	CID	D008063	D010523	1.0
647 | 18809400	CID	D017239	D010523	1.0
648 | 18809400	CID	D002945	D014786	1.0
649 | 18809400	CID	D008063	D014786	1.0
650 | 18809400	CID	D017239	D014786	1.0
651 | 18809400	CID	D002945	D045888	1.0
652 | 18809400	CID	D008063	D045888	1.0
653 | 18809400	CID	D017239	D045888	1.0
654 | 18809400	CID	D002945	D020258	1.0
655 | 18809400	CID	D008063	D020258	1.0
656 | 18809400	CID	D017239	D020258	1.
657 | 19657887	CID	D000079	D013610	1.0
658 | 19657887	CID	D004221	D013610	1.0
659 | 19657887	CID	D000431	D013610	1.0
660 | 19657887	CID	D003484	D013610	1.0
661 | 19657887	CID	D000079	D013610	1.0
662 | 19657887	CID	D000079	D007022	1.0
663 | 19657887	CID	D004221	D007022	1.0
664 | 19657887	CID	D000431	D007022	1.0
665 | 19657887	CID	D003484	D007022	1.0
666 | 19657887	CID	D000079	D007022	1.0
667 | 19657887	CID	D000079	D004417	1.0
668 | 19657887	CID	D004221	D004417	1.0
669 | 19657887	CID	D000431	D004417	1.0
670 | 19657887	CID	D003484	D004417	1.0
671 | 19657887	CID	D000079	D004417	1.0
672 | 19657887	CID	D000079	C536855	1.0
673 | 19657887	CID	D004221	C536855	1.0
674 | 19657887	CID	D000431	C536855	1.0
675 | 19657887	CID	D003484	C536855	1.0
676 | 19657887	CID	D000079	C536855	1.
677 | 19803309	CID	D013390	D013035	1.0
678 | 19803309	CID	D013390	D009222	1.
679 | 21029050	CID	D013390	D001049	1.0
680 | 21029050	CID	D013390	D016609	1.0
681 | 21029050	CID	D013390	D005316	1.
682 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/data/test/rment.py:
--------------------------------------------------------------------------------
 1 | import sys
 2 | import json
 3 | from itertools import groupby
 4 | from turtle import title
 5 | 
 6 | 
 7 | inp_f = sys.argv[1]
 8 | out_f = sys.argv[2]
 9 | 
10 | 
11 | def read_pubtator(file):
12 |     file = open(file, "r")
13 |     lines = (line.strip() for line in file)
14 |     for k, g in groupby(lines, key=bool):
15 |         g = list(g)
16 |         if g[0]:
17 |             yield g
18 |     file.close()
19 | 
20 | def extract_pubtator(lines):
21 |     res = []
22 |     fixed_lines = [
23 |         str_with_null.replace('\x00', '')
24 |         for str_with_null in lines[2:]
25 |     ]
26 |     for line in fixed_lines:
27 |         sline = line.split('\t')
28 |         if sline[1] == 'CID':
29 |             res.append(line+'\t1.0')
30 |     return res
31 | 
32 | data = read_pubtator(inp_f)
33 | with open(out_f, 'w') as f:
34 |     for sample in data:
35 |         lines = extract_pubtator(sample)
36 |         lines = '\n'.join(lines)
37 |         print(lines[:-1], file=f)
38 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_id.sh:
--------------------------------------------------------------------------------
1 | FORMAT=$1
2 | GOLD_FILE=$2
3 | PREDICTION_FILE=$3
4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate id Disease $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO
5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate id Disease $FORMAT $GOLD_FILE $PREDICTION_FILE
6 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_mention.sh:
--------------------------------------------------------------------------------
1 | FORMAT=$1
2 | GOLD_FILE=$2
3 | PREDICTION_FILE=$3
4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate mention Disease $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO
5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate mention Disease $FORMAT $GOLD_FILE $PREDICTION_FILE
6 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/eval_relation.sh:
--------------------------------------------------------------------------------
1 | FORMAT=$1
2 | GOLD_FILE=$2
3 | PREDICTION_FILE=$3
4 | java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate relation CID $FORMAT $GOLD_FILE $PREDICTION_FILE | grep -v INFO
5 | # java -cp bc5cdr_eval.jar ncbi.bc5cdr_eval.Evaluate relation CID $FORMAT $GOLD_FILE $PREDICTION_FILE
6 | 
7 | 


--------------------------------------------------------------------------------
/data/BC5CDR/raw/BC5CDR_Evaluation-0.0.3/readme.txt:
--------------------------------------------------------------------------------
 1 | [Directory]
 2 | A. Introduction
 3 | B. Installation
 4 | C. Instruction
 5 | D. Output control
 6 | E. Troubleshooting
 7 | 
 8 | #======================================================================================#
 9 | 
10 | A. [Introduction]
11 | 
12 | 	This is a set of scripts for executing a Java program for evaluating the performance of BioCreative V CDR task. Please follow the instructions to evaluate your system. We used DNorm, tmChem and coocurrence relations to develop a baseline results in the data/test folder.
13 | 	
14 | 	The three scripts each support both the PubTator (http://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/PubTator/) and BioC (http://bioc.sourceforge.net/) formats.
15 | 
16 | B. [Installation]
17 | 
18 | 	Users need to install Java in their environment. Scripts are provided for the UNIX command line. Batch files for the Windows environment should be straightforward but are not provided.
19 | 	
20 | C. [Instruction]
21 | 
22 | 	This program can evaluate the performance of disease mention recognition (mention), normalization (id) and chemical-induces-disease relation (relation) on both PubTator and BioC formats.
23 | 
24 | 	Instruction:
25 | 
26 | 		./eval_mention.sh [BioC|PubTator] [gold standard] [result] 
27 | 		./eval_id.sh [BioC|PubTator] [gold standard] [result] 
28 | 		./eval_relation.sh [BioC|PubTator] [gold standard] [result] 
29 | 
30 | 	Example mention evaluation:
31 | 		./eval_mention.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.DNER.PubTator
32 | 		OR
33 | 		./eval_mention.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.DNER.BioC.xml
34 | 
35 | 	Results:
36 | 		TP: 303
37 | 		FP: 105
38 | 		FN: 121
39 | 		Precision: 0.7426470588235294
40 | 		Recall: 0.714622641509434
41 | 		F-score: 0.7283653846153848
42 | 		
43 | 	Example ID evaluation:
44 | 		./eval_id.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.DNER.PubTator
45 | 		OR
46 | 		./eval_id.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.DNER.BioC.xml
47 | 	
48 | 	Results:
49 | 		TP: 150
50 | 		FP: 56
51 | 		FN: 64
52 | 		Precision: 0.7281553398058253
53 | 		Recall: 0.7009345794392523
54 | 		F-score: 0.7142857142857142
55 | 
56 | 	Example relation evaluation:
57 | 		./eval_relation.sh PubTator data/gold/CDR_sample.gold.PubTator data/test/CDR_sample.test.CID.PubTator
58 | 		OR
59 | 		./eval_relation.sh BioC data/gold/CDR_sample.gold.BioC.xml data/test/CDR_sample.test.CID.BioC.xml
60 | 	
61 | 	Results:
62 | 		TP: 90
63 | 		FP: 533
64 | 		FN: 33
65 | 		Precision: 0.14446227929373998
66 | 		Recall: 0.7317073170731707
67 | 		F-score: 0.24128686327077747
68 | 
69 | D. [Output control]
70 | 
71 | 	The evaluation code provides some output useful for debugging but this is filtered by the scripts for simplicity. The information removed consists of a tab-delimited string containing the values necessary for determining if the the predicted values match the target values. This information is specific to each evaluation type.
72 | 
73 | 	Mention evaluation:
74 | 		INFO	{TP|FP|FN}	mention	documentId	startOffset	endOffset	mentionType
75 | 	For example:
76 | 		INFO	TP	mention	1720453	34	48	Disease
77 | 
78 | 	ID evaluation:
79 | 		INFO	{TP|FP|FN}	id	documentId	mentionType	conceptId
80 | 	For example:
81 | 		INFO	TP	id	7352670	Disease	MESH:D007022
82 | 
83 | 	Relation evaluation
84 | 		INFO	{TP|FP|FN}	relation	documentId	relationType	conceptId1	conceptId2
85 | 	For example:
86 | 		INFO	TP	relation	2894766	CID	MESH:D012460	MESH:D011014
87 | 
88 | 	To obtain the debug output for a specific evaluation, open the appropriate script, and change the Java command that is commented out (with the "#" symbol) to be the line that does not contain "grep -v INFO".
89 | 		
90 | E. [Troubleshooting]
91 | 
92 | 	The scripts are written to ignore some unexpected data. Writing data in the incorrect format may therefore result in it being ignored. Results with zero TP and zero FP are therefore probably indicating a formatting error.
93 | 	


--------------------------------------------------------------------------------
/data/BC5CDR/raw/CDR_Data/BC5CDR.corpus.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.corpus.pdf


--------------------------------------------------------------------------------
/data/BC5CDR/raw/CDR_Data/BC5CDR.overview.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.overview.pdf


--------------------------------------------------------------------------------
/data/BC5CDR/raw/CDR_Data/BC5CDR.presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BioGPT/648f7d6503c038b44e70b7510bf431c7be94e891/data/BC5CDR/raw/CDR_Data/BC5CDR.presentation.pdf


--------------------------------------------------------------------------------
/data/BC5CDR/raw/CDR_Data/README.txt:
--------------------------------------------------------------------------------
 1 | ===========================================================================
 2 | *
 3 | *                            PUBLIC DOMAIN NOTICE
 4 | *               National Center for Biotechnology Information
 5 | *
 6 | *  This software/database is a "United States Government Work" under the
 7 | *  terms of the United States Copyright Act.  It was written as part of
 8 | *  the author's official duties as a United States Government employee and
 9 | *  thus cannot be copyrighted.  This software/database is freely available
10 | *  to the public for use. The National Library of Medicine and the U.S.
11 | *  Government have not placed any restriction on its use or reproduction.
12 | *
13 | *  Although all reasonable efforts have been taken to ensure the accuracy
14 | *  and reliability of the software and data, the NLM and the U.S.
15 | *  Government do not and cannot warrant the performance or results that
16 | *  may be obtained by using this software or data. The NLM and the U.S.
17 | *  Government disclaim all warranties, express or implied, including
18 | *  warranties of performance, merchantability or fitness for any 
19 | *  particular purpose.
20 | *
21 | *  Please cite the authors in any work or product based on this material:
22 | *
23 | *  1. Wei CH, Peng Y, Leaman R, Davis AP, Mattingly CJ, Li J, Wiegers TC, 
24 | *     Lu Z. Overview of the BioCreative V Chemical Disease Relation (CDR) 
25 | *     Task, Proceedings of the Fifth BioCreative Challenge Evaluation Workshop, 
26 | *     p154-166, 2015 
27 | *
28 | *  2. Li J, Sun Y, Johnson RJ, Sciaky D, Wei CH, Leaman R, Davis AP, Mattingly CJ,
29 | *     Wiegers TC, Lu Z. Anotating chemicals, diseases and their interactions in 
30 | *     biomedical literature, Proceedings of the Fifth BioCreative Challenge 
31 | *     Evaluation Workshop, p173-182, 2015 
32 | *
33 | *  3. Leaman R, Dogan RI, Lu Z. DNorm: disease name normalization with pairwise 
34 | *     learning to rank, Bioinformatics 29(22):2909-17, 2013
35 | * 
36 | *  4. Leaman R, Wei CH, Lu Z. tmChem: a high performance approach for chemical 
37 | *     named entity recognition and normalization. J Cheminform, 7:S3, 2015
38 | *
39 | *
40 | ==========================================================================
41 | 
42 | This directory contains the annotated corpus created and used in the BioCreative 
43 | V Chemical Disease Relation (CDR) Challenge Task [1]. In addition, it contains 
44 | text-mined results of two computational tools for disease & chemcial NER. All data
45 | are made available in both BioC XML and PubTator text formats. 
46 | 
47 | ./CDR.Corpus: The annotated CDR corpus of 1500 PubMed articles of chemicals, 
48 | diseases, and chemical-induced disease relationships [2]. 
49 | 
50 | ./DNorm.TestSet: The text-mined results of diseases on the test set using DNorm [3].
51 | 		 The normalization performance is 0.81 (P), 0.80 (R), and 0.81 (F). 
52 | 
53 | ./tmChem.TestSet: The text-mined results of chemicals on the test set using tmChem [4].
54 | 		  The normalization performance is 0.92 (P), 0.90 (R), and 0.91 (F). 
55 | 
56 | 


--------------------------------------------------------------------------------
/examples/DC-HoC/README.md:
--------------------------------------------------------------------------------
 1 | # Document Classification on HoC
 2 | 
 3 | ## Data
 4 | You can process the data by:
 5 | ``` bash
 6 | bash preprocess.sh
 7 | ```
 8 | 
 9 | ## Training
10 | You can fine-tune the pre-trained BioGPT on the task by:
11 | ``` bash
12 | bash train.sh
13 | ```
14 | 
15 | ## Model Checkpoint
16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models)
17 | 
18 | ## Inference and Evaluating
19 | You can inference and evalaute the model on the test set by:
20 | ``` bash
21 | bash infer.sh
22 | ```


--------------------------------------------------------------------------------
/examples/DC-HoC/hard_match_evaluation.py:
--------------------------------------------------------------------------------
 1 | # coding: utf-8
 2 | # Copyright (c) Microsoft Corporation.
 3 | # Licensed under the MIT License.
 4 | 
 5 | from ast import Global
 6 | import os
 7 | import sys
 8 | from sklearn.metrics import f1_score
 9 | from sklearn.preprocessing import MultiLabelBinarizer
10 | 
11 | pred_file = sys.argv[1]
12 | gold_file = sys.argv[2]
13 | 
14 | 
15 | def convert_hoc_labels(lines):
16 |     labels = []
17 |     classes = ['tumor promoting inflammation', 'inducing angiogenesis', 'evading growth suppressors', 'resisting cell death', 'cellular energetics', 'empty', 'genomic instability and mutation', 'sustaining proliferative signaling', 'avoiding immune destruction', 'activating invasion and metastasis', 'enabling replicative immortality']
18 |     for line in lines:
19 |         labels.append([w.strip() for w in line.strip().split('|')])
20 |     return MultiLabelBinarizer(classes=classes).fit_transform(labels)
21 | 
22 | def do_eval(preds, golden):
23 |     preds = convert_hoc_labels(preds)
24 |     golden = convert_hoc_labels(golden)
25 |     score = f1_score(golden, preds, average='micro')
26 |     print(score)
27 |     return
28 | 
29 | 
30 | def main():
31 |     preds = []
32 |     with open(pred_file) as reader:
33 |         for line in reader:
34 |             preds.append(line.strip())
35 | 
36 |     golden = []
37 |     with open(gold_file) as reader:
38 |         for line in reader:
39 |             line = line.strip()
40 |             if line != '' and len(line) > 0:
41 |                 golden.append(line.strip().split('\t')[-1])
42 |     
43 |     assert len(preds) == len(golden), f"{len(preds)} {len(golden)}"
44 | 
45 |     print("\n====File: ", os.path.basename(pred_file))
46 |     do_eval(preds, golden)
47 | 
48 | 
49 | if __name__ == "__main__":
50 |     main()
51 | 


--------------------------------------------------------------------------------
/examples/DC-HoC/infer.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/DC-HoC-BioGPT
 5 | MODEL=checkpoint_last.pt
 6 | DATA_DIR=${PWD}/../../data/HoC/ansis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/test.tsv
16 | 
17 | # inference
18 | if [ ! -f "$OUTPUT_FILE" ]; then
19 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
20 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
21 | fi
22 | 
23 | # debpe
24 | sed -i "s/@@ //g" ${OUTPUT_FILE}
25 | # detok
26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
27 | # postprocess
28 | python postprocess.py ${OUTPUT_FILE}.detok
29 | # eval
30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE}
31 | 


--------------------------------------------------------------------------------
/examples/DC-HoC/postprocess.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import sys
 5 | import re
 6 | 
 7 | 
 8 | out_file = sys.argv[1]
 9 | 
10 | prefix = [
11 |     '(learned[0-9]+ )+',
12 |     'we can conclude that',
13 |     'we have that',
14 |     'in conclusion,',
15 |     ]
16 | 
17 | 
18 | def strip_prefix(line):
19 |     for p in prefix:
20 |         res = re.search(p, line)
21 |         if res is not None:
22 |             line = re.split(p, line)[-1].strip()
23 |             break
24 |     return line
25 | 
26 | 
27 | def convert_ansis_sentence(sentence):
28 |     ans = None
29 |     segs = re.search(r"the type of this document is(.*)", sentence)
30 |     if segs is not None:
31 |         segs = segs.groups()
32 |         ans = segs[0].strip()
33 |     return ans
34 | 
35 | 
36 | all_lines = []
37 | with open(out_file, "r", encoding="utf8") as fr:
38 |     for line in fr:
39 |         e = line.strip()
40 |         if len(e) > 0 and e[-1] == ".":
41 |             all_lines.append(e[:-1])
42 |         else:
43 |             all_lines.append(e)
44 | 
45 | 
46 | hypothesis = []
47 | cnt = 0
48 | fail_cnt = 0
49 | 
50 | 
51 | for i, line in enumerate(all_lines):
52 |     cnt += 1
53 |     strip_line = strip_prefix(line)
54 |     ans = convert_ansis_sentence(strip_line)
55 |     if ans is not None:
56 |         hypothesis.append(ans)
57 |     else:
58 |         hypothesis.append("failed")
59 |         fail_cnt += 1
60 |         print("Failed:id:{}, line:{}".format(i+1, line))
61 | 
62 | 
63 | with open(f"{out_file}.extracted.txt", "w", encoding="utf8") as fw:
64 |     for eg in hypothesis:
65 |         print(eg, file=fw)
66 | 
67 | 
68 | print(f"failed = {fail_cnt}, total = {cnt}")
69 | 


--------------------------------------------------------------------------------
/examples/DC-HoC/preprocess.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/HoC
 5 | prefix=ansis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR}
14 | 
15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/dict.txt


--------------------------------------------------------------------------------
/examples/DC-HoC/rebuild_data.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | 
 7 | data_dir=sys.argv[1]
 8 | 
 9 | 
10 | def build_target_seq(tgt):
11 |     tgt = 'the type of this document is ' + tgt + '.'
12 |     return tgt
13 | 
14 | 
15 | def loader(fname, fn):
16 |     ret = []
17 |     cnt = 0
18 |     file = open(fname)
19 |     
20 |     for line in file:
21 |         
22 |         if line == '\n':
23 |             continue
24 |         cnt += 1
25 |         sent = line.split('\t')
26 |         source, target = sent[0].replace('\n', '').strip(), sent[1].replace('\n', '').strip()
27 |         if source[-1] == '.':
28 |             ret.append([source, fn(target)])
29 |         else:
30 |             ret.append([source +'.', fn(target)])
31 | 
32 |     print(f"{cnt} samples in {fname} has been processed")
33 |     return ret
34 | 
35 | 
36 | def dumper(content_list, prefix):
37 |     fw_source = open(prefix + ".x", "w")
38 |     fw_target = open(prefix + ".y", "w")
39 |     
40 |     for ele in content_list:
41 |         print(ele[0], file=fw_source)
42 |         print(ele[1], file=fw_target)
43 | 
44 |     fw_source.close()
45 |     fw_target.close()
46 | 
47 | 
48 | def worker(fname, prefix, fn):
49 |     ret = loader(fname, fn)
50 |     dumper(ret, prefix)
51 | 
52 | 
53 | for split in ['train', 'valid', 'test']:
54 |     worker(os.path.join(f"{data_dir}", f"{split}.tsv"), os.path.join(f"{data_dir}", f"ansis_{split}"), build_target_seq)


--------------------------------------------------------------------------------
/examples/DC-HoC/train.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | SAVE_DIR=../../checkpoints/DC-HoC-BioGPT
 5 | mkdir -p ${SAVE_DIR}
 6 | 
 7 | fairseq-train \
 8 |     ../../data/HoC/ansis-bin --save-dir ${SAVE_DIR} \
 9 |     --user-dir ../../src \
10 |     --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \
11 |     --task language_modeling_prompt \
12 |     --arch transformer_lm_prompt_biogpt \
13 |     --share-decoder-input-output-embed --decoder-learned-pos \
14 |     --optimizer adam --adam-betas '(0.9, 0.98)' \
15 |     --weight-decay 0.01 --clip-norm 0.0 \
16 |     --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 1000 --warmup-init-lr 1e-07 \
17 |     --tokens-per-sample 1024 --max-source-positions 900 --max-target-positions 1024 \
18 |     --max-tokens 1024 --update-freq 32 \
19 |     --skip-invalid-size-inputs-valid-test \
20 |     --max-update 20000 --save-interval-updates 1000 --no-epoch-checkpoints \
21 |     --learned-prompt 1


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/README.md:
--------------------------------------------------------------------------------
 1 | # Question Answering on PubMedQA in Reasoning Required Setting
 2 | 
 3 | ## Data
 4 | Download data from [PubMedQA](https://github.com/pubmedqa/pubmedqa) and following the steps of splitting dataset.
 5 | 
 6 | Copy the files `pqal_fold0/train_set.json`, `pqal_fold0/dev_set.json`, `test_set.json` and `test_ground_truth.json` to `../../data/PubMedQA/raw`
 7 | 
 8 | Then, you can process the data by:
 9 | ``` bash
10 | bash preprocess.sh # for BioGPT
11 | ```
12 | or 
13 | ``` bash
14 | bash preprocess_large.sh # for BioGPT-Large
15 | ```
16 | 
17 | 
18 | ## Model Checkpoint
19 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models)
20 | 
21 | ## Inference and Evaluating
22 | You can inference and evaluate the model on the test set by:
23 | ``` bash
24 | bash infer.sh # for BioGPT
25 | ```
26 | or 
27 | ``` bash
28 | bash infer_large.sh # for BioGPT-Large
29 | ```


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/hard_match_evaluation.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import json
 7 | from sklearn.metrics import accuracy_score
 8 | 
 9 | pred_file = sys.argv[1]
10 | gold_file = sys.argv[2]
11 | 
12 | 
13 | def do_eval(preds, golden):
14 |     print(accuracy_score(golden, preds))
15 |     return
16 | 
17 | 
18 | def main():
19 |     preds = []
20 |     with open(pred_file) as reader:
21 |         for line in reader:
22 |             preds.append(line.strip())
23 | 
24 |     golden = []
25 |     if gold_file.endswith('.tsv'):
26 |         with open(gold_file) as reader:
27 |             for line in reader:
28 |                 line = line.strip()
29 |                 if line != '' and len(line) > 0:
30 |                     golden.append(line.strip().split('\t')[-1])
31 |     elif gold_file.endswith('.json'):
32 |         with open(gold_file) as reader:
33 |             data = json.load(reader)
34 |             golden = [label for pmid, label in data.items()]
35 |     assert len(preds) == len(golden), f"{len(preds)} {len(golden)}"
36 | 
37 |     print("\n====File: ", os.path.basename(pred_file))
38 |     do_eval(preds, golden)
39 | 
40 | 
41 | if __name__ == "__main__":
42 |     main()
43 | 


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/infer.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/QA-PubMedQA-BioGPT
 5 | MODEL=checkpoint.pt
 6 | DATA_DIR=${PWD}/../../data/PubMedQA/pqal_qcl_ansis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/test_ground_truth.json
16 | 
17 | # inference
18 | if [ ! -f "$OUTPUT_FILE" ]; then
19 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
20 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
21 | fi
22 | 
23 | # debpe
24 | sed -i "s/@@ //g" ${OUTPUT_FILE}
25 | # detok
26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
27 | # postprocess
28 | python postprocess.py ${OUTPUT_FILE}.detok
29 | # eval
30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE}
31 | 


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/infer_large.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/QA-PubMedQA-BioGPT-Large
 5 | MODEL=checkpoint.pt
 6 | DATA_DIR=${PWD}/../../data/PubMedQA/biogpt-large-pqal_qcl_ansis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/test_ground_truth.json
16 | 
17 | # inference
18 | if [ ! -f "$OUTPUT_FILE" ]; then
19 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
20 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
21 | fi
22 | 
23 | # debpe
24 | sed -i "s/@@ //g" ${OUTPUT_FILE}
25 | # detok
26 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
27 | # postprocess
28 | python postprocess.py ${OUTPUT_FILE}.detok
29 | # eval
30 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.txt ${GOLD_FILE}
31 | 


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/postprocess.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import sys
 5 | import re
 6 | 
 7 | 
 8 | out_file = sys.argv[1]
 9 | 
10 | prefix = [
11 |     '(learned[0-9]+ )+',
12 |     'we can conclude that',
13 |     'we have that',
14 |     'in conclusion,',
15 |     ]
16 | 
17 | 
18 | def strip_prefix(line):
19 |     for p in prefix:
20 |         res = re.search(p, line)
21 |         if res is not None:
22 |             line = re.split(p, line)[-1].strip()
23 |             break
24 |     return line
25 | 
26 | 
27 | def convert_relis_sentence(sentence):
28 |     ans = None
29 |     segs = re.search(r"the answer to the question given the context is(.*)", sentence)
30 |     if segs is not None:
31 |         segs = segs.groups()
32 |         ans = segs[0].strip()
33 |     return ans
34 | 
35 | 
36 | all_lines = []
37 | with open(out_file, "r", encoding="utf8") as fr:
38 |     for line in fr:
39 |         e = line.strip()
40 |         if len(e) > 0 and e[-1] == ".":
41 |             all_lines.append(e[:-1])
42 |         else:
43 |             all_lines.append(e)
44 | 
45 | 
46 | hypothesis = []
47 | cnt = 0
48 | fail_cnt = 0
49 | 
50 | 
51 | for i, line in enumerate(all_lines):
52 |     cnt += 1
53 |     strip_line = strip_prefix(line)
54 |     ans = convert_relis_sentence(strip_line)
55 |     if ans is not None:
56 |         hypothesis.append(ans)
57 |     else:
58 |         hypothesis.append("failed")
59 |         fail_cnt += 1
60 |         print("Failed:id:{}, line:{}".format(i+1, line))
61 | 
62 | 
63 | with open(f"{out_file}.extracted.txt", "w", encoding="utf8") as fw:
64 |     for eg in hypothesis:
65 |         print(eg, file=fw)
66 | 
67 | 
68 | print(f"failed = {fail_cnt}, total = {cnt}")
69 | 


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/preprocess.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/PubMedQA
 5 | prefix=pqal_qcl_ansis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR} ${prefix}
14 | 
15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/dict.txt


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/preprocess_large.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/PubMedQA
 5 | prefix=biogpt-large-pqal_qcl_ansis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR} ${prefix}
14 | 
15 | cp ${DATA_DIR}/../biogpt_large_dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../biogpt_large_bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/biogpt_large_bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/biogpt_large_bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/biogpt_large_dict.txt


--------------------------------------------------------------------------------
/examples/QA-PubMedQA/rebuild_data.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import re
 7 | import json
 8 | 
 9 | data_dir=sys.argv[1]
10 | prefix=sys.argv[2]
11 | 
12 | 
13 | def build_source_seq(question, context, long_answer=None):
14 |     if long_answer:
15 |         src = "question: {} context: {} answer: {}".format(question.strip(), context.strip(), long_answer.strip())
16 |     else:
17 |         src = "question: {} context: {} ".format(question.strip(), context.strip())
18 |     return src
19 | 
20 | 
21 | def build_target_seq(tgt):
22 |     tgt = 'the answer to the question given the context is ' + tgt + '.'
23 |     return tgt
24 | 
25 | 
26 | def loader(fname, fn, required_long_answer=False):
27 |     ret = []
28 |     cnt = 0
29 |     
30 |     with open(fname, 'r') as file:
31 |         data = json.load(file)
32 |     
33 |     for pmid, content in data.items():
34 |         cnt += 1
35 |         question = content['QUESTION']
36 |         context = ' '.join(sen.strip() for sen in content['CONTEXTS'])
37 |         context = re.sub(r'\n', ' ', context)
38 |         # remove duplicate spaces
39 |         context = re.sub(r'\s+', ' ', context)
40 |         long_answer = content['LONG_ANSWER']
41 |         if required_long_answer:
42 |             source = build_source_seq(question, context, long_answer)
43 |         else:
44 |             source = build_source_seq(question, context)
45 | 
46 |         if 'final_decision' in content:
47 |             label = content['final_decision']
48 |             target = fn(label)
49 |         else:
50 |             target = ''
51 |         if isinstance(target, list):
52 |             for i in range(len(target)):
53 |                 data_pair = [source, target[i]]
54 |                 ret.append(data_pair)
55 |         else:
56 |             data_pair = [source, target]
57 |             ret.append(data_pair)
58 | 
59 |     print(f"{cnt} samples in {fname} has been processed")
60 |     return ret
61 | 
62 | 
63 | def dumper(content_list, prefix):
64 |     fw_source = open(prefix + ".x", "w")
65 |     fw_target = open(prefix + ".y", "w")
66 |     
67 |     for ele in content_list:
68 |         print(ele[0], file=fw_source)
69 |         print(ele[1], file=fw_target)
70 | 
71 |     fw_source.close()
72 |     fw_target.close()
73 | 
74 | 
75 | def worker(fname, prefix, fn):
76 |     ret = loader(fname, fn)
77 |     dumper(ret, prefix)
78 | 
79 | 
80 | worker(os.path.join(f"{data_dir}", "train_set.json"), os.path.join(f"{data_dir}", f"{prefix}_train"), build_target_seq)
81 | worker(os.path.join(f"{data_dir}", "dev_set.json"), os.path.join(f"{data_dir}", f"{prefix}_valid"), build_target_seq)
82 | worker(os.path.join(f"{data_dir}", "test_set.json"), os.path.join(f"{data_dir}", f"{prefix}_test"), build_target_seq)


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/README.md:
--------------------------------------------------------------------------------
 1 | # Relation Extraction on BC5CDR
 2 | 
 3 | ## Data
 4 | You can process the data by:
 5 | ``` bash
 6 | bash preprocess.sh
 7 | ```
 8 | 
 9 | ## Training
10 | You can fine-tune the pre-trained BioGPT on the task by:
11 | ``` bash
12 | bash train.sh
13 | ```
14 | 
15 | ## Model Checkpoint
16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models)
17 | 
18 | ## Inference and Evaluating
19 | You can inference and evalaute the model on the test set by:
20 | ``` bash
21 | bash infer.sh
22 | ```


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/infer.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/RE-BC5CDR-BioGPT
 5 | MODEL=checkpoint_avg.pt
 6 | DATA_DIR=${PWD}/../../data/BC5CDR/relis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/CDR_Data/CDR.Corpus.v010516/CDR_TestSet.PubTator.txt
16 | ENTITY_FILE=${RAW_DATA_DIR}/test.entities.json
17 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid
18 | 
19 | # average checkpoints
20 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
21 |     python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5
22 | fi
23 | 
24 | 
25 | # inference
26 | if [ ! -f "$OUTPUT_FILE" ]; then
27 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
28 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
29 | fi
30 | 
31 | # debpe
32 | sed -i "s/@@ //g" ${OUTPUT_FILE}
33 | # detok
34 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
35 | # postprocess
36 | python postprocess.py ${OUTPUT_FILE}.detok ${ENTITY_FILE} ${PMID_FILE}
37 | # eval
38 | cd ${RAW_DATA_DIR}/BC5CDR_Evaluation-0.0.3
39 | bash eval_relation.sh PubTator ${OLDPWD}/${GOLD_FILE} ${OLDPWD}/${OUTPUT_FILE}.detok.extracted.PubTator
40 | cd ${OLDPWD}


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/postprocess.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import re
 7 | import json
 8 | 
 9 | 
10 | out_file = sys.argv[1]
11 | entity_file=sys.argv[2]
12 | pmids_file = sys.argv[3]
13 | 
14 | prefix = [
15 |     '(learned[0-9]+ )+',
16 |     'in conclusion ,',
17 |     'we can conclude that',
18 |     'we have that',
19 |     ]
20 | 
21 | 
22 | def strip_prefix(line):
23 |     for p in prefix:
24 |         res = re.search(p, line)
25 |         if res is not None:
26 |             line = re.split(p, line)[-1].strip()
27 |             break
28 |     return line
29 | 
30 | 
31 | def split_sentence(line):
32 |     sentences = re.split(r";", line)
33 |     return sentences
34 | 
35 | 
36 | def convert_relis_sentence(sentence):
37 |     ans = None
38 |     segs = re.match(r"the relation between (.*) and (.*) exists", sentence.strip())
39 |     if segs is not None:
40 |         segs = segs.groups()
41 |         chemical = segs[0].strip()
42 |         disease = segs[1].strip()
43 |         ans = (chemical, disease)
44 |     return ans
45 | 
46 | 
47 | all_lines = []
48 | with open(out_file, "r", encoding="utf8") as fr:
49 |     for line in fr:
50 |         e = line.strip()
51 |         if e[-1] == ".":
52 |             all_lines.append(e[:-1])
53 |         else:
54 |             all_lines.append(e)
55 | with open(entity_file, "r", encoding="utf8") as fr:
56 |     ent2id = json.load(fr)
57 | with open(pmids_file, "r") as reader:
58 |     if '.json' in pmids_file:
59 |         pmids = json.load(reader)
60 |     else:
61 |         pmids = []
62 |         for line in reader:
63 |             pmids.append(line.strip())
64 | 
65 | 
66 | hypothesis = []
67 | cnt = 0
68 | fail_cnt = 0
69 | for i, line in enumerate(all_lines):
70 |     cnt += 1
71 |     strip_line = strip_prefix(line)
72 |     ret = []
73 |     sentences = split_sentence(strip_line)
74 |     for sen in sentences:
75 |         ans = convert_relis_sentence(sen)
76 |         if ans is not None:
77 |             chemical, disease = ans
78 |             chemicalID = ent2id['chemical2id'].get(chemical.strip(), "-1")
79 |             diseaseID = ent2id['disease2id'].get(disease.strip(), "-1")
80 |             ret.append(f"{pmids[i]}\tCID\t{chemicalID}\t{diseaseID}\t1.0")
81 |     if len(ret) > 0:
82 |         hypothesis.extend(ret)
83 |     else:
84 |         fail_cnt += 1
85 |         print("Failed:id:{}, line:{}".format(i+1, line))
86 | 
87 | 
88 | with open(f"{out_file}.extracted.PubTator", "w", encoding="utf8") as fw:
89 |     for line in hypothesis:
90 |         print(line, file=fw)
91 | 
92 | 
93 | print(f"failed = {fail_cnt}, total = {cnt}")


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/preprocess.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/BC5CDR
 5 | prefix=relis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR}
14 | 
15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/dict.txt


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/rebuild_data.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import os
  5 | import sys
  6 | import json
  7 | import re
  8 | 
  9 | data_dir=sys.argv[1]
 10 | 
 11 | 
 12 | def unify_ent2id(ent2id, method='max'):
 13 |     id2ent = {}
 14 |     for k, v in ent2id.items():
 15 |         if v in id2ent:
 16 |             if method == 'min':
 17 |                 id2ent[v] = k if len(k) < len(id2ent[v]) else id2ent[v]
 18 |             else:
 19 |                 id2ent[v] = k if len(k) > len(id2ent[v]) else id2ent[v]
 20 |         else:
 21 |             id2ent[v] = k
 22 |     ent2id = {v:k for k, v in id2ent.items()}
 23 |     return ent2id, id2ent
 24 |     
 25 | 
 26 | def sort_triples(triples, text):
 27 |     sorted_triples = sorted(triples, key=lambda x:text.find(x['chemical']))
 28 |     return sorted_triples
 29 | 
 30 | 
 31 | def build_target_seq_svo(relations, id2chem, id2disease):        
 32 |     answer = ""
 33 |     for z in relations:
 34 |         chemical = id2chem[z["chemical"]]
 35 |         disease = id2disease[z["disease"]]
 36 |         answer += f"{chemical} correlates with {disease}; "
 37 |     return answer[:-2] + "."
 38 | 
 39 | 
 40 | def build_target_seq_relis(relations, id2chem, id2disease):        
 41 |     answer = ""
 42 |     for z in relations:
 43 |         chemical = id2chem[z["chemical"]]
 44 |         disease = id2disease[z["disease"]]
 45 |         answer += f"the relation between {chemical} and {disease} exists; "
 46 |     return answer[:-2] + "."
 47 | 
 48 | 
 49 | def loader(fname, fn):
 50 |     ret = []
 51 |     null_cnt = 0
 52 |     suc_cnt = 0
 53 |     null_flag = False
 54 |     with open(fname, "r", encoding="utf8") as fr:
 55 |         data = json.load(fr)
 56 |     for pmid, v in data.items():
 57 |         if re.search(r"\W
quot;, v["title"]):
 58 |             content = v["title"] + " " + v["abstract"]
 59 |         else:
 60 |             content = v["title"] + ". " + v["abstract"]
 61 | 
 62 |         content = content.lower()
 63 |         
 64 |         if v["relations"] is None or len(v["relations"]) == 0:
 65 |             if not null_flag:
 66 |                 print(f"Following PMID in {fname} has no extracted relations:")
 67 |                 null_flag = True
 68 |             print(f"{pmid} ", end="")
 69 |             null_cnt += 1
 70 |         else:
 71 |             chemical2id = v["chemical2id"]
 72 |             disease2id = v["disease2id"]
 73 |             unified_chemical2id, id2chemical = unify_ent2id(chemical2id, method='max')
 74 |             unified_disease2id, id2disease = unify_ent2id(disease2id, method='max')
 75 |             answer = fn(v["relations"], id2chemical, id2disease)
 76 |             ret.append((pmid, content, answer))
 77 |             suc_cnt += 1
 78 |     if null_flag:
 79 |         print("")
 80 |     print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no relations extracted.")
 81 |     return ret
 82 | 
 83 | def dumper(content_list, prefix):
 84 |     fw_pmid = open(prefix + ".pmid", "w")
 85 |     fw_content = open(prefix + ".x", "w")
 86 |     fw_label = open(prefix + ".y", "w")
 87 |     
 88 |     for ele in content_list:
 89 |         print(ele[0], file=fw_pmid)
 90 |         print(ele[1], file=fw_content)
 91 |         print(ele[2], file=fw_label)
 92 | 
 93 |     fw_pmid.close()
 94 |     fw_content.close()
 95 |     fw_label.close()
 96 | 
 97 | def worker(fname, prefix, fn):
 98 |     ret = loader(fname, fn)
 99 |     dumper(ret, prefix)
100 |            
101 | 
102 | for split in ['train', 'valid', 'test']:
103 |     worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis)


--------------------------------------------------------------------------------
/examples/RE-BC5CDR/train.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | SAVE_DIR=../../checkpoints/RE-BC5CDR-BioGPT
 5 | mkdir -p ${SAVE_DIR}
 6 | 
 7 | fairseq-train \
 8 |     ../../data/BC5CDR/relis-bin --save-dir ${SAVE_DIR} \
 9 |     --user-dir ../../src \
10 |     --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \
11 |     --task language_modeling_prompt \
12 |     --arch transformer_lm_prompt_biogpt \
13 |     --share-decoder-input-output-embed --decoder-learned-pos \
14 |     --optimizer adam --adam-betas '(0.9, 0.98)' \
15 |     --weight-decay 0.01 --clip-norm 0.0 \
16 |     --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 100 --warmup-init-lr 1e-07 \
17 |     --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \
18 |     --max-tokens 1024 --update-freq 32 \
19 |     --skip-invalid-size-inputs-valid-test \
20 |     --max-epoch 100 --keep-last-epochs 5 \
21 |     --learned-prompt 9


--------------------------------------------------------------------------------
/examples/RE-DDI/README.md:
--------------------------------------------------------------------------------
 1 | # Relation Extraction on DDI
 2 | 
 3 | ## Data
 4 | You can process the data by:
 5 | ``` bash
 6 | bash preprocess.sh
 7 | ```
 8 | 
 9 | ## Training
10 | You can fine-tune the pre-trained BioGPT on the task by:
11 | ``` bash
12 | bash train.sh
13 | ```
14 | 
15 | ## Model Checkpoint
16 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models)
17 | 
18 | ## Inference and Evaluating
19 | You can inference and evalaute the model on the test set by:
20 | ``` bash
21 | bash infer.sh
22 | ```


--------------------------------------------------------------------------------
/examples/RE-DDI/hard_match_evaluation.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import re
  5 | import json
  6 | import sys
  7 | import os
  8 | 
  9 | pred_file = sys.argv[1]
 10 | gold_file = sys.argv[2]
 11 | pmids_file = sys.argv[3]
 12 | 
 13 | def normalize_name(s: str):
 14 |     s = s.strip()
 15 | 
 16 |     # normalize roman type id at end of string
 17 |     num2roman = {"0": "0", "1": "I", "2": "II", "3": "III", "4": "IV", "5": "V", "6": "VI", "7": "VII", "8": "VIII", "9": "IX"}
 18 |     if len(s) > 2 and s[-1].isnumeric() and not s[-2].isnumeric() and s[-1] in num2roman:
 19 |         tmps = list(s)
 20 |         s = ''.join(tmps[:-1]) + num2roman[tmps[-1]]
 21 | 
 22 |     # remove useless end string
 23 |     s = s.replace("-type", '')
 24 | 
 25 |     return re.sub('[^a-zA-Z0-9]+', '', s)
 26 | 
 27 | 
 28 | def rm_abbr(tgt_set):
 29 |     """ remove abbreviation in the brackets of entity, eg: aaa (bb) -> aaa """
 30 |     def rm(s):
 31 |         s = s.strip()
 32 |         if "(" in s and s[-1] == ')':  # entity end with a bracketed short cut
 33 |             return normalize_name(s[:s.rfind("(")].strip())
 34 |         else:
 35 |             return normalize_name(s)
 36 | 
 37 |     tgt_set = list(tgt_set)
 38 |     if tgt_set and type(tgt_set[0]) in [tuple, list]:  # process triples
 39 |         return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set])
 40 |     else:  # process entities
 41 |         return set([rm(e) for e in tgt_set])
 42 | 
 43 | 
 44 | def get_abbr(tgt_set):
 45 |     """ extract abbreviation in the brackets of entity, eg: aaa (bb) -> bb """
 46 |     def rm(s):
 47 |         s = s.strip()
 48 |         if "(" in s and s[-1] == ')':
 49 |             return normalize_name(s[s.rfind("(")+1:-1].strip())
 50 |         else:
 51 |             return normalize_name(s)
 52 | 
 53 |     tgt_set = list(tgt_set)
 54 |     if tgt_set and type(tgt_set[0]) in [tuple, list]:  # process triples
 55 |         return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set])
 56 |     else:  # process entities
 57 |         return set([rm(e) for e in tgt_set])
 58 | 
 59 | 
 60 | def acc(pred_set, gold_set):
 61 |     """ Multi-label style acc """
 62 |     tp_num = len(pred_set & gold_set)
 63 |     return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(pred_set | gold_set)
 64 | 
 65 | 
 66 | def precision(pred_set, gold_set):
 67 |     """ Multi-label style precision """
 68 |     tp_num = len(pred_set & gold_set)
 69 |     return int(pred_set == gold_set) if len(pred_set) == 0 else 1.0 * tp_num / len(pred_set)
 70 | 
 71 | 
 72 | def recall(pred_set, gold_set):
 73 |     """ Multi-label style recall """
 74 |     tp_num = len(pred_set & gold_set)
 75 |     return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(gold_set)
 76 | 
 77 | 
 78 | def normed_eval(pred_set, gold_set, metric):
 79 |     """ Both body and abbreviation match are considered correct """
 80 |     abbr_pred_set, abbr_gold_set = get_abbr(pred_set), get_abbr(gold_set)
 81 |     rm_pred_set, rm_gold_set = rm_abbr(pred_set), rm_abbr(gold_set)
 82 |     return max(metric(abbr_pred_set, abbr_gold_set), metric(rm_pred_set, rm_gold_set))
 83 | 
 84 | 
 85 | def get_f1(p, r):
 86 |     return 0 if (p + r) == 0 else (2.0 * p * r / (p + r))
 87 | 
 88 | 
 89 | def ave(scores):
 90 |     return 1.0 * sum(scores) / len(scores)
 91 | 
 92 | 
 93 | def do_eval(preds, pmids, golden):
 94 |     ret = []
 95 |     num_pred, num_gold, num_missing = 0, 0, 0
 96 |     all_f1, p, r, d_acc, t_acc, i_acc = [], [], [], [], [], []
 97 |     all_pred_triple, all_pred_d, all_pred_t, all_pred_i, all_gold_triple, all_gold_d, all_gold_t, all_gold_i = [], [], [], [], [], [], [], [],
 98 | 
 99 |     for pred, idx in zip(preds, pmids):
100 |         gold_d_set, gold_t_set, gold_i_set, gold_set = set(), set(), set(), set()
101 |         pred_d_set, pred_t_set, pred_i_set, pred_set = set(), set(), set(), set()
102 |         
103 |         if pred["triple_list_pred"] and pred["triple_list_pred"][0]["subject"] != 'failed':
104 |             for tp in pred["triple_list_pred"]:
105 |                 d = tp["subject"].strip().lower().replace(' ', '')
106 |                 t = tp["object"].strip().lower().replace(' ', '')
107 |                 i = tp["relation"].strip().lower().replace(' ', '')
108 | 
109 |                 pred_d_set.add(d)
110 |                 pred_t_set.add(t)
111 |                 pred_i_set.add(i)
112 |                 pred_set.add((d, t, i))
113 |         if idx not in golden:
114 |             num_missing += 1
115 |             # print("----Missing:", idx)
116 |             continue
117 |         if golden[idx]["triples"]:
118 |             for tp in golden[idx]["triples"]:
119 |                 d = tp["drug"].strip().lower().replace(' ', '')
120 |                 t = tp["target"].strip().lower().replace(' ', '')
121 |                 i = tp["interaction"].strip().lower().replace(' ', '')
122 |                 gold_d_set.add(d)
123 |                 gold_t_set.add(t)
124 |                 gold_i_set.add(i)
125 |                 gold_set.add((d, t, i))
126 | 
127 |         # sample level eval
128 |         p.append(normed_eval(pred_set, gold_set, metric=precision))
129 |         r.append(normed_eval(pred_set, gold_set, metric=recall))
130 |         all_f1.append(get_f1(p[-1], r[-1]))
131 |         d_acc.append(normed_eval(pred_d_set, gold_d_set, metric=acc))
132 |         t_acc.append(normed_eval(pred_t_set, gold_t_set, metric=acc))
133 |         i_acc.append(normed_eval(pred_i_set, gold_i_set, metric=acc))
134 | 
135 |         # onto level eval
136 |         all_pred_d.extend(pred_d_set)
137 |         all_pred_t.extend(pred_t_set)
138 |         all_pred_i.extend(pred_i_set)
139 |         all_pred_triple.extend(pred_set)
140 |         all_gold_d.extend(gold_d_set)
141 |         all_gold_t.extend(gold_t_set)
142 |         all_gold_i.extend(gold_i_set)
143 |         all_gold_triple.extend(gold_set)
144 |         
145 |         # if len(gold_set) < len(golden[idx]["triples"]):
146 |             # print("Duplicate extists, ori", golden[idx]["triples"], gold_set)
147 | 
148 |         num_pred += len(pred_set)
149 |         num_gold += len(gold_set)
150 | 
151 |         ret.append({
152 |             "pmid": idx,
153 |             "title": golden[idx]["title"] if "title" in golden[idx] else None,
154 |             "abstract": golden[idx]["abstract"],
155 |             "d_pred_gold": [d_acc[-1], list(pred_d_set), list(gold_d_set)],
156 |             "t_pred_gold": [t_acc[-1], list(pred_t_set), list(gold_t_set)],
157 |             "i_pred_gold": [i_acc[-1], list(pred_i_set), list(gold_i_set)],
158 |             "all_pred_gold": [all_f1[-1], list(pred_set), list(gold_set)],
159 |         })
160 | 
161 | 
162 |     print("num sample", len(all_f1), "missing", len(preds) - len(all_f1), "num_gold tp", num_gold, "num_pred", num_pred)
163 | 
164 |     # Note: we adopt multi-label metrics following: http://129.211.169.156/publication/tkde13rev.pdf
165 |     print("Sample: acc d: {:.4f}\tt:{:.4f}\ti: {:.4f}\ntp p: {:.4f}\ttp r: {:.4f}\ttp micro f1: {:.4f}\ttp macro f1: {:.4f} ".format(
166 |         ave(d_acc), ave(t_acc), ave(i_acc), ave(p), ave(r), ave(all_f1), get_f1(ave(p), ave(r))))
167 | 
168 |     # Ontology evaluation_scripts
169 |     all_p, all_r = normed_eval(set(all_pred_triple), set(all_gold_triple), metric=precision), normed_eval(set(all_pred_triple), set(all_gold_triple), metric=recall)
170 |     d_p, d_r = normed_eval(set(all_pred_d), set(all_gold_d), metric=precision), normed_eval(set(all_pred_d), set(all_gold_d), metric=recall)
171 |     t_p, t_r = normed_eval(set(all_pred_t), set(all_gold_t), metric=precision), normed_eval(set(all_pred_t), set(all_gold_t), metric=recall)
172 |     i_p, i_r = normed_eval(set(all_pred_i), set(all_gold_i), metric=precision), normed_eval(set(all_pred_i), set(all_gold_i), metric=recall)
173 | 
174 |     print("Ontology: f1 d: {:.4f}\tt:{:.4f}\ti: {:.4f}\t \nall p: {:.4f}\tall r: {:.4f}\tonto f1: {:.4f}".format(
175 |         get_f1(d_p, d_r), get_f1(t_p, t_r), get_f1(i_p, i_r), all_p, all_r, get_f1(all_p, all_r)
176 |     ))
177 |     return ret
178 | 
179 | 
180 | def main():
181 |     preds = []
182 |     with open(pred_file) as reader:
183 |         for line in reader:
184 |             preds.append(json.loads(line))
185 | 
186 |     with open(gold_file) as reader:
187 |         golden = json.load(reader)
188 | 
189 |     with open(pmids_file) as reader:
190 |         if '.json' in pmids_file:
191 |             pmids = json.load(reader)
192 |         else:
193 |             pmids = []
194 |             for line in reader:
195 |                 pmids.append(line.strip())
196 | 
197 |     print("\n====File: ", os.path.basename(pred_file))
198 |     result = do_eval(preds, pmids, golden)
199 | 
200 |     last_pos = pred_file.rfind('.json')
201 |     res_file_name = pred_file[:last_pos] + '.eval_res.json'
202 |     with open(res_file_name, 'w') as writer:
203 |         json.dump(result, writer, indent=2)
204 | 
205 | 
206 | if __name__ == "__main__":
207 |     main()
208 | 


--------------------------------------------------------------------------------
/examples/RE-DDI/infer.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/RE-DDI-BioGPT
 5 | MODEL=checkpoint_avg.pt
 6 | DATA_DIR=${PWD}/../../data/DDI/relis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/test.json
16 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid
17 | 
18 | # average checkpoints
19 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
20 |     python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5
21 | fi
22 | 
23 | # inference
24 | if [ ! -f "$OUTPUT_FILE" ]; then
25 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
26 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
27 | fi
28 | 
29 | # debpe
30 | sed -i "s/@@ //g" ${OUTPUT_FILE}
31 | # detok
32 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
33 | # postprocess
34 | python postprocess.py ${OUTPUT_FILE}.detok
35 | # eval
36 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.json ${GOLD_FILE} ${PMID_FILE}
37 | 


--------------------------------------------------------------------------------
/examples/RE-DDI/postprocess.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import re
 7 | import json
 8 | 
 9 | 
10 | out_file = sys.argv[1]
11 | 
12 | prefix = [
13 |     '(learned[0-9]+ )+',
14 |     'we can conclude that',
15 |     'we have that',
16 |     'in conclusion,',
17 |     ]
18 | 
19 | 
20 | def strip_prefix(line):
21 |     for p in prefix:
22 |         res = re.search(p, line)
23 |         if res is not None:
24 |             line = re.split(p, line)[-1].strip()
25 |             break
26 |     return line
27 | 
28 | 
29 | def split_sentence(line):
30 |     sentences = re.split(r"; ", line)
31 |     return sentences
32 | 
33 | 
34 | def convert_relis_sentence(sentence):
35 |     ans = None
36 |     segs = re.match(r"the interaction between (.*) and (.*) is (.*)", sentence)
37 |     if segs is not None:
38 |         segs = segs.groups()
39 |         ans = (segs[0].strip(), segs[2].strip(), segs[1].strip())
40 |     return ans
41 | 
42 | 
43 | def converter(sample, h_idx=0, r_idx=1, t_idx=2):
44 |     ret = {"triple_list_gold": [], "triple_list_pred": [], "new": [], "lack": [], "id": [0]}
45 |     for s in sample:
46 |         ret["triple_list_pred"].append({"subject": s[h_idx], "relation": s[r_idx], "object": s[t_idx]})
47 |     return ret
48 | 
49 | 
50 | all_lines = []
51 | with open(out_file, "r", encoding="utf8") as fr:
52 |     for line in fr:
53 |         e = line.strip()
54 |         if len(e) > 0 and e[-1] == ".":
55 |             all_lines.append(e[:-1])
56 |         else:
57 |             all_lines.append(e)
58 | 
59 | 
60 | hypothesis = []
61 | cnt = 0
62 | fail_cnt = 0
63 | 
64 | 
65 | for i, line in enumerate(all_lines):
66 |     cnt += 1
67 |     ret = []
68 |     strip_line = strip_prefix(line)
69 |     sentences = split_sentence(strip_line)
70 |     for sen in sentences:
71 |         ans = convert_relis_sentence(sen)
72 |         if ans is not None:
73 |             ret.append(ans)
74 |     if len(ret) > 0:
75 |         hypothesis.append(ret)
76 |     else:
77 |         hypothesis.append([("failed", "failed", "failed")])
78 |         fail_cnt += 1
79 |         print("Failed:id:{}, line:{}".format(i+1, line))
80 | 
81 | 
82 | ret_formatted = []
83 | for i in range(len(hypothesis)):
84 |     ret_formatted.append(converter(hypothesis[i]))
85 | 
86 | 
87 | with open(f"{out_file}.extracted.json", "w", encoding="utf8") as fw:
88 |     for eg in ret_formatted:
89 |         print(json.dumps(eg), file=fw)
90 | 
91 | 
92 | print(f"failed = {fail_cnt}, total = {cnt}")
93 | 


--------------------------------------------------------------------------------
/examples/RE-DDI/preprocess.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/DDI
 5 | prefix=relis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR}
14 | 
15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/dict.txt


--------------------------------------------------------------------------------
/examples/RE-DDI/rebuild_data.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import json
 7 | 
 8 | data_dir=sys.argv[1]
 9 | 
10 | 
11 | def sort_triples(triples, text):
12 |     sorted_triples = sorted(triples, key=lambda x:text.find(x['drug']))
13 |     return sorted_triples
14 | 
15 | 
16 | def build_target_seq_relis(triples):        
17 |     answer = ""
18 |     for z in triples:
19 |         drug = z["drug"].lower()
20 |         target = z["target"].lower()
21 |         rel = z["interaction"].lower()
22 |         answer += f"the interaction between {drug} and {target} is {rel}; "
23 | 
24 |     return answer[:-2] + "."
25 | 
26 | 
27 | def build_target_seq_2type(triples):        
28 |     answer = ""
29 |     for z in triples:
30 |         drug = z["drug"].lower()
31 |         target = z["target"].lower()
32 |         rel = z["interaction"].lower()
33 |         answer += f"{drug} and {target} are {rel}; "
34 | 
35 |     return answer[:-2] + "."
36 | 
37 | 
38 | def loader(fname, fn):
39 |     ret = []
40 |     null_cnt = 0
41 |     suc_cnt = 0
42 |     null_flag = False
43 |     with open(fname, "r", encoding="utf8") as fr:
44 |         data = json.load(fr)
45 |     for pmid, v in data.items():
46 |         content = v["abstract"].strip().replace('\n',' ')
47 | 
48 |         content = content.lower()
49 |         if v["triples"] is None or len(v["triples"]) == 0:
50 |             if not null_flag:
51 |                 print(f"Following PMID in {fname} has no extracted triples:")
52 |                 null_flag = True
53 |             print(f"{pmid} ", end="")
54 |             null_cnt += 1
55 |         else:
56 |             triples = v['triples']
57 |             triples = sort_triples(triples, content)
58 |             answer = fn(triples)
59 |             ret.append((pmid, content, answer))
60 |             suc_cnt += 1
61 |     if null_flag:
62 |         print("")
63 |     print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no triples extracted.")
64 |     return ret
65 | 
66 | 
67 | def dumper(content_list, prefix):
68 |     fw_pmid = open(prefix + ".pmid", "w")
69 |     fw_content = open(prefix + ".x", "w")
70 |     fw_label = open(prefix + ".y", "w")
71 |     
72 |     for pmid, x, y in content_list:
73 |         print(pmid, file=fw_pmid)
74 |         print(x, file=fw_content)
75 |         print(y, file=fw_label)
76 | 
77 |     fw_pmid.close()
78 |     fw_content.close()
79 |     fw_label.close()
80 | 
81 | 
82 | def worker(fname, prefix, fn):
83 |     ret = loader(fname, fn)
84 |     dumper(ret, prefix)
85 |            
86 | 
87 | for split in ['train', 'valid', 'test']:
88 |     worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis)
89 | 


--------------------------------------------------------------------------------
/examples/RE-DDI/train.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | SAVE_DIR=../../checkpoints/RE-DDI-BioGPT
 5 | mkdir -p ${SAVE_DIR}
 6 | 
 7 | fairseq-train \
 8 |     ../../data/DDI/relis-bin --save-dir ${SAVE_DIR} \
 9 |     --user-dir ../../src \
10 |     --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \
11 |     --task language_modeling_prompt \
12 |     --arch transformer_lm_prompt_biogpt \
13 |     --share-decoder-input-output-embed --decoder-learned-pos \
14 |     --optimizer adam --adam-betas '(0.9, 0.98)' \
15 |     --weight-decay 0.01 --clip-norm 0.0 \
16 |     --lr 1e-4 --lr-scheduler inverse_sqrt --warmup-updates 500 --warmup-init-lr 1e-07 \
17 |     --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \
18 |     --max-tokens 1024 --update-freq 32 \
19 |     --skip-invalid-size-inputs-valid-test \
20 |     --max-epoch 100 --keep-last-epochs 5 \
21 |     --learned-prompt 9


--------------------------------------------------------------------------------
/examples/RE-DTI/README.md:
--------------------------------------------------------------------------------
 1 | # Relation Extraction on KD-DTI
 2 | 
 3 | ## Data
 4 | According to the original [KD-DTI dataset](https://github.com/bert-nmt/BERT-DTI), before processing the data, you should first register a DrugBank account, download the xml dataset and replace the entity id with the entity name in the drugbank.
 5 | 
 6 | Then, you can process the data by:
 7 | ``` bash
 8 | bash preprocess.sh
 9 | ```
10 | 
11 | For more details, please see [here](https://github.com/bert-nmt/BERT-DTI).
12 | 
13 | ## Training
14 | You can fine-tune the pre-trained BioGPT on the task by:
15 | ``` bash
16 | bash train.sh
17 | ```
18 | 
19 | ## Model Checkpoint
20 | We provide our fine-tuned model on the task. See [here](../../README.md#pre-trained-models)
21 | 
22 | ## Inference and Evaluating
23 | You can inference and evalaute the model on the test set by:
24 | ``` bash
25 | bash infer.sh
26 | ```


--------------------------------------------------------------------------------
/examples/RE-DTI/hard_match_evaluation.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import re
  5 | import json
  6 | import sys
  7 | import os
  8 | 
  9 | pred_file = sys.argv[1]
 10 | gold_file = sys.argv[2]
 11 | pmids_file = sys.argv[3]
 12 | 
 13 | def normalize_name(s: str):
 14 |     s = s.strip()
 15 | 
 16 |     # normalize roman type id at end of string
 17 |     num2roman = {"0": "0", "1": "I", "2": "II", "3": "III", "4": "IV", "5": "V", "6": "VI", "7": "VII", "8": "VIII", "9": "IX"}
 18 |     if len(s) > 2 and s[-1].isnumeric() and not s[-2].isnumeric() and s[-1] in num2roman:
 19 |         tmps = list(s)
 20 |         s = ''.join(tmps[:-1]) + num2roman[tmps[-1]]
 21 | 
 22 |     # remove useless end string
 23 |     s = s.replace("-type", '')
 24 | 
 25 |     return re.sub('[^a-zA-Z0-9]+', '', s)
 26 | 
 27 | 
 28 | def rm_abbr(tgt_set):
 29 |     """ remove abbreviation in the brackets of entity, eg: aaa (bb) -> aaa """
 30 |     def rm(s):
 31 |         s = s.strip()
 32 |         if "(" in s and s[-1] == ')':  # entity end with a bracketed short cut
 33 |             return normalize_name(s[:s.rfind("(")].strip())
 34 |         else:
 35 |             return normalize_name(s)
 36 | 
 37 |     tgt_set = list(tgt_set)
 38 |     if tgt_set and type(tgt_set[0]) in [tuple, list]:  # process triples
 39 |         return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set])
 40 |     else:  # process entities
 41 |         return set([rm(e) for e in tgt_set])
 42 | 
 43 | 
 44 | def get_abbr(tgt_set):
 45 |     """ extract abbreviation in the brackets of entity, eg: aaa (bb) -> bb """
 46 |     def rm(s):
 47 |         s = s.strip()
 48 |         if "(" in s and s[-1] == ')':
 49 |             return normalize_name(s[s.rfind("(")+1:-1].strip())
 50 |         else:
 51 |             return normalize_name(s)
 52 | 
 53 |     tgt_set = list(tgt_set)
 54 |     if tgt_set and type(tgt_set[0]) in [tuple, list]:  # process triples
 55 |         return set([(rm(tp[0]), rm(tp[1]), rm(tp[2])) for tp in tgt_set])
 56 |     else:  # process entities
 57 |         return set([rm(e) for e in tgt_set])
 58 | 
 59 | 
 60 | def acc(pred_set, gold_set):
 61 |     """ Multi-label style acc """
 62 |     tp_num = len(pred_set & gold_set)
 63 |     return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(pred_set | gold_set)
 64 | 
 65 | 
 66 | def precision(pred_set, gold_set):
 67 |     """ Multi-label style precision """
 68 |     tp_num = len(pred_set & gold_set)
 69 |     return int(pred_set == gold_set) if len(pred_set) == 0 else 1.0 * tp_num / len(pred_set)
 70 | 
 71 | 
 72 | def recall(pred_set, gold_set):
 73 |     """ Multi-label style recall """
 74 |     tp_num = len(pred_set & gold_set)
 75 |     return int(pred_set == gold_set) if len(gold_set) == 0 else 1.0 * tp_num / len(gold_set)
 76 | 
 77 | 
 78 | def normed_eval(pred_set, gold_set, metric):
 79 |     """ Both body and abbreviation match are considered correct """
 80 |     abbr_pred_set, abbr_gold_set = get_abbr(pred_set), get_abbr(gold_set)
 81 |     rm_pred_set, rm_gold_set = rm_abbr(pred_set), rm_abbr(gold_set)
 82 |     return max(metric(abbr_pred_set, abbr_gold_set), metric(rm_pred_set, rm_gold_set))
 83 | 
 84 | 
 85 | def get_f1(p, r):
 86 |     return 0 if (p + r) == 0 else (2.0 * p * r / (p + r))
 87 | 
 88 | 
 89 | def ave(scores):
 90 |     return 1.0 * sum(scores) / len(scores)
 91 | 
 92 | 
 93 | def do_eval(preds, pmids, golden):
 94 |     ret = []
 95 |     num_pred, num_gold, num_missing = 0, 0, 0
 96 |     all_f1, p, r, d_acc, t_acc, i_acc = [], [], [], [], [], []
 97 |     all_pred_triple, all_pred_d, all_pred_t, all_pred_i, all_gold_triple, all_gold_d, all_gold_t, all_gold_i = [], [], [], [], [], [], [], [],
 98 | 
 99 |     for pred, idx in zip(preds, pmids):
100 |         gold_d_set, gold_t_set, gold_i_set, gold_set = set(), set(), set(), set()
101 |         pred_d_set, pred_t_set, pred_i_set, pred_set = set(), set(), set(), set()
102 |         
103 |         if pred["triple_list_pred"] and pred["triple_list_pred"][0]["subject"] != 'failed':
104 |             for tp in pred["triple_list_pred"]:
105 |                 d = tp["subject"].strip().lower().replace(' ', '')
106 |                 t = tp["object"].strip().lower().replace(' ', '')
107 |                 i = tp["relation"].strip().lower().replace(' ', '')
108 | 
109 |                 pred_d_set.add(d)
110 |                 pred_t_set.add(t)
111 |                 pred_i_set.add(i)
112 |                 pred_set.add((d, t, i))
113 |         if idx not in golden:
114 |             num_missing += 1
115 |             # print("----Missing:", idx)
116 |             continue
117 |         if golden[idx]["triples"]:
118 |             for tp in golden[idx]["triples"]:
119 |                 d = tp["drug"].strip().lower().replace(' ', '')
120 |                 t = tp["target"].strip().lower().replace(' ', '')
121 |                 i = tp["interaction"].strip().lower().replace(' ', '')
122 |                 gold_d_set.add(d)
123 |                 gold_t_set.add(t)
124 |                 gold_i_set.add(i)
125 |                 gold_set.add((d, t, i))
126 | 
127 |         # sample level eval
128 |         p.append(normed_eval(pred_set, gold_set, metric=precision))
129 |         r.append(normed_eval(pred_set, gold_set, metric=recall))
130 |         all_f1.append(get_f1(p[-1], r[-1]))
131 |         d_acc.append(normed_eval(pred_d_set, gold_d_set, metric=acc))
132 |         t_acc.append(normed_eval(pred_t_set, gold_t_set, metric=acc))
133 |         i_acc.append(normed_eval(pred_i_set, gold_i_set, metric=acc))
134 | 
135 |         # onto level eval
136 |         all_pred_d.extend(pred_d_set)
137 |         all_pred_t.extend(pred_t_set)
138 |         all_pred_i.extend(pred_i_set)
139 |         all_pred_triple.extend(pred_set)
140 |         all_gold_d.extend(gold_d_set)
141 |         all_gold_t.extend(gold_t_set)
142 |         all_gold_i.extend(gold_i_set)
143 |         all_gold_triple.extend(gold_set)
144 |         
145 |         # if len(gold_set) < len(golden[idx]["triples"]):
146 |             # print("Duplicate extists, ori", golden[idx]["triples"], gold_set)
147 | 
148 |         num_pred += len(pred_set)
149 |         num_gold += len(gold_set)
150 | 
151 |         ret.append({
152 |             "pmid": idx,
153 |             "title": golden[idx]["title"] if "title" in golden[idx] else None,
154 |             "abstract": golden[idx]["abstract"],
155 |             "d_pred_gold": [d_acc[-1], list(pred_d_set), list(gold_d_set)],
156 |             "t_pred_gold": [t_acc[-1], list(pred_t_set), list(gold_t_set)],
157 |             "i_pred_gold": [i_acc[-1], list(pred_i_set), list(gold_i_set)],
158 |             "all_pred_gold": [all_f1[-1], list(pred_set), list(gold_set)],
159 |         })
160 | 
161 | 
162 |     print("num sample", len(all_f1), "missing", len(preds) - len(all_f1), "num_gold tp", num_gold, "num_pred", num_pred)
163 | 
164 |     # Note: we adopt multi-label metrics following: http://129.211.169.156/publication/tkde13rev.pdf
165 |     print("Sample: acc d: {:.4f}\tt:{:.4f}\ti: {:.4f}\ntp p: {:.4f}\ttp r: {:.4f}\ttp micro f1: {:.4f}\ttp macro f1: {:.4f} ".format(
166 |         ave(d_acc), ave(t_acc), ave(i_acc), ave(p), ave(r), ave(all_f1), get_f1(ave(p), ave(r))))
167 | 
168 |     # Ontology evaluation_scripts
169 |     all_p, all_r = normed_eval(set(all_pred_triple), set(all_gold_triple), metric=precision), normed_eval(set(all_pred_triple), set(all_gold_triple), metric=recall)
170 |     d_p, d_r = normed_eval(set(all_pred_d), set(all_gold_d), metric=precision), normed_eval(set(all_pred_d), set(all_gold_d), metric=recall)
171 |     t_p, t_r = normed_eval(set(all_pred_t), set(all_gold_t), metric=precision), normed_eval(set(all_pred_t), set(all_gold_t), metric=recall)
172 |     i_p, i_r = normed_eval(set(all_pred_i), set(all_gold_i), metric=precision), normed_eval(set(all_pred_i), set(all_gold_i), metric=recall)
173 | 
174 |     print("Ontology: f1 d: {:.4f}\tt:{:.4f}\ti: {:.4f}\t \nall p: {:.4f}\tall r: {:.4f}\tonto f1: {:.4f}".format(
175 |         get_f1(d_p, d_r), get_f1(t_p, t_r), get_f1(i_p, i_r), all_p, all_r, get_f1(all_p, all_r)
176 |     ))
177 |     return ret
178 | 
179 | 
180 | def main():
181 |     preds = []
182 |     with open(pred_file) as reader:
183 |         for line in reader:
184 |             preds.append(json.loads(line))
185 | 
186 |     with open(gold_file) as reader:
187 |         golden = json.load(reader)
188 | 
189 |     with open(pmids_file) as reader:
190 |         if '.json' in pmids_file:
191 |             pmids = json.load(reader)
192 |         else:
193 |             pmids = []
194 |             for line in reader:
195 |                 pmids.append(line.strip())
196 | 
197 |     print("\n====File: ", os.path.basename(pred_file))
198 |     result = do_eval(preds, pmids, golden)
199 | 
200 |     last_pos = pred_file.rfind('.json')
201 |     res_file_name = pred_file[:last_pos] + '.eval_res.json'
202 |     with open(res_file_name, 'w') as writer:
203 |         json.dump(result, writer, indent=2)
204 | 
205 | 
206 | if __name__ == "__main__":
207 |     main()
208 | 


--------------------------------------------------------------------------------
/examples/RE-DTI/infer.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | MODEL_DIR=../../checkpoints/RE-DTI-BioGPT
 5 | MODEL=checkpoint_avg.pt
 6 | DATA_DIR=${PWD}/../../data/KD-DTI/relis-bin
 7 | BASE_DATA_DIR=${DATA_DIR%/*}
 8 | BIN_DATA_DIR=${DATA_DIR##*/}
 9 | DATA_PREFIX=${BIN_DATA_DIR%-*}
10 | RAW_DATA_DIR=${BASE_DATA_DIR}/raw
11 | OUTPUT_FILE=generate_${MODEL}
12 | 
13 | INPUT_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.tok.bpe.x
14 | OUTPUT_FILE=${MODEL_DIR}/${OUTPUT_FILE}
15 | GOLD_FILE=${RAW_DATA_DIR}/test.json
16 | PMID_FILE=${RAW_DATA_DIR}/${DATA_PREFIX}_test.pmid
17 | 
18 | # average checkpoints
19 | if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
20 |     python ../../scripts/average_checkpoints.py --inputs=${MODEL_DIR} --output=${MODEL_DIR}/${MODEL} --num-epoch-checkpoints=5
21 | fi
22 | 
23 | # inference
24 | if [ ! -f "$OUTPUT_FILE" ]; then
25 |     echo "Begin inferencing ${INPUT_FILE} using ${MODEL_DIR}/${MODEL}"
26 |     python ../../inference.py --data_dir=${DATA_DIR} --model_dir=${MODEL_DIR} --model_file=${MODEL} --src_file=${INPUT_FILE} --output_file=${OUTPUT_FILE}
27 | fi
28 | 
29 | # debpe
30 | sed -i "s/@@ //g" ${OUTPUT_FILE}
31 | # detok
32 | perl ${MOSES}/scripts/tokenizer/detokenizer.perl -l en -a < ${OUTPUT_FILE} > ${OUTPUT_FILE}.detok
33 | # postprocess
34 | python postprocess.py ${OUTPUT_FILE}.detok
35 | # eval
36 | python hard_match_evaluation.py ${OUTPUT_FILE}.detok.extracted.json ${GOLD_FILE} ${PMID_FILE}
37 | 


--------------------------------------------------------------------------------
/examples/RE-DTI/postprocess.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import os
 5 | import sys
 6 | import re
 7 | import json
 8 | 
 9 | 
10 | out_file = sys.argv[1]
11 | 
12 | prefix = [
13 |     '(learned[0-9]+ )+',
14 |     'we can conclude that',
15 |     'we have that',
16 |     'in conclusion,',
17 |     ]
18 | 
19 | 
20 | def strip_prefix(line):
21 |     for p in prefix:
22 |         res = re.search(p, line)
23 |         if res is not None:
24 |             line = re.split(p, line)[-1].strip()
25 |             break
26 |     return line
27 | 
28 | 
29 | def split_sentence(line):
30 |     sentences = re.split(r"; ", line)
31 |     return sentences
32 | 
33 | 
34 | def convert_relis_sentence(sentence):
35 |     ans = None
36 |     segs = re.match(r"the interaction between (.*) and (.*) is (.*)", sentence)
37 |     if segs is not None:
38 |         segs = segs.groups()
39 |         ans = (segs[0].strip(), segs[2].strip(), segs[1].strip())
40 |     return ans
41 | 
42 | 
43 | def converter(sample, h_idx=0, r_idx=1, t_idx=2):
44 |     ret = {"triple_list_gold": [], "triple_list_pred": [], "new": [], "lack": [], "id": [0]}
45 |     for s in sample:
46 |         ret["triple_list_pred"].append({"subject": s[h_idx], "relation": s[r_idx], "object": s[t_idx]})
47 |     return ret
48 | 
49 | 
50 | all_lines = []
51 | with open(out_file, "r", encoding="utf8") as fr:
52 |     for line in fr:
53 |         e = line.strip()
54 |         if len(e) > 0 and e[-1] == ".":
55 |             all_lines.append(e[:-1])
56 |         else:
57 |             all_lines.append(e)
58 | 
59 | 
60 | hypothesis = []
61 | cnt = 0
62 | fail_cnt = 0
63 | 
64 | 
65 | for i, line in enumerate(all_lines):
66 |     cnt += 1
67 |     ret = []
68 |     strip_line = strip_prefix(line)
69 |     sentences = split_sentence(strip_line)
70 |     for sen in sentences:
71 |         ans = convert_relis_sentence(sen)
72 |         if ans is not None:
73 |             ret.append(ans)
74 |     if len(ret) > 0:
75 |         hypothesis.append(ret)
76 |     else:
77 |         hypothesis.append([("failed", "failed", "failed")])
78 |         fail_cnt += 1
79 |         print("Failed:id:{}, line:{}".format(i+1, line))
80 | 
81 | 
82 | ret_formatted = []
83 | for i in range(len(hypothesis)):
84 |     ret_formatted.append(converter(hypothesis[i]))
85 | 
86 | 
87 | with open(f"{out_file}.extracted.json", "w", encoding="utf8") as fw:
88 |     for eg in ret_formatted:
89 |         print(json.dumps(eg), file=fw)
90 | 
91 | 
92 | print(f"failed = {fail_cnt}, total = {cnt}")
93 | 


--------------------------------------------------------------------------------
/examples/RE-DTI/preprocess.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | DATA_DIR=../../data/KD-DTI
 5 | prefix=relis
 6 | RAW_DATA_DIR=${DATA_DIR}/raw
 7 | OUTPUT_DIR=${DATA_DIR}/${prefix}-bin
 8 | 
 9 | if [ -d "${OUTPUT_DIR}" ]; then
10 |     rm -rf ${OUTPUT_DIR}
11 | fi
12 | 
13 | python rebuild_data.py ${RAW_DATA_DIR}
14 | 
15 | cp ${DATA_DIR}/../dict.txt ${RAW_DATA_DIR}/
16 | cp ${DATA_DIR}/../bpecodes ${RAW_DATA_DIR}/
17 | 
18 | SPLIT=(train valid test)
19 | 
20 | for ff in ${SPLIT[@]}; do
21 |     if [ -f "${RAW_DATA_DIR}/${prefix}_$ff.y" ]; then
22 |         echo "Preprocessing ${ff}"
23 | 
24 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.x > ${RAW_DATA_DIR}/${prefix}_$ff.tok.x
25 |         perl ${MOSES}/scripts/tokenizer/tokenizer.perl -l en -a -threads 8 < ${RAW_DATA_DIR}/${prefix}_$ff.y > ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
26 | 
27 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/bpecodes
28 |         ${FASTBPE}/fast applybpe ${RAW_DATA_DIR}/${prefix}_$ff.tok.bpe.y ${RAW_DATA_DIR}/${prefix}_$ff.tok.y ${RAW_DATA_DIR}/bpecodes
29 | 
30 |         rm ${RAW_DATA_DIR}/${prefix}_$ff.tok.x ${RAW_DATA_DIR}/${prefix}_$ff.tok.y
31 |     fi
32 | done
33 | 
34 | # do binarize
35 | fairseq-preprocess \
36 |     -s x -t y --workers 8 \
37 |     --joined-dictionary \
38 |     --trainpref ${RAW_DATA_DIR}/${prefix}_train.tok.bpe \
39 |     --validpref ${RAW_DATA_DIR}/${prefix}_valid.tok.bpe \
40 |     --testpref ${RAW_DATA_DIR}/${prefix}_test.tok.bpe \
41 |     --destdir ${OUTPUT_DIR} \
42 |     --srcdict ${RAW_DATA_DIR}/dict.txt


--------------------------------------------------------------------------------
/examples/RE-DTI/rebuild_data.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import os
  5 | import sys
  6 | import json
  7 | import re
  8 | 
  9 | data_dir=sys.argv[1]
 10 | 
 11 | def map_relation_to_verb(relation):
 12 |     special_mapping = {
 13 |         "product of": "is the product of",
 14 |         "negative modulator": "negatively modulates",
 15 |         "other/unknown": "randomly works with",
 16 |         "other": "randomly works with",
 17 |         "incorporation into and destabilization": "incorporates into and destabilizates",
 18 |         "cross-linking/alkylation": "cross lines / alkylates",
 19 |         "antibody": "is the antibody of",
 20 |         "downregulator": "downregulates",
 21 |         "desensitize the target": "desensitizes",
 22 |         "protector": "protects",
 23 |         "inhibitor": "inhibits",
 24 |         "weak inhibitor": "weakly inhibits",
 25 |         "blocker": "blocks"
 26 |     }
 27 |     if relation in special_mapping:
 28 |         return special_mapping[relation]
 29 | 
 30 |     if relation.endswith("agonist") or relation.endswith("antagonist"):
 31 |         return relation + "s"
 32 |     
 33 |     if relation.endswith("or") or relation.endswith("er"):
 34 |         return relation[:-2] + "es"
 35 | 
 36 |     if relation.endswith("tion"):
 37 |         return relation[:-3] + "es"
 38 | 
 39 |     if relation.endswith("ing"):
 40 |         return relation[:-3] + "s"
 41 | 
 42 |     return relation + "s"
 43 | 
 44 | 
 45 | def sort_triples(triples, text):
 46 |     sorted_triples = sorted(triples, key=lambda x:text.find(x['drug']))
 47 |     return sorted_triples
 48 | 
 49 | 
 50 | def build_target_seq_svo(triples):        
 51 |     answer = ""
 52 |     for z in triples:
 53 |         drug = z["drug"].lower()
 54 |         target = z["target"].lower()
 55 |         rel = map_relation_to_verb(z["interaction"].lower())
 56 |         answer += f"{drug} {rel} {target}; "
 57 | 
 58 |     return answer[:-2] + "."
 59 | 
 60 | 
 61 | def build_target_seq_isof(triples):        
 62 |     answer = ""
 63 |     for z in triples:
 64 |         drug = z["drug"].lower()
 65 |         target = z["target"].lower()
 66 |         rel = z["interaction"].lower()
 67 |         answer += f"{drug} is the {rel} of {target}; "
 68 | 
 69 |     return answer[:-2] + "."
 70 | 
 71 | 
 72 | def build_target_seq_htr(triples):        
 73 |     answer = ""
 74 |     for z in triples:
 75 |         drug = z["drug"].lower()
 76 |         target = z["target"].lower()
 77 |         rel = z["interaction"].lower()
 78 |         answer += f"<h> {drug} <t> {target} <r> {rel} "
 79 | 
 80 |     return answer[:-1] + "."
 81 | 
 82 | 
 83 | def build_target_seq_relis(triples):        
 84 |     answer = ""
 85 |     for z in triples:
 86 |         drug = z["drug"].lower()
 87 |         target = z["target"].lower()
 88 |         rel = z["interaction"].lower()
 89 |         answer += f"the interaction between {drug} and {target} is {rel}; "
 90 | 
 91 |     return answer[:-2] + "."
 92 | 
 93 | 
 94 | def loader(fname, fn):
 95 |     ret = []
 96 |     null_cnt = 0
 97 |     suc_cnt = 0
 98 |     null_flag = False
 99 |     with open(fname, "r", encoding="utf8") as fr:
100 |         data = json.load(fr)
101 |     for pmid, v in data.items():
102 |         if re.search(r"\W
quot;, v["title"]):
103 |             content = v["title"] + " " + v["abstract"]
104 |         else:
105 |             content = v["title"] + ". " + v["abstract"]
106 | 
107 |         content = content.lower()
108 |         if v["triples"] is None or len(v["triples"]) == 0:
109 |             if not null_flag:
110 |                 print(f"Following PMID in {fname} has no extracted triples:")
111 |                 null_flag = True
112 |             print(f"{pmid} ", end="")
113 |             null_cnt += 1
114 |         else:
115 |             triples = v['triples']
116 |             triples = sort_triples(triples, content)
117 |             answer = fn(triples)
118 |             ret.append((pmid, content, answer))
119 |             suc_cnt += 1
120 |     if null_flag:
121 |         print("")
122 |     print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no triples extracted.")
123 |     return ret
124 | 
125 | 
126 | def dumper(content_list, prefix):
127 |     fw_pmid = open(prefix + ".pmid", "w")
128 |     fw_content = open(prefix + ".x", "w")
129 |     fw_label = open(prefix + ".y", "w")
130 |     
131 |     for ele in content_list:
132 |         print(ele[0], file=fw_pmid)
133 |         print(ele[1], file=fw_content)
134 |         print(ele[2], file=fw_label)
135 | 
136 |     fw_pmid.close()
137 |     fw_content.close()
138 |     fw_label.close()
139 | 
140 | 
141 | def worker(fname, prefix, fn):
142 |     ret = loader(fname, fn)
143 |     dumper(ret, prefix)
144 |            
145 | 
146 | for split in ['train', 'valid', 'test']:
147 |     worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis)
148 |     #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"isof_{split}"), build_target_seq_isof)
149 |     #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"svo_{split}"), build_target_seq_svo)
150 |     #worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"htr_{split}"), build_target_seq_htr)


--------------------------------------------------------------------------------
/examples/RE-DTI/train.sh:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | SAVE_DIR=../../checkpoints/RE-DTI-BioGPT
 5 | mkdir -p ${SAVE_DIR}
 6 | 
 7 | fairseq-train \
 8 |     ../../data/KD-DTI/relis-bin --save-dir ${SAVE_DIR} \
 9 |     --user-dir ../../src \
10 |     --finetune-from-model ../../checkpoints/Pre-trained-BioGPT/checkpoint.pt \
11 |     --task language_modeling_prompt \
12 |     --arch transformer_lm_prompt_biogpt \
13 |     --share-decoder-input-output-embed --decoder-learned-pos \
14 |     --optimizer adam --adam-betas '(0.9, 0.98)' \
15 |     --weight-decay 0.01 --clip-norm 0.0 \
16 |     --lr 1e-5 --lr-scheduler inverse_sqrt --warmup-updates 1000 --warmup-init-lr 1e-07 \
17 |     --tokens-per-sample 1024 --max-source-positions 640 --max-target-positions 1024 \
18 |     --max-tokens 1024 --update-freq 32 \
19 |     --skip-invalid-size-inputs-valid-test \
20 |     --max-epoch 30 --keep-last-epochs 5 \
21 |     --learned-prompt 9


--------------------------------------------------------------------------------
/examples/text-generation/README.md:
--------------------------------------------------------------------------------
 1 | # Text Generation
 2 | You can use the pre-trained BioGPT model for free text generation, just as how you use GPT models.
 3 | ## Model Checkpoint
 4 | We provide our pre-trained BioGPT model. See [here](../../README.md#pre-trained-models)
 5 | 
 6 | ## Generation
 7 | We here provide an interactive way for generation:
 8 | ``` bash
 9 | python interactive.py
10 | ```
11 | 


--------------------------------------------------------------------------------
/examples/text-generation/interactive.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import argparse
 5 | from fairseq.models.transformer_lm import TransformerLanguageModel
 6 | 
 7 | parser = argparse.ArgumentParser()
 8 | parser.add_argument("--data_dir", type=str, default="../../data/PubMed/data-bin")
 9 | parser.add_argument("--model_dir", type=str, default="../../checkpoints/Pre-trained-BioGPT")
10 | parser.add_argument("--model_file", type=str, default="checkpoint.pt")
11 | parser.add_argument("--bpecodes", type=str, default="../../data/bpecodes")
12 | parser.add_argument("--beam", type=int, default=5)
13 | parser.add_argument("--lenpen", type=float, default=1.0)
14 | parser.add_argument("--min_len", type=int, default=100)
15 | parser.add_argument("--lower", default=False, action="store_true")
16 | args, _ = parser.parse_known_args()
17 | 
18 | 
19 | def main(args):
20 | 
21 |     m = TransformerLanguageModel.from_pretrained(
22 |         args.model_dir, 
23 |         args.model_file, 
24 |         args.data_dir, 
25 |         tokenizer='moses', 
26 |         bpe='fastbpe', 
27 |         bpe_codes=args.bpecodes,
28 |         min_len=args.min_len,
29 |         max_len_b=1024,
30 |         beam=args.beam,
31 |         lenpen=args.lenpen,
32 |         max_tokens=12000)
33 | 
34 |     print(m.cfg)
35 |     if m.cfg.common.fp16:
36 |         print('Converting to float 16')
37 |         m.half()
38 |     m.cuda()
39 | 
40 |     while True:
41 |         print("Please input and press enter:")
42 |         _src = input().strip()
43 |         src_tokens = m.encode(_src)
44 |         generate = m.generate([src_tokens], beam=args.beam)[0]
45 |         output = m.decode(generate[0]["tokens"])
46 |         print(output)
47 | if __name__ == "__main__":
48 |     main(args)


--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import argparse
 5 | from src.transformer_lm_prompt import TransformerLanguageModelPrompt
 6 | 
 7 | 
 8 | parser = argparse.ArgumentParser()
 9 | parser.add_argument("--data_dir", type=str, default='')
10 | parser.add_argument("--model_dir", type=str, default=None)
11 | parser.add_argument("--model_file", type=str, default="checkpoint_last.pt")
12 | parser.add_argument("--src_file", type=str, default=None)
13 | parser.add_argument("--output_file", type=str, default=None)
14 | parser.add_argument("--beam", type=int, default=1)
15 | parser.add_argument("--decoding_length", type=int, default=1024)
16 | args, _ = parser.parse_known_args()
17 | 
18 | 
19 | def main(args):
20 |     src_inputs = []
21 |     with open(args.src_file) as reader:
22 |         for line in reader:
23 |             src_inputs.append(line.strip())
24 |     
25 |     m = TransformerLanguageModelPrompt.from_pretrained(
26 |         args.model_dir, 
27 |         args.model_file, 
28 |         args.data_dir,
29 |         max_len_b=args.decoding_length,
30 |         max_tokens=12000,)
31 | 
32 |     print(m.cfg)
33 | 
34 |     if m.cfg.common.fp16:
35 |         print('Converting to float 16')
36 |         m.half()
37 |     m.cuda()
38 | 
39 |     outputs = m.sample(src_inputs, beam=args.beam)
40 | 
41 |     with open(f"{args.output_file}", "w", encoding='utf8') as fw:
42 |         for i in range(len(outputs)):
43 |             fw.write(outputs[i] + '\n')
44 | 
45 | 
46 | if __name__ == "__main__":
47 |     main(args)


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | antlr4-python3-runtime==4.8
 2 | bitarray==2.6.2
 3 | cffi==1.15.1
 4 | click==8.1.3
 5 | colorama==0.4.6
 6 | Cython==0.29.33
 7 | fairseq==0.12.2
 8 | hydra-core==1.0.7
 9 | joblib==1.2.0
10 | lxml==4.9.2
11 | numpy==1.24.1
12 | omegaconf==2.0.6
13 | portalocker==2.7.0
14 | protobuf==3.20.1
15 | pycparser==2.21
16 | PyYAML==6.0
17 | regex==2022.10.31
18 | sacrebleu==2.3.1
19 | sacremoses==0.0.53
20 | scikit-learn==1.2.1
21 | scipy==1.10.0
22 | six==1.16.0
23 | tabulate==0.9.0
24 | tensorboardX==2.5.1
25 | threadpoolctl==3.1.0
26 | torch==1.12.0
27 | torchaudio==0.12.0
28 | tqdm==4.64.1
29 | typing-extensions==4.4.0
30 | 


--------------------------------------------------------------------------------
/scripts/average_checkpoints.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | #!/usr/bin/env python3
  5 | # Copyright (c) Facebook, Inc. and its affiliates.
  6 | #
  7 | # This source code is licensed under the MIT license found in the
  8 | # LICENSE file in the root directory of this source tree.
  9 | 
 10 | import argparse
 11 | import collections
 12 | import os
 13 | import re
 14 | 
 15 | import torch
 16 | from fairseq.file_io import PathManager
 17 | 
 18 | 
 19 | def average_checkpoints(inputs):
 20 |     """Loads checkpoints from inputs and returns a model with averaged weights.
 21 | 
 22 |     Args:
 23 |       inputs: An iterable of string paths of checkpoints to load from.
 24 | 
 25 |     Returns:
 26 |       A dict of string keys mapping to various values. The 'model' key
 27 |       from the returned dict should correspond to an OrderedDict mapping
 28 |       string parameter names to torch Tensors.
 29 |     """
 30 |     params_dict = collections.OrderedDict()
 31 |     params_keys = None
 32 |     new_state = None
 33 |     num_models = len(inputs)
 34 | 
 35 |     for fpath in inputs:
 36 |         with PathManager.open(fpath, "rb") as f:
 37 |             state = torch.load(
 38 |                 f,
 39 |                 map_location=(
 40 |                     lambda s, _: torch.serialization.default_restore_location(s, "cpu")
 41 |                 ),
 42 |             )
 43 |         # Copies over the settings from the first checkpoint
 44 |         if new_state is None:
 45 |             new_state = state
 46 | 
 47 |         model_params = state["model"]
 48 | 
 49 |         model_params_keys = list(model_params.keys())
 50 |         if params_keys is None:
 51 |             params_keys = model_params_keys
 52 |         elif params_keys != model_params_keys:
 53 |             raise KeyError(
 54 |                 "For checkpoint {}, expected list of params: {}, "
 55 |                 "but found: {}".format(f, params_keys, model_params_keys)
 56 |             )
 57 | 
 58 |         for k in params_keys:
 59 |             p = model_params[k]
 60 |             if isinstance(p, torch.HalfTensor):
 61 |                 p = p.float()
 62 |             if k not in params_dict:
 63 |                 params_dict[k] = p.clone()
 64 |                 # NOTE: clone() is needed in case of p is a shared parameter
 65 |             else:
 66 |                 params_dict[k] += p
 67 | 
 68 |     averaged_params = collections.OrderedDict()
 69 |     for k, v in params_dict.items():
 70 |         averaged_params[k] = v
 71 |         if averaged_params[k].is_floating_point():
 72 |             averaged_params[k].div_(num_models)
 73 |         else:
 74 |             averaged_params[k] //= num_models
 75 |     new_state["model"] = averaged_params
 76 |     return new_state
 77 | 
 78 | 
 79 | def last_n_checkpoints(paths, n, update_based, upper_bound=None):
 80 |     assert len(paths) == 1
 81 |     path = paths[0]
 82 |     if update_based:
 83 |         pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
 84 |     else:
 85 |         pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
 86 |     files = PathManager.ls(path)
 87 | 
 88 |     entries = []
 89 |     for f in files:
 90 |         m = pt_regexp.fullmatch(f)
 91 |         if m is not None:
 92 |             sort_key = int(m.group(1))
 93 |             if upper_bound is None or sort_key <= upper_bound:
 94 |                 entries.append((sort_key, m.group(0)))
 95 |     if len(entries) < n:
 96 |         raise Exception(
 97 |             "Found {} checkpoint files but need at least {}", len(entries), n
 98 |         )
 99 |     return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
100 | 
101 | 
102 | def main():
103 |     parser = argparse.ArgumentParser(
104 |         description="Tool to average the params of input checkpoints to "
105 |         "produce a new checkpoint",
106 |     )
107 |     # fmt: off
108 |     parser.add_argument('--inputs', required=True, nargs='+',
109 |                         help='Input checkpoint file paths.')
110 |     parser.add_argument('--output', required=True, metavar='FILE',
111 |                         help='Write the new checkpoint containing the averaged weights to this path.')
112 |     num_group = parser.add_mutually_exclusive_group()
113 |     num_group.add_argument('--num-epoch-checkpoints', type=int,
114 |                            help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
115 |                            'and average last this many of them.')
116 |     num_group.add_argument('--num-update-checkpoints', type=int,
117 |                            help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
118 |                            'and average last this many of them.')
119 |     parser.add_argument('--checkpoint-upper-bound', type=int,
120 |                         help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
121 |                         'when using --num-update-checkpoints, this will set an upper bound on which update to use'
122 |                         'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.'
123 |                         'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500'
124 |                         )
125 |     # fmt: on
126 |     args = parser.parse_args()
127 |     print(args)
128 | 
129 |     num = None
130 |     is_update_based = False
131 |     if args.num_update_checkpoints is not None:
132 |         num = args.num_update_checkpoints
133 |         is_update_based = True
134 |     elif args.num_epoch_checkpoints is not None:
135 |         num = args.num_epoch_checkpoints
136 | 
137 |     assert args.checkpoint_upper_bound is None or (
138 |         args.num_epoch_checkpoints is not None
139 |         or args.num_update_checkpoints is not None
140 |     ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
141 |     assert (
142 |         args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
143 |     ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
144 | 
145 |     if num is not None:
146 |         args.inputs = last_n_checkpoints(
147 |             args.inputs,
148 |             num,
149 |             is_update_based,
150 |             upper_bound=args.checkpoint_upper_bound,
151 |         )
152 |         print("averaging checkpoints: ", args.inputs)
153 | 
154 |     new_state = average_checkpoints(args.inputs)
155 |     with PathManager.open(args.output, "wb") as f:
156 |         torch.save(new_state, f)
157 |     print("Finished writing averaged checkpoint to {}".format(args.output))
158 | 
159 | 
160 | if __name__ == "__main__":
161 |     main()
162 | 


--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | 
4 | from .language_modeling_prompt import *
5 | from .transformer_lm_prompt import *
6 | from .language_model_prompt_dataset import *
7 | from .constrained_generator import *


--------------------------------------------------------------------------------
/src/constrained_generator.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import math
  5 | from typing import Dict, List, Optional
  6 | 
  7 | import torch
  8 | import torch.nn as nn
  9 | from fairseq import search, utils
 10 | from fairseq.data import data_utils
 11 | from fairseq.models import FairseqIncrementalDecoder
 12 | from torch import Tensor
 13 | from fairseq.ngram_repeat_block import NGramRepeatBlock
 14 | from fairseq.sequence_generator import SequenceGenerator, EnsembleModel
 15 | 
 16 | 
 17 | class ConstrainedGenerator(nn.Module):
 18 |     def __init__(
 19 |         self,
 20 |         models,
 21 |         tgt_dict,
 22 |         beam_size=1,
 23 |         max_len_a=0,
 24 |         max_len_b=200,
 25 |         max_len=0,
 26 |         min_len=1,
 27 |         normalize_scores=True,
 28 |         len_penalty=1.0,
 29 |         unk_penalty=0.0,
 30 |         temperature=1.0,
 31 |         match_source_len=False,
 32 |         no_repeat_ngram_size=0,
 33 |         search_strategy=None,
 34 |         eos=None,
 35 |         symbols_to_strip_from_output=None,
 36 |         lm_model=None,
 37 |         lm_weight=1.0,
 38 |     ):
 39 |         """Generates translations of a given source sentence.
 40 | 
 41 |         Args:
 42 |             models (List[~fairseq.models.FairseqModel]): ensemble of models,
 43 |                 currently support fairseq.models.TransformerModel for scripting
 44 |             beam_size (int, optional): beam width (default: 1)
 45 |             max_len_a/b (int, optional): generate sequences of maximum length
 46 |                 ax + b, where x is the source length
 47 |             max_len (int, optional): the maximum length of the generated output
 48 |                 (not including end-of-sentence)
 49 |             min_len (int, optional): the minimum length of the generated output
 50 |                 (not including end-of-sentence)
 51 |             normalize_scores (bool, optional): normalize scores by the length
 52 |                 of the output (default: True)
 53 |             len_penalty (float, optional): length penalty, where <1.0 favors
 54 |                 shorter, >1.0 favors longer sentences (default: 1.0)
 55 |             unk_penalty (float, optional): unknown word penalty, where <0
 56 |                 produces more unks, >0 produces fewer (default: 0.0)
 57 |             temperature (float, optional): temperature, where values
 58 |                 >1.0 produce more uniform samples and values <1.0 produce
 59 |                 sharper samples (default: 1.0)
 60 |             match_source_len (bool, optional): outputs should match the source
 61 |                 length (default: False)
 62 |         """
 63 |         super().__init__()
 64 |         if isinstance(models, EnsembleModel):
 65 |             self.model = models
 66 |         else:
 67 |             self.model = EnsembleModel(models)
 68 |         self.tgt_dict = tgt_dict
 69 |         self.pad = tgt_dict.pad()
 70 |         self.unk = tgt_dict.unk()
 71 |         self.eos = tgt_dict.eos() if eos is None else eos
 72 |         self.symbols_to_strip_from_output = (
 73 |             symbols_to_strip_from_output.union({self.eos})
 74 |             if symbols_to_strip_from_output is not None
 75 |             else {self.eos}
 76 |         )
 77 |         self.vocab_size = len(tgt_dict)
 78 |         self.beam_size = beam_size
 79 |         # the max beam size is the dictionary size - 1, since we never select pad
 80 |         self.beam_size = min(beam_size, self.vocab_size - 1)
 81 |         self.max_len_a = max_len_a
 82 |         self.max_len_b = max_len_b
 83 |         self.min_len = min_len
 84 |         self.max_len = max_len or self.model.max_decoder_positions()
 85 | 
 86 |         self.normalize_scores = normalize_scores
 87 |         self.len_penalty = len_penalty
 88 |         self.unk_penalty = unk_penalty
 89 |         self.temperature = temperature
 90 |         self.match_source_len = match_source_len
 91 | 
 92 |         if no_repeat_ngram_size > 0:
 93 |             self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
 94 |         else:
 95 |             self.repeat_ngram_blocker = None
 96 | 
 97 |         assert temperature > 0, "--temperature must be greater than 0"
 98 | 
 99 |         self.search = (
100 |             search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
101 |         )
102 |         # We only need to set src_lengths in LengthConstrainedBeamSearch.
103 |         # As a module attribute, setting it would break in multithread
104 |         # settings when the model is shared.
105 |         self.should_set_src_lengths = (
106 |             hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
107 |         )
108 | 
109 |         self.model.eval()
110 | 
111 |         self.lm_model = lm_model
112 |         self.lm_weight = lm_weight
113 |         if self.lm_model is not None:
114 |             self.lm_model.eval()
115 | 
116 |     def cuda(self):
117 |         self.model.cuda()
118 |         return self
119 | 
120 |     @torch.no_grad()
121 |     def forward(
122 |         self,
123 |         sample: Dict[str, Dict[str, Tensor]],
124 |         prefix_tokens: Optional[Tensor] = None,
125 |         bos_token: Optional[int] = None,
126 |     ):
127 |         """Generate a batch of translations.
128 | 
129 |         Args:
130 |             sample (dict): batch
131 |             prefix_tokens (torch.LongTensor, optional): force decoder to begin
132 |                 with these tokens
133 |             bos_token (int, optional): beginning of sentence token
134 |                 (default: self.eos)
135 |         """
136 |         return self._generate(sample, prefix_tokens, bos_token=bos_token)
137 | 
138 |     # TODO(myleott): unused, deprecate after pytorch-translate migration
139 |     def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
140 |         """Iterate over a batched dataset and yield individual translations.
141 |         Args:
142 |             cuda (bool, optional): use GPU for generation
143 |             timer (StopwatchMeter, optional): time generations
144 |         """
145 |         for sample in data_itr:
146 |             s = utils.move_to_cuda(sample) if cuda else sample
147 |             if "net_input" not in s:
148 |                 continue
149 |             input = s["net_input"]
150 |             # model.forward normally channels prev_output_tokens into the decoder
151 |             # separately, but SequenceGenerator directly calls model.encoder
152 |             encoder_input = {
153 |                 k: v for k, v in input.items() if k != "prev_output_tokens"
154 |             }
155 |             if timer is not None:
156 |                 timer.start()
157 |             with torch.no_grad():
158 |                 hypos = self.generate(encoder_input)
159 |             if timer is not None:
160 |                 timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
161 |             for i, id in enumerate(s["id"].data):
162 |                 # remove padding
163 |                 src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
164 |                 ref = (
165 |                     utils.strip_pad(s["target"].data[i, :], self.pad)
166 |                     if s["target"] is not None
167 |                     else None
168 |                 )
169 |                 yield id, src, ref, hypos[i]
170 | 
171 |     @torch.no_grad()
172 |     def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
173 |         """Generate translations. Match the api of other fairseq generators.
174 | 
175 |         Args:
176 |             models (List[~fairseq.models.FairseqModel]): ensemble of models
177 |             sample (dict): batch
178 |             prefix_tokens (torch.LongTensor, optional): force decoder to begin
179 |                 with these tokens
180 |             constraints (torch.LongTensor, optional): force decoder to include
181 |                 the list of constraints
182 |             bos_token (int, optional): beginning of sentence token
183 |                 (default: self.eos)
184 |         """
185 |         return self._generate(sample, **kwargs)
186 | 
187 |     def _generate(
188 |         self,
189 |         sample: Dict[str, Dict[str, Tensor]],
190 |         prefix_tokens: Optional[Tensor] = None,
191 |         constraints: Optional[Tensor] = None,
192 |         bos_token: Optional[int] = None,
193 |         allowed_text: Optional[List[Tensor]] = None,
194 |     ):
195 |         incremental_states = torch.jit.annotate(
196 |             List[Dict[str, Dict[str, Optional[Tensor]]]],
197 |             [
198 |                 torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
199 |                 for i in range(self.model.models_size)
200 |             ],
201 |         )
202 |         net_input = sample["net_input"]
203 | 
204 |         if "src_tokens" in net_input:
205 |             src_tokens = net_input["src_tokens"]
206 |             # length of the source text being the character length except EndOfSentence and pad
207 |             src_lengths = (
208 |                 (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
209 |             )
210 |         elif "source" in net_input:
211 |             src_tokens = net_input["source"]
212 |             src_lengths = (
213 |                 net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
214 |                 if net_input["padding_mask"] is not None
215 |                 else torch.tensor(src_tokens.size(-1)).to(src_tokens)
216 |             )
217 |         else:
218 |             raise Exception("expected src_tokens or source in net input")
219 | 
220 |         # bsz: total number of sentences in beam
221 |         # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
222 |         bsz, src_len = src_tokens.size()[:2]
223 |         beam_size = self.beam_size
224 | 
225 |         if constraints is not None and not self.search.supports_constraints:
226 |             raise NotImplementedError(
227 |                 "Target-side constraints were provided, but search method doesn't support them"
228 |             )
229 | 
230 |         # Initialize constraints, when active
231 |         self.search.init_constraints(constraints, beam_size)
232 | 
233 |         # Allowed text
234 |         if allowed_text is not None:
235 |             if allowed_text.dim() == 1:
236 |                 allowed_text = allowed_text.unsqueeze(dim=0).repeat_interleave(bsz, dim=0)
237 |             allowed_text = torch.cat([prefix_tokens, allowed_text], dim=1)
238 |         
239 |         max_len: int = -1
240 |         if self.match_source_len:
241 |             max_len = src_lengths.max().item()
242 |         else:
243 |             max_len = min(
244 |                 int(self.max_len_a * src_len + self.max_len_b),
245 |                 self.max_len - 1,
246 |             )
247 |         assert (
248 |             self.min_len <= max_len
249 |         ), "min_len cannot be larger than max_len, please adjust these!"
250 |         # compute the encoder output for each beam
251 |         encoder_outs = self.model.forward_encoder(net_input)
252 | 
253 |         # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
254 |         new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
255 |         new_order = new_order.to(src_tokens.device).long()
256 |         encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
257 |         # ensure encoder_outs is a List.
258 |         assert encoder_outs is not None
259 | 
260 |         # initialize buffers
261 |         scores = (
262 |             torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
263 |         )  # +1 for eos; pad is never chosen for scoring
264 |         tokens = (
265 |             torch.zeros(bsz * beam_size, max_len + 2)
266 |             .to(src_tokens)
267 |             .long()
268 |             .fill_(self.pad)
269 |         )  # +2 for eos and pad
270 |         tokens[:, 0] = self.eos if bos_token is None else bos_token
271 |         attn: Optional[Tensor] = None
272 | 
273 |         # A list that indicates candidates that should be ignored.
274 |         # For example, suppose we're sampling and have already finalized 2/5
275 |         # samples. Then cands_to_ignore would mark 2 positions as being ignored,
276 |         # so that we only finalize the remaining 3 samples.
277 |         cands_to_ignore = (
278 |             torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
279 |         )  # forward and backward-compatible False mask
280 | 
281 |         # list of completed sentences
282 |         finalized = torch.jit.annotate(
283 |             List[List[Dict[str, Tensor]]],
284 |             [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
285 |         )  # contains lists of dictionaries of information about the hypothesis being finalized at each step
286 | 
287 |         # a boolean array indicating if the sentence at the index is finished or not
288 |         finished = [False for i in range(bsz)]
289 |         num_remaining_sent = bsz  # number of sentences remaining
290 | 
291 |         # number of candidate hypos per step
292 |         cand_size = 2 * beam_size  # 2 x beam size in case half are EOS
293 | 
294 |         # offset arrays for converting between different indexing schemes
295 |         bbsz_offsets = (
296 |             (torch.arange(0, bsz) * beam_size)
297 |             .unsqueeze(1)
298 |             .type_as(tokens)
299 |             .to(src_tokens.device)
300 |         )
301 |         cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
302 | 
303 |         reorder_state: Optional[Tensor] = None
304 |         batch_idxs: Optional[Tensor] = None
305 | 
306 |         original_batch_idxs: Optional[Tensor] = None
307 |         if "id" in sample and isinstance(sample["id"], Tensor):
308 |             original_batch_idxs = sample["id"]
309 |         else:
310 |             original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
311 | 
312 |         for step in range(max_len + 1):  # one extra step for EOS marker
313 |             # reorder decoder internal states based on the prev choice of beams
314 |             if reorder_state is not None:
315 |                 if batch_idxs is not None:
316 |                     # update beam indices to take into account removed sentences
317 |                     corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
318 |                         batch_idxs
319 |                     )
320 |                     reorder_state.view(-1, beam_size).add_(
321 |                         corr.unsqueeze(-1) * beam_size
322 |                     )
323 |                     original_batch_idxs = original_batch_idxs[batch_idxs]
324 |                 self.model.reorder_incremental_state(incremental_states, reorder_state)
325 |                 encoder_outs = self.model.reorder_encoder_out(
326 |                     encoder_outs, reorder_state
327 |                 )
328 | 
329 |             lprobs, avg_attn_scores = self.model.forward_decoder(
330 |                 tokens[:, : step + 1],
331 |                 encoder_outs,
332 |                 incremental_states,
333 |                 self.temperature,
334 |             )
335 | 
336 |             if self.lm_model is not None:
337 |                 lm_out = self.lm_model(tokens[:, : step + 1])
338 |                 probs = self.lm_model.get_normalized_probs(
339 |                     lm_out, log_probs=True, sample=None
340 |                 )
341 |                 probs = probs[:, -1, :] * self.lm_weight
342 |                 lprobs += probs
343 | 
344 |             lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
345 | 
346 |             # hack the lprobs here to constrain the output to be in the allowed text
347 |             if allowed_text is not None:
348 |                 if batch_idxs is not None:
349 |                     allowed_text = allowed_text[batch_idxs]
350 |                 bi = torch.arange(lprobs.size(0)).unsqueeze(dim=1).repeat_interleave(allowed_text.size(-1), dim=-1)
351 |                 mask = torch.ones(lprobs.size()).to(lprobs)
352 |                 mask[bi.view(-1), allowed_text.view(-1)] = 0
353 |                 mask[mask==1] = -math.inf
354 |                 lprobs +=  mask
355 | 
356 |             lprobs[:, self.pad] = -math.inf  # never select pad
357 |             lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty
358 | 
359 |             # handle max length constraint
360 |             if step >= max_len:
361 |                 lprobs[:, : self.eos] = -math.inf
362 |                 lprobs[:, self.eos + 1 :] = -math.inf
363 | 
364 |             # handle prefix tokens (possibly with different lengths)
365 |             if (
366 |                 prefix_tokens is not None
367 |                 and step < prefix_tokens.size(1)
368 |                 and step < max_len
369 |             ):
370 |                 lprobs, tokens, scores = self._prefix_tokens(
371 |                     step, lprobs, scores, tokens, prefix_tokens, beam_size
372 |                 )
373 |             elif step < self.min_len:
374 |                 # minimum length constraint (does not apply if using prefix_tokens)
375 |                 lprobs[:, self.eos] = -math.inf
376 | 
377 |             # Record attention scores, only support avg_attn_scores is a Tensor
378 |             if avg_attn_scores is not None:
379 |                 if attn is None:
380 |                     attn = torch.empty(
381 |                         bsz * beam_size, avg_attn_scores.size(1), max_len + 2
382 |                     ).to(scores)
383 |                 attn[:, :, step + 1].copy_(avg_attn_scores)
384 | 
385 |             scores = scores.type_as(lprobs)
386 |             eos_bbsz_idx = torch.empty(0).to(
387 |                 tokens
388 |             )  # indices of hypothesis ending with eos (finished sentences)
389 |             eos_scores = torch.empty(0).to(
390 |                 scores
391 |             )  # scores of hypothesis ending with eos (finished sentences)
392 | 
393 |             if self.should_set_src_lengths:
394 |                 self.search.set_src_lengths(src_lengths)
395 | 
396 |             if self.repeat_ngram_blocker is not None:
397 |                 lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
398 | 
399 |             # Shape: (batch, cand_size)
400 |             cand_scores, cand_indices, cand_beams = self.search.step(
401 |                 step,
402 |                 lprobs.view(bsz, -1, self.vocab_size),
403 |                 scores.view(bsz, beam_size, -1)[:, :, :step],
404 |                 tokens[:, : step + 1],
405 |                 original_batch_idxs,
406 |             )
407 | 
408 |             # cand_bbsz_idx contains beam indices for the top candidate
409 |             # hypotheses, with a range of values: [0, bsz*beam_size),
410 |             # and dimensions: [bsz, cand_size]
411 |             cand_bbsz_idx = cand_beams.add(bbsz_offsets)
412 | 
413 |             # finalize hypotheses that end in eos
414 |             # Shape of eos_mask: (batch size, beam size)
415 |             eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
416 |             eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
417 | 
418 |             # only consider eos when it's among the top beam_size indices
419 |             # Now we know what beam item(s) to finish
420 |             # Shape: 1d list of absolute-numbered
421 |             eos_bbsz_idx = torch.masked_select(
422 |                 cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
423 |             )
424 | 
425 |             finalized_sents: List[int] = []
426 |             if eos_bbsz_idx.numel() > 0:
427 |                 eos_scores = torch.masked_select(
428 |                     cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
429 |                 )
430 | 
431 |                 finalized_sents = self.finalize_hypos(
432 |                     step,
433 |                     eos_bbsz_idx,
434 |                     eos_scores,
435 |                     tokens,
436 |                     scores,
437 |                     finalized,
438 |                     finished,
439 |                     beam_size,
440 |                     attn,
441 |                     src_lengths,
442 |                     max_len,
443 |                 )
444 |                 num_remaining_sent -= len(finalized_sents)
445 | 
446 |             assert num_remaining_sent >= 0
447 |             if num_remaining_sent == 0:
448 |                 break
449 |             if self.search.stop_on_max_len and step >= max_len:
450 |                 break
451 |             assert step < max_len, f"{step} < {max_len}"
452 | 
453 |             # Remove finalized sentences (ones for which {beam_size}
454 |             # finished hypotheses have been generated) from the batch.
455 |             if len(finalized_sents) > 0:
456 |                 new_bsz = bsz - len(finalized_sents)
457 | 
458 |                 # construct batch_idxs which holds indices of batches to keep for the next pass
459 |                 batch_mask = torch.ones(
460 |                     bsz, dtype=torch.bool, device=cand_indices.device
461 |                 )
462 |                 batch_mask[finalized_sents] = False
463 |                 # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
464 |                 batch_idxs = torch.arange(
465 |                     bsz, device=cand_indices.device
466 |                 ).masked_select(batch_mask)
467 | 
468 |                 # Choose the subset of the hypothesized constraints that will continue
469 |                 self.search.prune_sentences(batch_idxs)
470 | 
471 |                 eos_mask = eos_mask[batch_idxs]
472 |                 cand_beams = cand_beams[batch_idxs]
473 |                 bbsz_offsets.resize_(new_bsz, 1)
474 |                 cand_bbsz_idx = cand_beams.add(bbsz_offsets)
475 |                 cand_scores = cand_scores[batch_idxs]
476 |                 cand_indices = cand_indices[batch_idxs]
477 | 
478 |                 if prefix_tokens is not None:
479 |                     prefix_tokens = prefix_tokens[batch_idxs]
480 |                 src_lengths = src_lengths[batch_idxs]
481 |                 cands_to_ignore = cands_to_ignore[batch_idxs]
482 | 
483 |                 scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
484 |                 tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
485 |                 if attn is not None:
486 |                     attn = attn.view(bsz, -1)[batch_idxs].view(
487 |                         new_bsz * beam_size, attn.size(1), -1
488 |                     )
489 |                 bsz = new_bsz
490 |             else:
491 |                 batch_idxs = None
492 | 
493 |             # Set active_mask so that values > cand_size indicate eos hypos
494 |             # and values < cand_size indicate candidate active hypos.
495 |             # After, the min values per row are the top candidate active hypos
496 | 
497 |             # Rewrite the operator since the element wise or is not supported in torchscript.
498 | 
499 |             eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
500 |             active_mask = torch.add(
501 |                 eos_mask.type_as(cand_offsets) * cand_size,
502 |                 cand_offsets[: eos_mask.size(1)],
503 |             )
504 | 
505 |             # get the top beam_size active hypotheses, which are just
506 |             # the hypos with the smallest values in active_mask.
507 |             # {active_hypos} indicates which {beam_size} hypotheses
508 |             # from the list of {2 * beam_size} candidates were
509 |             # selected. Shapes: (batch size, beam size)
510 |             new_cands_to_ignore, active_hypos = torch.topk(
511 |                 active_mask, k=beam_size, dim=1, largest=False
512 |             )
513 | 
514 |             # update cands_to_ignore to ignore any finalized hypos.
515 |             cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
516 |             # Make sure there is at least one active item for each sentence in the batch.
517 |             assert (~cands_to_ignore).any(dim=1).all()
518 | 
519 |             # update cands_to_ignore to ignore any finalized hypos
520 | 
521 |             # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
522 |             # can be selected more than once).
523 |             active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
524 |             active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
525 | 
526 |             active_bbsz_idx = active_bbsz_idx.view(-1)
527 |             active_scores = active_scores.view(-1)
528 | 
529 |             # copy tokens and scores for active hypotheses
530 | 
531 |             # Set the tokens for each beam (can select the same row more than once)
532 |             tokens[:, : step + 1] = torch.index_select(
533 |                 tokens[:, : step + 1], dim=0, index=active_bbsz_idx
534 |             )
535 |             # Select the next token for each of them
536 |             tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
537 |                 cand_indices, dim=1, index=active_hypos
538 |             )
539 |             if step > 0:
540 |                 scores[:, :step] = torch.index_select(
541 |                     scores[:, :step], dim=0, index=active_bbsz_idx
542 |                 )
543 |             scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
544 |                 cand_scores, dim=1, index=active_hypos
545 |             )
546 | 
547 |             # Update constraints based on which candidates were selected for the next beam
548 |             self.search.update_constraints(active_hypos)
549 | 
550 |             # copy attention for active hypotheses
551 |             if attn is not None:
552 |                 attn[:, :, : step + 2] = torch.index_select(
553 |                     attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
554 |                 )
555 | 
556 |             # reorder incremental state in decoder
557 |             reorder_state = active_bbsz_idx
558 | 
559 |         # sort by score descending
560 |         for sent in range(len(finalized)):
561 |             scores = torch.tensor(
562 |                 [float(elem["score"].item()) for elem in finalized[sent]]
563 |             )
564 |             _, sorted_scores_indices = torch.sort(scores, descending=True)
565 |             finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
566 |             finalized[sent] = torch.jit.annotate(
567 |                 List[Dict[str, Tensor]], finalized[sent]
568 |             )
569 |         return finalized
570 | 
571 |     def _prefix_tokens(
572 |         self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
573 |     ):
574 |         """Handle prefix tokens"""
575 |         prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
576 |         prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
577 |         prefix_mask = prefix_toks.ne(self.pad)
578 |         lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
579 |         lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
580 |             -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
581 |         )
582 |         # if prefix includes eos, then we should make sure tokens and
583 |         # scores are the same across all beams
584 |         eos_mask = prefix_toks.eq(self.eos)
585 |         if eos_mask.any():
586 |             # validate that the first beam matches the prefix
587 |             first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
588 |                 :, 0, 1 : step + 1
589 |             ]
590 |             eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
591 |             target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
592 |             assert (first_beam == target_prefix).all()
593 | 
594 |             # copy tokens, scores and lprobs from the first beam to all beams
595 |             tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
596 |             scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
597 |             lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
598 |         return lprobs, tokens, scores
599 | 
600 |     def replicate_first_beam(self, tensor, mask, beam_size: int):
601 |         tensor = tensor.view(-1, beam_size, tensor.size(-1))
602 |         tensor[mask] = tensor[mask][:, :1, :]
603 |         return tensor.view(-1, tensor.size(-1))
604 | 
605 |     def finalize_hypos(
606 |         self,
607 |         step: int,
608 |         bbsz_idx,
609 |         eos_scores,
610 |         tokens,
611 |         scores,
612 |         finalized: List[List[Dict[str, Tensor]]],
613 |         finished: List[bool],
614 |         beam_size: int,
615 |         attn: Optional[Tensor],
616 |         src_lengths,
617 |         max_len: int,
618 |     ):
619 |         """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
620 |         A sentence is finalized when {beam_size} finished items have been collected for it.
621 | 
622 |         Returns number of sentences (not beam items) being finalized.
623 |         These will be removed from the batch and not processed further.
624 |         Args:
625 |             bbsz_idx (Tensor):
626 |         """
627 |         assert bbsz_idx.numel() == eos_scores.numel()
628 | 
629 |         # clone relevant token and attention tensors.
630 |         # tokens is (batch * beam, max_len). So the index_select
631 |         # gets the newly EOS rows, then selects cols 1..{step + 2}
632 |         tokens_clone = tokens.index_select(0, bbsz_idx)[
633 |             :, 1 : step + 2
634 |         ]  # skip the first index, which is EOS
635 | 
636 |         tokens_clone[:, step] = self.eos
637 |         attn_clone = (
638 |             attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
639 |             if attn is not None
640 |             else None
641 |         )
642 | 
643 |         # compute scores per token position
644 |         pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
645 |         pos_scores[:, step] = eos_scores
646 |         # convert from cumulative to per-position scores
647 |         pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
648 | 
649 |         # normalize sentence-level scores
650 |         if self.normalize_scores:
651 |             eos_scores /= (step + 1) ** self.len_penalty
652 | 
653 |         # cum_unfin records which sentences in the batch are finished.
654 |         # It helps match indexing between (a) the original sentences
655 |         # in the batch and (b) the current, possibly-reduced set of
656 |         # sentences.
657 |         cum_unfin: List[int] = []
658 |         prev = 0
659 |         for f in finished:
660 |             if f:
661 |                 prev += 1
662 |             else:
663 |                 cum_unfin.append(prev)
664 | 
665 |         # The keys here are of the form "{sent}_{unfin_idx}", where
666 |         # "unfin_idx" is the index in the current (possibly reduced)
667 |         # list of sentences, and "sent" is the index in the original,
668 |         # unreduced batch
669 |         # set() is not supported in script export
670 |         sents_seen: Dict[str, Optional[Tensor]] = {}
671 | 
672 |         # For every finished beam item
673 |         for i in range(bbsz_idx.size()[0]):
674 |             idx = bbsz_idx[i]
675 |             score = eos_scores[i]
676 |             # sentence index in the current (possibly reduced) batch
677 |             unfin_idx = idx // beam_size
678 |             # sentence index in the original (unreduced) batch
679 |             sent = unfin_idx + cum_unfin[unfin_idx]
680 |             # Cannot create dict for key type '(int, int)' in torchscript.
681 |             # The workaround is to cast int to string
682 |             seen = str(sent.item()) + "_" + str(unfin_idx.item())
683 |             if seen not in sents_seen:
684 |                 sents_seen[seen] = None
685 | 
686 |             if self.match_source_len and step > src_lengths[unfin_idx]:
687 |                 score = torch.tensor(-math.inf).to(score)
688 | 
689 |             # An input sentence (among those in a batch) is finished when
690 |             # beam_size hypotheses have been collected for it
691 |             if len(finalized[sent]) < beam_size:
692 |                 if attn_clone is not None:
693 |                     # remove padding tokens from attn scores
694 |                     hypo_attn = attn_clone[i]
695 |                 else:
696 |                     hypo_attn = torch.empty(0)
697 | 
698 |                 finalized[sent].append(
699 |                     {
700 |                         "tokens": tokens_clone[i],
701 |                         "score": score,
702 |                         "attention": hypo_attn,  # src_len x tgt_len
703 |                         "alignment": torch.empty(0),
704 |                         "positional_scores": pos_scores[i],
705 |                     }
706 |                 )
707 | 
708 |         newly_finished: List[int] = []
709 | 
710 |         for seen in sents_seen.keys():
711 |             # check termination conditions for this sentence
712 |             sent: int = int(float(seen.split("_")[0]))
713 |             unfin_idx: int = int(float(seen.split("_")[1]))
714 | 
715 |             if not finished[sent] and self.is_finished(
716 |                 step, unfin_idx, max_len, len(finalized[sent]), beam_size
717 |             ):
718 |                 finished[sent] = True
719 |                 newly_finished.append(unfin_idx)
720 | 
721 |         return newly_finished
722 | 
723 |     def is_finished(
724 |         self,
725 |         step: int,
726 |         unfin_idx: int,
727 |         max_len: int,
728 |         finalized_sent_len: int,
729 |         beam_size: int,
730 |     ):
731 |         """
732 |         Check whether decoding for a sentence is finished, which
733 |         occurs when the list of finalized sentences has reached the
734 |         beam size, or when we reach the maximum length.
735 |         """
736 |         assert finalized_sent_len <= beam_size
737 |         if finalized_sent_len == beam_size or step == max_len:
738 |             return True
739 |         return False
740 | 


--------------------------------------------------------------------------------
/src/language_model_prompt_dataset.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import logging
  5 | 
  6 | import numpy as np
  7 | import torch
  8 | from fairseq.data import FairseqDataset, data_utils
  9 | 
 10 | 
 11 | logger = logging.getLogger(__name__)
 12 | 
 13 | 
 14 | def collate(samples, pad_idx, eos_idx, prefix=False, sep_idx=None, prompt=None):
 15 |     if len(samples) == 0:
 16 |         return {}
 17 | 
 18 |     def make_sentence(prompt, source, target):
 19 |         if source[-1] == eos_idx:
 20 |             source = source[:-1]
 21 |         if prompt is None:
 22 |             return torch.cat([source, target], dim=0)
 23 |         if prefix:
 24 |             sep = torch.LongTensor([sep_idx])
 25 |             return torch.cat([prompt, source, sep, target], dim=0)
 26 |         return torch.cat([source, prompt, target], dim=0)
 27 |         
 28 |             
 29 |     def merge(tokens, move_eos_to_beginning=False):
 30 |         return data_utils.collate_tokens(
 31 |             tokens,
 32 |             pad_idx,
 33 |             eos_idx,
 34 |             move_eos_to_beginning=move_eos_to_beginning,
 35 |         )
 36 | 
 37 |     id = torch.LongTensor([s["id"] for s in samples])
 38 |     #src_tokens = merge([s["source"] for s in samples])
 39 |     #src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
 40 | 
 41 |     target_tokens = []
 42 |     target_lengths = []
 43 |     for s in samples:
 44 |         target_tokens.append(make_sentence(prompt, s["source"], s["target"]))
 45 |         
 46 |     target_lengths = [t.ne(pad_idx).long().sum() for t in target_tokens]
 47 |     target = merge(target_tokens)
 48 |     target_lengths = torch.LongTensor(target_lengths)
 49 |     prev_output_tokens = merge(target_tokens, move_eos_to_beginning=True)
 50 |     ntokens = target_lengths.sum().item()
 51 |     batch = {
 52 |         "id": id,
 53 |         "nsentences": len(samples),
 54 |         "ntokens": ntokens,
 55 |         "net_input": {
 56 |             "src_tokens": prev_output_tokens, #src_tokens,
 57 |             "src_lengths": target_lengths, #src_lengths,
 58 |             #"prev_output_tokens": prev_output_tokens,
 59 |             #"target_lengths": target_lengths,
 60 |         },
 61 |         "target": target,
 62 |     }
 63 |     return batch
 64 | 
 65 | 
 66 | class LanguageModelPromptDataset(FairseqDataset):
 67 |     """
 68 |     A pair of torch.utils.data.Datasets.
 69 | 
 70 |     Args:
 71 |         src (torch.utils.data.Dataset): source dataset to wrap
 72 |         src_sizes (List[int]): source sentence lengths
 73 |         dictionary (~fairseq.data.Dictionary): vocabulary
 74 |         tgt (torch.utils.data.Dataset, optional): target dataset to wrap
 75 |         tgt_sizes (List[int], optional): target sentence lengths
 76 |         prefix (bool, optional): prefix
 77 |         prompt (str, optional): prompt to use
 78 |         shuffle (bool, optional): shuffle dataset elements before batching
 79 |             (default: True).
 80 |         max_source_length (int): max source text length
 81 |         max_length (int): max text length
 82 |         prompt_length (int): length of the prompt text
 83 |         
 84 |     """
 85 | 
 86 |     def __init__(
 87 |         self,
 88 |         src,
 89 |         src_sizes,
 90 |         dictionary,
 91 |         tgt,
 92 |         tgt_sizes,
 93 |         prefix=False,
 94 |         prompt=None,
 95 |         shuffle=True,
 96 |         eos=None,
 97 |         max_source_length=None,
 98 |         max_length=None,
 99 |         prompt_length=None,
100 |     ):
101 |         self.src = src
102 |         self.tgt = tgt
103 |         self.prefix = prefix
104 |         self.seq_sep = None
105 |         self.prompt = prompt
106 |         self.dict = dictionary
107 |         self.shuffle = shuffle
108 |         self.eos = eos if eos is not None else dictionary.eos()
109 |         self.max_source_length = max_source_length
110 |         self.max_target_length = max_length - max_source_length - prompt_length
111 |         if self.prefix:
112 |             self.max_target_length -= 1
113 |         self.src_sizes = [min(s-1, self.max_source_length) for s in src_sizes]
114 |         self.tgt_sizes = [min(t, self.max_target_length) for t in tgt_sizes]
115 |         self.sizes = np.array([s+t for s,t in zip(self.src_sizes, self.tgt_sizes)])
116 |         self.buckets = None
117 |         
118 |     def get_batch_shapes(self):
119 |         return self.buckets
120 | 
121 |     def __getitem__(self, index):
122 |         src_item = self.src[index]
123 |         if src_item.size(0) - 1 > self.max_source_length:
124 |             src_item = src_item[:self.max_source_length + 1]
125 |             src_item[-2] = self.dict.index('...')
126 |             src_item[-1] = self.eos
127 |         
128 |         tgt_item = self.tgt[index]
129 |         if tgt_item.size(0) > self.max_target_length:
130 |             tgt_item = tgt_item[:self.max_target_length]
131 |             tgt_item[-2] = self.dict.index('...')
132 |             tgt_item[-1] = self.eos
133 |         example = {
134 |             "id": index,
135 |             "source": src_item,
136 |             "target": tgt_item,
137 |         }
138 |         return example
139 | 
140 |     def __len__(self):
141 |         return len(self.src)
142 | 
143 |     def collater(self, samples):
144 |         """Merge a list of samples to form a mini-batch.
145 | 
146 |         Args:
147 |             samples (List[dict]): samples to collate
148 |         Returns:
149 |             dict: a mini-batch with the following keys:
150 | 
151 |                 - `id` (LongTensor): example IDs in the original input order
152 |                 - `ntokens` (int): total number of tokens in the batch
153 |                 - `net_input` (dict): the input to the Model, containing keys:
154 |                   - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
155 |                     the source sentence of shape `(bsz, src_len)`.
156 |                   - `src_lengths` (LongTensor): 1D Tensor of the unpadded
157 |                     lengths of each source sentence of shape `(bsz)`
158 |                   - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
159 |                     tokens in the target sentence, shifted right by one
160 |                     position for teacher forcing, of shape `(bsz, tgt_len)`.
161 |                   - `lengths` (LongTensor): 1D Tensor of the unpadded
162 |                     lengths of each target sentence of shape `(bsz)`
163 |         """
164 |         res = collate(
165 |             samples, 
166 |             pad_idx=self.dict.pad(),
167 |             eos_idx=self.dict.eos(),
168 |             prefix=self.prefix,
169 |             sep_idx=self.dict.sep_index,
170 |             prompt=self.prompt,
171 |         )
172 |         return res
173 | 
174 |     def num_tokens(self, index):
175 |         """Return the number of tokens in a sample. This value is used to
176 |         enforce ``--max-tokens`` during batching."""
177 |         return self.sizes[index]
178 | 
179 |     def num_tokens_vec(self, indices):
180 |         """Return the number of tokens for a set of positions defined by indices.
181 |         This value is used to enforce ``--max-tokens`` during batching."""
182 |         sizes = self.sizes[indices]
183 |         return sizes
184 | 
185 |     def size(self, index):
186 |         """Return an example's size as a float or tuple. This value is used when
187 |         filtering a dataset with ``--max-positions``."""
188 |         return self.sizes[index]
189 | 
190 |     def ordered_indices(self):
191 |         """Return an ordered list of indices. Batches will be constructed based
192 |         on this order."""
193 |         if self.shuffle:
194 |             indices = np.random.permutation(len(self)).astype(np.int64)
195 |         else:
196 |             indices = np.arange(len(self), dtype=np.int64)
197 |         return indices[np.argsort(self.sizes[indices], kind="mergesort")]
198 | 
199 |     @property
200 |     def supports_prefetch(self):
201 |         return getattr(self.src, "supports_prefetch", False) and (
202 |             getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
203 |         )
204 | 
205 |     def prefetch(self, indices):
206 |         self.src.prefetch(indices)
207 |         self.tgt.prefetch(indices)
208 | 


--------------------------------------------------------------------------------
/src/language_modeling_prompt.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Microsoft Corporation.
  2 | # Licensed under the MIT License.
  3 | 
  4 | import logging
  5 | import os
  6 | from dataclasses import dataclass, field
  7 | from typing import Optional
  8 | 
  9 | import torch
 10 | from fairseq import search, utils
 11 | from fairseq.data import (
 12 |     Dictionary,
 13 |     data_utils,
 14 |     indexed_dataset,
 15 | )
 16 | 
 17 | from fairseq.tasks import register_task
 18 | from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask
 19 | from .language_model_prompt_dataset import LanguageModelPromptDataset
 20 | from omegaconf import II
 21 | 
 22 | 
 23 | logger = logging.getLogger(__name__)
 24 | 
 25 | 
 26 | @dataclass
 27 | class LanguageModelingPromptConfig(LanguageModelingConfig):
 28 |     source_lang: Optional[str] = field(
 29 |         default=None, metadata={"help": "source language", "argparse_alias": "-s",}
 30 |     )
 31 |     target_lang: Optional[str] = field(
 32 |         default=None, metadata={"help": "target language","argparse_alias": "-t",}
 33 |     )
 34 |     max_source_positions: Optional[int] = field(
 35 |         default=384, metadata={"help": "max number of tokens in the source sequence, exclude eos."}
 36 |     )
 37 |     manual_prompt: Optional[str] = field(
 38 |         default=None, metadata={"help": "manual prompt to use",}
 39 |     )
 40 |     learned_prompt: Optional[int] = field(
 41 |         default=None, metadata={"help": "number of virtual tokens to use",}
 42 |     )
 43 |     learned_prompt_pattern: Optional[str] = field(
 44 |         default='learned', metadata={"help": "pattern of virtual tokens, default is learned",}
 45 |     )
 46 |     prefix: Optional[bool] = field(
 47 |         default=False, metadata={"help": "whether put prompt as prefix."}
 48 |     )
 49 |     sep_token: Optional[str] = field(
 50 |         default="<seqsep>", metadata={"help": "token to seperate prompt source and target."}
 51 |     )
 52 | 
 53 | 
 54 | @register_task("language_modeling_prompt", dataclass=LanguageModelingPromptConfig)
 55 | class LanguageModelingPromptTask(LanguageModelingTask):
 56 |     """
 57 |     Train a language model.
 58 | 
 59 |     Args:
 60 |         dictionary (~fairseq.data.Dictionary): the dictionary for the input of
 61 |             the language model
 62 |         output_dictionary (~fairseq.data.Dictionary): the dictionary for the
 63 |             output of the language model. In most cases it will be the same as
 64 |             *dictionary*, but could possibly be a more limited version of the
 65 |             dictionary (if ``--output-dictionary-size`` is used).
 66 |         targets (List[str]): list of the target types that the language model
 67 |             should predict.  Can be one of "self", "future", and "past".
 68 |             Defaults to "future".
 69 | 
 70 |     .. note::
 71 | 
 72 |         The language modeling task is compatible with :mod:`fairseq-train`,
 73 |         :mod:`fairseq-generate`, :mod:`fairseq-interactive` and
 74 |         :mod:`fairseq-eval-lm`.
 75 | 
 76 |     The language modeling task provides the following additional command-line
 77 |     arguments:
 78 | 
 79 |     .. argparse::
 80 |         :ref: fairseq.tasks.language_modeling_parser
 81 |         :prog:
 82 |     """
 83 |     def __init__(self, args, dictionary, output_dictionary=None, prompt=None, targets=None):
 84 |         super().__init__(args, dictionary, output_dictionary, targets)
 85 |         self.prompt = prompt
 86 |         self.prompt_length = self.prompt.size(0) if self.prompt is not None else 0
 87 |         self.prefix = args.prefix
 88 | 
 89 |     @classmethod
 90 |     def setup_prompt(cls, args, dictionary):
 91 |         if args.prefix:
 92 |             dictionary.sep_index = dictionary.add_symbol(args.sep_token)
 93 |         else:
 94 |             dictionary.sep_index = None
 95 |         assert not (args.manual_prompt and args.learned_prompt), "manual prompt and learned prompt can not be set "
 96 |         if args.manual_prompt and len(args.manual_prompt) != 0:
 97 |             prompt = dictionary.encode_line(args.manual_prompt, append_eos=False).long()
 98 |         elif args.learned_prompt:
 99 |             prompt = ''
100 |             for idx in range(args.learned_prompt):
101 |                 prompt += args.learned_prompt_pattern + str(idx+1) + ' '
102 |             prompt = dictionary.encode_line(prompt, append_eos=False).long()
103 |         else:
104 |             prompt = None
105 |         return prompt
106 | 
107 |     @classmethod
108 |     def setup_dictionary(cls, args, **kwargs):
109 |         dictionary = None
110 |         output_dictionary = None
111 |         if args.data:
112 |             paths = utils.split_paths(args.data)
113 |             assert len(paths) > 0
114 |             dictionary = Dictionary.load(os.path.join(paths[0], "dict.{}.txt".format(args.source_lang)))
115 |             logger.info("dictionary: {} types".format(len(dictionary)))
116 |             #output_dictionary = Dictionary.load(os.path.join(paths[0], "dict.{}.txt".format(args.target_lang)))
117 |             output_dictionary = dictionary
118 |         return (dictionary, output_dictionary)
119 | 
120 |     @classmethod
121 |     def setup_task(cls, args, **kwargs):
122 |         """Setup the task (e.g., load dictionaries).
123 | 
124 |         Args:
125 |             args (argparse.Namespace): parsed command-line arguments
126 |         """
127 |         paths = utils.split_paths(args.data)
128 |         assert len(paths) > 0
129 |         # find language pair automatically
130 |         if args.source_lang is None or args.target_lang is None:
131 |             args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
132 |         if args.source_lang is None or args.target_lang is None:
133 |             raise Exception(
134 |                 "Could not infer language pair, please provide it explicitly"
135 |             )
136 | 
137 |         dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
138 |         prompt = cls.setup_prompt(args, dictionary)
139 | 
140 |         # upgrade old checkpoints
141 |         if getattr(args, "exclude_self_target", False):
142 |             args.self_target = False
143 | 
144 |         targets = []
145 |         if getattr(args, "self_target", False):
146 |             targets.append("self")
147 |         if getattr(args, "future_target", False):
148 |             targets.append("future")
149 |         if getattr(args, "past_target", False):
150 |             targets.append("past")
151 |         if len(targets) == 0:
152 |             # standard language modeling
153 |             targets = ["future"]
154 | 
155 |         return cls(args, dictionary, output_dictionary, prompt, targets=targets)
156 | 
157 |     def load_dataset(
158 |         self, split: str, epoch=1, combine=False, **kwargs
159 |     ) -> LanguageModelPromptDataset:
160 |         """Load a given dataset split.
161 | 
162 |         Args:
163 |             split (str): name of the split (e.g., train, valid, test)
164 |         """
165 |         def split_exists(split, src, tgt, lang, data_path):
166 |             filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
167 |             return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
168 | 
169 |         paths = utils.split_paths(self.args.data)
170 |         assert len(paths) > 0
171 |         data_path = paths[(epoch - 1) % len(paths)]
172 | 
173 |         # source
174 |         if split_exists(split, self.args.source_lang, self.args.target_lang, self.args.source_lang, data_path):
175 |             prefix = os.path.join(data_path, "{}.{}-{}.".format(split, self.args.source_lang, self.args.target_lang))
176 |         else:
177 |             raise FileNotFoundError(
178 |                     "Dataset not found: {} ({})".format(split, data_path)
179 |                 )
180 |         src_dataset = data_utils.load_indexed_dataset(
181 |             prefix + self.args.source_lang, self.dictionary, self.args.dataset_impl
182 |         )
183 | 
184 |         tgt_dataset = data_utils.load_indexed_dataset(
185 |                 prefix + self.args.target_lang, self.output_dictionary, self.args.dataset_impl
186 |         )
187 | 
188 |         src_sizes = src_dataset.sizes
189 |         tgt_sizes = tgt_dataset.sizes
190 | 
191 |         dataset = LanguageModelPromptDataset(
192 |             src_dataset,
193 |             src_sizes,
194 |             self.dictionary,
195 |             tgt_dataset,
196 |             tgt_sizes,
197 |             prefix = self.prefix,
198 |             prompt=self.prompt,
199 |             max_source_length=self.args.max_source_positions,
200 |             max_length=self.args.max_target_positions,
201 |             prompt_length=self.prompt_length
202 |         )
203 | 
204 |         self.datasets[split] = dataset
205 | 
206 |     def build_dataset_for_inference(self, src_tokens, src_lengths, tgt_tokens=None, tgt_lengths=None):
207 |         """
208 |         Generate batches for inference. We prepend an eos token to src_tokens
209 |         (or bos if `--add-bos-token` is set) and we append a <pad> to target.
210 |         This is convenient both for generation with a prefix and LM scoring.
211 |         """
212 |         bs = len(src_tokens)
213 |         if tgt_tokens is None:
214 |             tgt_tokens = [torch.LongTensor([self.dictionary.eos()]) for _ in range(bs)]
215 |             tgt_lengths = torch.LongTensor([t.numel() for t in tgt_tokens])
216 |             
217 |         dataset = LanguageModelPromptDataset(
218 |             src_tokens,
219 |             src_lengths,
220 |             self.dictionary,
221 |             tgt_tokens,
222 |             tgt_lengths,
223 |             prefix = self.prefix,
224 |             prompt=self.prompt,
225 |             max_source_length=self.args.max_source_positions,
226 |             max_length=self.args.max_target_positions,
227 |             prompt_length=self.prompt_length
228 |         )
229 | 
230 |         return dataset
231 |     
232 |     def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None, allowed_text=None):
233 |         with torch.no_grad():
234 |             # Generation will always be conditioned on bos_token
235 |             if getattr(self.args, "add_bos_token", False):
236 |                 bos_token = self.source_dictionary.bos()
237 |             else:
238 |                 bos_token = self.source_dictionary.eos()
239 | 
240 |             if constraints is not None:
241 |                 raise NotImplementedError(
242 |                     "Constrained decoding with the language_modeling task is not supported"
243 |                 )
244 | 
245 |             if allowed_text is not None:
246 |                 allowed_text = self.target_dictionary.encode_line(allowed_text, add_if_not_exist=False).to(sample['net_input']['src_tokens'])
247 |             # SequenceGenerator doesn't use src_tokens directly, we need to
248 |             # pass the `prefix_tokens` argument instead
249 |         
250 |             if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
251 |                 prefix_tokens = sample["net_input"]["src_tokens"]
252 |                 if prefix_tokens[:, 0].eq(bos_token).all():
253 |                     prefix_tokens = prefix_tokens[:, 1:]
254 | 
255 |             return generator.generate(
256 |                 models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token, allowed_text=allowed_text
257 |             )
258 | 
259 |     def build_generator(
260 |         self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None
261 |     ):
262 |         from .constrained_generator import ConstrainedGenerator
263 | 
264 |         # Choose search strategy. Defaults to Beam Search.
265 |         sampling = getattr(args, "sampling", False)
266 |         sampling_topk = getattr(args, "sampling_topk", -1)
267 |         sampling_topp = getattr(args, "sampling_topp", -1.0)
268 |         diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
269 |         diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
270 |         match_source_len = getattr(args, "match_source_len", False)
271 |         diversity_rate = getattr(args, "diversity_rate", -1)
272 |         constrained = getattr(args, "constraints", False)
273 |         if prefix_allowed_tokens_fn is None:
274 |             prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
275 |         if (
276 |             sum(
277 |                 int(cond)
278 |                 for cond in [
279 |                     sampling,
280 |                     diverse_beam_groups > 0,
281 |                     match_source_len,
282 |                     diversity_rate > 0,
283 |                 ]
284 |             )
285 |             > 1
286 |         ):
287 |             raise ValueError("Provided Search parameters are mutually exclusive.")
288 |         assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
289 |         assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
290 | 
291 |         if sampling:
292 |             search_strategy = search.Sampling(
293 |                 self.target_dictionary, sampling_topk, sampling_topp
294 |             )
295 |         elif diverse_beam_groups > 0:
296 |             search_strategy = search.DiverseBeamSearch(
297 |                 self.target_dictionary, diverse_beam_groups, diverse_beam_strength
298 |             )
299 |         elif match_source_len:
300 |             # this is useful for tagging applications where the output
301 |             # length should match the input length, so we hardcode the
302 |             # length constraints for simplicity
303 |             search_strategy = search.LengthConstrainedBeamSearch(
304 |                 self.target_dictionary,
305 |                 min_len_a=1,
306 |                 min_len_b=0,
307 |                 max_len_a=1,
308 |                 max_len_b=0,
309 |             )
310 |         elif diversity_rate > -1:
311 |             search_strategy = search.DiverseSiblingsSearch(
312 |                 self.target_dictionary, diversity_rate
313 |             )
314 |         elif constrained:
315 |             search_strategy = search.LexicallyConstrainedBeamSearch(
316 |                 self.target_dictionary, args.constraints
317 |             )
318 |         elif prefix_allowed_tokens_fn:
319 |             search_strategy = search.PrefixConstrainedBeamSearch(
320 |                 self.target_dictionary, prefix_allowed_tokens_fn
321 |             )
322 |         else:
323 |             search_strategy = search.BeamSearch(self.target_dictionary)
324 | 
325 |         extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
326 |         
327 |         seq_gen_cls = ConstrainedGenerator
328 | 
329 |         return seq_gen_cls(
330 |             models,
331 |             self.target_dictionary,
332 |             beam_size=getattr(args, "beam", 5),
333 |             max_len_a=getattr(args, "max_len_a", 0),
334 |             max_len_b=getattr(args, "max_len_b", 200),
335 |             min_len=getattr(args, "min_len", 1),
336 |             normalize_scores=(not getattr(args, "unnormalized", False)),
337 |             len_penalty=getattr(args, "lenpen", 1),
338 |             unk_penalty=getattr(args, "unkpen", 0),
339 |             temperature=getattr(args, "temperature", 1.0),
340 |             match_source_len=getattr(args, "match_source_len", False),
341 |             no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
342 |             search_strategy=search_strategy,
343 |             **extra_gen_cls_kwargs,
344 |         )
345 | 


--------------------------------------------------------------------------------
/src/transformer_lm_prompt.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Microsoft Corporation.
 2 | # Licensed under the MIT License.
 3 | 
 4 | import logging
 5 | from dataclasses import dataclass, field
 6 | from typing import Optional, Dict, List, Tuple
 7 | from argparse import Namespace
 8 | import torch
 9 | from torch import Tensor
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | 
13 | from fairseq import options, utils
14 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass
15 | from fairseq.models import (
16 |     FairseqLanguageModel,
17 |     register_model,
18 |     register_model_architecture,
19 | )
20 | from fairseq.models.transformer import (
21 |     DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder
22 | )
23 | from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
24 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf
25 | from fairseq.models.transformer_lm import (
26 |     TransformerLanguageModelConfig,
27 |     TransformerLanguageModel, 
28 |     transformer_lm_gpt2_small,
29 |     transformer_lm_gpt2_big,
30 | )
31 | from omegaconf import II, DictConfig
32 | 
33 | 
34 | logger = logging.getLogger(__name__)
35 | 
36 | 
37 | @register_model("transformer_lm_prompt", dataclass=TransformerLanguageModelConfig)
38 | class TransformerLanguageModelPrompt(TransformerLanguageModel):
39 | 
40 |     def load_state_dict(
41 |         self,
42 |         state_dict,
43 |         strict=True,
44 |         model_cfg: Optional[DictConfig] = None,
45 |         args: Optional[Namespace] = None,
46 |     ):
47 |         """Copies parameters and buffers from *state_dict* into this module and
48 |         its descendants.
49 | 
50 |         Overrides the method in :class:`nn.Module`. Compared with that method
51 |         this additionally "upgrades" *state_dicts* from old checkpoints.
52 |         """
53 | 
54 |         if model_cfg is None and args is not None:
55 |             logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
56 |             model_cfg = convert_namespace_to_omegaconf(args).model
57 | 
58 |         self.upgrade_state_dict(state_dict)
59 | 
60 |         device = state_dict["decoder.embed_tokens.weight"].device
61 |         if self.decoder.embed_tokens.weight.shape[0] > state_dict["decoder.embed_tokens.weight"].shape[0]:
62 |             shape = state_dict["decoder.embed_tokens.weight"].shape
63 |             state_dict["decoder.embed_tokens.weight"] = torch.cat(
64 |                 [state_dict["decoder.embed_tokens.weight"],
65 |                 self.decoder.embed_tokens.weight[shape[0]:].to(device)]
66 |             )
67 |         if self.decoder.output_projection.weight.shape[0] > state_dict["decoder.output_projection.weight"].shape[0]:
68 |             shape = state_dict["decoder.output_projection.weight"].shape
69 |             device = state_dict["decoder.output_projection.weight"].device
70 |             state_dict["decoder.output_projection.weight"] = torch.cat(
71 |                 [state_dict["decoder.output_projection.weight"],
72 |                 self.decoder.output_projection.weight[shape[0]:].to(device)]
73 |             )
74 | 
75 |         from fairseq.checkpoint_utils import prune_state_dict
76 | 
77 |         new_state_dict = prune_state_dict(state_dict, model_cfg)
78 |         return super().load_state_dict(new_state_dict, strict)
79 | 
80 | 
81 | @register_model_architecture("transformer_lm_prompt", "transformer_lm_prompt_biogpt")
82 | def transformer_lm_prompt_biogpt(args):
83 |     transformer_lm_gpt2_small(args)
84 | 
85 | @register_model_architecture("transformer_lm_prompt", "transformer_lm_prompt_biogpt_large")
86 | def transformer_lm_prompt_gpt2_big(args):
87 |     transformer_lm_gpt2_big(args)


--------------------------------------------------------------------------------