├── .github └── workflows │ └── comment_bot.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── docs └── CONTRIBUTING.md ├── protein_lm.yml ├── protein_lm ├── __init__.py ├── configs │ └── train │ │ ├── toy_hf.yaml │ │ └── toy_localcsv.yaml ├── dataset │ ├── __init__.py │ ├── cluster_dataset.py │ └── uniref │ │ └── uniref50_trimmed.csv ├── evaluation │ ├── __init__.py │ ├── scripts.py │ │ ├── download_proteingym_data.py │ │ ├── fitness_supervised.py │ │ ├── fitness_zero_shot_AR.py │ │ └── fitness_zero_shot_ESM.py │ └── scripts │ │ ├── Protein-gym.py │ │ ├── contact_prediction.py │ │ └── utils.py ├── modeling │ ├── __init__.py │ ├── getters │ │ ├── __init__.py │ │ ├── data_collator.py │ │ ├── dataset.py │ │ ├── model.py │ │ ├── tokenizer.py │ │ ├── training_args.py │ │ └── wandb_log.py │ ├── models │ │ ├── __init__.py │ │ └── apt │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── model_pytorch.py │ ├── scripts │ │ └── train.py │ └── utils │ │ ├── __init__.py │ │ ├── alibi_embedding.py │ │ ├── modules.py │ │ ├── rerope_embedding.py │ │ ├── rotary_embedding.py │ │ └── scaled_rope_embedding.py ├── tests │ ├── tensors │ │ ├── 1a3a.pkl │ │ ├── 1xcr.pkl │ │ ├── 5ahw.pkl │ │ ├── 5ahw_1_A_jacobian.pkl │ │ ├── dynamic_rope.pkl │ │ ├── linear_rope.pkl │ │ ├── rerope.pkl │ │ └── rope.pkl │ ├── test_attention.py │ ├── test_cl.py │ ├── test_cl_continuous.py │ ├── test_contact_prediction.py │ ├── test_data │ │ ├── 1a3a_1_A.a3m │ │ ├── 1xcr_1_A.a3m │ │ └── 5ahw_1_A.a3m │ ├── test_encoding.py │ └── test_tokenizer.py └── tokenizer │ ├── __init__.py │ ├── rust_trie │ ├── .github │ │ └── workflows │ │ │ └── CI.yml │ ├── .gitignore │ ├── Cargo.toml │ ├── pyproject.toml │ └── src │ │ └── lib.rs │ └── tokenizer.py ├── protein_lm_cuda.yml └── setup.py /.github/workflows/comment_bot.yml: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | name: Comment Bot 19 | 20 | on: 21 | issue_comment: 22 | types: 23 | - created 24 | - edited 25 | 26 | permissions: 27 | contents: read 28 | pull-requests: write 29 | 30 | jobs: 31 | issue_assign: 32 | name: "Auto-assign issue" 33 | permissions: 34 | issues: write 35 | if: contains(github.event.comment.body, '/take') 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/github-script@v6 39 | with: 40 | github-token: ${{ secrets.GITHUB_TOKEN }} 41 | script: | 42 | github.rest.issues.addAssignees({ 43 | owner: context.repo.owner, 44 | repo: context.repo.repo, 45 | issue_number: context.payload.issue.number, 46 | assignees: context.payload.comment.user.login 47 | }); 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | protein_lm/dataset/ProteinGym/ 2 | protein_lm/evaluation/output/ 3 | /esm2*/ 4 | /toy/ 5 | *.lock 6 | *.pyc 7 | wandb/ 8 | checkpoints/ 9 | __pycache__/ 10 | protein_lm.egg-info/ 11 | *.DS_Store 12 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2023 the authors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Protein language models scaling laws 2 | ============== 3 | 4 | The goal of this project is to uncover the best approach to scale large protein language models (ie., learn scaling laws for protein language models) and then publicly release a suite of optimally-trained large protein language models. 5 | 6 | ## Installing enviroment 7 | 8 | If you want to run on CPU: 9 | ``` 10 | conda env create -f protein_lm.yml 11 | conda activate protein_lm_env 12 | pip install -e . 13 | ``` 14 | 15 | If you plan to use cuda, use the dedicated .yaml file: 16 | ``` 17 | conda env create -f protein_lm_cuda.yml 18 | conda activate protein_lm_env 19 | pip install -e . 20 | ``` 21 | 22 | 23 | ## Installing tokenizer 24 | 25 | This will be integrated to the rest of our installation formula, but for now, you need to run the following to build the Rust dependency of the tokenizer: 26 | 27 | ``` 28 | pip install -e protein_lm/tokenizer/rust_trie 29 | ``` 30 | 31 | ## Training 32 | 33 | ### Toy using local dataset 34 | 35 | We recommend using a toy tiny dataset for testing and debugging new changes that do not rely on having a large datset. Such a small dataset is provided in the `protein_lm/dataset/uniref` folder and an example toy training config yaml that uses this dataset is provided in `protein_lm/configs/train/toy_localcsv.yaml`. To use this config, at the root project directory (e.g., `protein_lm_scaling/`), run 36 | 37 | ``` 38 | python protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_localcsv.yaml 39 | ``` 40 | 41 | This config is actually the default, so the above is equivalent to 42 | 43 | ``` 44 | python protein_lm/modeling/scripts/train.py 45 | ``` 46 | 47 | ### Toy using a HuggingFace dataset 48 | 49 | For testing with a HuggingFace dataset, we have an example config yaml in `protein_lm/configs/train/toy_hf.yaml`. Note that training with this config is a little more involved than the above `protein_lm/configs/train/toy_localcsv.yaml`: 50 | 51 | * When first run, the script will download the [processed uniref50 dataset](https://huggingface.co/datasets/zpn/uniref50), which could take some time. 52 | * This config will log the loss values and other metrics to Weights and Biases. This will require you to create a wandb account. 53 | 54 | You can run with this config by: 55 | 56 | ``` 57 | python protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_hf.yaml 58 | ``` 59 | 60 | ### Running on multiple gpus 61 | 62 | We can run on a single node with multiple gpus by 63 | 64 | ``` 65 | torchrun --standalone --nnodes=1 --nproc-per-node protein_lm/modeling/scripts/train.py --config-file 66 | ``` 67 | 68 | For example, to run on a single node with 3 gpus with the provided `protein_lm/configs/train/toy_hf.yaml` config file, we can run with 69 | 70 | ``` 71 | torchrun --standalone --nnodes=1 --nproc-per-node 3 protein_lm/modeling/scripts/train.py --config-file protein_lm/configs/train/toy_hf.yaml 72 | ``` 73 | ## Evaluation 74 | 75 | ### Contact Prediction 76 | 77 | The script can be run using APT or ESM. It also has the option of choosing between `--method jacobian` where contact prediction is computed by categorical jacobian extraction or `--method regression` logistic regression based contact prediction. See [BERT-ESM1b](https://github.com/sokrypton/algosb_2021/blob/main/BERT_esm1b.ipynb) and [Rao et al. 2020](https://doi.org/10.1101/2020.12.15.422761) for details. Currently, script only supports .a3m files as input. 78 | 79 | You can run the contact prediction eval script by 80 | ``` 81 | python protein_lm/evalutation/scripts/contact_prediction.py --input /path/to/.a3m --model ESM --tokenizer EsmTokenizer --method jacobian --output /path/to/outputdir 82 | ``` 83 | ### Outputs 84 | * .png file with predicted contacts overlaid on ground truth contact map. 85 | * .csv file with results for each protein. 86 | 87 | ## Getting involved 88 | Your involvement is welcome! If you are interested, you can 89 | - Join the `#protein-lm-scaling` channel on the [OpenBioML discord server](https://discord.com/invite/GgDBFP8ZEt). 90 | - Check out our [contributing guide](docs/CONTRIBUTING.md) if you are interested in contributing to this repository. 91 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | If you are looking for ways to contribute, please check out our [task board](https://github.com/orgs/OpenBioML/projects/8) or our [open issues](https://github.com/OpenBioML/protein-lm-scaling/issues)! 4 | 5 | 6 | ## Issue Assignment 7 | 8 | If you see an interesting issue, please feel free to comment or ask questions! Issues are generally resolved by code changes to our repository via a [pull request](#creating-a-pull-request). If you would like to be the one responsible for the code changes corresponding to an issue, and if nobody else has already been assigned to the issue, you can be assigned the issue by asking the issue creator or one of the moderators of repository, or by self-assigning the issue by commenting with `/take`. 9 | 10 | ## Unit testing 11 | 12 | When working on a particular issue, please also include relevant unit tests together with the code. Please see the following for example tests: [example 1](https://github.com/OpenBioML/protein-lm-scaling/blob/main/protein_lm/tests/test_tokenizer.py) [example 2](https://github.com/OpenBioML/protein-lm-scaling/blob/main/protein_lm/tests/test_encoding.py). 13 | 14 | ## Creating a Pull Request 15 | 16 | In order to contribute code to this repository, you must make a pull request (PR). 17 | 18 | First, make sure you have a GitHub account and have git installed locally. Then, to make your first pull request, 19 | 20 | 1. [Fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo) our [repository](https://github.com/OpenBioML/protein-lm-scaling) by clicking on the [fork](https://github.com/OpenBioML/protein-lm-scaling/fork) button on the repository's main page. 21 | 2. [Clone](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) your fork. 22 | 3. [Create a new branch](https://git-scm.com/book/en/v2/Git-Branching-Basic-Branching-and-Merging). 23 | 4. Make your code contributions on this new branch. 24 | 5. [Create a PR from your fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to our repository. 25 | 26 | (For subsequent pull requests, skip steps 1 and 2 and start directly from step 3.) 27 | 28 | Once your pull request has been discussed, potentially updated, and approved, it will be merged by one of the project leads. If you have reached this step, congratulations and thank you for contributing! 29 | -------------------------------------------------------------------------------- /protein_lm.yml: -------------------------------------------------------------------------------- 1 | name: protein_lm_env 2 | channels: 3 | - pytorch 4 | - huggingface 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python>=3.8 9 | - numpy 10 | - scipy 11 | - pytorch 12 | - pydantic>=2.0 13 | - wandb 14 | - rust 15 | - biotite 16 | - pip: 17 | - transformers 18 | - datasets 19 | - accelerate 20 | - evaluate 21 | - pytest 22 | - fair-esm 23 | -------------------------------------------------------------------------------- /protein_lm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/__init__.py -------------------------------------------------------------------------------- /protein_lm/configs/train/toy_hf.yaml: -------------------------------------------------------------------------------- 1 | # corresponds to DatasetConfig 2 | dataset: 3 | dataset_type: "huggingface" 4 | dataset_loc: "zpn/uniref50" 5 | subsample_size: 1000 6 | split_seed: 2 7 | val_size: 10 8 | test_size: 10 9 | sequence_column_name: "sequence" 10 | max_sequence_length: 10 11 | do_curriculum_learning: false 12 | 13 | # corresponds to HuggingFace's TrainingArguments 14 | training_arguments: 15 | output_dir: "checkpoints/toy_hf" 16 | num_train_epochs: 2 17 | learning_rate: 0.1 18 | weight_decay: 0.1 19 | save_strategy: "epoch" 20 | per_device_train_batch_size: 10 21 | save_steps: 5 22 | evaluation_strategy: "steps" 23 | eval_steps: 5 24 | report_to: "wandb" 25 | label_names: 26 | - 'labels' 27 | no_cuda: false 28 | ddp_find_unused_parameters: false 29 | 30 | # corresponds to WandBConfig 31 | wandb: 32 | name: "toy_hf" 33 | dir: "wandb_files/" 34 | 35 | # corresponds to TokenizerConfig 36 | tokenizer: 37 | tokenizer_type: "APT" 38 | 39 | # corresponds to NNModelConfig 40 | model: 41 | nn_model_type: "APT" 42 | nn_model_config_args: 43 | position_embedding: "learned" 44 | max_sequence_length: 10 45 | pretrained_checkpoint: null 46 | 47 | # corresponds to DataCollatorConfig 48 | data_collator: 49 | data_collator_type: "default" 50 | -------------------------------------------------------------------------------- /protein_lm/configs/train/toy_localcsv.yaml: -------------------------------------------------------------------------------- 1 | # corresponds to DatasetConfig 2 | dataset: 3 | dataset_type: "csv" 4 | dataset_loc: "protein_lm/dataset/uniref/uniref50_trimmed.csv" 5 | subsample_size: 100 6 | split_seed: 2 7 | val_size: 10 8 | test_size: 10 9 | sequence_column_name: "sequence" 10 | max_sequence_length: 10 11 | do_curriculum_learning: false 12 | 13 | # corresponds to HuggingFace's TrainingArguments 14 | training_arguments: 15 | output_dir: "checkpoints/toy" 16 | max_steps: 1 17 | num_train_epochs: 1 18 | learning_rate: 0.1 19 | weight_decay: 0.1 20 | save_strategy: "epoch" 21 | per_device_train_batch_size: 1 22 | save_steps: 1 23 | report_to: "none" 24 | label_names: 25 | - 'labels' 26 | no_cuda: false 27 | 28 | # corresponds to TokenizerConfig 29 | tokenizer: 30 | tokenizer_type: "APT" 31 | 32 | # corresponds to NNModelConfig 33 | model: 34 | nn_model_type: "APT" 35 | nn_model_config_args: 36 | position_embedding: "learned" 37 | rope_scaling_factor: 1.0 38 | rope_theta: 10000 39 | max_sequence_length: 10 40 | pretrained_checkpoint: null 41 | 42 | # corresponds to DataCollatorConfig 43 | data_collator: 44 | data_collator_type: "default" 45 | -------------------------------------------------------------------------------- /protein_lm/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/dataset/__init__.py -------------------------------------------------------------------------------- /protein_lm/dataset/cluster_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | from Bio import SeqIO 7 | 8 | 9 | class ClusterDataset(Dataset): 10 | def __init__( 11 | self, 12 | dataset_path: str, 13 | cluster_table_path: str, 14 | size_to_sample_prob: Callable = lambda x: x, 15 | seed: int = 42, 16 | ) -> None: 17 | super().__init__() 18 | self.dataset_path = dataset_path 19 | self.cluster_table_path = cluster_table_path 20 | self.cluster_to_seqs = {} 21 | self.cluster_table = pd.read_csv( 22 | cluster_table_path, dtype={'cluster_name': str, 'cluster_size': int} 23 | ) 24 | self.cluster_table['sample_prob'] = self.cluster_table['cluster_size'].apply(size_to_sample_prob) 25 | self.cluster_table['sample_prob'] /= self.cluster_table['sample_prob'].sum() 26 | self.generator = np.random.default_rng(seed) 27 | 28 | def __len__(self) -> int: 29 | return len(self.cluster_table) 30 | 31 | def get_cluster_seqs(self, cluster_path: str) -> list: 32 | if cluster_path not in self.cluster_to_seqs: 33 | self.cluster_to_seqs[cluster_path] = [ 34 | str(x.seq) for x in SeqIO.parse(cluster_path, 'fasta') 35 | ] 36 | return self.cluster_to_seqs[cluster_path] 37 | 38 | def __iter__(self): 39 | for _ in range(len(self)): 40 | cluster_name = self.cluster_table.sample( 41 | n=1, weights='sample_prob', random_state=self.generator 42 | )[['cluster_name']].values[0][0] 43 | # Now we map cluster_name to the folder it is in 44 | if cluster_name == "unk": 45 | cluster_path = os.path.join(self.dataset_path, "unk", "unk.fasta") 46 | else: 47 | cluster_dir = f"{int(cluster_name) // 1000}000" 48 | cluster_path = os.path.join(self.dataset_path, cluster_dir, f"{cluster_name}.fasta") 49 | seqs = self.get_cluster_seqs(cluster_path) 50 | yield seqs[self.generator.integers(len(seqs))] 51 | -------------------------------------------------------------------------------- /protein_lm/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/evaluation/__init__.py -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts.py/download_proteingym_data.py: -------------------------------------------------------------------------------- 1 | ###################################################################################################### 2 | # Script: ProteinGym Download script 3 | # Authors: Maximilian Sprang, Muedi 4 | # Date: 09/2023 5 | # Description: This script downloads and handles the zipped Protein Gym Csvs and preprocesses them 6 | # to be able to tokenized by EMS/ProtBERT tokenizers. 7 | ###################################################################################################### 8 | import pandas as pd 9 | import requests, zipfile, io, os 10 | from datasets import load_dataset 11 | import argparse 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='Download ProteinGym Data') 15 | parser.add_argument("--data_path", default="protein_lm/dataset/ProteinGym/", type=str, help="Path to drop data") 16 | args = parser.parse_args() 17 | # relative protgym path 18 | data_path = args.data_path 19 | 20 | # download substitutions, unzip, save to disk 21 | dat_url = "https://marks.hms.harvard.edu/proteingym/ProteinGym_substitutions.zip" 22 | 23 | if os.path.exists(data_path + "ProteinGym_substitutions"): 24 | print("substitution data is already here :)") 25 | else: 26 | print("download substitution data ...") 27 | r = requests.get(dat_url) 28 | z = zipfile.ZipFile(io.BytesIO(r.content)) 29 | z.extractall(data_path) 30 | 31 | # download indels, unzip, save to disk 32 | dat_url = "https://marks.hms.harvard.edu/proteingym/ProteinGym_indels.zip" 33 | 34 | if os.path.exists(data_path + "ProteinGym_indels"): 35 | print("indel data is already here :)") 36 | else: 37 | print("download indel data ...") 38 | r = requests.get(dat_url) 39 | z = zipfile.ZipFile(io.BytesIO(r.content)) 40 | z.extractall(data_path) 41 | 42 | # download ref files 43 | dat_url = "https://raw.githubusercontent.com/OATML-Markslab/ProteinGym/main/ProteinGym_reference_file_substitutions.csv" 44 | if os.path.exists(data_path + "ProteinGym_reference_file_substitutions.csv"): 45 | print("Substitution reference file is already here :)") 46 | else: 47 | print("download substitution reference ...") 48 | r = requests.get(dat_url) 49 | df = pd.read_csv(io.BytesIO(r.content)) 50 | df.to_csv(data_path + "ProteinGym_reference_file_substitutions.csv", index=False) 51 | 52 | dat_url = "https://raw.githubusercontent.com/OATML-Markslab/ProteinGym/main/ProteinGym_reference_file_indels.csvv" 53 | if os.path.exists(data_path + "ProteinGym_reference_file_indels.csv"): 54 | print("Indel reference file is already here :)") 55 | else: 56 | print("download Indel reference ...") 57 | r = requests.get(dat_url) 58 | df = pd.read_csv(io.BytesIO(r.content)) 59 | df.to_csv(data_path + "ProteinGym_reference_file_indels.csv", index=False) 60 | # %% 61 | # load substitution data, introduce whitespeces and CLS/EOS tokens 62 | # save complete data as csv, load as HF dataset 63 | if os.path.exists(data_path + "ProteinGym_substitutions.csv"): 64 | print("preprocessing was already done, load csv") 65 | dataset = load_dataset("csv", data_files=(data_path + "ProteinGym_substitutions.csv")) 66 | else: 67 | print("preprocess substitutions ...") 68 | folder_path = "data/ProteinGym/ProteinGym_substitutions" 69 | all_data = [] 70 | for filename in os.listdir(folder_path): 71 | if filename.endswith(".csv"): 72 | file_path = os.path.join(folder_path, filename) 73 | df = pd.read_csv(file_path) 74 | experiment = filename[:-4] 75 | 76 | 77 | # add experiment name to track and get base sequence for zero-shot tasks 78 | df["mutant"] = experiment + "_" + df["mutant"] 79 | all_data.append(df) 80 | 81 | # get dataframe 82 | merged_data = pd.concat(all_data, ignore_index=True) 83 | # save the baseseqs 84 | # Add spaces between each amino acid in the "mutated_sequences" column 85 | # merged_data["mutated_sequence"] = merged_data["mutated_sequence"].apply(lambda seq: " ".join(list(seq))) 86 | # add cls and end tokens 87 | merged_data["mutated_sequence"] = "" + merged_data["mutated_sequence"] + "" 88 | # save csv 89 | merged_data.to_csv(data_path + "ProteinGym_substitutions.csv", index=False) 90 | dataset = load_dataset("csv", data_files=(data_path + "ProteinGym_substitutions.csv")) 91 | del merged_data 92 | 93 | 94 | if __name__ == '__main__': 95 | main() -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts.py/fitness_supervised.py: -------------------------------------------------------------------------------- 1 | # %% 2 | ###################################################################################################### 3 | # Script: ProteinGym Supervised Eval Script 4 | # Authors: Maximilian Sprang, Muedi 5 | # Date: 09/2023 6 | # Description: This script uses HF's evaluate library to test supervised perfromance of a given model 7 | # on ProteinGym data. 8 | # ATM only substitution data is implemented for the finetunning but both are preprocessed and the 9 | # complete datasets saved as CSV. 10 | ###################################################################################################### 11 | import sys, os 12 | sys.path.append(os.getcwd()) #needed to run script from base dir. 13 | # Otherwise prot_lm throws module not found exception 14 | 15 | # huggingface 16 | from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer 17 | from evaluate import load 18 | from datasets import load_dataset 19 | # ours 20 | from protein_lm.tokenizer import EsmTokenizer, AptTokenizer 21 | # others 22 | from tqdm import tqdm 23 | import numpy as np 24 | import argparse 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(description="Supervised Training Script") 28 | parser.add_argument("--data_path", default="protein_lm/dataset/ProteinGym/", type=str, help="Path to ProteinGym data") 29 | parser.add_argument("--checkpoint", default="facebook/esm2_t33_650M_UR50D", type=str, help="Checkpoint, of online model, or path to local checkpoints") 30 | 31 | args = parser.parse_args() 32 | 33 | checkpoint = args.checkpoint 34 | data_path = args.data_path 35 | dataset = load_dataset("csv", data_files=(data_path + "ProteinGym_substitutions.csv")) 36 | 37 | # load model for seq classification 38 | num_labels = 2 39 | model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels) 40 | model.config.pad_token_id = 2 # needed for apt as long as tokenizer is not API compatible :) 41 | 42 | model_name = checkpoint.split("/")[-1] 43 | batch_size = 8 44 | 45 | tokenizer = AptTokenizer() 46 | 47 | def tokenize(batch): 48 | tokens = tokenizer(batch["mutated_sequence"], return_tensors=True, max_sequence_length=760) 49 | return {"input_ids": tokens} 50 | 51 | token_data = dataset.map(tokenize, batched=True) 52 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 53 | 54 | # rename and remove stuff to fit into dataloader seemlessly 55 | # removes info we don"t use in the network, as we only use tokens and binned scores 56 | token_data = token_data.remove_columns(["DMS_score", "mutant", "mutated_sequence"]) 57 | # binned scores are renamed to "labels" 58 | token_data = token_data.rename_column("DMS_score_bin", "labels") 59 | 60 | # Split the train dataset into train, valid, and test subsets 61 | dict_train_test = token_data["train"].train_test_split(test_size=0.4, shuffle=True) 62 | train_dataset = dict_train_test["train"] 63 | test_dataset = dict_train_test["test"] 64 | 65 | # subset for testruns: 66 | # train_dataset = train_dataset.select([x for x in range(200)]) 67 | # test_dataset = test_dataset.select([x for x in range(100)]) 68 | 69 | # # here we could split into validation and test if needed 70 | # dict_test_valid = test_dataset.train_test_split(test_size=0.5, shuffle=True) 71 | # test_dataset = dict_test_valid["test"] 72 | # valid_dataset = dict_test_valid["train"] 73 | 74 | args = TrainingArguments( 75 | f"{model_name}-finetuned-localization", 76 | evaluation_strategy = "epoch", 77 | save_strategy = "epoch", 78 | learning_rate=2e-5, 79 | per_device_train_batch_size=batch_size, 80 | per_device_eval_batch_size=batch_size, 81 | num_train_epochs=3, 82 | weight_decay=0.01, 83 | load_best_model_at_end=True, 84 | metric_for_best_model="accuracy", 85 | push_to_hub=False, 86 | ) 87 | 88 | metric = load("accuracy") 89 | 90 | def compute_metrics(eval_pred): 91 | predictions, labels = eval_pred 92 | predictions = np.argmax(predictions, axis=1) 93 | return metric.compute(predictions=predictions, references=labels) 94 | 95 | trainer = Trainer( 96 | model, 97 | args, 98 | train_dataset=train_dataset, 99 | eval_dataset=test_dataset, 100 | # tokenizer=tokenizer, 101 | compute_metrics=compute_metrics, 102 | ) 103 | # run trainer, this will return eval loass andd accuracy every few steps 104 | # and save this to the disk in the model-ceckpoint* folder 105 | trainer.train() 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts.py/fitness_zero_shot_AR.py: -------------------------------------------------------------------------------- 1 | # %% 2 | ###################################################################################################### 3 | # Script: ProteinGym Supervised Eval Script 4 | # Authors: Maximilian Sprang, Muedi 5 | # Date: 09/2023 6 | # Description: zero shot for autoregressive models, as fopund in RITA 7 | # https://github.com/lightonai/RITA/blob/master/compute_fitness.py 8 | ###################################################################################################### 9 | import sys, os 10 | sys.path.append(os.getcwd()) #needed to run script from base dir. 11 | # Otherwise prot_lm throws module not found exception 12 | from transformers import AutoModelForCausalLM 13 | from torch.nn import CrossEntropyLoss 14 | import torch 15 | from tqdm import tqdm 16 | import numpy as np 17 | import pandas as pd 18 | from scipy.stats import spearmanr 19 | import argparse 20 | # ours 21 | from protein_lm.tokenizer import AptTokenizer 22 | 23 | 24 | def calc_fitness(model, prots, tokenizer, device="cuda:0", model_context_len=1023): 25 | # calculates the fitness 26 | loss_list = [] 27 | loss_fn = CrossEntropyLoss() 28 | with torch.no_grad(): 29 | for prot in tqdm(prots): 30 | loss_val = 0 31 | 32 | sequence_chunks=[] 33 | if len(prot) < model_context_len: 34 | sequence_chunks = [prot] 35 | else: 36 | len_target_seq = len(prot) 37 | num_windows = 1 + int( len_target_seq / model_context_len) 38 | start=0 39 | for window_index in range(1, num_windows+1): 40 | sequence_chunks.append(prot[start:start+model_context_len]) 41 | start += model_context_len 42 | 43 | for chunk in sequence_chunks: 44 | for p in [chunk, chunk[::-1]]: 45 | ids = torch.tensor([tokenizer.encode(p)]).to(device) 46 | input_ids = ids[:, :-1] 47 | targets = ids[:, 1:] 48 | 49 | logits=model(input_ids).logits 50 | loss = loss_fn(target=targets.view(-1), input=logits.view(-1,logits.size(-1))) 51 | loss_val += -loss.item() 52 | 53 | loss_list += [loss_val] 54 | return np.array(loss_list) 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser(description="Supervised Training Script") 58 | parser.add_argument("--checkpoint", default="checkpoints/toy", type=str, help="Checkpoint, path to local checkpoints") 59 | parser.add_argument("--data_path", default="protein_lm/dataset/ProteinGym/", type=str, help="Path to ProteinGym data") 60 | parser.add_argument("--outdir", default="protein_lm/evaluation/output/likelihood-autoreg/", type=str, help="Directory for output files") 61 | parser.add_argument("--nogpu", default=False, type=bool, help="Set true to run model on CPU") 62 | args = parser.parse_args() 63 | 64 | checkpoint = args.checkpoint 65 | data_path = args.data_path 66 | outdir = args.outdir 67 | nogpu = args.nogpu 68 | # check if output path exists 69 | if not os.path.exists(outdir): 70 | os.makedirs(outdir) 71 | 72 | model = AutoModelForCausalLM.from_pretrained(checkpoint) 73 | if torch.cuda.is_available() and not nogpu: 74 | model = model.cuda() 75 | print("Transferred model to GPU") 76 | model.eval() 77 | tokenizer = AptTokenizer() 78 | 79 | # get experiments and base seqs 80 | ref_df = pd.read_csv(data_path + "ProteinGym_reference_file_substitutions.csv") 81 | dms_ids = ref_df.DMS_id 82 | dms_file = ref_df.DMS_filename 83 | dms_ref_seqs = ref_df.target_seq 84 | for experiment, file_name, sequence in zip(dms_ids, dms_file, dms_ref_seqs): 85 | 86 | # Load the deep mutational scan 87 | DMS_data_path = data_path + "ProteinGym_substitutions/" + file_name 88 | DMS_data = pd.read_csv(DMS_data_path) 89 | DMS_output = "scores_{}".format(file_name) 90 | 91 | # compute scores 92 | model_scores = calc_fitness(model=model, prots=np.array(DMS_data["mutated_sequence"]), tokenizer=tokenizer) 93 | 94 | DMS_data["APT_score"] = model_scores 95 | DMS_data.to_csv(outdir + DMS_output, index=False) 96 | 97 | spearman, _ = spearmanr(DMS_data["APT_score"], DMS_data["DMS_score"]) 98 | print("Performance of APT on experiment {}: {}".format(experiment, spearman)) 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts.py/fitness_zero_shot_ESM.py: -------------------------------------------------------------------------------- 1 | # %% 2 | ###################################################################################################### 3 | # Script: ProteinGym Supervised Eval Script 4 | # Authors: Maximilian Sprang, Muedi 5 | # Date: 09/2023 6 | # Description: Zero-shot eval script for MLM models like ESM: 7 | # https://github.com/facebookresearch/esm/blob/main/examples/variant-prediction/predict.py 8 | ###################################################################################################### 9 | import sys, os 10 | sys.path.append(os.getcwd()) #needed to run script from base dir. 11 | # Otherwise prot_lm throws module not found exception 12 | import torch 13 | from tqdm import tqdm 14 | import numpy as np 15 | import pandas as pd 16 | from scipy.stats import spearmanr 17 | from Bio import SeqIO 18 | import itertools 19 | from typing import List, Tuple 20 | import argparse 21 | # ours 22 | from protein_lm.tokenizer import AptTokenizer 23 | # esm 24 | from esm import pretrained, Alphabet, FastaBatchedDataset, MSATransformer 25 | 26 | 27 | # def remove_insertions(sequence: str) -> str: 28 | # """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ 29 | # # This is an efficient way to delete lowercase characters and insertion characters from a string 30 | # deletekeys = dict.fromkeys(string.ascii_lowercase) 31 | # deletekeys["."] = None 32 | # deletekeys["*"] = None 33 | 34 | # translation = str.maketrans(deletekeys) 35 | # return sequence.translate(translation) 36 | 37 | 38 | # def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]: 39 | # """ Reads the first nseq sequences from an MSA file, automatically removes insertions. 40 | 41 | # The input file must be in a3m format (although we use the SeqIO fasta parser) 42 | # for remove_insertions to work properly.""" 43 | 44 | # msa = [ 45 | # (record.description, remove_insertions(str(record.seq))) 46 | # for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq) 47 | # ] 48 | # return msa 49 | 50 | 51 | def label_row(rows, sequence, token_probs, alphabet, offset_idx): 52 | rows = rows.split(":") 53 | score = 0 54 | for row in rows: 55 | wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1] 56 | assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence" 57 | 58 | wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt) 59 | 60 | # add 1 for BOS 61 | score_obj = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded] 62 | score += score_obj.item() 63 | return score / len(rows) 64 | 65 | 66 | def compute_pppl(mutated_sequence, model, alphabet): 67 | """ 68 | The original methods changes the given base_sequence to the mutated one, we"ll just read it from the df. 69 | We compute the pseudo-Perplexity of the complete mutated sequence. 70 | The code to achieve this has not been changed from esm's repo 71 | """ 72 | # wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1] 73 | # assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence" 74 | 75 | # # modify the sequence 76 | # sequence = sequence[:idx] + mt + sequence[(idx + 1) :] 77 | 78 | # encode the sequence 79 | data = [ 80 | ("protein1", mutated_sequence), 81 | ] 82 | 83 | batch_converter = alphabet.get_batch_converter() 84 | 85 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 86 | 87 | # wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt) 88 | 89 | # compute token probabilities at each position 90 | log_probs = [] 91 | for i in range(1, len(mutated_sequence) - 1): 92 | batch_tokens_masked = batch_tokens.clone() 93 | batch_tokens_masked[0, i] = alphabet.mask_idx 94 | with torch.no_grad(): 95 | token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1) 96 | log_probs.append(token_probs[0, i, alphabet.get_idx(mutated_sequence[i])].item()) # vocab size 97 | return sum(log_probs) 98 | 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser(description="Supervised Training Script") 103 | parser.add_argument("--checkpoint", default="facebook/esm2_t33_650M_UR50D", type=str, help="Checkpoint, path to local checkpoints") 104 | parser.add_argument("--data_path", default="protein_lm/dataset/ProteinGym/", type=str, help="Path to ProteinGym data") 105 | parser.add_argument("--outdir", default="protein_lm/evaluation/output/likelihood-autoreg/", type=str, help="Directory for output files") 106 | parser.add_argument("--scoring_strategy", default="masked-marginals", choices=["masked-marginals", "pseudo-ppl", "wt-marginals"], type=str, help="Scoring strategies for MLMs") 107 | parser.add_argument("--nogpu", default=False, type=bool, help="Set true to run model on CPU") 108 | args = parser.parse_args() 109 | 110 | # assign vars 111 | checkpoint = args.checkpoint 112 | data_path = args.data_path 113 | outdir = args.outdir 114 | nogpu = args.nogpu 115 | scoring_strategy = args.scoring_strategy 116 | 117 | # fixed vars (can be added as args if needed later 118 | mutation_col = 0 # column that holds info on mutations 119 | offset_idx = 1 # offset index, default was zero, but in our case it needs to be one 120 | 121 | # get experiments and base seqs 122 | ref_df = pd.read_csv(data_path + "ProteinGym_reference_file_substitutions.csv") 123 | dms_ids = ref_df.DMS_id 124 | dms_file = ref_df.DMS_filename 125 | dms_ref_seqs = ref_df.target_seq 126 | 127 | # relative output path 128 | outdir = "protein_lm/evaluation/output/{}/".format(scoring_strategy) 129 | # check if output path exists 130 | if not os.path.exists(outdir): 131 | os.makedirs(outdir) 132 | 133 | # inference for given model 134 | # set checkpoint to be mnodel location for now 135 | model_name = checkpoint.split("/")[-1] 136 | model, alphabet = pretrained.load_model_and_alphabet(model_name) 137 | model.eval() 138 | if torch.cuda.is_available() and not nogpu: 139 | model = model.cuda() 140 | print("Transferred model to GPU") 141 | 142 | for experiment, file_name, sequence in zip(dms_ids, dms_file, dms_ref_seqs): 143 | 144 | # Load the deep mutational scan 145 | DMS_data_path = data_path + "ProteinGym_substitutions/" + file_name 146 | DMS_data = pd.read_csv(DMS_data_path) 147 | DMS_output = "scores_{}".format(file_name) 148 | 149 | 150 | batch_converter = alphabet.get_batch_converter() 151 | 152 | if isinstance(model, MSATransformer): 153 | # as far as I know we do not plan on using this? I kept it around for now. 154 | print("MSATransformer is currently not supported :)") 155 | pass 156 | # data = [read_msa(args.msa_path, args.msa_samples)] 157 | # assert ( 158 | # scoring_strategy == "masked-marginals" 159 | # ), "MSA Transformer only supports masked marginal strategy" 160 | 161 | # batch_labels, batch_strs, batch_tokens = batch_converter(data) 162 | 163 | # all_token_probs = [] 164 | # for i in tqdm(range(batch_tokens.size(2))): 165 | # batch_tokens_masked = batch_tokens.clone() 166 | # batch_tokens_masked[0, 0, i] = alphabet.mask_idx # mask out first sequence 167 | # with torch.no_grad(): 168 | # token_probs = torch.log_softmax( 169 | # model(batch_tokens_masked.cuda())["logits"], dim=-1 170 | # ) 171 | # all_token_probs.append(token_probs[:, 0, i]) # vocab size 172 | # token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) 173 | # DMS_data[model_name+"_score"] = DMS_data.apply( 174 | # lambda row: label_row( 175 | # row[mutation_col], sequence, token_probs, alphabet, offset_idx 176 | # ), 177 | # axis=1, 178 | # ) 179 | 180 | else: 181 | data = [ 182 | ("protein1", sequence), 183 | ] 184 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 185 | 186 | if scoring_strategy == "wt-marginals": 187 | with torch.no_grad(): 188 | token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1) 189 | DMS_data[model_name+"_score"] = DMS_data.apply( 190 | lambda row: label_row( 191 | row[mutation_col], 192 | sequence, 193 | token_probs, 194 | alphabet, 195 | offset_idx, 196 | ), 197 | axis=1, 198 | ) 199 | elif scoring_strategy == "masked-marginals": 200 | all_token_probs = [] 201 | for i in tqdm(range(batch_tokens.size(1))): 202 | batch_tokens_masked = batch_tokens.clone() 203 | batch_tokens_masked[0, i] = alphabet.mask_idx 204 | with torch.no_grad(): 205 | token_probs = torch.log_softmax( 206 | model(batch_tokens_masked.cuda())["logits"], dim=-1 207 | ) 208 | all_token_probs.append(token_probs[:, i]) # vocab size 209 | token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) 210 | DMS_data[model_name+"_score"] = DMS_data.apply( 211 | lambda row: label_row( 212 | row[mutation_col], 213 | sequence, 214 | token_probs, 215 | alphabet, 216 | offset_idx, 217 | ), 218 | axis=1, 219 | ) 220 | elif scoring_strategy == "pseudo-ppl": 221 | tqdm.pandas() 222 | DMS_data[model_name+"_score"] = DMS_data.progress_apply( 223 | lambda row: compute_pppl( 224 | #row[mutation_col], 225 | # sequence, 226 | row["mutated_sequence"], 227 | model, 228 | alphabet 229 | #offset_idx 230 | ), 231 | axis=1, 232 | ) 233 | # save experiment 234 | DMS_data.to_csv(outdir + DMS_output, index=None) 235 | spearman, _ = spearmanr(DMS_data["{}_score".format(model_name)], DMS_data["DMS_score"]) 236 | print("Performance of {} on experiment {}: {}".format(model_name, experiment, spearman)) 237 | 238 | if __name__ == "__main__": 239 | main() -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts/Protein-gym.py: -------------------------------------------------------------------------------- 1 | ###################################################################################################### 2 | # Script: ProteinGym Eval script 3 | # Authors: Maximilian Sprang, Muedi 4 | # Date: 08/2023 5 | # Description: This script downloads and handles the zipped Protein Gym Csvs and preprocesses them 6 | # to be able to tokenized by EMS/ProtBERT tokenizers. 7 | # Tokenization is done and then the esm 630M Model is used to be finetuned on ProteinGyms data 8 | # ATM only substitution data is implemented for the finetunning but both are preprocessed and the 9 | # complete datasets saved as CSV. 10 | # finetuning is done with the evaluaten libray, which we'll likely change to an own trainling loop 11 | # to be more flexible with our own models. 12 | ###################################################################################################### 13 | # %% 14 | # huggingface 15 | from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer 16 | from datasets import load_dataset 17 | from evaluate import load 18 | # others 19 | # import matplotlib.pyplot as plt 20 | from datetime import datetime 21 | import numpy as np 22 | import pandas as pd 23 | # http requests 24 | import requests, zipfile, io, os 25 | 26 | # %% 27 | # download substitutions, unzip, save to disk 28 | path = "data/ProteinGym/" 29 | sub_url = "https://marks.hms.harvard.edu/proteingym/ProteinGym_substitutions.zip" 30 | 31 | if os.path.exists(path + "ProteinGym_substitutions"): 32 | print("substitution data is already here :)") 33 | else: 34 | print("download substitution data ...") 35 | r = requests.get(sub_url) 36 | z = zipfile.ZipFile(io.BytesIO(r.content)) 37 | z.extractall(path) 38 | 39 | # download indels, unzip, save to disk 40 | sub_url = "https://marks.hms.harvard.edu/proteingym/ProteinGym_indels.zip" 41 | 42 | if os.path.exists(path + "ProteinGym_indels"): 43 | print("indel data is already here :)") 44 | else: 45 | print("download indel data ...") 46 | r = requests.get(sub_url) 47 | z = zipfile.ZipFile(io.BytesIO(r.content)) 48 | z.extractall(path) 49 | 50 | # %% 51 | # load substitution data, introduce whitespeces and CLS/EOS tokens 52 | # save complete data as csv, load as HF dataset 53 | if os.path.exists(path + "ProteinGym_substitutions.csv"): 54 | print("preprocessing was already done, load csv") 55 | dataset = load_dataset("csv", data_files=(path + "ProteinGym_substitutions.csv")) 56 | else: 57 | print("preprocess substitutions ...") 58 | folder_path = "data/ProteinGym/ProteinGym_substitutions" 59 | all_data = [] 60 | for filename in os.listdir(folder_path): 61 | if filename.endswith('.csv'): 62 | file_path = os.path.join(folder_path, filename) 63 | df = pd.read_csv(file_path) 64 | all_data.append(df) 65 | merged_data = pd.concat(all_data, ignore_index=True) 66 | # Add spaces between each amino acid in the 'mutated_sequences' column 67 | merged_data['mutated_sequence'] = merged_data['mutated_sequence'].apply(lambda seq: ' '.join(list(seq))) 68 | # add cls and end tokens 69 | merged_data['mutated_sequence'] = " " + merged_data['mutated_sequence'] + " " 70 | # save csv 71 | merged_data.to_csv(path + "ProteinGym_substitutions.csv", index=False) 72 | dataset = load_dataset("csv", data_files=(path + "ProteinGym_substitutions.csv")) 73 | del merged_data 74 | 75 | # %% tokenize, with esm2_t33_650M_UR50D, use same checkpoint for model 76 | checkpoint = "facebook/esm2_t33_650M_UR50D" 77 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 78 | def tokenize(batch): 79 | return tokenizer(batch["mutated_sequence"], truncation=True, padding='max_length', max_length=760) 80 | 81 | token_data = dataset.map(tokenize, batched=True) 82 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 83 | 84 | # rename and remove stuff to fit into dataloader seemlessly 85 | # removes info we don't use in the network, as we only use tokens and binned scores 86 | token_data = token_data.remove_columns(["DMS_score", "mutant", "mutated_sequence"]) 87 | # binned scores are renamed to 'labels' 88 | token_data = token_data.rename_column("DMS_score_bin", "labels") 89 | 90 | # Split the train dataset into train, valid, and test subsets 91 | dict_train_test = token_data['train'].train_test_split(test_size=0.4, shuffle=True) 92 | train_dataset = dict_train_test['train'] 93 | test_dataset = dict_train_test['test'] 94 | 95 | # subset for testruns: 96 | # train_dataset = train_dataset.select([x for x in range(200)]) 97 | # test_dataset = test_dataset.select([x for x in range(100)]) 98 | 99 | # # here we could split into validation and test if needed 100 | # dict_test_valid = test_dataset.train_test_split(test_size=0.5, shuffle=True) 101 | # test_dataset = dict_test_valid['test'] 102 | # valid_dataset = dict_test_valid['train'] 103 | # %% taken from facebooks pretrained-finetuning notebook here: 104 | # https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb#scrollTo=fc164b49 105 | supervised=True 106 | if supervised: 107 | # load model for seq classification 108 | num_labels = 2 109 | model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels) 110 | 111 | model_name = checkpoint.split("/")[-1] 112 | batch_size = 8 113 | 114 | args = TrainingArguments( 115 | f"{model_name}-finetuned-localization", 116 | evaluation_strategy = "epoch", 117 | save_strategy = "epoch", 118 | learning_rate=2e-5, 119 | per_device_train_batch_size=batch_size, 120 | per_device_eval_batch_size=batch_size, 121 | num_train_epochs=3, 122 | weight_decay=0.01, 123 | load_best_model_at_end=True, 124 | metric_for_best_model="accuracy", 125 | push_to_hub=False, 126 | ) 127 | 128 | metric = load("accuracy") 129 | 130 | def compute_metrics(eval_pred): 131 | predictions, labels = eval_pred 132 | predictions = np.argmax(predictions, axis=1) 133 | return metric.compute(predictions=predictions, references=labels) 134 | 135 | trainer = Trainer( 136 | model, 137 | args, 138 | train_dataset=train_dataset, 139 | eval_dataset=test_dataset, 140 | tokenizer=tokenizer, 141 | compute_metrics=compute_metrics, 142 | ) 143 | # run trainer, this will return eval loass andd accuracy every few steps 144 | # and save this to the disk in the esm2* folder 145 | trainer.train() 146 | else: 147 | # here we'll add a zero-shot eval script like this: 148 | # https://github.com/facebookresearch/esm/blob/main/examples/variant-prediction/predict.py 149 | pass -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts/contact_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from protein_lm.modeling.getters.model import get_model 4 | from protein_lm.modeling.getters.tokenizer import get_tokenizer 5 | from protein_lm.tokenizer.tokenizer import EsmTokenizer 6 | from protein_lm.evaluation.scripts.utils import * 7 | import yaml 8 | 9 | import sys 10 | import argparse 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import matplotlib as mpl 16 | import pandas as pd 17 | torch.set_grad_enabled(False) 18 | 19 | import esm 20 | 21 | def predict_contacts_jacobian(modelname,model,x,ln,device): 22 | with torch.no_grad(): 23 | #model.logits returns batch_size x seq_len x vocab_size tensor 24 | if modelname=="APT": 25 | f = lambda x: model(x.to(device)).logits[...,1:(ln+1),3:23].cpu().numpy() 26 | elif modelname=="ESM": 27 | f = lambda x: model(x)["logits"][...,1:(ln+1),4:24].cpu().numpy() 28 | fx = f(x.to(device))[0] 29 | x = torch.tile(x,[20,1]).to(device) 30 | fx_h = np.zeros((ln,20,ln,20)) 31 | for n in range(ln): # for each position 32 | x_h = torch.clone(x) 33 | if modelname=="APT": 34 | x_h[:,n] = torch.arange(3,23) # mutate to all 20 aa 35 | elif modelname=="ESM": 36 | x_h[:,n+1] = torch.arange(4,24) 37 | fx_h[n] = f(x_h) 38 | jac=fx-fx_h 39 | # center & symmetrize 40 | for i in range(4): jac -= jac.mean(i,keepdims=True) 41 | jac = (jac + jac.transpose(2,3,0,1))/2 42 | return get_contacts(jac) 43 | 44 | def predict_contacts_regression(model,inputs,tokenizer,device): 45 | with torch.no_grad(): 46 | token_ids = tokenizer.encode(inputs[1],add_special_tokens=True) 47 | token_ids = torch.tensor(token_ids, dtype=torch.long) 48 | token_ids = token_ids.to(device) # Move token_ids to the same device as the model 49 | token_ids=token_ids.unsqueeze(0) 50 | return model.predict_contacts(token_ids)[0].cpu() 51 | 52 | def output_results(predictions,results,PDB_IDS): 53 | if not os.path.exists(args.output+args.method): 54 | os.makedirs(args.output+args.method) 55 | 56 | results = pd.DataFrame(results) 57 | results.to_csv(args.output+args.method+"/contact_prediction_results.csv",sep=",",index=False) 58 | for name in PDB_IDS: 59 | prediction = predictions[name] 60 | target = contacts[name] 61 | plot_contacts_and_predictions( 62 | prediction, target, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}" 63 | ) 64 | plt.savefig(args.output+args.method+"/"+name+".png") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser(description="Contact Prediction Script") 69 | parser.add_argument("--input", type=str,help="dir containing .a3m files for contact prediction") 70 | parser.add_argument("--configfile",default="protein_lm/configs/train/toy_localcsv.yaml",type=str, help="path to config file") 71 | parser.add_argument("--model", help="APT or ESM") 72 | parser.add_argument("--tokenizer", type=str,help="AptTokenizer or EsmTokenizer") 73 | parser.add_argument("--method",type=str,help="contact prediction method either jacobian or regression") 74 | parser.add_argument("--output",type=str,help="output dir for contact maps") 75 | args = parser.parse_args() 76 | 77 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 78 | 79 | with open(args.configfile, "r") as cf: 80 | config_dict = yaml.safe_load(cf) 81 | print(config_dict) 82 | 83 | if args.model=="APT": 84 | model = get_model( 85 | config_dict=config_dict["model"], 86 | ) 87 | elif args.model=="ESM": 88 | model, _ = esm.pretrained.esm2_t33_650M_UR50D() 89 | 90 | if args.tokenizer=="AptTokenizer": 91 | tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"]) 92 | elif args.tokenizer=="EsmTokenizer": 93 | tokenizer = EsmTokenizer() 94 | 95 | model.to(device) 96 | PDB_IDS = [f.split("_")[0] for f in os.listdir(args.input) if f.endswith(".a3m")] 97 | 98 | structures = { 99 | name.lower(): get_structure(PDBxFile.read(rcsb.fetch(name, "cif")))[0] 100 | for name in PDB_IDS 101 | } 102 | 103 | contacts = { 104 | name: contacts_from_pdb(structure, chain="A") 105 | for name, structure in structures.items() 106 | } 107 | 108 | msas = { 109 | name: read_msa(args.input+f"{name.lower()}_1_A.a3m") 110 | for name in PDB_IDS 111 | } 112 | 113 | sequences = { 114 | name: msa[0] for name, msa in msas.items() 115 | } 116 | 117 | predictions = {} 118 | results = [] 119 | 120 | if args.method=="jacobian": 121 | for name, inputs in sequences.items(): 122 | x,ln = tokenizer.batch_encode([inputs[1]],add_special_tokens=True),len(inputs[1]) 123 | x=torch.tensor(x) 124 | predictions[name]=predict_contacts_jacobian(args.model,model,x,ln,device) 125 | metrics = {"id": name, "model": args.model+"(Unsupervised)"} 126 | metrics.update(evaluate_prediction(predictions[name], contacts[name])) 127 | results.append(metrics) 128 | output_results(predictions,results,PDB_IDS) 129 | 130 | 131 | elif args.method=="regression": 132 | 133 | for name, inputs in sequences.items(): 134 | predictions[name]=predict_contacts_regression(model,inputs,tokenizer,device) 135 | metrics = {"id": name, "model": args.model+"(Unsupervised)"} 136 | metrics.update(evaluate_prediction(predictions[name], contacts[name])) 137 | results.append(metrics) 138 | 139 | output_results(predictions,results,PDB_IDS) -------------------------------------------------------------------------------- /protein_lm/evaluation/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import string 4 | from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable 5 | import biotite.structure as bs 6 | from biotite.structure.io.pdbx import PDBxFile, get_structure 7 | from biotite.database import rcsb 8 | import matplotlib.pyplot as plt 9 | import matplotlib as mpl 10 | from scipy.spatial.distance import squareform, pdist, cdist 11 | 12 | from Bio import SeqIO 13 | 14 | def parse_fasta(filename, a3m=True): 15 | '''function to parse fasta file''' 16 | 17 | if a3m: 18 | # for a3m files the lowercase letters are removed 19 | # as these do not align to the query sequence 20 | rm_lc = str.maketrans(dict.fromkeys(string.ascii_lowercase)) 21 | 22 | header, sequence = [],[] 23 | lines = open(filename, "r") 24 | for line in lines: 25 | line = line.rstrip() 26 | if line[0] == ">": 27 | header.append(line[1:]) 28 | sequence.append([]) 29 | else: 30 | if a3m: line = line.translate(rm_lc) 31 | else: line = line.upper() 32 | sequence[-1].append(line) 33 | lines.close() 34 | sequence = [''.join(seq) for seq in sequence] 35 | 36 | return header, sequence 37 | 38 | def do_apc(x, rm=1): 39 | '''given matrix do apc correction''' 40 | # trying to remove different number of components 41 | # rm=0 remove none 42 | # rm=1 apc 43 | x = np.copy(x) 44 | if rm == 0: 45 | return x 46 | elif rm == 1: 47 | a1 = x.sum(0,keepdims=True) 48 | a2 = x.sum(1,keepdims=True) 49 | y = x - (a1*a2)/x.sum() 50 | else: 51 | # decompose matrix, rm largest(s) eigenvectors 52 | u,s,v = np.linalg.svd(x) 53 | y = s[rm:] * u[:,rm:] @ v[rm:,:] 54 | np.fill_diagonal(y,0) 55 | return y 56 | 57 | def get_contacts(x, symm=True, center=True, rm=1): 58 | # convert jacobian (L,A,L,A) to contact map (L,L) 59 | j = x.copy() 60 | if center: 61 | for i in range(4): j -= j.mean(i,keepdims=True) 62 | j_fn = np.sqrt(np.square(j).sum((1,3))) 63 | np.fill_diagonal(j_fn,0) 64 | j_fn_corrected = do_apc(j_fn, rm=rm) 65 | if symm: 66 | j_fn_corrected = (j_fn_corrected + j_fn_corrected.T)/2 67 | return torch.tensor(j_fn_corrected) 68 | 69 | 70 | 71 | 72 | deletekeys = dict.fromkeys(string.ascii_lowercase) 73 | deletekeys["."] = None 74 | deletekeys["*"] = None 75 | translation = str.maketrans(deletekeys) 76 | 77 | 78 | def read_sequence(filename: str) -> Tuple[str, str]: 79 | """ Reads the first (reference) sequences from a fasta or MSA file.""" 80 | record = next(SeqIO.parse(filename, "fasta")) 81 | return record.description, str(record.seq) 82 | 83 | def remove_insertions(sequence: str) -> str: 84 | """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ 85 | return sequence.translate(translation) 86 | 87 | def read_msa(filename: str) -> List[Tuple[str, str]]: 88 | """ Reads the sequences from an MSA file, automatically removes insertions.""" 89 | return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] 90 | def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: 91 | assert mode in ("max", "min") 92 | if len(msa) <= num_seqs: 93 | return msa 94 | 95 | array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) 96 | 97 | optfunc = np.argmax if mode == "max" else np.argmin 98 | all_indices = np.arange(len(msa)) 99 | indices = [0] 100 | pairwise_distances = np.zeros((0, len(msa))) 101 | for _ in range(num_seqs - 1): 102 | dist = cdist(array[indices[-1:]], array, "hamming") 103 | pairwise_distances = np.concatenate([pairwise_distances, dist]) 104 | shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) 105 | shifted_index = optfunc(shifted_distance) 106 | index = np.delete(all_indices, indices)[shifted_index] 107 | indices.append(index) 108 | indices = sorted(indices) 109 | return [msa[idx] for idx in indices] 110 | 111 | 112 | 113 | def extend(a, b, c, L, A, D): 114 | """ 115 | input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral 116 | output: 4th coord 117 | """ 118 | 119 | def normalize(x): 120 | return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True) 121 | 122 | bc = normalize(b - c) 123 | n = normalize(np.cross(b - a, bc)) 124 | m = [bc, np.cross(n, bc), n] 125 | d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] 126 | return c + sum([m * d for m, d in zip(m, d)]) 127 | 128 | 129 | def contacts_from_pdb( 130 | structure: bs.AtomArray, 131 | distance_threshold: float = 8.0, 132 | chain: Optional[str] = None, 133 | ) -> np.ndarray: 134 | mask = ~structure.hetero 135 | if chain is not None: 136 | mask &= structure.chain_id == chain 137 | 138 | N = structure.coord[mask & (structure.atom_name == "N")] 139 | CA = structure.coord[mask & (structure.atom_name == "CA")] 140 | C = structure.coord[mask & (structure.atom_name == "C")] 141 | 142 | Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143) 143 | dist = squareform(pdist(Cbeta)) 144 | 145 | contacts = dist < distance_threshold 146 | contacts = contacts.astype(np.int64) 147 | contacts[np.isnan(dist)] = -1 148 | return contacts 149 | 150 | 151 | # Select sequences from the MSA to maximize the hamming distance 152 | # Alternatively, can use hhfilter 153 | def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: 154 | assert mode in ("max", "min") 155 | if len(msa) <= num_seqs: 156 | return msa 157 | 158 | array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) 159 | 160 | optfunc = np.argmax if mode == "max" else np.argmin 161 | all_indices = np.arange(len(msa)) 162 | indices = [0] 163 | pairwise_distances = np.zeros((0, len(msa))) 164 | for _ in range(num_seqs - 1): 165 | dist = cdist(array[indices[-1:]], array, "hamming") 166 | pairwise_distances = np.concatenate([pairwise_distances, dist]) 167 | shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) 168 | shifted_index = optfunc(shifted_distance) 169 | index = np.delete(all_indices, indices)[shifted_index] 170 | indices.append(index) 171 | indices = sorted(indices) 172 | return [msa[idx] for idx in indices] 173 | 174 | 175 | 176 | 177 | def compute_precisions( 178 | predictions: torch.Tensor, 179 | targets: torch.Tensor, 180 | src_lengths: Optional[torch.Tensor] = None, 181 | minsep: int = 6, 182 | maxsep: Optional[int] = None, 183 | override_length: Optional[int] = None, # for casp 184 | ): 185 | if isinstance(predictions, np.ndarray): 186 | predictions = torch.from_numpy(predictions) 187 | if isinstance(targets, np.ndarray): 188 | targets = torch.from_numpy(targets) 189 | if predictions.dim() == 2: 190 | predictions = predictions.unsqueeze(0) 191 | if targets.dim() == 2: 192 | targets = targets.unsqueeze(0) 193 | override_length = (targets[0, 0] >= 0).sum() 194 | 195 | # Check sizes 196 | if predictions.size() != targets.size(): 197 | raise ValueError( 198 | f"Size mismatch. Received predictions of size {predictions.size()}, " 199 | f"targets of size {targets.size()}" 200 | ) 201 | device = predictions.device 202 | 203 | batch_size, seqlen, _ = predictions.size() 204 | seqlen_range = torch.arange(seqlen, device=device) 205 | 206 | sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1) 207 | sep = sep.unsqueeze(0) 208 | valid_mask = sep >= minsep 209 | valid_mask = valid_mask & (targets >= 0) # negative targets are invalid 210 | 211 | if maxsep is not None: 212 | valid_mask &= sep < maxsep 213 | 214 | if src_lengths is not None: 215 | valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1) 216 | valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2) 217 | else: 218 | src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long) 219 | 220 | predictions = predictions.masked_fill(~valid_mask, float("-inf")) 221 | 222 | x_ind, y_ind = np.triu_indices(seqlen, minsep) 223 | predictions_upper = predictions[:, x_ind, y_ind] 224 | targets_upper = targets[:, x_ind, y_ind] 225 | 226 | topk = seqlen if override_length is None else max(seqlen, override_length) 227 | indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk] 228 | topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices] 229 | if topk_targets.size(1) < topk: 230 | topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)]) 231 | 232 | cumulative_dist = topk_targets.type_as(predictions).cumsum(-1) 233 | 234 | gather_lengths = src_lengths.unsqueeze(1) 235 | if override_length is not None: 236 | gather_lengths = override_length * torch.ones_like( 237 | gather_lengths, device=device 238 | ) 239 | 240 | gather_indices = ( 241 | torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths 242 | ).type(torch.long) - 1 243 | 244 | binned_cumulative_dist = cumulative_dist.gather(1, gather_indices) 245 | binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as( 246 | binned_cumulative_dist 247 | ) 248 | 249 | pl5 = binned_precisions[:, 1] 250 | pl2 = binned_precisions[:, 4] 251 | pl = binned_precisions[:, 9] 252 | auc = binned_precisions.mean(-1) 253 | 254 | return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5} 255 | 256 | 257 | def evaluate_prediction( 258 | predictions: torch.Tensor, 259 | targets: torch.Tensor, 260 | ) -> Dict[str, float]: 261 | if isinstance(targets, np.ndarray): 262 | targets = torch.from_numpy(targets) 263 | contact_ranges = [ 264 | ("local", 3, 6), 265 | ("short", 6, 12), 266 | ("medium", 12, 24), 267 | ("long", 24, None), 268 | ] 269 | metrics = {} 270 | targets = targets.to(predictions.device) 271 | for name, minsep, maxsep in contact_ranges: 272 | rangemetrics = compute_precisions( 273 | predictions, 274 | targets, 275 | minsep=minsep, 276 | maxsep=maxsep, 277 | ) 278 | for key, val in rangemetrics.items(): 279 | metrics[f"{name}_{key}"] = val.item() 280 | return metrics 281 | 282 | 283 | 284 | """Adapted from: https://github.com/rmrao/evo/blob/main/evo/visualize.py""" 285 | def plot_contacts_and_predictions( 286 | predictions: Union[torch.Tensor, np.ndarray], 287 | contacts: Union[torch.Tensor, np.ndarray], 288 | ax: Optional[mpl.axes.Axes] = None, 289 | # artists: Optional[ContactAndPredictionArtists] = None, 290 | cmap: str = "Blues", 291 | ms: float = 1, 292 | title: Union[bool, str, Callable[[float], str]] = True, 293 | animated: bool = False, 294 | ) -> None: 295 | 296 | if isinstance(predictions, torch.Tensor): 297 | predictions = predictions.detach().cpu().numpy() 298 | if isinstance(contacts, torch.Tensor): 299 | contacts = contacts.detach().cpu().numpy() 300 | if ax is None: 301 | ax = plt.gca() 302 | 303 | seqlen = contacts.shape[0] 304 | #print(seqlen) 305 | relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen)) 306 | #print(relative_distance) 307 | bottom_mask = relative_distance < 0 308 | 309 | predictions=np.squeeze(predictions) 310 | masked_image = np.ma.masked_where(bottom_mask,predictions) 311 | 312 | invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6 313 | predictions = predictions.copy() 314 | predictions[invalid_mask] = float("-inf") 315 | 316 | topl_val = np.sort(predictions.reshape(-1))[-seqlen] 317 | 318 | pred_contacts = predictions >= topl_val 319 | #print(pred_contacts) 320 | true_positives = contacts & pred_contacts & ~bottom_mask 321 | false_positives = ~contacts & pred_contacts & ~bottom_mask 322 | other_contacts = contacts & ~pred_contacts & ~bottom_mask 323 | 324 | if isinstance(title, str): 325 | title_text: Optional[str] = title 326 | elif title: 327 | long_range_pl = compute_precisions(predictions, contacts, minsep=24)[ 328 | "P@L" 329 | ].item() 330 | if callable(title): 331 | title_text = title(long_range_pl) 332 | else: 333 | title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}" 334 | else: 335 | title_text = None 336 | # Adjust the printing options to display the complete array 337 | #print("s") 338 | 339 | np.set_printoptions(threshold=np.inf) 340 | 341 | # Print the array 342 | 343 | img = ax.imshow(masked_image, cmap=cmap, animated=animated) 344 | oc = ax.plot(*np.where(other_contacts), "o", c="grey", ms=ms)[0] 345 | fn = ax.plot(*np.where(false_positives), "o", c="r", ms=ms)[0] 346 | tp = ax.plot(*np.where(true_positives), "o", c="b", ms=ms)[0] 347 | ti = ax.set_title(title_text) if title_text is not None else None 348 | # artists = ContactAndPredictionArtists(img, oc, fn, tp, ti) 349 | 350 | ax.axis("square") 351 | ax.set_xlim([0, seqlen]) 352 | ax.set_ylim([0, seqlen]) 353 | -------------------------------------------------------------------------------- /protein_lm/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/modeling/__init__.py -------------------------------------------------------------------------------- /protein_lm/modeling/getters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/modeling/getters/__init__.py -------------------------------------------------------------------------------- /protein_lm/modeling/getters/data_collator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal 2 | 3 | from pydantic import BaseModel 4 | from transformers import default_data_collator 5 | 6 | 7 | class DataCollatorConfig(BaseModel): 8 | data_collator_type: Literal["default"] 9 | 10 | 11 | def get_data_collator(config_dict: Dict): 12 | config = DataCollatorConfig(**config_dict) 13 | if config.data_collator_type == "default": 14 | return default_data_collator 15 | else: 16 | raise ValueError(f"Invalid data_collator_type {config.data_collator_type}") 17 | -------------------------------------------------------------------------------- /protein_lm/modeling/getters/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal, Optional 2 | 3 | from datasets import Dataset, load_dataset 4 | from datasets.dataset_dict import DatasetDict 5 | from pydantic import BaseModel 6 | 7 | 8 | class DatasetConfig(BaseModel): 9 | dataset_type: Literal["csv", "huggingface"] 10 | 11 | # The path if local or the huggingface dataset name if huggingface 12 | dataset_loc: str 13 | 14 | # sample size to limit to, if any, usually for debugging 15 | subsample_size: Optional[int] = None 16 | 17 | """ 18 | Args for splitting into train, val, test 19 | to be updated once we have more options 20 | """ 21 | # split seed 22 | split_seed: Optional[int] = None 23 | # size of validation dataset 24 | val_size: int 25 | # size of test dataset 26 | test_size: int 27 | 28 | # name of the column that contains the sequence 29 | sequence_column_name: str 30 | 31 | max_sequence_length: int 32 | do_curriculum_learning: bool 33 | curriculum_learning_strategy: Optional[str] = None 34 | curriculum_learning_column_name: Optional[str] = None 35 | 36 | 37 | def set_input_ids( 38 | result=None, 39 | tokenizer=None, 40 | sequence_column_name="sequence", 41 | max_sequence_length=1024, 42 | ): 43 | result["input_ids"] = tokenizer( 44 | result[sequence_column_name], 45 | max_sequence_length=max_sequence_length, 46 | add_special_tokens=True, 47 | return_tensors=True, 48 | ) 49 | return result 50 | 51 | def batch_set_curriculum_learning_column( 52 | result=None, 53 | input_column_name='sequence', 54 | curriculum_learning_column_name='sequence_length', 55 | strategy='sequence_length' 56 | ): 57 | supported_strategies = ['sequence_length', 'ppl', 'plddt'] 58 | 59 | if strategy not in supported_strategies: 60 | raise Exception(f'Invalid {strategy} provided. Supported strategy values include {", ".join(supported_strategies)}') 61 | 62 | if strategy == 'sequence_length': 63 | # LengthGroupedSampler sorts in descending so we make it ascending by multiplying with -1 64 | result[curriculum_learning_column_name] = [-len(x) for x in result[input_column_name]] 65 | elif strategy in ['ppl', 'plddt']: 66 | result[curriculum_learning_column_name] = [-x for x in result[strategy]] 67 | 68 | return result 69 | 70 | def set_labels(result): 71 | result["labels"] = result["input_ids"].copy() 72 | return result 73 | 74 | 75 | def train_val_test_split( 76 | dataset_dict: DatasetDict, 77 | config: DatasetConfig, 78 | ) -> DatasetDict: 79 | """ 80 | Given a dictionary of datasets that only contains the split "train", 81 | optionally subsamples it, and then splits it 82 | so that it has potentially 3 splits: "train", "val", "test", where 83 | "val" and "test" splits do not exist if the specified sizes are 0 84 | """ 85 | assert set(dataset_dict.keys()) == { 86 | "train" 87 | }, f"{train_val_test_split.__name__} expects its input to have the keys \ 88 | ['train'] but the input has keys {list(dataset_dict.keys())}" 89 | 90 | dataset = dataset_dict["train"] 91 | 92 | val_size = config.val_size 93 | test_size = config.test_size 94 | 95 | assert isinstance( 96 | dataset, Dataset 97 | ), f"Invalid dataset type {type(dataset)}, only datasets.Dataset allowed" 98 | 99 | dataset = dataset.shuffle(seed=config.split_seed) 100 | 101 | if config.subsample_size is not None: 102 | dataset = dataset.select(range(config.subsample_size)) 103 | 104 | valtest_size = val_size + test_size 105 | 106 | if valtest_size > 0: 107 | train_valtest = dataset.train_test_split( 108 | test_size=val_size + test_size, 109 | shuffle=False, 110 | ) 111 | split_dict = { 112 | "train": train_valtest["train"], 113 | } 114 | if test_size > 0 and val_size > 0: 115 | test_val = train_valtest["test"].train_test_split( 116 | test_size=test_size, 117 | shuffle=False, 118 | ) 119 | split_dict["val"] = test_val["train"] 120 | split_dict["test"] = test_val["test"] 121 | elif val_size > 0: 122 | split_dict["val"] = train_valtest["test"] 123 | else: 124 | split_dict["test"] = train_valtest["test"] 125 | else: 126 | split_dict = { 127 | "train": dataset, 128 | } 129 | 130 | split_dataset_dict = DatasetDict(split_dict) 131 | return split_dataset_dict 132 | 133 | 134 | def get_csv_dataset(config: DatasetConfig) -> Dataset: 135 | # note that a csv is read as having just one split "train" 136 | dataset_dict = load_dataset("csv", data_files=config.dataset_loc) 137 | return train_val_test_split(dataset_dict, config) 138 | 139 | 140 | def get_huggingface_dataset(config: DatasetConfig) -> Dataset: 141 | dataset_dict = load_dataset(config.dataset_loc) 142 | if set(dataset_dict.keys()) == {"train", "val", "test"}: 143 | return dataset_dict 144 | 145 | assert set(dataset_dict.keys()) == { 146 | "train" 147 | }, f"Huggingface DatasetDicts should have the keys {{'train'}} or \ 148 | {{'train', 'val', 'split'}} but this DatasetDict has keys \ 149 | {set(dataset_dict.keys())}" 150 | return train_val_test_split(dataset_dict, config) 151 | 152 | 153 | def get_dataset(config_dict: Dict, tokenizer) -> Dataset: 154 | config = DatasetConfig(**config_dict) 155 | 156 | if config.dataset_type == "csv": 157 | train_ds = get_csv_dataset(config) 158 | elif config.dataset_type == "huggingface": 159 | train_ds = get_huggingface_dataset(config) 160 | else: 161 | raise ValueError(f"Invalid dataset_type {config.dataset_type}!") 162 | 163 | train_ds = train_ds.map( 164 | lambda e: set_input_ids( 165 | result=e, 166 | tokenizer=tokenizer, 167 | sequence_column_name=config.sequence_column_name, 168 | max_sequence_length=config.max_sequence_length, 169 | ), 170 | batched=True, 171 | ) 172 | train_ds = train_ds.map(set_labels, batched=True) 173 | if config.do_curriculum_learning: 174 | train_ds = train_ds.map(lambda e: batch_set_curriculum_learning_column( 175 | result = e, 176 | input_column_name = config.sequence_column_name, 177 | curriculum_learning_column_name = config.curriculum_learning_column_name, 178 | strategy = config.curriculum_learning_strategy 179 | 180 | ),batched=True) 181 | 182 | return train_ds 183 | -------------------------------------------------------------------------------- /protein_lm/modeling/getters/model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal, Optional 2 | 3 | import torch 4 | from pydantic import BaseModel 5 | 6 | from protein_lm.modeling.models.apt.config import APTConfig 7 | from protein_lm.modeling.models.apt.model_pytorch import APTLMHeadModel 8 | 9 | 10 | class NNModelConfig(BaseModel): 11 | # If desired, this can be modified to support a variety of model types 12 | # Note: "nn_model_.." because anything with the "model_" prefix leads to 13 | # pydantic namespace warnings 14 | nn_model_type: Literal["APT"] 15 | nn_model_config_args: Dict 16 | pretrained_checkpoint: Optional[str] 17 | 18 | 19 | def get_model(config_dict: Dict): 20 | config = NNModelConfig(**config_dict) 21 | if config.nn_model_type == "APT": 22 | model_constructor = APTLMHeadModel 23 | model_config_constructor = APTConfig 24 | else: 25 | raise ValueError(f"Invalid NNModelConfig.nn_model_type {config.nn_model_type}") 26 | 27 | model_config = model_config_constructor(**config.nn_model_config_args) 28 | if config.pretrained_checkpoint is None: 29 | model = model_constructor(config=model_config) 30 | else: 31 | model = model_constructor.from_pretrained( 32 | pretrained_model_name_or_path=config.pretrained_checkpoint, 33 | config=model_config, 34 | ) 35 | 36 | return model 37 | -------------------------------------------------------------------------------- /protein_lm/modeling/getters/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Literal 2 | 3 | from pydantic import BaseModel 4 | 5 | from protein_lm.tokenizer.tokenizer import AptTokenizer 6 | 7 | 8 | class TokenizerConfig(BaseModel): 9 | tokenizer_type: Literal["APT"] 10 | 11 | 12 | def get_tokenizer(config_dict: Dict): 13 | config = TokenizerConfig(**config_dict) 14 | if config.tokenizer_type == "APT": 15 | tokenizer_constructor = AptTokenizer 16 | else: 17 | raise ValueError(f"Invalid tokenizer_type {config.tokenizer_type}") 18 | 19 | return tokenizer_constructor() 20 | -------------------------------------------------------------------------------- /protein_lm/modeling/getters/training_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | from transformers import TrainingArguments 5 | 6 | 7 | def get_training_args(config_dict: Dict) -> TrainingArguments: 8 | config = TrainingArguments(**config_dict) 9 | 10 | if not os.path.isdir(config.output_dir): 11 | print(f"creating checkpoint directory at {config.output_dir}") 12 | os.makedirs(config.output_dir) 13 | 14 | return config 15 | -------------------------------------------------------------------------------- /protein_lm/modeling/getters/wandb_log.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from pydantic import BaseModel 3 | from typing import Dict, Optional 4 | import os 5 | 6 | 7 | class WandBConfig(BaseModel): 8 | project: str = "protein_lm_scaling" 9 | name: str 10 | # directory to save to 11 | dir: Optional[str] = None 12 | 13 | 14 | def setup_wandb(config_dict: Dict) -> None: 15 | config = WandBConfig(**config_dict) 16 | if config.dir is not None: 17 | if not os.path.isdir(config.dir): 18 | print(f"creating wandb directory at {config.dir}") 19 | os.makedirs(config.dir) 20 | 21 | os.environ["WANDB_PROJECT"] = config.project 22 | os.environ["WANDB_NAME"] = config.name 23 | os.environ["WANDB_DIR"] = config.dir 24 | -------------------------------------------------------------------------------- /protein_lm/modeling/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/modeling/models/__init__.py -------------------------------------------------------------------------------- /protein_lm/modeling/models/apt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/modeling/models/apt/__init__.py -------------------------------------------------------------------------------- /protein_lm/modeling/models/apt/config.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Config 2 | from typing import Literal 3 | 4 | class APTConfig(GPT2Config): 5 | """ 6 | Config subclass for Autoregressive Protein Transformer (APT) model architecture. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | position_embedding: Literal["alibi", "learned", "rope", "rerope", "linear_rope_scaling", "dynamic_rope_scaling"]="learned", 12 | tokenizer=None, 13 | max_sequence_length = 1024, 14 | attn_type="standard", 15 | **kwargs 16 | ): 17 | super().__init__(**kwargs) 18 | self.nn_model_type = "APT" 19 | self.position_embedding = position_embedding 20 | self.tokenizer = tokenizer 21 | self.max_sequence_length = max_sequence_length 22 | self.attn_type = attn_type 23 | 24 | -------------------------------------------------------------------------------- /protein_lm/modeling/models/apt/model_pytorch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | import torch 3 | from torch import nn 4 | from torch.nn import CrossEntropyLoss 5 | from transformers.models.gpt2.modeling_gpt2 import GPT2Block,GPT2Attention 6 | from transformers import GPT2PreTrainedModel 7 | from transformers.modeling_outputs import (BaseModelOutputWithPastAndCrossAttentions,CausalLMOutputWithCrossAttentions) 8 | from transformers.pytorch_utils import Conv1D 9 | from transformers.activations import ACT2FN 10 | from transformers.utils import logging 11 | 12 | from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding 13 | from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding 14 | from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor 15 | from protein_lm.modeling.utils.scaled_rope_embedding import LlamaLinearScalingRotaryEmbedding,LlamaDynamicNTKScalingRotaryEmbedding 16 | from protein_lm.modeling.utils.modules import ContactPredictionHead 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | class APTAttention(GPT2Attention): 21 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 22 | super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) 23 | self.max_positions = config.max_position_embeddings 24 | self.register_buffer( 25 | "bias", 26 | torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)).view( 27 | 1, 1, self.max_positions, self.max_positions 28 | ), 29 | persistent=False, 30 | ) 31 | self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) 32 | self.position_embedding = config.position_embedding 33 | 34 | self.max_sequence_length = config.max_sequence_length 35 | self.embed_dim = config.hidden_size 36 | self.num_heads = config.num_attention_heads 37 | self.attn_type = config.attn_type 38 | self.head_dim = self.embed_dim // self.num_heads 39 | self.split_size = self.embed_dim 40 | if self.head_dim * self.num_heads != self.embed_dim: 41 | raise ValueError( 42 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 43 | f" {self.num_heads})." 44 | ) 45 | 46 | self.scale_attn_weights = config.scale_attn_weights 47 | self.is_cross_attention = is_cross_attention 48 | 49 | # Layer-wise attention scaling, reordering, and upcasting 50 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 51 | self.layer_idx = layer_idx 52 | 53 | if self.attn_type == "gqa": 54 | self.gqa_attn = True 55 | elif self.attn_type == "reorder_and_upcast_attn": 56 | self.reorder_and_upcast_attn = True 57 | elif self.attn_type == "standard": 58 | self.standard_attn = True 59 | 60 | #self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type 61 | 62 | if self.is_cross_attention: 63 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 64 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 65 | else: 66 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 67 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 68 | 69 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 70 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 71 | 72 | self.pruned_heads = set() 73 | 74 | self.rot_emb = None 75 | if self.position_embedding in ["rope", "rerope", "linear_rope_scaling", "dynamic_rope_scaling"]: 76 | self.rope_scaling_factor = config.rope_scaling_factor 77 | self.rope_theta = config.rope_theta 78 | if self.position_embedding == "rope": 79 | self.rot_emb=RotaryEmbedding(dim=self.head_dim) 80 | elif self.position_embedding == "rerope": 81 | self.rot_emb = RectifiedRotaryEmbedding(dim=self.head_dim,max_position_embeddings = self.max_positions) 82 | elif self.position_embedding=="linear_rope_scaling": 83 | self.rot_emb=LlamaLinearScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) 84 | elif self.position_embedding=="dynamic_rope_scaling": 85 | self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) 86 | 87 | 88 | 89 | def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): 90 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 91 | 92 | if self.scale_attn_weights: 93 | attn_weights = attn_weights / torch.full( 94 | [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device 95 | ) 96 | 97 | # Layer-wise attention scaling 98 | if self.scale_attn_by_inverse_layer_idx: 99 | attn_weights = attn_weights / float(self.layer_idx + 1) 100 | 101 | if not self.is_cross_attention: 102 | # if only "normal" attention layer implements causal mask 103 | query_length, key_length = query.size(-2), key.size(-2) 104 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 105 | mask_value = torch.finfo(attn_weights.dtype).min 106 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 107 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 108 | mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 109 | attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) 110 | if alibi_bias is not None: 111 | attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] 112 | 113 | if attention_mask is not None: 114 | # Apply the attention mask 115 | attn_weights = attn_weights + attention_mask 116 | 117 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 118 | 119 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 120 | attn_weights = attn_weights.type(value.dtype) 121 | attn_weights = self.attn_dropout(attn_weights) 122 | 123 | # Mask heads if we want to 124 | if head_mask is not None: 125 | attn_weights = attn_weights * head_mask 126 | 127 | attn_output = torch.matmul(attn_weights, value) 128 | 129 | return attn_output, attn_weights 130 | 131 | def _gqa_attn(self, query, key, value, attention_mask=None, 132 | alibi_bias =None, dropout=0.0): 133 | """Group Query Attention implementation.""" 134 | 135 | # Check for potential issues before moving on 136 | if not query.ndim == key.ndim == value.ndim == 4: 137 | raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes " 138 | f"{query.shape}, {key.shape}, and {value.shape}.") 139 | 140 | """ 141 | Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn 142 | """ 143 | batch_size, num_heads, query_len, query_dim = query.shape 144 | 145 | 146 | scale_factor = 1.0 147 | if self.scale_attn_weights: 148 | scale_factor /= float(value.size(-1)) ** 0.5 149 | query = query / scale_factor 150 | 151 | ''' 152 | Determine the number of groups 153 | For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups 154 | Lets say the number of group are 2 and head are 2, 155 | then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim) 156 | query shape (batch_size, num_groups, num_heads, query_len, query_dim) 157 | attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len). 158 | attention weights shape: (batch_size, num_heads, query_len, key_len) 159 | ''' 160 | 161 | n_groups = query.size(1) // key.size(1) 162 | 163 | if n_groups > 1: 164 | query_shape = query.shape 165 | grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3]) 166 | query_grouped = query.reshape(grouped_shape) 167 | attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1)) 168 | attn_weights = attn_weights_grouped.sum(dim=1) 169 | #print("attn_weights:", attn_weights.shape) 170 | 171 | else: 172 | ''' 173 | If the number of groups is 1, then we can use the normal attention function 174 | ''' 175 | attn_weights = torch.matmul(query, key.transpose(-2, -1)) 176 | 177 | if self.scale_attn_by_inverse_layer_idx: 178 | attn_weights = attn_weights / float(self.layer_idx + 1) 179 | 180 | if attention_mask is not None: 181 | # Apply the attention mask 182 | ''' 183 | Input attention_mask shape: (batch_size, query_len, key_len) 184 | ''' 185 | attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension 186 | 187 | # Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences. 188 | ## Adapted to work with groups and ensure similarity with vanilla attention 189 | if not self.is_cross_attention: 190 | query_length, key_length = query.size(-2), key.size(-2) 191 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 192 | mask_value = torch.finfo(attn_weights.dtype).min 193 | mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 194 | attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) 195 | 196 | # print("attn_weights:", attn_weights) 197 | # Softmax normalization to get the attention scores 198 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 199 | 200 | if alibi_bias is not None: 201 | attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] 202 | 203 | # Apply dropout if specified 204 | attn_weights = attn_weights.type(value.dtype) 205 | attn_weights = self.attn_dropout(attn_weights) 206 | 207 | # Compute the output by multiplying the attention scores with the value tensor. 208 | attn_output = torch.matmul(attn_weights, value) 209 | 210 | return attn_output, attn_weights 211 | 212 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): 213 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) 214 | bsz, num_heads, q_seq_len, dk = query.size() 215 | _, _, k_seq_len, _ = key.size() 216 | 217 | # Preallocate attn_weights for `baddbmm` 218 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) 219 | 220 | # Compute Scale Factor 221 | scale_factor = 1.0 222 | if self.scale_attn_weights: 223 | scale_factor /= float(value.size(-1)) ** 0.5 224 | 225 | if self.scale_attn_by_inverse_layer_idx: 226 | scale_factor /= float(self.layer_idx + 1) 227 | 228 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) 229 | with autocast(enabled=False): 230 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 231 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 232 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 233 | 234 | if not self.is_cross_attention: 235 | # if only "normal" attention layer implements causal mask 236 | query_length, key_length = query.size(-2), key.size(-2) 237 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 238 | mask_value = torch.finfo(attn_weights.dtype).min 239 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 240 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 241 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 242 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 243 | 244 | if alibi_bias is not None: 245 | attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] 246 | 247 | if attention_mask is not None: 248 | # Apply the attention mask 249 | attn_weights = attn_weights + attention_mask 250 | 251 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 252 | 253 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise 254 | if attn_weights.dtype != torch.float32: 255 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") 256 | attn_weights = attn_weights.type(value.dtype) 257 | attn_weights = self.attn_dropout(attn_weights) 258 | 259 | # Mask heads if we want to 260 | if head_mask is not None: 261 | attn_weights = attn_weights * head_mask 262 | 263 | attn_output = torch.matmul(attn_weights, value) 264 | 265 | return attn_output, attn_weights 266 | 267 | 268 | def forward( 269 | self, 270 | hidden_states: Optional[Tuple[torch.FloatTensor]], 271 | layer_past: Optional[Tuple[torch.Tensor]] = None, 272 | attention_mask: Optional[torch.FloatTensor] = None, 273 | head_mask: Optional[torch.FloatTensor] = None, 274 | encoder_hidden_states: Optional[torch.Tensor] = None, 275 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 276 | use_cache: Optional[bool] = False, 277 | output_attentions: Optional[bool] = False, 278 | alibi_bias: Optional[Tuple[torch.Tensor]] = None, 279 | position_ids: Optional[torch.LongTensor] = None, 280 | 281 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 282 | if encoder_hidden_states is not None: 283 | if not hasattr(self, "q_attn"): 284 | raise ValueError( 285 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 286 | "Please make sure to instantiate class with `APTAttention(..., is_cross_attention=True)`." 287 | ) 288 | 289 | query = self.q_attn(hidden_states) 290 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 291 | attention_mask = encoder_attention_mask 292 | else: 293 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 294 | 295 | query = self._split_heads(query, self.num_heads, self.head_dim) 296 | key = self._split_heads(key, self.num_heads, self.head_dim) 297 | value = self._split_heads(value, self.num_heads, self.head_dim) 298 | 299 | kv_seq_len=key.shape[-2] 300 | if layer_past is not None: 301 | kv_seq_len+=layer_past[0].shape[-2] 302 | 303 | # Apply rope embedding to query and key 304 | if self.rot_emb: 305 | bsz, q_len, _ = hidden_states.size() 306 | if self.position_embedding == 'rope': 307 | query, key = self.rot_emb(query,key) 308 | elif self.position_embedding == 'rerope': 309 | query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 310 | query *= ((position_ids + 1)[:, None, :, None].log() / torch.log(torch.tensor(self.max_sequence_length)).item()).clip(1).to(query.dtype) 311 | query, key = self.rot_emb(query,key,seq_len = self.max_sequence_length,position_ids=position_ids) 312 | elif self.position_embedding=="linear_rope_scaling" or self.position_embedding=="dynamic_rope_scaling": 313 | query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 314 | query, key = self.rot_emb(query, key, seq_len=kv_seq_len,position_ids=position_ids) 315 | 316 | if layer_past is not None: 317 | past_key, past_value = layer_past 318 | key = torch.cat((past_key, key), dim=-2) 319 | value = torch.cat((past_value, value), dim=-2) 320 | 321 | 322 | if use_cache is True: 323 | present = (key, value) 324 | else: 325 | present = None 326 | 327 | if self.reorder_and_upcast_attn: 328 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) 329 | elif self.standard_attn: 330 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) 331 | elif self.gqa_attn: 332 | attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias) 333 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 334 | attn_output = self.c_proj(attn_output) 335 | attn_output = self.resid_dropout(attn_output) 336 | 337 | outputs = (attn_output, present) 338 | if output_attentions: 339 | outputs += (attn_weights,) 340 | 341 | return outputs # a, present, (attentions) 342 | 343 | 344 | class APTMLP(nn.Module): 345 | def __init__(self, intermediate_size, config): 346 | super().__init__() 347 | embed_dim = config.hidden_size 348 | self.c_fc = Conv1D(intermediate_size, embed_dim) 349 | self.c_proj = Conv1D(embed_dim, intermediate_size) 350 | self.act = ACT2FN[config.activation_function] 351 | self.dropout = nn.Dropout(config.resid_pdrop) 352 | 353 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 354 | hidden_states = self.c_fc(hidden_states) 355 | hidden_states = self.act(hidden_states) 356 | hidden_states = self.c_proj(hidden_states) 357 | hidden_states = self.dropout(hidden_states) 358 | return hidden_states 359 | 360 | 361 | class APTBlock(nn.Module): 362 | def __init__(self, config, layer_idx=None): 363 | super().__init__() 364 | hidden_size = config.hidden_size 365 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 366 | 367 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 368 | self.attn = APTAttention(config, layer_idx=layer_idx) 369 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 370 | 371 | if config.add_cross_attention: 372 | self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx) 373 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 374 | 375 | self.mlp = APTMLP(inner_dim, config) 376 | 377 | def forward( 378 | self, 379 | hidden_states: Optional[Tuple[torch.FloatTensor]], 380 | layer_past: Optional[Tuple[torch.Tensor]] = None, 381 | attention_mask: Optional[torch.FloatTensor] = None, 382 | head_mask: Optional[torch.FloatTensor] = None, 383 | encoder_hidden_states: Optional[torch.Tensor] = None, 384 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 385 | use_cache: Optional[bool] = False, 386 | output_attentions: Optional[bool] = False, 387 | alibi_bias: Optional[torch.FloatTensor] = None, 388 | position_ids: Optional[torch.LongTensor] = None, 389 | 390 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 391 | residual = hidden_states 392 | hidden_states = self.ln_1(hidden_states) 393 | attn_outputs = self.attn( 394 | hidden_states, 395 | layer_past=layer_past, 396 | attention_mask=attention_mask, 397 | head_mask=head_mask, 398 | use_cache=use_cache, 399 | output_attentions=output_attentions, 400 | alibi_bias=alibi_bias, 401 | position_ids=position_ids, 402 | ) 403 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 404 | outputs = attn_outputs[1:] 405 | # residual connection 406 | hidden_states = attn_output + residual 407 | 408 | if encoder_hidden_states is not None: 409 | # add one self-attention block for cross-attention 410 | if not hasattr(self, "crossattention"): 411 | raise ValueError( 412 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 413 | "cross-attention layers by setting `config.add_cross_attention=True`" 414 | ) 415 | residual = hidden_states 416 | hidden_states = self.ln_cross_attn(hidden_states) 417 | cross_attn_outputs = self.crossattention( 418 | hidden_states, 419 | attention_mask=attention_mask, 420 | head_mask=head_mask, 421 | encoder_hidden_states=encoder_hidden_states, 422 | encoder_attention_mask=encoder_attention_mask, 423 | output_attentions=output_attentions, 424 | ) 425 | attn_output = cross_attn_outputs[0] 426 | # residual connection 427 | hidden_states = residual + attn_output 428 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 429 | 430 | residual = hidden_states 431 | hidden_states = self.ln_2(hidden_states) 432 | feed_forward_hidden_states = self.mlp(hidden_states) 433 | # residual connection 434 | hidden_states = residual + feed_forward_hidden_states 435 | 436 | if use_cache: 437 | outputs = (hidden_states,) + outputs 438 | else: 439 | outputs = (hidden_states,) + outputs[1:] 440 | 441 | return outputs # hidden_states, present, (attentions, cross_attentions) 442 | 443 | 444 | 445 | """The bare APT Model transformer outputting raw hidden-states without any specific head on top.""" 446 | class APTModel(GPT2PreTrainedModel): 447 | def __init__(self, config): 448 | super().__init__(config) 449 | 450 | self.embed_dim = config.hidden_size 451 | 452 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 453 | self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" 454 | 455 | if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": 456 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 457 | self.alibi = None 458 | elif self.position_embedding=="alibi": 459 | maxpos = config.n_positions 460 | attn_heads = config.n_head 461 | alibi = create_alibi_tensor(attn_heads,maxpos) 462 | self.register_buffer('alibi',alibi) 463 | else: 464 | raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') 465 | 466 | self.drop = nn.Dropout(config.embd_pdrop) 467 | self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 468 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 469 | 470 | # Model parallel 471 | self.model_parallel = False 472 | self.device_map = None 473 | self.gradient_checkpointing = False 474 | 475 | # Initialize weights and apply final processing 476 | self.post_init() 477 | def forward( 478 | self, 479 | input_ids: Optional[torch.LongTensor] = None, 480 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 481 | attention_mask: Optional[torch.FloatTensor] = None, 482 | token_type_ids: Optional[torch.LongTensor] = None, 483 | position_ids: Optional[torch.LongTensor] = None, 484 | head_mask: Optional[torch.FloatTensor] = None, 485 | inputs_embeds: Optional[torch.FloatTensor] = None, 486 | encoder_hidden_states: Optional[torch.Tensor] = None, 487 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 488 | use_cache: Optional[bool] = None, 489 | output_attentions: Optional[bool] = None, 490 | output_hidden_states: Optional[bool] = None, 491 | return_dict: Optional[bool] = None, 492 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 493 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 494 | output_hidden_states = ( 495 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 496 | ) 497 | use_cache = use_cache if use_cache is not None else self.config.use_cache 498 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 499 | 500 | if input_ids is not None and inputs_embeds is not None: 501 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 502 | elif input_ids is not None: 503 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) 504 | input_shape = input_ids.size() 505 | input_ids = input_ids.view(-1, input_shape[-1]) 506 | batch_size = input_ids.shape[0] 507 | elif inputs_embeds is not None: 508 | input_shape = inputs_embeds.size()[:-1] 509 | batch_size = inputs_embeds.shape[0] 510 | else: 511 | raise ValueError("You have to specify either input_ids or inputs_embeds") 512 | 513 | device = input_ids.device if input_ids is not None else inputs_embeds.device 514 | 515 | if token_type_ids is not None: 516 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 517 | if position_ids is not None: 518 | position_ids = position_ids.view(-1, input_shape[-1]) 519 | 520 | if past_key_values is None: 521 | past_length = 0 522 | past_key_values = tuple([None] * len(self.h)) 523 | else: 524 | past_length = past_key_values[0][0].size(-2) 525 | if position_ids is None: 526 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 527 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 528 | 529 | # GPT2Attention mask. 530 | if attention_mask is not None: 531 | if batch_size <= 0: 532 | raise ValueError("batch_size has to be defined and > 0") 533 | attention_mask = attention_mask.view(batch_size, -1) 534 | # We create a 3D attention mask from a 2D tensor mask. 535 | # Sizes are [batch_size, 1, 1, to_seq_length] 536 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 537 | # this attention mask is more simple than the triangular masking of causal attention 538 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 539 | attention_mask = attention_mask[:, None, None, :] 540 | 541 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 542 | # masked positions, this operation will create a tensor which is 0.0 for 543 | # positions we want to attend and the dtype's smallest value for masked positions. 544 | # Since we are adding it to the raw scores before the softmax, this is 545 | # effectively the same as removing these entirely. 546 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 547 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 548 | 549 | # If a 2D or 3D attention mask is provided for the cross-attention 550 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 551 | if self.config.add_cross_attention and encoder_hidden_states is not None: 552 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 553 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 554 | if encoder_attention_mask is None: 555 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 556 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 557 | else: 558 | encoder_attention_mask = None 559 | 560 | # Prepare head mask if needed 561 | # 1.0 in head_mask indicate we keep the head 562 | # attention_probs has shape bsz x n_heads x N x N 563 | # head_mask has shape n_layer x batch x n_heads x N x N 564 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 565 | 566 | if inputs_embeds is None: 567 | inputs_embeds = self.wte(input_ids) 568 | 569 | if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": 570 | position_embeds = self.wpe(position_ids) 571 | hidden_states = inputs_embeds + position_embeds 572 | else: 573 | hidden_states = inputs_embeds 574 | 575 | 576 | if token_type_ids is not None: 577 | token_type_embeds = self.wte(token_type_ids) 578 | hidden_states = hidden_states + token_type_embeds 579 | 580 | hidden_states = self.drop(hidden_states) 581 | 582 | output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) 583 | 584 | if self.gradient_checkpointing and self.training: 585 | if use_cache: 586 | logger.warning_once( 587 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 588 | ) 589 | use_cache = False 590 | 591 | presents = () if use_cache else None 592 | all_self_attentions = () if output_attentions else None 593 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 594 | all_hidden_states = () if output_hidden_states else None 595 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 596 | # Model parallel 597 | if self.model_parallel: 598 | torch.cuda.set_device(hidden_states.device) 599 | # Ensure layer_past is on same device as hidden_states (might not be correct) 600 | if layer_past is not None: 601 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 602 | # Ensure that attention_mask is always on the same device as hidden_states 603 | if attention_mask is not None: 604 | attention_mask = attention_mask.to(hidden_states.device) 605 | if isinstance(head_mask, torch.Tensor): 606 | head_mask = head_mask.to(hidden_states.device) 607 | if output_hidden_states: 608 | all_hidden_states = all_hidden_states + (hidden_states,) 609 | 610 | if self.gradient_checkpointing and self.training: 611 | 612 | def create_custom_forward(module): 613 | def custom_forward(*inputs): 614 | # None for past_key_value 615 | return module(*inputs, use_cache, output_attentions) 616 | 617 | return custom_forward 618 | 619 | outputs = torch.utils.checkpoint.checkpoint( 620 | create_custom_forward(block), 621 | hidden_states, 622 | None, 623 | attention_mask, 624 | head_mask[i], 625 | encoder_hidden_states, 626 | encoder_attention_mask, 627 | ) 628 | else: 629 | outputs = block( 630 | hidden_states, 631 | layer_past=layer_past, 632 | attention_mask=attention_mask, 633 | head_mask=head_mask[i], 634 | encoder_hidden_states=encoder_hidden_states, 635 | encoder_attention_mask=encoder_attention_mask, 636 | use_cache=use_cache, 637 | output_attentions=output_attentions, 638 | alibi_bias=self.alibi if hasattr(self, "alibi") else None, 639 | position_ids=position_ids 640 | ) 641 | 642 | hidden_states = outputs[0] 643 | if use_cache is True: 644 | presents = presents + (outputs[1],) 645 | 646 | if output_attentions: 647 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 648 | if self.config.add_cross_attention: 649 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 650 | 651 | # Model Parallel: If it's the last layer for that device, put things on the next device 652 | if self.model_parallel: 653 | for k, v in self.device_map.items(): 654 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 655 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 656 | 657 | hidden_states = self.ln_f(hidden_states) 658 | 659 | hidden_states = hidden_states.view(output_shape) 660 | # Add last hidden state 661 | if output_hidden_states: 662 | all_hidden_states = all_hidden_states + (hidden_states,) 663 | 664 | if not return_dict: 665 | return tuple( 666 | v 667 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 668 | if v is not None 669 | ) 670 | 671 | return BaseModelOutputWithPastAndCrossAttentions( 672 | last_hidden_state=hidden_states, 673 | past_key_values=presents, 674 | hidden_states=all_hidden_states, 675 | attentions=all_self_attentions, 676 | cross_attentions=all_cross_attentions, 677 | ) 678 | 679 | 680 | """ 681 | The APT Model transformer with a language modeling head on top (linear layer with weights tied to the input 682 | embeddings). 683 | """ 684 | class APTLMHeadModel(GPT2PreTrainedModel): 685 | _tied_weights_keys = ["lm_head.weight"] 686 | 687 | def __init__(self, config): 688 | super().__init__(config) 689 | self.transformer = APTModel(config) 690 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 691 | 692 | # Model parallel 693 | self.model_parallel = False 694 | self.device_map = None 695 | 696 | self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, 697 | prepend_bos=True, 698 | append_eos=True, 699 | eos_idx=2) 700 | 701 | # Initialize weights and apply final processing 702 | self.post_init() 703 | def forward( 704 | self, 705 | input_ids: Optional[torch.LongTensor] = None, 706 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 707 | attention_mask: Optional[torch.FloatTensor] = None, 708 | token_type_ids: Optional[torch.LongTensor] = None, 709 | position_ids: Optional[torch.LongTensor] = None, 710 | head_mask: Optional[torch.FloatTensor] = None, 711 | inputs_embeds: Optional[torch.FloatTensor] = None, 712 | encoder_hidden_states: Optional[torch.Tensor] = None, 713 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 714 | labels: Optional[torch.LongTensor] = None, 715 | use_cache: Optional[bool] = None, 716 | output_attentions: Optional[bool] = None, 717 | output_hidden_states: Optional[bool] = None, 718 | return_dict: Optional[bool] = None, 719 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 720 | r""" 721 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 722 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 723 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 724 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 725 | """ 726 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 727 | 728 | transformer_outputs = self.transformer( 729 | input_ids, 730 | past_key_values=past_key_values, 731 | attention_mask=attention_mask, 732 | token_type_ids=token_type_ids, 733 | position_ids=position_ids, 734 | head_mask=head_mask, 735 | inputs_embeds=inputs_embeds, 736 | encoder_hidden_states=encoder_hidden_states, 737 | encoder_attention_mask=encoder_attention_mask, 738 | use_cache=use_cache, 739 | output_attentions=output_attentions, 740 | output_hidden_states=output_hidden_states, 741 | return_dict=return_dict, 742 | ) 743 | hidden_states = transformer_outputs[0] 744 | 745 | # Set device for model parallelism 746 | if self.model_parallel: 747 | torch.cuda.set_device(self.transformer.first_device) 748 | hidden_states = hidden_states.to(self.lm_head.weight.device) 749 | 750 | lm_logits = self.lm_head(hidden_states) 751 | 752 | loss = None 753 | if labels is not None: 754 | # move labels to correct device to enable model parallelism 755 | labels = labels.to(lm_logits.device) 756 | # Shift so that tokens < n predict n 757 | shift_logits = lm_logits[..., :-1, :].contiguous() 758 | shift_labels = labels[..., 1:].contiguous() 759 | # Flatten the tokens 760 | loss_fct = CrossEntropyLoss() 761 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 762 | 763 | if not return_dict: 764 | output = (lm_logits,) + transformer_outputs[1:] 765 | return ((loss,) + output) if loss is not None else output 766 | 767 | return CausalLMOutputWithCrossAttentions( 768 | loss=loss, 769 | logits=lm_logits, 770 | past_key_values=transformer_outputs.past_key_values, 771 | hidden_states=transformer_outputs.hidden_states, 772 | attentions=transformer_outputs.attentions, 773 | cross_attentions=transformer_outputs.cross_attentions, 774 | ) 775 | 776 | def predict_contacts(self, input_ids): 777 | transformer_outputs = self.transformer( 778 | input_ids, 779 | return_dict=True, 780 | output_attentions=True, 781 | ) 782 | # Convert attention tuples to list 783 | attentions_list = list(transformer_outputs.attentions) 784 | 785 | # Stack the attention tensors 786 | stacked_attentions = torch.stack( 787 | [attn for attn in attentions_list], 788 | dim=1 789 | ) 790 | 791 | contact_predictions = self.contact_head(input_ids, stacked_attentions) 792 | 793 | return contact_predictions -------------------------------------------------------------------------------- /protein_lm/modeling/scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | 5 | import yaml 6 | from transformers import Trainer 7 | 8 | from protein_lm.modeling.getters.data_collator import get_data_collator 9 | from protein_lm.modeling.getters.dataset import get_dataset 10 | from protein_lm.modeling.getters.model import get_model 11 | from protein_lm.modeling.getters.tokenizer import get_tokenizer 12 | from protein_lm.modeling.getters.training_args import get_training_args 13 | from protein_lm.modeling.getters.wandb_log import setup_wandb 14 | 15 | 16 | def train( 17 | config_file: str, 18 | ): 19 | """ 20 | Main script to train APT. 21 | """ 22 | with open(config_file, "r") as cf: 23 | config_dict = yaml.safe_load(cf) 24 | print(config_dict) 25 | 26 | tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"]) 27 | 28 | dataset = get_dataset( 29 | config_dict=config_dict["dataset"], 30 | tokenizer=tokenizer, 31 | ) 32 | 33 | model = get_model( 34 | config_dict=config_dict["model"], 35 | ) 36 | model.train() 37 | 38 | data_collator = get_data_collator( 39 | config_dict=config_dict["data_collator"], 40 | ) 41 | if config_dict['dataset']['do_curriculum_learning']: 42 | #groupy_by_length uses the LengthGroupedSampler, 43 | #we have precomputed the lengths (or any discrete column) which can be used as sampling criteria 44 | config_dict["training_arguments"]['group_by_length'] = config_dict['dataset']['do_curriculum_learning'] 45 | config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name'] 46 | 47 | training_args = get_training_args( 48 | config_dict=config_dict["training_arguments"], 49 | ) 50 | 51 | if "wandb" in training_args.report_to: 52 | setup_wandb( 53 | config_dict["wandb"], 54 | ) 55 | 56 | trainer = Trainer( 57 | model=model, 58 | args=training_args, 59 | train_dataset=dataset["train"], 60 | eval_dataset=dataset.get("val", None), 61 | data_collator=data_collator, 62 | ) 63 | 64 | trainer.train() 65 | trainer.save_model() 66 | trainer.save_state() 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument( 72 | "--config-file", 73 | default="protein_lm/configs/train/toy_localcsv.yaml", 74 | type=str, 75 | help="Config yaml for training", 76 | ) 77 | args = parser.parse_args() 78 | 79 | train(config_file=args.config_file) 80 | -------------------------------------------------------------------------------- /protein_lm/modeling/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/modeling/utils/__init__.py -------------------------------------------------------------------------------- /protein_lm/modeling/utils/alibi_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def get_slopes(n): 5 | """ 6 | Function to compute the m constant for each attention head. Code has been adapted from the official ALiBi codebase at: 7 | https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py 8 | """ 9 | def get_slopes_power_of_2(n): 10 | start = (2**(-2**-(math.log2(n)-3))) 11 | ratio = start 12 | return [start*ratio**i for i in range(n)] 13 | 14 | if math.log2(n).is_integer(): 15 | return get_slopes_power_of_2(n) 16 | else: 17 | closest_power_of_2 = 2**math.floor(math.log2(n)) 18 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] 19 | 20 | 21 | def create_alibi_tensor(attn_heads,maxpos): 22 | slopes = torch.Tensor(get_slopes(attn_heads)) 23 | #The softmax operation is invariant to translation, and bias functions used are always linear. 24 | alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1) 25 | return alibi.view(attn_heads, 1, maxpos) 26 | 27 | 28 | -------------------------------------------------------------------------------- /protein_lm/modeling/utils/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def symmetrize(x): 9 | "Make layer symmetric in final two dimensions, used for contact prediction." 10 | return x + x.transpose(-1, -2) 11 | def apc(x): 12 | "Perform average product correct, used for contact prediction." 13 | a1 = x.sum(-1, keepdims=True) 14 | a2 = x.sum(-2, keepdims=True) 15 | a12 = x.sum((-1, -2), keepdims=True) 16 | 17 | avg = a1 * a2 18 | avg.div_(a12) # in-place to reduce memory 19 | normalized = x - avg 20 | return normalized 21 | 22 | class ContactPredictionHead(nn.Module): 23 | """Performs symmetrization, apc, and computes a logistic regression on the output features""" 24 | 25 | def __init__( 26 | self, 27 | in_features: int, 28 | prepend_bos: bool, 29 | append_eos: bool, 30 | bias=True, 31 | eos_idx: Optional[int] = None, 32 | ): 33 | super().__init__() 34 | self.in_features = in_features 35 | self.prepend_bos = prepend_bos 36 | self.append_eos = append_eos 37 | if append_eos and eos_idx is None: 38 | raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") 39 | self.eos_idx = eos_idx 40 | self.regression = nn.Linear(in_features, 1, bias) 41 | self.activation = nn.Sigmoid() 42 | 43 | def forward(self, tokens, attentions): 44 | # remove eos token attentions 45 | if self.append_eos: 46 | eos_mask = tokens.ne(self.eos_idx).to(attentions) 47 | eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) # Expand dimensions along last two dimensions 48 | attentions = attentions * eos_mask[:, None, None, :, :] 49 | attentions = attentions[..., :-1, :-1] 50 | # remove cls token attentions 51 | if self.prepend_bos: 52 | attentions = attentions[..., 1:, 1:] 53 | 54 | batch_size, layers, heads, seqlen, _ = attentions.size() 55 | attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) 56 | 57 | # features: B x C x T x T 58 | attentions = attentions.to( 59 | self.regression.weight.device 60 | ) # attentions always float32, may need to convert to float16 61 | attentions = apc(symmetrize(attentions)) 62 | attentions = attentions.permute(0, 2, 3, 1) 63 | return self.activation(self.regression(attentions).squeeze(3)) 64 | -------------------------------------------------------------------------------- /protein_lm/modeling/utils/rerope_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | def rotate_half(x): 5 | """Rotates half the hidden dims of the input.""" 6 | x1 = x[..., : x.shape[-1] // 2] 7 | x2 = x[..., x.shape[-1] // 2 :] 8 | return torch.cat((-x2, x1), dim=-1) 9 | 10 | def apply_rectified_rotary_pos_emb(q, k, cos, sin, position_ids): 11 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 12 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 13 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 14 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 15 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 16 | q_embed = (q * cos[:, :, -q.shape[2]:]) + (rotate_half(q) * sin[:, :, -q.shape[2]:]) if q is not None else None 17 | k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None 18 | return q_embed, k_embed 19 | 20 | class RectifiedRotaryEmbedding(torch.nn.Module): 21 | def __init__(self, dim, max_position_embeddings:int=2048, base:int=10000, device=None): 22 | super().__init__() 23 | 24 | self.dim = dim 25 | self.max_position_embeddings = max_position_embeddings 26 | self.base = base 27 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 28 | self.register_buffer("inv_freq", inv_freq, persistent=False) 29 | 30 | # Build here to make `torch.jit.trace` work. 31 | self._set_cos_sin_cache( 32 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 33 | ) 34 | 35 | def _set_cos_sin_cache(self, seq_len, device, dtype): 36 | self.max_seq_len_cached = seq_len 37 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 38 | 39 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 40 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 41 | emb = torch.cat((freqs, freqs), dim=-1) 42 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 43 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 44 | 45 | def forward(self,q, k, seq_len=None,position_ids = None): 46 | # x: [bs, num_attention_heads, seq_len, head_size] 47 | if seq_len > self.max_seq_len_cached: 48 | self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype) 49 | 50 | return ( 51 | apply_rectified_rotary_pos_emb(q,k,self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype), 52 | self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),position_ids) 53 | ) 54 | 55 | -------------------------------------------------------------------------------- /protein_lm/modeling/utils/rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | 10 | 11 | def rotate_half(x): 12 | x1, x2 = x.chunk(2, dim=-1) 13 | return torch.cat((-x2, x1), dim=-1) 14 | 15 | 16 | def apply_rotary_pos_emb(x, cos, sin): 17 | cos = cos[:, : x.shape[-2], :] 18 | sin = sin[:, : x.shape[-2], :] 19 | 20 | return (x * cos) + (rotate_half(x) * sin) 21 | 22 | 23 | class RotaryEmbedding(torch.nn.Module): 24 | """ 25 | The rotary position embeddings from RoFormer_ (Su et. al). 26 | A crucial insight from the method is that the query and keys are 27 | transformed by rotation matrices which depend on the relative positions. 28 | Other implementations are available in the Rotary Transformer repo_ and in 29 | GPT-NeoX_, GPT-NeoX was an inspiration 30 | .. _RoFormer: https://arxiv.org/abs/2104.09864 31 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 32 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 33 | .. warning: Please note that this embedding is not registered on purpose, as it is transformative 34 | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis 35 | """ 36 | 37 | def __init__(self, dim: int, *_, **__): 38 | super().__init__() 39 | # Generate and save the inverse frequency buffer (non trainable) 40 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 41 | self.register_buffer("inv_freq", inv_freq) 42 | 43 | self._seq_len_cached = None 44 | self._cos_cached = None 45 | self._sin_cached = None 46 | 47 | def _update_cos_sin_tables(self, x, seq_dimension=1): 48 | seq_len = x.shape[seq_dimension] 49 | 50 | # Reset the tables if the sequence length has changed, 51 | # or if we're on a new device (possibly due to tracing for instance) 52 | if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: 53 | self._seq_len_cached = seq_len 54 | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) 55 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 56 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 57 | 58 | self._cos_cached = emb.cos()[None, :, :] 59 | self._sin_cached = emb.sin()[None, :, :] 60 | 61 | return self._cos_cached, self._sin_cached 62 | 63 | def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 64 | self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) 65 | 66 | return ( 67 | apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), 68 | apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), 69 | ) -------------------------------------------------------------------------------- /protein_lm/modeling/utils/scaled_rope_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | def rotate_half(x): 5 | """Rotates half the hidden dims of the input.""" 6 | x1 = x[..., : x.shape[-1] // 2] 7 | x2 = x[..., x.shape[-1] // 2 :] 8 | return torch.cat((-x2, x1), dim=-1) 9 | 10 | 11 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 12 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 13 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 14 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 15 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 16 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 17 | q_embed = (q * cos) + (rotate_half(q) * sin) 18 | k_embed = (k * cos) + (rotate_half(k) * sin) 19 | return q_embed, k_embed 20 | 21 | 22 | 23 | class LlamaRotaryEmbedding(torch.nn.Module): 24 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 25 | super().__init__() 26 | 27 | self.dim = dim 28 | self.max_position_embeddings = max_position_embeddings 29 | self.base = base 30 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 31 | self.register_buffer("inv_freq", inv_freq, persistent=False) 32 | 33 | # Build here to make `torch.jit.trace` work. 34 | self._set_cos_sin_cache( 35 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 36 | ) 37 | 38 | def _set_cos_sin_cache(self, seq_len, device, dtype): 39 | self.max_seq_len_cached = seq_len 40 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 41 | 42 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 43 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 44 | emb = torch.cat((freqs, freqs), dim=-1) 45 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 46 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 47 | 48 | def forward(self,q, k, seq_len=None,position_ids = None): 49 | # x: [bs, num_attention_heads, seq_len, head_size] 50 | if seq_len > self.max_seq_len_cached: 51 | self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype) 52 | 53 | return ( 54 | apply_rotary_pos_emb(q,k,self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype), 55 | self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),position_ids) 56 | ) 57 | 58 | 59 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 60 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 61 | 62 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 63 | self.scaling_factor = scaling_factor 64 | super().__init__(dim, max_position_embeddings, base, device) 65 | 66 | def _set_cos_sin_cache(self, seq_len, device, dtype): 67 | self.max_seq_len_cached = seq_len 68 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 69 | t = t / self.scaling_factor 70 | 71 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 72 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 73 | emb = torch.cat((freqs, freqs), dim=-1) 74 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 75 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 76 | 77 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 78 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 79 | 80 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 81 | self.scaling_factor = scaling_factor 82 | super().__init__(dim, max_position_embeddings, base, device) 83 | 84 | def _set_cos_sin_cache(self, seq_len, device, dtype): 85 | self.max_seq_len_cached = seq_len 86 | 87 | if seq_len > self.max_position_embeddings: 88 | base = self.base * ( 89 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 90 | ) ** (self.dim / (self.dim - 2)) 91 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 92 | self.register_buffer("inv_freq", inv_freq, persistent=False) 93 | 94 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 95 | 96 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 97 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 98 | emb = torch.cat((freqs, freqs), dim=-1) 99 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 100 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) -------------------------------------------------------------------------------- /protein_lm/tests/tensors/1a3a.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/1a3a.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/1xcr.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/1xcr.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/5ahw.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/5ahw.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/5ahw_1_A_jacobian.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/5ahw_1_A_jacobian.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/dynamic_rope.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/dynamic_rope.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/linear_rope.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/linear_rope.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/rerope.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/rerope.pkl -------------------------------------------------------------------------------- /protein_lm/tests/tensors/rope.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBioML/protein-lm-scaling/57a530d065921b5b867c369bbef7b583f10521dd/protein_lm/tests/tensors/rope.pkl -------------------------------------------------------------------------------- /protein_lm/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from model_pytorch import APTAttention 6 | 7 | class ParameterConfig: 8 | def __init__(self): 9 | self.max_position_embeddings = 512 10 | self.position_embedding = 'rope' 11 | self.max_sequence_length = 512 12 | self.hidden_size = 768 13 | self.num_attention_heads = 12 14 | self.scale_attn_weights = True 15 | self.scale_attn_by_inverse_layer_idx = True 16 | self.reorder_and_upcast_attn = True 17 | self.attn_pdrop = 0.1 18 | self.resid_pdrop = 0.1 19 | self.rope_scaling_factor = 1 20 | self.rope_theta = 1 21 | self.attn_type = 'gqa' 22 | 23 | 24 | def test_vanilla_attn(): 25 | # Initialize with mock config 26 | config = ParameterConfig() 27 | attention = APTAttention(config, is_cross_attention=False, layer_idx=0) 28 | 29 | # generate random input tensors 30 | batch_size = 4 31 | seq_length = 100 32 | num_heads = config.num_attention_heads 33 | query_dim = config.hidden_size // num_heads 34 | query = torch.randn(batch_size, num_heads, seq_length, query_dim) 35 | key = torch.randn(batch_size, num_heads, seq_length, query_dim) 36 | value = torch.randn(batch_size, num_heads, seq_length, query_dim) 37 | 38 | # Create a random attention mask for testing 39 | attention_mask = torch.ones(batch_size,seq_length, seq_length) 40 | padding_positions = 10 41 | attention_mask[:, -padding_positions:, :] = float('-inf') 42 | attention_mask[:, :, -padding_positions:] = float('-inf') 43 | attention_mask = attention_mask.unsqueeze(1) 44 | # Pass them through the _attn method 45 | attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask) 46 | 47 | # Check the shapes and types of the output 48 | assert isinstance(attn_output, torch.Tensor) 49 | assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) 50 | assert isinstance(attn_weights, torch.Tensor) 51 | assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) 52 | print("Test passed!") 53 | 54 | def test_gqa_attn(): 55 | # Initialize with mock config 56 | config = ParameterConfig() 57 | attention = APTAttention(config, is_cross_attention=False, layer_idx=0) 58 | 59 | # generate random input tensors 60 | batch_size = 4 61 | seq_length = 100 62 | num_heads = config.num_attention_heads 63 | query_dim = config.hidden_size // num_heads 64 | query = torch.randn(batch_size, num_heads, seq_length, query_dim) 65 | key = torch.randn(batch_size, num_heads, seq_length, query_dim) 66 | value = torch.randn(batch_size, num_heads, seq_length, query_dim) 67 | 68 | # Create a random attention mask for testing 69 | attention_mask = torch.ones(batch_size,seq_length, seq_length) 70 | padding_positions = 10 71 | attention_mask[:, -padding_positions:, :] = float('-inf') 72 | attention_mask[:, :, -padding_positions:] = float('-inf') 73 | 74 | # Pass them through the _gqa_attn method 75 | attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask) 76 | 77 | # Check the shapes and types of the output 78 | assert isinstance(attn_output, torch.Tensor) 79 | assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) 80 | assert isinstance(attn_weights, torch.Tensor) 81 | assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) 82 | print("Test passed!") 83 | 84 | 85 | test_gqa_attn() 86 | test_vanilla_attn() 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /protein_lm/tests/test_cl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch 4 | import torch.nn as nn 5 | import yaml 6 | from transformers import Trainer 7 | from protein_lm.modeling.getters.data_collator import get_data_collator 8 | from protein_lm.modeling.getters.model import get_model 9 | from protein_lm.modeling.getters.tokenizer import get_tokenizer 10 | from protein_lm.modeling.getters.training_args import get_training_args 11 | from datasets import Dataset, load_dataset 12 | from datasets.dataset_dict import DatasetDict 13 | from pydantic import BaseModel 14 | from protein_lm.modeling.getters.dataset import DatasetConfig,get_csv_dataset,set_input_ids,set_labels,batch_set_curriculum_learning_column 15 | ##data collator imports 16 | from dataclasses import dataclass 17 | from typing import Dict, Literal,Any, Callable, Dict, List, NewType, Optional, Tuple, Union 18 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 19 | import pandas as pd 20 | import random 21 | 22 | CONFIG_PATH = "protein_lm/configs/train/toy_localcsv.yaml" 23 | strategies = ['sequence_length'] 24 | strategy2col = {'sequence_length': 'sequence_length'} #mapping of strategy to the computed column name storing the values of respective strategy 25 | total = 0 #number of batches/steps 26 | unsorted = 0 #number of unsorted batches/steps 27 | InputDataClass = NewType("InputDataClass", Any) 28 | def cl_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]: 29 | global total 30 | global unsorted 31 | """ 32 | Very simple data collator that simply collates batches of dict-like objects and performs special handling for 33 | potential keys named: 34 | 35 | - ``label``: handles a single value (int or float) per object 36 | - ``label_ids``: handles a list of values per object 37 | 38 | Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs 39 | to the model. See glue and ner for example of how it's useful. 40 | """ 41 | 42 | # In this function we'll make the assumption that all `features` in the batch 43 | # have the same attributes. 44 | # So we will look at the first element as a proxy for what attributes exist 45 | # on the whole batch. 46 | if not isinstance(features[0], (dict, BatchEncoding)): 47 | features = [vars(f) for f in features] 48 | 49 | first = features[0] 50 | batch = {} 51 | 52 | # Special handling for labels. 53 | # Ensure that tensor is created with the correct type 54 | # (it should be automatically the case, but let's make sure of it.) 55 | if "label" in first and first["label"] is not None: 56 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 57 | dtype = torch.long if isinstance(label, int) else torch.float 58 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 59 | elif "label_ids" in first and first["label_ids"] is not None: 60 | if isinstance(first["label_ids"], torch.Tensor): 61 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 62 | else: 63 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 64 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 65 | 66 | # Handling of all other possible keys. 67 | # Again, we will use the first element to figure out which key/values are not None for this model. 68 | for k, v in first.items(): 69 | 70 | if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): 71 | if isinstance(v, torch.Tensor): 72 | batch[k] = torch.stack([f[k] for f in features]) 73 | else: 74 | if k == 'sequence_length': 75 | batch[k] = [-f[k] for f in features] 76 | else: 77 | batch[k] = torch.tensor([f[k] for f in features]) 78 | lens = batch['sequence_length'] 79 | print('######lens(cl_data_collator)#########') 80 | print(lens) 81 | total = total + 1 82 | try: 83 | assert lens == sorted(lens) 84 | except: 85 | unsorted = unsorted + 1 86 | print('not sorted') 87 | return {'input_ids':batch['input_ids'],'labels': batch['labels']} 88 | 89 | 90 | def create_random_dataframe(sequence_column_name = 'sequence',curriculum_learning_column_name = 'sequence_length',curriculum_learning_strategy = 'sequence_length',max_sequence_length = 30, n = 5000): 91 | assert max_sequence_length > 2 92 | random.seed(42) 93 | df = pd.DataFrame() 94 | def create_sequence(length): 95 | seq = ''.join(random.choice('ACDEFGHIKLMNPQRSTVWY') for _ in range(length)) 96 | return seq 97 | 98 | if curriculum_learning_strategy == 'sequence_length': 99 | df[sequence_column_name] = [create_sequence(random.randint(2, max_sequence_length)) for i in range(n)] 100 | df[curriculum_learning_column_name] = df[sequence_column_name].apply(lambda x: len(x)) 101 | return df 102 | 103 | @pytest.mark.parametrize("strategy",strategies) 104 | def test_curriculum_learning(strategy): 105 | 106 | with open(CONFIG_PATH, "r") as cf: 107 | print('loading file.....') 108 | config_dict = yaml.safe_load(cf) 109 | 110 | config_dict['dataset']['max_sequence_length'] = 40 111 | config_dict['dataset']['do_curriculum_learning'] = True 112 | config_dict['dataset']['curriculum_learning_column_name'] = strategy2col[strategy] 113 | config_dict['dataset']['curriculum_learning_strategy'] = strategy 114 | config_dict['dataset']['val_size'] = 100 115 | config_dict['dataset']['test_size'] = 100 116 | config_dict['dataset']['subsample_size'] = 500 117 | config_dict["training_arguments"]['group_by_length'] = True 118 | config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name'] 119 | config_dict["training_arguments"]['remove_unused_columns'] = False # this is necessary to keep curriculum_learning_column_name 120 | config_dict["training_arguments"]['per_device_train_batch_size'] = 20 121 | config_dict["training_arguments"]['max_steps'] = -1 122 | config_dict["training_arguments"]['num_train_epochs'] = 2 123 | 124 | print(config_dict) 125 | 126 | tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"]) 127 | dataset = DatasetDict() 128 | val_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['val_size'] ) 129 | test_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['test_size'] ) 130 | train_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['subsample_size'] ) 131 | 132 | dataset['train'] = Dataset.from_pandas(train_df) 133 | dataset['val'] = Dataset.from_pandas(val_df) 134 | dataset['test'] = Dataset.from_pandas(test_df) 135 | dataset = dataset.map( 136 | lambda e: set_input_ids( 137 | result=e, 138 | tokenizer=tokenizer, 139 | sequence_column_name=config_dict['dataset']['sequence_column_name'], 140 | max_sequence_length=config_dict['dataset']['max_sequence_length'], 141 | ), 142 | batched=True, 143 | ) 144 | dataset = dataset.map(set_labels, batched=True) 145 | dataset = dataset.map(lambda e: batch_set_curriculum_learning_column( 146 | result = e, 147 | input_column_name = config_dict['dataset']['sequence_column_name'], 148 | curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'], 149 | strategy = config_dict['dataset']['curriculum_learning_strategy'] 150 | 151 | ),batched=True) 152 | dataset = dataset.select_columns(['input_ids', 'labels', strategy2col[strategy]]) 153 | model = get_model( 154 | config_dict=config_dict["model"], 155 | ) 156 | 157 | training_args = get_training_args( 158 | config_dict=config_dict["training_arguments"], 159 | ) 160 | 161 | trainer = Trainer( 162 | model=model, 163 | args=training_args, 164 | train_dataset=dataset["train"], 165 | eval_dataset=dataset.get("val", None), 166 | data_collator=cl_data_collator, 167 | ) 168 | 169 | trainer.train() 170 | percentage_unsorted = int((unsorted / total) * 100) #computing the number of times the list in collator was not sorted 171 | #there are sometimes cases where the list is off by a few entries aa the LengthGroupedSampler has a bit of randomness 172 | print(f'percentage_unsorted:{percentage_unsorted}') 173 | assert percentage_unsorted < 10 # just a rough heuristic 174 | 175 | 176 | -------------------------------------------------------------------------------- /protein_lm/tests/test_cl_continuous.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch 4 | import torch.nn as nn 5 | import yaml 6 | from transformers import Trainer 7 | from protein_lm.modeling.getters.data_collator import get_data_collator 8 | from protein_lm.modeling.getters.model import get_model 9 | from protein_lm.modeling.getters.tokenizer import get_tokenizer 10 | from protein_lm.modeling.getters.training_args import get_training_args 11 | from datasets import Dataset, load_dataset 12 | from datasets.dataset_dict import DatasetDict 13 | from pydantic import BaseModel 14 | from protein_lm.modeling.getters.dataset import DatasetConfig,get_csv_dataset,set_input_ids,set_labels,batch_set_curriculum_learning_column 15 | ##data collator imports 16 | from dataclasses import dataclass 17 | from typing import Dict, Literal,Any, Callable, Dict, List, NewType, Optional, Tuple, Union 18 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 19 | import pandas as pd 20 | import random 21 | 22 | CONFIG_PATH = "protein_lm/configs/train/toy_localcsv.yaml" 23 | strategies = ['ppl'] 24 | strategy2col = {'ppl': 'ppl'} #mapping of strategy to the computed column name storing the values of respective strategy 25 | total = 0 #number of batches/steps 26 | unsorted = 0 #number of unsorted batches/steps 27 | InputDataClass = NewType("InputDataClass", Any) 28 | 29 | global max_value_of_previous_batch 30 | max_value_of_previous_batch = None 31 | global batch_comparison_values 32 | batch_comparison_values = [] 33 | 34 | def cl_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]: 35 | global total 36 | global unsorted 37 | global max_value_of_previous_batch 38 | global batch_comparison_values 39 | """ 40 | Very simple data collator that simply collates batches of dict-like objects and performs special handling for 41 | potential keys named: 42 | 43 | - ``label``: handles a single value (int or float) per object 44 | - ``label_ids``: handles a list of values per object 45 | 46 | Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs 47 | to the model. See glue and ner for example of how it's useful. 48 | """ 49 | 50 | # In this function we'll make the assumption that all `features` in the batch 51 | # have the same attributes. 52 | # So we will look at the first element as a proxy for what attributes exist 53 | # on the whole batch. 54 | if not isinstance(features[0], (dict, BatchEncoding)): 55 | features = [vars(f) for f in features] 56 | 57 | first = features[0] 58 | batch = {} 59 | 60 | # Special handling for labels. 61 | # Ensure that tensor is created with the correct type 62 | # (it should be automatically the case, but let's make sure of it.) 63 | if "label" in first and first["label"] is not None: 64 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 65 | dtype = torch.long if isinstance(label, int) else torch.float 66 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 67 | elif "label_ids" in first and first["label_ids"] is not None: 68 | if isinstance(first["label_ids"], torch.Tensor): 69 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 70 | else: 71 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 72 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 73 | 74 | # Handling of all other possible keys. 75 | # Again, we will use the first element to figure out which key/values are not None for this model. 76 | for k, v in first.items(): 77 | 78 | if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): 79 | if isinstance(v, torch.Tensor): 80 | batch[k] = torch.stack([f[k] for f in features]) 81 | else: 82 | if k == 'ppl': 83 | batch[k] = [-f[k] for f in features] 84 | else: 85 | batch[k] = torch.tensor([f[k] for f in features]) 86 | lens = batch['ppl'] 87 | print('######lens(cl_data_collator)#########') 88 | print(lens) 89 | total = total + 1 90 | try: 91 | assert lens == sorted(lens) 92 | except: 93 | unsorted = unsorted + 1 94 | print('not sorted') 95 | 96 | # Compare between currect batch and previous one 97 | # Append min of current batch and placeholder 98 | batch_comparison_values.append([lens[0], None]) 99 | 100 | if max_value_of_previous_batch is not None: 101 | # Append max of the previous batch 102 | batch_comparison_values[-1][1] = max_value_of_previous_batch 103 | 104 | max_value_of_previous_batch = lens[-1] 105 | 106 | return {'input_ids':batch['input_ids'],'labels': batch['labels']} 107 | 108 | 109 | def create_random_dataframe(sequence_column_name = 'sequence', 110 | curriculum_learning_column_name = 'ppl', 111 | curriculum_learning_strategy = 'ppl', 112 | max_sequence_length = 30, 113 | max_perplexity = 100.0, 114 | n = 5000): 115 | assert max_sequence_length > 2 116 | random.seed(42) 117 | df = pd.DataFrame() 118 | def create_sequence(length): 119 | seq = ''.join(random.choice('ACDEFGHIKLMNPQRSTVWY') for _ in range(length)) 120 | return seq 121 | 122 | if curriculum_learning_strategy == 'ppl': 123 | df[sequence_column_name] = [create_sequence(random.randint(2, max_sequence_length)) for i in range(n)] 124 | df[curriculum_learning_column_name] = [random.uniform(1.0, max_perplexity) for _ in range(n)] 125 | return df 126 | 127 | @pytest.mark.parametrize("strategy",strategies) 128 | def test_curriculum_learning(strategy): 129 | 130 | with open(CONFIG_PATH, "r") as cf: 131 | print('loading file.....') 132 | config_dict = yaml.safe_load(cf) 133 | 134 | config_dict['dataset']['max_sequence_length'] = 40 135 | config_dict['dataset']['do_curriculum_learning'] = True 136 | config_dict['dataset']['curriculum_learning_column_name'] = strategy2col[strategy] 137 | config_dict['dataset']['curriculum_learning_strategy'] = strategy 138 | config_dict['dataset']['val_size'] = 100 139 | config_dict['dataset']['test_size'] = 100 140 | config_dict['dataset']['subsample_size'] = 500 141 | config_dict["training_arguments"]['group_by_length'] = True 142 | config_dict["training_arguments"]['length_column_name'] = config_dict['dataset']['curriculum_learning_column_name'] 143 | config_dict["training_arguments"]['remove_unused_columns'] = False # this is necessary to keep curriculum_learning_column_name 144 | config_dict["training_arguments"]['per_device_train_batch_size'] = 20 145 | config_dict["training_arguments"]['max_steps'] = -1 146 | config_dict["training_arguments"]['num_train_epochs'] = 2 147 | 148 | print(config_dict) 149 | 150 | tokenizer = get_tokenizer(config_dict=config_dict["tokenizer"]) 151 | dataset = DatasetDict() 152 | val_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['val_size'] ) 153 | test_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['test_size'] ) 154 | train_df = create_random_dataframe(sequence_column_name = config_dict['dataset']['sequence_column_name'],curriculum_learning_column_name = config_dict['dataset']["curriculum_learning_column_name"],max_sequence_length = config_dict['dataset']['max_sequence_length'], n = config_dict['dataset']['subsample_size'] ) 155 | 156 | dataset['train'] = Dataset.from_pandas(train_df) 157 | dataset['val'] = Dataset.from_pandas(val_df) 158 | dataset['test'] = Dataset.from_pandas(test_df) 159 | dataset = dataset.map( 160 | lambda e: set_input_ids( 161 | result=e, 162 | tokenizer=tokenizer, 163 | sequence_column_name=config_dict['dataset']['sequence_column_name'], 164 | max_sequence_length=config_dict['dataset']['max_sequence_length'], 165 | ), 166 | batched=True, 167 | ) 168 | dataset = dataset.map(set_labels, batched=True) 169 | dataset = dataset.map(lambda e: batch_set_curriculum_learning_column( 170 | result = e, 171 | input_column_name = config_dict['dataset']['sequence_column_name'], 172 | curriculum_learning_column_name = config_dict['dataset']['curriculum_learning_column_name'], 173 | strategy = config_dict['dataset']['curriculum_learning_strategy'] 174 | 175 | ),batched=True) 176 | dataset = dataset.select_columns(['input_ids', 'labels', strategy2col[strategy]]) 177 | model = get_model( 178 | config_dict=config_dict["model"], 179 | ) 180 | 181 | training_args = get_training_args( 182 | config_dict=config_dict["training_arguments"], 183 | ) 184 | 185 | trainer = Trainer( 186 | model=model, 187 | args=training_args, 188 | train_dataset=dataset["train"], 189 | eval_dataset=dataset.get("val", None), 190 | data_collator=cl_data_collator, 191 | ) 192 | 193 | trainer.train() 194 | 195 | threshold = 10 196 | num = 0 197 | # Iterate over the list 198 | print(batch_comparison_values) 199 | for i in batch_comparison_values: 200 | print(i) 201 | current_min_val, previous_max_val = i 202 | if previous_max_val is not None: 203 | if current_min_val < previous_max_val and previous_max_val - current_min_val <= threshold: 204 | num += 1 205 | assert num == 0 206 | percentage_unsorted = int((unsorted / total) * 100) #computing the number of times the list in collator was not sorted 207 | #there are sometimes cases where the list is off by a few entries aa the LengthGroupedSampler has a bit of randomness 208 | print(f'percentage_unsorted:{percentage_unsorted}') 209 | assert percentage_unsorted < 10 # just a rough heuristic -------------------------------------------------------------------------------- /protein_lm/tests/test_contact_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import esm 3 | from protein_lm.evaluation.scripts.contact_prediction import predict_contacts_regression,predict_contacts_jacobian 4 | from protein_lm.evaluation.scripts.utils import * 5 | from protein_lm.tokenizer.tokenizer import EsmTokenizer 6 | import pytest 7 | import os 8 | 9 | proteins = ["1a3a", "5ahw", "1xcr"] 10 | @pytest.mark.parametrize("protein",proteins) 11 | def test_contact_predictions_regression(protein): 12 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 13 | msa=read_msa("test_data/"+f"{protein.lower()}_1_A.a3m") 14 | seq=msa[0] 15 | model, _ = esm.pretrained.esm2_t33_650M_UR50D() 16 | tokenizer = EsmTokenizer() 17 | model.to(device) 18 | prediction=predict_contacts_regression(model,seq,tokenizer,device) 19 | contact_path = os.path.join(os.path.dirname(__file__),'tensors',protein+'.pkl') 20 | 21 | contact= torch.load(contact_path) 22 | torch.testing.assert_close(prediction,contact) 23 | 24 | proteins = ["5ahw"] 25 | @pytest.mark.parametrize("protein",proteins) 26 | def test_contact_predictions_jacobian(protein): 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | msa=read_msa("test_data/"+f"{protein.lower()}_1_A.a3m") 29 | seq=msa[0][1] 30 | model, _= esm.pretrained.esm2_t33_650M_UR50D() 31 | model.to(device) 32 | tokenizer = EsmTokenizer() 33 | x,ln = tokenizer.batch_encode([seq],add_special_tokens=True),len(seq) 34 | x=torch.tensor(x) 35 | prediction=predict_contacts_jacobian("ESM",model,x,ln,device) 36 | contact_path = os.path.join(os.path.dirname(__file__),'tensors',protein+'_1_A_jacobian.pkl') 37 | 38 | contact= torch.load(contact_path) 39 | torch.testing.assert_close(prediction,contact) -------------------------------------------------------------------------------- /protein_lm/tests/test_encoding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import functools 4 | import pytest 5 | import torch 6 | from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding 7 | from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding 8 | from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor 9 | from protein_lm.modeling.utils.scaled_rope_embedding import LlamaLinearScalingRotaryEmbedding,LlamaDynamicNTKScalingRotaryEmbedding 10 | 11 | assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) 12 | encodings = ['rope','rerope','alibi','linear_rope_scaling','dynamic_rope_scaling'] 13 | @pytest.mark.parametrize("encoding",encodings) 14 | def test_encoding(encoding): 15 | if encoding == 'rope': 16 | head_dim = 64 17 | seq_len = 10 18 | heads = 12 19 | rot_emb = RotaryEmbedding(dim = head_dim) 20 | q = torch.zeros(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 21 | k = torch.zeros(1, heads, seq_len,head_dim) # keys 22 | 23 | qr,kr = rot_emb(q,k) 24 | assert_equal(q,qr) 25 | assert_equal(k,kr) 26 | q = torch.ones(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 27 | k = torch.ones(1, heads, seq_len,head_dim) # keys 28 | qr1,kr1 = rot_emb(q,k) 29 | rope_path = os.path.join(os.path.dirname(__file__),'tensors','rope.pkl') 30 | rope = torch.load(rope_path) 31 | qr2,kr2 = rope[0],rope[1] 32 | torch.testing.assert_close(qr1,qr2) 33 | torch.testing.assert_close(kr1,kr2) 34 | 35 | 36 | elif encoding == 'rerope': 37 | head_dim = 64 38 | seq_len = 10 39 | rot_emb = RectifiedRotaryEmbedding(dim = head_dim,max_position_embeddings=seq_len ) 40 | heads = 12 41 | position_ids = torch.arange(0,seq_len,dtype=torch.int32).unsqueeze(0) 42 | q = torch.zeros(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 43 | k = torch.zeros(1, heads, seq_len,head_dim) # keys 44 | qr,kr = rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 45 | assert_equal(q,qr) 46 | assert_equal(k,kr) 47 | q = torch.ones(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 48 | k = torch.ones(1, heads, seq_len,head_dim) # keys 49 | qr1,kr1= rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 50 | rerope_path = os.path.join(os.path.dirname(__file__),'tensors','rerope.pkl') 51 | rerope = torch.load(rerope_path) 52 | qr2,kr2 = rerope[0],rerope[1] 53 | torch.testing.assert_close(qr1,qr2) 54 | torch.testing.assert_close(kr1,kr2) 55 | 56 | elif encoding == 'linear_rope_scaling': 57 | head_dim = 64 58 | seq_len = 10 59 | scaling_factor=1.0 60 | rope_theta=10000 61 | rot_emb = LlamaLinearScalingRotaryEmbedding(dim=head_dim,max_position_embeddings=seq_len,scaling_factor=scaling_factor,base=rope_theta) 62 | heads = 12 63 | position_ids = torch.arange(0,seq_len,dtype=torch.int32).unsqueeze(0) 64 | q = torch.zeros(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 65 | k = torch.zeros(1, heads, seq_len,head_dim) # keys 66 | qr,kr = rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 67 | assert_equal(q,qr) 68 | assert_equal(k,kr) 69 | q = torch.ones(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 70 | k = torch.ones(1, heads, seq_len,head_dim) # keys 71 | qr1,kr1= rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 72 | rerope_path = os.path.join(os.path.dirname(__file__),'tensors','linear_rope.pkl') 73 | rerope = torch.load(rerope_path) 74 | qr2,kr2 = rerope[0],rerope[1] 75 | torch.testing.assert_close(qr1,qr2) 76 | torch.testing.assert_close(kr1,kr2) 77 | 78 | elif encoding == "dynamic_rope_scaling": 79 | head_dim = 64 80 | seq_len = 10 81 | scaling_factor=1.0 82 | rope_theta=10000 83 | rot_emb = LlamaDynamicNTKScalingRotaryEmbedding(dim=head_dim,max_position_embeddings=seq_len,scaling_factor=scaling_factor,base=rope_theta) 84 | heads = 12 85 | position_ids = torch.arange(0,seq_len,dtype=torch.int32).unsqueeze(0) 86 | q = torch.zeros(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 87 | k = torch.zeros(1, heads, seq_len,head_dim) # keys 88 | qr,kr = rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 89 | assert_equal(q,qr) 90 | assert_equal(k,kr) 91 | q = torch.ones(1, heads, seq_len, head_dim) # queries - (batch, heads, seq len, dimension of head) 92 | k = torch.ones(1, heads, seq_len,head_dim) # keys 93 | qr1,kr1= rot_emb(q,k,seq_len=seq_len,position_ids = position_ids) 94 | rerope_path = os.path.join(os.path.dirname(__file__),'tensors','dynamic_rope.pkl') 95 | rerope = torch.load(rerope_path) 96 | qr2,kr2 = rerope[0],rerope[1] 97 | torch.testing.assert_close(qr1,qr2) 98 | torch.testing.assert_close(kr1,kr2) 99 | 100 | elif encoding == 'alibi': 101 | heads = 12 102 | maxpos = 10 103 | batch_size = 1 104 | def build_alibi_tensor(max_seq_len, num_attention_heads, batch_size): 105 | #adpated from https://github.com/bigscience-workshop/bigscience/ 106 | # Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 107 | """Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)""" 108 | def get_slopes(n): 109 | def get_slopes_power_of_2(n): 110 | start = (2 ** (-2 ** -(math.log2(n) - 3))) 111 | ratio = start 112 | return [start * ratio ** i for i in range(n)] 113 | 114 | if math.log2(n).is_integer(): 115 | return get_slopes_power_of_2(n) 116 | else: 117 | closest_power_of_2 = 2 ** math.floor(math.log2(n)) 118 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ 119 | :n - closest_power_of_2] 120 | slopes = torch.Tensor(get_slopes(num_attention_heads)) 121 | alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(num_attention_heads, -1, -1) 122 | alibi = alibi.repeat(batch_size, 1, 1) 123 | return alibi 124 | alibi1 = create_alibi_tensor(heads,maxpos) 125 | alibi2 = build_alibi_tensor(maxpos, heads, batch_size) 126 | torch.testing.assert_close(alibi1,alibi2) 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /protein_lm/tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from protein_lm.tokenizer import EsmTokenizer, AptTokenizer 4 | 5 | # Test parameters 6 | TOKENIZERS = [EsmTokenizer(), AptTokenizer()] 7 | 8 | # 1. Basic Encoding and Decoding 9 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 10 | def test_basic_encoding_decoding(tokenizer): 11 | sequence = "LAGERT" 12 | encoded = tokenizer.encode(sequence) 13 | decoded = tokenizer.decode(encoded) 14 | assert decoded == sequence 15 | 16 | # 2. Special Tokens Handling 17 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 18 | def test_special_tokens(tokenizer): 19 | sequence = "LAGERT" 20 | encoded = tokenizer.encode(sequence, add_special_tokens=True) 21 | assert encoded[0] == tokenizer.ids_to_tokens.index("") 22 | assert encoded[-1] == tokenizer.ids_to_tokens.index("") 23 | assert len(encoded) == len(sequence) + 2 24 | 25 | # 3. Max Sequence Length Handling 26 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 27 | def test_max_sequence_length(tokenizer): 28 | sequence = "LAGERT" 29 | max_length = 3 30 | encoded = tokenizer.encode(sequence, max_sequence_length=max_length) 31 | assert len(encoded) == max_length 32 | 33 | # 4. Returning Tensors 34 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 35 | def test_return_tensors(tokenizer): 36 | sequence = "LAGERT" 37 | encoded_tensor = tokenizer.encode(sequence, return_tensor=True) 38 | assert isinstance(encoded_tensor, torch.Tensor) 39 | 40 | # 5. Encoding with Special Tokens 41 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 42 | def test_encoding_special_tokens(tokenizer): 43 | sequence = "LAGERT" 44 | encoded = tokenizer.encode(sequence, add_special_tokens=True) 45 | assert encoded[0] == tokenizer.ids_to_tokens.index("") 46 | assert encoded[-1] == tokenizer.ids_to_tokens.index("") 47 | assert len(encoded) == len(sequence) + 2 48 | 49 | # 6. Encoding with Max Sequence Length 50 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 51 | def test_encoding_max_length(tokenizer): 52 | sequence = "LAGERT" 53 | max_length = 3 54 | encoded = tokenizer.encode(sequence, max_sequence_length=max_length) 55 | assert len(encoded) == max_length 56 | 57 | # 7. Encoding Returning Tensors 58 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 59 | def test_encoding_return_tensors(tokenizer): 60 | sequence = "LAGERT" 61 | encoded_tensor = tokenizer.encode(sequence, return_tensor=True) 62 | assert isinstance(encoded_tensor, torch.Tensor) 63 | 64 | # 8. Encoding with Special Tokens and Max Length 65 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 66 | def test_encoding_special_tokens_max_length(tokenizer): 67 | sequence = "LAGERT" 68 | max_length = 3 69 | encoded = tokenizer.encode(sequence, add_special_tokens=True, max_sequence_length=max_length) 70 | assert len(encoded) == max_length 71 | 72 | # 9. Encoding with Special Tokens and Tensors 73 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 74 | def test_encoding_special_tokens_tensors(tokenizer): 75 | sequence = "LAGERT" 76 | encoded_tensor = tokenizer.encode(sequence, add_special_tokens=True, return_tensor=True) 77 | assert isinstance(encoded_tensor, torch.Tensor) 78 | assert encoded_tensor[0] == tokenizer.ids_to_tokens.index("") 79 | assert encoded_tensor[-1] == tokenizer.ids_to_tokens.index("") 80 | 81 | # 10. Encoding with Max Length and Tensors 82 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 83 | def test_encoding_max_length_tensors(tokenizer): 84 | sequence = "LAGERT" 85 | max_length = 3 86 | encoded_tensor = tokenizer.encode(sequence, max_sequence_length=max_length, return_tensor=True) 87 | assert isinstance(encoded_tensor, torch.Tensor) 88 | assert len(encoded_tensor) == max_length 89 | 90 | # 11. Encoding with Special Tokens, Max Length, and Tensors 91 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 92 | def test_encoding_all_combinations(tokenizer): 93 | sequence = "LAGERT" 94 | max_length = 3 95 | encoded_tensor = tokenizer.encode(sequence, add_special_tokens=True, max_sequence_length=max_length, return_tensor=True) 96 | assert isinstance(encoded_tensor, torch.Tensor) 97 | assert len(encoded_tensor) == max_length 98 | 99 | 100 | # 12. Batch Encoding 101 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 102 | def test_batch_encoding(tokenizer): 103 | sequences = ["LAGERT", "SERPK"] 104 | batch_encoded = tokenizer.batch_encode(sequences) 105 | assert len(batch_encoded) == len(sequences) 106 | 107 | # 13. Handling of Unknown Tokens 108 | def test_unknown_tokens(): 109 | tokenizer = AptTokenizer() # Assuming unknown tokens will be handled as 110 | sequence = "XLAGERT" # "X" is not in the token list 111 | encoded = tokenizer.encode(sequence) 112 | assert encoded[0] == tokenizer.ids_to_tokens.index("") 113 | 114 | # 14.Test Batch Encoding with Special Tokens 115 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 116 | def test_batch_encoding_special_tokens(tokenizer): 117 | sequences = ["LAGERT", "SERPK"] 118 | batch_encoded = tokenizer.batch_encode(sequences, add_special_tokens=True) 119 | for encoded in batch_encoded: 120 | assert encoded[0] == tokenizer.ids_to_tokens.index("") 121 | assert encoded[-1] == tokenizer.ids_to_tokens.index("") 122 | 123 | # 15.Test Batch Encoding with Max Sequence Length 124 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 125 | def test_batch_encoding_max_length(tokenizer): 126 | sequences = ["LAGERT", "SERPK"] 127 | max_length = 3 128 | batch_encoded = tokenizer.batch_encode(sequences, max_sequence_length=max_length) 129 | for encoded in batch_encoded: 130 | assert len(encoded) == max_length 131 | 132 | # 16.Test Batch Encoding Returning Tensors 133 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 134 | def test_batch_encoding_return_tensors(tokenizer): 135 | sequences = ["LAGERT", "SERPK"] 136 | batch_encoded = tokenizer.batch_encode(sequences, return_tensors=True) 137 | assert isinstance(batch_encoded, torch.Tensor) 138 | 139 | # 17.Test Batch Encoding with Special Tokens, Max Length, and Tensors 140 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 141 | def test_batch_encoding_all_combinations(tokenizer): 142 | sequences = ["LAGERT", "SERPK"] 143 | max_length = 5 144 | batch_encoded = tokenizer.batch_encode( 145 | sequences, 146 | add_special_tokens=True, 147 | return_tensors=True, 148 | max_sequence_length=max_length 149 | ) 150 | assert isinstance(batch_encoded, torch.Tensor) 151 | assert batch_encoded.size(1) == max_length 152 | 153 | # 18.Test Batch Encoding with Empty List 154 | @pytest.mark.parametrize("tokenizer", TOKENIZERS) 155 | def test_batch_encoding_empty_list(tokenizer): 156 | sequences = [] 157 | batch_encoded = tokenizer.batch_encode(sequences) 158 | assert batch_encoded == [] 159 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import EsmTokenizer, AptTokenizer 2 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/rust_trie/.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by maturin v1.2.0 2 | # To update, run 3 | # 4 | # maturin generate-ci github 5 | # 6 | name: CI 7 | 8 | on: 9 | push: 10 | branches: 11 | - main 12 | - master 13 | tags: 14 | - '*' 15 | pull_request: 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | linux: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | matrix: 26 | target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] 27 | steps: 28 | - uses: actions/checkout@v3 29 | - uses: actions/setup-python@v4 30 | with: 31 | python-version: '3.10' 32 | - name: Build wheels 33 | uses: PyO3/maturin-action@v1 34 | with: 35 | target: ${{ matrix.target }} 36 | args: --release --out dist --find-interpreter 37 | sccache: 'true' 38 | manylinux: auto 39 | - name: Upload wheels 40 | uses: actions/upload-artifact@v3 41 | with: 42 | name: wheels 43 | path: dist 44 | 45 | windows: 46 | runs-on: windows-latest 47 | strategy: 48 | matrix: 49 | target: [x64, x86] 50 | steps: 51 | - uses: actions/checkout@v3 52 | - uses: actions/setup-python@v4 53 | with: 54 | python-version: '3.10' 55 | architecture: ${{ matrix.target }} 56 | - name: Build wheels 57 | uses: PyO3/maturin-action@v1 58 | with: 59 | target: ${{ matrix.target }} 60 | args: --release --out dist --find-interpreter 61 | sccache: 'true' 62 | - name: Upload wheels 63 | uses: actions/upload-artifact@v3 64 | with: 65 | name: wheels 66 | path: dist 67 | 68 | macos: 69 | runs-on: macos-latest 70 | strategy: 71 | matrix: 72 | target: [x86_64, aarch64] 73 | steps: 74 | - uses: actions/checkout@v3 75 | - uses: actions/setup-python@v4 76 | with: 77 | python-version: '3.10' 78 | - name: Build wheels 79 | uses: PyO3/maturin-action@v1 80 | with: 81 | target: ${{ matrix.target }} 82 | args: --release --out dist --find-interpreter 83 | sccache: 'true' 84 | - name: Upload wheels 85 | uses: actions/upload-artifact@v3 86 | with: 87 | name: wheels 88 | path: dist 89 | 90 | sdist: 91 | runs-on: ubuntu-latest 92 | steps: 93 | - uses: actions/checkout@v3 94 | - name: Build sdist 95 | uses: PyO3/maturin-action@v1 96 | with: 97 | command: sdist 98 | args: --out dist 99 | - name: Upload sdist 100 | uses: actions/upload-artifact@v3 101 | with: 102 | name: wheels 103 | path: dist 104 | 105 | release: 106 | name: Release 107 | runs-on: ubuntu-latest 108 | if: "startsWith(github.ref, 'refs/tags/')" 109 | needs: [linux, windows, macos, sdist] 110 | steps: 111 | - uses: actions/download-artifact@v3 112 | with: 113 | name: wheels 114 | - name: Publish to PyPI 115 | uses: PyO3/maturin-action@v1 116 | env: 117 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 118 | with: 119 | command: upload 120 | args: --skip-existing * 121 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/rust_trie/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version -------------------------------------------------------------------------------- /protein_lm/tokenizer/rust_trie/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust_trie" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [lib] 8 | name = "rust_trie" 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | pyo3 = "0.19.0" 13 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/rust_trie/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.2,<2.0", "setuptools"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "rust_trie" 7 | requires-python = ">=3.7" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | ] 13 | 14 | 15 | [tool.maturin] 16 | features = ["pyo3/extension-module"] 17 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/rust_trie/src/lib.rs: -------------------------------------------------------------------------------- 1 | use pyo3::prelude::*; 2 | use std::collections::HashMap; 3 | 4 | struct TrieNode { 5 | children: HashMap, 6 | token_id: Option, 7 | } 8 | 9 | impl TrieNode { 10 | fn new() -> Self { 11 | TrieNode { 12 | children: HashMap::new(), 13 | token_id: None, 14 | } 15 | } 16 | } 17 | 18 | #[pyclass] 19 | pub struct Trie { 20 | root: TrieNode, 21 | next_id: usize, // for assigning unique IDs to tokens 22 | unk_token_set: bool, 23 | unk_token_id: usize, 24 | } 25 | 26 | // A Trie: See https://en.wikipedia.org/wiki/Trie 27 | // This is a data structure that allows for tokenizing a stream of text 28 | // such that the longest possible tokens are recognized first. 29 | // 30 | // To explain how this works, let's first consider how we add new tokens. 31 | // Let's say we have four possible tokens: 'A', 'B', 'AA', 'AB' 32 | // The trie always has an empty root node. There will be two children: 33 | // the node 'A' with token_id 0 and 'B' with token_id 1. The node 'B' has 34 | // no children since we have no tokens that start with 'B' and continue to 35 | // another character. 36 | // The node 'A' has two children, one other node with 'A' with token_id of 2 37 | // and a node 'B' with token_id of 3. 38 | // Now, to tokenize the following string: 'ABAABA' 39 | // We start from the beginning of the string, continuing until we no longer have 40 | // the substring in our tokens. The first character is 'A'; going down the trie 41 | // we have a node that starts with 'A'. The next character is 'B', and our 'A' node 42 | // has a child that starts with 'B'. The character after is 'A', but our last node, 43 | // with token_id = 3, has no children, so we have our first token 'AB' with token_id 3. 44 | // Similarly, we have 'AA' token_id 2, 'B' with token_id 1, and 'A' with token_id 0. 45 | // So, 'ABAABA' -> [3, 2, 1, 0] 46 | #[pymethods] 47 | impl Trie { 48 | #[new] 49 | pub fn new(unk_token_id: Option) -> Self { 50 | Trie { 51 | root: TrieNode::new(), 52 | next_id: 0, // We start the IDs at 0 53 | unk_token_set: unk_token_id.is_some(), 54 | unk_token_id: unk_token_id.unwrap_or(0), 55 | } 56 | } 57 | 58 | // Function responsible for figuring out the tree structure 59 | // Children are represented as dictionaries to make the search simpler. 60 | // In fact, for our purposes where the number of children will be small, 61 | // it is probably faster to use lists. 62 | pub fn add(&mut self, word: &str) { 63 | let mut node = &mut self.root; 64 | for ch in word.chars() { 65 | node = node.children.entry(ch).or_insert_with(TrieNode::new); 66 | } 67 | if node.token_id.is_none() { 68 | node.token_id = Some(self.next_id); 69 | self.next_id += 1; 70 | if !self.unk_token_set { 71 | self.unk_token_id = self.next_id; 72 | } 73 | } 74 | } 75 | 76 | // Tokenizing function. Does what is described in the comment above. 77 | // You can see how we keep going through the characters until we hit a node 78 | // that has no children. 79 | pub fn tokenize(&self, text: &str) -> Vec { 80 | let mut tokens = vec![]; 81 | let mut start = 0; 82 | 83 | while start < text.len() { 84 | let mut node = &self.root; 85 | let mut matched = false; 86 | let mut end = start; 87 | for ch in text[start..].chars() { 88 | if let Some(next_node) = node.children.get(&ch) { 89 | // If the character matches a child, we go to the next node 90 | node = next_node; 91 | end += ch.len_utf8(); 92 | if node.token_id.is_some() { // If at the leaf, we have our token 93 | matched = true; 94 | break; 95 | } 96 | } else { // This means we never matched, so it is an '' token 97 | break; 98 | } 99 | } 100 | 101 | if matched { 102 | tokens.push(node.token_id.unwrap()); 103 | start = end; 104 | } else { 105 | tokens.push(self.unk_token_id); // Assign unknown token ID 106 | start += text[start..].chars().next().unwrap().len_utf8(); 107 | } 108 | } 109 | 110 | tokens 111 | } 112 | } 113 | 114 | #[pymodule] 115 | fn rust_trie(_py: Python, m: &PyModule) -> PyResult<()> { 116 | m.add_class::()?; 117 | Ok(()) 118 | } 119 | 120 | #[cfg(test)] 121 | mod tests { 122 | use super::*; 123 | use pyo3::types::IntoPyDict; 124 | 125 | #[test] 126 | fn test_trie() { 127 | let gil = Python::acquire_gil(); 128 | let py = gil.python(); 129 | let trie_module = PyModule::new(py, "trie_module").unwrap(); 130 | let locals = [("trie", trie_module)].into_py_dict(py); 131 | let py_trie: PyObject = py 132 | .eval("trie.Trie()", Some(locals), None) 133 | .unwrap() 134 | .into(); 135 | py_trie.call_method0("add", "[CLS]").unwrap(); 136 | let tokens: Vec = py_trie 137 | .call_method1("tokenize", ("[CLS] This is a test",)) 138 | .unwrap() 139 | .extract() 140 | .unwrap(); 141 | assert_eq!(tokens, vec![0, 1, 1, 1, 1]); 142 | } 143 | } 144 | 145 | 146 | #[cfg(test)] 147 | mod tests { 148 | use super::*; 149 | use pyo3::types::IntoPyDict; 150 | 151 | #[test] 152 | fn test_trie() { 153 | let gil = Python::acquire_gil(); 154 | let py = gil.python(); 155 | let trie_mod = PyModule::new(py, "trie_module").unwrap(); 156 | let locals = [("trie", trie_mod)].into_py_dict(py); 157 | let py_trie: PyObject = py 158 | .eval("trie.Trie()", Some(locals), None) 159 | .unwrap() 160 | .into(); 161 | py_trie.call_method0("add", "").unwrap(); 162 | let tokens: Vec = py_trie 163 | .call_method1("tokenized", (" This is a test",)) 164 | .unwrap() 165 | .extract() 166 | .unwrap(); 167 | assert_eq!(tokens, vec![0, 1, 1, 1, 1]); 168 | } 169 | } 170 | 171 | -------------------------------------------------------------------------------- /protein_lm/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Union, Optional 3 | 4 | from rust_trie import Trie 5 | 6 | 7 | class Tokenizer: 8 | def __init__(self, tokens: List[str], unk_token_id: Optional[int] = None): 9 | self.ids_to_tokens = tokens 10 | self.trie = Trie(unk_token_id) 11 | for token in tokens: 12 | self.trie.add(token) 13 | # If unk_token_id is not provided, add to the end of the tokens list 14 | if unk_token_id is None: 15 | self.ids_to_tokens += [""] 16 | self.pad_token_id = self.ids_to_tokens.index("") 17 | self.mask_token_id = self.ids_to_tokens.index("") 18 | 19 | 20 | def __call__(self, sequences: Union[str, List], *args, **kwargs): 21 | if isinstance(sequences, str): 22 | return self.encode(sequences, *args, **kwargs) 23 | else: 24 | return self.batch_encode(sequences, *args, **kwargs) 25 | 26 | def encode( 27 | self, 28 | sequence: str, 29 | add_special_tokens: bool = False, 30 | return_tensor: bool = False, 31 | max_sequence_length: Optional[int] = None, 32 | ) -> List[int]: 33 | if max_sequence_length is not None: 34 | if add_special_tokens: 35 | max_sequence_length -= 2 36 | sequence = sequence[:max_sequence_length] 37 | if add_special_tokens: 38 | sequence = "" + sequence + "" 39 | output = self.trie.tokenize(sequence) 40 | if return_tensor: 41 | output = torch.tensor(output, dtype=torch.long) 42 | return output 43 | 44 | def batch_encode( 45 | self, 46 | sequences: List[str], 47 | add_special_tokens: bool = False, 48 | return_tensors: bool = False, 49 | max_sequence_length: Optional[int] = None, 50 | ) -> List[List[int]]: 51 | output = [] 52 | if max_sequence_length is None and return_tensors: 53 | max_sequence_length = max([len(sequence) for sequence in sequences]) 54 | if add_special_tokens: 55 | max_sequence_length += 2 56 | if max_sequence_length is not None: 57 | sequences = [ 58 | sequence[:(max_sequence_length - 2) if add_special_tokens else max_sequence_length] 59 | for sequence in sequences 60 | ] 61 | for sequence in sequences: 62 | output.append(self.encode(sequence, add_special_tokens, return_tensors)) 63 | if return_tensors: 64 | tensor_out = torch.full((len(output), max_sequence_length), self.pad_token_id) 65 | for i, sequence in enumerate(output): 66 | tensor_out[i, :len(sequence)] = sequence 67 | output = tensor_out 68 | return output 69 | 70 | def decode(self, tokens: List[int]) -> str: 71 | return "".join([self.ids_to_tokens[idx] for idx in tokens]) 72 | 73 | 74 | 75 | class EsmTokenizer(Tokenizer): 76 | def __init__(self): 77 | tokens = [ 78 | "", "", "", "", "L", "A", "G", 79 | "V", "S", "E", "R", "T", "I", "D", "P", "K", "Q", 80 | "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", 81 | "Z", "O", ".", "-", "", "" 82 | ] 83 | super().__init__(tokens, unk_token_id=3) 84 | 85 | 86 | 87 | class AptTokenizer(Tokenizer): 88 | def __init__(self): 89 | # For our own tokenizers, we don't need to explicitly add the token 90 | # because it gets added as the last token in the tokens list 91 | # I've also removed X so that it gets translated to 92 | tokens = [ 93 | "", "", "", "L", "A", "G", "V", 94 | "S", "E", "R", "T", "I", "D", "P", "K", "Q", "N", 95 | "F", "Y", "M", "H", "W", "C", "B", "U", "Z", "O", 96 | "" 97 | ] 98 | super().__init__(tokens) -------------------------------------------------------------------------------- /protein_lm_cuda.yml: -------------------------------------------------------------------------------- 1 | name: protein_lm_cuda 2 | channels: 3 | - pytorch 4 | - huggingface 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - python>=3.8 10 | - numpy 11 | - scipy 12 | - cudatoolkit=11.7 13 | - pytorch-cuda=11.7 14 | - pydantic>=2.0 15 | - rust 16 | - pip: 17 | - transformers 18 | - datasets 19 | - accelerate 20 | - evaluate 21 | - pytest 22 | - fair-esm 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='protein_lm', version='1.0', packages=find_packages()) --------------------------------------------------------------------------------