├── .github └── workflows │ ├── pypi.yml │ ├── pythonapp.yml │ └── testpypi.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── cleanup_results.sh ├── config ├── resnet_config.json ├── transformer_config.json └── transformer_tiny_config.json ├── data_refs.bib ├── download_data.sh ├── download_data_aws.sh ├── environment.yml ├── examples ├── adding_model.py └── adding_task.py ├── gridsearch_config.json ├── mypy.ini ├── requirements.txt ├── scripts ├── fix_lmdb.py ├── generate_plots.py ├── lmdb_to_fasta.py ├── tfrecord_to_json.py └── tfrecord_to_lmdb.py ├── setup.py ├── tape ├── __init__.py ├── datasets.py ├── errors.py ├── main.py ├── metrics.py ├── models │ ├── __init__.py │ ├── file_utils.py │ ├── modeling_bert.py │ ├── modeling_lstm.py │ ├── modeling_onehot.py │ ├── modeling_resnet.py │ ├── modeling_trrosetta.py │ ├── modeling_unirep.py │ └── modeling_utils.py ├── optimization.py ├── registry.py ├── tokenizers.py ├── training.py ├── utils │ ├── __init__.py │ ├── _sampler.py │ ├── distributed_utils.py │ ├── setup_utils.py │ └── utils.py └── visualization.py ├── tests ├── test_basic.py └── test_forceDownload.py └── tox.ini /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPI 2 | 3 | # Controls when the action will run. 4 | on: 5 | # Triggers the workflow when a release is created 6 | release: 7 | types: [created] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 13 | jobs: 14 | # This workflow contains a single job called "testpypi" 15 | testpypi: 16 | # The type of runner that the job will run on 17 | runs-on: ubuntu-latest 18 | 19 | # Steps represent a sequence of tasks that will be executed as part of the job 20 | steps: 21 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 22 | - uses: actions/checkout@v2 23 | 24 | # Sets up python3 25 | - uses: actions/setup-python@v2 26 | with: 27 | python-version: 3.8 28 | 29 | - name: "Installs dependencies" 30 | run: | 31 | python3 -m pip install --upgrade pip 32 | python3 -m pip install setuptools wheel twine 33 | 34 | # Upload to TestPyPI 35 | - name: Build and upload to PyPI 36 | run: | 37 | python3 setup.py sdist bdist_wheel 38 | python3 -m twine upload dist/* 39 | env: 40 | TWINE_USERNAME: __token__ 41 | TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} 42 | -------------------------------------------------------------------------------- /.github/workflows/pythonapp.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.7 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: 3.7 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install flake8 pytest torch 24 | pip install . 25 | - name: Lint with flake8 26 | run: | 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-line-length=96 --statistics 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | -------------------------------------------------------------------------------- /.github/workflows/testpypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to TestPyPI 2 | 3 | # Controls when the action will run. 4 | on: 5 | # Allows you to run this workflow manually from the Actions tab 6 | workflow_dispatch: 7 | 8 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 9 | jobs: 10 | # This workflow contains a single job called "testpypi" 11 | testpypi: 12 | # The type of runner that the job will run on 13 | runs-on: ubuntu-latest 14 | 15 | # Steps represent a sequence of tasks that will be executed as part of the job 16 | steps: 17 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 18 | - uses: actions/checkout@v2 19 | 20 | # Sets up python3 21 | - uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.8 24 | 25 | - name: "Installs dependencies" 26 | run: | 27 | python3 -m pip install --upgrade pip 28 | python3 -m pip install setuptools wheel twine 29 | 30 | # Upload to TestPyPI 31 | - name: Build and upload to TestPyPI 32 | run: | 33 | python3 setup.py sdist bdist_wheel 34 | python3 -m twine upload dist/* 35 | env: 36 | TWINE_USERNAME: __token__ 37 | TWINE_PASSWORD: ${{ secrets.TWINE_TEST_TOKEN }} 38 | TWINE_REPOSITORY: testpypi 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | results/ 4 | pretrained_models/ 5 | wandb/ 6 | scratch/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # vim gitignores 114 | *.swp 115 | *.swo 116 | *.npy 117 | *.pkl 118 | *.npz 119 | *.fasta 120 | *.lmdb 121 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Regents of the University of California 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Tasks Assessing Protein Embeddings (TAPE) 4 | 5 | ![](https://github.com/songlab-cal/tape/workflows/Build/badge.svg) 6 | 7 | Data, weights, and code for running the TAPE benchmark on a trained protein embedding. We provide a pretraining corpus, five supervised downstream tasks, pretrained language model weights, and benchmarking code. This code has been updated to use pytorch - as such previous pretrained model weights and code will not work. The previous tensorflow TAPE repository is still available at [https://github.com/songlab-cal/tape-neurips2019](https://github.com/songlab-cal/tape-neurips2019). 8 | 9 | This repository is *not* an effort to maintain maximum compatibility and reproducability with the original paper, but is instead meant to facilitate ease of use and future development (both for us, and for the community). Although we provide much of the same functionality, we have not tested every aspect of training on all models/downstream tasks, and we have also made some deliberate changes. Therefore, if your goal is to reproduce the results from our paper, please use the original code. 10 | 11 | Our paper is available at [https://arxiv.org/abs/1906.08230](https://arxiv.org/abs/1906.08230). 12 | 13 | Some documentation is incomplete. We will try to fill it in over time, but if there is something you would like an explanation for, please open an issue so we know where to focus our effort! 14 | 15 | **Update 09/26/2020:** We no longer recommend trying to train directly with TAPE's training code. It will likely still work for some time, but will not be updated for future pytorch versions. Internally, we have been working with different frameworks for training (specifically Pytorch Lightning and Fairseq). We strongly recommend using a framework like these, as it offloads the requirement of maintaining compatability with Pytorch versions. TAPE models will continue to be available, and if the code is working for you, feel free to use it. However we will not be fixing issues regarding multi-GPU errors, OOM erros, etc during training. 16 | 17 | ## Contents 18 | 19 | * [Installation](#installation) 20 | * [Examples](#examples) 21 | * [Huggingface API for Loading Pretrained Models](#huggingface-api-for-loading-pretrained-models) 22 | * [Embedding Proteins with a Pretrained Model](#embedding-proteins-with-a-pretrained-model) 23 | * [Training a Language Model](#training-a-language-model) 24 | * [Evaluating a Language Model](#evaluating-a-language-model) 25 | * [Training a Downstream Model](#training-a-downstream-model) 26 | * [Evaluating a Downstream Model](#evaluating-a-downstream-model) 27 | * [List of Models and Tasks](#list-of-models-and-tasks) 28 | * [Adding New Models and Tasks](#adding-new-models-and-tasks) 29 | * [Data](#data) 30 | * [LMDB Data](#lmdb-data) 31 | * [Raw Data](#raw-data) 32 | * [Leaderboard](#leaderboard) 33 | * [Secondary Structure](#secondary-structure) 34 | * [Contact Prediction](#contact-prediction) 35 | * [Remote Homology Detection](#remote-homology-detection) 36 | * [Fluorescence](#fluorescence) 37 | * [Stability](#stability) 38 | * [Citation Guidelines](#citation-guidelines) 39 | 40 | ## Installation 41 | 42 | We recommend that you install `tape` into a python [virtual environment](https://virtualenv.pypa.io/en/latest/) using 43 | 44 | ```bash 45 | $ pip install tape_proteins 46 | ``` 47 | 48 | ## Examples 49 | 50 | ### Huggingface API for Loading Pretrained Models 51 | 52 | We build on the excellent [huggingface repository](https://github.com/huggingface/transformers) and use this as an API to define models, as well as to provide pretrained models. By using this API, pretrained models will be automatically downloaded when necessary and cached for future use. 53 | 54 | ```python 55 | import torch 56 | from tape import ProteinBertModel, TAPETokenizer 57 | model = ProteinBertModel.from_pretrained('bert-base') 58 | tokenizer = TAPETokenizer(vocab='iupac') # iupac is the vocab for TAPE models, use unirep for the UniRep model 59 | 60 | # Pfam Family: Hexapep, Clan: CL0536 61 | sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' 62 | token_ids = torch.tensor([tokenizer.encode(sequence)]) 63 | output = model(token_ids) 64 | sequence_output = output[0] 65 | pooled_output = output[1] 66 | 67 | # NOTE: pooled_output is *not* trained for the transformer, do not use 68 | # w/o fine-tuning. A better option for now is to simply take a mean of 69 | # the sequence output 70 | ``` 71 | 72 | Currently available pretrained models are: 73 | 74 | * bert-base (Transformer model) 75 | * babbler-1900 ([UniRep](https://www.biorxiv.org/content/10.1101/589333v1) model) 76 | * xaa, xab, xac, xad, xae ([trRosetta](https://www.pnas.org/content/117/3/1496) model) 77 | 78 | If there is a particular pretrained model that you would like to use, please open an issue and we will try to add it! 79 | 80 | ### Embedding Proteins with a Pretrained Model 81 | 82 | Given an input fasta file, you can generate a `.npz` file containing embedding proteins via the `tape-embed` command. 83 | 84 | Suppose this is our input fasta file: 85 | 86 | ``` 87 | >seq1 88 | GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ 89 | >seq2 90 | RTIKVRILHAIGFEGGLMLLTIPMVAYAMDMTLFQAILLDLSMTTCILVYTFIFQWCYDILENR 91 | ``` 92 | 93 | Then we could embed it with the UniRep babbler-1900 model like so: 94 | 95 | ```bash 96 | tape-embed unirep my_input.fasta output_filename.npz babbler-1900 --tokenizer unirep 97 | ``` 98 | 99 | There is no need to download the pretrained model manually - it will be automatically downloaded if needed. In addition, note the change of tokenizer to the `unirep` tokenizer. UniRep uses a different vocabulary, and so requires this tokenzer. If you get a cublas runtime error, please double check that you changed tokenizer correctly. 100 | 101 | The embed function is fully batched and will automatically distribute across as many GPUs as the machine has available. On a Titan Xp, it can process around 200 sequences / second. 102 | 103 | Once we have the output file, we can load it into numpy like so: 104 | 105 | ```python 106 | arrays = np.load('output_filename.npz', allow_pickle=True) 107 | 108 | list(arrays.keys()) # Will output the name of the keys in your fasta file (or if unnamed then '0', '1', ...) 109 | 110 | arrays[] # Returns a dictionary with keys 'pooled' and 'avg', (or 'seq' if using the --full_sequence_embed flag) 111 | ``` 112 | 113 | By default to save memory TAPE returns the average of the sequence embedding along with the pooled embedding generated through the pooling function. For some models (like UniRep), the pooled embedding is trained, and so can be used out of the box. For other models (like the transformer), the pooled embedding is not trained, and so the average embedding should be used. We will be looking into methods of self-supervised training the pooled embedding for all models in the future. 114 | 115 | If you would like the full embedding rather than the average embedding, this can be specified to `tape-embed` by passing the `--full_sequence_embed` flag. 116 | 117 | ### Training a Language Model 118 | 119 | Tape provides two commands for training, `tape-train` and `tape-train-distributed`. The first command uses standard pytorch data distribution to distributed across all available GPUs. The second one uses `torch.distributed.launch`-style multiprocessing to distributed across the number of specified GPUs (and could also be used for distributing across multiple nodes). We generally recommend using the second command, as it can provide a 10-15% speedup, but both will work. 120 | 121 | To train the transformer on masked language modeling, for example, you could run this 122 | 123 | ```bash 124 | tape-train-distributed transformer masked_language_modeling --batch_size BS --learning_rate LR --fp16 --warmup_steps WS --nproc_per_node NGPU --gradient_accumulation_steps NSTEPS 125 | ``` 126 | 127 | There are a number of features used in training: 128 | 129 | * Distributed training via multiprocessing 130 | * Half-precision training 131 | * Gradient accumulation 132 | * Gradient-allreduce post accumulation 133 | * Automatic batch by sequence length 134 | 135 | The first feature you are likely to need is the `gradient_accumulation_steps`. TAPE specifies a relatively high batch size (1024) by default. This is the batch size that will be used *per backwards pass*. This number will be divided by the number of GPUs as well as the gradient accumulation steps. So with a batch size of 1024, 2 GPUs, and 1 gradient accumulation step, you will do 512 examples per GPU. If you run out of memory (and you likely will), TAPE provides a clear error message and will tell you to increase the gradient accumulation steps. 136 | 137 | There are additional features as well that are not talked about here. See `tape-train-distributed --help` for a list of all commands. 138 | 139 | ### Evaluating a Language Model 140 | 141 | Once you've trained a language model, you'll have a pretrained weight file located in the `results` folder. To evaluate this model, you can do one of two things. One option is to directly evaluate the language modeling accuracy / perplexity. `tape-train` will report the perplexity over the training and validation set at the end of each epoch. However, we find empirically that language modeling accuracy and perplexity are poor measures of performance on downstream tasks. Therefore, to evaluate the language model we strongly recommend training your model on one or all of our provided tasks. 142 | 143 | ### Training a Downstream Model 144 | 145 | Training a model on a downstream task can also be done with the `tape-train` command. Simply use the same syntax as with training a language model, adding the flag `--from_pretrained `. To train a pretrained transformer on secondary structure prediction, for example, you would run 146 | 147 | ```bash 148 | tape-train-distributed transformer secondary_structure \ 149 | --from_pretrained results/ \ 150 | --batch_size BS \ 151 | --learning_rate LR \ 152 | --fp16 \ 153 | --warmup_steps WS \ 154 | --nproc_per_node NGPU \ 155 | --gradient_accumulation_steps NSTEPS \ 156 | --num_train_epochs NEPOCH \ 157 | --eval_freq EF \ 158 | --save_freq SF 159 | ``` 160 | 161 | For training a downstream model, you will likely need to experiment with hyperparameters to achieve the best results (optimal hyperparameters vary per-task and per-model). The set of parameters to consider are 162 | 163 | ``` 164 | * Batch size 165 | * Learning rate 166 | * Warmup steps 167 | * Num train epochs 168 | ``` 169 | 170 | These can all have significant effects on performance, and by default are set to maximize performance on language modeling rather than downstream tasks. In addition the `eval_freq` and `save_freq` parameters can be useful, as they reduce the frequency of running validation passes and saving the model, respectively. Since downstream task epochs are much shorter (and you're likely to need more of them), it makes sense to increase these values so that training takes less time. 171 | 172 | ### Evaluating a Downstream Model 173 | 174 | To evaluate your downstream task model, we provide the `tape-eval` command. This command will output your model predictions along with a set of metrics that you specify. At the moment, we support mean squared error (`mse`), mean absolute error (`mae`), Spearman's rho (`spearmanr`), and accuracy (`accuracy`). Precision @ L/5 will be added shortly. 175 | 176 | The syntax for the command is 177 | 178 | ```bash 179 | tape-eval MODEL TASK TRAINED_MODEL_FOLDER --metrics METRIC1 METRIC2 ... 180 | ``` 181 | 182 | so to evaluate a transformer trained on trained secondary structure, we can run 183 | 184 | ```bash 185 | tape-eval transformer secondary_structure results/ --metrics accuracy 186 | ``` 187 | 188 | This will report the overall accuracy, and will also dump a `results.pkl` file into the trained model directory for you to analyze however you like. 189 | 190 | ### trRosetta 191 | 192 | We have recently re-implemented the trRosetta model from Yang et. al. (2020). A link to the original repository, which was used as a basis for this re-implementation, can be found [here](https://github.com/gjoni/trRosetta). We provide a pytorch implementation and dataset to allow you to play around with the model. Data is available [here](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/trrosetta.tar.gz). This is the same as the data in the original paper, however we've added train / val split files to allow you to train your own model reproducibly. To use this model 193 | 194 | ```python 195 | from tape import TRRosetta 196 | from tape.datasets import TRRosettaDataset 197 | 198 | # Download data and place it under `/trrosetta` 199 | 200 | train_data = TRRosettaDatset('', 'train') # will subsample MSAs 201 | valid_data = TRRosettaDatset('', 'valid') # will not subsample MSAs 202 | 203 | model = TRRosetta.from_pretrained('xaa') # valid choices are 'xaa', 'xab', 'xac', 'xad', 'xae'. Each corresponds to one of the ensemble models. 204 | 205 | batch = train_data.collate_fn([train_data[0]]) 206 | loss, predictions = model(**batch) 207 | ``` 208 | 209 | The predictions can be saved as `.npz` files and then fed into the [structure modeling scripts](https://yanglab.nankai.edu.cn/trRosetta/download/) provided by the Yang Lab. 210 | 211 | 212 | ### List of Models and Tasks 213 | 214 | The available models are: 215 | 216 | - `transformer` (pretrained available) 217 | - `resnet` 218 | - `lstm` 219 | - `unirep` (pretrained available) 220 | - `onehot` (no pretraining required) 221 | - `trrosetta` (pretrained available) 222 | 223 | The available standard tasks are: 224 | 225 | - `language_modeling` 226 | - `masked_language_modeling` 227 | - `secondary_structure` 228 | - `contact_prediction` 229 | - `remote_homology` 230 | - `fluorescence` 231 | - `stability` 232 | - `trrosetta` (can only be used with `trrosetta` model) 233 | 234 | The available models and tasks can be found in `tape/datasets.py` and `tape/models/modeling*.py`. 235 | 236 | ### Adding New Models and Tasks 237 | 238 | We have made some efforts to make the new repository easier to understand and extend. See the `examples` folder for an example on how to add a new model and a new task to TAPE. If there are other examples you would like or if there is something missing in the current examples, please open an issue. 239 | 240 | ## Data 241 | Data should be placed in the `./data` folder, although you may also specify a different data directory if you wish. 242 | 243 | The supervised data is around 120MB compressed and 2GB uncompressed. 244 | The unsupervised Pfam dataset is around 7GB compressed and 19GB uncompressed. The data for training is hosted on AWS. By default we provide data as LMDB - see `tape/datasets.py` for examples on loading the data. If you wish to download all of TAPE, run `download_data.sh` to do so. We also provide links to each individual dataset below in both LMDB format and JSON format. 245 | 246 | ### LMDB Data 247 | 248 | [Pretraining Corpus (Pfam)](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz) 249 | 250 | ### Raw Data 251 | 252 | Raw data files are stored in JSON format for maximum portability. This data is JSON-ified, which removes certain constructs (in particular numpy arrays). As a result they cannot be directly loaded into the provided pytorch datasets (although the conversion should be quite easy by simply adding calls to `np.array`). 253 | 254 | [Pretraining Corpus (Pfam)](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/stability.tar.gz) 255 | 256 | 257 | ## Leaderboard 258 | 259 | We will soon have a leaderboard available for tracking progress on the core five TAPE tasks, so check back for a link here. See the main tables in our paper for a sense of where performance stands at this point. Publication on the leaderboard will be contingent on meeting the following citation guidelines. 260 | 261 | In the meantime, here's a temporary leaderboard for each task. All reported models on this leaderboard use unsupervised pretraining. 262 | 263 | ### Secondary Structure 264 | 265 | | Ranking | Model | Accuracy (3-class) | 266 | |:-:|:-:|:-:| 267 | | 1. | One Hot + Alignment | 0.80 | 268 | | 2. | LSTM | 0.75 | 269 | | 2. | ResNet | 0.75 | 270 | | 4. | Transformer | 0.73 | 271 | | 4. | Bepler | 0.73 | 272 | | 4. | Unirep | 0.73 | 273 | | 7. | One Hot | 0.69 | 274 | 275 | ### Contact Prediction 276 | 277 | | Ranking | Model | L/5 Medium + Long Range | 278 | |:-:|:-:|:-:| 279 | | 1. | One Hot + Alignment | 0.64 | 280 | | 2. | Bepler | 0.40 | 281 | | 3. | LSTM | 0.39 | 282 | | 4. | Transformer | 0.36 | 283 | | 5. | Unirep | 0.34 | 284 | | 6. | ResNet | 0.29 | 285 | | 6. | One Hot | 0.29 | 286 | 287 | ### Remote Homology Detection 288 | 289 | | Ranking | Model | Top 1 Accuracy | 290 | |:-:|:-:|:-:| 291 | | 1. | LSTM | 0.26 | 292 | | 2. | Unirep | 0.23 | 293 | | 3. | Transformer | 0.21 | 294 | | 4. | Bepler | 0.17 | 295 | | 4. | ResNet | 0.17 | 296 | | 6. | One Hot + Alignment | 0.09 | 297 | | 6. | One Hot | 0.09 | 298 | 299 | ### Fluorescence 300 | 301 | | Ranking | Model | Spearman's rho | 302 | |:-:|:-:|:-:| 303 | | 1. | Transformer | 0.68 | 304 | | 2. | LSTM | 0.67 | 305 | | 2. | Unirep | 0.67 | 306 | | 4. | Bepler | 0.33 | 307 | | 5. | ResNet | 0.21 | 308 | | 6. | One Hot | 0.14 | 309 | 310 | ### Stability 311 | 312 | | Ranking | Model | Spearman's rho | 313 | |:-:|:-:|:-:| 314 | | 1. | Transformer | 0.73 | 315 | | 1. | Unirep | 0.73 | 316 | | 1. | ResNet | 0.73 | 317 | | 4. | LSTM | 0.69 | 318 | | 5. | Bepler | 0.64 | 319 | | 6. | One Hot | 0.19 | 320 | 321 | ## Citation Guidelines 322 | 323 | If you find TAPE useful, please cite our corresponding paper. Additionally, __anyone using the datasets provided in TAPE must describe and cite all dataset components they use__. Producing these data is time and resource intensive, and we insist this be recognized by all TAPE users. For convenience,`data_refs.bib` contains all necessary citations. We also provide each individual citation below. 324 | 325 | __TAPE (Our paper):__ 326 | ``` 327 | @inproceedings{tape2019, 328 | author = {Rao, Roshan and Bhattacharya, Nicholas and Thomas, Neil and Duan, Yan and Chen, Xi and Canny, John and Abbeel, Pieter and Song, Yun S}, 329 | title = {Evaluating Protein Transfer Learning with TAPE}, 330 | booktitle = {Advances in Neural Information Processing Systems} 331 | year = {2019} 332 | } 333 | ``` 334 | 335 | __Pfam (Pretraining):__ 336 | ``` 337 | @article{pfam, 338 | author = {El-Gebali, Sara and Mistry, Jaina and Bateman, Alex and Eddy, Sean R and Luciani, Aur{\'{e}}lien and Potter, Simon C and Qureshi, Matloob and Richardson, Lorna J and Salazar, Gustavo A and Smart, Alfredo and Sonnhammer, Erik L L and Hirsh, Layla and Paladin, Lisanna and Piovesan, Damiano and Tosatto, Silvio C E and Finn, Robert D}, 339 | doi = {10.1093/nar/gky995}, 340 | file = {::}, 341 | issn = {0305-1048}, 342 | journal = {Nucleic Acids Research}, 343 | keywords = {community,protein domains,tandem repeat sequences}, 344 | number = {D1}, 345 | pages = {D427--D432}, 346 | publisher = {Narnia}, 347 | title = {{The Pfam protein families database in 2019}}, 348 | url = {https://academic.oup.com/nar/article/47/D1/D427/5144153}, 349 | volume = {47}, 350 | year = {2019} 351 | } 352 | ``` 353 | __SCOPe: (Remote Homology and Contact)__- 354 | ``` 355 | @article{scop, 356 | title={SCOPe: Structural Classification of Proteins—extended, integrating SCOP and ASTRAL data and classification of new structures}, 357 | author={Fox, Naomi K and Brenner, Steven E and Chandonia, John-Marc}, 358 | journal={Nucleic acids research}, 359 | volume={42}, 360 | number={D1}, 361 | pages={D304--D309}, 362 | year={2013}, 363 | publisher={Oxford University Press} 364 | } 365 | ``` 366 | __PDB: (Secondary Structure and Contact)__ 367 | ``` 368 | @article{pdb, 369 | title={The protein data bank}, 370 | author={Berman, Helen M and Westbrook, John and Feng, Zukang and Gilliland, Gary and Bhat, Talapady N and Weissig, Helge and Shindyalov, Ilya N and Bourne, Philip E}, 371 | journal={Nucleic acids research}, 372 | volume={28}, 373 | number={1}, 374 | pages={235--242}, 375 | year={2000}, 376 | publisher={Oxford University Press} 377 | } 378 | ``` 379 | 380 | __CASP12: (Secondary Structure and Contact)__ 381 | ``` 382 | @article{casp, 383 | author = {Moult, John and Fidelis, Krzysztof and Kryshtafovych, Andriy and Schwede, Torsten and Tramontano, Anna}, 384 | doi = {10.1002/prot.25415}, 385 | issn = {08873585}, 386 | journal = {Proteins: Structure, Function, and Bioinformatics}, 387 | keywords = {CASP,community wide experiment,protein structure prediction}, 388 | pages = {7--15}, 389 | publisher = {John Wiley {\&} Sons, Ltd}, 390 | title = {{Critical assessment of methods of protein structure prediction (CASP)-Round XII}}, 391 | url = {http://doi.wiley.com/10.1002/prot.25415}, 392 | volume = {86}, 393 | year = {2018} 394 | } 395 | ``` 396 | 397 | __NetSurfP2.0: (Secondary Structure)__ 398 | ``` 399 | @article{netsurfp, 400 | title={NetSurfP-2.0: Improved prediction of protein structural features by integrated deep learning}, 401 | author={Klausen, Michael Schantz and Jespersen, Martin Closter and Nielsen, Henrik and Jensen, Kamilla Kjaergaard and Jurtz, Vanessa Isabell and Soenderby, Casper Kaae and Sommer, Morten Otto Alexander and Winther, Ole and Nielsen, Morten and Petersen, Bent and others}, 402 | journal={Proteins: Structure, Function, and Bioinformatics}, 403 | year={2019}, 404 | publisher={Wiley Online Library} 405 | } 406 | ``` 407 | 408 | __ProteinNet: (Contact)__ 409 | ``` 410 | @article{proteinnet, 411 | title={ProteinNet: a standardized data set for machine learning of protein structure}, 412 | author={AlQuraishi, Mohammed}, 413 | journal={arXiv preprint arXiv:1902.00249}, 414 | year={2019} 415 | } 416 | ``` 417 | 418 | __Fluorescence:__ 419 | ``` 420 | @article{sarkisyan2016, 421 | title={Local fitness landscape of the green fluorescent protein}, 422 | author={Sarkisyan, Karen S and Bolotin, Dmitry A and Meer, Margarita V and Usmanova, Dinara R and Mishin, Alexander S and Sharonov, George V and Ivankov, Dmitry N and Bozhanova, Nina G and Baranov, Mikhail S and Soylemez, Onuralp and others}, 423 | journal={Nature}, 424 | volume={533}, 425 | number={7603}, 426 | pages={397}, 427 | year={2016}, 428 | publisher={Nature Publishing Group} 429 | } 430 | ``` 431 | 432 | __Stability:__ 433 | ``` 434 | @article{rocklin2017, 435 | title={Global analysis of protein folding using massively parallel design, synthesis, and testing}, 436 | author={Rocklin, Gabriel J and Chidyausiku, Tamuka M and Goreshnik, Inna and Ford, Alex and Houliston, Scott and Lemak, Alexander and Carter, Lauren and Ravichandran, Rashmi and Mulligan, Vikram K and Chevalier, Aaron and others}, 437 | journal={Science}, 438 | volume={357}, 439 | number={6347}, 440 | pages={168--175}, 441 | year={2017}, 442 | publisher={American Association for the Advancement of Science} 443 | } 444 | ``` 445 | -------------------------------------------------------------------------------- /cleanup_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for f in results/*; do 4 | if [[ ! `compgen -G "$f/pytorch_model*"` ]]; then 5 | echo rm -rf $f 6 | rm -rf $f 7 | echo rm -rf "logs/$(basename $f)" 8 | rm -rf "logs/$(basename $f)" 9 | fi 10 | done 11 | -------------------------------------------------------------------------------- /config/resnet_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "block": "bottleneck", 3 | "base_model": "resnet", 4 | "finetuning_task": null, 5 | "groups": 1, 6 | "hidden_dropout_prob": 0.1, 7 | "hidden_size": 64, 8 | "initial_hidden_dimension": 64, 9 | "initializer_range": 0.02, 10 | "layer_norm_eps": 1e-12, 11 | "layers": [ 12 | 3, 13 | 4, 14 | 23, 15 | 3 16 | ], 17 | "hidden_act": "relu", 18 | "max_position_embeddings": 8096, 19 | "num_labels": 2, 20 | "output_attentions": false, 21 | "output_hidden_states": false, 22 | "output_size": 2048, 23 | "pruned_heads": {}, 24 | "replace_stride_with_dilation": false, 25 | "torchscript": false, 26 | "type_vocab_size": 1, 27 | "vocab_size": 8000, 28 | "width_per_group": 64, 29 | "zero_init_residual": false 30 | } 31 | -------------------------------------------------------------------------------- /config/transformer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model": "transformer", 3 | "attention_probs_dropout_prob": 0.1, 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 8192, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "type_vocab_size": 1, 13 | "vocab_size": 8000, 14 | "layer_norm_eps": 1e-12 15 | } 16 | -------------------------------------------------------------------------------- /config/transformer_tiny_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_model": "transformer", 3 | "attention_probs_dropout_prob": 0.1, 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 64, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 256, 9 | "max_position_embeddings": 8192, 10 | "num_attention_heads": 8, 11 | "num_hidden_layers": 6, 12 | "type_vocab_size": 1, 13 | "vocab_size": 8000, 14 | "layer_norm_eps": 1e-12 15 | } 16 | -------------------------------------------------------------------------------- /data_refs.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{tape, 2 | author = {Rao, Roshan and Bhattacharya, Nicholas and Thomas, Neil and Duan, Yan and Chen, Xi and Canny, John and Abbeel, Pieter and Song, Yun S}, 3 | title = {Evaluating Protein Transfer Learning with TAPE}, 4 | booktitle = {Advances in Neural Information Processing Systems} 5 | year = {2019} 6 | } 7 | @article{pfam, 8 | author = {El-Gebali, Sara and Mistry, Jaina and Bateman, Alex and Eddy, Sean R and Luciani, Aur{\'{e}}lien and Potter, Simon C and Qureshi, Matloob and Richardson, Lorna J and Salazar, Gustavo A and Smart, Alfredo and Sonnhammer, Erik L L and Hirsh, Layla and Paladin, Lisanna and Piovesan, Damiano and Tosatto, Silvio C E and Finn, Robert D}, 9 | doi = {10.1093/nar/gky995}, 10 | file = {::}, 11 | issn = {0305-1048}, 12 | journal = {Nucleic Acids Research}, 13 | keywords = {community,protein domains,tandem repeat sequences}, 14 | number = {D1}, 15 | pages = {D427--D432}, 16 | publisher = {Narnia}, 17 | title = {{The Pfam protein families database in 2019}}, 18 | url = {https://academic.oup.com/nar/article/47/D1/D427/5144153}, 19 | volume = {47}, 20 | year = {2019} 21 | } 22 | @article{scope, 23 | title={SCOPe: Structural Classification of Proteins—extended, integrating SCOP and ASTRAL data and classification of new structures}, 24 | author={Fox, Naomi K and Brenner, Steven E and Chandonia, John-Marc}, 25 | journal={Nucleic acids research}, 26 | volume={42}, 27 | number={D1}, 28 | pages={D304--D309}, 29 | year={2013}, 30 | publisher={Oxford University Press} 31 | } 32 | @article{pdb, 33 | title={The protein data bank}, 34 | author={Berman, Helen M and Westbrook, John and Feng, Zukang and Gilliland, Gary and Bhat, Talapady N and Weissig, Helge and Shindyalov, Ilya N and Bourne, Philip E}, 35 | journal={Nucleic acids research}, 36 | volume={28}, 37 | number={1}, 38 | pages={235--242}, 39 | year={2000}, 40 | publisher={Oxford University Press} 41 | } 42 | @article{casp12, 43 | author = {Moult, John and Fidelis, Krzysztof and Kryshtafovych, Andriy and Schwede, Torsten and Tramontano, Anna}, 44 | doi = {10.1002/prot.25415}, 45 | issn = {08873585}, 46 | journal = {Proteins: Structure, Function, and Bioinformatics}, 47 | keywords = {CASP,community wide experiment,protein structure prediction}, 48 | pages = {7--15}, 49 | publisher = {John Wiley {\&} Sons, Ltd}, 50 | title = {{Critical assessment of methods of protein structure prediction (CASP)-Round XII}}, 51 | url = {http://doi.wiley.com/10.1002/prot.25415}, 52 | volume = {86}, 53 | year = {2018} 54 | } 55 | @article{proteinnet, 56 | title={ProteinNet: a standardized data set for machine learning of protein structure}, 57 | author={AlQuraishi, Mohammed}, 58 | journal={arXiv preprint arXiv:1902.00249}, 59 | year={2019} 60 | } 61 | @article{fluorescence-sarkisyan, 62 | title={Local fitness landscape of the green fluorescent protein}, 63 | author={Sarkisyan, Karen S and Bolotin, Dmitry A and Meer, Margarita V and Usmanova, Dinara R and Mishin, Alexander S and Sharonov, George V and Ivankov, Dmitry N and Bozhanova, Nina G and Baranov, Mikhail S and Soylemez, Onuralp and others}, 64 | journal={Nature}, 65 | volume={533}, 66 | number={7603}, 67 | pages={397}, 68 | year={2016}, 69 | publisher={Nature Publishing Group} 70 | } 71 | @article{stability-rocklin, 72 | title={Global analysis of protein folding using massively parallel design, synthesis, and testing}, 73 | author={Rocklin, Gabriel J and Chidyausiku, Tamuka M and Goreshnik, Inna and Ford, Alex and Houliston, Scott and Lemak, Alexander and Carter, Lauren and Ravichandran, Rashmi and Mulligan, Vikram K and Chevalier, Aaron and others}, 74 | journal={Science}, 75 | volume={357}, 76 | number={6347}, 77 | pages={168--175}, 78 | year={2017}, 79 | publisher={American Association for the Advancement of Science} 80 | } 81 | @article{netsurfp, 82 | title={NetSurfP-2.0: Improved prediction of protein structural features by integrated deep learning}, 83 | author={Klausen, Michael Schantz and Jespersen, Martin Closter and Nielsen, Henrik and Jensen, Kamilla Kjaergaard and Jurtz, Vanessa Isabell and Soenderby, Casper Kaae and Sommer, Morten Otto Alexander and Winther, Ole and Nielsen, Morten and Petersen, Bent and others}, 84 | journal={Proteins: Structure, Function, and Bioinformatics}, 85 | year={2019}, 86 | publisher={Wiley Online Library} 87 | } 88 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./data 2 | 3 | # Download pfam 4 | while true; do 5 | read -p "Do you wish to download and unzip the pretraining corpus? It is 7.7GB compressed and 19GB uncompressed? [y/n]" yn 6 | case $yn in 7 | [Yy]* ) wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.tar.gz; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;; 8 | [Nn]* ) exit;; 9 | * ) echo "Please answer yes (Y/y) or no (N/n).";; 10 | esac 11 | done 12 | 13 | # Download Vocab/Model files 14 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.model 15 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.vocab 16 | 17 | mv pfam.model data 18 | mv pfam.vocab data 19 | 20 | # Download Data Files 21 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz 22 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/proteinnet.tar.gz 23 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/remote_homology.tar.gz 24 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz 25 | wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz 26 | 27 | tar -xzf secondary_structure.tar.gz -C ./data 28 | tar -xzf proteinnet.tar.gz -C ./data 29 | tar -xzf remote_homology.tar.gz -C ./data 30 | tar -xzf fluorescence.tar.gz -C ./data 31 | tar -xzf stability.tar.gz -C ./data 32 | 33 | rm secondary_structure.tar.gz 34 | rm proteinnet.tar.gz 35 | rm remote_homology.tar.gz 36 | rm fluorescence.tar.gz 37 | rm stability.tar.gz 38 | -------------------------------------------------------------------------------- /download_data_aws.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./data 2 | 3 | # Download pfam 4 | while true; do 5 | read -p "Do you wish to download and unzip the pretraining corpus? It is 7.7GB compressed and 19GB uncompressed. [y/n]" yn 6 | case $yn in 7 | [Yy]* ) aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.tar.gz .; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;; 8 | [Nn]* ) exit;; 9 | * ) echo "Please answer yes (Y/y) or no (N/n).";; 10 | esac 11 | done 12 | 13 | echo "Downloading BPE Vocab/Model files" 14 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.model . && mv pfam.model data 15 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.vocab . && mv pfam.vocab data 16 | 17 | # Download Data Files 18 | echo "Download TAPE task datasets" 19 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz . \ 20 | && tar -xzf secondary_structure.tar.gz -C ./data \ 21 | && rm secondary_structure.tar.gz 22 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/proteinnet.tar.gz . \ 23 | && tar -xzf proteinnet.tar.gz -C ./data \ 24 | && rm proteinnet.tar.gz 25 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/remote_homology.tar.gz . \ 26 | && tar -xzf remote_homology.tar.gz -C ./data \ 27 | && rm remote_homology.tar.gz 28 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/fluorescence.tar.gz . \ 29 | && tar -xzf fluorescence.tar.gz -C ./data \ 30 | && rm fluorescence.tar.gz 31 | aws s3 cp s3://songlabdata/proteindata/data_pytorch/stability.tar.gz . \ 32 | && tar -xzf stability.tar.gz -C ./data \ 33 | && rm stability.tar.gz 34 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tape 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - backcall=0.1.0=py37_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.5.15=1 10 | - certifi=2019.6.16=py37_1 11 | - cffi=1.12.3=py37h2e261b9_0 12 | - cudatoolkit=9.0=h13b8566_0 13 | - cython=0.29.13=py37he6710b0_0 14 | - decorator=4.4.0=py37_1 15 | - freetype=2.9.1=h8a8886c_1 16 | - intel-openmp=2019.4=243 17 | - ipython=7.8.0=py37h39e3cac_0 18 | - ipython_genutils=0.2.0=py37_0 19 | - jedi=0.15.1=py37_0 20 | - jpeg=9b=h024ee3a_2 21 | - libedit=3.1.20181209=hc058e9b_0 22 | - libffi=3.2.1=hd88cf55_4 23 | - libgcc-ng=9.1.0=hdf63c60_0 24 | - libgfortran-ng=7.3.0=hdf63c60_0 25 | - libpng=1.6.37=hbc83047_0 26 | - libstdcxx-ng=9.1.0=hdf63c60_0 27 | - libtiff=4.0.10=h2733197_2 28 | - mkl=2019.4=243 29 | - mkl-service=2.0.2=py37h7b6447c_0 30 | - mkl_fft=1.0.14=py37ha843d7b_0 31 | - mkl_random=1.0.2=py37hd81dba3_0 32 | - ncurses=6.1=he6710b0_1 33 | - ninja=1.9.0=py37hfd86e86_0 34 | - numpy=1.16.4=py37h7e9f1db_0 35 | - numpy-base=1.16.4=py37hde5b4d6_0 36 | - olefile=0.46=py37_0 37 | - openssl=1.1.1c=h7b6447c_1 38 | - parso=0.5.1=py_0 39 | - pexpect=4.7.0=py37_0 40 | - pickleshare=0.7.5=py37_0 41 | - pillow=6.1.0=py37h34e0f95_0 42 | - pip=19.2.2=py37_0 43 | - prompt_toolkit=2.0.9=py37_0 44 | - ptyprocess=0.6.0=py37_0 45 | - pycparser=2.19=py37_0 46 | - pygments=2.4.2=py_0 47 | - python=3.7.4=h265db76_1 48 | - pytorch=1.1.0=py3.7_cuda9.0.176_cudnn7.5.1_0 49 | - readline=7.0=h7b6447c_5 50 | - setuptools=41.0.1=py37_0 51 | - six=1.12.0=py37_0 52 | - sqlite=3.29.0=h7b6447c_0 53 | - tk=8.6.8=hbc83047_0 54 | - torchvision=0.3.0=py37_cu9.0.176_1 55 | - traitlets=4.3.2=py37_0 56 | - wcwidth=0.1.7=py37_0 57 | - wheel=0.33.4=py37_0 58 | - xz=5.2.4=h14c3975_4 59 | - zlib=1.2.11=h7b6447c_3 60 | - zstd=1.3.7=h0b5b093_0 61 | - pip: 62 | - absl-py==0.8.0 63 | - apex==0.1 64 | - atomicwrites==1.3.0 65 | - attrs==19.1.0 66 | - boto3==1.9.221 67 | - botocore==1.12.221 68 | - chardet==3.0.4 69 | - click==7.0 70 | - docutils==0.15.2 71 | - grpcio==1.23.0 72 | - idna==2.8 73 | - importlib-metadata==0.20 74 | - ipdb==0.12.2 75 | - jmespath==0.9.4 76 | - joblib==0.13.2 77 | - lmdb==0.97 78 | - markdown==3.1.1 79 | - more-itertools==7.2.0 80 | - packaging==19.1 81 | - pluggy==0.12.0 82 | - protobuf==3.9.1 83 | - py==1.8.0 84 | - pyparsing==2.4.2 85 | - pytest==5.1.2 86 | - python-dateutil==2.8.0 87 | - pytorch-transformers==1.2.0 88 | - regex==2019.8.19 89 | - requests==2.22.0 90 | - s3transfer==0.2.1 91 | - sacremoses==0.0.33 92 | - sentencepiece==0.1.83 93 | - tensorboard==1.14.0 94 | - tensorboardx==1.8 95 | - tqdm==4.35.0 96 | - urllib3==1.25.3 97 | - werkzeug==0.15.5 98 | - zipp==0.6.0 99 | prefix: /home/rmrao/miniconda3/envs/tape 100 | 101 | -------------------------------------------------------------------------------- /examples/adding_model.py: -------------------------------------------------------------------------------- 1 | """Example of how to add a model in tape. 2 | 3 | This file shows an example of how to add a new model to the tape training 4 | pipeline. tape models follow the huggingface API and so require: 5 | 6 | - A config class 7 | - An abstract model class 8 | - A model class to output sequence and pooled embeddings 9 | - Task-specific classes for each individual task 10 | 11 | This will walkthrough how to create each of these, with a task-specific class for 12 | secondary structure prediction. You can look at the other task-specific classes 13 | defined in e.g. tape/models/modeling_bert.py for examples on how to 14 | define these other task-specific models for e.g. contact prediction or fluorescence 15 | prediction. 16 | 17 | In addition to defining these models, this shows how to register the model to 18 | tape so that you can use the same training machinery to run your tasks. 19 | """ 20 | 21 | 22 | import torch 23 | import torch.nn as nn 24 | from tape import ProteinModel, ProteinConfig 25 | from tape.models.modeling_utils import SequenceToSequenceClassificationHead 26 | from tape.registry import registry 27 | 28 | 29 | class SimpleConvConfig(ProteinConfig): 30 | """ The config class for our new model. This should be a subclass of 31 | ProteinConfig. It's a very straightforward definition, which just 32 | accepts the arguments that you would like the model to take in 33 | and assigns them to the class. 34 | 35 | Note - if you do not initialize using a model config file, you 36 | must provide defaults for all arguments. 37 | """ 38 | 39 | def __init__(self, 40 | vocab_size: int = 30, 41 | filter_size: int = 128, 42 | kernel_size: int = 5, 43 | num_layers: int = 3, 44 | **kwargs): 45 | super().__init__(**kwargs) 46 | self.vocab_size = vocab_size 47 | self.filter_size = filter_size 48 | self.kernel_size = kernel_size 49 | self.num_layers = num_layers 50 | 51 | 52 | class SimpleConvAbstractModel(ProteinModel): 53 | """ All your models will inherit from this one - it's used to define the 54 | config_class of the model set and also to define the base_model_prefix. 55 | This is used to allow easy loading/saving into different models. 56 | """ 57 | config_class = SimpleConvConfig 58 | base_model_prefix = 'simple_conv' 59 | 60 | 61 | class SimpleConvModel(SimpleConvAbstractModel): 62 | """ The base model class. This will return embeddings of the input amino 63 | acid sequence. It is not used for any specific task - you'll have to 64 | define task-specific models further on. Note that there is a little 65 | more machinery in the models we define, but this is a stripped down 66 | version that should give you what you need 67 | """ 68 | # init expects only a single argument - the config 69 | def __init__(self, config: SimpleConvConfig): 70 | super().__init__(config) 71 | self.embedding = nn.Embedding(config.vocab_size, config.filter_size) 72 | self.encoder = nn.Sequential( 73 | *[nn.Conv1d(config.filter_size, config.filter_size, config.kernel_size, 74 | padding=config.kernel_size // 2) 75 | for _ in range(config.num_layers)]) 76 | 77 | self.pooler = nn.AdaptiveAvgPool1d(1) 78 | 79 | def forward(self, input_ids, input_mask=None): 80 | """ Runs the forward model pass 81 | 82 | Args: 83 | input_ids (Tensor[long]): 84 | Tensor of input symbols of shape [batch_size x protein_length] 85 | input_mask (Tensor[bool]): 86 | Tensor of booleans w/ same shape as input_ids, indicating whether 87 | a given sequence position is valid 88 | 89 | Returns: 90 | sequence_embedding (Tensor[float]): 91 | Embedded sequence of shape [batch_size x protein_length x hidden_size] 92 | pooled_embedding (Tensor[float]): 93 | Pooled representation of the entire sequence of size [batch_size x hidden_size] 94 | """ 95 | 96 | # Embed the input_ids 97 | embed = self.embedding(input_ids) 98 | 99 | # Pass embeddings through the encoder - you may want to use 100 | # the input mask here (not used in this example, but is generally 101 | # used in most of our models). 102 | embed = embed.permute(0, 2, 1) # Conv layers are NCW 103 | sequence_embedding = self.encoder(embed) 104 | 105 | # Compute the pooled embedding - you can do arbitrarily complicated 106 | # things to do this, here we're just going to mean-pool the result 107 | pooled_embedding = self.pooler(sequence_embedding).squeeze(2) 108 | 109 | # Re-permute the sequence embedding to be NWC 110 | sequence_embedding = sequence_embedding.permute(0, 2, 1).contiguous() 111 | 112 | outputs = (sequence_embedding, pooled_embedding) 113 | return outputs 114 | 115 | 116 | # This registers the model to a specific task, allowing you to use all of TAPE's 117 | # machinery to train it. 118 | @registry.register_task_model('secondary_structure', 'simple-conv') 119 | class SimpleConvForSequenceToSequenceClassification(SimpleConvAbstractModel): 120 | 121 | def __init__(self, config: SimpleConvConfig): 122 | super().__init__(config) 123 | # the name of this variable *must* match the base_model_prefix 124 | self.simple_conv = SimpleConvModel(config) 125 | # The seq2seq classification head. First argument must match the 126 | # output embedding size of the SimpleConvModel. The second argument 127 | # is present in every config (it's an argument of ProteinConfig) 128 | # and is used for classification tasks. 129 | self.classify = SequenceToSequenceClassificationHead( 130 | config.filter_size, config.num_labels) 131 | 132 | def forward(self, input_ids, input_mask=None, targets=None): 133 | """ Runs the forward model pass and may compute the loss if targets 134 | is present. Note that this does expect the third argument to be named 135 | `targets`. You can look at the different defined models to see 136 | what different tasks expect the label name to be. 137 | 138 | Args: 139 | input_ids (Tensor[long]): 140 | Tensor of input symbols of shape [batch_size x protein_length] 141 | input_mask (Tensor[bool]): 142 | Tensor of booleans w/ same shape as input_ids, indicating whether 143 | a given sequence position is valid 144 | targets (Tensor[long], optional): 145 | Tensor of output target labels of shape [batch_size x protein_length] 146 | """ 147 | outputs = self.simple_conv(input_ids, input_mask) 148 | sequence_embedding = outputs[0] 149 | 150 | prediction = self.classify(sequence_embedding) 151 | 152 | outputs = (prediction,) 153 | 154 | if targets is not None: 155 | loss = nn.CrossEntropyLoss(ignore_index=-1)( 156 | prediction.view(-1, prediction.size(2)), targets.view(-1)) 157 | # cast to float b/c float16 does not have argmax support 158 | is_correct = prediction.float().argmax(-1) == targets 159 | is_valid_position = targets != -1 160 | 161 | # cast to float b/c otherwise torch does integer division 162 | num_correct = torch.sum(is_correct * is_valid_position).float() 163 | accuracy = num_correct / torch.sum(is_valid_position).float() 164 | metrics = {'acc': accuracy} 165 | 166 | outputs = ((loss, metrics),) + outputs 167 | 168 | return outputs # ((loss, metrics)), prediction 169 | 170 | 171 | if __name__ == '__main__': 172 | """ To actually run the model, you can do one of two things. You can 173 | simply import the appropriate run function from tape.main. The 174 | possible functions are `run_train`, `run_train_distributed`, `run_eval`, 175 | and `run_embed`. Alternatively, you can simply place this file inside 176 | the `tape/models` directory, where it will be auto-imported 177 | into tape. 178 | """ 179 | from tape.main import run_train 180 | run_train() 181 | -------------------------------------------------------------------------------- /examples/adding_task.py: -------------------------------------------------------------------------------- 1 | """Example of how to add a task in tape. 2 | 3 | In order to add a new task to TAPE, you must do a few things: 4 | 5 | 1) Serialize the data into different splits (e.g. train, val, test) and place 6 | them in an appropriate folder inside the tape data directory. 7 | 2) Define a dataset as a subclass of a torch Dataset. This should load and return 8 | the data from your splits. 9 | 3) Define a collate_fn as a method of your dataset which will describe how 10 | to load in a batch of data (pytorch does not automatically batch variable 11 | length sequences). 12 | 4) Register the task with TAPE 13 | 5) Register models to the task 14 | 15 | This file walks through how to create the 8-class secondary structure prediction 16 | task using the pre-existing secondary structure data. 17 | 18 | """ 19 | 20 | from typing import Union, List, Tuple, Any, Dict 21 | import torch 22 | from torch.utils.data import Dataset 23 | from pathlib import Path 24 | import numpy as np 25 | 26 | from tape.datasets import LMDBDataset, pad_sequences 27 | from tape.registry import registry 28 | from tape.tokenizers import TAPETokenizer 29 | from tape import ProteinBertForSequenceToSequenceClassification 30 | 31 | 32 | # Register the dataset as a new TAPE task. Since it's a classification task 33 | # we need to tell TAPE how many labels the downstream model will use. If this 34 | # wasn't a classification task, that argument could simply be dropped. 35 | @registry.register_task('secondary_structure_8', num_labels=8) 36 | class SecondaryStructure8ClassDataset(Dataset): 37 | """ Defines the 8-class secondary structure prediction dataset. 38 | 39 | Args: 40 | data_path (Union[str, Path]): Path to tape data directory. By default, this is 41 | assumed to be `./data`. Can be altered on the command line with the --data_dir 42 | flag. 43 | split (str): The specific dataset split to load often . In the 44 | case of secondary structure, there are three test datasets so each of these 45 | has a separate split flag. 46 | tokenizer (str): The model tokenizer to use when returning tokenized indices. 47 | in_memory (bool): Whether to load the entire dataset into memory or to keep 48 | it on disk. 49 | """ 50 | 51 | def __init__(self, 52 | data_path: Union[str, Path], 53 | split: str, 54 | tokenizer: Union[str, TAPETokenizer] = 'iupac', 55 | in_memory: bool = False): 56 | 57 | if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513'): 58 | raise ValueError(f"Unrecognized split: {split}. Must be one of " 59 | f"['train', 'valid', 'casp12', " 60 | f"'ts115', 'cb513']") 61 | 62 | if isinstance(tokenizer, str): 63 | # If you get tokenizer in as a string, create an actual tokenizer 64 | tokenizer = TAPETokenizer(vocab=tokenizer) 65 | self.tokenizer = tokenizer 66 | 67 | # Define the path to the data file. There are three helper datasets 68 | # that you can import from tape.datasets - a FastaDataset, 69 | # a JSONDataset, and an LMDBDataset. You can use these to load raw 70 | # data from your files (or of course, you can do this manually). 71 | data_path = Path(data_path) 72 | data_file = f'secondary_structure/secondary_structure_{split}.lmdb' 73 | self.data = LMDBDataset(data_path / data_file, in_memory=in_memory) 74 | 75 | def __len__(self) -> int: 76 | return len(self.data) 77 | 78 | def __getitem__(self, index: int): 79 | """ Return an item from the dataset. We've got an LMDBDataset that 80 | will load the raw data and return dictionaries. We have to then 81 | take that, load the keys that we need, tokenize and convert 82 | the amino acids to ids, and return the result. 83 | """ 84 | item = self.data[index] 85 | # tokenize + convert to numpy 86 | token_ids = self.tokenizer.encode(item['primary']) 87 | # this will be the attention mask - we'll pad it out in 88 | # collate_fn 89 | input_mask = np.ones_like(token_ids) 90 | 91 | # pad with -1s because of cls/sep tokens 92 | labels = np.asarray(item['ss8'], np.int64) 93 | labels = np.pad(labels, (1, 1), 'constant', constant_values=-1) 94 | 95 | return token_ids, input_mask, labels 96 | 97 | def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: 98 | """ Define a collate_fn to convert the variable length sequences into 99 | a batch of torch tensors. token ids and mask should be padded with 100 | zeros. Labels for classification should be padded with -1. 101 | 102 | This takes in a list of outputs from the dataset's __getitem__ 103 | method. You can use the `pad_sequences` helper function to pad 104 | a list of numpy arrays. 105 | """ 106 | input_ids, input_mask, ss_label = tuple(zip(*batch)) 107 | input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) 108 | input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) 109 | ss_label = torch.from_numpy(pad_sequences(ss_label, -1)) 110 | 111 | output = {'input_ids': input_ids, 112 | 'input_mask': input_mask, 113 | 'targets': ss_label} 114 | 115 | return output 116 | 117 | 118 | registry.register_task_model( 119 | 'secondary_structure_8', 'transformer', ProteinBertForSequenceToSequenceClassification) 120 | 121 | 122 | if __name__ == '__main__': 123 | """ To actually run the task, you can do one of two things. You can 124 | simply import the appropriate run function from tape.main. The 125 | possible functions are `run_train`, `run_train_distributed`, and 126 | `run_eval`. Alternatively, you can add this dataset directly to 127 | tape/datasets.py. 128 | """ 129 | from tape.main import run_train 130 | run_train() 131 | -------------------------------------------------------------------------------- /gridsearch_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "transformer", 3 | "model_config_file": null, 4 | "data_dir": "./data", 5 | "vocab_file": "data/pfam.model", 6 | "output_dir": "./results", 7 | "no_cuda": false, 8 | "local_rank": -1, 9 | "seed": 42, 10 | "tokenizer": "bpe", 11 | "num_workers": 16, 12 | "debug": false, 13 | "task": "secondary_structure", 14 | "learning_rate": [1e-5, 1e-4], 15 | "batch_size": [64, 128, 256], 16 | "num_train_epochs": 30, 17 | "patience": 5, 18 | "num_log_iter": 20, 19 | "fp16": false, 20 | "warmup_steps": 1000, 21 | "gradient_accumulation_steps": 16, 22 | "lr_schedule": "warmup_constant", 23 | "loss_scale": 0, 24 | "max_grad_norm": 1.0, 25 | "from_pretrained": "pretrained_models/bert_base_pretrain_pfam_tokenized/", 26 | "log_dir": "./logs", 27 | "nnodes": 1, 28 | "node_rank": 0, 29 | "nproc_per_node": 3, 30 | "master_addr": "127.0.0.1", 31 | "master_port": 29500, 32 | "no_eval": false, 33 | "save_freq": "improvement" 34 | } 35 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | tensorboardX 3 | scipy 4 | lmdb 5 | boto3 6 | requests 7 | biopython 8 | filelock 9 | -------------------------------------------------------------------------------- /scripts/fix_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | from tqdm import tqdm 3 | import pickle as pkl 4 | from pathlib import Path 5 | from tape.data_utils import PFAM_VOCAB 6 | import numpy as np 7 | 8 | 9 | vocab = {v: k for k, v in PFAM_VOCAB.items()} 10 | 11 | 12 | files = Path('data').rglob('*.lmdb') 13 | 14 | for lmdbfile in files: 15 | print(lmdbfile) 16 | env = lmdb.open(str(lmdbfile), map_size=50e9) 17 | with env.begin(write=True) as txn: 18 | keys = pkl.loads(txn.get(b'keys')) 19 | for key in tqdm(keys): 20 | data = txn.get(key) 21 | item = pkl.loads(txn.get(key)) 22 | if isinstance(item['primary'], np.ndarray): 23 | item['primary'] = ''.join(vocab[index] for index in item['primary']) 24 | txn.replace(key, pkl.dumps(item)) 25 | -------------------------------------------------------------------------------- /scripts/generate_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from collections import Counter 3 | import numpy as np 4 | 5 | 6 | def autolabel(ax, rects): 7 | for rect in rects: 8 | height = rect.get_height() 9 | ax.annotate( 10 | '{:.1%}'.format(height), 11 | xy=(rect.get_x() + rect.get_width() / 2, height), 12 | xytext=(0, 3), 13 | textcoords="offset points", 14 | ha="center", 15 | va="bottom") 16 | 17 | 18 | def plot_rects(true_overlap, rand_overlap, savefile=None): 19 | true_count = Counter(true_overlap) 20 | rand_count = Counter(rand_overlap) 21 | 22 | true_percent = [true_count[(i, False)] / (true_count[(i, False)] + true_count[(i, True)]) 23 | for i in range(3)] 24 | rand_percent = [rand_count[(i, False)] / (rand_count[(i, False)] + rand_count[(i, True)]) 25 | for i in range(3)] 26 | x = np.arange(3) 27 | width = 0.35 28 | 29 | fig, ax = plt.subplots() 30 | rects_true = ax.bar(x - width / 2, true_percent, width=width, label='BPE') 31 | rects_rand = ax.bar(x + width / 2, rand_percent, width=width, label='Random') 32 | 33 | ax.set_xticks(x) 34 | ax.set_xticklabels(['Alpha Helix', 'Strand', 'Beta Sheet']) 35 | ax.set_ylabel('Percent Agreement') 36 | ax.set_title('Agreement between tokens and secondary structure labels') 37 | ax.legend(loc='lower right') 38 | 39 | autolabel(ax, rects_true) 40 | autolabel(ax, rects_rand) 41 | 42 | plt.tight_layout() 43 | 44 | plt.show() 45 | -------------------------------------------------------------------------------- /scripts/lmdb_to_fasta.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from tqdm import tqdm 4 | from Bio.SeqIO.FastaIO import Seq, SeqRecord 5 | from tape.datasets import LMDBDataset 6 | 7 | parser = argparse.ArgumentParser(description='Convert an lmdb file into a fasta file') 8 | parser.add_argument('lmdbfile', type=str, help='The lmdb file to convert') 9 | parser.add_argument('fastafile', type=str, help='The fasta file to output') 10 | args = parser.parse_args() 11 | 12 | dataset = LMDBDataset(args.lmdbfile) 13 | 14 | id_fill = math.ceil(math.log10(len(dataset))) 15 | 16 | fastafile = args.fastafile 17 | if not fastafile.endswith('.fasta'): 18 | fastafile += '.fasta' 19 | 20 | with open(fastafile, 'w') as outfile: 21 | for i, element in enumerate(tqdm(dataset)): 22 | id_ = element.get('id', str(i).zfill(id_fill)) 23 | if isinstance(id_, bytes): 24 | id_ = id_.decode() 25 | 26 | primary = element['primary'] 27 | seq = Seq(primary) 28 | record = SeqRecord(seq, id_) 29 | outfile.write(record.format('fasta')) 30 | -------------------------------------------------------------------------------- /scripts/tfrecord_to_json.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import numpy as np 4 | import tape.data_utils 5 | import json 6 | import os 7 | 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # disable tensorflow info logging 9 | tf.enable_eager_execution() 10 | 11 | filename = sys.argv[1] 12 | funcname = sys.argv[2] 13 | outfile = filename.rsplit('.', 1)[0] + '.json' 14 | 15 | func = getattr(tape.data_utils, funcname) 16 | data = tf.data.TFRecordDataset(filename).map(func) 17 | 18 | 19 | def pythonify(tensor): 20 | array = tensor.numpy() 21 | if isinstance(array, np.ndarray): 22 | return array.tolist() 23 | elif isinstance(array, bytes): 24 | return array.decode() 25 | elif isinstance(array, (int, np.int32, np.int64)): 26 | return int(array) 27 | else: 28 | raise ValueError(array) 29 | 30 | 31 | jsondata = [{name: pythonify(tensor) for name, tensor in ex.items()} for ex in data] 32 | 33 | with open(outfile, 'w') as f: 34 | json.dump(jsondata, f) 35 | -------------------------------------------------------------------------------- /scripts/tfrecord_to_lmdb.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Union, List 3 | import lmdb 4 | import os 5 | import pickle as pkl 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | 9 | import warnings 10 | warnings.filterwarnings("ignore", module="tensorflow") 11 | warnings.filterwarnings("ignore", module="numpy") 12 | from tape import data_utils # noqa: E402 13 | import tensorflow as tf # noqa: E402 14 | import numpy as np # noqa: E402 15 | 16 | 17 | def pythonify(tensor): 18 | array = tensor.numpy() 19 | if isinstance(array, np.ndarray): 20 | return array 21 | elif isinstance(array, bytes): 22 | return array 23 | elif isinstance(array, (int, np.int32, np.int64)): 24 | return int(array) 25 | else: 26 | raise ValueError(array) 27 | 28 | 29 | def convert(flist: Union[Path, List[Path]], outfile: Path, deserialization_func): 30 | files = str(flist) if not isinstance(flist, list) else [str(path) for path in flist] 31 | data = tf.data.TFRecordDataset(files).map(deserialization_func) 32 | vocab = {v: k for k, v in data_utils.PFAM_VOCAB.items()} 33 | 34 | env = lmdb.open(str(outfile), map_size=50e9) 35 | with env.begin(write=True) as txn: 36 | num_examples = 0 37 | for i, example in enumerate(tqdm(data)): 38 | item = {name: pythonify(tensor) for name, tensor in example.items()} 39 | item['primary'] = ''.join(vocab[index] for index in item['primary']) 40 | id_ = str(i).encode() 41 | txn.put(id_, pkl.dumps(item)) 42 | num_examples += 1 43 | txn.put(b'num_examples', pkl.dumps(num_examples)) 44 | 45 | 46 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # disable tensorflow info logging 47 | tf.enable_eager_execution() 48 | 49 | data_dir = Path('/home/rmrao/projects/tape/data/') 50 | out_dir = Path('/home/rmrao/projects/tape-pytorch/data') 51 | 52 | file_lists: List[Union[Path, List[Path]]] = [] 53 | file_lists.append(list((data_dir / 'pfam').glob('pfam31_train*.tfrecord'))) 54 | file_lists.append(list((data_dir / 'pfam').glob('pfam31_valid*.tfrecord'))) 55 | file_lists.append(data_dir / 'pfam' / 'pfam31_holdout.tfrecord') 56 | file_lists.append(list((data_dir / 'proteinnet').glob('contact_map_train*.tfrecord'))) 57 | file_lists.append(data_dir / 'proteinnet' / 'contact_map_valid.tfrecord') 58 | file_lists.append(data_dir / 'proteinnet' / 'contact_map_test.tfrecord') 59 | file_lists += list((data_dir / 'fluorescence').glob('*.tfrecord')) 60 | file_lists += list((data_dir / 'stability').glob('*.tfrecord')) 61 | file_lists += list((data_dir / 'remote_homology').glob('*.tfrecord')) 62 | file_lists += list((data_dir / 'secondary_structure').glob('*.tfrecord')) 63 | 64 | deserialize_funcs = { 65 | 'pfam': data_utils.deserialize_pfam_sequence, 66 | 'proteinnet': data_utils.deserialize_proteinnet_sequence, 67 | 'fluorescence': data_utils.deserialize_gfp_sequence, 68 | 'stability': data_utils.deserialize_stability_sequence, 69 | 'remote_homology': data_utils.deserialize_remote_homology_sequence, 70 | 'secondary_structure': data_utils.deserialize_secondary_structure} 71 | 72 | # outfile = 'data/proteinnet/proteinnet_test.lmdb' 73 | flist_names = ['pfam_train.lmdb', 'pfam_valid.lmdb', 'proteinnet_train.lmdb'] 74 | 75 | for flist in file_lists: 76 | if isinstance(flist, list): 77 | name = flist_names.pop(0) 78 | task_name = flist[0].relative_to(data_dir).parts[0] 79 | deserialization_func = deserialize_funcs[task_name] 80 | else: 81 | name = flist.with_suffix('.lmdb').name 82 | task_name = flist.relative_to(data_dir).parts[0] 83 | deserialization_func = deserialize_funcs[task_name] 84 | 85 | name = name.replace('pfam31', 'pfam') 86 | name = name.replace('contact_map', 'proteinnet') 87 | 88 | outfile = out_dir / task_name / name 89 | if outfile.exists(): 90 | continue 91 | 92 | print("Converting", name) 93 | convert(flist, outfile, deserialization_func) 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from setuptools import setup, find_packages 3 | import os 4 | 5 | 6 | def get_version(): 7 | directory = os.path.abspath(os.path.dirname(__file__)) 8 | init_file = os.path.join(directory, 'tape', '__init__.py') 9 | with open(init_file) as f: 10 | for line in f: 11 | if line.startswith('__version__'): 12 | delim = '"' if '"' in line else "'" 13 | return line.split(delim)[1] 14 | else: 15 | raise RuntimeError("Unable to find version string.") 16 | 17 | 18 | with open('README.md', 'r') as rf: 19 | README = rf.read() 20 | 21 | with open('LICENSE', 'r') as lf: 22 | LICENSE = lf.read() 23 | 24 | with open('requirements.txt', 'r') as reqs: 25 | requirements = reqs.read().split() 26 | 27 | setup( 28 | name='tape_proteins', 29 | packages=find_packages(), 30 | version=get_version(), 31 | description="Repostory of Protein Benchmarking and Modeling", 32 | author="Roshan Rao, Nick Bhattacharya, Neil Thomas", 33 | author_email='roshan_rao@berkeley.edu, nickbhat@berkeley.edu, nthomas@berkeley.edu', 34 | url='https://github.com/songlab-cal/tape', 35 | license=LICENSE, 36 | keywords=['Proteins', 'Deep Learning', 'Pytorch', 'TAPE'], 37 | include_package_data=True, 38 | install_requires=requirements, 39 | entry_points={ 40 | 'console_scripts': [ 41 | 'tape-train = tape.main:run_train', 42 | 'tape-train-distributed = tape.main:run_train_distributed', 43 | 'tape-eval = tape.main:run_eval', 44 | 'tape-embed = tape.main:run_embed', 45 | ] 46 | }, 47 | classifiers=[ 48 | 'Programming Language :: Python :: 3.6', 49 | 'Programming Language :: Python :: 3.7', 50 | 'Programming Language :: Python :: 3.8', 51 | 'Operating System :: POSIX :: Linux', 52 | 'Intended Audience :: Developers', 53 | 'Intended Audience :: Science/Research', 54 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 55 | 'Topic :: Scientific/Engineering :: Bio-Informatics' 56 | ], 57 | ) 58 | -------------------------------------------------------------------------------- /tape/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets # noqa: F401 2 | from . import metrics # noqa: F401 3 | from .tokenizers import TAPETokenizer # noqa: F401 4 | from .models.modeling_utils import ProteinModel 5 | from .models.modeling_utils import ProteinConfig 6 | 7 | import sys 8 | from pathlib import Path 9 | import importlib 10 | import pkgutil 11 | 12 | __version__ = '0.5' 13 | 14 | 15 | # Import all the models and configs 16 | for _, name, _ in pkgutil.iter_modules([str(Path(__file__).parent / 'models')]): 17 | imported_module = importlib.import_module('.models.' + name, package=__name__) 18 | for name, cls in imported_module.__dict__.items(): 19 | if isinstance(cls, type) and \ 20 | (issubclass(cls, ProteinModel) or issubclass(cls, ProteinConfig)): 21 | setattr(sys.modules[__name__], name, cls) 22 | -------------------------------------------------------------------------------- /tape/errors.py: -------------------------------------------------------------------------------- 1 | class EarlyStopping(Exception): 2 | """Raised when stopping training b/c no improvement in validation loss""" 3 | pass 4 | -------------------------------------------------------------------------------- /tape/main.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import os 3 | import logging 4 | import argparse 5 | import warnings 6 | import inspect 7 | 8 | 9 | try: 10 | import apex # noqa: F401 11 | APEX_FOUND = True 12 | except ImportError: 13 | APEX_FOUND = False 14 | 15 | from .registry import registry 16 | from . import training 17 | from . import utils 18 | 19 | CallbackList = typing.Sequence[typing.Callable] 20 | OutputDict = typing.Dict[str, typing.List[typing.Any]] 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | warnings.filterwarnings( # Ignore pytorch warning about loss gathering 25 | 'ignore', message='Was asked to gather along dimension 0', module='torch.nn.parallel') 26 | 27 | 28 | def create_base_parser() -> argparse.ArgumentParser: 29 | parser = argparse.ArgumentParser(description='Parent parser for tape functions', 30 | add_help=False) 31 | parser.add_argument('model_type', help='Base model class to run') 32 | parser.add_argument('--model_config_file', default=None, type=utils.check_is_file, 33 | help='Config file for model') 34 | parser.add_argument('--vocab_file', default=None, 35 | help='Pretrained tokenizer vocab file') 36 | parser.add_argument('--output_dir', default='./results', type=str) 37 | parser.add_argument('--no_cuda', action='store_true', help='CPU-only flag') 38 | parser.add_argument('--seed', default=42, type=int, help='Random seed to use') 39 | parser.add_argument('--local_rank', type=int, default=-1, 40 | help='Local rank of process in distributed training. ' 41 | 'Set by launch script.') 42 | parser.add_argument('--tokenizer', choices=['iupac', 'unirep'], 43 | default='iupac', help='Tokenizes to use on the amino acid sequences') 44 | parser.add_argument('--num_workers', default=8, type=int, 45 | help='Number of workers to use for multi-threaded data loading') 46 | parser.add_argument('--log_level', default=logging.INFO, 47 | choices=['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR', 48 | logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR], 49 | help="log level for the experiment") 50 | parser.add_argument('--debug', action='store_true', help='Run in debug mode') 51 | 52 | return parser 53 | 54 | 55 | def create_train_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 56 | parser = argparse.ArgumentParser(description='Run Training on the TAPE datasets', 57 | parents=[base_parser]) 58 | parser.add_argument('task', choices=list(registry.task_name_mapping.keys()), 59 | help='TAPE Task to train/eval on') 60 | parser.add_argument('--learning_rate', default=1e-4, type=float, 61 | help='Learning rate') 62 | parser.add_argument('--batch_size', default=1024, type=int, 63 | help='Batch size') 64 | parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir, 65 | help='Directory from which to load task data') 66 | parser.add_argument('--num_train_epochs', default=10, type=int, 67 | help='Number of training epochs') 68 | parser.add_argument('--num_log_iter', default=20, type=int, 69 | help='Number of training steps per log iteration') 70 | parser.add_argument('--fp16', action='store_true', help='Whether to use fp16 weights') 71 | parser.add_argument('--warmup_steps', default=10000, type=int, 72 | help='Number of learning rate warmup steps') 73 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, 74 | help='Number of forward passes to make for each backwards pass') 75 | parser.add_argument('--loss_scale', default=0, type=int, 76 | help='Loss scaling. Only used during fp16 training.') 77 | parser.add_argument('--max_grad_norm', default=1.0, type=float, 78 | help='Maximum gradient norm') 79 | parser.add_argument('--exp_name', default=None, type=str, 80 | help='Name to give to this experiment') 81 | parser.add_argument('--from_pretrained', default=None, type=str, 82 | help='Directory containing config and pretrained model weights') 83 | parser.add_argument('--log_dir', default='./logs', type=str) 84 | parser.add_argument('--eval_freq', type=int, default=1, 85 | help="Frequency of eval pass. A value <= 0 means the eval pass is " 86 | "not run") 87 | parser.add_argument('--save_freq', default=1, type=utils.int_or_str, 88 | help="How often to save the model during training. Either an integer " 89 | "frequency or the string 'improvement'") 90 | parser.add_argument('--patience', default=-1, type=int, 91 | help="How many epochs without improvement to wait before ending " 92 | "training") 93 | parser.add_argument('--resume_from_checkpoint', action='store_true', 94 | help="whether to resume training from the checkpoint") 95 | return parser 96 | 97 | 98 | def create_eval_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 99 | parser = argparse.ArgumentParser(description='Run Eval on the TAPE Datasets', 100 | parents=[base_parser]) 101 | parser.add_argument('task', choices=list(registry.task_name_mapping.keys()), 102 | help='TAPE Task to train/eval on') 103 | parser.add_argument('from_pretrained', type=str, 104 | help='Directory containing config and pretrained model weights') 105 | parser.add_argument('--batch_size', default=1024, type=int, 106 | help='Batch size') 107 | parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir, 108 | help='Directory from which to load task data') 109 | parser.add_argument('--metrics', default=[], 110 | help=f'Metrics to run on the result. ' 111 | f'Choices: {list(registry.metric_name_mapping.keys())}', 112 | nargs='*') 113 | parser.add_argument('--split', default='test', type=str, 114 | help='Which split to run on') 115 | return parser 116 | 117 | 118 | def create_embed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 119 | parser = argparse.ArgumentParser( 120 | description='Embed a set of proteins with a pretrained model', 121 | parents=[base_parser]) 122 | parser.add_argument('data_file', type=str, 123 | help='File containing set of proteins to embed') 124 | parser.add_argument('out_file', type=str, 125 | help='Name of output file') 126 | parser.add_argument('from_pretrained', type=str, 127 | help='Directory containing config and pretrained model weights') 128 | parser.add_argument('--batch_size', default=1024, type=int, 129 | help='Batch size') 130 | parser.add_argument('--full_sequence_embed', action='store_true', 131 | help='If true, saves an embedding at every amino acid position ' 132 | 'in the sequence. Note that this can take a large amount ' 133 | 'of disk space.') 134 | parser.set_defaults(task='embed') 135 | return parser 136 | 137 | 138 | def create_distributed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 139 | parser = argparse.ArgumentParser(add_help=False, parents=[base_parser]) 140 | # typing.Optional arguments for the launch helper 141 | parser.add_argument("--nnodes", type=int, default=1, 142 | help="The number of nodes to use for distributed " 143 | "training") 144 | parser.add_argument("--node_rank", type=int, default=0, 145 | help="The rank of the node for multi-node distributed " 146 | "training") 147 | parser.add_argument("--nproc_per_node", type=int, default=1, 148 | help="The number of processes to launch on each node, " 149 | "for GPU training, this is recommended to be set " 150 | "to the number of GPUs in your system so that " 151 | "each process can be bound to a single GPU.") 152 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 153 | help="Master node (rank 0)'s address, should be either " 154 | "the IP address or the hostname of node 0, for " 155 | "single node multi-proc training, the " 156 | "--master_addr can simply be 127.0.0.1") 157 | parser.add_argument("--master_port", default=29500, type=int, 158 | help="Master node (rank 0)'s free port that needs to " 159 | "be used for communciation during distributed " 160 | "training") 161 | return parser 162 | 163 | 164 | def run_train(args: typing.Optional[argparse.Namespace] = None, env=None) -> None: 165 | if env is not None: 166 | os.environ = env 167 | 168 | if args is None: 169 | base_parser = create_base_parser() 170 | train_parser = create_train_parser(base_parser) 171 | args = train_parser.parse_args() 172 | 173 | if args.gradient_accumulation_steps < 1: 174 | raise ValueError( 175 | f"Invalid gradient_accumulation_steps parameter: " 176 | f"{args.gradient_accumulation_steps}, should be >= 1") 177 | 178 | if (args.fp16 or args.local_rank != -1) and not APEX_FOUND: 179 | raise ImportError( 180 | "Please install apex from https://www.github.com/nvidia/apex " 181 | "to use distributed and fp16 training.") 182 | 183 | arg_dict = vars(args) 184 | arg_names = inspect.getfullargspec(training.run_train).args 185 | 186 | missing = set(arg_names) - set(arg_dict.keys()) 187 | if missing: 188 | raise RuntimeError(f"Missing arguments: {missing}") 189 | train_args = {name: arg_dict[name] for name in arg_names} 190 | 191 | training.run_train(**train_args) 192 | 193 | 194 | def run_eval(args: typing.Optional[argparse.Namespace] = None) -> typing.Dict[str, float]: 195 | if args is None: 196 | base_parser = create_base_parser() 197 | parser = create_eval_parser(base_parser) 198 | args = parser.parse_args() 199 | 200 | if args.from_pretrained is None: 201 | raise ValueError("Must specify pretrained model") 202 | if args.local_rank != -1: 203 | raise ValueError("TAPE does not support distributed validation pass") 204 | 205 | arg_dict = vars(args) 206 | arg_names = inspect.getfullargspec(training.run_eval).args 207 | 208 | missing = set(arg_names) - set(arg_dict.keys()) 209 | if missing: 210 | raise RuntimeError(f"Missing arguments: {missing}") 211 | eval_args = {name: arg_dict[name] for name in arg_names} 212 | 213 | return training.run_eval(**eval_args) 214 | 215 | 216 | def run_embed(args: typing.Optional[argparse.Namespace] = None) -> None: 217 | if args is None: 218 | base_parser = create_base_parser() 219 | parser = create_embed_parser(base_parser) 220 | args = parser.parse_args() 221 | if args.from_pretrained is None: 222 | raise ValueError("Must specify pretrained model") 223 | if args.local_rank != -1: 224 | raise ValueError("TAPE does not support distributed validation pass") 225 | 226 | arg_dict = vars(args) 227 | arg_names = inspect.getfullargspec(training.run_embed).args 228 | 229 | missing = set(arg_names) - set(arg_dict.keys()) 230 | if missing: 231 | raise RuntimeError(f"Missing arguments: {missing}") 232 | embed_args = {name: arg_dict[name] for name in arg_names} 233 | 234 | training.run_embed(**embed_args) 235 | 236 | 237 | def run_train_distributed(args: typing.Optional[argparse.Namespace] = None) -> None: 238 | """Runs distributed training via multiprocessing. 239 | """ 240 | if args is None: 241 | base_parser = create_base_parser() 242 | distributed_parser = create_distributed_parser(base_parser) 243 | distributed_train_parser = create_train_parser(distributed_parser) 244 | args = distributed_train_parser.parse_args() 245 | 246 | # Define the experiment name here, instead of dealing with barriers and communication 247 | # when getting the experiment name 248 | exp_name = utils.get_expname(args.exp_name, args.task, args.model_type) 249 | args.exp_name = exp_name 250 | utils.launch_process_group( 251 | run_train, args, args.nproc_per_node, args.nnodes, 252 | args.node_rank, args.master_addr, args.master_port) 253 | 254 | 255 | if __name__ == '__main__': 256 | run_train_distributed() 257 | -------------------------------------------------------------------------------- /tape/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | import numpy as np 3 | import scipy.stats 4 | 5 | from .registry import registry 6 | 7 | 8 | @registry.register_metric('mse') 9 | def mean_squared_error(target: Sequence[float], 10 | prediction: Sequence[float]) -> float: 11 | target_array = np.asarray(target) 12 | prediction_array = np.asarray(prediction) 13 | return np.mean(np.square(target_array - prediction_array)) 14 | 15 | 16 | @registry.register_metric('mae') 17 | def mean_absolute_error(target: Sequence[float], 18 | prediction: Sequence[float]) -> float: 19 | target_array = np.asarray(target) 20 | prediction_array = np.asarray(prediction) 21 | return np.mean(np.abs(target_array - prediction_array)) 22 | 23 | 24 | @registry.register_metric('spearmanr') 25 | def spearmanr(target: Sequence[float], 26 | prediction: Sequence[float]) -> float: 27 | target_array = np.asarray(target) 28 | prediction_array = np.asarray(prediction) 29 | return scipy.stats.spearmanr(target_array, prediction_array).correlation 30 | 31 | 32 | @registry.register_metric('accuracy') 33 | def accuracy(target: Union[Sequence[int], Sequence[Sequence[int]]], 34 | prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float: 35 | if isinstance(target[0], int): 36 | # non-sequence case 37 | return np.mean(np.asarray(target) == np.asarray(prediction).argmax(-1)) 38 | else: 39 | correct = 0 40 | total = 0 41 | for label, score in zip(target, prediction): 42 | label_array = np.asarray(label) 43 | pred_array = np.asarray(score).argmax(-1) 44 | mask = label_array != -1 45 | is_correct = label_array[mask] == pred_array[mask] 46 | correct += is_correct.sum() 47 | total += is_correct.size 48 | return correct / total 49 | -------------------------------------------------------------------------------- /tape/models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .modeling_utils import ProteinConfig # noqa: F401 2 | # from .modeling_utils import ProteinModel # noqa: F401 3 | 4 | # from .modeling_bert import ProteinBertModel # noqa: F401 5 | # from .modeling_bert import ProteinBertForMaskedLM # noqa: F401 6 | # from .modeling_bert import ProteinBertForValuePrediction # noqa: F401 7 | # from .modeling_bert import ProteinBertForSequenceClassification # noqa: F401 8 | # from .modeling_bert import ProteinBertForSequenceToSequenceClassification # noqa: F401 9 | # # TODO: ProteinBertForContactPrediction 10 | # from .modeling_resnet import ProteinResNetModel # noqa: F401 11 | # from .modeling_resnet import ProteinResNetForMaskedLM # noqa: F401 12 | # from .modeling_resnet import ProteinResNetForValuePrediction # noqa: F401 13 | # from .modeling_resnet import ProteinResNetForSequenceClassification # noqa: F401 14 | # from .modeling_resnet import ProteinResNetForSequenceToSequenceClassification # noqa: F401 15 | # # TODO: ProteinResNetForContactPrediction 16 | # # TODO: ProteinLSTM* 17 | # from .modeling_unirep import UniRepModel # noqa: F401 18 | # from .modeling_unirep import UniRepForLM # noqa: F401 19 | # from .modeling_unirep import UniRepForValuePrediction # noqa: F401 20 | # from .modeling_unirep import UniRepForSequenceClassification # noqa: F401 21 | # from .modeling_unirep import UniRepForSequenceToSequenceClassification # noqa: F401 22 | # # TODO: UniRepForContactPrediction 23 | # # TODO: Bepler* 24 | # from .modeling_onehot import OneHotModel # noqa: F401 25 | # from .modeling_onehot import OneHotForValuePrediction # noqa: F401 26 | # from .modeling_onehot import OneHotForSequenceClassification # noqa: F401 27 | # from .modeling_onehot import OneHotForSequenceToSequenceClassification # noqa: F401 28 | # TODO: OneHotForContactPrediction 29 | -------------------------------------------------------------------------------- /tape/models/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the huggingface transformers library at 4 | https://github.com/huggingface/transformers, which in turn is adapted from the AllenNLP 5 | library at https://github.com/allenai/allennlp 6 | Copyright by the AllenNLP authors. 7 | Note - this file goes to effort to support Python 2, but the rest of this repository does not. 8 | """ 9 | from __future__ import (absolute_import, division, print_function, unicode_literals) 10 | 11 | import typing 12 | import sys 13 | import json 14 | import logging 15 | import os 16 | import tempfile 17 | import fnmatch 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | from contextlib import contextmanager 26 | from functools import partial, wraps 27 | from hashlib import sha256 28 | 29 | from filelock import FileLock 30 | # from tqdm.auto import tqdm 31 | 32 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 33 | 34 | 35 | try: 36 | from torch.hub import _get_torch_home 37 | torch_cache_home = _get_torch_home() 38 | except ImportError: 39 | torch_cache_home = os.path.expanduser( 40 | os.getenv('TORCH_HOME', os.path.join( 41 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 42 | default_cache_path = os.path.join(torch_cache_home, 'protein_models') 43 | 44 | try: 45 | from urllib.parse import urlparse 46 | except ImportError: 47 | from urlparse import urlparse # type: ignore 48 | 49 | try: 50 | from pathlib import Path 51 | PYTORCH_PRETRAINED_BERT_CACHE: typing.Union[str, Path] = Path( 52 | os.getenv('PROTEIN_MODELS_CACHE', os.getenv( 53 | 'PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 54 | except (AttributeError, ImportError): 55 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PROTEIN_MODELS_CACHE', 56 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 57 | default_cache_path)) 58 | 59 | PROTEIN_MODELS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 60 | 61 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 62 | 63 | 64 | def get_cache(): 65 | return PROTEIN_MODELS_CACHE 66 | 67 | 68 | def get_etag(url): 69 | # Get eTag to add to filename, if it exists. 70 | if url.startswith("s3://"): 71 | etag = s3_etag(url) 72 | else: 73 | try: 74 | response = requests.head(url, allow_redirects=True) 75 | if response.status_code != 200: 76 | etag = None 77 | else: 78 | etag = response.headers.get("ETag") 79 | except EnvironmentError: 80 | etag = None 81 | 82 | if sys.version_info[0] == 2 and etag is not None: 83 | etag = etag.decode('utf-8') 84 | 85 | return etag 86 | 87 | 88 | def url_to_filename(url, etag=None): 89 | """ 90 | Convert `url` into a hashed filename in a repeatable way. 91 | If `etag` is specified, append its hash to the url's, delimited 92 | by a period. 93 | """ 94 | url_bytes = url.encode('utf-8') 95 | url_hash = sha256(url_bytes) 96 | filename = url_hash.hexdigest() 97 | 98 | if etag: 99 | etag_bytes = etag.encode('utf-8') 100 | etag_hash = sha256(etag_bytes) 101 | filename += '.' + etag_hash.hexdigest() 102 | 103 | return filename 104 | 105 | 106 | def filename_to_url(filename, cache_dir=None): 107 | """ 108 | Return the url and etag (which may be ``None``) stored for `filename`. 109 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 110 | """ 111 | if cache_dir is None: 112 | cache_dir = PROTEIN_MODELS_CACHE 113 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 114 | cache_dir = str(cache_dir) 115 | 116 | cache_path = os.path.join(cache_dir, filename) 117 | if not os.path.exists(cache_path): 118 | raise EnvironmentError("file {} not found".format(cache_path)) 119 | 120 | meta_path = cache_path + '.json' 121 | if not os.path.exists(meta_path): 122 | raise EnvironmentError("file {} not found".format(meta_path)) 123 | 124 | with open(meta_path, encoding="utf-8") as meta_file: 125 | metadata = json.load(meta_file) 126 | url = metadata['url'] 127 | etag = metadata['etag'] 128 | 129 | return url, etag 130 | 131 | 132 | def cached_path(url_or_filename, force_download=False, cache_dir=None): 133 | """ 134 | Given something that might be a URL (or might be a local path), 135 | determine which. If it's a URL, download the file and cache it, and 136 | return the path to the cached file. If it's already a local path, 137 | make sure the file exists and then return the path. 138 | 139 | Args: 140 | cache_dir: specify a cache directory to save the file to 141 | (overwrite the default cache dir). 142 | force_download: if True, re-dowload the file even if it's 143 | already cached in the cache dir. 144 | """ 145 | if cache_dir is None: 146 | cache_dir = PROTEIN_MODELS_CACHE 147 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 148 | url_or_filename = str(url_or_filename) 149 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 150 | cache_dir = str(cache_dir) 151 | 152 | parsed = urlparse(url_or_filename) 153 | 154 | if parsed.scheme in ('http', 'https', 's3'): 155 | # URL, so get it from the cache (downloading if necessary) 156 | output_path = get_from_cache(url_or_filename, cache_dir, force_download) 157 | elif os.path.exists(url_or_filename): 158 | # File, and it exists. 159 | output_path = url_or_filename 160 | elif parsed.scheme == '': 161 | # File, but it doesn't exist. 162 | raise EnvironmentError("file {} not found".format(url_or_filename)) 163 | else: 164 | # Something unknown 165 | raise ValueError("unable to parse {} as a URL or as a local path".format( 166 | url_or_filename)) 167 | 168 | return output_path 169 | 170 | 171 | def split_s3_path(url): 172 | """Split a full s3 path into the bucket name and path.""" 173 | parsed = urlparse(url) 174 | if not parsed.netloc or not parsed.path: 175 | raise ValueError("bad s3 path {}".format(url)) 176 | bucket_name = parsed.netloc 177 | s3_path = parsed.path 178 | # Remove '/' at beginning of path. 179 | if s3_path.startswith("/"): 180 | s3_path = s3_path[1:] 181 | return bucket_name, s3_path 182 | 183 | 184 | def s3_request(func): 185 | """ 186 | Wrapper function for s3 requests in order to create more helpful error 187 | messages. 188 | """ 189 | 190 | @wraps(func) 191 | def wrapper(url, *args, **kwargs): 192 | try: 193 | return func(url, *args, **kwargs) 194 | except ClientError as exc: 195 | if int(exc.response["Error"]["Code"]) == 404: 196 | raise EnvironmentError("file {} not found".format(url)) 197 | else: 198 | raise 199 | 200 | return wrapper 201 | 202 | 203 | @s3_request 204 | def s3_etag(url): 205 | """Check ETag on S3 object.""" 206 | s3_resource = boto3.resource("s3") 207 | bucket_name, s3_path = split_s3_path(url) 208 | s3_object = s3_resource.Object(bucket_name, s3_path) 209 | return s3_object.e_tag 210 | 211 | 212 | @s3_request 213 | def s3_get(url, temp_file): 214 | """Pull a file directly from S3.""" 215 | s3_resource = boto3.resource("s3") 216 | bucket_name, s3_path = split_s3_path(url) 217 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 218 | 219 | 220 | def http_get(url, temp_file): 221 | req = requests.get(url, stream=True) 222 | content_length = req.headers.get('Content-Length') 223 | total = int(content_length) if content_length is not None else None 224 | progress = tqdm(unit="B", total=total) 225 | for chunk in req.iter_content(chunk_size=1024): 226 | if chunk: # filter out keep-alive new chunks 227 | progress.update(len(chunk)) 228 | temp_file.write(chunk) 229 | progress.close() 230 | 231 | 232 | def get_from_cache(url, cache_dir=None, force_download=False, resume_download=False): 233 | """ 234 | Given a URL, look for the corresponding dataset in the local cache. 235 | If it's not there, download it. Then return the path to the cached file. 236 | """ 237 | if cache_dir is None: 238 | cache_dir = PROTEIN_MODELS_CACHE 239 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 240 | cache_dir = str(cache_dir) 241 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 242 | cache_dir = str(cache_dir) 243 | 244 | if not os.path.exists(cache_dir): 245 | os.makedirs(cache_dir) 246 | 247 | # Get eTag to add to filename, if it exists. 248 | if url.startswith("s3://"): 249 | etag = s3_etag(url) 250 | else: 251 | try: 252 | response = requests.head(url, allow_redirects=True) 253 | if response.status_code != 200: 254 | etag = None 255 | else: 256 | etag = response.headers.get("ETag") 257 | except EnvironmentError: 258 | etag = None 259 | 260 | if sys.version_info[0] == 2 and etag is not None: 261 | etag = etag.decode('utf-8') 262 | filename = url_to_filename(url, etag) 263 | 264 | # get cache path to put the file 265 | cache_path = os.path.join(cache_dir, filename) 266 | 267 | if os.path.exists(cache_path) and etag is None: 268 | return cache_path 269 | 270 | # If we don't have a connection (etag is None) and can't identify the file 271 | # try to get the last downloaded one 272 | if not os.path.exists(cache_path) and etag is None: 273 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 274 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 275 | if matching_files: 276 | cache_path = os.path.join(cache_dir, matching_files[-1]) 277 | 278 | # From now on, etag is not None 279 | if os.path.exists(cache_path) and not force_download: 280 | return cache_path 281 | 282 | # Prevent parallel downloads of the same file with a lock. 283 | lock_path = cache_path + ".lock" 284 | with FileLock(lock_path): 285 | 286 | # If the download just completed while the lock was activated. 287 | if os.path.exists(cache_path) and not force_download: 288 | # Even if returning early like here, the lock will be released. 289 | return cache_path 290 | 291 | if resume_download: 292 | incomplete_path = cache_path + ".incomplete" 293 | 294 | @contextmanager 295 | def _resumable_file_manager(): 296 | with open(incomplete_path, "a+b") as f: 297 | yield f 298 | 299 | temp_file_manager = _resumable_file_manager 300 | else: 301 | temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, 302 | delete=False) 303 | # Download to temporary file, then copy to cache dir once finished. 304 | # Otherwise you get corrupt cache entries if the download gets interrupted. 305 | with temp_file_manager() as temp_file: 306 | logger.info("%s not in cache or force_download=True, download to %s", 307 | url, temp_file.name) 308 | 309 | http_get(url, temp_file) 310 | 311 | logger.info("storing %s in cache at %s", url, cache_path) 312 | os.replace(temp_file.name, cache_path) 313 | 314 | logger.info("creating metadata file for %s", cache_path) 315 | meta = {"url": url, "etag": etag} 316 | meta_path = cache_path + ".json" 317 | with open(meta_path, "w") as meta_file: 318 | json.dump(meta, meta_file) 319 | ''' 320 | if not os.path.exists(cache_path): 321 | # Download to temporary file, then copy to cache dir once finished. 322 | # Otherwise you get corrupt cache entries if the download gets interrupted. 323 | with tempfile.NamedTemporaryFile() as temp_file: 324 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 325 | 326 | # GET file object 327 | if url.startswith("s3://"): 328 | s3_get(url, temp_file) 329 | else: 330 | http_get(url, temp_file) 331 | 332 | # we are copying the file before closing it, so flush to avoid truncation 333 | temp_file.flush() 334 | # shutil.copyfileobj() starts at the current position, so go to the start 335 | temp_file.seek(0) 336 | 337 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 338 | with open(cache_path, 'wb') as cache_file: 339 | shutil.copyfileobj(temp_file, cache_file) 340 | 341 | logger.info("creating metadata file for %s", cache_path) 342 | meta = {'url': url, 'etag': etag} 343 | meta_path = cache_path + '.json' 344 | with open(meta_path, 'w') as meta_file: 345 | output_string = json.dumps(meta) 346 | if sys.version_info[0] == 2 and isinstance(output_string, str): 347 | # The beauty of python 2 348 | output_string = unicode(output_string, 'utf-8') # noqa: F821 349 | meta_file.write(output_string) 350 | 351 | logger.info("removing temp file %s", temp_file.name) 352 | ''' 353 | return cache_path 354 | -------------------------------------------------------------------------------- /tape/models/modeling_lstm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .modeling_utils import ProteinConfig 8 | from .modeling_utils import ProteinModel 9 | from .modeling_utils import ValuePredictionHead 10 | from .modeling_utils import SequenceClassificationHead 11 | from .modeling_utils import SequenceToSequenceClassificationHead 12 | from .modeling_utils import PairwiseContactPredictionHead 13 | from ..registry import registry 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/" 19 | LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {} 20 | LSTM_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {} 21 | 22 | 23 | class ProteinLSTMConfig(ProteinConfig): 24 | pretrained_config_archive_map = LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP 25 | 26 | def __init__(self, 27 | vocab_size: int = 30, 28 | input_size: int = 128, 29 | hidden_size: int = 1024, 30 | num_hidden_layers: int = 3, 31 | hidden_dropout_prob: float = 0.1, 32 | initializer_range: float = 0.02, 33 | **kwargs): 34 | super().__init__(**kwargs) 35 | self.vocab_size = vocab_size 36 | self.input_size = input_size 37 | self.hidden_size = hidden_size 38 | self.num_hidden_layers = num_hidden_layers 39 | self.hidden_dropout_prob = hidden_dropout_prob 40 | self.initializer_range = initializer_range 41 | 42 | 43 | class ProteinLSTMLayer(nn.Module): 44 | 45 | def __init__(self, input_size: int, hidden_size: int, dropout: float = 0.): 46 | super().__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 49 | 50 | def forward(self, inputs): 51 | inputs = self.dropout(inputs) 52 | self.lstm.flatten_parameters() 53 | return self.lstm(inputs) 54 | 55 | 56 | class ProteinLSTMPooler(nn.Module): 57 | def __init__(self, config): 58 | super().__init__() 59 | self.scalar_reweighting = nn.Linear(2 * config.num_hidden_layers, 1) 60 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 61 | self.activation = nn.Tanh() 62 | 63 | def forward(self, hidden_states): 64 | # We "pool" the model by simply taking the hidden state corresponding 65 | # to the first token. 66 | pooled_output = self.scalar_reweighting(hidden_states).squeeze(2) 67 | pooled_output = self.dense(pooled_output) 68 | pooled_output = self.activation(pooled_output) 69 | return pooled_output 70 | 71 | 72 | class ProteinLSTMEncoder(nn.Module): 73 | 74 | def __init__(self, config: ProteinLSTMConfig): 75 | super().__init__() 76 | forward_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)] 77 | reverse_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)] 78 | for _ in range(config.num_hidden_layers - 1): 79 | forward_lstm.append(ProteinLSTMLayer( 80 | config.hidden_size, config.hidden_size, config.hidden_dropout_prob)) 81 | reverse_lstm.append(ProteinLSTMLayer( 82 | config.hidden_size, config.hidden_size, config.hidden_dropout_prob)) 83 | self.forward_lstm = nn.ModuleList(forward_lstm) 84 | self.reverse_lstm = nn.ModuleList(reverse_lstm) 85 | self.output_hidden_states = config.output_hidden_states 86 | 87 | def forward(self, inputs, input_mask=None): 88 | all_forward_pooled = () 89 | all_reverse_pooled = () 90 | all_hidden_states = (inputs,) 91 | forward_output = inputs 92 | for layer in self.forward_lstm: 93 | forward_output, forward_pooled = layer(forward_output) 94 | all_forward_pooled = all_forward_pooled + (forward_pooled[0],) 95 | all_hidden_states = all_hidden_states + (forward_output,) 96 | 97 | reversed_sequence = self.reverse_sequence(inputs, input_mask) 98 | reverse_output = reversed_sequence 99 | for layer in self.reverse_lstm: 100 | reverse_output, reverse_pooled = layer(reverse_output) 101 | all_reverse_pooled = all_reverse_pooled + (reverse_pooled[0],) 102 | all_hidden_states = all_hidden_states + (reverse_output,) 103 | reverse_output = self.reverse_sequence(reverse_output, input_mask) 104 | 105 | output = torch.cat((forward_output, reverse_output), dim=2) 106 | pooled = all_forward_pooled + all_reverse_pooled 107 | pooled = torch.stack(pooled, 3).squeeze(0) 108 | outputs = (output, pooled) 109 | if self.output_hidden_states: 110 | outputs = outputs + (all_hidden_states,) 111 | 112 | return outputs # sequence_embedding, pooled_embedding, (hidden_states) 113 | 114 | def reverse_sequence(self, sequence, input_mask): 115 | if input_mask is None: 116 | idx = torch.arange(sequence.size(1) - 1, -1, -1) 117 | reversed_sequence = sequence.index_select(1, idx, device=sequence.device) 118 | else: 119 | sequence_lengths = input_mask.sum(1) 120 | reversed_sequence = [] 121 | for seq, seqlen in zip(sequence, sequence_lengths): 122 | idx = torch.arange(seqlen - 1, -1, -1, device=seq.device) 123 | seq = seq.index_select(0, idx) 124 | seq = F.pad(seq, [0, 0, 0, sequence.size(1) - seqlen]) 125 | reversed_sequence.append(seq) 126 | reversed_sequence = torch.stack(reversed_sequence, 0) 127 | return reversed_sequence 128 | 129 | 130 | class ProteinLSTMAbstractModel(ProteinModel): 131 | 132 | config_class = ProteinLSTMConfig 133 | pretrained_model_archive_map = LSTM_PRETRAINED_MODEL_ARCHIVE_MAP 134 | base_model_prefix = "lstm" 135 | 136 | def _init_weights(self, module): 137 | """ Initialize the weights """ 138 | if isinstance(module, (nn.Linear, nn.Embedding)): 139 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 140 | if isinstance(module, nn.Linear) and module.bias is not None: 141 | module.bias.data.zero_() 142 | 143 | 144 | @registry.register_task_model('embed', 'lstm') 145 | class ProteinLSTMModel(ProteinLSTMAbstractModel): 146 | 147 | def __init__(self, config: ProteinLSTMConfig): 148 | super().__init__(config) 149 | self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size) 150 | self.encoder = ProteinLSTMEncoder(config) 151 | self.pooler = ProteinLSTMPooler(config) 152 | self.output_hidden_states = config.output_hidden_states 153 | self.init_weights() 154 | 155 | def forward(self, input_ids, input_mask=None): 156 | if input_mask is None: 157 | input_mask = torch.ones_like(input_ids) 158 | 159 | # fp16 compatibility 160 | embedding_output = self.embed_matrix(input_ids) 161 | outputs = self.encoder(embedding_output, input_mask=input_mask) 162 | sequence_output = outputs[0] 163 | pooled_outputs = self.pooler(outputs[1]) 164 | 165 | outputs = (sequence_output, pooled_outputs) + outputs[2:] 166 | return outputs # sequence_output, pooled_output, (hidden_states) 167 | 168 | 169 | @registry.register_task_model('language_modeling', 'lstm') 170 | class ProteinLSTMForLM(ProteinLSTMAbstractModel): 171 | 172 | def __init__(self, config): 173 | super().__init__(config) 174 | 175 | self.lstm = ProteinLSTMModel(config) 176 | self.feedforward = nn.Linear(config.hidden_size, config.vocab_size) 177 | 178 | self.init_weights() 179 | 180 | def forward(self, 181 | input_ids, 182 | input_mask=None, 183 | targets=None): 184 | 185 | outputs = self.lstm(input_ids, input_mask=input_mask) 186 | 187 | sequence_output, pooled_output = outputs[:2] 188 | 189 | forward_prediction, reverse_prediction = sequence_output.chunk(2, -1) 190 | forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0]) 191 | reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1]) 192 | prediction_scores = \ 193 | self.feedforward(forward_prediction) + self.feedforward(reverse_prediction) 194 | prediction_scores = prediction_scores.contiguous() 195 | 196 | # add hidden states and if they are here 197 | outputs = (prediction_scores,) + outputs[2:] 198 | 199 | if targets is not None: 200 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 201 | lm_loss = loss_fct( 202 | prediction_scores.view(-1, self.config.vocab_size), targets.view(-1)) 203 | outputs = (lm_loss,) + outputs 204 | 205 | # (loss), prediction_scores, seq_relationship_score, (hidden_states) 206 | return outputs 207 | 208 | 209 | @registry.register_task_model('fluorescence', 'lstm') 210 | @registry.register_task_model('stability', 'lstm') 211 | class ProteinLSTMForValuePrediction(ProteinLSTMAbstractModel): 212 | 213 | def __init__(self, config): 214 | super().__init__(config) 215 | 216 | self.lstm = ProteinLSTMModel(config) 217 | self.predict = ValuePredictionHead(config.hidden_size) 218 | 219 | self.init_weights() 220 | 221 | def forward(self, input_ids, input_mask=None, targets=None): 222 | 223 | outputs = self.lstm(input_ids, input_mask=input_mask) 224 | 225 | sequence_output, pooled_output = outputs[:2] 226 | outputs = self.predict(pooled_output, targets) + outputs[2:] 227 | # (loss), prediction_scores, (hidden_states) 228 | return outputs 229 | 230 | 231 | @registry.register_task_model('remote_homology', 'lstm') 232 | class ProteinLSTMForSequenceClassification(ProteinLSTMAbstractModel): 233 | 234 | def __init__(self, config): 235 | super().__init__(config) 236 | 237 | self.lstm = ProteinLSTMModel(config) 238 | self.classify = SequenceClassificationHead( 239 | config.hidden_size, config.num_labels) 240 | 241 | self.init_weights() 242 | 243 | def forward(self, input_ids, input_mask=None, targets=None): 244 | 245 | outputs = self.lstm(input_ids, input_mask=input_mask) 246 | 247 | sequence_output, pooled_output = outputs[:2] 248 | outputs = self.classify(pooled_output, targets) + outputs[2:] 249 | # (loss), prediction_scores, (hidden_states) 250 | return outputs 251 | 252 | 253 | @registry.register_task_model('secondary_structure', 'lstm') 254 | class ProteinLSTMForSequenceToSequenceClassification(ProteinLSTMAbstractModel): 255 | 256 | def __init__(self, config): 257 | super().__init__(config) 258 | 259 | self.lstm = ProteinLSTMModel(config) 260 | self.classify = SequenceToSequenceClassificationHead( 261 | config.hidden_size * 2, config.num_labels, ignore_index=-1) 262 | 263 | self.init_weights() 264 | 265 | def forward(self, input_ids, input_mask=None, targets=None): 266 | 267 | outputs = self.lstm(input_ids, input_mask=input_mask) 268 | 269 | sequence_output, pooled_output = outputs[:2] 270 | amino_acid_class_scores = self.classify(sequence_output.contiguous()) 271 | 272 | # add hidden states and if they are here 273 | outputs = (amino_acid_class_scores,) + outputs[2:] 274 | 275 | if targets is not None: 276 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 277 | classification_loss = loss_fct( 278 | amino_acid_class_scores.view(-1, self.config.num_labels), 279 | targets.view(-1)) 280 | outputs = (classification_loss,) + outputs 281 | 282 | # (loss), prediction_scores, seq_relationship_score, (hidden_states) 283 | return outputs 284 | 285 | 286 | @registry.register_task_model('contact_prediction', 'lstm') 287 | class ProteinLSTMForContactPrediction(ProteinLSTMAbstractModel): 288 | 289 | def __init__(self, config): 290 | super().__init__(config) 291 | 292 | self.lstm = ProteinLSTMModel(config) 293 | self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) 294 | 295 | self.init_weights() 296 | 297 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 298 | 299 | outputs = self.lstm(input_ids, input_mask=input_mask) 300 | 301 | sequence_output, pooled_output = outputs[:2] 302 | outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] 303 | # (loss), prediction_scores, (hidden_states), (attentions) 304 | return outputs 305 | -------------------------------------------------------------------------------- /tape/models/modeling_onehot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .modeling_utils import ProteinConfig 8 | from .modeling_utils import ProteinModel 9 | from .modeling_utils import ValuePredictionHead 10 | from .modeling_utils import SequenceClassificationHead 11 | from .modeling_utils import SequenceToSequenceClassificationHead 12 | from .modeling_utils import PairwiseContactPredictionHead 13 | from ..registry import registry 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class ProteinOneHotConfig(ProteinConfig): 19 | pretrained_config_archive_map: typing.Dict[str, str] = {} 20 | 21 | def __init__(self, 22 | vocab_size: int, 23 | initializer_range: float = 0.02, 24 | use_evolutionary: bool = False, 25 | **kwargs): 26 | super().__init__(**kwargs) 27 | self.vocab_size = vocab_size 28 | self.use_evolutionary = use_evolutionary 29 | self.initializer_range = initializer_range 30 | 31 | 32 | class ProteinOneHotAbstractModel(ProteinModel): 33 | 34 | config_class = ProteinOneHotConfig 35 | pretrained_model_archive_map: typing.Dict[str, str] = {} 36 | base_model_prefix = "onehot" 37 | 38 | def _init_weights(self, module): 39 | """ Initialize the weights """ 40 | if isinstance(module, (nn.Linear, nn.Embedding)): 41 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 42 | if isinstance(module, nn.Linear) and module.bias is not None: 43 | module.bias.data.zero_() 44 | 45 | 46 | class ProteinOneHotModel(ProteinOneHotAbstractModel): 47 | 48 | def __init__(self, config: ProteinOneHotConfig): 49 | super().__init__(config) 50 | self.vocab_size = config.vocab_size 51 | 52 | # Note: this exists *solely* for fp16 support 53 | # There doesn't seem to be an easier way to check whether to use fp16 or fp32 training 54 | buffer = torch.tensor([0.]) 55 | self.register_buffer('_buffer', buffer) 56 | 57 | def forward(self, input_ids, input_mask=None): 58 | if input_mask is None: 59 | input_mask = torch.ones_like(input_ids) 60 | 61 | sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size) 62 | # fp16 compatibility 63 | sequence_output = sequence_output.type_as(self._buffer) 64 | input_mask = input_mask.unsqueeze(2).type_as(sequence_output) 65 | # just a bag-of-words for amino acids 66 | pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1) 67 | 68 | outputs = (sequence_output, pooled_outputs) 69 | return outputs 70 | 71 | 72 | @registry.register_task_model('fluorescence', 'onehot') 73 | @registry.register_task_model('stability', 'onehot') 74 | class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel): 75 | 76 | def __init__(self, config): 77 | super().__init__(config) 78 | 79 | self.onehot = ProteinOneHotModel(config) 80 | self.predict = ValuePredictionHead(config.vocab_size) 81 | 82 | self.init_weights() 83 | 84 | def forward(self, input_ids, input_mask=None, targets=None): 85 | 86 | outputs = self.onehot(input_ids, input_mask=input_mask) 87 | 88 | sequence_output, pooled_output = outputs[:2] 89 | outputs = self.predict(pooled_output, targets) + outputs[2:] 90 | # (loss), prediction_scores, (hidden_states) 91 | return outputs 92 | 93 | 94 | @registry.register_task_model('remote_homology', 'onehot') 95 | class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel): 96 | 97 | def __init__(self, config): 98 | super().__init__(config) 99 | 100 | self.onehot = ProteinOneHotModel(config) 101 | self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels) 102 | 103 | self.init_weights() 104 | 105 | def forward(self, input_ids, input_mask=None, targets=None): 106 | 107 | outputs = self.onehot(input_ids, input_mask=input_mask) 108 | 109 | sequence_output, pooled_output = outputs[:2] 110 | outputs = self.classify(pooled_output, targets) + outputs[2:] 111 | # (loss), prediction_scores, (hidden_states) 112 | return outputs 113 | 114 | 115 | @registry.register_task_model('secondary_structure', 'onehot') 116 | class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel): 117 | 118 | def __init__(self, config): 119 | super().__init__(config) 120 | 121 | self.onehot = ProteinOneHotModel(config) 122 | self.classify = SequenceToSequenceClassificationHead( 123 | config.vocab_size, config.num_labels, ignore_index=-1) 124 | 125 | self.init_weights() 126 | 127 | def forward(self, input_ids, input_mask=None, targets=None): 128 | 129 | outputs = self.onehot(input_ids, input_mask=input_mask) 130 | 131 | sequence_output, pooled_output = outputs[:2] 132 | outputs = self.classify(sequence_output, targets) + outputs[2:] 133 | # (loss), prediction_scores, (hidden_states) 134 | return outputs 135 | 136 | 137 | @registry.register_task_model('contact_prediction', 'onehot') 138 | class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel): 139 | 140 | def __init__(self, config): 141 | super().__init__(config) 142 | 143 | self.onehot = ProteinOneHotModel(config) 144 | self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) 145 | 146 | self.init_weights() 147 | 148 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 149 | 150 | outputs = self.onehot(input_ids, input_mask=input_mask) 151 | 152 | sequence_output, pooled_output = outputs[:2] 153 | outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] 154 | # (loss), prediction_scores, (hidden_states), (attentions) 155 | return outputs 156 | -------------------------------------------------------------------------------- /tape/models/modeling_resnet.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .modeling_utils import ProteinConfig 7 | from .modeling_utils import ProteinModel 8 | from .modeling_utils import get_activation_fn 9 | from .modeling_utils import MLMHead 10 | from .modeling_utils import LayerNorm 11 | from .modeling_utils import ValuePredictionHead 12 | from .modeling_utils import SequenceClassificationHead 13 | from .modeling_utils import SequenceToSequenceClassificationHead 14 | from .modeling_utils import PairwiseContactPredictionHead 15 | from ..registry import registry 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {} 20 | RESNET_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {} 21 | 22 | 23 | class ProteinResNetConfig(ProteinConfig): 24 | pretrained_config_archive_map = RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP 25 | 26 | def __init__(self, 27 | vocab_size: int = 30, 28 | hidden_size: int = 512, 29 | num_hidden_layers: int = 30, 30 | hidden_act: str = "gelu", 31 | hidden_dropout_prob: float = 0.1, 32 | initializer_range: float = 0.02, 33 | layer_norm_eps: float = 1e-12, 34 | **kwargs): 35 | super().__init__(**kwargs) 36 | self.vocab_size = vocab_size 37 | self.num_hidden_layers = num_hidden_layers 38 | self.hidden_size = hidden_size 39 | self.hidden_act = hidden_act 40 | self.hidden_dropout_prob = hidden_dropout_prob 41 | self.initializer_range = initializer_range 42 | self.layer_norm_eps = layer_norm_eps 43 | 44 | 45 | class MaskedConv1d(nn.Conv1d): 46 | 47 | def forward(self, x, input_mask=None): 48 | if input_mask is not None: 49 | x = x * input_mask 50 | return super().forward(x) 51 | 52 | 53 | class ProteinResNetLayerNorm(nn.Module): 54 | 55 | def __init__(self, config): 56 | super().__init__() 57 | self.norm = LayerNorm(config.hidden_size) 58 | 59 | def forward(self, x): 60 | return self.norm(x.transpose(1, 2)).transpose(1, 2) 61 | 62 | 63 | class ProteinResNetBlock(nn.Module): 64 | 65 | def __init__(self, config): 66 | super().__init__() 67 | self.conv1 = MaskedConv1d( 68 | config.hidden_size, config.hidden_size, 3, padding=1, bias=False) 69 | # self.bn1 = nn.BatchNorm1d(config.hidden_size) 70 | self.bn1 = ProteinResNetLayerNorm(config) 71 | self.conv2 = MaskedConv1d( 72 | config.hidden_size, config.hidden_size, 3, padding=1, bias=False) 73 | # self.bn2 = nn.BatchNorm1d(config.hidden_size) 74 | self.bn2 = ProteinResNetLayerNorm(config) 75 | self.activation_fn = get_activation_fn(config.hidden_act) 76 | 77 | def forward(self, x, input_mask=None): 78 | identity = x 79 | 80 | out = self.conv1(x, input_mask) 81 | out = self.bn1(out) 82 | out = self.activation_fn(out) 83 | 84 | out = self.conv2(out, input_mask) 85 | out = self.bn2(out) 86 | 87 | out += identity 88 | out = self.activation_fn(out) 89 | 90 | return out 91 | 92 | 93 | class ProteinResNetEmbeddings(nn.Module): 94 | """Construct the embeddings from word, position and token_type embeddings. 95 | """ 96 | def __init__(self, config): 97 | super().__init__() 98 | embed_dim = config.hidden_size 99 | self.word_embeddings = nn.Embedding(config.vocab_size, embed_dim, padding_idx=0) 100 | inverse_frequency = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim)) 101 | self.register_buffer('inverse_frequency', inverse_frequency) 102 | 103 | self.layer_norm = LayerNorm(embed_dim, eps=config.layer_norm_eps) 104 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 105 | 106 | def forward(self, input_ids): 107 | words_embeddings = self.word_embeddings(input_ids) 108 | 109 | seq_length = input_ids.size(1) 110 | position_ids = torch.arange( 111 | seq_length - 1, -1, -1.0, 112 | dtype=words_embeddings.dtype, 113 | device=words_embeddings.device) 114 | sinusoidal_input = torch.ger(position_ids, self.inverse_frequency) 115 | position_embeddings = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1) 116 | position_embeddings = position_embeddings.unsqueeze(0) 117 | 118 | embeddings = words_embeddings + position_embeddings 119 | embeddings = self.layer_norm(embeddings) 120 | embeddings = self.dropout(embeddings) 121 | return embeddings 122 | 123 | 124 | class ProteinResNetPooler(nn.Module): 125 | def __init__(self, config): 126 | super().__init__() 127 | self.attention_weights = nn.Linear(config.hidden_size, 1) 128 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 129 | self.activation = nn.Tanh() 130 | 131 | def forward(self, hidden_states, mask=None): 132 | # We "pool" the model by simply taking the hidden state corresponding 133 | # to the first token. 134 | attention_scores = self.attention_weights(hidden_states) 135 | if mask is not None: 136 | attention_scores += -10000. * (1 - mask) 137 | attention_weights = torch.softmax(attention_scores, -1) 138 | weighted_mean_embedding = torch.matmul( 139 | hidden_states.transpose(1, 2), attention_weights).squeeze(2) 140 | pooled_output = self.dense(weighted_mean_embedding) 141 | pooled_output = self.activation(pooled_output) 142 | return pooled_output 143 | 144 | 145 | class ResNetEncoder(nn.Module): 146 | 147 | def __init__(self, config): 148 | super().__init__() 149 | self.output_hidden_states = config.output_hidden_states 150 | self.layer = nn.ModuleList( 151 | [ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)]) 152 | 153 | def forward(self, hidden_states, input_mask=None): 154 | all_hidden_states = () 155 | for layer_module in self.layer: 156 | if self.output_hidden_states: 157 | all_hidden_states = all_hidden_states + (hidden_states,) 158 | hidden_states = layer_module(hidden_states, input_mask) 159 | 160 | if self.output_hidden_states: 161 | all_hidden_states = all_hidden_states + (hidden_states,) 162 | 163 | outputs = (hidden_states,) 164 | if self.output_hidden_states: 165 | outputs = outputs + (all_hidden_states,) 166 | 167 | return outputs 168 | 169 | 170 | class ProteinResNetAbstractModel(ProteinModel): 171 | """ An abstract class to handle weights initialization and 172 | a simple interface for dowloading and loading pretrained models. 173 | """ 174 | config_class = ProteinResNetConfig 175 | pretrained_model_archive_map = RESNET_PRETRAINED_MODEL_ARCHIVE_MAP 176 | base_model_prefix = "resnet" 177 | 178 | def __init__(self, config): 179 | super().__init__(config) 180 | 181 | def _init_weights(self, module): 182 | """ Initialize the weights """ 183 | if isinstance(module, nn.Embedding): 184 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 185 | elif isinstance(module, nn.Linear): 186 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 187 | if module.bias is not None: 188 | module.bias.data.zero_() 189 | elif isinstance(module, nn.Conv1d): 190 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 191 | if module.bias is not None: 192 | module.bias.data.zero_() 193 | # elif isinstance(module, ProteinResNetBlock): 194 | # nn.init.constant_(module.bn2.weight, 0) 195 | 196 | 197 | @registry.register_task_model('embed', 'resnet') 198 | class ProteinResNetModel(ProteinResNetAbstractModel): 199 | 200 | def __init__(self, config): 201 | super().__init__(config) 202 | 203 | self.embeddings = ProteinResNetEmbeddings(config) 204 | self.encoder = ResNetEncoder(config) 205 | self.pooler = ProteinResNetPooler(config) 206 | 207 | self.init_weights() 208 | 209 | def forward(self, 210 | input_ids, 211 | input_mask=None): 212 | if input_mask is not None and torch.any(input_mask != 1): 213 | extended_input_mask = input_mask.unsqueeze(2) 214 | # fp16 compatibility 215 | extended_input_mask = extended_input_mask.to( 216 | dtype=next(self.parameters()).dtype) 217 | else: 218 | extended_input_mask = None 219 | 220 | embedding_output = self.embeddings(input_ids) 221 | embedding_output = embedding_output.transpose(1, 2) 222 | if extended_input_mask is not None: 223 | extended_input_mask = extended_input_mask.transpose(1, 2) 224 | encoder_outputs = self.encoder(embedding_output, extended_input_mask) 225 | sequence_output = encoder_outputs[0] 226 | sequence_output = sequence_output.transpose(1, 2).contiguous() 227 | # sequence_output = encoder_outputs[0] 228 | if extended_input_mask is not None: 229 | extended_input_mask = extended_input_mask.transpose(1, 2) 230 | pooled_output = self.pooler(sequence_output, extended_input_mask) 231 | 232 | # add hidden_states and attentions if they are here 233 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] 234 | return outputs # sequence_output, pooled_output, (hidden_states) 235 | 236 | 237 | @registry.register_task_model('masked_language_modeling', 'resnet') 238 | class ProteinResNetForMaskedLM(ProteinResNetAbstractModel): 239 | 240 | def __init__(self, config): 241 | super().__init__(config) 242 | 243 | self.resnet = ProteinResNetModel(config) 244 | self.mlm = MLMHead( 245 | config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps, 246 | ignore_index=-1) 247 | 248 | self.init_weights() 249 | self.tie_weights() 250 | 251 | def tie_weights(self): 252 | """ Make sure we are sharing the input and output embeddings. 253 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 254 | """ 255 | self._tie_or_clone_weights(self.mlm.decoder, 256 | self.resnet.embeddings.word_embeddings) 257 | 258 | def forward(self, 259 | input_ids, 260 | input_mask=None, 261 | targets=None): 262 | 263 | outputs = self.resnet(input_ids, input_mask=input_mask) 264 | 265 | sequence_output, pooled_output = outputs[:2] 266 | outputs = self.mlm(sequence_output, targets) + outputs[2:] 267 | # (loss), prediction_scores, (hidden_states), (attentions) 268 | return outputs 269 | 270 | 271 | @registry.register_task_model('fluorescence', 'resnet') 272 | @registry.register_task_model('stability', 'resnet') 273 | class ProteinResNetForValuePrediction(ProteinResNetAbstractModel): 274 | 275 | def __init__(self, config): 276 | super().__init__(config) 277 | 278 | self.resnet = ProteinResNetModel(config) 279 | self.predict = ValuePredictionHead(config.hidden_size) 280 | 281 | self.init_weights() 282 | 283 | def forward(self, input_ids, input_mask=None, targets=None): 284 | 285 | outputs = self.resnet(input_ids, input_mask=input_mask) 286 | 287 | sequence_output, pooled_output = outputs[:2] 288 | outputs = self.predict(pooled_output, targets) + outputs[2:] 289 | # (loss), prediction_scores, (hidden_states), (attentions) 290 | return outputs 291 | 292 | 293 | @registry.register_task_model('remote_homology', 'resnet') 294 | class ProteinResNetForSequenceClassification(ProteinResNetAbstractModel): 295 | 296 | def __init__(self, config): 297 | super().__init__(config) 298 | 299 | self.resnet = ProteinResNetModel(config) 300 | self.classify = SequenceClassificationHead(config.hidden_size, config.num_labels) 301 | 302 | self.init_weights() 303 | 304 | def forward(self, input_ids, input_mask=None, targets=None): 305 | 306 | outputs = self.resnet(input_ids, input_mask=input_mask) 307 | 308 | sequence_output, pooled_output = outputs[:2] 309 | outputs = self.classify(pooled_output, targets) + outputs[2:] 310 | # (loss), prediction_scores, (hidden_states), (attentions) 311 | return outputs 312 | 313 | 314 | @registry.register_task_model('secondary_structure', 'resnet') 315 | class ProteinResNetForSequenceToSequenceClassification(ProteinResNetAbstractModel): 316 | 317 | def __init__(self, config): 318 | super().__init__(config) 319 | 320 | self.resnet = ProteinResNetModel(config) 321 | self.classify = SequenceToSequenceClassificationHead( 322 | config.hidden_size, config.num_labels, ignore_index=-1) 323 | 324 | self.init_weights() 325 | 326 | def forward(self, input_ids, input_mask=None, targets=None): 327 | 328 | outputs = self.resnet(input_ids, input_mask=input_mask) 329 | 330 | sequence_output, pooled_output = outputs[:2] 331 | outputs = self.classify(sequence_output, targets) + outputs[2:] 332 | # (loss), prediction_scores, (hidden_states), (attentions) 333 | return outputs 334 | 335 | 336 | @registry.register_task_model('contact_prediction', 'resnet') 337 | class ProteinResNetForContactPrediction(ProteinResNetAbstractModel): 338 | 339 | def __init__(self, config): 340 | super().__init__(config) 341 | 342 | self.resnet = ProteinResNetModel(config) 343 | self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) 344 | 345 | self.init_weights() 346 | 347 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 348 | 349 | outputs = self.resnet(input_ids, input_mask=input_mask) 350 | 351 | sequence_output, pooled_output = outputs[:2] 352 | outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] 353 | # (loss), prediction_scores, (hidden_states), (attentions) 354 | return outputs 355 | -------------------------------------------------------------------------------- /tape/models/modeling_trrosetta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..registry import registry 5 | from .modeling_utils import ProteinConfig 6 | from .modeling_utils import ProteinModel 7 | 8 | URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/" 9 | TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 10 | 'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin", 11 | 'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin", 12 | 'xac': URL_PREFIX + "trRosetta-xac-pytorch_model.bin", 13 | 'xad': URL_PREFIX + "trRosetta-xad-pytorch_model.bin", 14 | 'xae': URL_PREFIX + "trRosetta-xae-pytorch_model.bin", 15 | } 16 | TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 17 | 'xaa': URL_PREFIX + "trRosetta-xaa-config.json", 18 | 'xab': URL_PREFIX + "trRosetta-xab-config.json", 19 | 'xac': URL_PREFIX + "trRosetta-xac-config.json", 20 | 'xad': URL_PREFIX + "trRosetta-xad-config.json", 21 | 'xae': URL_PREFIX + "trRosetta-xae-config.json", 22 | } 23 | 24 | 25 | class TRRosettaConfig(ProteinConfig): 26 | 27 | pretrained_config_archive_map = TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP 28 | 29 | def __init__(self, 30 | num_features: int = 64, 31 | kernel_size: int = 3, 32 | num_layers: int = 61, 33 | dropout: float = 0.15, 34 | msa_cutoff: float = 0.8, 35 | penalty_coeff: float = 4.5, 36 | initializer_range: float = 0.02, 37 | **kwargs): 38 | super().__init__(**kwargs) 39 | self.num_features = num_features 40 | self.kernel_size = kernel_size 41 | self.num_layers = num_layers 42 | self.dropout = dropout 43 | self.msa_cutoff = msa_cutoff 44 | self.penalty_coeff = penalty_coeff 45 | self.initializer_range = initializer_range 46 | 47 | 48 | class MSAFeatureExtractor(nn.Module): 49 | 50 | def __init__(self, config: TRRosettaConfig): 51 | super().__init__() 52 | self.msa_cutoff = config.msa_cutoff 53 | self.penalty_coeff = config.penalty_coeff 54 | 55 | def forward(self, msa1hot): 56 | # Convert to float, then potentially back to half 57 | # These transforms aren't well suited to half-precision 58 | initial_type = msa1hot.dtype 59 | 60 | msa1hot = msa1hot.float() 61 | seqlen = msa1hot.size(2) 62 | 63 | weights = self.reweight(msa1hot) 64 | features_1d = self.extract_features_1d(msa1hot, weights) 65 | features_2d = self.extract_features_2d(msa1hot, weights) 66 | 67 | left = features_1d.unsqueeze(2).repeat(1, 1, seqlen, 1) 68 | right = features_1d.unsqueeze(1).repeat(1, seqlen, 1, 1) 69 | features = torch.cat((left, right, features_2d), -1) 70 | features = features.type(initial_type) 71 | features = features.permute(0, 3, 1, 2) 72 | features = features.contiguous() 73 | return features 74 | 75 | def reweight(self, msa1hot, eps=1e-9): 76 | # Reweight 77 | seqlen = msa1hot.size(2) 78 | id_min = seqlen * self.msa_cutoff 79 | id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0) 80 | id_mask = id_mtx > id_min 81 | weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps) 82 | return weights 83 | 84 | def extract_features_1d(self, msa1hot, weights): 85 | # 1D Features 86 | f1d_seq = msa1hot[:, 0, :, :20] 87 | batch_size = msa1hot.size(0) 88 | seqlen = msa1hot.size(2) 89 | 90 | # msa2pssm 91 | beff = weights.sum() 92 | f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9 93 | h_i = (-f_i * f_i.log()).sum(2, keepdims=True) 94 | f1d_pssm = torch.cat((f_i, h_i), dim=2) 95 | f1d = torch.cat((f1d_seq, f1d_pssm), dim=2) 96 | f1d = f1d.view(batch_size, seqlen, 42) 97 | return f1d 98 | 99 | def extract_features_2d(self, msa1hot, weights): 100 | # 2D Features 101 | batch_size = msa1hot.size(0) 102 | num_alignments = msa1hot.size(1) 103 | seqlen = msa1hot.size(2) 104 | num_symbols = 21 105 | 106 | if num_alignments == 1: 107 | # No alignments, predict from sequence alone 108 | f2d_dca = torch.zeros( 109 | batch_size, seqlen, seqlen, 442, 110 | dtype=torch.float, 111 | device=msa1hot.device) 112 | return f2d_dca 113 | 114 | # compute fast_dca 115 | # covariance 116 | x = msa1hot.view(batch_size, num_alignments, seqlen * num_symbols) 117 | num_points = weights.sum(1) - weights.mean(1).sqrt() 118 | mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None] 119 | x = (x - mean) * weights[:, :, None].sqrt() 120 | cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None] 121 | 122 | # inverse covariance 123 | reg = torch.eye(seqlen * num_symbols, 124 | device=weights.device, 125 | dtype=weights.dtype)[None] 126 | reg = reg * self.penalty_coeff / weights.sum(1, keepdims=True).sqrt().unsqueeze(2) 127 | cov_reg = cov + reg 128 | inv_cov = torch.stack([torch.inverse(cr) for cr in cov_reg.unbind(0)], 0) 129 | 130 | x1 = inv_cov.view(batch_size, seqlen, num_symbols, seqlen, num_symbols) 131 | x2 = x1.permute(0, 1, 3, 2, 4) 132 | features = x2.reshape(batch_size, seqlen, seqlen, num_symbols * num_symbols) 133 | 134 | x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * ( 135 | 1 - torch.eye(seqlen, device=weights.device, dtype=weights.dtype)[None]) 136 | apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum( 137 | (1, 2), keepdims=True) 138 | contacts = (x3 - apc) * (1 - torch.eye( 139 | seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0)) 140 | 141 | f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3) 142 | return f2d_dca 143 | 144 | @property 145 | def feature_size(self) -> int: 146 | return 526 147 | 148 | 149 | class DilatedResidualBlock(nn.Module): 150 | 151 | def __init__(self, num_features: int, kernel_size: int, dilation: int, dropout: float): 152 | super().__init__() 153 | padding = self._get_padding(kernel_size, dilation) 154 | self.conv1 = nn.Conv2d( 155 | num_features, num_features, kernel_size, padding=padding, dilation=dilation) 156 | self.norm1 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6) 157 | self.actv1 = nn.ELU(inplace=True) 158 | self.dropout = nn.Dropout(dropout) 159 | self.conv2 = nn.Conv2d( 160 | num_features, num_features, kernel_size, padding=padding, dilation=dilation) 161 | self.norm2 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6) 162 | self.actv2 = nn.ELU(inplace=True) 163 | self.apply(self._init_weights) 164 | nn.init.constant_(self.norm2.weight, 0) 165 | 166 | def _get_padding(self, kernel_size: int, dilation: int) -> int: 167 | return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 168 | 169 | def _init_weights(self, module): 170 | """ Initialize the weights """ 171 | if isinstance(module, nn.Conv2d): 172 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 173 | if module.bias is not None: 174 | module.bias.data.zero_() 175 | 176 | # elif isinstance(module, DilatedResidualBlock): 177 | # nn.init.constant_(module.norm2.weight, 0) 178 | 179 | def forward(self, features): 180 | shortcut = features 181 | features = self.conv1(features) 182 | features = self.norm1(features) 183 | features = self.actv1(features) 184 | features = self.dropout(features) 185 | features = self.conv2(features) 186 | features = self.norm2(features) 187 | features = self.actv2(features + shortcut) 188 | return features 189 | 190 | 191 | class TRRosettaAbstractModel(ProteinModel): 192 | 193 | config_class = TRRosettaConfig 194 | base_model_prefix = 'trrosetta' 195 | pretrained_model_archive_map = TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP 196 | 197 | def __init__(self, config: TRRosettaConfig): 198 | super().__init__(config) 199 | 200 | def _init_weights(self, module): 201 | """ Initialize the weights """ 202 | if isinstance(module, nn.Linear): 203 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 204 | if module.bias is not None: 205 | module.bias.data.zero_() 206 | elif isinstance(module, nn.Conv2d): 207 | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') 208 | if module.bias is not None: 209 | module.bias.data.zero_() 210 | elif isinstance(module, DilatedResidualBlock): 211 | nn.init.constant_(module.norm2.weight, 0) 212 | 213 | 214 | class TRRosettaPredictor(TRRosettaAbstractModel): 215 | 216 | def __init__(self, config: TRRosettaConfig): 217 | super().__init__(config) 218 | layers = [ 219 | nn.Conv2d(526, config.num_features, 1), 220 | nn.InstanceNorm2d(config.num_features, affine=True, eps=1e-6), 221 | nn.ELU(), 222 | nn.Dropout(config.dropout)] 223 | 224 | dilation = 1 225 | for _ in range(config.num_layers): 226 | block = DilatedResidualBlock( 227 | config.num_features, config.kernel_size, dilation, config.dropout) 228 | layers.append(block) 229 | 230 | dilation *= 2 231 | if dilation > 16: 232 | dilation = 1 233 | 234 | self.resnet = nn.Sequential(*layers) 235 | self.predict_theta = nn.Conv2d(config.num_features, 25, 1) 236 | self.predict_phi = nn.Conv2d(config.num_features, 13, 1) 237 | self.predict_dist = nn.Conv2d(config.num_features, 37, 1) 238 | self.predict_bb = nn.Conv2d(config.num_features, 3, 1) 239 | self.predict_omega = nn.Conv2d(config.num_features, 25, 1) 240 | 241 | self.init_weights() 242 | 243 | def init_weights(self): 244 | self.apply(self._init_weights) 245 | nn.init.constant_(self.predict_theta.weight, 0) 246 | nn.init.constant_(self.predict_phi.weight, 0) 247 | nn.init.constant_(self.predict_dist.weight, 0) 248 | nn.init.constant_(self.predict_bb.weight, 0) 249 | nn.init.constant_(self.predict_omega.weight, 0) 250 | 251 | def forward(self, 252 | features, 253 | theta=None, 254 | phi=None, 255 | dist=None, 256 | omega=None): 257 | batch_size = features.size(0) 258 | seqlen = features.size(2) 259 | embedding = self.resnet(features) 260 | 261 | # anglegrams for theta 262 | logits_theta = self.predict_theta(embedding) 263 | 264 | # anglegrams for phi 265 | logits_phi = self.predict_phi(embedding) 266 | 267 | # symmetrize 268 | sym_embedding = 0.5 * (embedding + embedding.transpose(-1, -2)) 269 | 270 | # distograms 271 | logits_dist = self.predict_dist(sym_embedding) 272 | 273 | # beta-strand pairings (not used) 274 | # logits_bb = self.predict_bb(sym_embedding) 275 | 276 | # anglegrams for omega 277 | logits_omega = self.predict_omega(sym_embedding) 278 | 279 | logits_dist = logits_dist.permute(0, 2, 3, 1).contiguous() 280 | logits_theta = logits_theta.permute(0, 2, 3, 1).contiguous() 281 | logits_omega = logits_omega.permute(0, 2, 3, 1).contiguous() 282 | logits_phi = logits_phi.permute(0, 2, 3, 1).contiguous() 283 | 284 | probs = {} 285 | probs['p_dist'] = nn.Softmax(-1)(logits_dist) 286 | probs['p_theta'] = nn.Softmax(-1)(logits_theta) 287 | probs['p_omega'] = nn.Softmax(-1)(logits_omega) 288 | probs['p_phi'] = nn.Softmax(-1)(logits_phi) 289 | outputs = (probs,) 290 | 291 | metrics = {} 292 | total_loss = 0 293 | 294 | if dist is not None: 295 | logits_dist = logits_dist.reshape(batch_size * seqlen * seqlen, 37) 296 | loss_dist = nn.CrossEntropyLoss(ignore_index=-1)(logits_dist, dist.view(-1)) 297 | metrics['dist'] = loss_dist 298 | total_loss += loss_dist 299 | if theta is not None: 300 | logits_theta = logits_theta.reshape(batch_size * seqlen * seqlen, 25) 301 | loss_theta = nn.CrossEntropyLoss(ignore_index=0)(logits_theta, theta.view(-1)) 302 | metrics['theta'] = loss_theta 303 | total_loss += loss_theta 304 | if omega is not None: 305 | logits_omega = logits_omega.reshape(batch_size * seqlen * seqlen, 25) 306 | loss_omega = nn.CrossEntropyLoss(ignore_index=0)(logits_omega, omega.view(-1)) 307 | metrics['omega'] = loss_omega 308 | total_loss += loss_omega 309 | if phi is not None: 310 | logits_phi = logits_phi.reshape(batch_size * seqlen * seqlen, 13) 311 | loss_phi = nn.CrossEntropyLoss(ignore_index=0)(logits_phi, phi.view(-1)) 312 | metrics['phi'] = loss_phi 313 | total_loss += loss_phi 314 | 315 | if len(metrics) > 0: 316 | outputs = ((total_loss, metrics),) + outputs 317 | 318 | return outputs 319 | 320 | 321 | @registry.register_task_model('trrosetta', 'trrosetta') 322 | class TRRosetta(TRRosettaAbstractModel): 323 | 324 | def __init__(self, config: TRRosettaConfig): 325 | super().__init__(config) 326 | self.extract_features = MSAFeatureExtractor(config) 327 | self.trrosetta = TRRosettaPredictor(config) 328 | 329 | def forward(self, 330 | msa1hot, 331 | theta=None, 332 | phi=None, 333 | dist=None, 334 | omega=None): 335 | features = self.extract_features(msa1hot) 336 | return self.trrosetta(features, theta, phi, dist, omega) 337 | -------------------------------------------------------------------------------- /tape/models/modeling_unirep.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils import weight_norm 6 | 7 | from .modeling_utils import ProteinConfig 8 | from .modeling_utils import ProteinModel 9 | from .modeling_utils import ValuePredictionHead 10 | from .modeling_utils import SequenceClassificationHead 11 | from .modeling_utils import SequenceToSequenceClassificationHead 12 | from .modeling_utils import PairwiseContactPredictionHead 13 | from ..registry import registry 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/" 19 | UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = { 20 | 'babbler-1900': URL_PREFIX + 'unirep-base-config.json'} 21 | UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = { 22 | 'babbler-1900': URL_PREFIX + 'unirep-base-pytorch_model.bin'} 23 | 24 | 25 | class UniRepConfig(ProteinConfig): 26 | pretrained_config_archive_map = UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP 27 | 28 | def __init__(self, 29 | vocab_size: int = 26, 30 | input_size: int = 10, 31 | hidden_size: int = 1900, 32 | hidden_dropout_prob: float = 0.1, 33 | layer_norm_eps: float = 1e-12, 34 | initializer_range: float = 0.02, 35 | **kwargs): 36 | super().__init__(**kwargs) 37 | self.vocab_size = vocab_size 38 | self.input_size = input_size 39 | self.hidden_size = hidden_size 40 | self.hidden_dropout_prob = hidden_dropout_prob 41 | self.layer_norm_eps = layer_norm_eps 42 | self.initializer_range = initializer_range 43 | 44 | 45 | class mLSTMCell(nn.Module): 46 | def __init__(self, config): 47 | super().__init__() 48 | project_size = config.hidden_size * 4 49 | self.wmx = weight_norm( 50 | nn.Linear(config.input_size, config.hidden_size, bias=False)) 51 | self.wmh = weight_norm( 52 | nn.Linear(config.hidden_size, config.hidden_size, bias=False)) 53 | self.wx = weight_norm( 54 | nn.Linear(config.input_size, project_size, bias=False)) 55 | self.wh = weight_norm( 56 | nn.Linear(config.hidden_size, project_size, bias=True)) 57 | 58 | def forward(self, inputs, state): 59 | h_prev, c_prev = state 60 | m = self.wmx(inputs) * self.wmh(h_prev) 61 | z = self.wx(inputs) + self.wh(m) 62 | i, f, o, u = torch.chunk(z, 4, 1) 63 | i = torch.sigmoid(i) 64 | f = torch.sigmoid(f) 65 | o = torch.sigmoid(o) 66 | u = torch.tanh(u) 67 | c = f * c_prev + i * u 68 | h = o * torch.tanh(c) 69 | 70 | return h, c 71 | 72 | 73 | class mLSTM(nn.Module): 74 | 75 | def __init__(self, config): 76 | super().__init__() 77 | self.mlstm_cell = mLSTMCell(config) 78 | self.hidden_size = config.hidden_size 79 | 80 | def forward(self, inputs, state=None, mask=None): 81 | batch_size = inputs.size(0) 82 | seqlen = inputs.size(1) 83 | 84 | if mask is None: 85 | mask = torch.ones(batch_size, seqlen, 1, dtype=inputs.dtype, device=inputs.device) 86 | elif mask.dim() == 2: 87 | mask = mask.unsqueeze(2) 88 | 89 | if state is None: 90 | zeros = torch.zeros(batch_size, self.hidden_size, 91 | dtype=inputs.dtype, device=inputs.device) 92 | state = (zeros, zeros) 93 | 94 | steps = [] 95 | for seq in range(seqlen): 96 | prev = state 97 | seq_input = inputs[:, seq, :] 98 | hx, cx = self.mlstm_cell(seq_input, state) 99 | seqmask = mask[:, seq] 100 | hx = seqmask * hx + (1 - seqmask) * prev[0] 101 | cx = seqmask * cx + (1 - seqmask) * prev[1] 102 | state = (hx, cx) 103 | steps.append(hx) 104 | 105 | return torch.stack(steps, 1), (hx, cx) 106 | 107 | 108 | class UniRepAbstractModel(ProteinModel): 109 | 110 | config_class = UniRepConfig 111 | pretrained_model_archive_map = UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP 112 | base_model_prefix = "unirep" 113 | 114 | def _init_weights(self, module): 115 | """ Initialize the weights """ 116 | if isinstance(module, (nn.Linear, nn.Embedding)): 117 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 118 | if isinstance(module, nn.Linear) and module.bias is not None: 119 | module.bias.data.zero_() 120 | 121 | 122 | @registry.register_task_model('embed', 'unirep') 123 | class UniRepModel(UniRepAbstractModel): 124 | 125 | def __init__(self, config: UniRepConfig): 126 | super().__init__(config) 127 | self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size) 128 | self.encoder = mLSTM(config) 129 | self.output_hidden_states = config.output_hidden_states 130 | self.init_weights() 131 | 132 | def forward(self, input_ids, input_mask=None): 133 | if input_mask is None: 134 | input_mask = torch.ones_like(input_ids) 135 | 136 | # fp16 compatibility 137 | input_mask = input_mask.to(dtype=next(self.parameters()).dtype) 138 | embedding_output = self.embed_matrix(input_ids) 139 | 140 | encoder_outputs = self.encoder(embedding_output, mask=input_mask) 141 | sequence_output = encoder_outputs[0] 142 | hidden_states = encoder_outputs[1] 143 | pooled_outputs = torch.cat(hidden_states, 1) 144 | 145 | outputs = (sequence_output, pooled_outputs) 146 | return outputs 147 | 148 | 149 | @registry.register_task_model('language_modeling', 'unirep') 150 | class UniRepForLM(UniRepAbstractModel): 151 | # TODO: Fix this for UniRep - UniRep changes the size of the targets 152 | 153 | def __init__(self, config): 154 | super().__init__(config) 155 | 156 | self.unirep = UniRepModel(config) 157 | self.feedforward = nn.Linear(config.hidden_size, config.vocab_size - 1) 158 | 159 | self.init_weights() 160 | 161 | def forward(self, 162 | input_ids, 163 | input_mask=None, 164 | targets=None): 165 | 166 | outputs = self.unirep(input_ids, input_mask=input_mask) 167 | 168 | sequence_output, pooled_output = outputs[:2] 169 | prediction_scores = self.feedforward(sequence_output) 170 | 171 | # add hidden states and if they are here 172 | outputs = (prediction_scores,) + outputs[2:] 173 | 174 | if targets is not None: 175 | targets = targets[:, 1:] 176 | prediction_scores = prediction_scores[:, :-1] 177 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 178 | lm_loss = loss_fct( 179 | prediction_scores.view(-1, self.config.vocab_size), targets.view(-1)) 180 | outputs = (lm_loss,) + outputs 181 | 182 | # (loss), prediction_scores, (hidden_states) 183 | return outputs 184 | 185 | 186 | @registry.register_task_model('fluorescence', 'unirep') 187 | @registry.register_task_model('stability', 'unirep') 188 | class UniRepForValuePrediction(UniRepAbstractModel): 189 | 190 | def __init__(self, config): 191 | super().__init__(config) 192 | 193 | self.unirep = UniRepModel(config) 194 | self.predict = ValuePredictionHead(config.hidden_size * 2) 195 | 196 | self.init_weights() 197 | 198 | def forward(self, input_ids, input_mask=None, targets=None): 199 | 200 | outputs = self.unirep(input_ids, input_mask=input_mask) 201 | 202 | sequence_output, pooled_output = outputs[:2] 203 | outputs = self.predict(pooled_output, targets) + outputs[2:] 204 | # (loss), prediction_scores, (hidden_states) 205 | return outputs 206 | 207 | 208 | @registry.register_task_model('remote_homology', 'unirep') 209 | class UniRepForSequenceClassification(UniRepAbstractModel): 210 | 211 | def __init__(self, config): 212 | super().__init__(config) 213 | 214 | self.unirep = UniRepModel(config) 215 | self.classify = SequenceClassificationHead( 216 | config.hidden_size * 2, config.num_labels) 217 | 218 | self.init_weights() 219 | 220 | def forward(self, input_ids, input_mask=None, targets=None): 221 | 222 | outputs = self.unirep(input_ids, input_mask=input_mask) 223 | 224 | sequence_output, pooled_output = outputs[:2] 225 | outputs = self.classify(pooled_output, targets) + outputs[2:] 226 | # (loss), prediction_scores, (hidden_states) 227 | return outputs 228 | 229 | 230 | @registry.register_task_model('secondary_structure', 'unirep') 231 | class UniRepForSequenceToSequenceClassification(UniRepAbstractModel): 232 | 233 | def __init__(self, config): 234 | super().__init__(config) 235 | 236 | self.unirep = UniRepModel(config) 237 | self.classify = SequenceToSequenceClassificationHead( 238 | config.hidden_size, config.num_labels, ignore_index=-1) 239 | 240 | self.init_weights() 241 | 242 | def forward(self, input_ids, input_mask=None, targets=None): 243 | 244 | outputs = self.unirep(input_ids, input_mask=input_mask) 245 | 246 | sequence_output, pooled_output = outputs[:2] 247 | outputs = self.classify(sequence_output, targets) + outputs[2:] 248 | # (loss), prediction_scores, (hidden_states) 249 | return outputs 250 | 251 | 252 | @registry.register_task_model('contact_prediction', 'unirep') 253 | class UniRepForContactPrediction(UniRepAbstractModel): 254 | 255 | def __init__(self, config): 256 | super().__init__(config) 257 | 258 | self.unirep = UniRepModel(config) 259 | self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) 260 | 261 | self.init_weights() 262 | 263 | def forward(self, input_ids, protein_length, input_mask=None, targets=None): 264 | 265 | outputs = self.unirep(input_ids, input_mask=input_mask) 266 | 267 | sequence_output, pooled_output = outputs[:2] 268 | outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] 269 | # (loss), prediction_scores, (hidden_states), (attentions) 270 | return outputs 271 | -------------------------------------------------------------------------------- /tape/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Modifications by Roshan Rao 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch optimization for BERT model.""" 17 | 18 | import logging 19 | import math 20 | 21 | import torch 22 | from torch.optim import Optimizer # type: ignore 23 | from torch.optim.lr_scheduler import LambdaLR 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class ConstantLRSchedule(LambdaLR): 29 | """ Constant learning rate schedule. 30 | """ 31 | def __init__(self, optimizer, last_epoch=-1): 32 | super(ConstantLRSchedule, self).__init__( 33 | optimizer, lambda _: 1.0, last_epoch=last_epoch) 34 | 35 | 36 | class WarmupConstantSchedule(LambdaLR): 37 | """ Linear warmup and then constant. 38 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` 39 | training steps. Keeps learning rate schedule equal to 1. after warmup_steps. 40 | """ 41 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 42 | self.warmup_steps = warmup_steps 43 | super(WarmupConstantSchedule, self).__init__( 44 | optimizer, self.lr_lambda, last_epoch=last_epoch) 45 | 46 | def lr_lambda(self, step): 47 | if step < self.warmup_steps: 48 | return float(step) / float(max(1.0, self.warmup_steps)) 49 | return 1. 50 | 51 | 52 | class WarmupLinearSchedule(LambdaLR): 53 | """ Linear warmup and then linear decay. 54 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 55 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` 56 | steps. 57 | """ 58 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 59 | self.warmup_steps = warmup_steps 60 | self.t_total = t_total 61 | super(WarmupLinearSchedule, self).__init__( 62 | optimizer, self.lr_lambda, last_epoch=last_epoch) 63 | 64 | def lr_lambda(self, step): 65 | if step < self.warmup_steps: 66 | return float(step) / float(max(1, self.warmup_steps)) 67 | return max(0.0, float(self.t_total - step) / float( 68 | max(1.0, self.t_total - self.warmup_steps))) 69 | 70 | 71 | class WarmupCosineSchedule(LambdaLR): 72 | """ Linear warmup and then cosine decay. 73 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 74 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps 75 | following a cosine curve. If `cycles` (default=0.5) is different from default, learning 76 | rate follows cosine function after warmup. 77 | """ 78 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 79 | self.warmup_steps = warmup_steps 80 | self.t_total = t_total 81 | self.cycles = cycles 82 | super(WarmupCosineSchedule, self).__init__( 83 | optimizer, self.lr_lambda, last_epoch=last_epoch) 84 | 85 | def lr_lambda(self, step): 86 | if step < self.warmup_steps: 87 | return float(step) / float(max(1.0, self.warmup_steps)) 88 | # progress after warmup 89 | progress = float(step - self.warmup_steps) / float( 90 | max(1, self.t_total - self.warmup_steps)) 91 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 92 | 93 | 94 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 95 | """ Linear warmup and then cosine cycles with hard restarts. 96 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 97 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times 98 | a cosine decaying learning rate (with hard restarts). 99 | """ 100 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 101 | self.warmup_steps = warmup_steps 102 | self.t_total = t_total 103 | self.cycles = cycles 104 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 105 | optimizer, self.lr_lambda, last_epoch=last_epoch) 106 | 107 | def lr_lambda(self, step): 108 | if step < self.warmup_steps: 109 | return float(step) / float(max(1, self.warmup_steps)) 110 | # progress after warmup 111 | progress = float(step - self.warmup_steps) / float( 112 | max(1, self.t_total - self.warmup_steps)) 113 | if progress >= 1.0: 114 | return 0.0 115 | return max(0.0, 0.5 * (1. + math.cos( 116 | math.pi * ((float(self.cycles) * progress) % 1.0)))) 117 | 118 | 119 | class AdamW(Optimizer): 120 | """ Implements Adam algorithm with weight decay fix. 121 | 122 | Parameters: 123 | lr (float): learning rate. Default 1e-3. 124 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 125 | eps (float): Adams epsilon. Default: 1e-6 126 | weight_decay (float): Weight decay. Default: 0.0 127 | correct_bias (bool): can be set to False to avoid correcting bias in Adam 128 | (e.g. like in Bert TF repository). Default True. 129 | """ 130 | def __init__(self, 131 | params, 132 | lr=1e-3, 133 | betas=(0.9, 0.999), 134 | eps=1e-6, 135 | weight_decay=0.0, 136 | correct_bias=True): 137 | if lr < 0.0: 138 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 139 | if not 0.0 <= betas[0] < 1.0: 140 | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") 141 | if not 0.0 <= betas[1] < 1.0: 142 | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") 143 | if not 0.0 <= eps: 144 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 145 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 146 | correct_bias=correct_bias) 147 | super(AdamW, self).__init__(params, defaults) 148 | 149 | def step(self, closure=None): 150 | """Performs a single optimization step. 151 | 152 | Arguments: 153 | closure (callable, optional): A closure that reevaluates the model 154 | and returns the loss. 155 | """ 156 | loss = None 157 | if closure is not None: 158 | loss = closure() 159 | 160 | for group in self.param_groups: 161 | for p in group['params']: 162 | if p.grad is None: 163 | continue 164 | grad = p.grad.data 165 | if grad.is_sparse: 166 | raise RuntimeError('Adam does not support sparse gradients, ' 167 | 'please consider SparseAdam instead') 168 | 169 | state = self.state[p] 170 | 171 | # State initialization 172 | if len(state) == 0: 173 | state['step'] = 0 174 | # Exponential moving average of gradient values 175 | state['exp_avg'] = torch.zeros_like(p.data) 176 | # Exponential moving average of squared gradient values 177 | state['exp_avg_sq'] = torch.zeros_like(p.data) 178 | 179 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 180 | beta1, beta2 = group['betas'] 181 | 182 | state['step'] += 1 183 | 184 | # Decay the first and second moment running average coefficient 185 | # In-place operations to update the averages at the same time 186 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 187 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 188 | denom = exp_avg_sq.sqrt().add_(group['eps']) 189 | 190 | step_size = group['lr'] 191 | if group['correct_bias']: # No bias correction for Bert 192 | bias_correction1 = 1.0 - beta1 ** state['step'] 193 | bias_correction2 = 1.0 - beta2 ** state['step'] 194 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 195 | 196 | p.data.addcdiv_(-step_size, exp_avg, denom) 197 | 198 | # Just adding the square of the weights to the loss function is *not* 199 | # the correct way of using L2 regularization/weight decay with Adam, 200 | # since that will interact with the m and v parameters in strange ways. 201 | # 202 | # Instead we want to decay the weights in a manner that doesn't interact 203 | # with the m/v parameters. This is equivalent to adding the square 204 | # of the weights to the loss with plain (non-momentum) SGD. 205 | # Add weight decay at the end (fixed version) 206 | if group['weight_decay'] > 0.0: 207 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 208 | 209 | return loss 210 | -------------------------------------------------------------------------------- /tape/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type, Callable, Optional, Union 2 | from torch.utils.data import Dataset 3 | from .models.modeling_utils import ProteinModel 4 | from pathlib import Path 5 | 6 | PathType = Union[str, Path] 7 | 8 | 9 | class TAPETaskSpec: 10 | """ 11 | Attributes 12 | ---------- 13 | name (str): 14 | The name of the TAPE task 15 | dataset (Type[Dataset]): 16 | The dataset used in the TAPE task 17 | num_labels (int): 18 | number of labels used if this is a classification task 19 | models (Dict[str, ProteinModel]): 20 | The set of models that can be used for this task. Default: {}. 21 | """ 22 | 23 | def __init__(self, 24 | name: str, 25 | dataset: Type[Dataset], 26 | num_labels: int = -1, 27 | models: Optional[Dict[str, Type[ProteinModel]]] = None): 28 | self.name = name 29 | self.dataset = dataset 30 | self.num_labels = num_labels 31 | self.models = models if models is not None else {} 32 | 33 | def register_model(self, model_name: str, model_cls: Optional[Type[ProteinModel]] = None): 34 | if model_cls is not None: 35 | if model_name in self.models: 36 | raise KeyError( 37 | f"A model with name '{model_name}' is already registered for this task") 38 | self.models[model_name] = model_cls 39 | return model_cls 40 | else: 41 | return lambda model_cls: self.register_model(model_name, model_cls) 42 | 43 | def get_model(self, model_name: str) -> Type[ProteinModel]: 44 | return self.models[model_name] 45 | 46 | 47 | class Registry: 48 | r"""Class for registry object which acts as the 49 | central repository for TAPE.""" 50 | 51 | task_name_mapping: Dict[str, TAPETaskSpec] = {} 52 | metric_name_mapping: Dict[str, Callable] = {} 53 | 54 | @classmethod 55 | def register_task(cls, 56 | task_name: str, 57 | num_labels: int = -1, 58 | dataset: Optional[Type[Dataset]] = None, 59 | models: Optional[Dict[str, Type[ProteinModel]]] = None): 60 | """ Register a a new TAPE task. This creates a new TAPETaskSpec. 61 | 62 | Args: 63 | 64 | task_name (str): The name of the TAPE task. 65 | num_labels (int): Number of labels used if this is a classification task. If this 66 | is not a classification task, simply leave the default as -1. 67 | dataset (Type[Dataset]): The dataset used in the TAPE task. 68 | models (Optional[Dict[str, ProteinModel]]): The set of models that can be used for 69 | this task. If you do not pass this argument, you can register models to the task 70 | later by using `registry.register_task_model`. Default: {}. 71 | 72 | Examples: 73 | 74 | There are two ways of registering a new task. First, one can define the task by simply 75 | declaring all the components, and then calling the register method, like so: 76 | 77 | class SecondaryStructureDataset(Dataset): 78 | ... 79 | 80 | class ProteinBertForSequenceToSequenceClassification(): 81 | ... 82 | 83 | registry.register_task( 84 | 'secondary_structure', 3, SecondaryStructureDataset, 85 | {'transformer': ProteinBertForSequenceToSequenceClassification}) 86 | 87 | This will register a new task, 'secondary_structure', with a single model. More models 88 | can be added with `registry.register_task_model`. Alternatively, this can be used as a 89 | decorator: 90 | 91 | @registry.regsiter_task('secondary_structure', 3) 92 | class SecondaryStructureDataset(Dataset): 93 | ... 94 | 95 | @registry.register_task_model('secondary_structure', 'transformer') 96 | class ProteinBertForSequenceToSequenceClassification(): 97 | ... 98 | 99 | These two pieces of code are exactly equivalent, in terms of the resulting registry 100 | state. 101 | 102 | """ 103 | if dataset is not None: 104 | if models is None: 105 | models = {} 106 | task_spec = TAPETaskSpec(task_name, dataset, num_labels, models) 107 | return cls.register_task_spec(task_name, task_spec).dataset 108 | else: 109 | return lambda dataset: cls.register_task(task_name, num_labels, dataset, models) 110 | 111 | @classmethod 112 | def register_task_spec(cls, task_name: str, task_spec: Optional[TAPETaskSpec] = None): 113 | """ Registers a task_spec directly. If you find it easier to actually create a 114 | TAPETaskSpec manually, and then register it, feel free to use this method, 115 | but otherwise it is likely easier to use `registry.register_task`. 116 | """ 117 | if task_spec is not None: 118 | if task_name in cls.task_name_mapping: 119 | raise KeyError(f"A task with name '{task_name}' is already registered") 120 | cls.task_name_mapping[task_name] = task_spec 121 | return task_spec 122 | else: 123 | return lambda task_spec: cls.register_task_spec(task_name, task_spec) 124 | 125 | @classmethod 126 | def register_task_model(cls, 127 | task_name: str, 128 | model_name: str, 129 | model_cls: Optional[Type[ProteinModel]] = None): 130 | r"""Register a specific model to a task with the provided model name. 131 | The task must already be in the registry - you cannot register a 132 | model to an unregistered task. 133 | 134 | Args: 135 | task_name (str): Name of task to which to register the model. 136 | model_name (str): Name of model to use when registering task, this 137 | is the name that you will use to refer to the model on the 138 | command line. 139 | model_cls (Type[ProteinModel]): The model to register. 140 | 141 | Examples: 142 | 143 | As with `registry.register_task`, this can both be used as a regular 144 | python function, and as a decorator. For example this: 145 | 146 | class ProteinBertForSequenceToSequenceClassification(): 147 | ... 148 | registry.register_task_model( 149 | 'secondary_structure', 'transformer', 150 | ProteinBertForSequenceToSequenceClassification) 151 | 152 | and as a decorator: 153 | 154 | @registry.register_task_model('secondary_structure', 'transformer') 155 | class ProteinBertForSequenceToSequenceClassification(): 156 | ... 157 | 158 | are both equivalent. 159 | """ 160 | if task_name not in cls.task_name_mapping: 161 | raise KeyError( 162 | f"Tried to register a task model for an unregistered task: {task_name}. " 163 | f"Make sure to register the task {task_name} first.") 164 | return cls.task_name_mapping[task_name].register_model(model_name, model_cls) 165 | 166 | @classmethod 167 | def register_metric(cls, name: str) -> Callable[[Callable], Callable]: 168 | r"""Register a metric to registry with key 'name' 169 | 170 | Args: 171 | name: Key with which the metric will be registered. 172 | 173 | Usage:: 174 | from tape.registry import registry 175 | 176 | @registry.register_metric('mse') 177 | def mean_squred_error(inputs, outputs): 178 | ... 179 | """ 180 | 181 | def wrap(fn: Callable) -> Callable: 182 | assert callable(fn), "All metrics must be callable" 183 | cls.metric_name_mapping[name] = fn 184 | return fn 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def get_task_spec(cls, name: str) -> TAPETaskSpec: 190 | return cls.task_name_mapping[name] 191 | 192 | @classmethod 193 | def get_metric(cls, name: str) -> Callable: 194 | return cls.metric_name_mapping[name] 195 | 196 | @classmethod 197 | def get_task_model(cls, 198 | model_name: str, 199 | task_name: str, 200 | config_file: Optional[PathType] = None, 201 | load_dir: Optional[PathType] = None) -> ProteinModel: 202 | """ Create a TAPE task model, either from scratch or from a pretrained model. 203 | This is mostly a helper function that evaluates the if statements in a 204 | sensible order if you pass all three of the arguments. 205 | Args: 206 | model_name (str): Which type of model to create (e.g. transformer, unirep, ...) 207 | task_name (str): The TAPE task for which to create a model 208 | config_file (str, optional): A json config file that specifies hyperparameters 209 | load_dir (str, optional): A save directory for a pretrained model 210 | Returns: 211 | model (ProteinModel): A TAPE task model 212 | """ 213 | task_spec = registry.get_task_spec(task_name) 214 | model_cls = task_spec.get_model(model_name) 215 | 216 | if load_dir is not None: 217 | model = model_cls.from_pretrained(load_dir, num_labels=task_spec.num_labels) 218 | else: 219 | config_class = model_cls.config_class 220 | if config_file is not None: 221 | config = config_class.from_json_file(config_file) 222 | else: 223 | config = config_class() 224 | config.num_labels = task_spec.num_labels 225 | model = model_cls(config) 226 | return model 227 | 228 | 229 | registry = Registry() 230 | -------------------------------------------------------------------------------- /tape/tokenizers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import logging 3 | from collections import OrderedDict 4 | import numpy as np 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | IUPAC_CODES = OrderedDict([ 9 | ('Ala', 'A'), 10 | ('Asx', 'B'), 11 | ('Cys', 'C'), 12 | ('Asp', 'D'), 13 | ('Glu', 'E'), 14 | ('Phe', 'F'), 15 | ('Gly', 'G'), 16 | ('His', 'H'), 17 | ('Ile', 'I'), 18 | ('Lys', 'K'), 19 | ('Leu', 'L'), 20 | ('Met', 'M'), 21 | ('Asn', 'N'), 22 | ('Pro', 'P'), 23 | ('Gln', 'Q'), 24 | ('Arg', 'R'), 25 | ('Ser', 'S'), 26 | ('Thr', 'T'), 27 | ('Sec', 'U'), 28 | ('Val', 'V'), 29 | ('Trp', 'W'), 30 | ('Xaa', 'X'), 31 | ('Tyr', 'Y'), 32 | ('Glx', 'Z')]) 33 | 34 | IUPAC_VOCAB = OrderedDict([ 35 | ("", 0), 36 | ("", 1), 37 | ("", 2), 38 | ("", 3), 39 | ("", 4), 40 | ("A", 5), 41 | ("B", 6), 42 | ("C", 7), 43 | ("D", 8), 44 | ("E", 9), 45 | ("F", 10), 46 | ("G", 11), 47 | ("H", 12), 48 | ("I", 13), 49 | ("K", 14), 50 | ("L", 15), 51 | ("M", 16), 52 | ("N", 17), 53 | ("O", 18), 54 | ("P", 19), 55 | ("Q", 20), 56 | ("R", 21), 57 | ("S", 22), 58 | ("T", 23), 59 | ("U", 24), 60 | ("V", 25), 61 | ("W", 26), 62 | ("X", 27), 63 | ("Y", 28), 64 | ("Z", 29)]) 65 | 66 | UNIREP_VOCAB = OrderedDict([ 67 | ("", 0), 68 | ("M", 1), 69 | ("R", 2), 70 | ("H", 3), 71 | ("K", 4), 72 | ("D", 5), 73 | ("E", 6), 74 | ("S", 7), 75 | ("T", 8), 76 | ("N", 9), 77 | ("Q", 10), 78 | ("C", 11), 79 | ("U", 12), 80 | ("G", 13), 81 | ("P", 14), 82 | ("A", 15), 83 | ("V", 16), 84 | ("I", 17), 85 | ("F", 18), 86 | ("Y", 19), 87 | ("W", 20), 88 | ("L", 21), 89 | ("O", 22), 90 | ("X", 23), 91 | ("Z", 23), 92 | ("B", 23), 93 | ("J", 23), 94 | ("", 24), 95 | ("", 25)]) 96 | 97 | 98 | class TAPETokenizer(): 99 | r"""TAPE Tokenizer. Can use different vocabs depending on the model. 100 | """ 101 | 102 | def __init__(self, vocab: str = 'iupac'): 103 | if vocab == 'iupac': 104 | self.vocab = IUPAC_VOCAB 105 | elif vocab == 'unirep': 106 | self.vocab = UNIREP_VOCAB 107 | self.tokens = list(self.vocab.keys()) 108 | self._vocab_type = vocab 109 | assert self.start_token in self.vocab and self.stop_token in self.vocab 110 | 111 | @property 112 | def vocab_size(self) -> int: 113 | return len(self.vocab) 114 | 115 | @property 116 | def start_token(self) -> str: 117 | return "" 118 | 119 | @property 120 | def stop_token(self) -> str: 121 | return "" 122 | 123 | @property 124 | def mask_token(self) -> str: 125 | if "" in self.vocab: 126 | return "" 127 | else: 128 | raise RuntimeError(f"{self._vocab_type} vocab does not support masking") 129 | 130 | def tokenize(self, text: str) -> List[str]: 131 | return [x for x in text] 132 | 133 | def convert_token_to_id(self, token: str) -> int: 134 | """ Converts a token (str/unicode) in an id using the vocab. """ 135 | try: 136 | return self.vocab[token] 137 | except KeyError: 138 | raise KeyError(f"Unrecognized token: '{token}'") 139 | 140 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: 141 | return [self.convert_token_to_id(token) for token in tokens] 142 | 143 | def convert_id_to_token(self, index: int) -> str: 144 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 145 | try: 146 | return self.tokens[index] 147 | except IndexError: 148 | raise IndexError(f"Unrecognized index: '{index}'") 149 | 150 | def convert_ids_to_tokens(self, indices: List[int]) -> List[str]: 151 | return [self.convert_id_to_token(id_) for id_ in indices] 152 | 153 | def convert_tokens_to_string(self, tokens: str) -> str: 154 | """ Converts a sequence of tokens (string) in a single string. """ 155 | return ''.join(tokens) 156 | 157 | def add_special_tokens(self, token_ids: List[str]) -> List[str]: 158 | """ 159 | Adds special tokens to the a sequence for sequence classification tasks. 160 | A BERT sequence has the following format: [CLS] X [SEP] 161 | """ 162 | cls_token = [self.start_token] 163 | sep_token = [self.stop_token] 164 | return cls_token + token_ids + sep_token 165 | 166 | def encode(self, text: str) -> np.ndarray: 167 | tokens = self.tokenize(text) 168 | tokens = self.add_special_tokens(tokens) 169 | token_ids = self.convert_tokens_to_ids(tokens) 170 | return np.array(token_ids, np.int64) 171 | 172 | @classmethod 173 | def from_pretrained(cls, **kwargs): 174 | return cls() 175 | -------------------------------------------------------------------------------- /tape/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import int_or_str # noqa: F401 2 | from .utils import check_is_file # noqa: F401 3 | from .utils import check_is_dir # noqa: F401 4 | from .utils import path_to_datetime # noqa: F401 5 | from .utils import get_expname # noqa: F401 6 | from .utils import get_effective_num_gpus # noqa: F401 7 | from .utils import get_effective_batch_size # noqa: F401 8 | from .utils import get_num_train_optimization_steps # noqa: F401 9 | from .utils import set_random_seeds # noqa: F401 10 | from .utils import MetricsAccumulator # noqa: F401 11 | from .utils import wrap_cuda_oom_error # noqa: F401 12 | from .utils import write_lmdb # noqa: F401 13 | from .utils import IncrementalNPZ # noqa: F401 14 | 15 | from .setup_utils import setup_logging # noqa: F401 16 | from .setup_utils import setup_optimizer # noqa: F401 17 | from .setup_utils import setup_dataset # noqa: F401 18 | from .setup_utils import setup_loader # noqa: F401 19 | from .setup_utils import setup_distributed # noqa: F401 20 | 21 | from .distributed_utils import barrier_if_distributed # noqa: F401 22 | from .distributed_utils import reduce_scalar # noqa: F401 23 | from .distributed_utils import launch_process_group # noqa: F401 24 | -------------------------------------------------------------------------------- /tape/utils/_sampler.py: -------------------------------------------------------------------------------- 1 | """Implementation of a bucketed data sampler from PyTorch-NLP. 2 | Modified by Roshan Rao. 3 | 4 | See https://github.com/PetrochukM/PyTorch-NLP/ 5 | """ 6 | import typing 7 | import math 8 | import operator 9 | from torch.utils.data.sampler import Sampler 10 | from torch.utils.data.sampler import BatchSampler 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | 13 | 14 | class SortedSampler(Sampler): 15 | """ Samples elements sequentially, always in the same order. 16 | Args: 17 | data (iterable): Iterable data. 18 | sort_key (callable): Specifies a function of one argument that is used to extract a 19 | numerical comparison key from each list element. 20 | Example: 21 | >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) 22 | [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] 23 | """ 24 | 25 | def __init__(self, 26 | dataset, 27 | sort_key: typing.Callable[[int], typing.Any], 28 | indices: typing.Optional[typing.Iterable[int]] = None): 29 | super().__init__(dataset) 30 | self.dataset = dataset 31 | self.sort_key = sort_key 32 | if indices is None: 33 | sort_keys = map(sort_key, dataset) 34 | else: 35 | sort_keys = ((i, sort_key(dataset[i])) for i in indices) 36 | self.sorted_indices = [i for i, _ in sorted(sort_keys, key=operator.itemgetter(1))] 37 | 38 | def __iter__(self): 39 | return iter(self.sorted_indices) 40 | 41 | def __len__(self): 42 | return len(self.dataset) 43 | 44 | 45 | class BucketBatchSampler(BatchSampler): 46 | """ `BucketBatchSampler` toggles between `sampler` batches and sorted batches. 47 | Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between 48 | random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted 49 | and vice versa. Provides ~10-25 percent speedup. 50 | 51 | Background: 52 | ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular 53 | libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together 54 | examples with a similar size length to reduce the padding required for each batch 55 | while maintaining some noise through bucketing. 56 | 57 | Args: 58 | sampler (torch.data.utils.sampler.Sampler): 59 | batch_size (int): Size of mini-batch. 60 | drop_last (bool): If `True` the sampler will drop the last batch if its size 61 | would be less than `batch_size`. 62 | sort_key (callable, optional): Callable to specify a comparison key for sorting. 63 | bucket_size_multiplier (int, optional): Buckets are of size 64 | `batch_size * bucket_size_multiplier`. 65 | Example: 66 | >>> from torch.utils.data.sampler import SequentialSampler 67 | >>> sampler = SequentialSampler(list(range(10))) 68 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) 69 | [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] 70 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) 71 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 72 | """ 73 | 74 | def __init__(self, 75 | sampler, 76 | batch_size, 77 | drop_last, 78 | sort_key, 79 | dataset, 80 | bucket_size_multiplier=100): 81 | super().__init__(sampler, batch_size, drop_last) 82 | self.sort_key = sort_key 83 | self.dataset = dataset 84 | self.bucket_sampler = BatchSampler( 85 | sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False) 86 | 87 | def __iter__(self): 88 | for bucket in self.bucket_sampler: 89 | sorted_sampler = SortedSampler(self.dataset, self.sort_key, indices=bucket) 90 | for batch in SubsetRandomSampler( 91 | list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): 92 | yield batch 93 | 94 | def __len__(self): 95 | if self.drop_last: 96 | return len(self.sampler) // self.batch_size 97 | else: 98 | return math.ceil(len(self.sampler) / self.batch_size) 99 | -------------------------------------------------------------------------------- /tape/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import argparse 3 | import os 4 | import multiprocessing as mp 5 | import sys 6 | import signal 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore 11 | 12 | from ..errors import EarlyStopping 13 | 14 | 15 | def reduce_scalar(scalar: float) -> float: 16 | if dist.is_available() and dist.is_initialized(): 17 | float_tensor = torch.cuda.FloatTensor([scalar]) # type: ignore 18 | dist.all_reduce(float_tensor) 19 | float_tensor /= dist.get_world_size() 20 | scalar = float_tensor.item() 21 | return scalar 22 | 23 | 24 | def barrier_if_distributed() -> None: 25 | """Raises a barrier if in a distributed context, otherwise does nothing.""" 26 | if dist.is_available() and dist.is_initialized(): 27 | dist.barrier() 28 | 29 | 30 | def _wrap(fn, kwargs, error_queue): 31 | # prctl(2) is a Linux specific system call. 32 | # On other systems the following function call has no effect. 33 | # This is set to ensure that non-daemonic child processes can 34 | # terminate if their parent terminates before they do. 35 | _prctl_pr_set_pdeathsig(signal.SIGINT) 36 | 37 | try: 38 | fn(**kwargs) 39 | except KeyboardInterrupt: 40 | pass # SIGINT; Killed by parent, do nothing 41 | except EarlyStopping: 42 | sys.exit(signal.SIGUSR1) # tape early stop exception 43 | except Exception: 44 | # Propagate exception to parent process, keeping original traceback 45 | import traceback 46 | error_queue.put(traceback.format_exc()) 47 | sys.exit(1) 48 | 49 | 50 | class ProcessContext: 51 | def __init__(self, processes, error_queues): 52 | self.error_queues = error_queues 53 | self.processes = processes 54 | self.sentinels = { 55 | process.sentinel: index 56 | for index, process in enumerate(processes) 57 | } 58 | 59 | def pids(self): 60 | return [int(process.pid) for process in self.processes] 61 | 62 | def join(self, timeout=None): 63 | r""" 64 | Tries to join one or more processes in this process context. 65 | If one of them exited with a non-zero exit status, this function 66 | kills the remaining processes and raises an exception with the cause 67 | of the first process exiting. 68 | 69 | Returns ``True`` if all processes have been joined successfully, 70 | ``False`` if there are more processes that need to be joined. 71 | 72 | Arguments: 73 | timeout (float): Wait this long before giving up on waiting. 74 | """ 75 | # Ensure this function can be called even when we're done. 76 | if len(self.sentinels) == 0: 77 | return True 78 | 79 | # Wait for any process to fail or all of them to succeed. 80 | ready = mp.connection.wait( 81 | self.sentinels.keys(), 82 | timeout=timeout, 83 | ) 84 | error_index = None 85 | for sentinel in ready: 86 | index = self.sentinels.pop(sentinel) 87 | process = self.processes[index] 88 | process.join() 89 | if process.exitcode != 0: 90 | error_index = index 91 | break 92 | # Return if there was no error. 93 | if error_index is None: 94 | # Return whether or not all processes have been joined. 95 | return len(self.sentinels) == 0 96 | # Assume failure. Terminate processes that are still alive. 97 | for process in self.processes: 98 | if process.is_alive(): 99 | process.terminate() 100 | process.join() 101 | 102 | # There won't be an error on the queue if the process crashed. 103 | if self.error_queues[error_index].empty(): 104 | exitcode = self.processes[error_index].exitcode 105 | if exitcode == signal.SIGUSR1: 106 | return True 107 | elif exitcode < 0: 108 | name = signal.Signals(-exitcode).name 109 | raise Exception( 110 | "process %d terminated with signal %s" % 111 | (error_index, name) 112 | ) 113 | else: 114 | raise Exception( 115 | "process %d terminated with exit code %d" % 116 | (error_index, exitcode) 117 | ) 118 | 119 | original_trace = self.error_queues[error_index].get() 120 | msg = "\n\n-- Process %d terminated with the following error:\n" % error_index 121 | msg += original_trace 122 | raise Exception(msg) 123 | 124 | 125 | def launch_process_group(func: typing.Callable, 126 | args: argparse.Namespace, 127 | num_processes: int, 128 | num_nodes: int = 1, 129 | node_rank: int = 0, 130 | master_addr: str = "127.0.0.1", 131 | master_port: int = 29500, 132 | join: bool = True, 133 | daemon: bool = False): 134 | # world size in terms of number of processes 135 | dist_world_size = num_processes * num_nodes 136 | 137 | # set PyTorch distributed related environmental variables 138 | current_env = os.environ.copy() 139 | current_env["MASTER_ADDR"] = master_addr 140 | current_env["MASTER_PORT"] = str(master_port) 141 | current_env["WORLD_SIZE"] = str(dist_world_size) 142 | if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1: 143 | current_env["OMP_NUM_THREADS"] = str(4) 144 | 145 | error_queues = [] 146 | processes = [] 147 | 148 | for local_rank in range(num_processes): 149 | # each process's rank 150 | dist_rank = num_processes * node_rank + local_rank 151 | current_env["RANK"] = str(dist_rank) 152 | current_env["LOCAL_RANK"] = str(local_rank) 153 | args.local_rank = local_rank 154 | 155 | error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue() 156 | kwargs = {'args': args, 'env': current_env} 157 | process = mp.Process( 158 | target=_wrap, 159 | args=(func, kwargs, error_queue), 160 | daemon=daemon) 161 | process.start() 162 | error_queues.append(error_queue) 163 | processes.append(process) 164 | 165 | process_context = ProcessContext(processes, error_queues) 166 | if not join: 167 | return process_context 168 | 169 | while not process_context.join(): 170 | pass 171 | -------------------------------------------------------------------------------- /tape/utils/setup_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions to help setup the model, optimizer, distributed compute, etc. 2 | """ 3 | import typing 4 | import logging 5 | from pathlib import Path 6 | import sys 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, RandomSampler, Dataset 11 | from torch.utils.data.distributed import DistributedSampler 12 | from ..optimization import AdamW 13 | 14 | from ..registry import registry 15 | 16 | from .utils import get_effective_batch_size 17 | from ._sampler import BucketBatchSampler 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def setup_logging(local_rank: int, 23 | save_path: typing.Optional[Path] = None, 24 | log_level: typing.Union[str, int] = None) -> None: 25 | if log_level is None: 26 | level = logging.INFO 27 | elif isinstance(log_level, str): 28 | level = getattr(logging, log_level.upper()) 29 | elif isinstance(log_level, int): 30 | level = log_level 31 | 32 | if local_rank not in (-1, 0): 33 | level = max(level, logging.WARN) 34 | 35 | root_logger = logging.getLogger() 36 | root_logger.setLevel(level) 37 | 38 | formatter = logging.Formatter( 39 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s", 40 | datefmt="%y/%m/%d %H:%M:%S") 41 | 42 | if not root_logger.hasHandlers(): 43 | console_handler = logging.StreamHandler(sys.stdout) 44 | console_handler.setLevel(level) 45 | console_handler.setFormatter(formatter) 46 | root_logger.addHandler(console_handler) 47 | 48 | if save_path is not None: 49 | file_handler = logging.FileHandler(save_path / 'log') 50 | file_handler.setLevel(level) 51 | file_handler.setFormatter(formatter) 52 | root_logger.addHandler(file_handler) 53 | 54 | 55 | def setup_optimizer(model, 56 | learning_rate: float): 57 | """Create the AdamW optimizer for the given model with the specified learning rate. Based on 58 | creation in the pytorch_transformers repository. 59 | 60 | Args: 61 | model (PreTrainedModel): The model for which to create an optimizer 62 | learning_rate (float): Default learning rate to use when creating the optimizer 63 | 64 | Returns: 65 | optimizer (AdamW): An AdamW optimizer 66 | 67 | """ 68 | param_optimizer = list(model.named_parameters()) 69 | no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] 70 | optimizer_grouped_parameters = [ 71 | { 72 | "params": [ 73 | p for n, p in param_optimizer if not any(nd in n for nd in no_decay) 74 | ], 75 | "weight_decay": 0.01, 76 | }, 77 | { 78 | "params": [ 79 | p for n, p in param_optimizer if any(nd in n for nd in no_decay) 80 | ], 81 | "weight_decay": 0.0, 82 | }, 83 | ] 84 | 85 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate) 86 | return optimizer 87 | 88 | 89 | def setup_dataset(task: str, 90 | data_dir: typing.Union[str, Path], 91 | split: str, 92 | tokenizer: str) -> Dataset: 93 | task_spec = registry.get_task_spec(task) 94 | return task_spec.dataset(data_dir, split, tokenizer) # type: ignore 95 | 96 | 97 | def setup_loader(dataset: Dataset, 98 | batch_size: int, 99 | local_rank: int, 100 | n_gpu: int, 101 | gradient_accumulation_steps: int, 102 | num_workers: int) -> DataLoader: 103 | sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset) 104 | batch_size = get_effective_batch_size( 105 | batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu 106 | # WARNING: this will fail if the primary sequence is not the first thing the dataset returns 107 | batch_sampler = BucketBatchSampler( 108 | sampler, batch_size, False, lambda x: len(x[0]), dataset) 109 | 110 | loader = DataLoader( 111 | dataset, 112 | num_workers=num_workers, 113 | collate_fn=dataset.collate_fn, # type: ignore 114 | batch_sampler=batch_sampler) 115 | 116 | return loader 117 | 118 | 119 | def setup_distributed(local_rank: int, 120 | no_cuda: bool) -> typing.Tuple[torch.device, int, bool]: 121 | if local_rank != -1 and not no_cuda: 122 | torch.cuda.set_device(local_rank) 123 | device: torch.device = torch.device("cuda", local_rank) 124 | n_gpu = 1 125 | dist.init_process_group(backend="nccl") 126 | elif not torch.cuda.is_available() or no_cuda: 127 | device = torch.device("cpu") 128 | n_gpu = 1 129 | else: 130 | device = torch.device("cuda") 131 | n_gpu = torch.cuda.device_count() 132 | 133 | is_master = local_rank in (-1, 0) 134 | 135 | return device, n_gpu, is_master 136 | -------------------------------------------------------------------------------- /tape/utils/utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import random 3 | from pathlib import Path 4 | import logging 5 | from time import strftime, gmtime 6 | from datetime import datetime 7 | import os 8 | import argparse 9 | import contextlib 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import Dataset 15 | import torch.distributed as dist 16 | 17 | logger = logging.getLogger(__name__) 18 | FloatOrTensor = typing.Union[float, torch.Tensor] 19 | 20 | 21 | def int_or_str(arg: str) -> typing.Union[int, str]: 22 | try: 23 | return int(arg) 24 | except ValueError: 25 | return arg 26 | 27 | 28 | def check_is_file(file_path: str) -> str: 29 | if file_path is None or os.path.isfile(file_path): 30 | return file_path 31 | else: 32 | raise argparse.ArgumentTypeError(f"File path: {file_path} is not a valid file") 33 | 34 | 35 | def check_is_dir(dir_path: str) -> str: 36 | if dir_path is None or os.path.isdir(dir_path): 37 | return dir_path 38 | else: 39 | raise argparse.ArgumentTypeError(f"Directory path: {dir_path} is not a valid directory") 40 | 41 | 42 | def path_to_datetime(path: Path) -> datetime: 43 | name = path.name 44 | datetime_string = name.split('_')[0] 45 | try: 46 | year, month, day, hour, minute, second = datetime_string.split('-') 47 | except ValueError: 48 | try: 49 | # Deprecated datetime strings 50 | year, month, day, time_str = datetime_string.split('-') 51 | hour, minute, second = time_str.split(':') 52 | except ValueError: 53 | return datetime(1, 1, 1) 54 | 55 | pathdatetime = datetime( 56 | int(year), int(month), int(day), int(hour), int(minute), int(second)) 57 | return pathdatetime 58 | 59 | 60 | def get_expname(exp_name: typing.Optional[str], 61 | task: typing.Optional[str] = None, 62 | model_type: typing.Optional[str] = None) -> str: 63 | if exp_name is None: 64 | time_stamp = strftime("%y-%m-%d-%H-%M-%S", gmtime()) 65 | exp_name = f"{task}_{model_type}_{time_stamp}_{random.randint(0, int(1e6)):0>6d}" 66 | return exp_name 67 | 68 | 69 | def set_random_seeds(seed: int, n_gpu: int) -> None: 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | if n_gpu > 0: 74 | torch.cuda.manual_seed_all(seed) # type: ignore 75 | 76 | 77 | def get_effective_num_gpus(local_rank: int, n_gpu: int) -> int: 78 | if local_rank == -1: 79 | num_gpus = n_gpu 80 | else: 81 | num_gpus = dist.get_world_size() 82 | return num_gpus 83 | 84 | 85 | def get_effective_batch_size(batch_size: int, 86 | local_rank: int, 87 | n_gpu: int, 88 | gradient_accumulation_steps: int = 1) -> int: 89 | eff_batch_size = float(batch_size) 90 | eff_batch_size /= gradient_accumulation_steps 91 | eff_batch_size /= get_effective_num_gpus(local_rank, n_gpu) 92 | return int(eff_batch_size) 93 | 94 | 95 | def get_num_train_optimization_steps(dataset: Dataset, 96 | batch_size: int, 97 | num_train_epochs: int) -> int: 98 | return int(len(dataset) / batch_size * num_train_epochs) 99 | 100 | 101 | class MetricsAccumulator: 102 | 103 | def __init__(self, smoothing: float = 0.95): 104 | self._loss_tmp = 0. 105 | self._smoothloss: typing.Optional[float] = None 106 | self._totalloss = 0. 107 | self._metricstmp: typing.Dict[str, float] = defaultdict(lambda: 0.0) 108 | self._smoothmetrics: typing.Dict[str, float] = {} 109 | self._totalmetrics: typing.Dict[str, float] = defaultdict(lambda: 0.0) 110 | 111 | self._nacc_steps = 0 112 | self._nupdates = 0 113 | self._smoothing = smoothing 114 | 115 | def update(self, 116 | loss: FloatOrTensor, 117 | metrics: typing.Dict[str, FloatOrTensor], 118 | step: bool = True) -> None: 119 | if isinstance(loss, torch.Tensor): 120 | loss = loss.item() 121 | 122 | self._loss_tmp += loss 123 | for name, value in metrics.items(): 124 | if isinstance(value, torch.Tensor): 125 | value = value.item() 126 | self._metricstmp[name] += value 127 | self._nacc_steps += 1 128 | 129 | if step: 130 | self.step() 131 | 132 | def step(self) -> typing.Dict[str, float]: 133 | loss_tmp = self._loss_tmp / self._nacc_steps 134 | metricstmp = {name: value / self._nacc_steps 135 | for name, value in self._metricstmp.items()} 136 | 137 | if self._smoothloss is None: 138 | self._smoothloss = loss_tmp 139 | else: 140 | self._smoothloss *= self._smoothing 141 | self._smoothloss += (1 - self._smoothing) * loss_tmp 142 | self._totalloss += loss_tmp 143 | 144 | for name, value in metricstmp.items(): 145 | if name in self._smoothmetrics: 146 | currvalue = self._smoothmetrics[name] 147 | newvalue = currvalue * self._smoothing + value * (1 - self._smoothing) 148 | else: 149 | newvalue = value 150 | 151 | self._smoothmetrics[name] = newvalue 152 | self._totalmetrics[name] += value 153 | 154 | self._nupdates += 1 155 | 156 | self._nacc_steps = 0 157 | self._loss_tmp = 0 158 | self._metricstmp = defaultdict(lambda: 0.0) 159 | 160 | metricstmp['loss'] = loss_tmp 161 | return metricstmp 162 | 163 | def loss(self) -> float: 164 | if self._smoothloss is None: 165 | raise RuntimeError("Trying to get the loss without any updates") 166 | return self._smoothloss 167 | 168 | def metrics(self) -> typing.Dict[str, float]: 169 | if self._nupdates == 0: 170 | raise RuntimeError("Trying to get metrics without any updates") 171 | return dict(self._smoothmetrics) 172 | 173 | def final_loss(self) -> float: 174 | return self._totalloss / self._nupdates 175 | 176 | def final_metrics(self) -> typing.Dict[str, float]: 177 | return {name: value / self._nupdates 178 | for name, value in self._totalmetrics.items()} 179 | 180 | 181 | class wrap_cuda_oom_error(contextlib.ContextDecorator): 182 | """A context manager that wraps the Cuda OOM message so that you get some more helpful 183 | context as to what you can/should change. Can also be used as a decorator. 184 | 185 | Examples: 186 | 1) As a context manager: 187 | 188 | with wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation): 189 | loss = model.forward(batch) 190 | loss.backward() 191 | optimizer.step() 192 | optimizer.zero_grad 193 | 194 | 2) As a decorator: 195 | 196 | @wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation) 197 | def run_train_epoch(args): 198 | ... 199 | 200 | ... 201 | """ 202 | 203 | def __init__(self, 204 | local_rank: int, 205 | batch_size: int, 206 | n_gpu: int = 1, 207 | gradient_accumulation_steps: typing.Optional[int] = None): 208 | self._local_rank = local_rank 209 | self._batch_size = batch_size 210 | self._n_gpu = n_gpu 211 | self._gradient_accumulation_steps = gradient_accumulation_steps 212 | 213 | def __enter__(self): 214 | return self 215 | 216 | def __exit__(self, exc_type, exc_value, traceback): 217 | exc_args = exc_value.args if exc_value is not None else None 218 | if exc_args and 'CUDA out of memory' in exc_args[0]: 219 | eff_ngpu = get_effective_num_gpus(self._local_rank, self._n_gpu) 220 | if self._gradient_accumulation_steps is not None: 221 | eff_batch_size = get_effective_batch_size( 222 | self._batch_size, self._local_rank, self._n_gpu, 223 | self._gradient_accumulation_steps) 224 | message = (f"CUDA out of memory. Reduce batch size or increase " 225 | f"gradient_accumulation_steps to divide each batch over more " 226 | f"forward passes.\n\n" 227 | f"\tHyperparameters:\n" 228 | f"\t\tbatch_size per backward-pass: {self._batch_size}\n" 229 | f"\t\tgradient_accumulation_steps: " 230 | f"{self._gradient_accumulation_steps}\n" 231 | f"\t\tn_gpu: {eff_ngpu}\n" 232 | f"\t\tbatch_size per (gpu * forward-pass): " 233 | f"{eff_batch_size}") 234 | else: 235 | eff_batch_size = get_effective_batch_size( 236 | self._batch_size, self._local_rank, self._n_gpu) 237 | message = (f"CUDA out of memory. Reduce batch size to fit each " 238 | f"iteration in memory.\n\n" 239 | f"\tHyperparameters:\n" 240 | f"\t\tbatch_size per forward-pass: {self._batch_size}\n" 241 | f"\t\tn_gpu: {eff_ngpu}\n" 242 | f"\t\tbatch_size per (gpu * forward-pass): " 243 | f"{eff_batch_size}") 244 | raise RuntimeError(message) 245 | return False 246 | 247 | 248 | def write_lmdb(filename: str, iterable: typing.Iterable, map_size: int = 2 ** 20): 249 | """Utility for writing a dataset to an LMDB file. 250 | 251 | Args: 252 | filename (str): Output filename to write to 253 | iterable (Iterable): An iterable dataset to write to. Entries must be pickleable. 254 | map_size (int, optional): Maximum allowable size of database in bytes. Required by LMDB. 255 | You will likely have to increase this. Default: 1MB. 256 | """ 257 | import lmdb 258 | import pickle as pkl 259 | env = lmdb.open(filename, map_size=map_size) 260 | 261 | with env.begin(write=True) as txn: 262 | for i, entry in enumerate(iterable): 263 | txn.put(str(i).encode(), pkl.dumps(entry)) 264 | txn.put(b'num_examples', pkl.dumps(i + 1)) 265 | env.close() 266 | 267 | 268 | class IncrementalNPZ(object): 269 | # Modified npz that allows incremental saving, from https://stackoverflow.com/questions/22712292/how-to-use-numpy-savez-in-a-loop-for-save-more-than-one-array # noqa: E501 270 | def __init__(self, file): 271 | import tempfile 272 | import zipfile 273 | import os 274 | 275 | if isinstance(file, str): 276 | if not file.endswith('.npz'): 277 | file = file + '.npz' 278 | 279 | compression = zipfile.ZIP_STORED 280 | 281 | zipfile = self.zipfile_factory(file, mode="w", compression=compression) 282 | 283 | # Stage arrays in a temporary file on disk, before writing to zip. 284 | fd, tmpfile = tempfile.mkstemp(suffix='-numpy.npy') 285 | os.close(fd) 286 | 287 | self.tmpfile = tmpfile 288 | self.zip = zipfile 289 | self._i = 0 290 | 291 | def zipfile_factory(self, *args, **kwargs): 292 | import zipfile 293 | import sys 294 | if sys.version_info >= (2, 5): 295 | kwargs['allowZip64'] = True 296 | return zipfile.ZipFile(*args, **kwargs) 297 | 298 | def savez(self, *args, **kwds): 299 | import os 300 | import numpy.lib.format as fmt 301 | 302 | namedict = kwds 303 | for val in args: 304 | key = 'arr_%d' % self._i 305 | if key in namedict.keys(): 306 | raise ValueError("Cannot use un-named variables and keyword %s" % key) 307 | namedict[key] = val 308 | self._i += 1 309 | 310 | try: 311 | for key, val in namedict.items(): 312 | fname = key + '.npy' 313 | fid = open(self.tmpfile, 'wb') 314 | with open(self.tmpfile, 'wb') as fid: 315 | fmt.write_array(fid, np.asanyarray(val), allow_pickle=True) 316 | self.zip.write(self.tmpfile, arcname=fname) 317 | finally: 318 | os.remove(self.tmpfile) 319 | 320 | def close(self): 321 | self.zip.close() 322 | 323 | def __enter__(self): 324 | return self 325 | 326 | def __exit__(self, exc_type, exc_value, traceback): 327 | self.close() 328 | -------------------------------------------------------------------------------- /tape/visualization.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import os 3 | import logging 4 | from abc import ABC, abstractmethod 5 | from pathlib import Path 6 | import torch.nn as nn 7 | 8 | from tensorboardX import SummaryWriter 9 | 10 | try: 11 | import wandb 12 | WANDB_FOUND = True 13 | except ImportError: 14 | WANDB_FOUND = False 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class TAPEVisualizer(ABC): 20 | """Base class for visualization in TAPE""" 21 | 22 | @abstractmethod 23 | def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False): 24 | raise NotImplementedError 25 | 26 | @abstractmethod 27 | def log_config(self, config: typing.Dict[str, typing.Any]) -> None: 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def watch(self, model: nn.Module) -> None: 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def log_metrics(self, 36 | metrics_dict: typing.Dict[str, float], 37 | split: str, 38 | step: int): 39 | raise NotImplementedError 40 | 41 | 42 | class DummyVisualizer(TAPEVisualizer): 43 | """Dummy class that doesn't do anything. Used for non-master branches.""" 44 | 45 | def __init__(self, 46 | log_dir: typing.Union[str, Path] = '', 47 | exp_name: str = '', 48 | debug: bool = False): 49 | pass 50 | 51 | def log_config(self, config: typing.Dict[str, typing.Any]) -> None: 52 | pass 53 | 54 | def watch(self, model: nn.Module) -> None: 55 | pass 56 | 57 | def log_metrics(self, 58 | metrics_dict: typing.Dict[str, float], 59 | split: str, 60 | step: int): 61 | pass 62 | 63 | 64 | class TBVisualizer(TAPEVisualizer): 65 | 66 | def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False): 67 | log_dir = Path(log_dir) / exp_name 68 | logger.info(f"tensorboard file at: {log_dir}") 69 | self.logger = SummaryWriter(log_dir=str(log_dir)) 70 | 71 | def log_config(self, config: typing.Dict[str, typing.Any]) -> None: 72 | logger.warn("Cannot log config when using a TBVisualizer. " 73 | "Configure wandb for this functionality") 74 | 75 | def watch(self, model: nn.Module) -> None: 76 | logger.warn("Cannot watch models when using a TBVisualizer. " 77 | "Configure wandb for this functionality") 78 | 79 | def log_metrics(self, 80 | metrics_dict: typing.Dict[str, float], 81 | split: str, 82 | step: int): 83 | for name, value in metrics_dict.items(): 84 | self.logger.add_scalar(split + "/" + name, value, step) 85 | 86 | 87 | class WandBVisualizer(TAPEVisualizer): 88 | 89 | def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False): 90 | if not WANDB_FOUND: 91 | raise ImportError("wandb module not available") 92 | if debug: 93 | os.environ['WANDB_MODE'] = 'dryrun' 94 | if 'WANDB_PROJECT' not in os.environ: 95 | # Want the user to set the WANDB_PROJECT. 96 | logger.warning("WANDB_PROJECT environment variable not found, " 97 | "not logging to app.wandb.ai") 98 | os.environ['WANDB_MODE'] = 'dryrun' 99 | wandb.init(dir=log_dir, name=exp_name) 100 | 101 | def log_config(self, config: typing.Dict[str, typing.Any]) -> None: 102 | wandb.config.update(config) 103 | 104 | def watch(self, model: nn.Module): 105 | wandb.watch(model) 106 | 107 | def log_metrics(self, 108 | metrics_dict: typing.Dict[str, float], 109 | split: str, 110 | step: int): 111 | wandb.log({f"{split.capitalize()} {name.capitalize()}": value 112 | for name, value in metrics_dict.items()}, step=step) 113 | 114 | 115 | def get(log_dir: typing.Union[str, Path], 116 | exp_name: str, 117 | local_rank: int, 118 | debug: bool = False) -> TAPEVisualizer: 119 | if local_rank not in (-1, 0): 120 | return DummyVisualizer(log_dir, exp_name, debug) 121 | elif WANDB_FOUND: 122 | return WandBVisualizer(log_dir, exp_name, debug) 123 | else: 124 | return TBVisualizer(log_dir, exp_name, debug) 125 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | def test_basic(): 2 | import torch 3 | from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore 4 | 5 | config = ProteinBertConfig(hidden_size=12, intermediate_size=12 * 4, num_hidden_layers=2) 6 | model = ProteinBertModel(config) 7 | tokenizer = TAPETokenizer(vocab='iupac') 8 | 9 | sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' 10 | token_ids = torch.tensor([tokenizer.encode(sequence)]) 11 | output = model(token_ids) 12 | sequence_output = output[0] # noqa 13 | pooled_output = output[1] # noqa 14 | -------------------------------------------------------------------------------- /tests/test_forceDownload.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from tape.models.file_utils import url_to_filename, get_cache, get_etag 4 | from tape import ProteinBertModel 5 | from tape import TAPETokenizer 6 | from tape.models.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 7 | import torch 8 | 9 | 10 | def test_forcedownload(): 11 | model = ProteinBertModel.from_pretrained('bert-base') 12 | url = BERT_PRETRAINED_MODEL_ARCHIVE_MAP['bert-base'] 13 | filename = url_to_filename(url, get_etag(url)) 14 | wholepath = get_cache()/filename 15 | oldtime = time.ctime(os.path.getmtime(wholepath)) 16 | model = ProteinBertModel.from_pretrained('bert-base', force_download=True) 17 | newtime = time.ctime(os.path.getmtime(wholepath)) 18 | assert(newtime != oldtime) 19 | # Deploy model 20 | # iupac is the vocab for TAPE models, use unirep for the UniRep model 21 | tokenizer = TAPETokenizer(vocab='iupac') 22 | # Pfam Family: Hexapep, Clan: CL0536 23 | sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' 24 | token_ids = torch.tensor([tokenizer.encode(sequence)]) 25 | model(token_ids) 26 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore= 3 | # Rules about indentation that I don't care about 4 | E114, 5 | E116 6 | max-line-length = 96 7 | --------------------------------------------------------------------------------