├── .gitignore ├── LICENSE ├── README.md ├── config ├── parallel_config.yaml └── training_config.yaml ├── confit ├── data_utils.py ├── inference.py ├── stat_utils.py └── train.py ├── requirements.txt └── scripts ├── download.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Luo Group 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConFit 2 | 3 | - [ConFit](#confit) 4 | * [Overview](#overview) 5 | * [Dependencies](#dependencies) 6 | * [Data](#data) 7 | + [Source Data](#source-data) 8 | + [Running on customized data](#running-on-customized-data) 9 | * [Train ConFit](#train-confit) 10 | + [customizing config files](#customizing-config-files) 11 | 12 | ## Overview 13 | 14 | ConFit is a pLM-based ML method for learning the protein fitness landscape with limited experimental fitness measurements as training data. It uses a contrastive learning strategy to fine-tune the pre-trained pLM, tailoring it to achieve protein-specific fitness prediction while avoiding overfitting. 15 | 16 | ## Dependencies 17 | 18 | This code is based on Python 3.9.18 and PyTorch 2.1.0 with CUDA 12.2. Please first install the correct PyTorch version and then install the required packages as follows: 19 | 20 | ``` 21 | conda install pytorch torchvision torchaudio pytorch-cuda=12.2 -c pytorch -c nvidia 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Data 26 | 27 | ### Source Data 28 | 29 | To use the source data in the paper, please use the script `scripts/download.sh` to download our source data. 30 | 31 | ### Running on customized data 32 | 33 | To run the customized data using our model, please follow the steps below: 34 | 35 | 1. collect the DMS dataset and put it in `data/$dataset/data.csv`, the csv file should have the following necessary columns: 36 | 37 | **seq**: the mutant sequence. 38 | 39 | **log_fitness**: the ground truth fitness value for the sequence. 40 | 41 | **mutated_position**: the mutant position the the assay, please note that the position number should start from 0. 42 | 43 | **PID**: A unique number for each sequence, which can be generated by auto-increment. 44 | 45 | 2. Generate the fasta file for the wild-type sequence, please put it in `data/$dataset/wt.fasta` 46 | 47 | 3. If you want to use retrieval augmentation, please follow the DeepSequence repo(https://github.com/debbiemarkslab/DeepSequence) to generate the ELBO for each mutant assay. Please put the predicted ELBO in `data/$dataset/vae_elbo.csv`, which contains the following columns: 48 | 49 | **seq**: the mutant sequence. 50 | 51 | **PID**: a unique number for each sequence, which should be consistent with `data.csv`. 52 | 53 | **elbo**:the predicted ELBO values for each mutant assays. 54 | 55 | 56 | 57 | ## Train ConFit 58 | 59 | We provide an example of training ConFit in the `scripts/train.sh` file, which would help you quickly try our model. For example, the following script trains our model on GB1_Olson2014_ddg dataset using 48 shots of training data: 60 | 61 | ``` 62 | accelerate launch --config_file config/parallel_config.yaml confit/train.py \ 63 | --config config/training_config.yaml \ 64 | --dataset GB1_Olson2014_ddg \ 65 | --sample_seed 0 \ 66 | --model_seed 1 67 | ``` 68 | 69 | ​ `--config`: (required) specifies the file containing training hyperparameters 70 | 71 | ​ `--dataset`: (required) specifies the dataset name 72 | 73 | ​ `--sample_seed`: (optional) specify the random seed when sampling testing and training data. 74 | 75 | ​ `--model_seed`: (optional) specify the initiating seed for the pretrianed ESM-1v model, please choose it from 1-5. 76 | 77 | After training, please use the following scripts to conduct inference on the test set: 78 | 79 | ``` 80 | python confit/inference.py --dataset $dataset --shot $shot 81 | ``` 82 | 83 | ​ `--dataset`: (required) specifies the dataset name 84 | 85 | ​ `--shot`: (required) specifies the training size 86 | 87 | ​ `--no_retrieval` :(optional) forbidden the retrieval augmentation in inference 88 | 89 | the test spearman will be generated in `results/$dataset/summary.csv` 90 | 91 | ### customizing config files 92 | 93 | **Training size**: For different training sizes, please modify **shot** in `config/training_config.yaml`and change the training hyperparameters in that file accordingly. 94 | 95 | **GPU:** We trained ConFit using 4 A40. According to the GPU numbers you use, please modify **num_processes** and **gpu_number** respectively in `config/parallel_config.yaml` and `config/training_config.yaml`. 96 | 97 | **PLM**: We utilized ESM-1v as our PLM to be fine-tuned. Similar protein language models can also be used. Please modify **model** in `config/training_config.yaml` to ESM-2 or ESM-1b to change the PLM. 98 | 99 | -------------------------------------------------------------------------------- /config/parallel_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: FSDP 3 | downcast_bf16: 'no' 4 | fsdp_config: 5 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 6 | fsdp_backward_prefetch_policy: BACKWARD_PRE 7 | fsdp_offload_params: true 8 | fsdp_sharding_strategy: 1 9 | fsdp_state_dict_type: FULL_STATE_DICT 10 | fsdp_transformer_layer_cls_to_wrap: EsmLayer 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'no' 14 | num_machines: 1 15 | num_processes: 4 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /config/training_config.yaml: -------------------------------------------------------------------------------- 1 | model: ESM-2 2 | gpu_number: 4 3 | shot: 48 4 | batch_size: 16 5 | max_epochs: 30 6 | ini_lr: 5e-4 7 | min_lr: 1e-4 8 | endure_time: 5 9 | lambda_reg: 0.1 10 | lora_r: 8 11 | lora_alpha: 8 12 | lora_dropout: 0.1 13 | 14 | -------------------------------------------------------------------------------- /confit/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | import pandas as pd 4 | from scipy.stats import spearmanr 5 | from scipy import stats 6 | from scipy.stats import bootstrap 7 | import numpy as np 8 | import os 9 | from Bio import SeqIO 10 | 11 | 12 | class Mutation_Set(Dataset): 13 | def __init__(self, data, fname, tokenizer, sep_len=1024): 14 | self.data = data 15 | self.tokenizer = tokenizer 16 | self.seq_len = sep_len 17 | self.seq, self.attention_mask = tokenizer(list(self.data['seq']), padding='max_length', 18 | truncation=True, 19 | max_length=self.seq_len).values() 20 | wt_path = os.path.join('data', fname, 'wt.fasta') 21 | for seq_record in SeqIO.parse(wt_path, "fasta"): 22 | wt = str(seq_record.seq) 23 | target = [wt]*len(self.data) 24 | self.target, self.tgt_mask = tokenizer(target, padding='max_length', truncation=True, 25 | max_length=self.seq_len).values() 26 | self.score = torch.tensor(np.array(self.data['log_fitness'])) 27 | self.pid = np.asarray(data['PID']) 28 | 29 | if type(list(self.data['mutated_position'])[0]) != str: 30 | self.position = [[u] for u in self.data['mutated_position']] 31 | 32 | else: 33 | 34 | temp = [u.split(',') for u in self.data['mutated_position']] 35 | self.position = [] 36 | for u in temp: 37 | pos = [int(v) for v in u] 38 | self.position.append(pos) 39 | 40 | def __getitem__(self, idx): 41 | return [self.seq[idx], self.attention_mask[idx], self.target[idx],self.tgt_mask[idx] ,self.position[idx], self.score[idx], self.pid[idx]] 42 | 43 | def __len__(self): 44 | return len(self.score) 45 | 46 | def collate_fn(self, data): 47 | seq = torch.tensor(np.array([u[0] for u in data])) 48 | att_mask = torch.tensor(np.array([u[1] for u in data])) 49 | tgt = torch.tensor(np.array([u[2] for u in data])) 50 | tgt_mask = torch.tensor(np.array([u[3] for u in data])) 51 | pos = [torch.tensor(u[4]) for u in data] 52 | score = torch.tensor(np.array([u[5] for u in data]), dtype=torch.float32) 53 | pid = torch.tensor(np.array([u[6] for u in data])) 54 | return seq, att_mask, tgt, tgt_mask, pos, score, pid 55 | 56 | 57 | def sample_data(dataset_name, seed, shot, frac=0.2): 58 | ''' 59 | sample the train data and test data 60 | :param seed: sample seed 61 | :param frac: the fraction of testing data, default to 0.2 62 | :param shot: the size of training data 63 | ''' 64 | 65 | data = pd.read_csv(f'data/{dataset_name}/data.csv', index_col=0) 66 | test_data = data.sample(frac=frac, random_state=seed) 67 | train_data = data.drop(test_data.index) 68 | kshot_data = train_data.sample(n=shot, random_state=seed) 69 | assert len(kshot_data) == shot, ( 70 | f'expected {shot} train examples, received {len(train_data)}') 71 | 72 | kshot_data.to_csv(f'data/{dataset_name}/train.csv') 73 | test_data.to_csv(f'data/{dataset_name}/test.csv') 74 | 75 | 76 | def split_train(dataset_name): 77 | ''' 78 | five equal split training data, one of which will be used as validation set when training ConFit 79 | ''' 80 | train = pd.read_csv(f'data/{dataset_name}/train.csv', index_col=0) 81 | tlen = int(np.ceil(len(train) / 5)) 82 | start = 0 83 | for i in range(1, 5): 84 | csv = train[start:start + tlen] 85 | start += tlen 86 | csv.to_csv(f'data/{dataset_name}/train_{i}.csv') 87 | csv = train[start:] 88 | csv.to_csv(f'data/{dataset_name}/train_{5}.csv') 89 | 90 | 91 | 92 | 93 | 94 | def spearman(y_pred, y_true): 95 | if np.var(y_pred) < 1e-6 or np.var(y_true) < 1e-6: 96 | return 0.0 97 | return spearmanr(y_pred, y_true)[0] 98 | 99 | def compute_stat(sr): 100 | sr = np.asarray(sr) 101 | mean = np.mean(sr) 102 | std = np.std(sr) 103 | sr = (sr,) 104 | ci = list(bootstrap(sr, np.mean).confidence_interval) 105 | return mean, std, ci 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /confit/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | from stat_utils import compute_stat, spearman 7 | import argparse 8 | 9 | 10 | parser = argparse.ArgumentParser(description='inference!') 11 | parser.add_argument('--dataset', type=str, help='dataset name') 12 | parser.add_argument('--shot', type=int, help='training size') 13 | parser.add_argument('--no_retrival', action='store_true', help='whether use retrival') 14 | parser.add_argument('--alpha', type=float, default=0.8, 15 | help='retrieval alpha') 16 | args = parser.parse_args() 17 | 18 | 19 | if os.path.exists(f'results/{args.dataset}/summary.csv'): 20 | summary = pd.read_csv(f'results/{args.dataset}/summary.csv', index_col=0) 21 | else: 22 | summary = pd.DataFrame(None) 23 | 24 | 25 | if os.path.exists(f'predicted/{args.dataset}/pred.csv'): 26 | pred = pd.read_csv(f'predicted/{args.dataset}/pred.csv', index_col=0) 27 | pred = pred.drop_duplicates(subset='PID') 28 | if not args.no_retrival: 29 | elbo = pd.read_csv(f'data/{args.dataset}/vae_elbo.csv', index_col=0) 30 | seed_list = [] 31 | for i in range(1, 6): 32 | if f'{i}' in pred.columns: 33 | seed_list.append(f'{i}') 34 | temp = pred[seed_list] 35 | temp = temp.mean(axis=1) 36 | pred = pd.concat([pred, temp], axis=1) 37 | pred = pred.rename(columns={0: 'avg'}) 38 | test = pd.read_csv(f'data/{args.dataset}/test.csv', index_col=0) 39 | avg = pred[['avg', 'PID']] 40 | label = test[['PID', 'log_fitness']] 41 | perf = pd.merge(avg, label, on='PID') 42 | 43 | if not args.no_retrival: 44 | perf = pd.merge(perf, elbo, on='PID') 45 | perf['retrival'] = args.alpha * perf['avg'] + (1 - args.alpha) * perf['elbo'] 46 | score = list(perf['retrival']) 47 | gscore = list(perf['log_fitness']) 48 | score = np.asarray(score) 49 | gscore = np.asarray(gscore) 50 | sr = spearman(score, gscore) 51 | out = pd.DataFrame({'spearman': sr, 'shot': args.shot}, index=[f'{args.dataset}']) 52 | summary = pd.concat([summary, out], axis=0) 53 | else: 54 | 55 | score = list(perf['avg']) 56 | gscore = list(perf['log_fitness']) 57 | score = np.asarray(score) 58 | gscore = np.asarray(gscore) 59 | sr = spearman(score, gscore) 60 | out = pd.DataFrame({'spearman': sr, 'shot': args.shot}, index=[f'{args.dataset}']) 61 | summary = pd.concat([summary, out], axis=0) 62 | 63 | if not os.path.isdir(f'results/{args.dataset}'): 64 | os.makedirs(f'results/{args.dataset}') 65 | 66 | summary.to_csv(f'results/{args.dataset}/summary.csv') 67 | -------------------------------------------------------------------------------- /confit/stat_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.stats import spearmanr 5 | from scipy import stats 6 | 7 | 8 | def spearman(y_pred, y_true): 9 | if np.var(y_pred) < 1e-6 or np.var(y_true) < 1e-6: 10 | return 0.0 11 | return spearmanr(y_pred, y_true)[0] 12 | 13 | 14 | def compute_stat(sr): 15 | sr = np.asarray(sr) 16 | mean = np.mean(sr) 17 | std = np.std(sr) 18 | return mean, std 19 | 20 | 21 | def compute_score(model, seq, mask, wt, pos, tokenizer): 22 | ''' 23 | compute mutational proxy using masked marginal probability 24 | :param seq:mutant seq 25 | :param mask:attention mask for input seq 26 | :param wt: wild type sequence 27 | :param pos:mutant position 28 | :return: 29 | score: mutational proxy score 30 | logits: output logits for masked sequence 31 | ''' 32 | device = seq.device 33 | 34 | mask_seq = seq.clone() 35 | m_id = tokenizer.mask_token_id 36 | 37 | batch_size = int(seq.shape[0]) 38 | for i in range(batch_size): 39 | mut_pos = pos[i] 40 | mask_seq[i, mut_pos+1] = m_id 41 | 42 | out = model(mask_seq, mask, output_hidden_states=True) 43 | logits = out.logits 44 | log_probs = torch.log_softmax(logits, dim=-1) 45 | scores = torch.zeros(batch_size) 46 | scores = scores.to(device) 47 | 48 | for i in range(batch_size): 49 | 50 | mut_pos = pos[i] 51 | score_i = log_probs[i] 52 | wt_i = wt[i] 53 | seq_i = seq[i] 54 | scores[i] = torch.sum(score_i[mut_pos+1, seq_i[mut_pos+1]])-torch.sum(score_i[mut_pos+1, wt_i[mut_pos+1]]) 55 | 56 | return scores, logits 57 | 58 | 59 | def BT_loss(scores, golden_score): 60 | loss = torch.tensor(0.) 61 | loss = loss.cuda() 62 | for i in range(len(scores)): 63 | for j in range(i, len(scores)): 64 | if golden_score[i] > golden_score[j]: 65 | loss += torch.log(1+torch.exp(scores[j]-scores[i])) 66 | else: 67 | loss += torch.log(1+torch.exp(scores[i]-scores[j])) 68 | return loss 69 | 70 | 71 | def KLloss(logits, logits_reg, seq, att_mask): 72 | 73 | creterion_reg = torch.nn.KLDivLoss(reduction='mean') 74 | batch_size = int(seq.shape[0]) 75 | 76 | loss = torch.tensor(0.) 77 | loss = loss.cuda() 78 | probs = torch.softmax(logits, dim=-1) 79 | probs_reg = torch.softmax(logits_reg, dim=-1) 80 | for i in range(batch_size): 81 | 82 | probs_i = probs[i] 83 | probs_reg_i = probs_reg[i] 84 | 85 | 86 | seq_len = torch.sum(att_mask[i]) 87 | 88 | reg = probs_reg_i[torch.arange(0, seq_len), seq[i, :seq_len]] 89 | pred = probs_i[torch.arange(0, seq_len), seq[i, :seq_len]] 90 | 91 | loss += creterion_reg(reg.log(), pred) 92 | return loss -------------------------------------------------------------------------------- /confit/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | import pandas as pd 5 | from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model 6 | from peft.utils.other import fsdp_auto_wrap_policy 7 | from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig 8 | import os 9 | import argparse 10 | from pathlib import Path 11 | import accelerate 12 | from accelerate import Accelerator 13 | 14 | from data_utils import Mutation_Set, split_train, sample_data 15 | from stat_utils import spearman, compute_score, BT_loss, KLloss 16 | import gc 17 | import warnings 18 | import time 19 | import yaml 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | 24 | def train(model, model_reg, trainloder, optimizer, tokenizer, lambda_reg): 25 | 26 | model.train() 27 | 28 | total_loss = 0. 29 | 30 | for step, data in enumerate(trainloder): 31 | seq, mask = data[0], data[1] 32 | wt, wt_mask = data[2], data[3] 33 | pos = data[4] 34 | golden_score = data[5] 35 | score, logits = compute_score(model, seq, mask, wt, pos, tokenizer) 36 | score = score.cuda() 37 | 38 | l_BT = BT_loss(score, golden_score) 39 | 40 | out_reg = model_reg(wt, wt_mask) 41 | logits_reg = out_reg.logits 42 | l_reg = KLloss(logits, logits_reg, seq, mask) 43 | 44 | loss = l_BT + lambda_reg*l_reg 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | total_loss += loss.item() 50 | return total_loss 51 | 52 | 53 | def evaluate(model, testloader, tokenizer, accelerator, istest=False): 54 | model.eval() 55 | seq_list = [] 56 | score_list = [] 57 | gscore_list = [] 58 | with torch.no_grad(): 59 | for step, data in enumerate(testloader): 60 | seq, mask = data[0], data[1] 61 | wt, wt_mask = data[2], data[3] 62 | pos = data[4] 63 | golden_score = data[5] 64 | pid = data[6] 65 | if istest: 66 | pid = pid.cuda() 67 | pid = accelerator.gather(pid) 68 | for s in pid: 69 | seq_list.append(s.cpu()) 70 | 71 | score, logits = compute_score(model, seq, mask, wt, pos, tokenizer) 72 | 73 | score = score.cuda() 74 | score = accelerator.gather(score) 75 | golden_score = accelerator.gather(golden_score) 76 | score = np.asarray(score.cpu()) 77 | golden_score = np.asarray(golden_score.cpu()) 78 | score_list.extend(score) 79 | gscore_list.extend(golden_score) 80 | score_list = np.asarray(score_list) 81 | gscore_list = np.asarray(gscore_list) 82 | sr = spearman(score_list, gscore_list) 83 | 84 | if istest: 85 | seq_list = np.asarray(seq_list) 86 | 87 | return sr, score_list, seq_list 88 | else: 89 | return sr 90 | 91 | 92 | def main(): 93 | parser = argparse.ArgumentParser(description='ConFit train, set hyperparameters') 94 | parser.add_argument('--config', type=str, default='48shot_config.yaml', 95 | help='the config file name') 96 | parser.add_argument('--dataset', type=str, help='the dataset name') 97 | parser.add_argument('--sample_seed', type=int, default=0, help='the sample seed for dataset') 98 | parser.add_argument('--model_seed', type=int, default=1, help='the random seed for the pretrained model initiate') 99 | args = parser.parse_args() 100 | dataset = args.dataset 101 | 102 | #read in config 103 | with open(f'{args.config}', 'r', encoding='utf-8') as f: 104 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 105 | 106 | batch_size = int(int(config['batch_size'])/int(config['gpu_number'])) 107 | 108 | 109 | accelerator = Accelerator() 110 | 111 | ### creat model 112 | if config['model'] == 'ESM-1v': 113 | basemodel = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{args.model_seed}') 114 | model_reg = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{args.model_seed}') 115 | tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{args.model_seed}') 116 | 117 | elif config['model'] == 'ESM-2': 118 | basemodel = EsmForMaskedLM.from_pretrained('facebook/esm2_t48_15B_UR50D') 119 | model_reg = EsmForMaskedLM.from_pretrained('facebook/esm2_t48_15B_UR50D') 120 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t48_15B_UR50D') 121 | 122 | elif config['model'] == 'ESM-1b': 123 | basemodel = EsmForMaskedLM.from_pretrained('facebook/esm1b_t33_650M_UR50S') 124 | model_reg = EsmForMaskedLM.from_pretrained('facebook/esm1b_t33_650M_UR50S') 125 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm1b_t33_650M_UR50S') 126 | 127 | for pm in model_reg.parameters(): 128 | pm.requires_grad = False 129 | model_reg.eval() #regularization model 130 | 131 | 132 | peft_config = LoraConfig( 133 | task_type="CAUSAL_LM", 134 | r=int(config['lora_r']), 135 | lora_alpha=int(config['lora_alpha']), 136 | lora_dropout=float(config['lora_dropout']), 137 | target_modules=["query", "value"] 138 | ) 139 | 140 | model = get_peft_model(basemodel, peft_config) 141 | 142 | # create optimizer and scheduler 143 | optimizer = torch.optim.Adam(model.parameters(), lr=float(config['ini_lr'])) 144 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2*int(config['max_epochs']), eta_min=float(config['min_lr'])) 145 | if os.environ.get("ACCELERATE_USE_FSDP", None) is not None: 146 | accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) 147 | model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) 148 | model_reg = accelerator.prepare(model_reg) 149 | 150 | accelerator.print(f'===================dataset:{dataset}, preparing data=============') 151 | 152 | # sample data 153 | if accelerator.is_main_process: 154 | sample_data(dataset, args.sample_seed, int(config['shot'])) 155 | split_train(dataset) 156 | 157 | with accelerator.main_process_first(): 158 | train_csv = pd.DataFrame(None) 159 | test_csv = pd.read_csv(f'data/{dataset}/test.csv') 160 | for i in range(1, 6): 161 | if i == args.model_seed: 162 | val_csv = pd.read_csv(f'data/{dataset}/train_{i}.csv') #using 1/5 train data as validation set 163 | temp_csv = pd.read_csv(f'data/{dataset}/train_{i}.csv') 164 | train_csv = pd.concat([train_csv, temp_csv], axis=0) 165 | 166 | 167 | #creat dataset and dataloader 168 | trainset = Mutation_Set(data=train_csv, fname=dataset, tokenizer=tokenizer) 169 | testset = Mutation_Set(data=test_csv, fname=dataset, tokenizer=tokenizer) 170 | valset = Mutation_Set(data=val_csv, fname=dataset, tokenizer=tokenizer) 171 | with accelerator.main_process_first(): 172 | trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=trainset.collate_fn, shuffle=True) 173 | testloader = DataLoader(testset, batch_size=2, collate_fn=testset.collate_fn) 174 | valloader = DataLoader(valset, batch_size=2, collate_fn=testset.collate_fn) 175 | 176 | trainloader = accelerator.prepare(trainloader) 177 | testloader = accelerator.prepare(testloader) 178 | valloader = accelerator.prepare(valloader) 179 | accelerator.print('==============data preparing done!================') 180 | # accelerator.print("Current allocated memory:", torch.cuda.memory_allocated()) 181 | # accelerator.print("cached:", torch.cuda.memory_reserved()) 182 | 183 | 184 | best_sr = -np.inf 185 | endure = 0 186 | best_epoch = 0 187 | 188 | for epoch in range(int(config['max_epochs'])): 189 | loss = train(model, model_reg, trainloader, optimizer, tokenizer, float(config['lambda_reg'])) 190 | accelerator.print(f'========epoch{epoch}; training loss :{loss}=================') 191 | sr = evaluate(model, valloader, tokenizer, accelerator) 192 | accelerator.print(f'========epoch{epoch}; val spearman correlation :{sr}=================') 193 | scheduler.step() 194 | if best_sr > sr: 195 | endure += 1 196 | else: 197 | endure = 0 198 | best_sr = sr 199 | best_epoch = epoch 200 | 201 | if not os.path.isdir(f'checkpoint/{dataset}'): 202 | if accelerator.is_main_process: 203 | os.makedirs(f'checkpoint/{dataset}') 204 | save_path = os.path.join('checkpoint', f'{dataset}', 205 | f'seed{args.model_seed}') 206 | accelerator.wait_for_everyone() 207 | unwrapped_model = accelerator.unwrap_model(model) 208 | unwrapped_model.save_pretrained(save_path) 209 | if sr == 1.0: 210 | accelerator.print(f'========early stop at epoch{epoch}!============') 211 | break 212 | if endure > int(config['endure_time']): 213 | accelerator.print(f'========early stop at epoch{epoch}!============') 214 | break 215 | 216 | # inference on the test sest 217 | accelerator.print('=======training done!, test the performance!========') 218 | save_path = Path(os.path.join('checkpoint', f'{dataset}', f'seed{args.model_seed}')) 219 | del basemodel 220 | del model 221 | accelerator.free_memory() 222 | 223 | if config['model'] == 'ESM-1v': 224 | basemodel = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{args.model_seed}') 225 | tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{args.model_seed}') 226 | 227 | if config['model'] == 'ESM-2': 228 | basemodel = EsmForMaskedLM.from_pretrained('facebook/esm2_t48_15B_UR50D') 229 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t48_15B_UR50D') 230 | 231 | if config['model'] == 'ESM-1b': 232 | basemodel = EsmForMaskedLM.from_pretrained('facebook/esm1b_t33_650M_UR50S') 233 | tokenizer = EsmTokenizer.from_pretrained('facebook/esm1b_t33_650M_UR50S') 234 | 235 | model = PeftModel.from_pretrained(basemodel, save_path) 236 | model = accelerator.prepare(model) 237 | sr, score, pid = evaluate(model, testloader, tokenizer, accelerator, istest=True) 238 | pred_csv = pd.DataFrame({f'{args.model_seed}': score, 'PID': pid}) 239 | if accelerator.is_main_process: 240 | if not os.path.isdir(f'predicted/{dataset}'): 241 | os.makedirs(f'predicted/{dataset}') 242 | if os.path.exists(f'predicted/{dataset}/pred.csv'): 243 | pred = pd.read_csv(f'predicted/{dataset}/pred.csv', index_col=0) 244 | pred = pd.merge(pred, pred_csv, on='PID') 245 | else: 246 | pred = pred_csv 247 | pred.to_csv(f'predicted/{dataset}/pred.csv') 248 | accelerator.print(f'=============the test spearman correlation for early stop: {sr}==================') 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | 254 | 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | appdirs==1.4.4 3 | loralib 4 | bitsandbytes==0.41.1 5 | datasets==2.12.0 6 | fire==0.5.0 7 | peft==0.5.0 8 | transformers==4.34.0 9 | gradio 10 | scipy 11 | numpy 12 | pandas 13 | pyyaml 14 | biopython 15 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # fix the downloading issues before 3 | wget --no-check-certificate "https://docs.google.com/uc?export=download&id=1jF2weMxWEi4AolGW1OT73wRsudI3i3zZ" -O "data.zip" 4 | unzip data.zip 5 | rm data.zip -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for((seed=1;seed<=5;seed++)); 3 | do 4 | accelerate launch --config_file config/parallel_config.yaml confit/train.py \ 5 | --config config/training_config.yaml \ 6 | --dataset GB1_Olson2014_ddg \ 7 | --sample_seed 0 \ 8 | --model_seed $seed 9 | done 10 | python confit/inference.py --dataset GB1_Olson2014_ddg --shot 48 --------------------------------------------------------------------------------