├── .gitignore ├── README.md ├── ckpts └── .placeholder ├── configs ├── esm.yaml ├── mutaplm_ft.yaml ├── mutaplm_inference.yaml ├── mutaplm_pt.yaml ├── ontoprotein.yaml └── random.yaml ├── data └── .placeholder ├── dataset ├── __init__.py ├── finetune_dataset.py ├── fitness_dataset.py └── literature_dataset.py ├── eval.py ├── evaluator.py ├── example.ipynb ├── logs └── .placeholder ├── metrics.py ├── model ├── __init__.py ├── esm_landscape.py ├── modeling_esm.py ├── mutaplm.py └── vanllina_esm.py ├── outputs └── .placeholder ├── scripts ├── optimize │ ├── evoprotgrad.sh │ ├── mutaplm.sh │ └── random.sh ├── test │ ├── mutaplm_engineer.sh │ └── mutaplm_explain.sh └── train │ ├── finetune.sh │ └── pretrain.sh ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | /outputs/* 3 | !/outputs/.placeholder 4 | /logs/* 5 | !/logs/.placeholder 6 | /data/* 7 | !/data/.placeholder 8 | /ckpts/* 9 | !/ckpts/.placeholder -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MutaPLM 2 | 3 | This the the official repository for the NeurIPS 2024 paper [MutaPLM: Protein Language Modeling for Mutation Explanation and Engineering](https://arxiv.org/abs/2410.22949). 4 | 5 | #### Requirements 6 | 7 | ```bash 8 | pytorch==1.13.1+cu117 9 | transformers==4.36.1 10 | peft==0.9.0 11 | pandas 12 | numpy 13 | scipy 14 | evoprotgrad 15 | nltk 16 | rouge_score 17 | sequence_models 18 | scikit-learn 19 | ``` 20 | 21 | #### Data 22 | 23 | The pre-training dataset and the **MutaDescribe** dataset are available at [HuggingFace](https://huggingface.co/datasets/icycookies/MutaDescribe). Download the data and place them under the `data` folder. 24 | 25 | #### Model Checkpoints 26 | 27 | Before running the scripts, you should: 28 | - Download the PLM checkpoint [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) and put it in `ckpts/esm2-650m`. 29 | - Download the LLM checkpoint [BioMedGPT-LM](https://huggingface.co/PharMolix/BioMedGPT-LM-7B) and put it in `ckpts/biomedgpt-lm`. If you intend to perform evaluation only, you can just download the configuration files. 30 | - Download the fine-tuned checkpoint [MutaPLM](https://huggingface.co/PharMolix/MutaPLM) and put it in `ckpts/mutaplm`. 31 | 32 | 33 | #### Implementation 34 | 35 | For pre-training on protein literature, run the following script: 36 | 37 | ```bash 38 | bash scripts/train/pretrain.sh 39 | ``` 40 | 41 | For fine-tuning on the MutaDescribe dataset, run the following script: 42 | 43 | ```bash 44 | bash scripts/train/finetune.sh 45 | ``` 46 | 47 | For evaluating MutaPLM on mutation explanation, run the following script: 48 | 49 | ```bash 50 | bash scripts/test/mutaplm_explain.sh 51 | ``` 52 | 53 | For evaluating MutaPLM on mutation engineering, run the following script: 54 | 55 | ```bash 56 | bash scripts/test/mutaplm_engineer.sh 57 | ``` 58 | 59 | #### Citation 60 | ``` 61 | @misc{luo2024mutaplm, 62 | title={MutaPLM: Protein Language Modeling for Mutation Explanation and Engineering}, 63 | author={Yizhen Luo and Zikun Nie and Massimo Hong and Suyuan Zhao and Hao Zhou and Zaiqing Nie}, 64 | year={2024}, 65 | eprint={2410.22949}, 66 | archivePrefix={arXiv}, 67 | primaryClass={cs.LG}, 68 | url={https://arxiv.org/abs/2410.22949}, 69 | } 70 | ``` -------------------------------------------------------------------------------- /ckpts/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/MutaPLM/495815b2069419d19d9449a59070d0fb24b596d1/ckpts/.placeholder -------------------------------------------------------------------------------- /configs/esm.yaml: -------------------------------------------------------------------------------- 1 | path: ./ckpts/esm2-650m 2 | protein_maxlen: 1026 -------------------------------------------------------------------------------- /configs/mutaplm_ft.yaml: -------------------------------------------------------------------------------- 1 | protein_model: "./ckpts/esm2-650m" 2 | llama_ckpt: "./ckpts/biomedgpt-lm" 3 | num_query_tokens_protein1: 64 4 | num_query_tokens_protein2: 64 5 | ca_num_head: 8 6 | protein_maxlen: 1026 7 | text_maxlen: 256 8 | func_maxlen: 256 9 | m2t: true 10 | t2m: true 11 | test_mode: false 12 | pretrain: false -------------------------------------------------------------------------------- /configs/mutaplm_inference.yaml: -------------------------------------------------------------------------------- 1 | protein_model: "./ckpts/esm2-650m" 2 | llama_ckpt: "./ckpts/biomedgpt-lm" 3 | num_query_tokens_protein1: 64 4 | num_query_tokens_protein2: 64 5 | ca_num_head: 8 6 | protein_maxlen: 1026 7 | text_maxlen: 512 8 | func_maxlen: 256 9 | m2t: true 10 | t2m: false 11 | test_mode: true 12 | pretrain: false -------------------------------------------------------------------------------- /configs/mutaplm_pt.yaml: -------------------------------------------------------------------------------- 1 | protein_model: "./ckpts/esm2-650m" 2 | llama_ckpt: "./ckpts/biomedgpt-lm" 3 | num_query_tokens_protein1: 64 4 | num_query_tokens_protein2: 64 5 | ca_num_head: 8 6 | protein_maxlen: 1026 7 | text_maxlen: 512 8 | func_maxlen: 512 9 | m2t: true 10 | t2m: true 11 | test_mode: false 12 | pretrain: true -------------------------------------------------------------------------------- /configs/ontoprotein.yaml: -------------------------------------------------------------------------------- 1 | path: ./ckpts/ontoprotein 2 | ontoprotein: true 3 | protein_maxlen: 1026 -------------------------------------------------------------------------------- /configs/random.yaml: -------------------------------------------------------------------------------- 1 | path: ./ckpts/esm2-650m -------------------------------------------------------------------------------- /data/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/MutaPLM/495815b2069419d19d9449a59070d0fb24b596d1/data/.placeholder -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.literature_dataset import LiteratureDataset 2 | from dataset.finetune_dataset import MutaDescribeDataset 3 | from dataset.fitness_dataset import FitnessDataset 4 | 5 | dataset_name2cls = { 6 | "literature": LiteratureDataset, 7 | "mutadescribe": MutaDescribeDataset, 8 | "AAV": FitnessDataset, 9 | "AMIE": FitnessDataset, 10 | "avGFP": FitnessDataset, 11 | "E4B": FitnessDataset, 12 | "LGK": FitnessDataset, 13 | "UBE2I": FitnessDataset, 14 | } -------------------------------------------------------------------------------- /dataset/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | class MutaDescribeDataset(Dataset): 7 | def __init__(self, path, split="train", data_percent=1.0, **kwargs) -> None: 8 | """ path: path to dataset.csv """ 9 | super().__init__() 10 | self.df = pd.read_csv(path) 11 | self.df = self.df[:int(len(self.df)*data_percent)] # 5% 12 | 13 | def __getitem__(self, index): 14 | site = self.df["entry"][index].split("-")[1] 15 | prot1 = self.df["protein1"][index] 16 | prot2 = self.df["protein2"][index] 17 | uni_despt = self.df["uniprot_description"][index] if not pd.isna(self.df["uniprot_description"][index]) else '' 18 | GPT_despt = self.df["GPT_description"][index] if not pd.isna(self.df["GPT_description"][index]) else '' 19 | prot_function = self.df["function"][index] 20 | template = "Next is a feature of the mutation {} to {} at position {}. Please generate a {} text to describe it." 21 | mut_prompt = template.format(site[0], site[-1], int(site[1:-1]), "long detailed" if len(GPT_despt) >= 1 else "brief summary") 22 | return prot1, prot2, site, (uni_despt + ' ' + GPT_despt).strip(), prot_function, mut_prompt 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def get_example(self): 28 | for i in range(len(self)): 29 | prot1, prot2, site, desc, _, _ = self[i] 30 | yield "Wild Type:" + prot1[:20] + "...\t" + "Site: " + site + "\tMutation Effect: " + desc[:50] + "..." 31 | raise RuntimeError("Number of examples exceed dataset length!") -------------------------------------------------------------------------------- /dataset/fitness_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | import re 6 | from torch.utils.data import Dataset 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | name2target = { 11 | "AAV": "Strongly increased viability for packaging of a DNA payload for gene therapy.", 12 | "AMIE": "Increase in activity.", 13 | "avGFP": "Leads to enhanced fluorescence at 37 degrees Celsius.", 14 | "E4B": "Enhances cleavage by caspase-6 and granzyme B.", 15 | "LGK": "Increase in activity.", 16 | "UBE2I": "Increased growth rescue rate at high temperature in a yeast strain." 17 | } 18 | 19 | name2prompt = { 20 | "AAV": "Capsid protein self-assembles to form an icosahedral capsid with a T=1 symmetry, about 22 nm in diameter, and consisting of 60 copies of three size variants of the capsid protein VP1, VP2 and VP3 which differ in their N-terminus. The capsid encapsulates the genomic ssDNA. Binds to host cell heparan sulfate and uses host ITGA5-ITGB1 as coreceptor on the cell surface to provide virion attachment to target cell. This attachment induces virion internalization predominantly through clathrin-dependent endocytosis. Binding to the host receptor also induces capsid rearrangements leading to surface exposure of VP1 N-terminus, specifically its phospholipase A2-like region and putative nuclear localization signal(s). VP1 N-terminus might serve as a lipolytic enzyme to breach the endosomal membrane during entry into host cell and might contribute to virus transport to the nucleus.", 21 | "AMIE": "Catalyzes the hydrolysis of short-chain aliphatic amides to their corresponding organic acids with release of ammonia.", 22 | "avGFP": "Energy-transfer acceptor. Its role is to transduce the blue chemiluminescence of the protein aequorin into green fluorescent light by energy transfer. Fluoresces in vivo upon receiving energy from the Ca(2+)-activated photoprotein aequorin.", 23 | "E4B": "Ubiquitin-protein ligase that probably functions as an E3 ligase in conjunction with specific E1 and E2 ligases. May also function as an E4 ligase mediating the assembly of polyubiquitin chains on substrates ubiquitinated by another E3 ubiquitin ligase. May regulate myosin assembly in striated muscles together with STUB1 and VCP/p97 by targeting myosin chaperone UNC45B for proteasomal degradation.", 24 | "LGK": "Levoglucosan kinase that catalyzes the transfer of a phosphate group from ATP to levoglucosan (1,6-anhydro-beta-D-glucopyranose, LG) to yield glucose 6-phosphate in the presence of magnesium ion and ATP. In addition to the canonical kinase phosphotransfer reaction, the conversion requires cleavage of the 1,6-anhydro ring to allow ATP-dependent phosphorylation of the sugar O-6 atom.", 25 | "UBE2I": "Accepts the ubiquitin-like proteins SUMO1, SUMO2, SUMO3, SUMO4 and SUMO1P1/SUMO5 from the UBLE1A-UBLE1B E1 complex and catalyzes their covalent attachment to other proteins with the help of an E3 ligase such as RANBP2, CBX4 and ZNF451. Can catalyze the formation of poly-SUMO chains. Necessary for sumoylation of FOXL2 and KAT5. Essential for nuclear architecture and chromosome segregation. Sumoylates p53/TP53 at 'Lys-386'. Mediates sumoylation of ERCC6 which is essential for its transcription-coupled nucleotide excision repair activity", 26 | } 27 | 28 | class FitnessDataset(Dataset): 29 | def __init__(self, path, split="valid", name=None, nshot=None, **kwargs): 30 | super().__init__() 31 | if os.path.exists(os.path.join(path, "wild_type.json")): 32 | self.wild_type = json.load(open(os.path.join(path, "wild_type.json"), "r"))["seq"] 33 | else: 34 | self.wild_type = json.load(open(os.path.join(path, "starting_sequence.json"), "r")) 35 | data = json.load(open(os.path.join(path, split + ".json"), "r")) 36 | self.data = [] 37 | 38 | for i in range(len(data)): 39 | if len(data[i]["seq"]) == len(self.wild_type): 40 | self.data.append(data[i]) 41 | if nshot is not None: 42 | perm = np.random.permutation(len(self.data))[:nshot] 43 | new_data = [self.data[i] for i in perm] 44 | self.data = new_data 45 | self.starting_sequence = json.load(open(os.path.join(path, "starting_sequence.json"), "r")) 46 | self.prompt = name2prompt[path.split("/")[-1]] 47 | self.target = name2target[path.split("/")[-1]] 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | def get_example(self): 53 | for i in range(len(self)): 54 | yield "Sequence:" + self.data[i]["seq"][:30] + "...\tFitness:" + str(self.data[i]["fitness"][0]) + "\tNum mutations:" + str(self.data[i]["num_mutations"]) 55 | raise RuntimeError("Number of examples exceed dataset length!") 56 | 57 | def __getitem__(self, index): 58 | return self.data[index]["seq"], self.data[index]["fitness"][0] -------------------------------------------------------------------------------- /dataset/literature_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | import re 6 | 7 | class LiteratureDataset(Dataset): 8 | def __init__(self, path, **kwargs) -> None: 9 | super().__init__() 10 | self.uniprot2pubmed = json.load(open(os.path.join(path, "uniprot_pubmed.json"), "r")) 11 | self.uniprot2seq = {} 12 | self.uniprot2func = {} 13 | uniprot_data = json.load(open(os.path.join(path, "uniprot_accession.json"), "r")) 14 | keys = set() 15 | for id in uniprot_data: 16 | for key in uniprot_data[id]: 17 | if key not in keys: 18 | keys.add(key) 19 | if "Sequence" in uniprot_data[id] and id in self.uniprot2pubmed: 20 | self.uniprot2seq[id] = uniprot_data[id]["Sequence"] 21 | if "Description" in uniprot_data[id]: 22 | pattern1 = r'\(PubMed:\d+(, PubMed:\d+)*\)' 23 | pattern2 = r'\(By similarity\)' 24 | self.uniprot2func[id] = "; ".join([re.sub(pattern2, '', re.sub(pattern1, '', text)) for text in uniprot_data[id]["Description"]]) 25 | self.uniprot_ids = list(self.uniprot2seq.keys()) 26 | self.pubmed_corpus = {} 27 | with open(os.path.join(path, "corpus.jsonl"), "r") as f: 28 | for line in f.readlines(): 29 | data = json.loads(line) 30 | if data["title"] is not None and data["abstract"] is not None: 31 | self.pubmed_corpus[data["pubmed"]] = data["title"] + " " + data["abstract"] 32 | elif data["title"] is not None: 33 | self.pubmed_corpus[data["pubmed"]] = data["title"] 34 | elif data["abstract"] is not None: 35 | self.pubmed_corpus[data["pubmed"]] = data["abstract"] 36 | for id in self.uniprot2pubmed: 37 | for i, pubmed_id in enumerate(self.uniprot2pubmed[id]): 38 | if pubmed_id not in self.pubmed_corpus: 39 | self.uniprot2pubmed[id].pop(i) 40 | 41 | def get_by_uniport(self, id): 42 | print(id) 43 | if id in self.uniprot2func: 44 | print("Function:", self.uniprot2func[id]) 45 | print("---------------------------------------------") 46 | for j in self.uniprot2pubmed[id]: 47 | print(self.pubmed_corpus[j]) 48 | 49 | def __len__(self): 50 | return len(self.uniprot_ids) 51 | 52 | def __getitem__(self, index): 53 | id = self.uniprot_ids[index] 54 | seq = self.uniprot2seq[id] 55 | text_id = random.sample(self.uniprot2pubmed[id], k=1)[0] 56 | return seq, self.pubmed_corpus[text_id] 57 | 58 | def get_example(self): 59 | for i in range(len(self)): 60 | seq, text = self[i] 61 | yield "Accession: " + self.uniprot_ids[i] + "\tSequence:" + seq[:30] + "...\tText:" + text[:100] 62 | raise RuntimeError("Number of examples exceed dataset length!") 63 | 64 | if __name__ == "__main__": 65 | dataset = LiteratureDataset("./data/pubs") 66 | cnt = 0 67 | print(len(dataset)) 68 | for i in range(len(dataset)): 69 | cnt += len(dataset.uniprot2pubmed[dataset.uniprot_ids[i]]) 70 | print(cnt) 71 | print(dataset.get_by_uniport("B3VI55")) -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from evaluator import Evaluator, MutaExplainEvaluator, MutaEngineerEvaluator, FitnessOptimizeEvaluator 3 | 4 | def add_arguments(parser): 5 | parser.add_argument("--muta_explain", action="store_true") 6 | parser.add_argument("--muta_engineer", action="store_true") 7 | parser.add_argument("--fitness_optimize", action="store_true") 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | add_arguments(parser) 12 | args = parser.parse_known_args()[0] 13 | 14 | if args.muta_explain: 15 | cls = MutaExplainEvaluator 16 | elif args.muta_engineer: 17 | cls = MutaEngineerEvaluator 18 | if args.fitness_optimize: 19 | cls = FitnessOptimizeEvaluator 20 | 21 | parser = cls.add_arguments(parser) 22 | args = parser.parse_args() 23 | evaluator = cls(args) 24 | evaluator.evaluate() -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | from collections import OrderedDict 7 | import copy 8 | import yaml 9 | from tqdm import tqdm 10 | from nltk.translate.bleu_score import corpus_bleu 11 | from nltk.translate.meteor_score import meteor_score 12 | from rouge_score import rouge_scorer 13 | from transformers import EsmForMaskedLM, EsmTokenizer, BertForMaskedLM, BertTokenizer 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | from torch.cuda.amp import autocast 19 | from torch.utils.data import DataLoader 20 | 21 | from dataset import dataset_name2cls 22 | from dataset.fitness_dataset import name2prompt, name2target 23 | from model import model_name2cls 24 | from model.esm_landscape import EsmForLandscapeRegression 25 | 26 | class Evaluator(ABC): 27 | @staticmethod 28 | def add_arguments(parser): 29 | parser.add_argument("--dataset_name", type=str, default="mutadescribe") 30 | parser.add_argument("--dataset_path", type=str, default="./data/") 31 | parser.add_argument("--model_name", type=str, default="mutaplm") 32 | parser.add_argument("--model_config_path", type=str, default="./configs/mutaplm_inference.yaml") 33 | parser.add_argument("--model_checkpoint", type=str, default=None) 34 | parser.add_argument("--pred_save_path", type=str, default="./outputs/pred.txt") 35 | parser.add_argument("--batch_size", type=int, default=64) 36 | parser.add_argument("--num_workers", type=int, default=1) 37 | parser.add_argument("--device", type=int, default=0) 38 | return parser 39 | 40 | def __init__(self, args) -> None: 41 | super().__init__() 42 | self.args = args 43 | self.device = torch.device("cuda", self.args.device) 44 | # self.device = torch.device("cpu") 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 47 | datefmt="%m/%d/%Y %H:%M:%S", 48 | level=logging.INFO, 49 | ) 50 | self._setup_data() 51 | self._setup_model() 52 | 53 | def _setup_data(self): 54 | logger.info(f"Loading dataset {self.args.dataset_name}...") 55 | self.dataset = dataset_name2cls[self.args.dataset_name](self.args.dataset_path) 56 | logger.info(f"Num Samples: {len(self.dataset)}",) 57 | if hasattr(self.dataset, "get_example"): 58 | for i, example in enumerate(self.dataset.get_example()): 59 | if i >= 2: 60 | break 61 | logger.info(example) 62 | self.dataloader = DataLoader( 63 | self.dataset, 64 | batch_size=self.args.batch_size, 65 | shuffle=False, 66 | num_workers=self.args.num_workers, 67 | ) 68 | 69 | def _setup_model(self): 70 | logger.info("Loading model...") 71 | model_cls = model_name2cls[self.args.model_name] 72 | model_cfg = yaml.load(open(self.args.model_config_path, "r"), Loader=yaml.Loader) 73 | model_cfg["device"] = self.device 74 | self.model = model_cls(**model_cfg).to(self.device) 75 | 76 | if self.args.model_checkpoint is not None: 77 | logger.info(f"Load model checkpoint from {self.args.model_checkpoint}") 78 | state_dict = torch.load(open(self.args.model_checkpoint, "rb"), map_location="cpu") 79 | new_ckpt = state_dict["model"] 80 | self.model.load_state_dict(new_ckpt, strict=False) 81 | 82 | @abstractmethod 83 | def evaluate(self): 84 | raise NotImplementedError 85 | 86 | class MutaExplainEvaluator(Evaluator): 87 | def __init__(self, args) -> None: 88 | super().__init__(args) 89 | 90 | def evaluate(self): 91 | logger.info("Start evaluation!") 92 | self.model.eval() 93 | all_preds_func, all_labels_func = [], [] 94 | all_preds_mut, all_labels_mut = [], [] 95 | all_preds_func_tokens, all_labels_func_tokens = [], [] 96 | all_preds_mut_tokens, all_labels_mut_tokens = [], [] 97 | meteor_func, meteor_mut = [], [] 98 | with open(self.args.pred_save_path, "w") as f: 99 | f.write("Site\tPred_Func\tLabel_Func\tPred_Effect\tLabel_Effect\n") 100 | for i, data in enumerate(tqdm(self.dataloader)): 101 | with torch.no_grad(): 102 | with autocast(dtype=torch.bfloat16): 103 | preds_func, preds_mut = self.model.generate(data[0], data[1], data[5], pfunction=data[4]) 104 | for j in range(len(data[-1])): 105 | all_preds_func.append(preds_func[j]) 106 | all_labels_func.append(data[4][j]) 107 | all_preds_mut.append(preds_mut[j]) 108 | all_labels_mut.append(data[3][j]) 109 | with open(self.args.pred_save_path, "a+") as f: 110 | f.write(f"{data[2][j]}\t{preds_func[j]}\t{data[4][j]}\t{preds_mut[j]}\t{data[3][j]}\n") 111 | all_preds_func_tokens.append(preds_func[j].split(" ")) 112 | all_labels_func_tokens.append([data[4][j].split(" ")]) 113 | meteor_func.append(meteor_score(all_labels_func_tokens[-1], all_preds_func_tokens[-1])) 114 | all_preds_mut_tokens.append(preds_mut[j].split(" ")) 115 | all_labels_mut_tokens.append([data[3][j].split(" ")]) 116 | meteor_mut.append(meteor_score(all_labels_mut_tokens[-1], all_preds_mut_tokens[-1])) 117 | 118 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL']) 119 | scores_func, scores_mut = [], [] 120 | for i in range(len(all_preds_func)): 121 | scores_func.append(scorer.score(all_preds_func[i], all_labels_func[i])) 122 | scores_mut.append(scorer.score(all_preds_mut[i], all_labels_mut[i])) 123 | bleu2_func = corpus_bleu(all_labels_func_tokens, all_preds_func_tokens, weights=(0.5, 0.5)) 124 | bleu4_func = corpus_bleu(all_labels_func_tokens, all_preds_func_tokens, weights=(0.25, 0.25, 0.25, 0.25)) 125 | print("------------Function--------------") 126 | print("BLEU-2 = %.4lf" % bleu2_func) 127 | print("BLEU-4 = %.4lf" % bleu4_func) 128 | print("METEOR = %.4lf" % np.mean(meteor_func)) 129 | print("ROUGE-1 = %.4lf" % (np.mean([rs['rouge1'].fmeasure for rs in scores_func]))) 130 | print("ROUGE-2 = %.4lf" % (np.mean([rs['rouge2'].fmeasure for rs in scores_func]))) 131 | print("ROUGE-L = %.4lf" % (np.mean([rs['rougeL'].fmeasure for rs in scores_func]))) 132 | print("------------Mutation--------------") 133 | bleu2_mut = corpus_bleu(all_labels_mut_tokens, all_preds_mut_tokens, weights=(0.5, 0.5)) 134 | bleu4_mut = corpus_bleu(all_labels_mut_tokens, all_preds_mut_tokens, weights=(0.25, 0.25, 0.25, 0.25)) 135 | print("BLEU-2 = %.4lf" % bleu2_mut) 136 | print("BLEU-4 = %.4lf" % bleu4_mut) 137 | print("METEOR = %.4lf" % np.mean(meteor_mut)) 138 | print("ROUGE-1 = %.4lf" % (np.mean([rs['rouge1'].fmeasure for rs in scores_mut]))) 139 | print("ROUGE-2 = %.4lf" % (np.mean([rs['rouge2'].fmeasure for rs in scores_mut]))) 140 | print("ROUGE-L = %.4lf" % (np.mean([rs['rougeL'].fmeasure for rs in scores_mut]))) 141 | 142 | class MutaEngineerEvaluator(Evaluator): 143 | def __init__(self, args): 144 | super().__init__(args) 145 | 146 | def evaluate(self): 147 | logger.info("Start evaluation!") 148 | self.model.eval() 149 | all_preds = [] 150 | all_preds_with_pos = [] 151 | all_labels = [] 152 | for data in tqdm(self.dataloader): 153 | with torch.no_grad(): 154 | with autocast(dtype=torch.bfloat16): 155 | preds = self.model.lm_design(data[0], data[3], pfunction=data[-2], muta_prompt=data[-1]) 156 | all_labels += data[2] 157 | all_preds.append(preds) 158 | pos = torch.tensor([int(x[1:-1]) for x in data[2]]) 159 | preds = preds[torch.arange(len(data[-1])), pos] 160 | all_preds_with_pos.append(torch.argmax(preds, dim=1)) 161 | all_preds = torch.cat(all_preds, dim=0).flatten(1, 2) 162 | all_preds_with_pos = "".join(self.model.protein_tokenizer.decode(torch.cat(all_preds_with_pos)).split(" ")) 163 | 164 | all_pos = [] 165 | all_aa = [] 166 | for i in range(len(all_preds)): 167 | top50 = all_preds[i].topk(50).indices 168 | all_pos.append(top50 // len(self.model.protein_tokenizer)) 169 | all_aa.append(top50 % len(self.model.protein_tokenizer)) 170 | all_aa = self.model.protein_tokenizer.batch_decode(torch.stack(all_aa, dim=0)) 171 | 172 | acc, rec = 0, 0 173 | with open(self.args.pred_save_path, "w") as f: 174 | f.write("Labels\tPreds\tSequence\n") 175 | for i in range(len(all_preds)): 176 | seq = self.dataset[i][0] 177 | all_aa[i] = "".join(all_aa[i].split(" ")) 178 | preds = [] 179 | for j in range(50): 180 | pos = all_pos[i][j].item() 181 | preds.append(seq[pos - 1] + str(pos) + all_aa[i][j]) 182 | f.write(all_labels[i] + "\t" + ",".join(preds[:10]) + "\t" + self.dataset[i][0] + "\n") 183 | if all_labels[i] in preds: 184 | rec += 1 185 | if all_preds_with_pos[i] == all_labels[i][-1]: 186 | acc += 1 187 | print("Accuracy = ", acc / len(all_labels)) 188 | print("Recall@50 = ", rec / len(all_labels)) 189 | 190 | class FitnessOptimizeEvaluator(Evaluator): 191 | @staticmethod 192 | def add_arguments(parser): 193 | parser = Evaluator.add_arguments(parser) 194 | parser.add_argument("--surrogate_path", type=str, default="./ckpts/landscape_ckpts/") 195 | parser.add_argument("--num_candidates", type=int, default=100) 196 | parser.add_argument("--num_rounds", type=int, default=10) 197 | parser.add_argument("--score_save_path", type=str, default="./outputs/") 198 | parser.add_argument("--evo_prot_grad", action="store_true") 199 | return parser 200 | 201 | def __init__(self, args) -> None: 202 | super().__init__(args) 203 | self._setup_surrogate() 204 | self.prompt = name2prompt[self.args.dataset_name] 205 | self.target = name2target[self.args.dataset_name] 206 | if self.args.evo_prot_grad: 207 | import evo_prot_grad 208 | self.expert = evo_prot_grad.get_expert( 209 | expert_name="bert", 210 | model=BertForMaskedLM.from_pretrained("./ckpts/protein_ckpts/ontoprotein"), 211 | tokenizer=BertTokenizer.from_pretrained("./ckpts/protein_ckpts/ontoprotein"), 212 | device=self.device, 213 | temperature=1.0 214 | ) 215 | 216 | def _setup_surrogate(self): 217 | self.surrogate = EsmForLandscapeRegression("./ckpts/protein_ckpts/esm1b", self.args.surrogate_path, self.device) 218 | self.surrogate.to(self.device) 219 | self.surrogate.eval() 220 | 221 | def evaluate(self): 222 | logger.info("Start evaluation!") 223 | self.model.eval() 224 | 225 | mx_scores, mean_scores = [], [] 226 | cur_fitness = 0 227 | for i in tqdm(range(20)): 228 | protein = [self.dataset.starting_sequence] 229 | prev_scores = torch.tensor([0.0]) 230 | cur_fitness = self.surrogate(protein).item() 231 | print("Initial protein fitness:", cur_fitness) 232 | print("Function:", self.prompt) 233 | print("Target:", self.target) 234 | if self.args.evo_prot_grad: 235 | import evo_prot_grad 236 | cur_mx_scores, cur_mean_scores = [], [] 237 | all_proteins = [[] for j in range(self.args.num_rounds)] 238 | for j in range(self.args.num_candidates // 10): 239 | new_proteins, scores = evo_prot_grad.DirectedEvolution( 240 | n_steps=10, 241 | max_mutations=self.args.num_rounds + 1, 242 | wt_protein=protein[0], 243 | parallel_chains=10, 244 | experts=[self.expert], 245 | output='all', 246 | random_seed=i+42 247 | )() 248 | for round in range(self.args.num_rounds): 249 | all_proteins[round] += ["".join(p.split(" ")) for p in new_proteins[round]] 250 | 251 | for round in range(self.args.num_rounds): 252 | round_scores = [] 253 | with torch.no_grad(): 254 | for batch in range((self.args.num_candidates - 1) // self.args.batch_size + 1): 255 | st, ed = batch * self.args.batch_size, min(len(all_proteins[round]), (batch + 1) * self.args.batch_size) 256 | round_scores.append(self.surrogate(all_proteins[round][st:ed])) 257 | round_scores = torch.cat(round_scores, dim=0) 258 | cur_mx_scores.append(torch.max(round_scores).item()) 259 | if round >= 1: 260 | cur_mx_scores[-1] = max(cur_mx_scores[-1], cur_mx_scores[-2]) 261 | cur_mean_scores.append(torch.mean(round_scores).item()) 262 | mx_scores.append(cur_mx_scores) 263 | mean_scores.append(cur_mean_scores) 264 | else: 265 | torch.random.manual_seed(i) 266 | cur_mx_scores, cur_mean_scores = [], [] 267 | for round in range(self.args.num_rounds): 268 | with torch.no_grad(): 269 | with autocast(dtype=torch.bfloat16): 270 | all_preds = [] 271 | for batch in range((len(protein) - 1) // self.args.batch_size + 1): 272 | st, ed = batch * self.args.batch_size, min(len(protein), (batch + 1) * self.args.batch_size) 273 | preds = self.model.lm_design( 274 | protein[st:ed], 275 | muta_prompt=["Not Available"] * (ed - st), 276 | pfunction=[self.prompt] * (ed - st), 277 | text=[self.target] * (ed - st), 278 | use_gt_function=True 279 | ) 280 | preds += prev_scores[st:ed].to(self.device).view(ed - st, 1, 1).expand(ed - st, preds.shape[1], preds.shape[2]) 281 | all_preds.append(preds) 282 | preds = torch.cat(all_preds, dim=0) 283 | topk = torch.multinomial(preds.flatten(), self.args.num_candidates) 284 | # topk = torch.topk(preds.flatten(), self.args.num_candidates) 285 | indices = topk 286 | idx = indices // (preds.shape[1] * 33) 287 | pos = indices % (preds.shape[1] * 33) // 33 288 | aa = self.model.protein_tokenizer.batch_decode(indices % 33) 289 | print(pos, aa) 290 | prev_scores = preds.flatten()[indices] 291 | new_protein = [] 292 | for j in range(self.args.num_candidates): 293 | cur = protein[idx[j].item()] 294 | new_protein.append(cur[:pos[j].item() - 1] + aa[j] + cur[pos[j].item():]) 295 | 296 | protein = new_protein 297 | round_scores = [] 298 | with torch.no_grad(): 299 | for batch in range((self.args.num_candidates - 1) // self.args.batch_size + 1): 300 | st, ed = batch * self.args.batch_size, min(len(protein), (batch + 1) * self.args.batch_size) 301 | round_scores.append(self.surrogate(protein[st:ed])) 302 | round_scores = torch.cat(round_scores, dim=0) 303 | print(protein, round_scores) 304 | cur_mx_scores.append(torch.max(round_scores).item()) 305 | if round >= 1: 306 | cur_mx_scores[-1] = max(cur_mx_scores[-1], cur_mx_scores[-2]) 307 | cur_mean_scores.append(torch.mean(round_scores).item()) 308 | mx_scores.append(cur_mx_scores) 309 | mean_scores.append(cur_mean_scores) 310 | mx_scores = np.array(mx_scores) 311 | print("Max scores:") 312 | for i in range(self.args.num_rounds): 313 | print("Round ", i, " Fitness=", np.mean(mx_scores[:, i]), "\pm", np.var(mx_scores[:, i])) 314 | mean_scores = np.array(mean_scores) 315 | print("Avg scores:") 316 | for i in range(self.args.num_rounds): 317 | print("Round ", i, " Fitness=", np.mean(mean_scores[:, i]), "\pm", np.var(mx_scores[:, i])) 318 | with open(self.args.score_save_path, "a") as f: 319 | f.write(self.args.dataset_name + "\t") 320 | f.write("%.4lf" % (cur_fitness) + "\t") 321 | for i in range(self.args.num_rounds): 322 | f.write("%.4lf" % (np.mean(mx_scores[:, i])) + "\t" + "%.4lf" % (np.var(mx_scores[:, i])) + "\t") 323 | f.write("\n") -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/mnt/niezk_dair/anaconda3/envs/openbiomed/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "*** loading protein model...\n" 21 | ] 22 | }, 23 | { 24 | "name": "stderr", 25 | "output_type": "stream", 26 | "text": [ 27 | "Some weights of EsmForMutationDesign were not initialized from the model checkpoint at /data3/niezk/model/esm/esm2_t33_650M_UR50D and are newly initialized: ['esm.encoder.layer.32.crossattention_adapter.self.query.bias', 'esm.encoder.layer.32.crossattention_adapter.self.value.bias', 'esm.encoder.layer.32.crossattention_adapter.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.32.crossattention_adapter.self.key.weight', 'esm.encoder.layer.32.crossattention_adapter.self.key.bias', 'mutation_classifier.bias', 'esm.encoder.layer.32.crossattention_adapter.self.value.weight', 'esm.encoder.layer.32.crossattention_adapter.output.dense.weight', 'esm.encoder.layer.32.intermediate_adapter.dense.weight', 'esm.encoder.layer.32.intermediate_adapter.dense.bias', 'esm.encoder.layer.32.crossattention_adapter.LayerNorm.weight', 'esm.encoder.layer.32.output_adapter.dense.bias', 'esm.encoder.layer.32.crossattention_adapter.self.query.weight', 'mutation_classifier.weight', 'esm.encoder.layer.32.LayerNorm_adapter.bias', 'esm.encoder.layer.32.output_adapter.dense.weight', 'esm.encoder.layer.32.LayerNorm_adapter.weight', 'esm.encoder.layer.32.crossattention_adapter.LayerNorm.bias', 'esm.encoder.layer.32.crossattention_adapter.output.dense.bias']\n", 28 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 29 | ] 30 | }, 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "*** freezing protein model...\n", 36 | "*** loading llm tokenizer...\n" 37 | ] 38 | }, 39 | { 40 | "name": "stderr", 41 | "output_type": "stream", 42 | "text": [ 43 | "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n" 44 | ] 45 | }, 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "*** loading llm from /data3/niezk/model/biomedgpt-lm...\n", 51 | "*** adding LoRA...\n", 52 | "trainable params: 0 || all params: 6,774,206,464 || trainable%: 0.0\n", 53 | "*** building delta encoder...\n", 54 | "*** model built successfully.\n" 55 | ] 56 | }, 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "MutaPLM(\n", 61 | " (protein_model): EsmForMutationDesign(\n", 62 | " (esm): EsmModel(\n", 63 | " (embeddings): EsmEmbeddings(\n", 64 | " (word_embeddings): Embedding(33, 1280, padding_idx=1)\n", 65 | " (dropout): Dropout(p=0.0, inplace=False)\n", 66 | " (position_embeddings): Embedding(1026, 1280, padding_idx=1)\n", 67 | " )\n", 68 | " (encoder): EsmEncoder(\n", 69 | " (layer): ModuleList(\n", 70 | " (0-31): 32 x EsmLayer(\n", 71 | " (attention): EsmAttention(\n", 72 | " (self): EsmSelfAttention(\n", 73 | " (query): Linear(in_features=1280, out_features=1280, bias=True)\n", 74 | " (key): Linear(in_features=1280, out_features=1280, bias=True)\n", 75 | " (value): Linear(in_features=1280, out_features=1280, bias=True)\n", 76 | " (dropout): Dropout(p=0.0, inplace=False)\n", 77 | " (rotary_embeddings): RotaryEmbedding()\n", 78 | " )\n", 79 | " (output): EsmSelfOutput(\n", 80 | " (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", 81 | " (dropout): Dropout(p=0.0, inplace=False)\n", 82 | " )\n", 83 | " (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 84 | " )\n", 85 | " (intermediate): EsmIntermediate(\n", 86 | " (dense): Linear(in_features=1280, out_features=5120, bias=True)\n", 87 | " )\n", 88 | " (output): EsmOutput(\n", 89 | " (dense): Linear(in_features=5120, out_features=1280, bias=True)\n", 90 | " (dropout): Dropout(p=0.0, inplace=False)\n", 91 | " )\n", 92 | " (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 93 | " )\n", 94 | " (32): EsmLayer(\n", 95 | " (attention): EsmAttention(\n", 96 | " (self): EsmSelfAttention(\n", 97 | " (query): Linear(in_features=1280, out_features=1280, bias=True)\n", 98 | " (key): Linear(in_features=1280, out_features=1280, bias=True)\n", 99 | " (value): Linear(in_features=1280, out_features=1280, bias=True)\n", 100 | " (dropout): Dropout(p=0.0, inplace=False)\n", 101 | " (rotary_embeddings): RotaryEmbedding()\n", 102 | " )\n", 103 | " (output): EsmSelfOutput(\n", 104 | " (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", 105 | " (dropout): Dropout(p=0.0, inplace=False)\n", 106 | " )\n", 107 | " (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 108 | " )\n", 109 | " (crossattention_adapter): EsmAttention(\n", 110 | " (self): EsmSelfAttention(\n", 111 | " (query): Linear(in_features=1280, out_features=1280, bias=True)\n", 112 | " (key): Linear(in_features=1280, out_features=1280, bias=True)\n", 113 | " (value): Linear(in_features=1280, out_features=1280, bias=True)\n", 114 | " (dropout): Dropout(p=0.0, inplace=False)\n", 115 | " (rotary_embeddings): RotaryEmbedding()\n", 116 | " )\n", 117 | " (output): EsmSelfOutput(\n", 118 | " (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", 119 | " (dropout): Dropout(p=0.0, inplace=False)\n", 120 | " )\n", 121 | " (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 122 | " )\n", 123 | " (intermediate_adapter): EsmIntermediate(\n", 124 | " (dense): Linear(in_features=1280, out_features=640, bias=True)\n", 125 | " )\n", 126 | " (output_adapter): EsmOutput(\n", 127 | " (dense): Linear(in_features=640, out_features=1280, bias=True)\n", 128 | " (dropout): Dropout(p=0.0, inplace=False)\n", 129 | " )\n", 130 | " (LayerNorm_adapter): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 131 | " (intermediate): EsmIntermediate(\n", 132 | " (dense): Linear(in_features=1280, out_features=5120, bias=True)\n", 133 | " )\n", 134 | " (output): EsmOutput(\n", 135 | " (dense): Linear(in_features=5120, out_features=1280, bias=True)\n", 136 | " (dropout): Dropout(p=0.0, inplace=False)\n", 137 | " )\n", 138 | " (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 139 | " )\n", 140 | " )\n", 141 | " (emb_layer_norm_after): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 142 | " )\n", 143 | " (contact_head): EsmContactPredictionHead(\n", 144 | " (regression): Linear(in_features=660, out_features=1, bias=True)\n", 145 | " (activation): Sigmoid()\n", 146 | " )\n", 147 | " )\n", 148 | " (dropout): Dropout(p=0.0, inplace=False)\n", 149 | " (mutation_classifier): Linear(in_features=1280, out_features=2, bias=True)\n", 150 | " (lm_head): EsmLMHead(\n", 151 | " (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", 152 | " (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", 153 | " (decoder): Linear(in_features=1280, out_features=33, bias=False)\n", 154 | " )\n", 155 | " )\n", 156 | " (llm): PeftModel(\n", 157 | " (base_model): LoraModel(\n", 158 | " (model): LlamaForCausalLM(\n", 159 | " (model): LlamaModel(\n", 160 | " (embed_tokens): Embedding(32001, 4096)\n", 161 | " (layers): ModuleList(\n", 162 | " (0-31): 32 x LlamaDecoderLayer(\n", 163 | " (self_attn): LlamaAttention(\n", 164 | " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", 165 | " (k_proj): lora.Linear(\n", 166 | " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n", 167 | " (lora_dropout): ModuleDict(\n", 168 | " (default): Dropout(p=0.05, inplace=False)\n", 169 | " )\n", 170 | " (lora_A): ModuleDict(\n", 171 | " (default): Linear(in_features=4096, out_features=16, bias=False)\n", 172 | " )\n", 173 | " (lora_B): ModuleDict(\n", 174 | " (default): Linear(in_features=16, out_features=4096, bias=False)\n", 175 | " )\n", 176 | " (lora_embedding_A): ParameterDict()\n", 177 | " (lora_embedding_B): ParameterDict()\n", 178 | " )\n", 179 | " (v_proj): lora.Linear(\n", 180 | " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n", 181 | " (lora_dropout): ModuleDict(\n", 182 | " (default): Dropout(p=0.05, inplace=False)\n", 183 | " )\n", 184 | " (lora_A): ModuleDict(\n", 185 | " (default): Linear(in_features=4096, out_features=16, bias=False)\n", 186 | " )\n", 187 | " (lora_B): ModuleDict(\n", 188 | " (default): Linear(in_features=16, out_features=4096, bias=False)\n", 189 | " )\n", 190 | " (lora_embedding_A): ParameterDict()\n", 191 | " (lora_embedding_B): ParameterDict()\n", 192 | " )\n", 193 | " (o_proj): lora.Linear(\n", 194 | " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n", 195 | " (lora_dropout): ModuleDict(\n", 196 | " (default): Dropout(p=0.05, inplace=False)\n", 197 | " )\n", 198 | " (lora_A): ModuleDict(\n", 199 | " (default): Linear(in_features=4096, out_features=16, bias=False)\n", 200 | " )\n", 201 | " (lora_B): ModuleDict(\n", 202 | " (default): Linear(in_features=16, out_features=4096, bias=False)\n", 203 | " )\n", 204 | " (lora_embedding_A): ParameterDict()\n", 205 | " (lora_embedding_B): ParameterDict()\n", 206 | " )\n", 207 | " (rotary_emb): LlamaRotaryEmbedding()\n", 208 | " )\n", 209 | " (mlp): LlamaMLP(\n", 210 | " (gate_proj): lora.Linear(\n", 211 | " (base_layer): Linear(in_features=4096, out_features=11008, bias=False)\n", 212 | " (lora_dropout): ModuleDict(\n", 213 | " (default): Dropout(p=0.05, inplace=False)\n", 214 | " )\n", 215 | " (lora_A): ModuleDict(\n", 216 | " (default): Linear(in_features=4096, out_features=16, bias=False)\n", 217 | " )\n", 218 | " (lora_B): ModuleDict(\n", 219 | " (default): Linear(in_features=16, out_features=11008, bias=False)\n", 220 | " )\n", 221 | " (lora_embedding_A): ParameterDict()\n", 222 | " (lora_embedding_B): ParameterDict()\n", 223 | " )\n", 224 | " (up_proj): lora.Linear(\n", 225 | " (base_layer): Linear(in_features=4096, out_features=11008, bias=False)\n", 226 | " (lora_dropout): ModuleDict(\n", 227 | " (default): Dropout(p=0.05, inplace=False)\n", 228 | " )\n", 229 | " (lora_A): ModuleDict(\n", 230 | " (default): Linear(in_features=4096, out_features=16, bias=False)\n", 231 | " )\n", 232 | " (lora_B): ModuleDict(\n", 233 | " (default): Linear(in_features=16, out_features=11008, bias=False)\n", 234 | " )\n", 235 | " (lora_embedding_A): ParameterDict()\n", 236 | " (lora_embedding_B): ParameterDict()\n", 237 | " )\n", 238 | " (down_proj): lora.Linear(\n", 239 | " (base_layer): Linear(in_features=11008, out_features=4096, bias=False)\n", 240 | " (lora_dropout): ModuleDict(\n", 241 | " (default): Dropout(p=0.05, inplace=False)\n", 242 | " )\n", 243 | " (lora_A): ModuleDict(\n", 244 | " (default): Linear(in_features=11008, out_features=16, bias=False)\n", 245 | " )\n", 246 | " (lora_B): ModuleDict(\n", 247 | " (default): Linear(in_features=16, out_features=4096, bias=False)\n", 248 | " )\n", 249 | " (lora_embedding_A): ParameterDict()\n", 250 | " (lora_embedding_B): ParameterDict()\n", 251 | " )\n", 252 | " (act_fn): SiLU()\n", 253 | " )\n", 254 | " (input_layernorm): LlamaRMSNorm()\n", 255 | " (post_attention_layernorm): LlamaRMSNorm()\n", 256 | " )\n", 257 | " )\n", 258 | " (norm): LlamaRMSNorm()\n", 259 | " )\n", 260 | " (lm_head): Linear(in_features=4096, out_features=32001, bias=False)\n", 261 | " )\n", 262 | " )\n", 263 | " )\n", 264 | " (pooler_protein1): MultiheadAttention(\n", 265 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)\n", 266 | " )\n", 267 | " (pooler_protein2): MultiheadAttention(\n", 268 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)\n", 269 | " )\n", 270 | " (proj_protein1): Linear(in_features=1280, out_features=4096, bias=True)\n", 271 | " (proj_protein2): Linear(in_features=1280, out_features=4096, bias=True)\n", 272 | " (proj_text): Linear(in_features=4096, out_features=1280, bias=True)\n", 273 | ")" 274 | ] 275 | }, 276 | "execution_count": 1, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "import torch\n", 283 | "import yaml\n", 284 | "from torch.cuda.amp import autocast\n", 285 | "from model.mutaplm import MutaPLM\n", 286 | "\n", 287 | "# load model\n", 288 | "device = torch.device(\"cuda:2\")\n", 289 | "model_config_path = \"./configs/mutaplm_inference.yaml\"\n", 290 | "model_cfg = yaml.load(open(model_config_path, \"r\"), Loader=yaml.Loader)\n", 291 | "model_cfg[\"device\"] = device\n", 292 | "model = MutaPLM(**model_cfg).to(device)\n", 293 | "new_ckpt = torch.load(open(\"./ckpts/mutaplm.pth\", \"rb\"), map_location=\"cpu\")[\"model\"]\n", 294 | "model.load_state_dict(new_ckpt, strict=False)\n", 295 | "model.eval()" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 4, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "Predicted function: ase that can recognize specific palindromic sequences and target them to the proteasome for degradation. Can recognize the palindromic IN box and upstream of a 5'-AAA-3' motif in different proteins such as cyclins B1/2 (CCNB1 and CCNB2), histone H4 (H4), and histone H2B (H2B). Exhibits an endogenous activity in HEK293 cells (human embryonic kidney cells) and can induce the degradation of CCNB1 in these cells. Can drive the degradation of CCNB1 in HEK293 cells even if CCNB1 does not contain the IN box in its C-terminus. Also involved in the cell cycle regulation of G1/S phase by controlling KIP1/CHFR-mediated destabilization of CDK4 and phosphorylation of histone H3, histone H4 and histone H2B. Involved in DNA damage response by triggering protein degradation through the 26S proteasome in a p53/TP53-dependent\n", 308 | "Predicted effect: Decrease of IN box-dependent E3 ubiquitin ligase activity.\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "# Explanation: given wildtype protein and mutation site, predict its original function and mutational effect.\n", 314 | "wildtype_protein = \"MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK\"\n", 315 | "site = \"A70K\"\n", 316 | "mutated_protein = wildtype_protein[:int(site[1:-1])-1] + site[-1] + wildtype_protein[int(site[1:-1]):]\n", 317 | "muta_prompt = f\"Next is a feature of the mutation {site[0]} to {site[-1]} at position {site[1:-1]}. Please generate a brief summary text to describe it.\"\n", 318 | "\n", 319 | "with torch.no_grad():\n", 320 | " with autocast(dtype=torch.bfloat16):\n", 321 | " pred_func, pred_mut = model.generate([wildtype_protein], [mutated_protein], [muta_prompt])\n", 322 | "\n", 323 | "print(\"Predicted function:\", pred_func[0])\n", 324 | "print(\"Predicted effect:\", pred_mut[0])" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 5, 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "name": "stdout", 334 | "output_type": "stream", 335 | "text": [ 336 | "mutated position: 70\n", 337 | "new amino acid: K\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "# Engineering: given wildtype protein and mutational effect, predict mutated position and new amino acid.\n", 343 | "wildtype_protein = \"MASDAAAEPSSGVTHPPRYVIGYALAPKKQQSFIQPSLVAQAASRGMDLVPVDASQPLAEQGPFHLLIHALYGDDWRAQLVAFAARHPAVPIVDPPHAIDRLHNRISMLQVVSELDHAADQDSTFGIPSQVVVYDAAALADFGLLAALRFPLIAKPLVADGTAKSHKMSLVYHREGLGKLRPPLVLQEFVNHGGVIFKVYVVGGHVTCVKRRSLPDVSPEDDASAQGSVSFSQVSNLPTERTAEEYYGEKSLEDAVVPPAAFINQIAGGLRRALGLQLFNFDMIRDVRAGDRYLVIDINYFPGYAKMPGYETVLTDFFWEMVHKDGVGNQQEEKGANHVVVK\"\n", 344 | "effect_text = \"Strongly enhanced InsP6 kinase activity. The mutation in the ITPK protein causes a change in its catalytic activity.\"\n", 345 | "muta_prompt = \"What is the mutated position and new amino acid?\"\n", 346 | "\n", 347 | "with torch.no_grad():\n", 348 | " with autocast(dtype=torch.bfloat16):\n", 349 | " preds = model.lm_design([wildtype_protein], [effect_text], muta_prompt=[muta_prompt])\n", 350 | "\n", 351 | "top50 = preds[0].flatten().topk(50).indices\n", 352 | "top50_pos = top50 // len(model.protein_tokenizer)\n", 353 | "top50_aa = top50 % len(model.protein_tokenizer)\n", 354 | "top50_aa = model.protein_tokenizer.batch_decode(top50_aa)\n", 355 | "print(\"mutated position:\", top50_pos[0].item())\n", 356 | "print(\"new amino acid:\", top50_aa[0])" 357 | ] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "openbiomed", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.9.17" 377 | } 378 | }, 379 | "nbformat": 4, 380 | "nbformat_minor": 2 381 | } 382 | -------------------------------------------------------------------------------- /logs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/MutaPLM/495815b2069419d19d9449a59070d0fb24b596d1/logs/.placeholder -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import spearmanr 3 | 4 | def loss(y_true, y_pred): 5 | return np.mean(y_pred) 6 | 7 | def spearman(y_true, y_pred): 8 | return spearmanr(y_true, y_pred).statistic 9 | 10 | name2metric = { 11 | "spearmanr": spearman, 12 | "loss": loss, 13 | } -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.vanllina_esm import VanllinaEsm, RandomModel 2 | from model.mutaplm import MutaPLM 3 | 4 | model_name2cls = { 5 | "esm": VanllinaEsm, 6 | "mutaplm": MutaPLM, 7 | "random": RandomModel 8 | } -------------------------------------------------------------------------------- /model/esm_landscape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import EsmModel, EsmTokenizer 4 | from sequence_models.structure import Attention1d 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, input_dim=1280, hidden_dim=512): 8 | super().__init__() 9 | self.dense_1 = nn.Linear(input_dim, hidden_dim) 10 | self.dense_2 = nn.Linear(hidden_dim, hidden_dim) 11 | self.attention1d = Attention1d(in_dim=hidden_dim) 12 | self.dense_3 = nn.Linear(hidden_dim, hidden_dim) 13 | self.dense_4 = nn.Linear(hidden_dim, 1) 14 | 15 | def forward(self, x): 16 | x = torch.relu(self.dense_1(x)) 17 | x = torch.relu(self.dense_2(x)) 18 | x = self.attention1d(x) 19 | x = torch.relu(self.dense_3(x)) 20 | x = self.dense_4(x) 21 | return x 22 | 23 | 24 | class EsmForLandscapeRegression(nn.Module): 25 | def __init__(self, esm_path, decoder_ckpt, device): 26 | super().__init__() 27 | self.esm = EsmModel.from_pretrained(esm_path) 28 | self.esm_tokenizer = EsmTokenizer.from_pretrained(esm_path) 29 | self.decoder = Decoder() 30 | ckpt = torch.load(decoder_ckpt) 31 | self.decoder.load_state_dict(ckpt) 32 | self.device = device 33 | 34 | def forward(self, protein): 35 | protein = self.esm_tokenizer( 36 | list(protein), 37 | add_special_tokens=True, 38 | truncation=True, 39 | padding=True, 40 | max_length=1024, 41 | return_tensors='pt' 42 | ).to(self.device) 43 | h = self.esm(**protein, return_dict=True).last_hidden_state 44 | return self.decoder(h).squeeze() 45 | 46 | def predict_fitness(self, protein, *kwargs): 47 | return self.forward(protein) 48 | 49 | if __name__ == "__main__": 50 | import os 51 | import sys 52 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 53 | from tqdm import tqdm 54 | from dataset.fitness_dataset import FitnessDataset 55 | from torch.utils.data import DataLoader 56 | from sklearn.metrics import r2_score 57 | from scipy.stats import spearmanr 58 | ckpt_path = "./ckpts/landscape_ckpts/landscape_params/esm1b_landscape/" 59 | data_path = "./data/fitness/" 60 | device = torch.device("cuda", 0) 61 | for dataset_name in ["AAV", "AMIE", "avGFP", "E4B", "LGK", "Pab1", "TEM", "UBE2I"]: 62 | print(dataset_name) 63 | dataset = FitnessDataset(data_path + dataset_name, "valid") 64 | dataloader = DataLoader(dataset, batch_size=64, shuffle=False) 65 | model = EsmForLandscapeRegression("./ckpts/protein_ckpts/esm1b", ckpt_path + dataset_name + "/decoder.pt", device).to(device) 66 | preds, gts = [], [] 67 | model.eval() 68 | with torch.no_grad(): 69 | for i, (seq, score) in enumerate(tqdm(dataloader)): 70 | preds += model(seq).tolist() 71 | gts += score.tolist() 72 | if i == 0: 73 | print("Preds:", model(seq).tolist()) 74 | print("gts: ", score.tolist()) 75 | print("R2:", r2_score(gts, preds)) 76 | print("Spearman: ", spearmanr(gts, preds)) -------------------------------------------------------------------------------- /model/modeling_esm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch ESM model.""" 16 | 17 | import copy 18 | import math 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 25 | import torch.nn.functional as F 26 | 27 | from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward 28 | from transformers.modeling_outputs import ( 29 | BaseModelOutputWithPastAndCrossAttentions, 30 | BaseModelOutputWithPoolingAndCrossAttentions, 31 | MaskedLMOutput, 32 | SequenceClassifierOutput, 33 | TokenClassifierOutput, 34 | ModelOutput, 35 | ) 36 | from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer 37 | from transformers.utils import logging 38 | from transformers.models.esm.configuration_esm import EsmConfig 39 | 40 | from model.esm_landscape import Decoder 41 | 42 | logger = logging.get_logger(__name__) 43 | 44 | _CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D" 45 | _CONFIG_FOR_DOC = "EsmConfig" 46 | 47 | ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [ 48 | "facebook/esm2_t6_8M_UR50D", 49 | "facebook/esm2_t12_35M_UR50D", 50 | # This is not a complete list of all ESM models! 51 | # See all ESM models at https://huggingface.co/models?filter=esm 52 | ] 53 | 54 | 55 | def rotate_half(x): 56 | x1, x2 = x.chunk(2, dim=-1) 57 | return torch.cat((-x2, x1), dim=-1) 58 | 59 | 60 | def apply_rotary_pos_emb(x, cos, sin): 61 | cos = cos[:, :, : x.shape[-2], :] 62 | sin = sin[:, :, : x.shape[-2], :] 63 | 64 | return (x * cos) + (rotate_half(x) * sin) 65 | 66 | 67 | def gelu(x): 68 | """ 69 | This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results. 70 | """ 71 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 72 | 73 | 74 | def symmetrize(x): 75 | "Make layer symmetric in final two dimensions, used for contact prediction." 76 | return x + x.transpose(-1, -2) 77 | 78 | 79 | def average_product_correct(x): 80 | "Perform average product correct, used for contact prediction." 81 | a1 = x.sum(-1, keepdims=True) 82 | a2 = x.sum(-2, keepdims=True) 83 | a12 = x.sum((-1, -2), keepdims=True) 84 | 85 | avg = a1 * a2 86 | avg.div_(a12) # in-place to reduce memory 87 | normalized = x - avg 88 | return normalized 89 | 90 | 91 | class RotaryEmbedding(torch.nn.Module): 92 | """ 93 | Rotary position embeddings based on those in 94 | [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation 95 | matrices which depend on their relative positions. 96 | """ 97 | 98 | def __init__(self, dim: int): 99 | super().__init__() 100 | # Generate and save the inverse frequency buffer (non trainable) 101 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 102 | inv_freq = inv_freq 103 | self.register_buffer("inv_freq", inv_freq) 104 | 105 | self._seq_len_cached = None 106 | self._cos_cached = None 107 | self._sin_cached = None 108 | 109 | def _update_cos_sin_tables(self, x, seq_dimension=2): 110 | seq_len = x.shape[seq_dimension] 111 | 112 | # Reset the tables if the sequence length has changed, 113 | # or if we're on a new device (possibly due to tracing for instance) 114 | if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: 115 | self._seq_len_cached = seq_len 116 | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) 117 | freqs = torch.outer(t, self.inv_freq) 118 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 119 | 120 | self._cos_cached = emb.cos()[None, None, :, :] 121 | self._sin_cached = emb.sin()[None, None, :, :] 122 | 123 | return self._cos_cached, self._sin_cached 124 | 125 | def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 126 | self._cos_cached, self._sin_cached = self._update_cos_sin_tables(q, seq_dimension=-2) 127 | rotated_q = apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached) 128 | 129 | self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) 130 | rotated_k = apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached) 131 | 132 | return (rotated_q, rotated_k) 133 | 134 | 135 | class EsmContactPredictionHead(nn.Module): 136 | """Performs symmetrization, apc, and computes a logistic regression on the output features""" 137 | 138 | def __init__( 139 | self, 140 | in_features: int, 141 | bias=True, 142 | eos_idx: int = 2, 143 | ): 144 | super().__init__() 145 | self.in_features = in_features 146 | self.eos_idx = eos_idx 147 | self.regression = nn.Linear(in_features, 1, bias) 148 | self.activation = nn.Sigmoid() 149 | 150 | def forward(self, tokens, attentions): 151 | # remove eos token attentions 152 | eos_mask = tokens.ne(self.eos_idx).to(attentions) 153 | eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) 154 | attentions = attentions * eos_mask[:, None, None, :, :] 155 | attentions = attentions[..., :-1, :-1] 156 | # remove cls token attentions 157 | attentions = attentions[..., 1:, 1:] 158 | batch_size, layers, heads, seqlen, _ = attentions.size() 159 | attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) 160 | 161 | # features: batch x channels x tokens x tokens (symmetric) 162 | attentions = attentions.to( 163 | self.regression.weight.device 164 | ) # attentions always float32, may need to convert to float16 165 | attentions = average_product_correct(symmetrize(attentions)) 166 | attentions = attentions.permute(0, 2, 3, 1) 167 | return self.activation(self.regression(attentions).squeeze(3)) 168 | 169 | 170 | class EsmEmbeddings(nn.Module): 171 | """ 172 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 173 | """ 174 | 175 | def __init__(self, config): 176 | super().__init__() 177 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 178 | 179 | if config.emb_layer_norm_before: 180 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 181 | else: 182 | self.layer_norm = None 183 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 184 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 185 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 186 | self.register_buffer( 187 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False 188 | ) 189 | 190 | self.padding_idx = config.pad_token_id 191 | self.position_embeddings = nn.Embedding( 192 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx 193 | ) 194 | self.token_dropout = config.token_dropout 195 | self.mask_token_id = config.mask_token_id 196 | 197 | def forward( 198 | self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 199 | ): 200 | if position_ids is None: 201 | if input_ids is not None: 202 | # Create the position ids from the input token ids. Any padded tokens remain padded. 203 | position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) 204 | else: 205 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 206 | 207 | if inputs_embeds is None: 208 | inputs_embeds = self.word_embeddings(input_ids) 209 | 210 | # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an 211 | # embedding_scale factor here. 212 | embeddings = inputs_embeds 213 | 214 | # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout 215 | # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, 216 | # masked tokens are treated as if they were selected for input dropout and zeroed out. 217 | # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by 218 | # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). 219 | # This is analogous to the way that dropout layers scale down outputs during evaluation when not 220 | # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). 221 | if self.token_dropout: 222 | embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) 223 | mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs 224 | src_lengths = attention_mask.sum(-1) 225 | mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths 226 | embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( 227 | embeddings.dtype 228 | ) 229 | 230 | if self.position_embedding_type == "absolute": 231 | position_embeddings = self.position_embeddings(position_ids) 232 | embeddings = embeddings + position_embeddings 233 | 234 | if self.layer_norm is not None: 235 | embeddings = self.layer_norm(embeddings) 236 | if attention_mask is not None: 237 | embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) 238 | # Matt: I think this line was copied incorrectly from BERT, disabling it for now. 239 | # embeddings = self.dropout(embeddings) 240 | return embeddings 241 | 242 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 243 | """ 244 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. 245 | 246 | Args: 247 | inputs_embeds: torch.Tensor 248 | 249 | Returns: torch.Tensor 250 | """ 251 | input_shape = inputs_embeds.size()[:-1] 252 | sequence_length = input_shape[1] 253 | 254 | position_ids = torch.arange( 255 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device 256 | ) 257 | return position_ids.unsqueeze(0).expand(input_shape) 258 | 259 | 260 | class EsmSelfAttention(nn.Module): 261 | def __init__(self, config, position_embedding_type=None): 262 | super().__init__() 263 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 264 | raise ValueError( 265 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 266 | f"heads ({config.num_attention_heads})" 267 | ) 268 | 269 | self.num_attention_heads = config.num_attention_heads 270 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 271 | self.all_head_size = self.num_attention_heads * self.attention_head_size 272 | 273 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 274 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 275 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 276 | 277 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 278 | self.position_embedding_type = position_embedding_type or getattr( 279 | config, "position_embedding_type", "absolute" 280 | ) 281 | self.rotary_embeddings = None 282 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 283 | self.max_position_embeddings = config.max_position_embeddings 284 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 285 | elif self.position_embedding_type == "rotary": 286 | self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) 287 | 288 | self.is_decoder = config.is_decoder 289 | 290 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 291 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 292 | x = x.view(new_x_shape) 293 | return x.permute(0, 2, 1, 3) 294 | 295 | def forward( 296 | self, 297 | hidden_states: torch.Tensor, 298 | attention_mask: Optional[torch.FloatTensor] = None, 299 | head_mask: Optional[torch.FloatTensor] = None, 300 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 301 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 302 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 303 | output_attentions: Optional[bool] = False, 304 | ) -> Tuple[torch.Tensor]: 305 | mixed_query_layer = self.query(hidden_states) 306 | 307 | # If this is instantiated as a cross-attention module, the keys 308 | # and values come from an encoder; the attention mask needs to be 309 | # such that the encoder's padding tokens are not attended to. 310 | is_cross_attention = encoder_hidden_states is not None 311 | 312 | if is_cross_attention and past_key_value is not None: 313 | # reuse k,v, cross_attentions 314 | key_layer = past_key_value[0] 315 | value_layer = past_key_value[1] 316 | attention_mask = encoder_attention_mask 317 | elif is_cross_attention: 318 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 319 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 320 | attention_mask = encoder_attention_mask 321 | elif past_key_value is not None: 322 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 323 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 324 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 325 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 326 | else: 327 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 328 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 329 | 330 | query_layer = self.transpose_for_scores(mixed_query_layer) 331 | 332 | # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). 333 | # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, 334 | # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original 335 | # ESM code and fix rotary embeddings. 336 | query_layer = query_layer * self.attention_head_size**-0.5 337 | 338 | if self.is_decoder: 339 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 340 | # Further calls to cross_attention layer can then reuse all cross-attention 341 | # key/value_states (first "if" case) 342 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 343 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 344 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 345 | # if encoder bi-directional self-attention `past_key_value` is always `None` 346 | past_key_value = (key_layer, value_layer) 347 | 348 | if self.position_embedding_type == "rotary": 349 | query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) 350 | 351 | # Take the dot product between "query" and "key" to get the raw attention scores. 352 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 353 | 354 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 355 | seq_length = hidden_states.size()[1] 356 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 357 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 358 | distance = position_ids_l - position_ids_r 359 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 360 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 361 | 362 | if self.position_embedding_type == "relative_key": 363 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 364 | attention_scores = attention_scores + relative_position_scores 365 | elif self.position_embedding_type == "relative_key_query": 366 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 367 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 368 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 369 | 370 | if attention_mask is not None: 371 | # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) 372 | attention_scores = attention_scores + attention_mask 373 | 374 | # Normalize the attention scores to probabilities. 375 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 376 | 377 | # This is actually dropping out entire tokens to attend to, which might 378 | # seem a bit unusual, but is taken from the original Transformer paper. 379 | attention_probs = self.dropout(attention_probs) 380 | 381 | # Mask heads if we want to 382 | if head_mask is not None: 383 | attention_probs = attention_probs * head_mask 384 | 385 | context_layer = torch.matmul(attention_probs, value_layer) 386 | 387 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 388 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 389 | context_layer = context_layer.view(new_context_layer_shape) 390 | 391 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 392 | 393 | if self.is_decoder: 394 | outputs = outputs + (past_key_value,) 395 | return outputs 396 | 397 | 398 | class EsmSelfOutput(nn.Module): 399 | def __init__(self, config): 400 | super().__init__() 401 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 402 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 403 | 404 | def forward(self, hidden_states, input_tensor): 405 | hidden_states = self.dense(hidden_states) 406 | hidden_states = self.dropout(hidden_states) 407 | hidden_states = hidden_states + input_tensor 408 | return hidden_states 409 | 410 | 411 | class EsmAttention(nn.Module): 412 | def __init__(self, config): 413 | super().__init__() 414 | self.self = EsmSelfAttention(config) 415 | self.output = EsmSelfOutput(config) 416 | self.pruned_heads = set() 417 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 418 | 419 | def prune_heads(self, heads): 420 | if len(heads) == 0: 421 | return 422 | heads, index = find_pruneable_heads_and_indices( 423 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 424 | ) 425 | 426 | # Prune linear layers 427 | self.self.query = prune_linear_layer(self.self.query, index) 428 | self.self.key = prune_linear_layer(self.self.key, index) 429 | self.self.value = prune_linear_layer(self.self.value, index) 430 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 431 | 432 | # Update hyper params and store pruned heads 433 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 434 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 435 | self.pruned_heads = self.pruned_heads.union(heads) 436 | 437 | def forward( 438 | self, 439 | hidden_states, 440 | attention_mask=None, 441 | head_mask=None, 442 | encoder_hidden_states=None, 443 | encoder_attention_mask=None, 444 | past_key_value=None, 445 | output_attentions=False, 446 | ): 447 | hidden_states_ln = self.LayerNorm(hidden_states) 448 | self_outputs = self.self( 449 | hidden_states_ln, 450 | attention_mask, 451 | head_mask, 452 | encoder_hidden_states, 453 | encoder_attention_mask, 454 | past_key_value, 455 | output_attentions, 456 | ) 457 | attention_output = self.output(self_outputs[0], hidden_states) 458 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 459 | return outputs 460 | 461 | 462 | class EsmIntermediate(nn.Module): 463 | def __init__(self, config): 464 | super().__init__() 465 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 466 | 467 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 468 | hidden_states = self.dense(hidden_states) 469 | hidden_states = gelu(hidden_states) 470 | return hidden_states 471 | 472 | 473 | class EsmOutput(nn.Module): 474 | def __init__(self, config): 475 | super().__init__() 476 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 477 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 478 | 479 | def forward(self, hidden_states, input_tensor): 480 | hidden_states = self.dense(hidden_states) 481 | hidden_states = self.dropout(hidden_states) 482 | hidden_states = hidden_states + input_tensor 483 | return hidden_states 484 | 485 | 486 | class EsmLayer(nn.Module): 487 | def __init__(self, config, layer_id): 488 | super().__init__() 489 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 490 | self.seq_len_dim = 1 491 | self.attention = EsmAttention(config) 492 | self.is_decoder = config.is_decoder 493 | self.add_cross_attention = config.add_cross_attention and (layer_id + 1) % config.adapter_freq == 0 494 | if self.add_cross_attention: 495 | """ 496 | fuck this assertion 497 | if not self.is_decoder: 498 | raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") 499 | """ 500 | adapter_config = copy.deepcopy(config) 501 | adapter_config.intermediate_size = adapter_config.hidden_size // 2 502 | self.crossattention_adapter = EsmAttention(adapter_config) 503 | self.intermediate_adapter = EsmIntermediate(adapter_config) 504 | self.output_adapter = EsmOutput(adapter_config) 505 | self.LayerNorm_adapter = nn.LayerNorm(adapter_config.hidden_size, adapter_config.layer_norm_eps) 506 | self.intermediate = EsmIntermediate(config) 507 | self.output = EsmOutput(config) 508 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 509 | 510 | def forward( 511 | self, 512 | hidden_states, 513 | attention_mask=None, 514 | head_mask=None, 515 | encoder_hidden_states=None, 516 | encoder_attention_mask=None, 517 | past_key_value=None, 518 | output_attentions=False, 519 | ): 520 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 521 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 522 | self_attention_outputs = self.attention( 523 | hidden_states, 524 | attention_mask, 525 | head_mask, 526 | output_attentions=output_attentions, 527 | past_key_value=self_attn_past_key_value, 528 | ) 529 | attention_output = self_attention_outputs[0] 530 | 531 | # if decoder, the last output is tuple of self-attn cache 532 | if self.is_decoder: 533 | outputs = self_attention_outputs[1:-1] 534 | present_key_value = self_attention_outputs[-1] 535 | else: 536 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 537 | 538 | layer_output = self.feed_forward_chunk(attention_output) 539 | 540 | cross_attn_present_key_value = None 541 | if self.add_cross_attention and encoder_hidden_states is not None: 542 | if not hasattr(self, "crossattention_adapter"): 543 | raise AttributeError( 544 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated" 545 | " with cross-attention layers by setting `config.add_cross_attention=True`" 546 | ) 547 | 548 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 549 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 550 | cross_attention_outputs = self.crossattention_adapter( 551 | layer_output, 552 | attention_mask, 553 | head_mask, 554 | encoder_hidden_states, 555 | encoder_attention_mask, 556 | cross_attn_past_key_value, 557 | output_attentions, 558 | ) 559 | attention_output = cross_attention_outputs[0] 560 | layer_output = self.feed_forward_chunk_adapter(attention_output) 561 | 562 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 563 | 564 | # add cross-attn cache to positions 3,4 of present_key_value tuple 565 | # cross_attn_present_key_value = cross_attention_outputs[-1] 566 | # present_key_value = present_key_value + cross_attn_present_key_value 567 | 568 | outputs = (layer_output,) + outputs 569 | 570 | # if decoder, return the attn key/values as the last output 571 | if self.is_decoder: 572 | outputs = outputs + (present_key_value,) 573 | return outputs 574 | 575 | def feed_forward_chunk(self, attention_output): 576 | attention_output_ln = self.LayerNorm(attention_output) 577 | intermediate_output = self.intermediate(attention_output_ln) 578 | layer_output = self.output(intermediate_output, attention_output) 579 | return layer_output 580 | 581 | def feed_forward_chunk_adapter(self, attention_output): 582 | attention_output_ln = self.LayerNorm_adapter(attention_output) 583 | intermediate_output = self.intermediate_adapter(attention_output_ln) 584 | layer_output = self.output_adapter(intermediate_output, attention_output) 585 | return layer_output 586 | 587 | 588 | class EsmEncoder(nn.Module): 589 | def __init__(self, config): 590 | super().__init__() 591 | self.config = config 592 | self.layer = nn.ModuleList([EsmLayer(config, i) for i in range(config.num_hidden_layers)]) 593 | self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 594 | self.gradient_checkpointing = False 595 | 596 | def forward( 597 | self, 598 | hidden_states, 599 | attention_mask=None, 600 | head_mask=None, 601 | encoder_hidden_states=None, 602 | encoder_attention_mask=None, 603 | past_key_values=None, 604 | use_cache=None, 605 | output_attentions=False, 606 | output_hidden_states=False, 607 | return_dict=True, 608 | ): 609 | if self.gradient_checkpointing and self.training: 610 | if use_cache: 611 | logger.warning_once( 612 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 613 | "`use_cache=False`..." 614 | ) 615 | use_cache = False 616 | all_hidden_states = () if output_hidden_states else None 617 | all_self_attentions = () if output_attentions else None 618 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 619 | 620 | next_decoder_cache = () if use_cache else None 621 | for i, layer_module in enumerate(self.layer): 622 | if output_hidden_states: 623 | all_hidden_states = all_hidden_states + (hidden_states,) 624 | 625 | layer_head_mask = head_mask[i] if head_mask is not None else None 626 | past_key_value = past_key_values[i] if past_key_values is not None else None 627 | 628 | if self.gradient_checkpointing and self.training: 629 | layer_outputs = self._gradient_checkpointing_func( 630 | layer_module.__call__, 631 | hidden_states, 632 | attention_mask, 633 | layer_head_mask, 634 | encoder_hidden_states, 635 | encoder_attention_mask, 636 | past_key_value, 637 | output_attentions, 638 | ) 639 | else: 640 | layer_outputs = layer_module( 641 | hidden_states, 642 | attention_mask, 643 | layer_head_mask, 644 | encoder_hidden_states, 645 | encoder_attention_mask, 646 | past_key_value, 647 | output_attentions, 648 | ) 649 | 650 | hidden_states = layer_outputs[0] 651 | if use_cache: 652 | next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) 653 | if output_attentions: 654 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 655 | if self.config.add_cross_attention: 656 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 657 | 658 | if self.emb_layer_norm_after: 659 | hidden_states = self.emb_layer_norm_after(hidden_states) 660 | 661 | if output_hidden_states: 662 | all_hidden_states = all_hidden_states + (hidden_states,) 663 | 664 | if not return_dict: 665 | return tuple( 666 | v 667 | for v in [ 668 | hidden_states, 669 | next_decoder_cache, 670 | all_hidden_states, 671 | all_self_attentions, 672 | all_cross_attentions, 673 | ] 674 | if v is not None 675 | ) 676 | return BaseModelOutputWithPastAndCrossAttentions( 677 | last_hidden_state=hidden_states, 678 | past_key_values=next_decoder_cache, 679 | hidden_states=all_hidden_states, 680 | attentions=all_self_attentions, 681 | cross_attentions=all_cross_attentions, 682 | ) 683 | 684 | 685 | # Copied from transformers.models.bert.modeling_bert.BertPooler 686 | class EsmPooler(nn.Module): 687 | def __init__(self, config): 688 | super().__init__() 689 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 690 | self.activation = nn.Tanh() 691 | 692 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 693 | # We "pool" the model by simply taking the hidden state corresponding 694 | # to the first token. 695 | first_token_tensor = hidden_states[:, 0] 696 | pooled_output = self.dense(first_token_tensor) 697 | pooled_output = self.activation(pooled_output) 698 | return pooled_output 699 | 700 | 701 | class EsmPreTrainedModel(PreTrainedModel): 702 | """ 703 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 704 | models. 705 | """ 706 | 707 | config_class = EsmConfig 708 | base_model_prefix = "esm" 709 | supports_gradient_checkpointing = True 710 | _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] 711 | 712 | # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights 713 | def _init_weights(self, module): 714 | """Initialize the weights""" 715 | if isinstance(module, nn.Linear): 716 | # Slightly different from the TF version which uses truncated_normal for initialization 717 | # cf https://github.com/pytorch/pytorch/pull/5617 718 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 719 | if module.bias is not None: 720 | module.bias.data.zero_() 721 | elif isinstance(module, nn.Embedding): 722 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 723 | if module.padding_idx is not None: 724 | module.weight.data[module.padding_idx].zero_() 725 | elif isinstance(module, nn.LayerNorm): 726 | module.bias.data.zero_() 727 | module.weight.data.fill_(1.0) 728 | 729 | 730 | ESM_START_DOCSTRING = r""" 731 | 732 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 733 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 734 | etc.) 735 | 736 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 737 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 738 | and behavior. 739 | 740 | Parameters: 741 | config ([`EsmConfig`]): Model configuration class with all the parameters of the 742 | model. Initializing with a config file does not load the weights associated with the model, only the 743 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 744 | """ 745 | 746 | ESM_INPUTS_DOCSTRING = r""" 747 | Args: 748 | input_ids (`torch.LongTensor` of shape `({0})`): 749 | Indices of input sequence tokens in the vocabulary. 750 | 751 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 752 | [`PreTrainedTokenizer.__call__`] for details. 753 | 754 | [What are input IDs?](../glossary#input-ids) 755 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 756 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 757 | 758 | - 1 for tokens that are **not masked**, 759 | - 0 for tokens that are **masked**. 760 | 761 | [What are attention masks?](../glossary#attention-mask) 762 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 763 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 764 | config.max_position_embeddings - 1]`. 765 | 766 | [What are position IDs?](../glossary#position-ids) 767 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 768 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 769 | 770 | - 1 indicates the head is **not masked**, 771 | - 0 indicates the head is **masked**. 772 | 773 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 774 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 775 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 776 | model's internal embedding lookup matrix. 777 | output_attentions (`bool`, *optional*): 778 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 779 | tensors for more detail. 780 | output_hidden_states (`bool`, *optional*): 781 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 782 | more detail. 783 | return_dict (`bool`, *optional*): 784 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 785 | """ 786 | 787 | 788 | @add_start_docstrings( 789 | "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", 790 | ESM_START_DOCSTRING, 791 | ) 792 | class EsmModel(EsmPreTrainedModel): 793 | """ 794 | 795 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 796 | cross-attention is added between the self-attention layers, following the architecture described in [Attention is 797 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 798 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 799 | 800 | To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set 801 | to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and 802 | `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. 803 | """ 804 | 805 | def __init__(self, config, add_pooling_layer=True): 806 | config.add_cross_attention = True 807 | config.adapter_freq = 33 808 | super().__init__(config) 809 | self.config = config 810 | 811 | self.embeddings = EsmEmbeddings(config) 812 | self.encoder = EsmEncoder(config) 813 | 814 | self.pooler = EsmPooler(config) if add_pooling_layer else None 815 | 816 | self.contact_head = EsmContactPredictionHead( 817 | in_features=config.num_hidden_layers * config.num_attention_heads, bias=True 818 | ) 819 | 820 | # Initialize weights and apply final processing 821 | self.post_init() 822 | 823 | def get_input_embeddings(self): 824 | return self.embeddings.word_embeddings 825 | 826 | def set_input_embeddings(self, value): 827 | self.embeddings.word_embeddings = value 828 | 829 | def _prune_heads(self, heads_to_prune): 830 | """ 831 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 832 | class PreTrainedModel 833 | """ 834 | for layer, heads in heads_to_prune.items(): 835 | self.encoder.layer[layer].attention.prune_heads(heads) 836 | 837 | @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) 838 | @add_code_sample_docstrings( 839 | checkpoint=_CHECKPOINT_FOR_DOC, 840 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 841 | config_class=_CONFIG_FOR_DOC, 842 | ) 843 | def forward( 844 | self, 845 | input_ids: Optional[torch.Tensor] = None, 846 | attention_mask: Optional[torch.Tensor] = None, 847 | position_ids: Optional[torch.Tensor] = None, 848 | head_mask: Optional[torch.Tensor] = None, 849 | inputs_embeds: Optional[torch.Tensor] = None, 850 | encoder_hidden_states: Optional[torch.Tensor] = None, 851 | encoder_attention_mask: Optional[torch.Tensor] = None, 852 | past_key_values: Optional[List[torch.FloatTensor]] = None, 853 | use_cache: Optional[bool] = None, 854 | output_attentions: Optional[bool] = None, 855 | output_hidden_states: Optional[bool] = None, 856 | return_dict: Optional[bool] = None, 857 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: 858 | r""" 859 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 860 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 861 | the model is configured as a decoder. 862 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 863 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 864 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 865 | 866 | - 1 for tokens that are **not masked**, 867 | - 0 for tokens that are **masked**. 868 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 869 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 870 | 871 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 872 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 873 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 874 | use_cache (`bool`, *optional*): 875 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 876 | `past_key_values`). 877 | """ 878 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 879 | output_hidden_states = ( 880 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 881 | ) 882 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 883 | 884 | if self.config.is_decoder: 885 | use_cache = use_cache if use_cache is not None else self.config.use_cache 886 | else: 887 | use_cache = False 888 | 889 | if input_ids is not None and inputs_embeds is not None: 890 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 891 | elif input_ids is not None: 892 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) 893 | input_shape = input_ids.size() 894 | elif inputs_embeds is not None: 895 | input_shape = inputs_embeds.size()[:-1] 896 | else: 897 | raise ValueError("You have to specify either input_ids or inputs_embeds") 898 | 899 | batch_size, seq_length = input_shape 900 | device = input_ids.device if input_ids is not None else inputs_embeds.device 901 | 902 | # past_key_values_length 903 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 904 | 905 | if attention_mask is None: 906 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 907 | 908 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 909 | # ourselves in which case we just need to make it broadcastable to all heads. 910 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) 911 | 912 | # If a 2D or 3D attention mask is provided for the cross-attention 913 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 914 | if self.config.is_decoder and encoder_hidden_states is not None: 915 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 916 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 917 | if encoder_attention_mask is None: 918 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 919 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 920 | else: 921 | encoder_extended_attention_mask = None 922 | 923 | # Prepare head mask if needed 924 | # 1.0 in head_mask indicate we keep the head 925 | # attention_probs has shape bsz x n_heads x N x N 926 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 927 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 928 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 929 | 930 | embedding_output = self.embeddings( 931 | input_ids=input_ids, 932 | position_ids=position_ids, 933 | attention_mask=attention_mask, 934 | inputs_embeds=inputs_embeds, 935 | past_key_values_length=past_key_values_length, 936 | ) 937 | encoder_outputs = self.encoder( 938 | embedding_output, 939 | attention_mask=extended_attention_mask, 940 | head_mask=head_mask, 941 | encoder_hidden_states=encoder_hidden_states, 942 | encoder_attention_mask=encoder_extended_attention_mask, 943 | past_key_values=past_key_values, 944 | use_cache=use_cache, 945 | output_attentions=output_attentions, 946 | output_hidden_states=output_hidden_states, 947 | return_dict=return_dict, 948 | ) 949 | sequence_output = encoder_outputs[0] 950 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 951 | 952 | if not return_dict: 953 | return (sequence_output, pooled_output) + encoder_outputs[1:] 954 | 955 | return BaseModelOutputWithPoolingAndCrossAttentions( 956 | last_hidden_state=sequence_output, 957 | pooler_output=pooled_output, 958 | past_key_values=encoder_outputs.past_key_values, 959 | hidden_states=encoder_outputs.hidden_states, 960 | attentions=encoder_outputs.attentions, 961 | cross_attentions=encoder_outputs.cross_attentions, 962 | ) 963 | 964 | def predict_contacts(self, tokens, attention_mask): 965 | attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions 966 | attns = torch.stack(attns, dim=1) # Matches the original model layout 967 | # In the original model, attentions for padding tokens are completely zeroed out. 968 | # This makes no difference most of the time because the other tokens won't attend to them, 969 | # but it does for the contact prediction task, which takes attentions as input, 970 | # so we have to mimic that here. 971 | attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) 972 | attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) 973 | return self.contact_head(tokens, attns) 974 | 975 | 976 | @add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) 977 | class EsmForMaskedLM(EsmPreTrainedModel): 978 | _tied_weights_keys = ["lm_head.decoder.weight"] 979 | 980 | def __init__(self, config): 981 | super().__init__(config) 982 | 983 | if config.is_decoder: 984 | logger.warning( 985 | "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " 986 | "bi-directional self-attention." 987 | ) 988 | 989 | self.esm = EsmModel(config, add_pooling_layer=False) 990 | self.lm_head = EsmLMHead(config) 991 | 992 | self.init_weights() 993 | 994 | def get_output_embeddings(self): 995 | return self.lm_head.decoder 996 | 997 | def set_output_embeddings(self, new_embeddings): 998 | self.lm_head.decoder = new_embeddings 999 | 1000 | @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1001 | @add_code_sample_docstrings( 1002 | checkpoint=_CHECKPOINT_FOR_DOC, 1003 | output_type=MaskedLMOutput, 1004 | config_class=_CONFIG_FOR_DOC, 1005 | mask="", 1006 | ) 1007 | def forward( 1008 | self, 1009 | input_ids: Optional[torch.LongTensor] = None, 1010 | attention_mask: Optional[torch.Tensor] = None, 1011 | position_ids: Optional[torch.LongTensor] = None, 1012 | head_mask: Optional[torch.Tensor] = None, 1013 | inputs_embeds: Optional[torch.FloatTensor] = None, 1014 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 1015 | encoder_attention_mask: Optional[torch.Tensor] = None, 1016 | labels: Optional[torch.LongTensor] = None, 1017 | output_attentions: Optional[bool] = None, 1018 | output_hidden_states: Optional[bool] = None, 1019 | return_dict: Optional[bool] = None, 1020 | ) -> Union[Tuple, MaskedLMOutput]: 1021 | r""" 1022 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1023 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 1024 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 1025 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 1026 | kwargs (`Dict[str, any]`, optional, defaults to *{}*): 1027 | Used to hide legacy arguments that have been deprecated. 1028 | """ 1029 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1030 | 1031 | outputs = self.esm( 1032 | input_ids, 1033 | attention_mask=attention_mask, 1034 | position_ids=position_ids, 1035 | head_mask=head_mask, 1036 | inputs_embeds=inputs_embeds, 1037 | encoder_hidden_states=encoder_hidden_states, 1038 | encoder_attention_mask=encoder_attention_mask, 1039 | output_attentions=output_attentions, 1040 | output_hidden_states=output_hidden_states, 1041 | return_dict=return_dict, 1042 | ) 1043 | sequence_output = outputs[0] 1044 | prediction_scores = self.lm_head(sequence_output) 1045 | 1046 | masked_lm_loss = None 1047 | if labels is not None: 1048 | loss_fct = CrossEntropyLoss(reduction='none') 1049 | 1050 | labels = labels.to(prediction_scores.device) 1051 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1052 | 1053 | if not return_dict: 1054 | output = (prediction_scores,) + outputs[2:] 1055 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1056 | 1057 | return MaskedLMOutput( 1058 | loss=masked_lm_loss, 1059 | logits=prediction_scores, 1060 | hidden_states=outputs.hidden_states, 1061 | attentions=outputs.attentions, 1062 | ) 1063 | 1064 | def predict_contacts(self, tokens, attention_mask): 1065 | return self.esm.predict_contacts(tokens, attention_mask=attention_mask) 1066 | 1067 | 1068 | class EsmLMHead(nn.Module): 1069 | """ESM Head for masked language modeling.""" 1070 | 1071 | def __init__(self, config): 1072 | super().__init__() 1073 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1074 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1075 | 1076 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1077 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 1078 | 1079 | def forward(self, features, **kwargs): 1080 | x = self.dense(features) 1081 | x = gelu(x) 1082 | x = self.layer_norm(x) 1083 | 1084 | # project back to size of vocabulary with bias 1085 | x = self.decoder(x) + self.bias 1086 | return x 1087 | 1088 | 1089 | @add_start_docstrings( 1090 | """ 1091 | ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled 1092 | output) e.g. for GLUE tasks. 1093 | """, 1094 | ESM_START_DOCSTRING, 1095 | ) 1096 | class EsmForSequenceClassification(EsmPreTrainedModel): 1097 | def __init__(self, config): 1098 | super().__init__(config) 1099 | self.num_labels = config.num_labels 1100 | self.config = config 1101 | 1102 | self.esm = EsmModel(config, add_pooling_layer=False) 1103 | self.classifier = EsmClassificationHead(config) 1104 | 1105 | self.init_weights() 1106 | 1107 | @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1108 | @add_code_sample_docstrings( 1109 | checkpoint=_CHECKPOINT_FOR_DOC, 1110 | output_type=SequenceClassifierOutput, 1111 | config_class=_CONFIG_FOR_DOC, 1112 | ) 1113 | def forward( 1114 | self, 1115 | input_ids: Optional[torch.LongTensor] = None, 1116 | attention_mask: Optional[torch.Tensor] = None, 1117 | position_ids: Optional[torch.LongTensor] = None, 1118 | head_mask: Optional[torch.Tensor] = None, 1119 | inputs_embeds: Optional[torch.FloatTensor] = None, 1120 | labels: Optional[torch.LongTensor] = None, 1121 | output_attentions: Optional[bool] = None, 1122 | output_hidden_states: Optional[bool] = None, 1123 | return_dict: Optional[bool] = None, 1124 | ) -> Union[Tuple, SequenceClassifierOutput]: 1125 | r""" 1126 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1127 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1128 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1129 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1130 | """ 1131 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1132 | 1133 | outputs = self.esm( 1134 | input_ids, 1135 | attention_mask=attention_mask, 1136 | position_ids=position_ids, 1137 | head_mask=head_mask, 1138 | inputs_embeds=inputs_embeds, 1139 | output_attentions=output_attentions, 1140 | output_hidden_states=output_hidden_states, 1141 | return_dict=return_dict, 1142 | ) 1143 | sequence_output = outputs[0] 1144 | logits = self.classifier(sequence_output) 1145 | 1146 | loss = None 1147 | if labels is not None: 1148 | labels = labels.to(logits.device) 1149 | 1150 | if self.config.problem_type is None: 1151 | if self.num_labels == 1: 1152 | self.config.problem_type = "regression" 1153 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1154 | self.config.problem_type = "single_label_classification" 1155 | else: 1156 | self.config.problem_type = "multi_label_classification" 1157 | 1158 | if self.config.problem_type == "regression": 1159 | loss_fct = MSELoss() 1160 | if self.num_labels == 1: 1161 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1162 | else: 1163 | loss = loss_fct(logits, labels) 1164 | elif self.config.problem_type == "single_label_classification": 1165 | loss_fct = CrossEntropyLoss() 1166 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1167 | elif self.config.problem_type == "multi_label_classification": 1168 | loss_fct = BCEWithLogitsLoss() 1169 | loss = loss_fct(logits, labels) 1170 | 1171 | if not return_dict: 1172 | output = (logits,) + outputs[2:] 1173 | return ((loss,) + output) if loss is not None else output 1174 | 1175 | return SequenceClassifierOutput( 1176 | loss=loss, 1177 | logits=logits, 1178 | hidden_states=outputs.hidden_states, 1179 | attentions=outputs.attentions, 1180 | ) 1181 | 1182 | 1183 | @add_start_docstrings( 1184 | """ 1185 | ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1186 | Named-Entity-Recognition (NER) tasks. 1187 | """, 1188 | ESM_START_DOCSTRING, 1189 | ) 1190 | class EsmForTokenClassification(EsmPreTrainedModel): 1191 | def __init__(self, config): 1192 | super().__init__(config) 1193 | self.num_labels = config.num_labels 1194 | 1195 | self.esm = EsmModel(config, add_pooling_layer=False) 1196 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1197 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1198 | 1199 | self.init_weights() 1200 | 1201 | @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1202 | @add_code_sample_docstrings( 1203 | checkpoint=_CHECKPOINT_FOR_DOC, 1204 | output_type=TokenClassifierOutput, 1205 | config_class=_CONFIG_FOR_DOC, 1206 | ) 1207 | def forward( 1208 | self, 1209 | input_ids: Optional[torch.LongTensor] = None, 1210 | attention_mask: Optional[torch.Tensor] = None, 1211 | position_ids: Optional[torch.LongTensor] = None, 1212 | head_mask: Optional[torch.Tensor] = None, 1213 | inputs_embeds: Optional[torch.FloatTensor] = None, 1214 | labels: Optional[torch.LongTensor] = None, 1215 | output_attentions: Optional[bool] = None, 1216 | output_hidden_states: Optional[bool] = None, 1217 | return_dict: Optional[bool] = None, 1218 | ) -> Union[Tuple, TokenClassifierOutput]: 1219 | r""" 1220 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1221 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 1222 | """ 1223 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1224 | 1225 | outputs = self.esm( 1226 | input_ids, 1227 | attention_mask=attention_mask, 1228 | position_ids=position_ids, 1229 | head_mask=head_mask, 1230 | inputs_embeds=inputs_embeds, 1231 | output_attentions=output_attentions, 1232 | output_hidden_states=output_hidden_states, 1233 | return_dict=return_dict, 1234 | ) 1235 | attention_mask = attention_mask.to() 1236 | sequence_output = outputs[0] 1237 | 1238 | sequence_output = self.dropout(sequence_output) 1239 | logits = self.classifier(sequence_output) 1240 | 1241 | loss = None 1242 | if labels is not None: 1243 | loss_fct = CrossEntropyLoss() 1244 | 1245 | labels = labels.to(logits.device) 1246 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1247 | 1248 | if not return_dict: 1249 | output = (logits,) + outputs[2:] 1250 | return ((loss,) + output) if loss is not None else output 1251 | 1252 | return TokenClassifierOutput( 1253 | loss=loss, 1254 | logits=logits, 1255 | hidden_states=outputs.hidden_states, 1256 | attentions=outputs.attentions, 1257 | ) 1258 | 1259 | 1260 | class EsmClassificationHead(nn.Module): 1261 | """Head for sentence-level classification tasks.""" 1262 | 1263 | def __init__(self, config): 1264 | super().__init__() 1265 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1266 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1267 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 1268 | 1269 | def forward(self, features, **kwargs): 1270 | x = features[:, 0, :] # take token (equiv. to [CLS]) 1271 | x = self.dropout(x) 1272 | x = self.dense(x) 1273 | x = torch.tanh(x) 1274 | x = self.dropout(x) 1275 | x = self.out_proj(x) 1276 | return x 1277 | 1278 | class EsmForMutationDesign(EsmPreTrainedModel): 1279 | def __init__(self, config): 1280 | config.add_cross_attention = True 1281 | config.adapter_freq = 3 1282 | super().__init__(config) 1283 | 1284 | self.esm = EsmModel(config, add_pooling_layer=False) 1285 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1286 | self.mutation_classifier = nn.Linear(config.hidden_size, 2) 1287 | self.lm_head = EsmLMHead(config) 1288 | 1289 | self.init_weights() 1290 | 1291 | def forward( 1292 | self, 1293 | input_ids: Optional[torch.LongTensor] = None, 1294 | attention_mask: Optional[torch.Tensor] = None, 1295 | position_ids: Optional[torch.LongTensor] = None, 1296 | head_mask: Optional[torch.Tensor] = None, 1297 | inputs_embeds: Optional[torch.FloatTensor] = None, 1298 | mutation_position: Optional[torch.LongTensor] = None, 1299 | mutation_aa: Optional[torch.LongTensor] = None, 1300 | batch_idx: Optional[torch.LongTensor] = None, 1301 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 1302 | encoder_attention_mask: Optional[torch.LongTensor] = None, 1303 | output_attentions: Optional[bool] = None, 1304 | output_hidden_states: Optional[bool] = None, 1305 | return_dict: Optional[bool] = None, 1306 | ) -> Union[Tuple, ModelOutput]: 1307 | r""" 1308 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1309 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 1310 | """ 1311 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1312 | 1313 | outputs = self.esm( 1314 | input_ids, 1315 | attention_mask=attention_mask, 1316 | position_ids=position_ids, 1317 | head_mask=head_mask, 1318 | inputs_embeds=inputs_embeds, 1319 | output_attentions=output_attentions, 1320 | output_hidden_states=output_hidden_states, 1321 | encoder_hidden_states=encoder_hidden_states, 1322 | encoder_attention_mask=encoder_attention_mask, 1323 | return_dict=return_dict, 1324 | ) 1325 | sequence_output = outputs[0] 1326 | sequence_output = self.dropout(sequence_output) 1327 | mutation_pos_logits = self.mutation_classifier(sequence_output) 1328 | # extended_attention_mask = (1.0 - attention_mask) * torch.finfo(sequence_output.dtype).min 1329 | # mutation_pos_logits += extended_attention_mask 1330 | 1331 | extended_mutation_position = batch_idx * sequence_output.shape[1] + mutation_position 1332 | mutation_aa_logits = self.lm_head(sequence_output.view(-1, sequence_output.shape[2])[extended_mutation_position]) 1333 | 1334 | # mutation_position_label = torch.eye(input_ids.shape[1], dtype=torch.long).to(input_ids.device)[mutation_position] - 100 * (1 - attention_mask) 1335 | mutation_position_label = torch.zeros(input_ids.shape, dtype=torch.long).to(input_ids.device) 1336 | mutation_position_label[batch_idx, mutation_position] = 1 1337 | mutation_position_label -= 100 * (1 - attention_mask) 1338 | 1339 | loss_fct_pos = CrossEntropyLoss(weight=torch.tensor([1.0, 50.0]).to(input_ids.device)) 1340 | loss_fct_aa = CrossEntropyLoss() 1341 | loss_pos = loss_fct_pos(mutation_pos_logits.view(-1, 2), mutation_position_label.view(-1)) 1342 | loss_aa = loss_fct_aa(mutation_aa_logits.view(-1, self.config.vocab_size), mutation_aa.view(-1)) 1343 | 1344 | if not return_dict: 1345 | output = (mutation_pos_logits, mutation_aa_logits,) + outputs[2:] 1346 | return ((loss_pos, loss_aa) + output) 1347 | 1348 | return ModelOutput( 1349 | loss_pos=loss_pos, 1350 | loss_aa=loss_aa, 1351 | logits_pos=mutation_pos_logits, 1352 | logits_aa=mutation_aa_logits, 1353 | hidden_states=outputs.hidden_states, 1354 | attentions=outputs.attentions, 1355 | ) 1356 | 1357 | @torch.no_grad() 1358 | def lm_design(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask): 1359 | outputs = self.esm( 1360 | input_ids, 1361 | attention_mask=attention_mask, 1362 | encoder_hidden_states=encoder_hidden_states, 1363 | encoder_attention_mask=encoder_attention_mask, 1364 | return_dict=True 1365 | ) 1366 | sequence_output = outputs[0] 1367 | mutation_pos_prob = F.softmax(self.mutation_classifier(sequence_output) / 0.5, dim=-1)[:, :, 1] * attention_mask 1368 | mutation_pos_prob[:, 0] = 0 1369 | mutation_pos_prob[input_ids == 2] = 0 1370 | mutation_aa_prob = F.softmax(self.lm_head(sequence_output) / 0.5, dim=-1) 1371 | mutation_prob = mutation_pos_prob.unsqueeze(2) * mutation_aa_prob 1372 | 1373 | return mutation_prob 1374 | 1375 | 1376 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 1377 | """ 1378 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 1379 | are ignored. This is modified from fairseq's `utils.make_positions`. 1380 | 1381 | Args: 1382 | x: torch.Tensor x: 1383 | 1384 | Returns: torch.Tensor 1385 | """ 1386 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 1387 | mask = input_ids.ne(padding_idx).int() 1388 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 1389 | return incremental_indices.long() + padding_idx 1390 | -------------------------------------------------------------------------------- /model/mutaplm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import contextlib 3 | import re 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from transformers import LlamaTokenizer, LlamaConfig, LlamaForCausalLM, EsmTokenizer 10 | from collections import OrderedDict 11 | from peft import get_peft_model, LoraConfig, TaskType 12 | from torch.nn import CrossEntropyLoss 13 | from model.modeling_esm import EsmForMaskedLM, EsmForMutationDesign 14 | 15 | class MutaPLM(nn.Module): 16 | def __init__( 17 | self, 18 | protein_model=None, 19 | llama_ckpt=None, 20 | llama_pretrained_ckpt=None, 21 | num_query_tokens_protein1=64, 22 | num_query_tokens_protein2=64, 23 | ca_num_head=8, 24 | protein_maxlen=1024, 25 | text_maxlen=256, 26 | func_maxlen=512, 27 | test_mode=False, 28 | resume=False, 29 | device=None, 30 | m2t=True, 31 | t2m=True, 32 | pretrain=False, 33 | ): 34 | super(MutaPLM, self).__init__() 35 | self.device = device 36 | self.num_query_tokens_protein1 = num_query_tokens_protein1 37 | self.num_query_tokens_protein2 = num_query_tokens_protein2 38 | self.ca_num_head = ca_num_head 39 | self.protein_maxlen = protein_maxlen 40 | self.text_maxlen = text_maxlen 41 | self.func_maxlen = func_maxlen 42 | self.m2t = m2t 43 | self.t2m = t2m 44 | self.pretrain = pretrain 45 | 46 | # load esm 47 | print("*** loading protein model...") 48 | if pretrain: 49 | self.protein_model = EsmForMaskedLM.from_pretrained(protein_model, torch_dtype=torch.bfloat16) 50 | self.forward_fn = self.forward_pt 51 | self.loss_names = [] 52 | if self.m2t: 53 | self.loss_names.append("loss_p2t") 54 | if self.t2m: 55 | self.loss_names.append("loss_t2p") 56 | else: 57 | self.protein_model = EsmForMutationDesign.from_pretrained(protein_model, torch_dtype=torch.bfloat16) # delta decoder is here 58 | self.forward_fn = self.forward_ft 59 | self.loss_names = [] 60 | if self.m2t: 61 | self.loss_names.append("loss_m2t") 62 | if self.t2m: 63 | self.loss_names += (["loss_pos", "loss_aa"]) 64 | self.protein_tokenizer = EsmTokenizer.from_pretrained(protein_model) 65 | print("*** freezing protein model...") 66 | for name, param in self.protein_model.named_parameters(): 67 | if not "_adapter" in name and not "mutation_classifier" in name and not "lm_head" in name: 68 | param.requires_grad = False 69 | 70 | # load llm 71 | print("*** loading llm tokenizer...") 72 | self.llm_tokenizer = LlamaTokenizer.from_pretrained(llama_ckpt, truncation_side="left") 73 | self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 74 | self.llm_tokenizer.add_special_tokens({'bos_token': ''}) 75 | self.llm_tokenizer.add_special_tokens({'eos_token': ''}) 76 | self.llm_tokenizer.add_special_tokens({'unk_token': ''}) 77 | print(f"*** loading llm from {llama_ckpt}...") 78 | if pretrain: 79 | self.llm = LlamaForCausalLM.from_pretrained(llama_ckpt, torch_dtype=torch.bfloat16) 80 | else: 81 | cfg = LlamaConfig.from_pretrained(llama_ckpt) 82 | self.llm = LlamaForCausalLM(cfg) 83 | self.llm.resize_token_embeddings(len(self.llm_tokenizer)) 84 | 85 | # add lora 86 | print("*** adding LoRA...") 87 | lora_config = LoraConfig( 88 | peft_type=TaskType.CAUSAL_LM, 89 | inference_mode=test_mode, 90 | r=16, lora_alpha=16, 91 | lora_dropout=0.05, 92 | target_modules=["v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], 93 | ) 94 | self.llm = get_peft_model(self.llm, lora_config) 95 | self.llm.print_trainable_parameters() 96 | 97 | # delta encoder with cross attention 98 | print("*** building delta encoder...") 99 | self.query_protein1 = nn.Parameter( 100 | torch.zeros(1, num_query_tokens_protein1, self.protein_model.config.hidden_size) 101 | ) 102 | nn.init.normal_(self.query_protein1, 0, 0.02) 103 | self.query_protein2 = nn.Parameter( 104 | torch.zeros(1, num_query_tokens_protein2, self.protein_model.config.hidden_size) 105 | ) 106 | nn.init.normal_(self.query_protein2, 0, 0.02) 107 | self.pooler_protein1 = nn.MultiheadAttention( 108 | embed_dim=self.protein_model.config.hidden_size, 109 | num_heads=self.ca_num_head, 110 | batch_first=True 111 | ) 112 | self.pooler_protein2 = nn.MultiheadAttention( 113 | embed_dim=self.protein_model.config.hidden_size, 114 | num_heads=self.ca_num_head, 115 | batch_first=True 116 | ) 117 | 118 | self.bop_embeds = nn.Parameter(torch.zeros(1, 1, self.llm.config.hidden_size)) 119 | self.eop_embeds = nn.Parameter(torch.zeros(1, 1, self.llm.config.hidden_size)) 120 | self.bom_embeds = nn.Parameter(torch.zeros(1, 1, self.llm.config.hidden_size)) 121 | self.eom_embeds = nn.Parameter(torch.zeros(1, 1, self.llm.config.hidden_size)) 122 | self.soft_tokens = nn.Parameter(torch.zeros(1, num_query_tokens_protein2, self.llm.config.hidden_size)) 123 | nn.init.normal_(self.bop_embeds, 0, 0.02) 124 | nn.init.normal_(self.eop_embeds, 0, 0.02) 125 | nn.init.normal_(self.bom_embeds, 0, 0.02) 126 | nn.init.normal_(self.eom_embeds, 0, 0.02) 127 | nn.init.normal_(self.soft_tokens, 0, 0.02) 128 | 129 | # build proj 130 | self.proj_protein1 = nn.Linear(self.protein_model.config.hidden_size, self.llm.config.hidden_size) 131 | self.proj_protein2 = nn.Linear(self.protein_model.config.hidden_size, self.llm.config.hidden_size) 132 | self.proj_text = nn.Linear(self.llm.config.hidden_size, self.protein_model.config.hidden_size) 133 | 134 | if not pretrain and llama_pretrained_ckpt is not None: 135 | print(f"*** loading pretrained llm from {llama_pretrained_ckpt}...") 136 | ckpt = torch.load(llama_pretrained_ckpt, map_location=torch.device("cpu"))["model"] 137 | print(self.load_state_dict(self.convert_params(ckpt), strict=False)) 138 | del ckpt 139 | 140 | if not m2t: 141 | print("*** freeze m2t parameters") 142 | self.freeze_m2t_params() 143 | print("*** model built successfully.") 144 | 145 | def freeze_m2t_params(self): 146 | for param in self.pooler_protein1.parameters(): 147 | param.requires_grad = False 148 | for param in self.pooler_protein2.parameters(): 149 | param.requires_grad = False 150 | for param in self.proj_protein1.parameters(): 151 | param.requires_grad = False 152 | for param in self.proj_protein2.parameters(): 153 | param.requires_grad = False 154 | self.query_protein1.requires_grad = False 155 | self.query_protein2.requires_grad = False 156 | self.bop_embeds.requires_grad = False 157 | self.eop_embeds.requires_grad = False 158 | self.bom_embeds.requires_grad = False 159 | self.eom_embeds.requires_grad = False 160 | 161 | 162 | def convert_params(self, ckpt): 163 | # Initialize parameters for fine-tuning 164 | # pooler_protein -> pooler_protein 1&2 165 | # query_protein -> query_protein 1&2 166 | new_ckpt = OrderedDict() 167 | for k, v in ckpt.items(): 168 | if "pooler_protein" in k: 169 | new_ckpt[k.replace("pooler_protein", "pooler_protein1")] = v 170 | new_ckpt[k.replace("pooler_protein", "pooler_protein2")] = v 171 | elif k.startswith("proj"): 172 | new_ckpt[k.replace("proj", "proj_protein1")] = v 173 | new_ckpt[k.replace("proj", "proj_protein2")] = v 174 | elif "query_protein" in k: 175 | new_ckpt[k.replace("query_protein", "query_protein1")] = v 176 | new_ckpt[k.replace("query_protein", "query_protein2")] = v 177 | elif "bop_embeds" in k: 178 | new_ckpt[k] = v 179 | new_ckpt[k.replace("bop_embeds", "bom_embeds")] = v 180 | elif "eop_embeds" in k: 181 | new_ckpt[k] = v 182 | new_ckpt[k.replace("eop_embeds", "eom_embeds")] = v 183 | else: 184 | new_ckpt[k] = v 185 | 186 | return new_ckpt 187 | 188 | 189 | def maybe_autocast(self, dtype=torch.bfloat16): 190 | # if on cpu, don't use autocast 191 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 192 | enable_autocast = self.device != torch.device("cpu") 193 | 194 | if enable_autocast: 195 | return torch.cuda.amp.autocast(dtype=dtype) 196 | else: 197 | return contextlib.nullcontext() 198 | 199 | 200 | def _encode_protein(self, protein1, protein2): 201 | batch_size = len(protein1) 202 | protein1 = self.protein_tokenizer( 203 | protein1, 204 | max_length=self.protein_maxlen, 205 | padding=True, 206 | truncation=True, 207 | add_special_tokens=True, 208 | return_tensors='pt' 209 | ).to(self.device) # input_ids: [bs, prot_len] 210 | if protein2 is not None: 211 | protein2 = self.protein_tokenizer( 212 | protein2, 213 | max_length=self.protein_maxlen, 214 | padding=True, 215 | truncation=True, 216 | add_special_tokens=True, 217 | return_tensors='pt' 218 | ).to(self.device) 219 | 220 | with self.maybe_autocast(): 221 | protein_feature1 = self.protein_model.esm(**protein1) # last_hidden_states: [bs, prot_len, esm_hidden_size] 222 | query_protein1 = self.query_protein1.expand(batch_size, -1, -1) 223 | attn_mask_1 = (1 - protein1.attention_mask.repeat(self.ca_num_head, 1).unsqueeze(1).expand(-1, self.num_query_tokens_protein1, -1)).to(bool) 224 | p_feature1 = self.pooler_protein1( 225 | query_protein1, 226 | protein_feature1[0], 227 | protein_feature1[0], 228 | attn_mask = attn_mask_1 229 | ) 230 | protein1_embeds = self.proj_protein1(p_feature1[0]) 231 | 232 | if protein2 is not None: 233 | p_feature2 = self.protein_model.esm(**protein2) 234 | query_protein2 = self.query_protein2.expand(batch_size, -1, -1) 235 | attn_mask_2 = (1 - protein2.attention_mask.repeat(self.ca_num_head, 1).unsqueeze(1).expand(-1, self.num_query_tokens_protein2, -1)).to(bool) 236 | delta_feature = p_feature2[0] - protein_feature1[0] 237 | p_feature2 = self.pooler_protein2( 238 | query_protein2, 239 | delta_feature, 240 | delta_feature, 241 | attn_mask = attn_mask_2 242 | ) 243 | protein2_embeds = self.proj_protein2(p_feature2[0]) 244 | 245 | if protein2 is not None: 246 | return protein1_embeds, protein2_embeds 247 | else: 248 | return protein1_embeds 249 | 250 | 251 | def add_padding(self, wrapped_embeds, wrapped_attention_mask=None, targets=None, regress_ids=None, padding="right"): 252 | assert (targets is None) or (regress_ids is None) 253 | batch_size = len(wrapped_embeds) 254 | max_length_batch = max([x.shape[1] for x in wrapped_embeds]) 255 | for i in range(batch_size): 256 | pad_len = max_length_batch - wrapped_embeds[i].shape[1] 257 | if padding == "right": 258 | wrapped_embeds[i] = torch.cat(( 259 | wrapped_embeds[i], 260 | torch.zeros((1, pad_len, wrapped_embeds[i].shape[2]), dtype=wrapped_embeds[i].dtype).to(wrapped_embeds[i].device) 261 | ), dim=1) 262 | if wrapped_attention_mask: 263 | wrapped_attention_mask[i] = torch.cat(( 264 | wrapped_attention_mask[i], 265 | torch.zeros((1, pad_len), dtype=wrapped_attention_mask[i].dtype).to(wrapped_attention_mask[i].device) 266 | ), dim=1) 267 | if targets: 268 | targets[i] = torch.cat(( 269 | targets[i], 270 | torch.ones((1, pad_len), dtype=targets[i].dtype).to(targets[i].device).fill_(-100) 271 | ), dim=1) 272 | if regress_ids: 273 | regress_ids[i] = torch.cat(( 274 | regress_ids[i], 275 | torch.zeros((pad_len), dtype=regress_ids[i].dtype).to(regress_ids[i].device) 276 | ), dim=0) 277 | else: 278 | wrapped_embeds[i] = torch.cat(( 279 | torch.zeros((1, pad_len, wrapped_embeds[i].shape[2]), dtype=wrapped_embeds[i].dtype).to(wrapped_embeds[i].device), 280 | wrapped_embeds[i], 281 | ), dim=1) 282 | if wrapped_attention_mask: 283 | wrapped_attention_mask[i] = torch.cat(( 284 | torch.zeros((1, pad_len), dtype=wrapped_attention_mask[i].dtype).to(wrapped_attention_mask[i].device), 285 | wrapped_attention_mask[i], 286 | ), dim=1) 287 | if targets: 288 | targets[i] = torch.cat(( 289 | torch.ones((1, pad_len), dtype=targets[i].dtype).to(targets[i].device).fill_(-100), 290 | targets[i], 291 | ), dim=1) 292 | if regress_ids: 293 | regress_ids[i] = torch.cat(( 294 | torch.zeros((pad_len), dtype=regress_ids[i].dtype).to(regress_ids[i].device), 295 | regress_ids[i] 296 | ), dim=0) 297 | 298 | if targets: 299 | return torch.cat(wrapped_embeds, dim=0), torch.cat(wrapped_attention_mask, dim=0), torch.cat(targets, dim=0) 300 | if regress_ids: 301 | return torch.cat(wrapped_embeds, dim=0), torch.cat(wrapped_attention_mask, dim=0), torch.stack(regress_ids, dim=0) 302 | if wrapped_attention_mask is None: 303 | return torch.cat(wrapped_embeds, dim=0) 304 | else: 305 | return torch.cat(wrapped_embeds, dim=0), torch.cat(wrapped_attention_mask, dim=0) 306 | 307 | def _wrapped_sentence_pt(self, protein, text): 308 | if self.t2m: 309 | soft_embeds = self.soft_tokens.to(self.device) 310 | batched_embeds2, batched_attn_mask2, batched_soft_ids = [], [], [] 311 | 312 | with self.maybe_autocast(): 313 | batch_size = len(protein) 314 | protein = self.protein_tokenizer( 315 | protein, 316 | max_length=self.protein_maxlen, 317 | padding=True, 318 | truncation=True, 319 | add_special_tokens=True, 320 | return_tensors='pt' 321 | ).to(self.device) 322 | p_feature = self.protein_model.esm(**protein) 323 | 324 | query_protein = self.query_protein1.expand(batch_size, -1, -1) 325 | attn_mask_ca = (1 - protein.attention_mask.repeat(self.ca_num_head, 1).unsqueeze(1).expand(-1, self.num_query_tokens_protein1, -1)).to(bool) 326 | pooled_feature = self.pooler_protein1( 327 | query_protein, 328 | p_feature[0], 329 | p_feature[0], 330 | attn_mask = attn_mask_ca 331 | ) 332 | protein_embeds = self.proj_protein1(pooled_feature[0]) 333 | 334 | input_emb = self.llm.get_input_embeddings() 335 | bos_tokens = self.llm_tokenizer('', return_tensors='pt', add_special_tokens=False).to(self.device).input_ids.expand(batch_size, -1) 336 | bos_embeds = input_emb(bos_tokens) 337 | bop_embeds = self.bop_embeds.expand(batch_size, -1, -1) 338 | eop_embeds = self.eop_embeds.expand(batch_size, -1, -1) 339 | 340 | text = [t+"" for t in text] 341 | text_tokens = self.llm_tokenizer( 342 | text, 343 | max_length=self.text_maxlen, 344 | padding=True, 345 | truncation=True, 346 | return_tensors='pt', 347 | add_special_tokens=False 348 | ).to(self.device) 349 | text_embeds = input_emb(text_tokens.input_ids) 350 | 351 | wrapped_embeds = torch.cat([bos_embeds, bop_embeds, protein_embeds, eop_embeds, text_embeds], dim=1) 352 | attention_mask = torch.ones((batch_size, bos_embeds.shape[1] + bop_embeds.shape[1] + protein_embeds.shape[1] + eop_embeds.shape[1]), dtype=torch.long, device=self.device) 353 | wrapped_attention_mask = torch.cat([attention_mask, text_tokens.attention_mask], dim=1) 354 | labels = text_tokens.input_ids.masked_fill(~text_tokens.attention_mask.bool(), -100) 355 | wrapped_labels = torch.cat([attention_mask * -100, labels], dim=1) 356 | 357 | if self.t2m: 358 | for t in text: 359 | tokens = self.llm_tokenizer( 360 | [t.rstrip("")], 361 | max_length=self.text_maxlen, 362 | padding=False, 363 | truncation=True, 364 | return_tensors='pt', 365 | add_special_tokens=False 366 | ).to(self.device) 367 | text_embeds = input_emb(tokens.input_ids) 368 | # regression loss 369 | regress_start_id = text_embeds.shape[1] + 2 370 | wrapped_embeds2 = torch.cat([ 371 | bos_embeds[0].unsqueeze(0), text_embeds, bop_embeds[0].unsqueeze(0), soft_embeds 372 | ], dim=1) 373 | wrapped_attn_mask2 = torch.ones((1, wrapped_embeds2.shape[1]), dtype=torch.long, device=self.device) 374 | regress_ids = torch.cat([ 375 | torch.zeros(regress_start_id, dtype=torch.long, device=self.device), 376 | torch.ones(self.num_query_tokens_protein2, dtype=torch.long, device=self.device), 377 | ], dim=0).bool() 378 | batched_embeds2.append(wrapped_embeds2) 379 | batched_attn_mask2.append(wrapped_attn_mask2) 380 | batched_soft_ids.append(regress_ids) 381 | batched_embeds2, batched_attn_mask2, batched_soft_ids = self.add_padding( 382 | batched_embeds2, batched_attn_mask2, targets=None, regress_ids=batched_soft_ids 383 | ) 384 | return wrapped_embeds, wrapped_attention_mask, wrapped_labels, batched_embeds2, batched_attn_mask2, batched_soft_ids 385 | 386 | return wrapped_embeds, wrapped_attention_mask, wrapped_labels 387 | 388 | def _wrapped_sentence_ft(self, protein1_embeds, protein2_embeds, mut_entry, p_function, muta_prompt, text): 389 | assert text is not None 390 | batch_size = protein1_embeds.shape[0] 391 | input_emb = self.llm.get_input_embeddings() 392 | bos_tokens = self.llm_tokenizer('', return_tensors='pt', add_special_tokens=False).to(self.device).input_ids 393 | bos_embeds = input_emb(bos_tokens) # [1, 1, 4096] 394 | bop_embeds = self.bop_embeds.to(self.device) 395 | eop_embeds = self.eop_embeds.to(self.device) 396 | bom_embeds = self.bom_embeds.to(self.device) 397 | eom_embeds = self.eom_embeds.to(self.device) 398 | 399 | if self.t2m: 400 | soft_embeds = self.soft_tokens.to(self.device) 401 | batched_embeds2, batched_attn_mask2, batched_regress_ids = [], [], [] 402 | 403 | batched_embeds1, batched_attn_mask1, batched_labels = [], [], [] 404 | p_function = [t+"" for t in p_function] 405 | text = [t+"" for t in text] 406 | sys_prompt_tokens = self.llm_tokenizer( 407 | "You are an expert at biology and life science. Now a user gives you several protein sequences and mutations. Please follow user instructions and answer their questions. Based on the following protein sequence, please describe its function.", 408 | max_length=self.func_maxlen, 409 | padding=False, 410 | truncation=True, 411 | return_tensors='pt', 412 | add_special_tokens=False, 413 | ).to(self.device).input_ids 414 | sys_embeds = input_emb(sys_prompt_tokens) 415 | for i in range(batch_size): 416 | function_tokens = self.llm_tokenizer( 417 | p_function[i], 418 | max_length=self.func_maxlen, 419 | padding=False, 420 | truncation=True, 421 | return_tensors='pt', 422 | add_special_tokens=False, 423 | ).to(self.device) 424 | mutation_tokens = self.llm_tokenizer( 425 | muta_prompt[i], 426 | max_length=self.text_maxlen, 427 | padding=False, 428 | truncation=True, 429 | return_tensors='pt', 430 | add_special_tokens=False, 431 | ).to(self.device) 432 | text_tokens = self.llm_tokenizer( 433 | text[i], 434 | max_length=self.text_maxlen, 435 | padding=False, 436 | truncation=True, 437 | return_tensors='pt', 438 | add_special_tokens=False, 439 | ).to(self.device) 440 | func_embeds = input_emb(function_tokens.input_ids) 441 | muta_embeds = input_emb(mutation_tokens.input_ids) 442 | text_embeds = input_emb(text_tokens.input_ids) 443 | 444 | # understanding loss 445 | wrapped_embeds1 = torch.cat([ 446 | bos_embeds, sys_embeds, bop_embeds, protein1_embeds[i].unsqueeze(0), eop_embeds, 447 | func_embeds, muta_embeds, 448 | bom_embeds, protein2_embeds[i].unsqueeze(0), eom_embeds, 449 | text_embeds 450 | ], dim=1) 451 | wrapped_attn_mask1 = torch.ones((1, wrapped_embeds1.shape[1]), dtype=torch.long, device=self.device) 452 | wrapped_labels = torch.cat([ 453 | torch.ones((1, 3 + sys_embeds.shape[1] + protein2_embeds.shape[1]), dtype=torch.long, device=self.device) * -100, 454 | function_tokens.input_ids, 455 | torch.ones((1, muta_embeds.shape[1] + 2 + protein2_embeds.shape[1]), dtype=torch.long, device=self.device) * -100, 456 | text_tokens.input_ids 457 | ], dim=1) 458 | batched_embeds1.append(wrapped_embeds1) 459 | batched_attn_mask1.append(wrapped_attn_mask1) 460 | batched_labels.append(wrapped_labels) 461 | 462 | if self.t2m: 463 | regress_start_id = sys_embeds.shape[1] + self.num_query_tokens_protein1 + 3 + func_embeds.shape[1] + text_embeds.shape[1] 464 | wrapped_embeds2 = torch.cat([ 465 | bos_embeds, sys_embeds, bop_embeds, protein1_embeds[i].unsqueeze(0), eop_embeds, 466 | func_embeds, text_embeds[:, :-1, :], 467 | bom_embeds, soft_embeds, eom_embeds, text_embeds[:, -1:, :] 468 | ], dim=1) 469 | wrapped_attn_mask2 = torch.ones((1, wrapped_embeds2.shape[1]), dtype=torch.long, device=self.device) 470 | regress_ids = torch.cat([ 471 | torch.zeros(regress_start_id, dtype=torch.long, device=self.device), 472 | torch.ones(self.num_query_tokens_protein2, dtype=torch.long, device=self.device), 473 | torch.zeros(2, dtype=torch.long, device=self.device), 474 | ], dim=0).bool() 475 | batched_embeds2.append(wrapped_embeds2) 476 | batched_attn_mask2.append(wrapped_attn_mask2) 477 | batched_regress_ids.append(regress_ids) 478 | 479 | batched_embeds1, batched_attn_mask1, batched_labels = self.add_padding( 480 | batched_embeds1, batched_attn_mask1, targets=batched_labels, regress_ids=None) 481 | if self.t2m: 482 | mut_pos = torch.tensor([int(x[1:-1]) for x in mut_entry], dtype=torch.long).to(self.device) 483 | mut_aa = self.protein_tokenizer( 484 | [x[-1] for x in mut_entry], 485 | padding=False, 486 | truncation=True, 487 | max_length=self.protein_maxlen, 488 | return_tensors='pt', 489 | add_special_tokens=False, 490 | ).input_ids.to(self.device) 491 | batched_embeds2, batched_attn_mask2, batched_regress_ids = self.add_padding( 492 | batched_embeds2, batched_attn_mask2, targets=None, regress_ids=batched_regress_ids) 493 | return batched_embeds1, batched_attn_mask1, batched_labels, batched_embeds2, batched_attn_mask2, batched_regress_ids, mut_pos, mut_aa 494 | else: 495 | return batched_embeds1, batched_attn_mask1, batched_labels 496 | 497 | def _wrapped_sentence_inference(self, protein1_embeds, protein2_embeds, muta_prompt, predict_function=None, mut_text=None): 498 | batch_size = protein1_embeds.shape[0] 499 | input_emb = self.llm.get_input_embeddings() 500 | bos_tokens = self.llm_tokenizer('', return_tensors='pt', add_special_tokens=False).to(self.device).input_ids 501 | bos_embeds = input_emb(bos_tokens) # [1, 1, 4096] 502 | sys_prompt_tokens = self.llm_tokenizer( 503 | "You are an expert at biology and life science. Now a user gives you several protein sequences and mutations. Please follow user instructions and answer their questions.", 504 | max_length=self.func_maxlen, 505 | padding=False, 506 | truncation=True, 507 | return_tensors='pt', 508 | add_special_tokens=False, 509 | ).to(self.device).input_ids 510 | sys_embeds = input_emb(sys_prompt_tokens) 511 | if predict_function is None: # CoT stage 1 512 | sys_embeds = sys_embeds.expand(batch_size, -1, -1) 513 | bos_embeds = bos_embeds.expand(batch_size, -1, -1) 514 | bop_embeds = self.bop_embeds.expand(batch_size, -1, -1) 515 | eop_embeds = self.eop_embeds.expand(batch_size, -1, -1) 516 | bom_embeds = self.bom_embeds.expand(batch_size, -1, -1) 517 | eom_embeds = self.eom_embeds.expand(batch_size, -1, -1) 518 | wrapped_embeds = torch.cat([bos_embeds, sys_embeds, bop_embeds, protein1_embeds, eop_embeds], dim=1) 519 | attention_mask = torch.ones((batch_size, wrapped_embeds.shape[1]), dtype=torch.long, device=self.device) 520 | return wrapped_embeds, attention_mask 521 | 522 | else: # CoT stage 2 523 | bop_embeds = self.bop_embeds.to(self.device) 524 | eop_embeds = self.eop_embeds.to(self.device) 525 | bom_embeds = self.bom_embeds.to(self.device) 526 | eom_embeds = self.eom_embeds.to(self.device) 527 | batched_embeds, batched_attn_mask = [], [] 528 | if mut_text is not None: 529 | batched_regress_ids = [] 530 | predict_function = [t+"" for t in predict_function] 531 | for i in range(batch_size): 532 | function_tokens = self.llm_tokenizer( 533 | predict_function[i], 534 | max_length=self.func_maxlen, padding=False, truncation=True, 535 | return_tensors='pt', add_special_tokens=False, 536 | ).to(self.device) 537 | mutation_tokens = self.llm_tokenizer( 538 | muta_prompt[i], 539 | max_length=self.text_maxlen, padding=False, truncation=True, 540 | return_tensors='pt', add_special_tokens=False, 541 | ).to(self.device) 542 | func_embeds = input_emb(function_tokens.input_ids) 543 | muta_embeds = input_emb(mutation_tokens.input_ids) 544 | if mut_text is not None: 545 | mut_eff = self.llm_tokenizer( 546 | mut_text[i], 547 | max_length=self.text_maxlen, padding=False, truncation=True, 548 | return_tensors='pt', add_special_tokens=False, 549 | ).to(self.device) 550 | mut_eff_embeds = input_emb(mut_eff.input_ids) 551 | soft_embeds = self.soft_tokens.to(self.device) 552 | regress_start_id = sys_embeds.shape[1] + self.num_query_tokens_protein1 + 4 + func_embeds.shape[1] + mut_eff_embeds.shape[1] 553 | wrapped_embeds = torch.cat([ 554 | bos_embeds, sys_embeds, bop_embeds, protein1_embeds[i].unsqueeze(0), eop_embeds, 555 | func_embeds, mut_eff_embeds, 556 | bom_embeds, soft_embeds 557 | ], dim=1) 558 | regress_ids = torch.cat([ 559 | torch.zeros(regress_start_id, dtype=torch.long, device=self.device), 560 | torch.ones(self.num_query_tokens_protein2, dtype=torch.long, device=self.device), 561 | ], dim=0).bool() 562 | batched_regress_ids.append(regress_ids) 563 | else: 564 | wrapped_embeds = torch.cat([ 565 | bos_embeds, sys_embeds, bop_embeds, protein1_embeds[i].unsqueeze(0), eop_embeds, 566 | func_embeds, muta_embeds, 567 | bom_embeds, protein2_embeds[i].unsqueeze(0), eom_embeds, 568 | ], dim=1) 569 | wrapped_attn_mask = torch.ones((1, wrapped_embeds.shape[1]), dtype=torch.long, device=self.device) 570 | batched_embeds.append(wrapped_embeds) 571 | batched_attn_mask.append(wrapped_attn_mask) 572 | 573 | if mut_text is None: 574 | batched_embeds, batched_attn_mask = self.add_padding( 575 | batched_embeds, batched_attn_mask, targets=None, regress_ids=None, padding="left") 576 | return batched_embeds, batched_attn_mask 577 | else: 578 | batched_embeds, batched_attn_mask, batched_regress_ids = self.add_padding( 579 | batched_embeds, batched_attn_mask, targets=None, regress_ids=batched_regress_ids, padding="left") 580 | return batched_embeds, batched_attn_mask, batched_regress_ids 581 | 582 | 583 | def protein_mask(self, protein, mask_ratio=0.15): 584 | protein = self.protein_tokenizer( 585 | protein, 586 | add_special_tokens=True, 587 | truncation=True, 588 | padding=True, 589 | max_length=self.protein_maxlen, 590 | return_tensors='pt' 591 | ).to(self.device) 592 | labels = protein.input_ids.clone() 593 | masked_indices = torch.bernoulli(torch.full(labels.shape, mask_ratio)).bool() 594 | masked_indices[labels == self.protein_tokenizer.pad_token_id] = False 595 | masked_indices[labels == self.protein_tokenizer.cls_token_id] = False 596 | masked_indices[labels == self.protein_tokenizer.eos_token_id] = False 597 | protein.input_ids[masked_indices] = self.protein_tokenizer.mask_token_id 598 | labels[~masked_indices] = -100 599 | return protein, labels 600 | 601 | def forward_pt(self, protein, text): 602 | if self.t2m: 603 | input_embeds_p2t, attn_mask_p2t, labels_p2t, input_embeds_t2p, attn_mask_t2p, soft_ids = self._wrapped_sentence_pt(protein, text) 604 | else: 605 | input_embeds_p2t, attn_mask_p2t, labels_p2t = self._wrapped_sentence_pt(protein, text) 606 | with self.maybe_autocast(): 607 | if self.m2t: 608 | loss_p2t = self.llm( 609 | inputs_embeds=input_embeds_p2t, 610 | attention_mask=attn_mask_p2t, 611 | labels=labels_p2t, 612 | return_dict=True 613 | ).loss 614 | if self.t2m: 615 | masked_protein, masked_labels = self.protein_mask(protein) 616 | outputs = self.llm( 617 | inputs_embeds=input_embeds_t2p, 618 | attention_mask=attn_mask_t2p, 619 | output_hidden_states=True, 620 | return_dict=True, 621 | ).hidden_states[-1] 622 | soft_embeds = outputs[soft_ids].contiguous() 623 | soft_embeds = self.proj_text(soft_embeds.view(len(protein), self.num_query_tokens_protein2, -1)) 624 | loss_t2p = torch.mean(self.protein_model( 625 | input_ids=masked_protein.input_ids, 626 | attention_mask=masked_protein.attention_mask, 627 | encoder_hidden_states=soft_embeds, 628 | encoder_attention_mask=torch.ones(soft_embeds.shape[:-1], dtype=torch.long).to(self.device), 629 | labels=masked_labels, 630 | return_dict=True 631 | ).loss) 632 | 633 | if self.m2t and self.t2m: 634 | return loss_p2t + loss_t2p, {"loss_p2t": loss_p2t, "loss_t2p": loss_t2p} 635 | elif self.m2t: 636 | return loss_p2t, {"loss_p2t": loss_p2t} 637 | else: 638 | return loss_t2p, {"loss_t2p": loss_t2p} 639 | 640 | def forward_ft(self, protein1, protein2, mut_entry, text, p_function, muta_prompt): 641 | protein1_embeds, protein2_embeds = self._encode_protein(protein1, protein2) 642 | if self.t2m: 643 | input_embeds_m2t, attn_mask_m2t, labels_m2t, input_embeds_t2m, attn_mask_t2m, soft_ids_t2m, mut_pos, mut_aa = self._wrapped_sentence_ft(protein1_embeds, protein2_embeds, mut_entry, p_function, muta_prompt, text) 644 | else: 645 | input_embeds_m2t, attn_mask_m2t, labels_m2t = self._wrapped_sentence_ft(protein1_embeds, protein2_embeds, mut_entry, p_function, muta_prompt, text) 646 | 647 | with self.maybe_autocast(): 648 | if self.m2t: 649 | loss_m2t = self.llm( 650 | inputs_embeds=input_embeds_m2t, 651 | attention_mask=attn_mask_m2t, 652 | labels=labels_m2t, 653 | return_dict=True 654 | ).loss 655 | 656 | if self.t2m: 657 | outputs = self.llm( 658 | inputs_embeds=input_embeds_t2m, 659 | attention_mask=attn_mask_t2m, 660 | output_hidden_states=True, 661 | return_dict=True 662 | ).hidden_states 663 | soft_output = outputs[soft_ids_t2m].contiguous() 664 | soft_output = self.proj_text(soft_output.view(len(protein1), self.num_query_tokens_protein2, -1)) 665 | protein = self.protein_tokenizer( 666 | protein1, 667 | add_special_tokens=True, 668 | truncation=True, 669 | padding=True, 670 | max_length=self.protein_maxlen, 671 | return_tensors='pt' 672 | ).to(self.device) 673 | outputs = self.protein_model( 674 | input_ids=protein.input_ids, 675 | attention_mask=protein.attention_mask, 676 | mutation_position=mut_pos, 677 | mutation_aa=mut_aa, 678 | batch_idx=torch.arange(len(protein1)).to(self.device), 679 | encoder_hidden_states=soft_output, 680 | encoder_attention_mask=torch.ones(soft_output.shape[:1], dtype=torch.long).to(self.device), 681 | return_dict=True 682 | ) 683 | 684 | if self.m2t and self.t2m: 685 | return loss_m2t + outputs.loss_pos + 0.2 * outputs.loss_aa, {"loss_m2t": loss_m2t, "loss_pos": outputs.loss_pos, "loss_aa": outputs.loss_aa} 686 | elif self.m2t: 687 | return loss_m2t, {"loss_m2t": loss_m2t} 688 | else: 689 | return outputs.loss_pos + 0.2 * outputs.loss_aa, {"loss_pos": outputs.loss_pos, "loss_aa": outputs.loss_aa} 690 | 691 | 692 | @torch.no_grad() 693 | def generate( 694 | self, 695 | protein1, 696 | protein2, 697 | muta_prompt, 698 | pfunction=None, 699 | use_gt_function=False, 700 | use_nucleus_sampling=True, 701 | num_beams=2, 702 | max_length=256, 703 | min_length=1, 704 | top_p=0.9, 705 | repetition_penalty=1.5, 706 | length_penalty=1, 707 | num_captions=1, 708 | temperature=1, 709 | ): 710 | with self.maybe_autocast(): 711 | # stage 1 712 | protein1_embeds, protein2_embeds = self._encode_protein(protein1, protein2) 713 | if not use_gt_function: 714 | input_embeds, attn_mask = self._wrapped_sentence_inference(protein1_embeds, protein2_embeds, muta_prompt, predict_function=None) 715 | outputs_function = self.llm.generate( 716 | inputs_embeds=input_embeds, 717 | attention_mask=attn_mask, 718 | do_sample=use_nucleus_sampling, 719 | top_p=top_p, 720 | temperature=temperature, 721 | num_beams=num_beams, 722 | max_length=max_length, 723 | min_length=min_length, 724 | repetition_penalty=repetition_penalty, 725 | length_penalty=length_penalty, 726 | num_return_sequences=num_captions, 727 | eos_token_id=self.llm_tokenizer.eos_token_id, 728 | pad_token_id=self.llm_tokenizer.pad_token_id, 729 | ) 730 | outputs_function[outputs_function == 0] = 2 # convert output id 0 to 2 (eos_token_id) 731 | output_function_text = self.llm_tokenizer.batch_decode(outputs_function, skip_special_tokens=True) 732 | output_function_text = [text.strip() for text in output_function_text] 733 | else: # use ground truth protein function directly 734 | output_function_text = pfunction 735 | 736 | # stage 2 737 | input_embeds, attn_mask = self._wrapped_sentence_inference(protein1_embeds, protein2_embeds, muta_prompt, predict_function=output_function_text) 738 | outputs_effect = self.llm.generate( 739 | inputs_embeds=input_embeds, 740 | attention_mask=attn_mask, 741 | do_sample=use_nucleus_sampling, 742 | top_p=top_p, 743 | temperature=temperature, 744 | num_beams=num_beams, 745 | max_length=max_length, 746 | min_length=min_length, 747 | repetition_penalty=repetition_penalty, 748 | length_penalty=length_penalty, 749 | num_return_sequences=num_captions, 750 | eos_token_id=self.llm_tokenizer.eos_token_id, 751 | pad_token_id=self.llm_tokenizer.pad_token_id, 752 | ) 753 | outputs_effect[outputs_effect == 0] = 2 # convert output id 0 to 2 (eos_token_id) 754 | output_effect_text = self.llm_tokenizer.batch_decode(outputs_effect, skip_special_tokens=True) 755 | output_effect_text = [text.strip() for text in output_effect_text] 756 | 757 | return output_function_text, output_effect_text 758 | 759 | @torch.no_grad() 760 | def lm_design(self, 761 | protein, 762 | text, 763 | muta_prompt=None, 764 | pfunction=None, 765 | use_gt_function=True, 766 | use_nucleus_sampling=True, 767 | num_beams=2, 768 | max_length=256, 769 | min_length=1, 770 | top_p=0.9, 771 | repetition_penalty=1.5, 772 | length_penalty=1, 773 | num_captions=1, 774 | temperature=1, 775 | ): 776 | protein_embeds = self._encode_protein(protein, None) 777 | if not use_gt_function: 778 | input_embeds, attn_mask = self._wrapped_sentence_inference(protein_embeds, None, None, predict_function=None) 779 | outputs_function = self.llm.generate( 780 | inputs_embeds=input_embeds, 781 | attention_mask=attn_mask, 782 | do_sample=use_nucleus_sampling, 783 | top_p=top_p, 784 | temperature=temperature, 785 | num_beams=num_beams, 786 | max_length=max_length, 787 | min_length=min_length, 788 | repetition_penalty=repetition_penalty, 789 | length_penalty=length_penalty, 790 | num_return_sequences=num_captions, 791 | eos_token_id=self.llm_tokenizer.eos_token_id, 792 | pad_token_id=self.llm_tokenizer.pad_token_id, 793 | ) 794 | outputs_function[outputs_function == 0] = 2 # convert output id 0 to 2 (eos_token_id) 795 | output_function_text = self.llm_tokenizer.batch_decode(outputs_function, skip_special_tokens=True) 796 | output_function_text = [text.strip() for text in output_function_text] 797 | else: 798 | output_function_text = pfunction 799 | input_embeds, attn_mask, soft_ids = self._wrapped_sentence_inference(protein_embeds, None, muta_prompt, predict_function=output_function_text, mut_text=text) 800 | soft_output = self.llm.model( 801 | inputs_embeds=input_embeds, 802 | attention_mask=attn_mask, 803 | output_hidden_states=True, 804 | return_dict=True 805 | ).hidden_states[-1] 806 | soft_output = soft_output[soft_ids].contiguous() 807 | soft_output = self.proj_text(soft_output.view(len(protein), self.num_query_tokens_protein2, -1)) 808 | protein = self.protein_tokenizer( 809 | protein, 810 | add_special_tokens=True, 811 | truncation=True, 812 | padding='max_length', 813 | max_length=self.protein_maxlen, 814 | return_tensors='pt' 815 | ).to(self.device) 816 | return self.protein_model.lm_design( 817 | input_ids=protein.input_ids, 818 | attention_mask=protein.attention_mask, 819 | encoder_hidden_states=soft_output, 820 | encoder_attention_mask=torch.ones(soft_output.shape[:-1], dtype=torch.long).to(self.device) 821 | ) -------------------------------------------------------------------------------- /model/vanllina_esm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformers import EsmTokenizer, EsmForMaskedLM, BertTokenizer, BertForMaskedLM 8 | 9 | valid_aa = ['A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 10 | 'O', 'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z'] 11 | 12 | class RandomModel(nn.Module): 13 | def __init__(self, path, device) -> None: 14 | super().__init__() 15 | self.protein_tokenizer = EsmTokenizer.from_pretrained(path) 16 | self.device = device 17 | 18 | def lm_design(self, protein, text, **kwargs): 19 | all_logits = [] 20 | for p in protein: 21 | random_logits = torch.randn(len(p), 33) 22 | random_logits[:, :4] = -100 23 | random_logits[:, 24:] = -100 24 | all_logits.append(F.softmax(torch.cat( 25 | [torch.ones(1, 33) * -100, random_logits, torch.ones(1025 - len(p), 33) * -100], 26 | dim=0 27 | ), dim=0)) 28 | return torch.stack(all_logits, dim=0).to(self.device) 29 | 30 | 31 | class VanllinaEsm(nn.Module): 32 | def __init__(self, path, protein_maxlen, device, lambd=0.1, ontoprotein=False): 33 | super().__init__() 34 | self.protein_maxlen = protein_maxlen 35 | if ontoprotein: 36 | self.esm = BertForMaskedLM.from_pretrained(path) 37 | self.protein_tokenizer = BertTokenizer.from_pretrained(path) 38 | else: 39 | self.esm = EsmForMaskedLM.from_pretrained(path) 40 | self.protein_tokenizer = EsmTokenizer.from_pretrained(path) 41 | 42 | self.ontoprotein = ontoprotein 43 | self.device = device 44 | self.lambd = lambd 45 | self.loss_names = ["loss_reward", "loss_kl"] 46 | 47 | def forward(self, protein_mut, protein_wild, text, scores): 48 | outputs = self.predict_fitness(protein_mut, protein_wild, text, return_dict=True) 49 | pred_scores = outputs["score"] 50 | loss_reward = torch.tensor(0.).to(self.device) 51 | for i in range(len(protein_mut)): 52 | for j in range(i): 53 | if scores[i] > scores[j]: 54 | loss_reward += -F.logsigmoid(pred_scores[i] - pred_scores[j]) 55 | else: 56 | loss_reward += -F.logsigmoid(pred_scores[j] - pred_scores[i]) 57 | 58 | loss_fn = nn.KLDivLoss(reduction='none') 59 | logits = outputs["logits"] 60 | logits_orig = self.frozen_lm_head(self.protein_model.esm(outputs["mask_input_ids"], outputs["attention_mask"])[0]) 61 | targets = F.softmax(logits_orig, dim=-1) 62 | i_indices = torch.arange(logits.shape[0]).unsqueeze(1).unsqueeze(2).expand(-1, logits.shape[1], logits.shape[2]) 63 | j_indices = torch.arange(logits.shape[1]).unsqueeze(0).unsqueeze(2).expand(logits.shape[0], -1, logits.shape[2]) 64 | k_indices = outputs["orig_input_ids"].unsqueeze(2).expand(-1, -1, logits.shape[2]) 65 | loss_kl = torch.mean(loss_fn(logits[i_indices, j_indices, k_indices], targets[i_indices, j_indices, k_indices]) * outputs["attn_mask"]) 66 | 67 | return loss_reward + self.lambd * loss_kl, {"loss_reward": loss_reward.detach(), "loss_kl": loss_kl.detach()} 68 | 69 | def validate_fn(self, protein_mut, protein_wild, text, scores): 70 | preds = self.predict_fitness(protein_mut, protein_wild) 71 | return preds, scores 72 | 73 | @torch.no_grad() 74 | def lm_design(self, protein, text, **kwargs): 75 | if self.ontoprotein: 76 | protein = [" ".join(list(p)) for p in protein] 77 | protein = self.protein_tokenizer( 78 | protein, 79 | add_special_tokens=True, 80 | truncation=True, 81 | padding='max_length', 82 | max_length=self.protein_maxlen, 83 | return_tensors='pt' 84 | ).to(self.device) 85 | logits = self.esm(protein.input_ids, protein.attention_mask, return_dict=True).logits 86 | 87 | i_indices = torch.arange(logits.shape[0]).unsqueeze(1).unsqueeze(2).expand(-1, logits.shape[1], logits.shape[2]) 88 | j_indices = torch.arange(logits.shape[1]).unsqueeze(0).unsqueeze(2).expand(logits.shape[0], -1, logits.shape[2]) 89 | k_indices = protein.input_ids.unsqueeze(2).expand(-1, -1, logits.shape[2]) 90 | logits -= logits[i_indices, j_indices, k_indices] 91 | logits[torch.where(protein.input_ids == self.protein_tokenizer.cls_token_id)] = -1000 92 | if not self.ontoprotein: 93 | logits[torch.where(protein.input_ids == self.protein_tokenizer.eos_token_id)] = -1000 94 | else: 95 | logits[torch.where(protein.input_ids == self.protein_tokenizer.sep_token_id)] = -1000 96 | for i in range(logits.shape[0]): 97 | for j in range(logits.shape[1]): 98 | logits[i, j, protein.input_ids[i][j]] = -1000 99 | logits[(1 - protein.attention_mask).bool()] = -1000 100 | if not self.ontoprotein: 101 | logits[:, :, :4] = -1000 102 | logits[:, :, 24:] = -1000 103 | else: 104 | logits[:, :, :5] = -1000 105 | logits[:, :, 25:] = -1000 106 | return F.softmax(logits, dim=-1) 107 | 108 | 109 | def predict_fitness(self, protein, wild_type, *kwargs): 110 | mut_i_index, mut_j_index, mut_k_wt_index, mut_k_mt_index = [], [], [], [] 111 | for i in range(len(protein)): 112 | assert(len(protein[i]) == len(wild_type[0])) 113 | for j in range(len(protein[i])): 114 | if protein[i][j] != wild_type[0][j]: 115 | mut_i_index.append(i) 116 | mut_j_index.append(j + 1) 117 | mut_k_wt_index.append(wild_type[0][j]) 118 | mut_k_mt_index.append(protein[i][j]) 119 | inp_protein = self.protein_tokenizer( 120 | protein, 121 | add_special_tokens=True, 122 | truncation=True, 123 | padding=True, 124 | max_length=self.protein_maxlen, 125 | return_tensors='pt' 126 | ).to(self.device) 127 | mut_i_index = torch.LongTensor(mut_i_index).to(self.device) 128 | mut_j_index = torch.LongTensor(mut_j_index).to(self.device) 129 | mut_k_wt_index = self.protein_tokenizer.encode(mut_k_wt_index, add_special_tokens=False, return_tensors='pt').squeeze().to(self.device) 130 | mut_k_mt_index = self.protein_tokenizer.encode(mut_k_mt_index, add_special_tokens=False, return_tensors='pt').squeeze().to(self.device) 131 | # print(mut_i_index, mut_j_index, mut_k_wt_index, mut_k_mt_index) 132 | 133 | inp_protein.input_ids[mut_i_index, mut_j_index] = self.protein_tokenizer.mask_token_id 134 | 135 | logits = self.esm(**inp_protein, return_dict=True).logits 136 | logits = logits[mut_i_index, mut_j_index, mut_k_mt_index] - logits[mut_i_index, mut_j_index, mut_k_wt_index] 137 | mask = (mut_i_index.unsqueeze(1) == torch.arange(len(protein)).unsqueeze(0).to(self.device)).transpose(0, 1).float() 138 | score = logits * mask 139 | return score.sum(dim=-1) 140 | -------------------------------------------------------------------------------- /outputs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PharMolix/MutaPLM/495815b2069419d19d9449a59070d0fb24b596d1/outputs/.placeholder -------------------------------------------------------------------------------- /scripts/optimize/evoprotgrad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | datasets=("AAV" "AMIE" "avGFP" "E4B" "LGK" "UBE2I") 4 | ncandidates=20 5 | nrounds=10 6 | 7 | rm ./outputs/score_esm.txt 8 | 9 | for dataset in "${datasets[@]}"; 10 | do 11 | python eval.py \ 12 | --fitness_optimize \ 13 | --evo_prot_grad \ 14 | --num_rounds $nrounds \ 15 | --num_candidates $ncandidates \ 16 | --dataset_name $dataset \ 17 | --dataset_path ./data/fitness/$dataset \ 18 | --surrogate_path ./ckpts/landscape_params/esm1b_landscape/$dataset/decoder.pt \ 19 | --model_name esm \ 20 | --model_config_path ./configs/esm.yaml \ 21 | --score_save_path ./outputs/score_esm.txt \ 22 | --batch_size 64 \ 23 | --device 0 24 | done -------------------------------------------------------------------------------- /scripts/optimize/mutaplm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | datasets=("AAV" "AMIE" "avGFP" "E4B" "LGK" "UBE2I") 4 | ncandidates=20 5 | nrounds=10 6 | 7 | rm ./outputs/score_mutaplm.txt 8 | 9 | for dataset in "${datasets[@]}"; 10 | do 11 | python eval.py \ 12 | --fitness_optimize \ 13 | --num_rounds $nrounds \ 14 | --num_candidates $ncandidates \ 15 | --dataset_name $dataset \ 16 | --dataset_path ./data/fitness/$dataset \ 17 | --surrogate_path ./ckpts/landscape_params/esm1b_landscape/$dataset/decoder.pt \ 18 | --model_name e_esm \ 19 | --model_config_path ./configs/mutaplm_inference.yaml \ 20 | --model_checkpoint ./ckpts/finetune/checkpoint_9.pth \ 21 | --score_save_path ./outputs/score_mutaplm.txt \ 22 | --batch_size 4 \ 23 | --device 0 24 | done -------------------------------------------------------------------------------- /scripts/optimize/random.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | datasets=("AAV" "AMIE" "avGFP" "E4B" "LGK" "UBE2I") 4 | ncandidates=20 5 | nrounds=10 6 | 7 | rm ./outputs/score_random.txt 8 | 9 | for dataset in "${datasets[@]}"; 10 | do 11 | python eval.py \ 12 | --fitness_optimize \ 13 | --num_rounds $nrounds \ 14 | --num_candidates $ncandidates \ 15 | --dataset_name $dataset \ 16 | --dataset_path ./data/fitness/$dataset \ 17 | --surrogate_path ./ckpts/landscape_params/esm1b_landscape/$dataset/decoder.pt \ 18 | --model_name random \ 19 | --model_config_path ./configs/random.yaml \ 20 | --score_save_path ./outputs/score_random.txt \ 21 | --batch_size 64 \ 22 | --device 0 23 | done -------------------------------------------------------------------------------- /scripts/test/mutaplm_engineer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | levels=("easy" "medium" "hard") 4 | 5 | for level in "${levels[@]}"; 6 | do 7 | python eval.py \ 8 | --dataset_name mutadescribe \ 9 | --dataset_path ./data/mutadescribe/test_$level.csv \ 10 | --muta_engineer \ 11 | --model_name mutaplm \ 12 | --model_config_path ./configs/mutaplm_inference.yaml \ 13 | --model_checkpoint ./ckpts/mutaplm/model_checkpoint1.pth \ 14 | --pred_save_path ./outputs/mutaplm.txt \ 15 | --batch_size 4 \ 16 | --device 0 17 | done -------------------------------------------------------------------------------- /scripts/test/mutaplm_explain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | levels=("easy" "medium" "hard") 4 | 5 | for level in "${levels[@]}"; 6 | do 7 | python eval.py \ 8 | --dataset_name mutadescribe \ 9 | --dataset_path ./data/mutadescribe/test_${level}.csv \ 10 | --muta_explain \ 11 | --model_name mutaplm \ 12 | --model_config_path ./configs/mutaplm_inference.yaml \ 13 | --model_checkpoint ./ckpts/mutaplm/model_checkpoint.pth \ 14 | --pred_save_path ./outputs/mutaplm_${level}.txt \ 15 | --batch_size 4 \ 16 | --device 0 17 | done -------------------------------------------------------------------------------- /scripts/train/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3 3 | NUM_GPUS=4 4 | 5 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port 2345 train.py \ 6 | --dataset_name mutadescribe \ 7 | --dataset_path ./data/mutadescribe/train.csv \ 8 | --model_name mutaplm \ 9 | --model_config_path ./configs/mutaplm_ft.yaml \ 10 | --model_checkpoint ./ckpts/pretrain/checkpoint_9.pth \ 11 | --epochs 20 \ 12 | --save_epochs 5 \ 13 | --warmup_steps 1000 \ 14 | --batch_size 1 \ 15 | --grad_accu_steps 6 \ 16 | --lr 1e-4 \ 17 | --distributed \ 18 | --save_path ./ckpts/mutaplm -------------------------------------------------------------------------------- /scripts/train/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3 3 | NUM_GPUS=4 4 | 5 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port 2345 train.py \ 6 | --dataset_name literature \ 7 | --dataset_path ./data/pubs \ 8 | --model_name mutaplm \ 9 | --model_config_path ./configs/mutaplm_pt.yaml \ 10 | --epochs 10 \ 11 | --save_epochs 5 \ 12 | --warmup_steps 1000 \ 13 | --batch_size 2 \ 14 | --grad_accu_steps 4 \ 15 | --lr 1e-4 \ 16 | --distributed \ 17 | --save_path ./ckpts/pretrain -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | Trainer.add_arguments(parser) 10 | args = parser.parse_args() 11 | 12 | trainer = Trainer(args) 13 | trainer.train() -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | nowtime = time.localtime() 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | import argparse 9 | from collections import OrderedDict 10 | import numpy as np 11 | import random 12 | import json 13 | import yaml 14 | import copy 15 | 16 | import torch 17 | from torch.cuda.amp import autocast 18 | import torch.distributed as dist 19 | from torch.utils.data import DataLoader, RandomSampler 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | from dataset import dataset_name2cls 23 | from model import model_name2cls 24 | from utils import init_distributed_mode, get_rank, is_main_process, concat_gather, MetricLogger, SmoothedValue 25 | from metrics import name2metric 26 | 27 | class Trainer(object): 28 | @staticmethod 29 | def add_arguments(parser): 30 | # path & data params 31 | parser.add_argument("--dataset_name", type=str, default="mut_eff") 32 | parser.add_argument("--dataset_path", type=str, default="./data/") 33 | parser.add_argument("--nshot", type=int, default=None) 34 | parser.add_argument("--exclude", type=str, nargs='*', default=[]) 35 | parser.add_argument("--model_name", type=str, default="bert_esm") 36 | parser.add_argument("--model_config_path", type=str, default="./configs/bert_esm.yaml") 37 | parser.add_argument("--model_checkpoint", type=str, default=None) 38 | parser.add_argument("--save_path", type=str, default="./ckpts/fusion_ckpts/bert_esm") 39 | parser.add_argument("--log_path", type=str, default=f"./logs/{nowtime[1]}-{nowtime[2]}_{nowtime[3]}-{nowtime[4]}.logger.info") 40 | parser.add_argument("--resume", action="store_true") 41 | parser.add_argument("--resume_checkpoint", type=str, default="./ckpts/") 42 | parser.add_argument("--data_percent", type=float, default=1.0) 43 | 44 | # training params 45 | parser.add_argument("--seed", type=int, default=42) 46 | parser.add_argument("--epochs", type=int, default=30) 47 | parser.add_argument("--warmup_steps", type=int, default=5000) 48 | parser.add_argument("--batch_size", type=int, default=2) 49 | parser.add_argument("--num_workers", type=int, default=1) 50 | parser.add_argument("--lr", type=float, default=1e-4) 51 | parser.add_argument("--weight_decay", type=float, default=1e-2) 52 | parser.add_argument("--clip_grad_norm", type=bool, default=True) 53 | parser.add_argument("--save_epochs", type=int, default=1) 54 | parser.add_argument("--log_steps", type=int, default=10) 55 | parser.add_argument("--grad_accu_steps", type=int, default=1) 56 | 57 | # validation params 58 | parser.add_argument("--validate", action="store_true") 59 | parser.add_argument("--patience", type=int, default=5) 60 | parser.add_argument("--metric", type=str, default="spearmanr") 61 | parser.add_argument("--lower", action="store_true") 62 | 63 | # distributed params 64 | parser.add_argument("--distributed", action="store_true") 65 | parser.add_argument('--world_size', type=int, default=2, help='number of distributed processes') 66 | parser.add_argument('--local-rank', type=int, default=0) 67 | 68 | return parser 69 | 70 | def __init__(self, args): 71 | super().__init__() 72 | init_distributed_mode(args) 73 | 74 | self.args = args 75 | self.local_rank = get_rank() 76 | self.device = torch.device("cuda", self.local_rank) 77 | logging.basicConfig( 78 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 79 | datefmt="%m/%d/%Y %H:%M:%S", 80 | #level=logging.INFO 81 | level=logging.INFO if is_main_process() else logging.ERROR, 82 | ) 83 | logger.info("Rank: %d" % (self.local_rank)) 84 | self._setup_seed(self.args.seed + self.local_rank) 85 | self._setup_data() 86 | self._setup_model() 87 | 88 | def _setup_seed(self, seed): 89 | self.seed = seed 90 | random.seed(seed) 91 | np.random.seed(seed) 92 | torch.manual_seed(seed) 93 | 94 | def _setup_data(self): 95 | logger.info("Loading dataset...") 96 | self.train_dataset = dataset_name2cls[self.args.dataset_name](self.args.dataset_path, split="train", name=self.args.dataset_name, nshot=self.args.nshot, exclude=self.args.exclude) 97 | logger.info(f"Num Train Samples: {len(self.train_dataset)}") 98 | if hasattr(self.train_dataset, "get_example"): 99 | for i, example in enumerate(self.train_dataset.get_example()): 100 | if i >= 2: 101 | break 102 | logger.info(example) 103 | 104 | if self.args.distributed: 105 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 106 | else: 107 | self.train_sampler = RandomSampler(self.train_dataset) 108 | self.train_dataloader = DataLoader( 109 | self.train_dataset, 110 | batch_size=self.args.batch_size, 111 | sampler=self.train_sampler, 112 | num_workers=self.args.num_workers, 113 | ) 114 | 115 | 116 | if self.args.validate: 117 | self.valid_dataset = dataset_name2cls[self.args.dataset_name](self.args.dataset_path, split="valid", name=self.args.dataset_name, nshot=self.args.nshot, exclude=self.args.exclude) 118 | logger.info(f"Num Valid Samples: {len(self.valid_dataset)}") 119 | if self.args.distributed: 120 | self.valid_sampler = DistributedSampler(self.valid_dataset, seed=self.seed, shuffle=True) 121 | else: 122 | self.valid_sampler = RandomSampler(self.valid_dataset) 123 | 124 | self.valid_dataloader = DataLoader( 125 | self.valid_dataset, 126 | batch_size=self.args.batch_size, 127 | sampler=self.valid_sampler, 128 | num_workers=self.args.num_workers, 129 | ) 130 | self.patience = 0 131 | 132 | def _setup_model(self): 133 | logger.info("Loading model...") 134 | model_cls = model_name2cls[self.args.model_name] 135 | model_cfg = yaml.load(open(self.args.model_config_path, "r"), Loader=yaml.Loader) 136 | model_cfg["device"] = self.device 137 | self.model = model_cls(**model_cfg).to(self.device) 138 | 139 | logger.info(f"Trainable params: {sum([p.numel() if p.requires_grad else 0 for p in self.model.parameters()])/1000000}M") 140 | logger.info(f"Total params: {sum([p.numel() for p in self.model.parameters()])/1000000}M") 141 | 142 | if self.args.model_checkpoint is not None: 143 | logger.info(f"Load model checkpoint from {self.args.model_checkpoint}") 144 | state_dict = torch.load(open(self.args.model_checkpoint, "rb"), map_location=torch.device("cpu")) 145 | # NOTE: change back to state_dict["model"] 146 | self.model.load_state_dict(state_dict["model"], strict=True) 147 | self.optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 148 | self.schedular = torch.optim.lr_scheduler.OneCycleLR( 149 | self.optimizer, 150 | max_lr=self.args.lr, 151 | total_steps=int(self.args.epochs * len(self.train_dataloader)), 152 | epochs=self.args.epochs, 153 | pct_start=self.args.warmup_steps * 1.0 / self.args.epochs / len(self.train_dataloader), 154 | anneal_strategy='cos', 155 | final_div_factor=1e2 156 | ) 157 | logger.info(f"Epochs = {self.args.epochs}, Dataloader Length = {len(self.train_dataloader)}, world size = {self.args.world_size}") 158 | 159 | # continue training 160 | if self.args.resume: 161 | logger.info(f"resume from {self.args.resume_checkpoint}...") 162 | ckpt = torch.load(self.args.resume_checkpoint, map_location=torch.device("cpu")) 163 | self.model.load_state_dict(ckpt["model"]) 164 | self.optimizer.load_state_dict(ckpt["optimizer"]) 165 | self.schedular.load_state_dict(ckpt["schedular"]) 166 | self.start_epoch = ckpt["epoch"] + 1 167 | del ckpt 168 | logger.info("Load model successfully.") 169 | else: 170 | logger.info("Train from scratch") 171 | self.start_epoch = 0 172 | 173 | if self.args.distributed: 174 | logger.info("Parallizing model...") 175 | self.model = torch.nn.parallel.DistributedDataParallel( 176 | self.model, 177 | device_ids=[self.local_rank], 178 | output_device=self.local_rank, 179 | find_unused_parameters=True, 180 | broadcast_buffers=False 181 | ) 182 | self.model_without_ddp = self.model.module 183 | else: 184 | self.model_without_ddp = self.model 185 | 186 | def train(self): 187 | self.all_steps = 0 188 | logger.info("Start training!") 189 | if self.args.validate: 190 | best_result = -1e7 191 | for epoch in range(self.start_epoch, self.args.epochs): 192 | logger.info(f"Epoch {epoch + 1} / {self.args.epochs}") 193 | if self.args.distributed: 194 | self.train_sampler.set_epoch(epoch) 195 | 196 | train_stats = self.train_epoch(epoch) 197 | 198 | if self.args.validate: 199 | result = self.validate_epoch() 200 | if self.args.lower ^ (result > best_result): 201 | best_result = result 202 | self.patience = 0 203 | logger.info(f"Best ckpt at epoch {epoch}...") 204 | save_dict = { 205 | "model": copy.deepcopy(self.model_without_ddp.state_dict()), 206 | } 207 | else: 208 | self.patience += 1 209 | logger.info(f"Remaining patience: {self.args.patience - self.patience}/{self.args.patience}") 210 | if self.patience >= self.args.patience: 211 | break 212 | elif is_main_process() and (epoch + 1) % self.args.save_epochs == 0: 213 | logger.info(f"Saving ckpt at epoch {epoch}...") 214 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch,} 215 | save_dict = { 216 | "epoch": epoch, 217 | "model": self.model_without_ddp.state_dict(), 218 | "optimizer": self.optimizer.state_dict(), 219 | "schedular": self.schedular.state_dict(), 220 | } 221 | if not os.path.exists(self.args.save_path): 222 | os.mkdir(self.args.save_path) 223 | torch.save(save_dict, os.path.join(self.args.save_path, "checkpoint_%d.pth" % epoch)) 224 | logger.info(json.dumps(log_stats)) 225 | 226 | if self.args.distributed: 227 | dist.barrier() 228 | if self.args.validate: 229 | if not os.path.exists(self.args.save_path): 230 | os.mkdir(self.args.save_path) 231 | torch.save(save_dict, os.path.join(self.args.save_path, "best_model.pth")) 232 | 233 | def train_epoch(self, epoch): 234 | metric_logger = MetricLogger(self.args, delimiter=" ") 235 | metric_logger.add_meter('lr', SmoothedValue(window_size=50, fmt='{value:.6f}')) 236 | for k in self.model_without_ddp.loss_names: 237 | metric_logger.add_meter(k, SmoothedValue(window_size=200, fmt='{avg:.4f}')) 238 | header = 'Train Epoch: [{}]'.format(epoch) 239 | 240 | self.model.train() 241 | for i, data in enumerate(metric_logger.log_every(self.train_dataloader, 5, header)): 242 | with autocast(dtype=torch.bfloat16): 243 | if not hasattr(self.model_without_ddp, "forward_fn"): 244 | loss, output = self.model(*data) 245 | else: 246 | loss, output = self.model_without_ddp.forward_fn(*data) 247 | 248 | loss /= self.args.grad_accu_steps 249 | loss.backward() 250 | if self.args.clip_grad_norm: 251 | torch.nn.utils.clip_grad_norm_(list(self.model.parameters()), max_norm=10) 252 | if (i + 1) % self.args.grad_accu_steps == 0: 253 | self.all_steps += 1 254 | if self.all_steps in [5000, 20000] and is_main_process(): 255 | logger.info(f"Best ckpt at step {self.all_steps}...") 256 | save_dict = { 257 | "model": copy.deepcopy(self.model_without_ddp.state_dict()), 258 | } 259 | torch.save(save_dict, os.path.join(self.args.save_path, "step_%dK.pth" % (self.all_steps // 1000))) 260 | if self.args.distributed: 261 | dist.barrier() 262 | self.optimizer.step() 263 | self.optimizer.zero_grad() 264 | self.schedular.step() 265 | 266 | metric_logger.update(lr = self.optimizer.param_groups[0]["lr"]) 267 | metric_logger.update(**output) 268 | 269 | metric_logger.synchronize_between_processes() 270 | logger.info(f"Averaged stats: {metric_logger.global_avg()}") 271 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 272 | 273 | def validate_epoch(self): 274 | logger.info("Validating...") 275 | self.model.eval() 276 | all_preds, all_gts = [], [] 277 | for i, data in enumerate(self.valid_dataloader): 278 | with torch.no_grad(): 279 | with autocast(dtype=torch.bfloat16): 280 | if not hasattr(self.model_without_ddp, "validate_fn"): 281 | preds, _ = self.model(*data) 282 | gt = torch.zeros(1).to(self.device) 283 | else: 284 | if len(data) == 2: 285 | preds, gt = self.model_without_ddp.validate_fn(data[0], [self.train_dataset.wild_type], [self.train_dataset.prompt], data[1]) 286 | else: 287 | preds, gt = self.model_without_ddp.validate_fn(*data) 288 | all_preds.append(preds.cpu()) 289 | all_gts.append(gt.cpu()) 290 | 291 | if i % 50 == 0: 292 | logger.info(f"Validation step {i}/{len(self.valid_dataloader)}") 293 | all_preds = torch.cat(all_preds, dim=0) 294 | all_gts = torch.cat(all_gts, dim=0) 295 | all_preds = concat_gather(all_preds) 296 | all_gts = concat_gather(all_gts) 297 | if self.args.dataset_name == "mixfitness": 298 | scores = [] 299 | for i in range(self.valid_dataset.n_datasets): 300 | st, ed = i * self.valid_dataset.nshot, (i + 1) * self.valid_dataset.nshot 301 | scores.append(name2metric[self.args.metric](all_gts[st:ed].numpy(), all_preds[st:ed].numpy())) 302 | print(scores) 303 | score = np.mean(scores) 304 | else: 305 | score = name2metric[self.args.metric](all_gts.numpy(), all_preds.numpy()) 306 | logger.info(f"Validation result: {self.args.metric} = {score}") 307 | return score -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | from collections import defaultdict, deque 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | def is_dist_avail_and_initialized(): 13 | return dist.is_available() and dist.is_initialized() 14 | 15 | 16 | def is_main_process(): 17 | return get_rank() == 0 18 | 19 | 20 | def get_rank(): 21 | return dist.get_rank() if is_dist_avail_and_initialized() else 0 22 | 23 | 24 | def LOG(info, args): 25 | if is_main_process(): 26 | t = time.strftime("%m-%d %H:%M:%S", time.localtime()) 27 | t = "["+str(t)+"] " 28 | print(t + info) 29 | with open(args.log_path, "a") as f: 30 | f.write(t+str(info)+"\n") 31 | 32 | 33 | def setup_for_distributed(is_master): 34 | """ 35 | This function disables printing when not in master process (rank 0) 36 | """ 37 | import builtins as __builtin__ 38 | builtin_print = __builtin__.print 39 | 40 | def print(*args, **kwargs): 41 | force = kwargs.pop('force', False) 42 | if is_master or force: 43 | builtin_print(*args, **kwargs) 44 | 45 | __builtin__.print = print 46 | 47 | 48 | def init_distributed_mode(args): 49 | if not args.distributed: 50 | args.rank = 0 51 | args.world_size = 1 52 | return 53 | 54 | dist.init_process_group(backend='nccl') 55 | args.rank = get_rank() 56 | args.world_size = dist.get_world_size() 57 | torch.cuda.set_device(args.rank) 58 | logger.info(f'| distributed init (rank {args.rank} / {args.world_size})') 59 | 60 | dist.barrier() 61 | setup_for_distributed(args.rank == 0) 62 | 63 | 64 | @torch.no_grad() 65 | def concat_gather(tensor): 66 | if not is_dist_avail_and_initialized(): 67 | return tensor 68 | 69 | gather_tensor = [torch.zeros_like(tensor) for i in range(dist.get_world_size())] 70 | dist.all_gather(gather_tensor, tensor, async_op=False) 71 | return torch.cat(gather_tensor, dim=0) 72 | 73 | 74 | def concat_gather_with_grad(tensor): 75 | if not is_dist_avail_and_initialized(): 76 | return tensor 77 | 78 | gather_tensor = GatherLayer.apply(tensor) 79 | return torch.cat(gather_tensor, dim=0) 80 | 81 | 82 | class GatherLayer(torch.autograd.Function): 83 | """ 84 | Gather tensors from all workers with support for backward propagation: 85 | This implementation does not cut the gradients as torch.distributed.all_gather does. 86 | """ 87 | 88 | @staticmethod 89 | def forward(ctx, x): 90 | output = [ 91 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 92 | ] 93 | torch.distributed.all_gather(output, x) 94 | return tuple(output) 95 | 96 | @staticmethod 97 | def backward(ctx, *grads): 98 | all_gradients = torch.stack(grads) 99 | torch.distributed.all_reduce(all_gradients) 100 | return all_gradients[torch.distributed.get_rank()] 101 | 102 | 103 | class SmoothedValue(object): 104 | """Track a series of values and provide access to smoothed values over a 105 | window or the global series average. 106 | """ 107 | 108 | def __init__(self, window_size=20, fmt=None): 109 | if fmt is None: 110 | fmt = "{median:.4f} ({global_avg:.4f})" 111 | self.deque = deque(maxlen=window_size) 112 | self.total = 0.0 113 | self.count = 0 114 | self.fmt = fmt 115 | 116 | def update(self, value, n=1): 117 | self.deque.append(value) 118 | self.count += n 119 | self.total += value * n 120 | 121 | def synchronize_between_processes(self): 122 | """ 123 | Warning: does not synchronize the deque! 124 | """ 125 | if not is_dist_avail_and_initialized(): 126 | return 127 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 128 | dist.barrier() 129 | dist.all_reduce(t) 130 | t = t.tolist() 131 | self.count = int(t[0]) 132 | self.total = t[1] 133 | 134 | @property 135 | def median(self): 136 | d = torch.tensor(list(self.deque)) 137 | return d.median().item() 138 | 139 | @property 140 | def avg(self): 141 | d = torch.tensor(list(self.deque), dtype=torch.float32) 142 | return d.mean().item() 143 | 144 | @property 145 | def global_avg(self): 146 | return self.total / self.count 147 | 148 | @property 149 | def max(self): 150 | return max(self.deque) 151 | 152 | @property 153 | def value(self): 154 | return self.deque[-1] 155 | 156 | def __str__(self): 157 | return self.fmt.format( 158 | median=self.median, 159 | avg=self.avg, 160 | global_avg=self.global_avg, 161 | max=self.max, 162 | value=self.value) 163 | 164 | 165 | class MetricLogger(object): 166 | def __init__(self, args, delimiter="\t"): 167 | self.meters = defaultdict(SmoothedValue) 168 | self.delimiter = delimiter 169 | self.args = args 170 | 171 | def update(self, **kwargs): 172 | for k, v in kwargs.items(): 173 | if isinstance(v, torch.Tensor): 174 | v = v.item() 175 | assert isinstance(v, (float, int)) 176 | self.meters[k].update(v) 177 | 178 | def __getattr__(self, attr): 179 | if attr in self.meters: 180 | return self.meters[attr] 181 | if attr in self.__dict__: 182 | return self.__dict__[attr] 183 | raise AttributeError("'{}' object has no attribute '{}'".format( 184 | type(self).__name__, attr)) 185 | 186 | def __str__(self): 187 | loss_str = [] 188 | for name, meter in self.meters.items(): 189 | loss_str.append( 190 | "{}: {}".format(name, str(meter)) 191 | ) 192 | return self.delimiter.join(loss_str) 193 | 194 | def global_avg(self): 195 | loss_str = [] 196 | for name, meter in self.meters.items(): 197 | loss_str.append( 198 | "{}: {:.4f}".format(name, meter.global_avg) 199 | ) 200 | return self.delimiter.join(loss_str) 201 | 202 | def synchronize_between_processes(self): 203 | for meter in self.meters.values(): 204 | meter.synchronize_between_processes() 205 | 206 | def add_meter(self, name, meter): 207 | self.meters[name] = meter 208 | 209 | def log_every(self, iterable, print_freq, header=None): 210 | i = 1 211 | if not header: 212 | header = '' 213 | start_time = time.time() 214 | end = time.time() 215 | iter_time = SmoothedValue(fmt='{avg:.4f}') 216 | data_time = SmoothedValue(fmt='{avg:.4f}') 217 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 218 | log_msg = [ 219 | header, 220 | '[{0' + space_fmt + '}/{1}]', 221 | 'eta: {eta}', 222 | '{meters}', 223 | 'time: {time}', 224 | 'data: {data}' 225 | ] 226 | if torch.cuda.is_available(): 227 | log_msg.append('max mem: {memory:.0f}') 228 | log_msg = self.delimiter.join(log_msg) 229 | MB = 1024.0 * 1024.0 230 | for obj in iterable: 231 | data_time.update(time.time() - end) 232 | yield obj 233 | iter_time.update(time.time() - end) 234 | if i % print_freq == 0 or i == len(iterable) - 1: 235 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 236 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 237 | if torch.cuda.is_available(): 238 | LOG(log_msg.format( 239 | i, len(iterable), eta=eta_string, 240 | meters=str(self), 241 | time=str(iter_time), data=str(data_time), 242 | memory=torch.cuda.max_memory_allocated() / MB), args=self.args) 243 | else: 244 | LOG(log_msg.format( 245 | i, len(iterable), eta=eta_string, 246 | meters=str(self), 247 | time=str(iter_time), data=str(data_time)), args=self.args) 248 | i += 1 249 | end = time.time() 250 | total_time = time.time() - start_time 251 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 252 | LOG('{} Total time: {} ({:.4f} s / it)'.format( 253 | header, total_time_str, total_time / len(iterable)), args = self.args) --------------------------------------------------------------------------------