├── .gitignore ├── .vscode └── launch.json ├── environment.yml ├── fig └── framework.png ├── model_zoom ├── __init__.py ├── dplm_example.py ├── esm │ ├── __init__.py │ ├── data │ │ ├── ParentChildTreeFile.txt │ │ ├── entry_list_safety_29026.list │ │ ├── interpro_29026_to_keywords_58641.csv │ │ ├── keyword_idf_safety_filtered_58641.npy │ │ └── keyword_vocabulary_safety_filtered_58641.txt │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── blocks.py │ │ ├── codebook.py │ │ ├── ffn.py │ │ ├── geom_attention.py │ │ ├── regression_head.py │ │ ├── rotary.py │ │ ├── structure_proj.py │ │ └── transformer_stack.py │ ├── models │ │ ├── __init__.py │ │ ├── esm3.py │ │ ├── esmc.py │ │ ├── function_decoder.py │ │ └── vqvae.py │ ├── pretrained.py │ ├── sdk │ │ ├── __init__.py │ │ ├── api.py │ │ ├── experimental.py │ │ ├── forge.py │ │ └── sagemaker.py │ ├── tokenization │ │ ├── __init__.py │ │ ├── function_tokenizer.py │ │ ├── residue_tokenizer.py │ │ ├── sasa_tokenizer.py │ │ ├── sequence_tokenizer.py │ │ ├── ss_tokenizer.py │ │ ├── structure_tokenizer.py │ │ └── tokenizer_base.py │ ├── utils │ │ ├── __init__.py │ │ ├── constants │ │ │ ├── api.py │ │ │ ├── esm3.py │ │ │ ├── models.py │ │ │ └── physics.py │ │ ├── decoding.py │ │ ├── encoding.py │ │ ├── forge_context_manager.py │ │ ├── function │ │ │ ├── encode_decode.py │ │ │ ├── interpro.py │ │ │ ├── lsh.py │ │ │ └── tfidf.py │ │ ├── generation.py │ │ ├── generation_test.py │ │ ├── misc.py │ │ ├── misc_test.py │ │ ├── noise_schedules.py │ │ ├── residue_constants.py │ │ ├── sampling.py │ │ ├── sampling_test.py │ │ ├── structure │ │ │ ├── affine3d.py │ │ │ ├── aligner.py │ │ │ ├── metrics.py │ │ │ ├── normalize_coordinates.py │ │ │ ├── predicted_aligned_error.py │ │ │ ├── protein_chain.py │ │ │ ├── protein_complex.py │ │ │ └── protein_structure.py │ │ └── types.py │ └── widgets │ │ ├── components │ │ ├── function_annotator.py │ │ ├── results_visualizer.py │ │ ├── sasa_prompt_selector.py │ │ ├── secondary_structure_prompt_selector.py │ │ ├── sequence_prompt_selector.py │ │ └── structure_prompt_selector.py │ │ ├── utils │ │ ├── clients.py │ │ ├── drawing │ │ │ ├── colors.py │ │ │ ├── draw_category_array.py │ │ │ ├── draw_function_annotations.py │ │ │ └── draw_protein_structure.py │ │ ├── indexing.py │ │ ├── parsing.py │ │ ├── printing.py │ │ ├── prompting.py │ │ ├── protein_import.py │ │ ├── serialization.py │ │ └── types.py │ │ └── views │ │ ├── esm3_generation_launcher.py │ │ ├── esm3_prompt_preview.py │ │ ├── esm3_prompt_selector.py │ │ ├── generation.py │ │ ├── inverse_folding.py │ │ ├── login.py │ │ └── prediction.py ├── esm2_3b_example.py ├── esm2_example.py ├── esm3_example.py ├── esmc_example.py ├── ontoprotein_example.py ├── pglm_example.py ├── pro2text_example.py ├── procyon_example.py ├── progen2_example.py ├── prollama_example.py ├── prosst_2048_example.py ├── prostT5_example.py ├── prost_example.py ├── protgpt2_example.py ├── protrek_example.py ├── prott5_example.py ├── saport_example.py └── venusplm_example.py ├── pyrightconfig.json ├── readme.md ├── run_adapter.sh ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── affine3d.py │ ├── constants │ │ ├── __init__.py │ │ ├── geometry.py │ │ └── sequence.py │ ├── esm │ │ ├── __init__.py │ │ ├── pretrained.py │ │ ├── sdk │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── experimental.py │ │ │ ├── forge.py │ │ │ └── sagemaker.py │ │ ├── tokenization │ │ │ ├── __init__.py │ │ │ ├── function_tokenizer.py │ │ │ ├── residue_tokenizer.py │ │ │ ├── sasa_tokenizer.py │ │ │ ├── sequence_tokenizer.py │ │ │ ├── ss_tokenizer.py │ │ │ ├── structure_tokenizer.py │ │ │ └── tokenizer_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── constants │ │ │ ├── api.py │ │ │ ├── esm3.py │ │ │ ├── models.py │ │ │ └── physics.py │ │ │ ├── decoding.py │ │ │ ├── encoding.py │ │ │ ├── forge_context_manager.py │ │ │ ├── function │ │ │ ├── encode_decode.py │ │ │ ├── interpro.py │ │ │ ├── lsh.py │ │ │ └── tfidf.py │ │ │ ├── generation.py │ │ │ ├── generation_test.py │ │ │ ├── misc.py │ │ │ ├── misc_test.py │ │ │ ├── noise_schedules.py │ │ │ ├── residue_constants.py │ │ │ ├── sampling.py │ │ │ ├── sampling_test.py │ │ │ ├── structure │ │ │ ├── affine3d.py │ │ │ ├── aligner.py │ │ │ ├── metrics.py │ │ │ ├── normalize_coordinates.py │ │ │ ├── predicted_aligned_error.py │ │ │ ├── protein_chain.py │ │ │ ├── protein_complex.py │ │ │ └── protein_structure.py │ │ │ └── types.py │ ├── geo_featurizer.py │ ├── msa_dataset.py │ ├── protein │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── polyseq.py │ │ ├── protein.py │ │ ├── starparser.py │ │ ├── system.py │ │ └── xcs.py │ ├── protein_dataset.py │ └── proteingym_dataset.py ├── interface │ ├── __init__.py │ ├── data_interface.py │ └── model_interface.py ├── model │ ├── __init__.py │ ├── finetune_model.py │ ├── pretrain_model_interface.py │ ├── pretrain_model_interface_old.py │ └── pretrain_modules.py └── utils │ ├── __init__.py │ ├── bce.py │ ├── logger.py │ ├── metrics.py │ ├── utils.py │ └── zero_to_fp32.py ├── tasks ├── __init__.py ├── configs │ ├── antibiotic_resistance.yaml │ ├── beta_lactamase_activity_prediction.yaml │ ├── binding_db.yaml │ ├── cloning_clf.yaml │ ├── contact_map.yaml │ ├── deepet_topt.yaml │ ├── deeploc2_multi.yaml │ ├── deeploc_binary.yaml │ ├── deeploc_multi.yaml │ ├── deepsol.yaml │ ├── deepsolue.yaml │ ├── enzyme_catalytic_efficiency.yaml │ ├── enzyme_commission_number.yaml │ ├── esol.yaml │ ├── fitness_prediction.yaml │ ├── flip_aav_sampled.yaml │ ├── flip_gb1_sampled.yaml │ ├── fold_prediction.yaml │ ├── go_biological_process.yaml │ ├── go_cellular_component.yaml │ ├── go_molecular_function.yaml │ ├── human_ppi.yaml │ ├── material_production.yaml │ ├── metal_ion_binding.yaml │ ├── optimal_ph.yaml │ ├── pdbind.yaml │ ├── peptide_HLA_MHC_affinity.yaml │ ├── peta_chs_sol.yaml │ ├── peta_lgk_sol.yaml │ ├── peta_tem_sol.yaml │ ├── ppi_affinity.yaml │ ├── prot_solm.yaml │ ├── sorting_signal.yaml │ ├── ssp_q8.yaml │ ├── tape_fluorescence.yaml │ ├── tape_stability.yaml │ ├── tcr_pmhc_affinity.yaml │ ├── temperature_stability.yaml │ └── yeast_ppi.yaml ├── data_interface.py ├── main.py ├── model_interface.py └── tuner.py └── zeroshot ├── data_interface.py ├── model_interface.py ├── msa_kl_light copy.py ├── msa_kl_light.py ├── pglm └── compute_fitness.py ├── proteingym_light.py └── saprot ├── compute_fitness.py └── foldseek_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | # Distribution / packaging 4 | build/ 5 | dist/ 6 | *.egg-info/ 7 | 8 | # binaries 9 | *.so 10 | 11 | # tmp files 12 | *.swp 13 | 14 | # AML workspace config file 15 | config.json 16 | 17 | .coverage_* 18 | *~ 19 | slurm* 20 | logs 21 | 22 | # Data folder 23 | bookcorpus_data/ 24 | 25 | tmp/ 26 | ds-configs/ 27 | .kube/ 28 | wandb/ 29 | events.out.tfevents.* 30 | model_zoom/ 31 | N128/ 32 | old-sota-yamls/ 33 | outputs/ 34 | datasets/ 35 | *.log 36 | results/ 37 | *.pdb 38 | *.pickle 39 | /zeroshot/msa 40 | *config.json 41 | output_model_plots/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "env": { 15 | "PYTHONPATH": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom", 16 | "CUDA_VISIBLE_DEVICES": "0", 17 | }, 18 | }, 19 | { 20 | "name": "finetune", 21 | "type": "debugpy", 22 | "request": "launch", 23 | "program": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/tasks/main.py", 24 | "console": "integratedTerminal", 25 | "justMyCode": true, 26 | "args": [ 27 | "--config_name", "ssp_q8", 28 | "--pretrain_model_name", "esm2_35m", 29 | "--num_workers", "1", 30 | ], 31 | "env": { 32 | "PYTHONPATH": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/", 33 | "CUDA_VISIBLE_DEVICES": "1", 34 | }, 35 | }, 36 | { 37 | "name": "zeroshot", 38 | "type": "debugpy", 39 | "request": "launch", 40 | "program": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa_kl_light.py", 41 | "console": "integratedTerminal", 42 | "justMyCode": true, 43 | "args": [ 44 | "--config_name", "fitness_prediction", 45 | "--pretrain_model_name", "esm2_650m", 46 | "--offline", "1" 47 | ], 48 | "env": { 49 | "PYTHONPATH": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/", 50 | "CUDA_VISIBLE_DEVICES": "0", 51 | "CONFIG_NAME": "fitness_prediction", 52 | "EXP_NAME": "fitness_prediction", 53 | "OFFLINE": "1" 54 | }, 55 | }, 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/fig/framework.png -------------------------------------------------------------------------------- /model_zoom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/model_zoom/__init__.py -------------------------------------------------------------------------------- /model_zoom/dplm_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from byprot.models.lm.dplm import DiffusionProteinLanguageModel 3 | 4 | 5 | model_name = "YOUR WEIGHTS PATH" 6 | dplm = DiffusionProteinLanguageModel.from_pretrained(model_name) 7 | params = torch.load(f'{model_name}/pytorch_model.bin') 8 | dplm.net.load_state_dict(params) 9 | 10 | toy_data = [ 11 | "MKTLLLTLVVVTIVCLDLGYT", # 序列1 12 | "GAVLFGYTPGGLAAGALYGVK" # 序列2 13 | ] 14 | 15 | inputs = dplm.net.tokenizer(toy_data, return_tensors="pt", padding=True, truncation=True) 16 | 17 | outputs = dplm.net.esm( 18 | inputs['input_ids'], 19 | attention_mask=inputs['attention_mask'], 20 | return_dict=True, 21 | ) 22 | 23 | embeddings = outputs.last_hidden_state 24 | 25 | # 打印 embedding 形状 26 | print(f"Embedding shape for sequence 1: {embeddings[0].shape}") 27 | print(f"Embedding shape for sequence 2: {embeddings[1].shape}") 28 | -------------------------------------------------------------------------------- /model_zoom/esm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.1.7" 2 | 3 | -------------------------------------------------------------------------------- /model_zoom/esm/data/keyword_idf_safety_filtered_58641.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/model_zoom/esm/data/keyword_idf_safety_filtered_58641.npy -------------------------------------------------------------------------------- /model_zoom/esm/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/model_zoom/esm/layers/__init__.py -------------------------------------------------------------------------------- /model_zoom/esm/layers/codebook.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class EMACodebook(nn.Module): 9 | def __init__( 10 | self, 11 | n_codes, 12 | embedding_dim, 13 | no_random_restart=True, 14 | restart_thres=1.0, 15 | ema_decay=0.99, 16 | ): 17 | super().__init__() 18 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) 19 | self.register_buffer("N", torch.zeros(n_codes)) 20 | self.register_buffer("z_avg", self.embeddings.data.clone()) 21 | 22 | self.n_codes = n_codes 23 | self.embedding_dim = embedding_dim 24 | self._need_init = True 25 | self.no_random_restart = no_random_restart 26 | self.restart_thres = restart_thres 27 | self.freeze_codebook = False 28 | self.ema_decay = ema_decay 29 | 30 | def reset_parameters(self): 31 | # For meta init 32 | pass 33 | 34 | def _tile(self, x): 35 | d, ew = x.shape 36 | if d < self.n_codes: 37 | n_repeats = (self.n_codes + d - 1) // d 38 | std = 0.01 / np.sqrt(ew) 39 | x = x.repeat(n_repeats, 1) 40 | x = x + torch.randn_like(x) * std 41 | return x 42 | 43 | def _init_embeddings(self, z): 44 | # z: [b, t, c] 45 | self._need_init = False 46 | flat_inputs = z.view(-1, self.embedding_dim) 47 | y = self._tile(flat_inputs) 48 | 49 | y.shape[0] 50 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 51 | if dist.is_initialized(): 52 | dist.broadcast(_k_rand, 0) 53 | self.embeddings.data.copy_(_k_rand) 54 | self.z_avg.data.copy_(_k_rand) 55 | self.N.data.copy_(torch.ones(self.n_codes)) 56 | 57 | def forward(self, z): 58 | # z: [b, t, c] 59 | if self._need_init and self.training and not self.freeze_codebook: 60 | self._init_embeddings(z) 61 | # z is of shape [batch_size, sequence length, channels] 62 | flat_inputs = z.view(-1, self.embedding_dim) 63 | distances = ( 64 | (flat_inputs**2).sum(dim=1, keepdim=True) 65 | - 2 * flat_inputs @ self.embeddings.t() 66 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) 67 | ) # [bt, c] 68 | 69 | encoding_indices = torch.argmin(distances, dim=1) 70 | encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode] 71 | 72 | embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c] 73 | 74 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) 75 | 76 | # EMA codebook update 77 | if self.training and not self.freeze_codebook: 78 | assert False, "Not implemented" 79 | embeddings_st = (embeddings - z).detach() + z 80 | 81 | return embeddings_st, encoding_indices, commitment_loss 82 | 83 | def dictionary_lookup(self, encodings): 84 | embeddings = F.embedding(encodings, self.embeddings) 85 | return embeddings 86 | 87 | def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor: 88 | return weights @ self.embeddings 89 | -------------------------------------------------------------------------------- /model_zoom/esm/layers/ffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | 5 | # NOT CURRENTLY USED 6 | 7 | 8 | class SwiGLU(nn.Module): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | def forward(self, x: Tensor) -> Tensor: 13 | x1, x2 = x.chunk(2, dim=-1) 14 | hidden = F.silu(x1) * x2 15 | return hidden 16 | 17 | 18 | class FFN(nn.Module): 19 | def __init__(self, in_proj, activation, out_proj) -> None: 20 | super().__init__() 21 | self.in_proj = in_proj 22 | self.activation = activation 23 | self.out_proj = out_proj 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | x = self.in_proj(x) 27 | x = self.activation(x) 28 | x = self.out_proj(x) 29 | return x 30 | -------------------------------------------------------------------------------- /model_zoom/esm/layers/regression_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def RegressionHead( 5 | d_model: int, output_dim: int, hidden_dim: int | None = None 6 | ) -> nn.Module: 7 | """Single-hidden layer MLP for supervised output. 8 | 9 | Args: 10 | d_model: input dimension 11 | output_dim: dimensionality of the output. 12 | hidden_dim: optional dimension of hidden layer, defaults to d_model. 13 | Returns: 14 | output MLP module. 15 | """ 16 | hidden_dim = hidden_dim if hidden_dim is not None else d_model 17 | return nn.Sequential( 18 | nn.Linear(d_model, hidden_dim), 19 | nn.GELU(), 20 | nn.LayerNorm(hidden_dim), 21 | nn.Linear(hidden_dim, output_dim), 22 | ) 23 | -------------------------------------------------------------------------------- /model_zoom/esm/layers/structure_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from esm.utils.constants.physics import BB_COORDINATES 5 | from esm.utils.structure.affine3d import ( 6 | Affine3D, 7 | RotationMatrix, 8 | ) 9 | 10 | 11 | class Dim6RotStructureHead(nn.Module): 12 | # Normally, AF2 uses quaternions to specify rotations. There's some evidence that 13 | # other representations are more well behaved - the best one according to 14 | # https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_On_the_Continuity_of_Rotation_Representations_in_Neural_Networks_CVPR_2019_paper.pdf 15 | # is using graham schmidt on 2 vectors, which is implemented here. 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | trans_scale_factor: float = 10, 20 | norm_type: str = "layernorm", 21 | activation_fn: str = "esm_gelu", 22 | predict_torsion_angles: bool = True, 23 | ): 24 | super().__init__() 25 | self.ffn1 = nn.Linear(input_dim, input_dim) 26 | self.activation_fn = nn.GELU() 27 | self.norm = nn.LayerNorm(input_dim) 28 | self.proj = nn.Linear(input_dim, 9 + 7 * 2) 29 | self.trans_scale_factor = trans_scale_factor 30 | self.predict_torsion_angles = predict_torsion_angles 31 | self.bb_local_coords = torch.tensor(BB_COORDINATES).float() 32 | 33 | def forward(self, x, affine, affine_mask, **kwargs): 34 | if affine is None: 35 | rigids = Affine3D.identity( 36 | x.shape[:-1], 37 | dtype=x.dtype, 38 | device=x.device, 39 | requires_grad=self.training, 40 | rotation_type=RotationMatrix, 41 | ) 42 | else: 43 | rigids = affine 44 | 45 | # [*, N] 46 | x = self.ffn1(x) 47 | x = self.activation_fn(x) 48 | x = self.norm(x) 49 | trans, x, y, angles = self.proj(x).split([3, 3, 3, 7 * 2], dim=-1) 50 | trans = trans * self.trans_scale_factor 51 | x = x / (x.norm(dim=-1, keepdim=True) + 1e-5) 52 | y = y / (y.norm(dim=-1, keepdim=True) + 1e-5) 53 | update = Affine3D.from_graham_schmidt(x + trans, trans, y + trans) 54 | rigids = rigids.compose(update.mask(affine_mask)) 55 | affine = rigids.tensor 56 | 57 | # We approximate the positions of the backbone atoms in the global frame by applying the rigid 58 | # transformation to the mean of the backbone atoms in the local frame. 59 | all_bb_coords_local = ( 60 | self.bb_local_coords[None, None, :, :] 61 | .expand(*x.shape[:-1], 3, 3) 62 | .to(x.device) 63 | ) 64 | pred_xyz = rigids[..., None].apply(all_bb_coords_local) 65 | 66 | return affine, pred_xyz 67 | -------------------------------------------------------------------------------- /model_zoom/esm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/model_zoom/esm/models/__init__.py -------------------------------------------------------------------------------- /model_zoom/esm/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from esm.sdk.forge import ESM3ForgeInferenceClient 4 | from esm.utils.forge_context_manager import ForgeBatchExecutor 5 | 6 | # Note: please do not import ESM3SageMakerClient here since that requires AWS SDK. 7 | 8 | 9 | def client( 10 | model="esm3-sm-open-v1", 11 | url="https://forge.evolutionaryscale.ai", 12 | token=os.environ.get("ESM_API_KEY", ""), 13 | request_timeout=None, 14 | ): 15 | """ 16 | Args: 17 | model: Name of the model to use. 18 | url: URL of a forge server. 19 | token: User's API token. 20 | request_timeout: Amount of time to wait for a request to finish. 21 | Default is wait indefinitely. 22 | """ 23 | return ESM3ForgeInferenceClient(model, url, token, request_timeout) 24 | 25 | 26 | def batch_executor(max_attempts: int = 10): 27 | """ 28 | Args: 29 | max_attempts: Maximum number of attempts to make before giving up. 30 | 31 | Usage: 32 | with batch_executor() as executor: 33 | executor.execute_batch(fn, **kwargs) 34 | """ 35 | return ForgeBatchExecutor(max_attempts) 36 | -------------------------------------------------------------------------------- /model_zoom/esm/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Protocol 3 | 4 | from esm.utils.constants.models import ( 5 | ESM3_OPEN_SMALL, 6 | normalize_model_name, 7 | ) 8 | 9 | from .function_tokenizer import InterProQuantizedTokenizer 10 | from .residue_tokenizer import ResidueAnnotationsTokenizer 11 | from .sasa_tokenizer import SASADiscretizingTokenizer 12 | from .sequence_tokenizer import EsmSequenceTokenizer 13 | from .ss_tokenizer import SecondaryStructureTokenizer 14 | from .structure_tokenizer import StructureTokenizer 15 | from .tokenizer_base import EsmTokenizerBase 16 | 17 | 18 | class TokenizerCollectionProtocol(Protocol): 19 | sequence: EsmSequenceTokenizer 20 | structure: StructureTokenizer 21 | secondary_structure: SecondaryStructureTokenizer 22 | sasa: SASADiscretizingTokenizer 23 | function: InterProQuantizedTokenizer 24 | residue_annotations: ResidueAnnotationsTokenizer 25 | 26 | 27 | @dataclass 28 | class TokenizerCollection: 29 | sequence: EsmSequenceTokenizer 30 | structure: StructureTokenizer 31 | secondary_structure: SecondaryStructureTokenizer 32 | sasa: SASADiscretizingTokenizer 33 | function: InterProQuantizedTokenizer 34 | residue_annotations: ResidueAnnotationsTokenizer 35 | 36 | 37 | def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: 38 | if normalize_model_name(model) == ESM3_OPEN_SMALL: 39 | return TokenizerCollection( 40 | sequence=EsmSequenceTokenizer(), 41 | structure=StructureTokenizer(), 42 | secondary_structure=SecondaryStructureTokenizer(kind="ss8"), 43 | sasa=SASADiscretizingTokenizer(), 44 | function=InterProQuantizedTokenizer(), 45 | residue_annotations=ResidueAnnotationsTokenizer(), 46 | ) 47 | else: 48 | raise ValueError(f"Unknown model: {model}") 49 | 50 | 51 | def get_esmc_model_tokenizers() -> EsmSequenceTokenizer: 52 | return EsmSequenceTokenizer() 53 | 54 | 55 | def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: 56 | if isinstance(tokenizer, EsmSequenceTokenizer): 57 | return [ 58 | tokenizer.mask_token_id, # type: ignore 59 | tokenizer.pad_token_id, # type: ignore 60 | tokenizer.cls_token_id, # type: ignore 61 | tokenizer.eos_token_id, # type: ignore 62 | ] 63 | else: 64 | return [ 65 | tokenizer.mask_token_id, 66 | tokenizer.pad_token_id, 67 | tokenizer.bos_token_id, 68 | tokenizer.eos_token_id, 69 | ] 70 | -------------------------------------------------------------------------------- /model_zoom/esm/tokenization/structure_tokenizer.py: -------------------------------------------------------------------------------- 1 | from esm.tokenization.tokenizer_base import EsmTokenizerBase 2 | from esm.utils.constants import esm3 as C 3 | 4 | 5 | class StructureTokenizer(EsmTokenizerBase): 6 | """A convenince class for accessing special token ids of 7 | the StructureTokenEncoder and StructureTokenDecoder.""" 8 | 9 | def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE): 10 | self.vq_vae_special_tokens = { 11 | "MASK": codebook_size, 12 | "EOS": codebook_size + 1, 13 | "BOS": codebook_size + 2, 14 | "PAD": codebook_size + 3, 15 | "CHAINBREAK": codebook_size + 4, 16 | } 17 | 18 | def mask_token(self) -> str: 19 | raise NotImplementedError( 20 | "Structure tokens are defined on 3D coordinates, not strings." 21 | ) 22 | 23 | @property 24 | def mask_token_id(self) -> int: 25 | return self.vq_vae_special_tokens["MASK"] 26 | 27 | def bos_token(self) -> str: 28 | raise NotImplementedError( 29 | "Structure tokens are defined on 3D coordinates, not strings." 30 | ) 31 | 32 | @property 33 | def bos_token_id(self) -> int: 34 | return self.vq_vae_special_tokens["BOS"] 35 | 36 | def eos_token(self) -> str: 37 | raise NotImplementedError( 38 | "Structure tokens are defined on 3D coordinates, not strings." 39 | ) 40 | 41 | @property 42 | def eos_token_id(self) -> int: 43 | return self.vq_vae_special_tokens["EOS"] 44 | 45 | def pad_token(self) -> str: 46 | raise NotImplementedError( 47 | "Structure tokens are defined on 3D coordinates, not strings." 48 | ) 49 | 50 | @property 51 | def pad_token_id(self) -> int: 52 | return self.vq_vae_special_tokens["PAD"] 53 | 54 | def chain_break_token(self) -> str: 55 | raise NotImplementedError( 56 | "Structure tokens are defined on 3D coordinates, not strings." 57 | ) 58 | 59 | @property 60 | def chain_break_token_id(self) -> int: 61 | return self.vq_vae_special_tokens["CHAINBREAK"] 62 | 63 | @property 64 | def all_token_ids(self): 65 | return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens))) 66 | 67 | @property 68 | def special_token_ids(self): 69 | return self.vq_vae_special_tokens.values() 70 | 71 | def encode(self, *args, **kwargs): 72 | raise NotImplementedError( 73 | "The StructureTokenizer class is provided as a convenience for " 74 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 75 | "Please use them instead." 76 | ) 77 | 78 | def decode(self, *args, **kwargs): 79 | raise NotImplementedError( 80 | "The StructureTokenizer class is provided as a convenience for " 81 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 82 | "Please use them instead." 83 | ) 84 | -------------------------------------------------------------------------------- /model_zoom/esm/tokenization/tokenizer_base.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | 4 | @runtime_checkable 5 | class EsmTokenizerBase(Protocol): 6 | mask_token: str 7 | mask_token_id: int 8 | bos_token: str 9 | bos_token_id: int 10 | eos_token: str 11 | eos_token_id: int 12 | pad_token: str 13 | pad_token_id: int 14 | chain_break_token: str 15 | chain_break_token_id: int 16 | 17 | def encode(self, *args, **kwargs): ... 18 | 19 | def decode(self, *args, **kwargs): ... 20 | 21 | @property 22 | def all_token_ids(self): ... 23 | 24 | @property 25 | def special_token_ids(self): ... 26 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/model_zoom/esm/utils/__init__.py -------------------------------------------------------------------------------- /model_zoom/esm/utils/constants/api.py: -------------------------------------------------------------------------------- 1 | MAX_TOPK_SEQUENCE = 32 2 | MAX_TOPK_STRUCTURE = MAX_TOPK_SEQUENCE 3 | MAX_TOPK_SECONDARY_STRUCTURE = MAX_TOPK_SEQUENCE 4 | MAX_TOPK_SASA = MAX_TOPK_SEQUENCE 5 | MAX_TOPK_FUNCTION = MAX_TOPK_SEQUENCE 6 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/constants/esm3.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import cache 3 | from pathlib import Path 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | SEQUENCE_BOS_TOKEN = 0 8 | SEQUENCE_PAD_TOKEN = 1 9 | SEQUENCE_EOS_TOKEN = 2 10 | SEQUENCE_CHAINBREAK_TOKEN = 31 11 | SEQUENCE_MASK_TOKEN = 32 12 | 13 | VQVAE_CODEBOOK_SIZE = 4096 14 | VQVAE_SPECIAL_TOKENS = { 15 | "MASK": VQVAE_CODEBOOK_SIZE, 16 | "EOS": VQVAE_CODEBOOK_SIZE + 1, 17 | "BOS": VQVAE_CODEBOOK_SIZE + 2, 18 | "PAD": VQVAE_CODEBOOK_SIZE + 3, 19 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, 20 | } 21 | VQVAE_DIRECTION_LOSS_BINS = 16 22 | VQVAE_PAE_BINS = 64 23 | VQVAE_MAX_PAE_BIN = 31.0 24 | VQVAE_PLDDT_BINS = 50 25 | 26 | STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"] 27 | STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"] 28 | STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"] 29 | STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"] 30 | STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"] 31 | STRUCTURE_UNDEFINED_TOKEN = 955 32 | 33 | SASA_PAD_TOKEN = 0 34 | 35 | SS8_PAD_TOKEN = 0 36 | 37 | INTERPRO_PAD_TOKEN = 0 38 | 39 | RESIDUE_PAD_TOKEN = 0 40 | 41 | CHAIN_BREAK_STR = "|" 42 | 43 | SEQUENCE_BOS_STR = "" 44 | SEQUENCE_EOS_STR = "" 45 | 46 | MASK_STR_SHORT = "_" 47 | SEQUENCE_MASK_STR = "" 48 | SASA_MASK_STR = "" 49 | SS8_MASK_STR = "" 50 | 51 | # fmt: off 52 | SEQUENCE_VOCAB = [ 53 | "", "", "", "", 54 | "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", 55 | "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", 56 | "O", ".", "-", "|", 57 | "", 58 | ] 59 | # fmt: on 60 | 61 | SSE_8CLASS_VOCAB = "GHITEBSC" 62 | SSE_3CLASS_VOCAB = "HEC" 63 | SSE_8CLASS_TO_3CLASS_MAP = { 64 | "G": "H", 65 | "H": "H", 66 | "I": "H", 67 | "T": "C", 68 | "E": "E", 69 | "B": "E", 70 | "S": "C", 71 | "C": "C", 72 | } 73 | 74 | SASA_DISCRETIZATION_BOUNDARIES = [ 75 | 0.8, 76 | 4.0, 77 | 9.6, 78 | 16.4, 79 | 24.5, 80 | 32.9, 81 | 42.0, 82 | 51.5, 83 | 61.2, 84 | 70.9, 85 | 81.6, 86 | 93.3, 87 | 107.2, 88 | 125.4, 89 | 151.4, 90 | ] 91 | 92 | MAX_RESIDUE_ANNOTATIONS = 16 93 | 94 | 95 | TFIDF_VECTOR_SIZE = 58641 96 | 97 | 98 | @staticmethod 99 | @cache 100 | def data_root(model: str): 101 | if "INFRA_PROVIDER" in os.environ: 102 | return Path("") 103 | # Try to download from hugginface if it doesn't exist 104 | if model.startswith("esm3"): 105 | # path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1")) 106 | path = Path('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/esm3-sm-open-v1') 107 | elif model.startswith("esmc-300"): 108 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12")) 109 | elif model.startswith("esmc-600"): 110 | # path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12")) 111 | path = Path('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/esmc-600m-2024-12') 112 | else: 113 | raise ValueError(f"{model=} is an invalid model name.") 114 | return path 115 | 116 | 117 | IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data" 118 | 119 | INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list" 120 | INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt" 121 | INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt" 122 | INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json" 123 | 124 | LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"} 125 | 126 | KEYWORDS_VOCABULARY = ( 127 | IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt" 128 | ) 129 | KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy" 130 | 131 | RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv" 132 | INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv" 133 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/constants/models.py: -------------------------------------------------------------------------------- 1 | # Model names 2 | ESM3_OPEN_SMALL = "esm3_sm_open_v1" 3 | ESM3_OPEN_SMALL_ALIAS_1 = "esm3-open-2024-03" 4 | ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1" 5 | ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open" 6 | ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0" 7 | ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0" 8 | ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0" 9 | ESMC_600M = "esmc_600m" 10 | ESMC_300M = "esmc_300m" 11 | 12 | 13 | def forge_only_return_single_layer_hidden_states(model_name: str): 14 | return model_name.startswith("esmc-6b") 15 | 16 | 17 | def model_is_locally_supported(x: str): 18 | return x in { 19 | ESM3_OPEN_SMALL, 20 | ESM3_OPEN_SMALL_ALIAS_1, 21 | ESM3_OPEN_SMALL_ALIAS_2, 22 | ESM3_OPEN_SMALL_ALIAS_3, 23 | } 24 | 25 | 26 | def normalize_model_name(x: str): 27 | if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}: 28 | return ESM3_OPEN_SMALL 29 | return x 30 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/constants/physics.py: -------------------------------------------------------------------------------- 1 | BB_COORDINATES = [ 2 | [0.5256, 1.3612, 0.0000], 3 | [0.0000, 0.0000, 0.0000], 4 | [-1.5251, 0.0000, 0.0000], 5 | ] 6 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/function/lsh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cloudpathlib import AnyPath 3 | 4 | from esm.utils.types import PathLike 5 | 6 | 7 | class LSHTable: 8 | def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): 9 | if hyperplanes is None: 10 | hyperplanes = np.random.randn(n_bits, dim) 11 | hyperplanes = hyperplanes / np.linalg.norm( 12 | hyperplanes, axis=-1, keepdims=True 13 | ) 14 | else: 15 | assert hyperplanes.shape == (n_bits, dim), ( 16 | hyperplanes.shape, 17 | (n_bits, dim), 18 | ) 19 | assert hyperplanes is not None 20 | self.hyperplanes: np.ndarray = hyperplanes 21 | self.values = 1 << np.arange(n_bits) 22 | 23 | def __call__(self, array, tokenize: bool = True): 24 | similarity = self.hyperplanes @ array.T 25 | bits = np.where(similarity >= 0, 1, 0) 26 | if tokenize: 27 | tokens = bits.T @ self.values 28 | return tokens 29 | else: 30 | return bits.T 31 | 32 | 33 | class LSHTokenized: 34 | def __init__( 35 | self, 36 | n_bits: int, 37 | dim: int, 38 | num_tables: int = 1, 39 | filepath: PathLike | None = None, 40 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 41 | ): 42 | table_hyperplanes = None 43 | if filepath is not None: 44 | filepath = AnyPath(filepath) 45 | if not filepath.exists(): 46 | raise FileNotFoundError(filepath) 47 | table_hyperplanes = np.load(filepath) # type: ignore 48 | for i in range(num_tables): 49 | assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" 50 | elif not allow_create_hyperplanes: 51 | raise RuntimeError( 52 | "Not allowed to create hyperplanes but no filepath provided" 53 | ) 54 | 55 | self.tables = [ 56 | LSHTable( 57 | n_bits, 58 | dim, 59 | table_hyperplanes[str(i)] if table_hyperplanes is not None else None, 60 | ) 61 | for i in range(num_tables) 62 | ] 63 | 64 | def write_hyperplanes(self, filepath: PathLike): 65 | hyperplanes: dict[str, np.ndarray] = { # type: ignore 66 | str(i): table.hyperplanes for i, table in enumerate(self.tables) 67 | } 68 | np.savez(filepath, **hyperplanes) 69 | 70 | def __call__(self, array): 71 | tokens = np.stack([table(array) for table in self.tables], 1) 72 | return tokens 73 | 74 | 75 | class LSHBitstream: 76 | def __init__( 77 | self, 78 | n_bits: int, 79 | dim: int, 80 | filepath: PathLike | None = None, 81 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 82 | ): 83 | table_hyperplanes = None 84 | if filepath is not None: 85 | filepath = AnyPath(filepath) 86 | if not filepath.exists(): 87 | raise FileNotFoundError(filepath) 88 | table_hyperplanes = np.load(filepath) 89 | elif not allow_create_hyperplanes: 90 | raise RuntimeError( 91 | "Not allowed to create hyperplanes but no filepath provided" 92 | ) 93 | 94 | self.table = LSHTable( 95 | n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None 96 | ) 97 | 98 | def write_hyperplanes(self, filepath: PathLike): 99 | np.save(filepath, self.table.hyperplanes) 100 | 101 | def __call__(self, array): 102 | return self.table(array, tokenize=False) 103 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/function/tfidf.py: -------------------------------------------------------------------------------- 1 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model.""" 2 | 3 | from collections import Counter 4 | from functools import cached_property 5 | 6 | import numpy as np 7 | from cloudpathlib import AnyPath 8 | from scipy import sparse 9 | 10 | from esm.utils.types import PathLike 11 | 12 | 13 | class TFIDFModel: 14 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model. 15 | Mimics sklearn.feature_extraction.text.TfidfVectorizer with sublinear_tf=True 16 | """ 17 | 18 | def __init__(self, vocabulary_path: PathLike, idf_path: PathLike): 19 | with AnyPath(vocabulary_path).open("r") as f: 20 | self.vocabulary = f.read().strip().split("\n") 21 | 22 | with AnyPath(idf_path).open("rb") as f: 23 | self.idf_ = np.load(f) 24 | 25 | assert self.idf_.ndim == 1 26 | assert ( 27 | len(self.idf_) == len(self.vocabulary) 28 | ), f"IDF size must match vocabulary size, got {len(self.idf_)} and {len(self.vocabulary)}" 29 | 30 | @cached_property 31 | def vocab_to_index(self) -> dict[str, int]: 32 | return {term: index for index, term in enumerate(self.vocabulary)} 33 | 34 | def encode(self, terms: list[str]) -> sparse.csr_matrix: 35 | """Encodes terms as TF-IDF vectors. 36 | 37 | Args: 38 | terms: list of terms to encode. 39 | 40 | Returns: 41 | TF-IDF vector encoded as sparse matrix of shape (1, num_terms) 42 | """ 43 | counter = Counter(filter(self.vocabulary.__contains__, terms)) 44 | indices = [self.vocab_to_index[term] for term in counter] 45 | 46 | tf = np.array([count for term, count in counter.items()]) 47 | idf = np.take(self.idf_, indices) 48 | 49 | values = (1 + np.log(tf)) * idf 50 | values /= np.linalg.norm(values) 51 | 52 | return sparse.csr_matrix( 53 | (values, (np.zeros_like(indices), indices)), shape=(1, len(self.vocabulary)) 54 | ) 55 | 56 | def decode(self, vec: sparse.csr_matrix) -> list[str]: 57 | """Extract terms from TF-IDF.""" 58 | return [self.vocabulary[i] for i in vec.indices] 59 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/misc_test.py: -------------------------------------------------------------------------------- 1 | """Tests for misc.py""" 2 | 3 | from esm.utils.misc import merge_annotations 4 | from esm.utils.types import FunctionAnnotation 5 | 6 | 7 | def test_merge_annotations(): 8 | merged = merge_annotations( 9 | [ 10 | FunctionAnnotation("a", start=1, end=10), 11 | FunctionAnnotation("b", start=5, end=15), 12 | FunctionAnnotation("a", start=10, end=20), 13 | FunctionAnnotation("b", start=2, end=6), 14 | FunctionAnnotation("c", start=4, end=10), 15 | ] 16 | ) 17 | assert len(merged) == 3 18 | assert FunctionAnnotation("a", start=1, end=20) in merged 19 | assert FunctionAnnotation("b", start=2, end=15) in merged 20 | assert FunctionAnnotation("c", start=4, end=10) in merged 21 | 22 | 23 | def test_merge_annotations_gap(): 24 | merged = merge_annotations( 25 | [ 26 | FunctionAnnotation("a", start=1, end=10), 27 | FunctionAnnotation("a", start=13, end=20), # gap is 2 28 | FunctionAnnotation("a", start=24, end=30), 29 | ], 30 | merge_gap_max=2, 31 | ) 32 | 33 | assert len(merged) == 2 34 | assert FunctionAnnotation("a", 1, 20) in merged 35 | assert FunctionAnnotation("a", 24, 30) in merged 36 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/noise_schedules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def cosine_schedule(t: torch.Tensor): 7 | # t is a tensor of size (batch_size,) with values between 0 and 1. This is the 8 | # schedule used in the MaskGIT paper 9 | return torch.cos(t * math.pi * 0.5) 10 | 11 | 12 | def cubic_schedule(t): 13 | return 1 - t**3 14 | 15 | 16 | def linear_schedule(t): 17 | return 1 - t 18 | 19 | 20 | def square_root_schedule(t): 21 | return 1 - torch.sqrt(t) 22 | 23 | 24 | def square_schedule(t): 25 | return 1 - t**2 26 | 27 | 28 | NOISE_SCHEDULE_REGISTRY = { 29 | "cosine": cosine_schedule, 30 | "linear": linear_schedule, 31 | "square_root_schedule": square_root_schedule, 32 | "cubic": cubic_schedule, 33 | "square": square_schedule, 34 | } 35 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/residue_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This mapping is used when we need to store atom data in a format that requires 17 | # fixed atom data size for every residue (e.g. a numpy array). 18 | atom_types = [ 19 | "N", 20 | "CA", 21 | "C", 22 | "CB", 23 | "O", 24 | "CG", 25 | "CG1", 26 | "CG2", 27 | "OG", 28 | "OG1", 29 | "SG", 30 | "CD", 31 | "CD1", 32 | "CD2", 33 | "ND1", 34 | "ND2", 35 | "OD1", 36 | "OD2", 37 | "SD", 38 | "CE", 39 | "CE1", 40 | "CE2", 41 | "CE3", 42 | "NE", 43 | "NE1", 44 | "NE2", 45 | "OE1", 46 | "OE2", 47 | "CH2", 48 | "NH1", 49 | "NH2", 50 | "OH", 51 | "CZ", 52 | "CZ2", 53 | "CZ3", 54 | "NZ", 55 | "OXT", 56 | ] 57 | atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} 58 | atom_type_num = len(atom_types) # := 37. 59 | 60 | restype_1to3 = { 61 | "A": "ALA", 62 | "R": "ARG", 63 | "N": "ASN", 64 | "D": "ASP", 65 | "C": "CYS", 66 | "Q": "GLN", 67 | "E": "GLU", 68 | "G": "GLY", 69 | "H": "HIS", 70 | "I": "ILE", 71 | "L": "LEU", 72 | "K": "LYS", 73 | "M": "MET", 74 | "F": "PHE", 75 | "P": "PRO", 76 | "S": "SER", 77 | "T": "THR", 78 | "W": "TRP", 79 | "Y": "TYR", 80 | "V": "VAL", 81 | } 82 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/sampling_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from esm.utils.sampling import sample_logits 5 | 6 | 7 | def test_sample_logits(): 8 | # batched input. temperature != 0.0. 9 | sampled = sample_logits( 10 | logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096)) 11 | ) 12 | assert sampled.shape == (64, 8) 13 | 14 | # batched input. temperature == 0.0. 15 | sampled = sample_logits( 16 | logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096)) 17 | ) 18 | assert sampled.shape == (64, 8) 19 | 20 | # non-batched input. temperature != 0.0. 21 | sampled = sample_logits( 22 | logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096)) 23 | ) 24 | assert sampled.shape == (8,) 25 | 26 | # non-batched input. temperature == 0.0. 27 | sampled = sample_logits( 28 | logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096)) 29 | ) 30 | assert sampled.shape == (8,) 31 | 32 | with pytest.raises(ValueError): 33 | sampled = sample_logits( 34 | logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[] 35 | ) 36 | 37 | 38 | test_sample_logits() 39 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/structure/aligner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import Field, replace 4 | from typing import Any, ClassVar, Protocol, TypeVar 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from esm.utils.structure.protein_structure import ( 10 | compute_affine_and_rmsd, 11 | ) 12 | 13 | 14 | class Alignable(Protocol): 15 | atom37_positions: np.ndarray 16 | atom37_mask: np.ndarray 17 | # Trick to detect whether an object is a dataclass 18 | __dataclass_fields__: ClassVar[dict[str, Field[Any]]] 19 | 20 | def __len__(self) -> int: ... 21 | 22 | 23 | T = TypeVar("T", bound=Alignable) 24 | 25 | 26 | class Aligner: 27 | def __init__( 28 | self, 29 | mobile: Alignable, 30 | target: Alignable, 31 | only_use_backbone: bool = False, 32 | use_reflection: bool = False, 33 | ): 34 | """ 35 | Aligns a mobile protein chain against a target protein chain. 36 | 37 | Args: 38 | mobile (ProteinChain): Protein chain to be aligned. 39 | target (ProteinChain): Protein chain target. 40 | only_use_backbone (bool): Whether to only use backbone atoms. 41 | use_reflection (bool): Whether to align to target reflection. 42 | """ 43 | # Check proteins must have same number of residues 44 | assert len(mobile) == len(target) 45 | 46 | # Determine overlapping atoms 47 | joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype( 48 | bool 49 | ) 50 | 51 | # Backbone atoms are first sites in atom37 representation 52 | if only_use_backbone: 53 | joint_atom37_mask[:, 3:] = False 54 | 55 | # Extract matching atom positions and convert to batched tensors 56 | mobile_atom_tensor = ( 57 | torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0) 58 | ) 59 | target_atom_tensor = ( 60 | torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0) 61 | ) 62 | joint_atom37_mask = ( 63 | torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0) 64 | ) 65 | 66 | # If using reflection flip target 67 | if use_reflection: 68 | target_atom_tensor = -target_atom_tensor 69 | 70 | # Compute alignment and rmsd 71 | affine3D, rmsd = compute_affine_and_rmsd( 72 | mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask 73 | ) 74 | self._affine3D = affine3D 75 | self._rmsd = rmsd.item() 76 | 77 | @property 78 | def rmsd(self): 79 | return self._rmsd 80 | 81 | def apply(self, mobile: T) -> T: 82 | """Apply alignment to a protein chain""" 83 | # Extract atom positions and convert to batched tensors 84 | mobile_atom_tensor = ( 85 | torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask]) 86 | .type(torch.float32) 87 | .unsqueeze(0) 88 | ) 89 | 90 | # Transform atom arrays 91 | aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0) 92 | 93 | # Rebuild atom37 positions 94 | aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan) 95 | aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor 96 | 97 | return replace(mobile, atom37_positions=aligned_atom37_positions) 98 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/structure/normalize_coordinates.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | from esm.utils import residue_constants as RC 8 | from esm.utils.structure.affine3d import Affine3D 9 | 10 | ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) 11 | 12 | 13 | def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: 14 | N, CA, C = bb_positions.unbind(dim=-2) 15 | return Affine3D.from_graham_schmidt(C, CA, N) 16 | 17 | 18 | def index_by_atom_name( 19 | atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 20 | ) -> ArrayOrTensor: 21 | squeeze = False 22 | if isinstance(atom_names, str): 23 | atom_names = [atom_names] 24 | squeeze = True 25 | indices = [RC.atom_order[atom_name] for atom_name in atom_names] 26 | dim = dim % atom37.ndim 27 | index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) 28 | result = atom37[index] # type: ignore 29 | if squeeze: 30 | result = result.squeeze(dim) 31 | return result 32 | 33 | 34 | def get_protein_normalization_frame(coords: Tensor) -> Affine3D: 35 | """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. 36 | Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame 37 | using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. 38 | 39 | Args: 40 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 41 | 42 | Returns: 43 | Affine3D: tensor of Affine3D frame 44 | """ 45 | bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) 46 | coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1) 47 | 48 | average_position_per_n_ca_c = bb_coords.masked_fill( 49 | ~coord_mask[..., None, None], 0 50 | ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) 51 | frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) 52 | 53 | return frame 54 | 55 | 56 | def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: 57 | """Given a set of coordinates and a single frame, apply the frame to the coordinates. 58 | 59 | Args: 60 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 61 | frame (Affine3D): Affine3D frame 62 | 63 | Returns: 64 | torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates 65 | """ 66 | coords_trans_rot = frame[..., None, None].invert().apply(coords) 67 | 68 | # only transform coordinates with frame that have a valid rotation 69 | valid_frame = frame.trans.norm(dim=-1) > 0 70 | 71 | is_inf = torch.isinf(coords) 72 | coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) 73 | coords.masked_fill_(is_inf, torch.inf) 74 | 75 | return coords 76 | 77 | 78 | def normalize_coordinates(coords: Tensor) -> Tensor: 79 | return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) 80 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/structure/predicted_aligned_error.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from esm.utils.structure.affine3d import Affine3D 5 | 6 | 7 | def masked_mean( 8 | mask: torch.Tensor, 9 | value: torch.Tensor, 10 | dim: int | None | tuple[int, ...] = None, 11 | eps=1e-10, 12 | ) -> torch.Tensor: 13 | """Compute the mean of `value` where only positions where `mask == true` are 14 | counted. 15 | """ 16 | mask = mask.expand(*value.shape) 17 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 18 | 19 | 20 | def _pae_bins( 21 | max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu") 22 | ): 23 | bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device) 24 | step = max_bin / (num_bins - 2) 25 | bin_centers = bins + step / 2 26 | bin_centers = torch.cat( 27 | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 28 | ) 29 | return bin_centers 30 | 31 | 32 | def _compute_pae_masks(mask: torch.Tensor): 33 | square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool() 34 | return square_mask 35 | 36 | 37 | def compute_predicted_aligned_error( 38 | logits: torch.Tensor, 39 | aa_mask: torch.Tensor, 40 | sequence_id: torch.Tensor | None = None, 41 | max_bin: float = 31, 42 | ) -> torch.Tensor: 43 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 44 | square_mask = _compute_pae_masks(aa_mask) 45 | min_v = torch.finfo(logits.dtype).min 46 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 47 | 48 | return (probs * bins).sum(dim=-1) 49 | 50 | 51 | @torch.no_grad 52 | def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0): 53 | square_mask = _compute_pae_masks(aa_mask) 54 | seqlens = aa_mask.sum(-1, keepdim=True) 55 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 56 | d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8 57 | f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2) 58 | 59 | min_v = torch.finfo(logits.dtype).min 60 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 61 | # This is the sum over bins 62 | ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1) 63 | # This is the mean over residues j 64 | ptm = masked_mean(square_mask, ptm, dim=-1) 65 | # The we do a max over residues i 66 | return ptm.max(dim=-1).values 67 | 68 | 69 | def tm_loss( 70 | logits: torch.Tensor, 71 | pred_affine: torch.Tensor, 72 | targ_affine: torch.Tensor, 73 | targ_mask: torch.Tensor, 74 | tm_mask: torch.Tensor | None = None, 75 | sequence_id: torch.Tensor | None = None, 76 | max_bin: float = 31, 77 | ): 78 | pred = Affine3D.from_tensor(pred_affine) 79 | targ = Affine3D.from_tensor(targ_affine) 80 | 81 | def transform(affine: Affine3D): 82 | pts = affine.trans[..., None, :, :] 83 | return affine.invert()[..., None].apply(pts) 84 | 85 | with torch.no_grad(): 86 | sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1) 87 | 88 | num_bins = logits.shape[-1] 89 | sq_bins = torch.linspace( 90 | 0, max_bin, num_bins - 1, device=logits.device 91 | ).square() 92 | # Gets the bin id by using a sum. 93 | true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long() 94 | 95 | errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none") 96 | square_mask = _compute_pae_masks(targ_mask) 97 | loss = masked_mean(square_mask, errors, dim=(-1, -2)) 98 | 99 | if tm_mask is not None: 100 | loss = masked_mean(tm_mask, loss, dim=None) 101 | else: 102 | loss = loss.mean() 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /model_zoom/esm/utils/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | from cloudpathlib import CloudPath 9 | 10 | PathLike = Union[str, Path, CloudPath] 11 | PathOrBuffer = Union[PathLike, io.StringIO] 12 | 13 | 14 | @dataclass 15 | class FunctionAnnotation: 16 | """Represents an annotation of a protein's function over a range of residues. 17 | 18 | Fields: 19 | label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs 20 | start (int): Start index of this annotation. 1-indexed, inclusive. 21 | end (int): End index of this annotation. 1-indexed, inclusive. 22 | """ 23 | 24 | label: str 25 | start: int 26 | end: int 27 | 28 | def to_tuple(self) -> tuple[str, int, int]: 29 | return self.label, self.start, self.end 30 | 31 | def __len__(self) -> int: 32 | """Length of the annotation.""" 33 | return self.end - self.start + 1 34 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/clients.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import huggingface_hub 4 | import huggingface_hub.errors 5 | import torch 6 | 7 | from esm.models.esm3 import ESM3 8 | from esm.sdk import ESM3ForgeInferenceClient 9 | from esm.sdk.api import ESM3InferenceClient 10 | 11 | 12 | def get_local_client() -> ESM3InferenceClient: 13 | try: 14 | huggingface_hub.whoami() 15 | except huggingface_hub.errors.LocalTokenNotFoundError: 16 | raise ValueError("Hugging Face token not found.") 17 | return ESM3.from_pretrained(device=torch.device("cuda")) 18 | 19 | 20 | def get_forge_client(model_name: str) -> ESM3InferenceClient: 21 | forge_token = os.environ.get("ESM_API_KEY", None) 22 | if forge_token is None: 23 | raise ValueError( 24 | "Forge API key not found. Please set the ESM_API_KEY environment variable." 25 | ) 26 | return ESM3ForgeInferenceClient( 27 | model=model_name, url="https://forge.evolutionaryscale.ai", token=forge_token 28 | ) 29 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/drawing/colors.py: -------------------------------------------------------------------------------- 1 | def hex_to_rgba_tuple(hex_color, alpha=1.0): 2 | hex_color = hex_color.lstrip("#") 3 | r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) 4 | return r, g, b, alpha 5 | 6 | 7 | def rgba_tuple_to_rgba_html_string( 8 | rgba_tuple: tuple[int | float, int | float, int | float, float], 9 | ) -> str: 10 | return f"rgba({rgba_tuple[0]},{rgba_tuple[1]},{rgba_tuple[2]},{rgba_tuple[3]})" 11 | 12 | 13 | def rgba_tuple_to_hex( 14 | rgba: tuple[int | float, int | float, int | float] 15 | | tuple[int | float, int | float, int | float, int | float], 16 | ) -> str: 17 | def float_to_int(f): 18 | return int(f * 255) 19 | 20 | if all([isinstance(c, float) for c in rgba]): 21 | r = float_to_int(rgba[0]) 22 | g = float_to_int(rgba[1]) 23 | b = float_to_int(rgba[2]) 24 | if len(rgba) > 3: 25 | rgba = (r, g, b, rgba[3]) 26 | else: 27 | rgba = (r, g, b) 28 | 29 | if len(rgba) == 4: 30 | rgba_ = (*rgba[:3], float_to_int(rgba[3])) 31 | return "#%02x%02x%02x%02x" % rgba_ 32 | else: 33 | return "#%02x%02x%02x" % rgba 34 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/drawing/draw_function_annotations.py: -------------------------------------------------------------------------------- 1 | import io 2 | from contextlib import contextmanager 3 | 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | from dna_features_viewer import GraphicFeature, GraphicRecord 7 | from ipywidgets import widgets 8 | from matplotlib import colormaps 9 | from PIL import Image 10 | 11 | from esm.sdk.api import FunctionAnnotation 12 | from esm.utils.function.interpro import ( 13 | InterPro, 14 | InterProEntryType, 15 | ) 16 | 17 | 18 | @contextmanager 19 | def use_backend(backend): 20 | original_backend = matplotlib.get_backend() 21 | matplotlib.use(backend, force=True) 22 | try: 23 | yield 24 | finally: 25 | matplotlib.use(original_backend, force=True) 26 | 27 | 28 | def draw_function_annotations( 29 | annotations: list[FunctionAnnotation], sequence_length: int, interpro_=InterPro() 30 | ) -> widgets.Image: 31 | cmap = colormaps["tab10"] 32 | colors = [cmap(i) for i in range(len(InterProEntryType))] 33 | type_colors = dict(zip(InterProEntryType, colors)) 34 | 35 | features = [] 36 | for annotation in annotations: 37 | if annotation.label in interpro_.entries: 38 | entry = interpro_.entries[annotation.label] 39 | label = entry.name 40 | entry_type = entry.type 41 | else: 42 | label = annotation.label 43 | entry_type = InterProEntryType.UNKNOWN 44 | 45 | feature = GraphicFeature( 46 | start=annotation.start - 1, # one index -> zero index 47 | end=annotation.end, 48 | label=label, 49 | color=type_colors[entry_type], # type: ignore 50 | strand=None, 51 | ) 52 | features.append(feature) 53 | 54 | # Initialize plotting backend 55 | temp_output = widgets.Output() 56 | with temp_output: 57 | fig, ax = plt.subplots() 58 | temp_output.clear_output() 59 | 60 | buf = io.BytesIO() 61 | with use_backend("agg"): 62 | fig, ax = plt.subplots() 63 | record = GraphicRecord( 64 | sequence=None, sequence_length=sequence_length, features=features 65 | ) 66 | record.plot(ax=ax, plot_sequence=False) 67 | fig.savefig(buf, format="png", dpi=200, bbox_inches="tight") 68 | 69 | # Load the image from the buffer to get its size 70 | image = Image.open(buf) 71 | width, height = image.size 72 | aspect_ratio = width / height 73 | 74 | # Set the maximum height for the image widget 75 | max_height = 300 76 | calculated_width = int(max_height * aspect_ratio) 77 | 78 | buf.seek(0) 79 | 80 | image_widget = widgets.Image( 81 | value=buf.getvalue(), 82 | format="png", 83 | layout=widgets.Layout(width=f"{calculated_width}px", height=f"{max_height}px"), 84 | ) 85 | buf.close() 86 | return image_widget 87 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/drawing/draw_protein_structure.py: -------------------------------------------------------------------------------- 1 | import py3Dmol 2 | from IPython.display import clear_output 3 | from ipywidgets import widgets 4 | 5 | from esm.utils.structure.protein_chain import ProteinChain 6 | 7 | 8 | def draw_protein_structure( 9 | output: widgets.Output, 10 | protein_chain: ProteinChain, 11 | highlighted_ranges: list[tuple[int, int, str]] = [], 12 | ): 13 | pdb_str = protein_chain.to_pdb_string() 14 | with output: 15 | clear_output(wait=True) 16 | view = py3Dmol.view(width=500, height=500) 17 | view.addModel(pdb_str, "pdb") 18 | view.setStyle({"cartoon": {"color": "gray"}}) 19 | 20 | for start, end, color in highlighted_ranges: 21 | view.setStyle( 22 | {"resi": str(start) + "-" + str(end)}, {"cartoon": {"color": color}} 23 | ) 24 | 25 | view.zoomTo() 26 | view.show() 27 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/indexing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from esm.utils.structure.protein_chain import ProteinChain 4 | 5 | ZERO_INDEX = "Zero index" 6 | PDB_INDEX = "PDB index" 7 | 8 | PDB_INDEX_SUFFIX = "[PDB Index]" 9 | 10 | 11 | def get_pdb_index_min_max(protein_chain: ProteinChain) -> tuple[int, int]: 12 | residue_index = protein_chain.residue_index 13 | valid_residue_index = residue_index[residue_index != -1] 14 | return min(valid_residue_index), max(valid_residue_index) 15 | 16 | 17 | def pdb_index_to_zero_index(residue_index: int, protein_chain: ProteinChain) -> int: 18 | # Find the first position equal to residue_index 19 | pos = np.argwhere(residue_index == protein_chain.residue_index) 20 | if len(pos) == 0: 21 | raise ValueError(f"Residue index {residue_index} not found in protein chain") 22 | return pos[0][0] 23 | 24 | 25 | def zero_index_to_pdb_index(zero_index: int, protein_chain: ProteinChain) -> int: 26 | return protein_chain.residue_index[zero_index] 27 | 28 | 29 | def zero_range_to_pdb_range( 30 | zero_range: tuple[int, int], protein_chain: ProteinChain 31 | ) -> tuple[int, int]: 32 | return ( 33 | zero_index_to_pdb_index(zero_range[0], protein_chain), 34 | zero_index_to_pdb_index(zero_range[1], protein_chain), 35 | ) 36 | 37 | 38 | def pdb_range_to_zero_range( 39 | pdb_range: tuple[int, int], protein_chain: ProteinChain 40 | ) -> tuple[int, int]: 41 | return ( 42 | pdb_index_to_zero_index(pdb_range[0], protein_chain), 43 | pdb_index_to_zero_index(pdb_range[1], protein_chain), 44 | ) 45 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/parsing.py: -------------------------------------------------------------------------------- 1 | def convert_range_string_to_list_of_ranges(range_str: str) -> list[tuple[int, int]]: 2 | def parse_range(range_str: str) -> list[tuple[int, int]]: 3 | result: list[tuple[int, int]] = [] 4 | for r in range_str.split(","): 5 | if "-" in r: 6 | start, end = map(int, r.split("-")) 7 | result.append((start, end)) 8 | else: 9 | start = end = int(r) 10 | result.append((start, end)) 11 | return result 12 | 13 | return parse_range(range_str) 14 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/printing.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | 3 | 4 | def wrapped_print(text, width=70): 5 | text = str(text) 6 | wrapped_text = textwrap.fill(text, width=width) 7 | print(wrapped_text) 8 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from io import StringIO 4 | from typing import Literal 5 | 6 | from ipywidgets import widgets 7 | 8 | from esm.sdk.api import ESMProtein 9 | 10 | 11 | def protein_to_pdb_buffer(protein: ESMProtein) -> bytes: 12 | pdb_buffer = StringIO() 13 | protein.to_pdb(pdb_buffer) 14 | pdb_buffer.seek(0) 15 | return pdb_buffer.read().encode() 16 | 17 | 18 | def create_download_button_from_buffer( 19 | buffer: bytes, 20 | filename: str, 21 | description: str = "Download", 22 | type: Literal["json", "bytes"] = "bytes", 23 | ) -> widgets.HTML: 24 | b64 = base64.b64encode(buffer).decode() 25 | if type == "json": 26 | payload = f"data:text/json;base64,{b64}" 27 | elif type == "bytes": 28 | payload = f"data:application/octet-stream;base64,{b64}" 29 | html_buttons = f""" 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | """ 41 | download_link = widgets.HTML(html_buttons) 42 | return download_link 43 | 44 | 45 | def create_download_results_button( 46 | protein_list: list[ESMProtein], filename: str 47 | ) -> widgets.HTML: 48 | serialized_proteins = [serialize_protein(p) for p in protein_list] 49 | serialized_data = json.dumps(serialized_proteins, indent=4) 50 | return create_download_button_from_buffer( 51 | buffer=serialized_data.encode(), 52 | filename=filename, 53 | type="json", 54 | description="Download As JSON", 55 | ) 56 | 57 | 58 | def serialize_protein(protein: ESMProtein) -> str: 59 | protein_dict = { 60 | "sequence": protein.sequence, 61 | "coordinates": protein.coordinates.tolist() 62 | if protein.coordinates is not None 63 | else None, 64 | "secondary_structure": protein.secondary_structure, 65 | "sasa": protein.sasa, 66 | "function_annotations": [ 67 | (annotation.label, annotation.start, annotation.end) 68 | for annotation in protein.function_annotations 69 | ] 70 | if protein.function_annotations is not None 71 | else None, 72 | "plddt": protein.plddt.tolist() if protein.plddt is not None else None, 73 | "ptm": protein.ptm.tolist() if protein.ptm is not None else None, 74 | } 75 | return json.dumps(protein_dict, indent=4) 76 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/utils/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Literal, TypedDict 2 | 3 | from esm.sdk.api import ESM3InferenceClient 4 | 5 | 6 | class ClientInitContainerMetadata(TypedDict): 7 | inference_option: Literal["Forge API", "Local"] | None 8 | 9 | 10 | class ClientInitContainer: 11 | client_init_callback: Callable[[], ESM3InferenceClient] | None = None 12 | metadata: ClientInitContainerMetadata 13 | 14 | def __init__(self): 15 | self.metadata = ClientInitContainerMetadata(inference_option=None) 16 | 17 | def __call__(self, *args: Any, **kwds: Any) -> ESM3InferenceClient: 18 | if self.client_init_callback is None: 19 | raise ValueError("Client not initialized.") 20 | return self.client_init_callback() 21 | -------------------------------------------------------------------------------- /model_zoom/esm/widgets/views/inverse_folding.py: -------------------------------------------------------------------------------- 1 | from ipywidgets import widgets 2 | 3 | from esm.sdk.api import ( 4 | ESM3InferenceClient, 5 | ESMProtein, 6 | ESMProteinError, 7 | GenerationConfig, 8 | ) 9 | from esm.widgets.components.results_visualizer import ( 10 | create_results_visualizer, 11 | ) 12 | from esm.widgets.utils.printing import wrapped_print 13 | from esm.widgets.utils.protein_import import ProteinImporter 14 | 15 | 16 | def create_inverse_folding_ui(client: ESM3InferenceClient) -> widgets.Widget: 17 | # Alow a single protein and immediately load it from workspace 18 | protein_importer = ProteinImporter(max_proteins=1, autoload=True) 19 | output = widgets.Output() 20 | inverse_folding_ui = widgets.VBox([protein_importer.importer_ui, output]) 21 | 22 | inverse_fold_button = widgets.Button( 23 | description="Inverse Fold", 24 | disabled=True, 25 | tooltip="Click to predict the protein sequence from the structure", 26 | style={"button_color": "lightgreen"}, 27 | ) 28 | 29 | def get_protein() -> ESMProtein: 30 | [first_protein] = protein_importer.protein_list 31 | protein_id, protein_chain = first_protein 32 | protein = ESMProtein.from_protein_chain(protein_chain) 33 | 34 | # NOTE: We ignore all properties except structure 35 | protein.sequence = None 36 | protein.secondary_structure = None 37 | protein.sasa = None 38 | protein.function_annotations = None 39 | return protein 40 | 41 | def on_new_protein(_): 42 | is_protein = len(protein_importer.protein_list) > 0 43 | inverse_fold_button.disabled = not is_protein 44 | inverse_folding_ui.children = [ 45 | protein_importer.importer_ui, 46 | inverse_fold_button, 47 | output, 48 | ] 49 | 50 | def validate_inverse_fold(_): 51 | if len(protein_importer.protein_list) == 0: 52 | inverse_fold_button.disabled = True 53 | else: 54 | inverse_fold_button.disabled = False 55 | 56 | def on_click_inverse_fold(_): 57 | try: 58 | # Reset the output and results 59 | output.clear_output() 60 | inverse_folding_ui.children = [ 61 | protein_importer.importer_ui, 62 | inverse_fold_button, 63 | output, 64 | ] 65 | # Predict the protein's sequence 66 | protein = get_protein() 67 | with output: 68 | print("Predicting the protein sequence from the structure...") 69 | protein = client.generate( 70 | input=protein, 71 | config=GenerationConfig(track="sequence", num_steps=1), 72 | ) 73 | if isinstance(protein, ESMProteinError): 74 | wrapped_print(f"Protein Error: {protein.error_msg}") 75 | elif isinstance(protein, ESMProtein): 76 | sequence_results = create_results_visualizer( 77 | modality="sequence", 78 | samples=[protein], 79 | items_per_page=1, 80 | include_title=False, 81 | ) 82 | output.clear_output() 83 | inverse_folding_ui.children = [ 84 | protein_importer.importer_ui, 85 | inverse_fold_button, 86 | sequence_results, 87 | ] 88 | except Exception as e: 89 | with output: 90 | wrapped_print(e) 91 | 92 | inverse_fold_button.on_click(on_click_inverse_fold) 93 | protein_importer.entries_box.observe(on_new_protein, names="children") 94 | protein_importer.register_delete_callback(lambda: validate_inverse_fold(None)) 95 | 96 | return inverse_folding_ui 97 | -------------------------------------------------------------------------------- /model_zoom/esm2_3b_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM 3 | 4 | 5 | tokenizer = AutoTokenizer.from_pretrained("model_zoom/esm2_3b") 6 | pretrain_model = AutoModelForMaskedLM.from_pretrained("model_zoom/esm2_3b") 7 | 8 | toy_data = [ 9 | "MKTLLLTLVVVTIVCLDLGYT", # 序列1 10 | "GAVLFGYTPGGLAAGALYGVK" # 序列2 11 | ] 12 | 13 | # 将蛋白质序列进行分词 14 | inputs = tokenizer(toy_data, return_tensors="pt", padding=True, truncation=True) 15 | 16 | # 进行推理并获取模型输出 17 | with torch.no_grad(): 18 | outputs = pretrain_model.esm( 19 | inputs['input_ids'], 20 | attention_mask=inputs['attention_mask'], 21 | return_dict=True, 22 | ) 23 | embeddings = outputs.last_hidden_state 24 | 25 | # 打印 embedding 形状 26 | print(f"Embedding shape for sequence 1: {embeddings[0].shape}") 27 | print(f"Embedding shape for sequence 2: {embeddings[1].shape}") -------------------------------------------------------------------------------- /model_zoom/esm2_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM 3 | 4 | 5 | tokenizer = AutoTokenizer.from_pretrained("model_zoom/esm2_650m") 6 | pretrain_model = AutoModelForMaskedLM.from_pretrained("model_zoom/esm2_650m") 7 | 8 | toy_data = [ 9 | "MKTLLLTLVVVTIVCLDLGYT", # 序列1 10 | "GAVLFGYTPGGLAAGALYGVK" # 序列2 11 | ] 12 | 13 | # 将蛋白质序列进行分词 14 | inputs = tokenizer(toy_data, return_tensors="pt", padding=True, truncation=True) 15 | 16 | # 进行推理并获取模型输出 17 | with torch.no_grad(): 18 | outputs = pretrain_model.esm( 19 | inputs['input_ids'], 20 | attention_mask=inputs['attention_mask'], 21 | return_dict=True, 22 | ) 23 | embeddings = outputs.last_hidden_state 24 | 25 | # 打印 embedding 形状 26 | print(f"Embedding shape for sequence 1: {embeddings[0].shape}") 27 | print(f"Embedding shape for sequence 2: {embeddings[1].shape}") -------------------------------------------------------------------------------- /model_zoom/esm3_example.py: -------------------------------------------------------------------------------- 1 | from model_zoom.esm.models.esm3 import ESM3 2 | from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig 3 | from esm.sdk.api import ( 4 | ESMCInferenceClient, 5 | ESMProtein, 6 | ESMProteinTensor, 7 | ForwardTrackData, 8 | LogitsConfig, 9 | LogitsOutput, 10 | ) 11 | from esm.utils.sampling import _BatchedESMProteinTensor 12 | from esm.utils.misc import stack_variable_length_tensors 13 | from esm.utils import encoding 14 | 15 | model = ESM3.from_pretrained("esm3-sm-open-v1", ).to("cuda") 16 | 17 | protein1 = ESMProtein.from_pdb('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/N128/8rk2A.pdb') 18 | protein2 = ESMProtein.from_pdb('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/N128/8rkfA.pdb') 19 | pad = model.tokenizers.sequence.pad_token_id 20 | sequence_list = [protein1.sequence, protein2.sequence] 21 | coordinates_list = [protein1.coordinates, protein2.coordinates] 22 | 23 | seq_tokenizer = model.tokenizers.sequence 24 | struct_tokenizer = model.tokenizers.structure 25 | 26 | sequence_tokens = stack_variable_length_tensors( 27 | [ 28 | encoding.tokenize_sequence(x, seq_tokenizer, add_special_tokens=True) 29 | for x in sequence_list 30 | ], 31 | constant_value=pad, 32 | ).to(next(model.parameters()).device) 33 | 34 | 35 | structure_tokens_batch = [] 36 | coordinates_batch = [] 37 | for coordinate in coordinates_list: 38 | coordinates, plddt, structure_token = encoding.tokenize_structure(coordinate, model.get_structure_encoder(), struct_tokenizer, add_special_tokens=True) 39 | structure_tokens_batch.append(structure_token) 40 | coordinates_batch.append(coordinates) 41 | 42 | structure_tokens_batch = stack_variable_length_tensors( 43 | structure_tokens_batch, 44 | constant_value=pad, 45 | ).to(next(model.parameters()).device) 46 | 47 | coordinates_batch = stack_variable_length_tensors( 48 | coordinates_batch, 49 | constant_value=pad, 50 | ).to(next(model.parameters()).device) 51 | 52 | 53 | protein_tensor = _BatchedESMProteinTensor(sequence=sequence_tokens, structure=structure_tokens_batch, coordinates=coordinates_batch).to( 54 | next(model.parameters()).device 55 | ) 56 | 57 | output = model.logits( 58 | protein_tensor, 59 | LogitsConfig( 60 | sequence=True, 61 | structure=True, 62 | secondary_structure=True, 63 | sasa=True, 64 | function=True, 65 | residue_annotations=True, 66 | return_embeddings=True, 67 | ), 68 | ) 69 | 70 | print(output.embeddings.shape) -------------------------------------------------------------------------------- /model_zoom/esmc_example.py: -------------------------------------------------------------------------------- 1 | from model_zoom.esm.models.esmc import ESMC 2 | from esm.sdk.api import ESMProtein, LogitsConfig 3 | from esm.utils.sampling import _BatchedESMProteinTensor 4 | 5 | # protein = ESMProtein(sequence="AAAAA") 6 | client = ESMC.from_pretrained("esmc_600m").to("cuda") # or "cpu" 7 | 8 | 9 | sequence_tokens = client._tokenize(["AAAAA", "AAAAAMK"]) 10 | protein_tensor = _BatchedESMProteinTensor(sequence=sequence_tokens).to( 11 | next(client.parameters()).device 12 | ) 13 | logits_output = client.logits( 14 | protein_tensor, LogitsConfig(sequence=True, return_embeddings=True) 15 | ) 16 | print(logits_output.logits, logits_output.embeddings) -------------------------------------------------------------------------------- /model_zoom/ontoprotein_example.py: -------------------------------------------------------------------------------- 1 | # Load model directly 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM 3 | 4 | tokenizer = AutoTokenizer.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/OntoProtein") 5 | model = AutoModelForMaskedLM.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/OntoProtein") 6 | 7 | import re 8 | sequence_Example = "A E T C Z A O" 9 | sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example) 10 | encoded_input = tokenizer(sequence_Example, return_tensors='pt') 11 | output = model.bert( 12 | input_ids=encoded_input['input_ids'], 13 | attention_mask=encoded_input['attention_mask'], 14 | token_type_ids=encoded_input['token_type_ids'], 15 | ) 16 | 17 | print(output.last_hidden_state) -------------------------------------------------------------------------------- /model_zoom/pglm_example.py: -------------------------------------------------------------------------------- 1 | 2 | # Obtain residue embeddings 3 | from transformers import AutoModelForMaskedLM, AutoTokenizer 4 | import torch 5 | 6 | tokenizer = AutoTokenizer.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/proteinglm-1b-mlm", trust_remote_code=True, use_fast=True) 7 | model = AutoModelForMaskedLM.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/proteinglm-1b-mlm", trust_remote_code=True, torch_dtype=torch.bfloat16) 8 | if torch.cuda.is_available(): 9 | model = model.cuda() 10 | model.eval() 11 | 12 | seq = 'MILMCQHFSGQFSKYFLAVSSDFCHFVFPIILVSHVNFKQMKRKGFALWNDRAVPFTQGIFTTVMILLQYLHGTG' 13 | output = tokenizer(seq, add_special_tokens=True, return_tensors='pt') 14 | with torch.inference_mode(): 15 | inputs = {"input_ids": output["input_ids"].cuda(), "attention_mask": output["attention_mask"].cuda()} 16 | output_embeddings = model(**inputs, output_hidden_states=True, return_last_hidden_state=True).hidden_states[:-1, 0] # get rid of the token 17 | 18 | print(output_embeddings.shape) # (1, 2048) 19 | 20 | -------------------------------------------------------------------------------- /model_zoom/pro2text_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["HF_MODULES_CACHE"] = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom" 3 | from functools import partial 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from graphein.protein.config import ProteinGraphConfig, DSSPConfig 6 | from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor 7 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure 8 | from graphein.protein.edges.distance import ( 9 | add_peptide_bonds, 10 | add_hydrogen_bond_interactions, 11 | add_distance_threshold, 12 | ) 13 | from model_zoom.prot2text.pdb2graph import PDB2Graph, download_alphafold_structure 14 | 15 | 16 | 17 | prot2text_weight = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/prot2text" 18 | tokenizer = AutoTokenizer.from_pretrained(prot2text_weight, 19 | trust_remote_code=True) 20 | model = AutoModelForCausalLM.from_pretrained(prot2text_weight, 21 | trust_remote_code=True) 22 | 23 | config = { 24 | "node_metadata_functions": [ 25 | amino_acid_one_hot, 26 | expasy_protein_scale, 27 | meiler_embedding, 28 | hydrogen_bond_acceptor, 29 | hydrogen_bond_donor 30 | ], 31 | "edge_construction_functions": [ 32 | add_peptide_bonds, 33 | add_hydrogen_bond_interactions, 34 | partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.) 35 | ], 36 | "graph_metadata_functions":[ 37 | asa, 38 | phi, 39 | psi, 40 | secondary_structure, 41 | rsa 42 | ], 43 | "dssp_config": DSSPConfig() 44 | } 45 | config = ProteinGraphConfig(**config) 46 | 47 | function = model.generate_protein_description(protein_pdbID='Q10MK9', 48 | tokenizer=tokenizer, 49 | device='cuda' # replace with 'mps' to run on a Mac device 50 | ) 51 | 52 | print(function) 53 | 54 | gpdb = PDB2Graph(root = PATH_TO_DATA, output_folder = OUTPUT_FOLDER, config=config, n_processors=1).create_pyg_graph(structure_filename) 55 | seq = esmtokenizer(gpdb.sequence, add_special_tokens=True, truncation=True, max_length=1021, padding='max_length',return_tensors="pt") # 56 | torch.save(gpdb, graph_filename) 57 | gpdb.edge_type = [np.array(gpdb.edge_type.transpose(0,1))] 58 | gpdb.encoder_input_ids = seq['input_ids'] 59 | gpdb.attention_mask = seq['attention_mask'] 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /model_zoom/procyon_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from procyon.model.model_unified import UnifiedProCyon 4 | from procyon.data.inference_utils import ( 5 | create_caption_input_simple, 6 | uniprot_id_to_index, 7 | ) 8 | 9 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 10 | # Initialize tokenizer and model 11 | CKPT_NAME = '/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/procyon/model_weights/ProCyon-Full' 12 | os.environ["LLAMA3_PATH"] = "/nfs_beijing/wanghao/2025-onesystem/vllm/Meta-Llama-3-8B" 13 | model, _ = UnifiedProCyon.from_pretrained(pretrained_weights_dir=CKPT_NAME, checkpoint_dir=CKPT_NAME) 14 | model = model.to(device) 15 | data_args = torch.load(os.path.join(CKPT_NAME, "data_args.pt")) 16 | 17 | # Internally, ProCyon uses integer IDs that have been assigned to UniProt proteins in ProCyon-Instruct. 18 | want_proteins = ["Q5T1N1"] 19 | protein_ids = [uniprot_id_to_index(x) for x in want_proteins] 20 | 21 | input_simple = create_caption_input_simple( 22 | input_aaseq_ids=protein_ids, 23 | data_args=data_args, 24 | # The `instruction_source_dataset` and `instruction_source_relation` here control the style 25 | # of pre-templated instruction used in these queries. In particular, here we query for UniProt-style 26 | # functional descriptions. 27 | instruction_source_dataset="uniprot", 28 | instruction_source_relation="all", 29 | aaseq_type="protein", 30 | task_type="caption", 31 | icl_example_number=1, 32 | device=device, 33 | ) 34 | 35 | text_gen_args = { 36 | "method": "beam", 37 | # Maximum length of generated text. 38 | "max_len": 200, 39 | # Total number of beams maintained per input. `beam_size` / `beam_group_size` = number of phenotypes returned per input. 40 | "beam_size": 20, 41 | # Size of the individual beam groups in DBS. 42 | "beam_group_size": 2, 43 | # Penalty applied to repetition within a beam group. 44 | "diversity_penalty": 0.8, 45 | } 46 | 47 | out_tokens, log_probs, output_logits, out_text = model.generate( 48 | inputs=input_simple, 49 | aaseq_type="protein", 50 | **text_gen_args 51 | ) 52 | 53 | breakpoint() 54 | print(":)") 55 | -------------------------------------------------------------------------------- /model_zoom/progen2_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from progen2.modeling_progen import ProGenForCausalLM 3 | from tokenizers import Tokenizer 4 | 5 | def create_tokenizer_custom(file): 6 | with open(file, 'r') as f: 7 | return Tokenizer.from_str(f.read()) 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | ckpt = '/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/progen2' 11 | model = ProGenForCausalLM.from_pretrained(ckpt).to(device) 12 | 13 | tokens = '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' 14 | 15 | 16 | tokenizer = create_tokenizer_custom(file='/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/progen2/tokenizer.json') 17 | 18 | target = torch.tensor(tokenizer.encode(tokens).ids).to(device) 19 | logits = model(target, labels=target).logits 20 | print(logits.shape) -------------------------------------------------------------------------------- /model_zoom/prosst_2048_example.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForMaskedLM 2 | from ProSST.prosst.structure.quantizer import PdbQuantizer 3 | import torch 4 | 5 | 6 | def tokenize_structure_sequence(structure_sequence): 7 | shift_structure_sequence = [i + 3 for i in structure_sequence] 8 | shift_structure_sequence = [1, *shift_structure_sequence, 2] 9 | return torch.tensor( 10 | [ 11 | shift_structure_sequence, 12 | ], 13 | dtype=torch.long, 14 | ) 15 | 16 | 17 | prosst_2048_weight_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProSST/prosst_2048_weight" 18 | device = "cuda:1" 19 | 20 | model = AutoModelForMaskedLM.from_pretrained(prosst_2048_weight_path, trust_remote_code=True).to(device) 21 | 22 | aa_seq = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGLDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH" 23 | pdb_path = "/nfs_beijing/onesystem/share_data/pdbs/glmfold/2d1b76cead45d9d7783547e6ce02d918/ranked_unrelax_0.pdb" 24 | 25 | tokenizer = AutoTokenizer.from_pretrained(prosst_2048_weight_path, trust_remote_code=True) 26 | tokenized_results = tokenizer([aa_seq], return_tensors="pt") 27 | input_ids = tokenized_results["input_ids"].to(device) 28 | attention_mask = tokenized_results["attention_mask"].to(device) 29 | 30 | processor = PdbQuantizer(structure_vocab_size=2048) 31 | ss_seq = processor(pdb_path, return_residue_seq=False)['2048']['ranked_unrelax_0.pdb']["struct"] 32 | ss_input_ids = tokenize_structure_sequence(ss_seq).to(device) 33 | 34 | 35 | 36 | outputs = model( 37 | input_ids=input_ids, 38 | attention_mask=attention_mask, 39 | ss_input_ids=ss_input_ids, 40 | output_hidden_states=True 41 | ) 42 | 43 | print(outputs.hidden_states[-1].shape) # [1, L, 768] 44 | -------------------------------------------------------------------------------- /model_zoom/prostT5_example.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5EncoderModel 2 | import torch 3 | import re 4 | from src.data.protein import Protein 5 | import mini3di 6 | encoder = mini3di.Encoder() 7 | 8 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 9 | 10 | # Load the tokenizer 11 | tokenizer = T5Tokenizer.from_pretrained('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProstT5', do_lower_case=False) 12 | 13 | # Load the model 14 | model = T5EncoderModel.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProstT5").to(device) 15 | 16 | # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower) 17 | model.float() if device.type=='cpu' else model.half() 18 | 19 | # prepare your protein sequences/structures as a list. 20 | # Amino acid sequences are expected to be upper-case ("PRTEINO" below) 21 | # while 3Di-sequences need to be lower-case ("strctr" below). 22 | data = Protein.from_PDB('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/N128/8rk2A.pdb') 23 | X, C, S = data.to_XCS(all_atom=True) 24 | N, CA, C, O, CB = X[0,:, 0], X[0,:, 1], X[0,:, 2], X[0,:, 3], X[0,:, 4] 25 | states = encoder.encode_atoms(ca = CA.numpy(), cb = CB.numpy(), n = N.numpy(), c = C.numpy()) 26 | struct_sequence = encoder.build_sequence(states).lower() 27 | AA_sequence = data.sequence() 28 | sequence_examples = [AA_sequence, struct_sequence] 29 | 30 | # replace all rare/ambiguous amino acids by X (3Di sequences do not have those) and introduce white-space between all sequences (AAs and 3Di) 31 | sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples] 32 | 33 | # The direction of the translation is indicated by two special tokens: 34 | # if you go from AAs to 3Di (or if you want to embed AAs), you need to prepend "" 35 | # if you go from 3Di to AAs (or if you want to embed 3Di), you need to prepend "" 36 | sequence_examples = [ "" + " " + s if s.isupper() else "" + " " + s # this expects 3Di sequences to be already lower-case 37 | for s in sequence_examples 38 | ] 39 | 40 | # tokenize sequences and pad up to the longest sequence in the batch 41 | ids = tokenizer.batch_encode_plus(sequence_examples, 42 | add_special_tokens=True, 43 | padding="longest", 44 | return_tensors='pt').to(device) 45 | 46 | # generate embeddings 47 | with torch.no_grad(): 48 | embedding_repr = model( 49 | ids.input_ids, 50 | attention_mask=ids.attention_mask 51 | ) 52 | 53 | # extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens, incl. prefix ([0,1:8]) 54 | emb_0 = embedding_repr.last_hidden_state[0,1:8] # shape (7 x 1024) 55 | # same for the second ([1,:]) sequence but taking into account different sequence lengths ([1,:6]) 56 | emb_1 = embedding_repr.last_hidden_state[1,1:6] # shape (5 x 1024) 57 | 58 | # if you want to derive a single representation (per-protein embedding) for the whole protein 59 | emb_0_per_protein = emb_0.mean(dim=0) # shape (1024) -------------------------------------------------------------------------------- /model_zoom/protgpt2_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | sequence='MGEAMGLTQPAVSRAVARLEERVGIRIFNRTARAITLTDEGRRFYEAVAPLLAGIEMHGYR\nVNVEGVAQLLELYARDILAEGRLVQLLPEWAD' 6 | 7 | #Convert the sequence to a string like this 8 | #(note we have to introduce new line characters every 60 amino acids, 9 | #following the FASTA file format). 10 | 11 | sequence = "<|endoftext|>\nMGEAMGLTQPAVSRAVARLEERVGIRIFNRTARAITLTDEGRRFYEAVAPLLAGIEMHGY\nRVNVEGVAQLLELYARDILAEGRLVQLLPEWAD\n<|endoftext|>" 12 | 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | tokenizer = AutoTokenizer.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProtGPT2") 16 | model = AutoModelForCausalLM.from_pretrained("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProtGPT2") 17 | model = model.to(device) 18 | 19 | input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0) 20 | input_ids = input_ids.to(device) 21 | with torch.no_grad(): 22 | outputs = model(input_ids, labels=input_ids) 23 | loss, logits = outputs[:2] 24 | 25 | print(logits.shape) 26 | 27 | -------------------------------------------------------------------------------- /model_zoom/protrek_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.data.protein import Protein 3 | from model_zoom.ProTrek.model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel 4 | from model_zoom.ProTrek.utils.foldseek_util import get_struc_seq 5 | import mini3di 6 | encoder = mini3di.Encoder() 7 | 8 | # Load model 9 | config = { 10 | "protein_config": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProTrek/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D", 11 | "text_config": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProTrek/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 12 | "structure_config": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProTrek/ProTrek_650M_UniRef50/foldseek_t30_150M", 13 | "load_protein_pretrained": False, 14 | "load_text_pretrained": False, 15 | "from_checkpoint": "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/ProTrek/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt" 16 | } 17 | 18 | device = "cuda" 19 | model = ProTrekTrimodalModel(**config).eval().to(device) 20 | 21 | # Load protein and text 22 | pdb_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/N128/8rk2A.pdb" 23 | # seqs = get_struc_seq("bin/foldseek", pdb_path, ["A"])["A"] 24 | # aa_seq = seqs[0] 25 | # foldseek_seq = seqs[1].lower() 26 | 27 | data = Protein.from_PDB(pdb_path) 28 | X, C, S = data.to_XCS(all_atom=True) 29 | N, CA, C, O, CB = X[0,:, 0], X[0,:, 1], X[0,:, 2], X[0,:, 3], X[0,:, 4] 30 | states = encoder.encode_atoms(ca = CA.numpy(), cb = CB.numpy(), n = N.numpy(), c = C.numpy()) 31 | foldseek_seq = encoder.build_sequence(states).lower() 32 | aa_seq = data.sequence() 33 | text = "Replication initiator in the monomeric form, and autogenous repressor in the dimeric form." 34 | 35 | with torch.no_grad(): 36 | # Obtain protein sequence embedding 37 | seq_embedding = model.get_protein_repr([aa_seq]) 38 | print("Protein sequence embedding shape:", seq_embedding.shape) 39 | 40 | # Obtain protein structure embedding 41 | struc_embedding = model.get_structure_repr([foldseek_seq]) 42 | print("Protein structure embedding shape:", struc_embedding.shape) 43 | 44 | # Obtain text embedding 45 | text_embedding = model.get_text_repr([text]) 46 | print("Text embedding shape:", text_embedding.shape) 47 | 48 | # Calculate similarity score between protein sequence and structure 49 | seq_struc_score = seq_embedding @ struc_embedding.T / model.temperature 50 | print("Similarity score between protein sequence and structure:", seq_struc_score.item()) 51 | 52 | # Calculate similarity score between protein sequence and text 53 | seq_text_score = seq_embedding @ text_embedding.T / model.temperature 54 | print("Similarity score between protein sequence and text:", seq_text_score.item()) 55 | 56 | # Calculate similarity score between protein structure and text 57 | struc_text_score = struc_embedding @ text_embedding.T / model.temperature 58 | print("Similarity score between protein structure and text:", struc_text_score.item()) 59 | 60 | 61 | """ 62 | Protein sequence embedding shape: torch.Size([1, 1024]) 63 | Protein structure embedding shape: torch.Size([1, 1024]) 64 | Text embedding shape: torch.Size([1, 1024]) 65 | Similarity score between protein sequence and structure: 28.506675720214844 66 | Similarity score between protein sequence and text: 17.842409133911133 67 | Similarity score between protein structure and text: 11.866174697875977 68 | """ -------------------------------------------------------------------------------- /model_zoom/prott5_example.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5EncoderModel 2 | import torch 3 | import re 4 | 5 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 6 | 7 | # Load the tokenizer 8 | prott5_weight_path = "/nfs_beijing/kubeflow-user/wanghao/workspace/ai4sci/protein_benchmark_new/protein_benchmark/model_zoom/ProtTrans" 9 | 10 | tokenizer = T5Tokenizer.from_pretrained(prott5_weight_path, do_lower_case=False) 11 | 12 | # Load the model 13 | model = T5EncoderModel.from_pretrained(prott5_weight_path).to(device) 14 | 15 | # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower) 16 | if device == torch.device("cpu"): 17 | model.to(torch.float32) 18 | 19 | # prepare your protein sequences as a list 20 | sequence_examples = ["PRTEINO", "SEQWENCE"] 21 | 22 | # replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids 23 | sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples] 24 | 25 | # tokenize sequences and pad up to the longest sequence in the batch 26 | ids = tokenizer.encode(sequence_examples, add_special_tokens=True, padding="longest") 27 | 28 | input_ids = torch.tensor(ids['input_ids']).to(device) 29 | attention_mask = torch.tensor(ids['attention_mask']).to(device) 30 | 31 | # generate embeddings 32 | with torch.no_grad(): 33 | embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask) 34 | 35 | # extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens ([0,:7]) 36 | # emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024) 37 | # same for the second ([1,:]) sequence but taking into account different sequence lengths ([1,:8]) 38 | # emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024) 39 | 40 | # if you want to derive a single representation (per-protein embedding) for the whole protein 41 | # emb_0_per_protein = emb_0.mean(dim=0) # shape (1024) 42 | print(embedding_repr.last_hidden_state.shape) 43 | -------------------------------------------------------------------------------- /model_zoom/saport_example.py: -------------------------------------------------------------------------------- 1 | from transformers import EsmTokenizer, EsmForMaskedLM 2 | from src.data.protein import Protein 3 | import mini3di 4 | encoder = mini3di.Encoder() 5 | 6 | model_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/SaPort/ckpt" # Note this is the directory path of SaProt, not the ".pt" file 7 | tokenizer = EsmTokenizer.from_pretrained(model_path) 8 | model = EsmForMaskedLM.from_pretrained(model_path) 9 | 10 | #################### Example #################### 11 | device = "cuda" 12 | model.to(device) 13 | 14 | data = Protein.from_PDB('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/N128/8rk2A.pdb') 15 | X, C, S = data.to_XCS(all_atom=True) 16 | N, CA, C, O, CB = X[0,:, 0], X[0,:, 1], X[0,:, 2], X[0,:, 3], X[0,:, 4] 17 | states = encoder.encode_atoms(ca = CA.numpy(), cb = CB.numpy(), n = N.numpy(), c = C.numpy()) 18 | struct_sequence = encoder.build_sequence(states) 19 | AA_sequence = data.sequence() 20 | 21 | merged_seq = ''.join(a + b.lower() for a, b in zip(AA_sequence, struct_sequence)) 22 | 23 | # seq = "M#EvVpQpL#VyQdYaKv" # Here "#" represents lower plDDT regions (plddt < 70) 24 | tokens = tokenizer.tokenize(merged_seq) 25 | print(tokens) 26 | 27 | inputs = tokenizer(merged_seq, return_tensors="pt") 28 | inputs = {k: v.to(device) for k, v in inputs.items()} 29 | 30 | outputs = model(**inputs) 31 | print(outputs.logits.shape) 32 | 33 | """ 34 | ['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] 35 | torch.Size([1, 11, 446]) 36 | """ -------------------------------------------------------------------------------- /model_zoom/venusplm_example.py: -------------------------------------------------------------------------------- 1 | from vplm import TransformerForMaskedLM, TransformerConfig 2 | from vplm import VPLMTokenizer 3 | import torch 4 | 5 | venusplm_weight_path = "YOUR WEIGHTS PATH" 6 | config = TransformerConfig.from_pretrained(venusplm_weight_path, attn_impl="sdpa") # or "flash_attn" if you have installed flash-attn 7 | model = TransformerForMaskedLM.from_pretrained(venusplm_weight_path, config=config) 8 | tokenizer = VPLMTokenizer.from_pretrained(venusplm_weight_path) 9 | 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | sequence = "MALWMRLLPLLALLALWGPDPAAA" 12 | encoded_sequence = tokenizer(sequence, return_tensors="pt").to(device) 13 | 14 | input_ids = encoded_sequence["input_ids"] 15 | attention_mask = encoded_sequence["attention_mask"] 16 | 17 | with torch.no_grad(): 18 | outputs = model( 19 | input_ids=input_ids, 20 | attention_mask=attention_mask, 21 | output_hidden_states=True 22 | ) 23 | 24 | hidden_states = outputs.hidden_states[-1] 25 | print(hidden_states.shape) # [1, L, 1024] 26 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "**/__pycache__", 4 | "**/build", 5 | "**/dist", 6 | "**/*.egg-info", 7 | "**/*.so", 8 | "**/*.swp", 9 | "**/.kube", 10 | "**/tmp", 11 | "**/ds-configs", 12 | "**/wandb", 13 | "**/logs", 14 | "**/bookcorpus_data", 15 | // "**/model_zoom", 16 | "**/N128", 17 | "**/old-sota-yamls", 18 | "**/outputs", 19 | "**/datasets", 20 | "**/results", 21 | "**/*.log", 22 | "**/*.pdb", 23 | "**/slurm*", 24 | "**/.coverage_*", 25 | "**/events.out.tfevents.*", 26 | "**/*~" 27 | ], 28 | "include": ["."], 29 | "pythonVersion": "3.11", 30 | "typeCheckingMode": "basic", 31 | "reportAttributeAccessIssue": false, 32 | "reportOptionalCall": false, 33 | "reportOptionalMemberAccess": false, 34 | "reportArgumentType": false 35 | } 36 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from src.data.constants.geometry import AA_GEOMETRY 16 | from src.data.constants.sequence import * 17 | -------------------------------------------------------------------------------- /src/data/constants/sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants used across protein representations. 16 | 17 | These constants standardize protein tokenization alphabets, ideal structure 18 | geometries and topologies, etc. 19 | """ 20 | from src.data.constants.geometry import AA_GEOMETRY 21 | 22 | # Standard tokenization for Omniprot and Omniprot-interacting models 23 | OMNIPROT_TOKENS = "ABCDEFGHIKLMNOPQRSTUVWYXZ*-#" 24 | POTTS_EXTENDED_TOKENS = "ACDEFGHIKLMNPQRSTVWY-*#" 25 | PAD = "-" 26 | START = "@" 27 | STOP = "*" 28 | MASK = "#" 29 | DNA_TOKENS = "ACGT" 30 | RNA_TOKENS = "AGCU" 31 | PROTEIN_TOKENS = "ACDEFGHIKLMNPQRSTVWY" 32 | 33 | # Minimal 20-letter alphabet and corresponding triplet codes 34 | AA20 = "ACDEFGHIKLMNPQRSTVWY" 35 | AA20_3_TO_1 = { 36 | "ALA": "A", 37 | "ARG": "R", 38 | "ASN": "N", 39 | "ASP": "D", 40 | "CYS": "C", 41 | "GLN": "Q", 42 | "GLU": "E", 43 | "GLY": "G", 44 | "HIS": "H", 45 | "ILE": "I", 46 | "LEU": "L", 47 | "LYS": "K", 48 | "MET": "M", 49 | "PHE": "F", 50 | "PRO": "P", 51 | "SER": "S", 52 | "THR": "T", 53 | "TRP": "W", 54 | "TYR": "Y", 55 | "VAL": "V", 56 | } 57 | AA20_1_TO_3 = { 58 | "A": "ALA", 59 | "R": "ARG", 60 | "N": "ASN", 61 | "D": "ASP", 62 | "C": "CYS", 63 | "Q": "GLN", 64 | "E": "GLU", 65 | "G": "GLY", 66 | "H": "HIS", 67 | "I": "ILE", 68 | "L": "LEU", 69 | "K": "LYS", 70 | "M": "MET", 71 | "F": "PHE", 72 | "P": "PRO", 73 | "S": "SER", 74 | "T": "THR", 75 | "W": "TRP", 76 | "Y": "TYR", 77 | "V": "VAL", 78 | } 79 | AA20_3 = [AA20_1_TO_3[aa] for aa in AA20] 80 | 81 | # Adding noncanonical amino acids 82 | NONCANON_AA = [ 83 | "HSD", 84 | "HSE", 85 | "HSC", 86 | "HSP", 87 | "MSE", 88 | "CSO", 89 | "SEC", 90 | "CSX", 91 | "HIP", 92 | "SEP", 93 | "TPO", 94 | ] 95 | AA31_3 = AA20_3 + NONCANON_AA 96 | 97 | # Chain alphabet for PDB chain naming 98 | CHAIN_ALPHABET = "_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 99 | 100 | # Standard atom indexing 101 | ATOMS_BB = ["N", "CA", "C", "O"] 102 | 103 | ATOM_SYMMETRIES = { 104 | "ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling 105 | "ASP": [("OD1", "OD2")], 106 | "GLU": [("OE1", "OE2")], 107 | "PHE": [("CD1", "CD2"), ("CE1", "CE2")], 108 | "TYR": [("CD1", "CD2"), ("CE1", "CE2")], 109 | } 110 | 111 | AA20_NUM_ATOMS = [4 + len(AA_GEOMETRY[aa]["atoms"]) for aa in AA20_3] 112 | AA20_NUM_CHI = [len(AA_GEOMETRY[aa]["chi_indices"]) for aa in AA20_3] 113 | -------------------------------------------------------------------------------- /src/data/esm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.1.7" 2 | 3 | -------------------------------------------------------------------------------- /src/data/esm/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from src.data.esm.sdk.forge import ESM3ForgeInferenceClient 4 | from src.data.esm.utils.forge_context_manager import ForgeBatchExecutor 5 | 6 | # Note: please do not import ESM3SageMakerClient here since that requires AWS SDK. 7 | 8 | 9 | def client( 10 | model="esm3-sm-open-v1", 11 | url="https://forge.evolutionaryscale.ai", 12 | token=os.environ.get("ESM_API_KEY", ""), 13 | request_timeout=None, 14 | ): 15 | """ 16 | Args: 17 | model: Name of the model to use. 18 | url: URL of a forge server. 19 | token: User's API token. 20 | request_timeout: Amount of time to wait for a request to finish. 21 | Default is wait indefinitely. 22 | """ 23 | return ESM3ForgeInferenceClient(model, url, token, request_timeout) 24 | 25 | 26 | def batch_executor(max_attempts: int = 10): 27 | """ 28 | Args: 29 | max_attempts: Maximum number of attempts to make before giving up. 30 | 31 | Usage: 32 | with batch_executor() as executor: 33 | executor.execute_batch(fn, **kwargs) 34 | """ 35 | return ForgeBatchExecutor(max_attempts) 36 | -------------------------------------------------------------------------------- /src/data/esm/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Protocol 3 | 4 | from src.data.esm.utils.constants.models import ( 5 | ESM3_OPEN_SMALL, 6 | normalize_model_name, 7 | ) 8 | 9 | from .function_tokenizer import InterProQuantizedTokenizer 10 | from .residue_tokenizer import ResidueAnnotationsTokenizer 11 | from .sasa_tokenizer import SASADiscretizingTokenizer 12 | from .sequence_tokenizer import EsmSequenceTokenizer 13 | from .ss_tokenizer import SecondaryStructureTokenizer 14 | from .structure_tokenizer import StructureTokenizer 15 | from .tokenizer_base import EsmTokenizerBase 16 | 17 | 18 | class TokenizerCollectionProtocol(Protocol): 19 | sequence: EsmSequenceTokenizer 20 | structure: StructureTokenizer 21 | secondary_structure: SecondaryStructureTokenizer 22 | sasa: SASADiscretizingTokenizer 23 | function: InterProQuantizedTokenizer 24 | residue_annotations: ResidueAnnotationsTokenizer 25 | 26 | 27 | @dataclass 28 | class TokenizerCollection: 29 | sequence: EsmSequenceTokenizer 30 | structure: StructureTokenizer 31 | secondary_structure: SecondaryStructureTokenizer 32 | sasa: SASADiscretizingTokenizer 33 | function: InterProQuantizedTokenizer 34 | residue_annotations: ResidueAnnotationsTokenizer 35 | 36 | 37 | def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: 38 | if normalize_model_name(model) == ESM3_OPEN_SMALL: 39 | return TokenizerCollection( 40 | sequence=EsmSequenceTokenizer(), 41 | structure=StructureTokenizer(), 42 | secondary_structure=SecondaryStructureTokenizer(kind="ss8"), 43 | sasa=SASADiscretizingTokenizer(), 44 | function=InterProQuantizedTokenizer(), 45 | residue_annotations=ResidueAnnotationsTokenizer(), 46 | ) 47 | else: 48 | raise ValueError(f"Unknown model: {model}") 49 | 50 | 51 | def get_esmc_model_tokenizers() -> EsmSequenceTokenizer: 52 | return EsmSequenceTokenizer() 53 | 54 | 55 | def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: 56 | if isinstance(tokenizer, EsmSequenceTokenizer): 57 | return [ 58 | tokenizer.mask_token_id, # type: ignore 59 | tokenizer.pad_token_id, # type: ignore 60 | tokenizer.cls_token_id, # type: ignore 61 | tokenizer.eos_token_id, # type: ignore 62 | ] 63 | else: 64 | return [ 65 | tokenizer.mask_token_id, 66 | tokenizer.pad_token_id, 67 | tokenizer.bos_token_id, 68 | tokenizer.eos_token_id, 69 | ] 70 | -------------------------------------------------------------------------------- /src/data/esm/tokenization/structure_tokenizer.py: -------------------------------------------------------------------------------- 1 | from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase 2 | from src.data.esm.utils.constants import esm3 as C 3 | 4 | 5 | class StructureTokenizer(EsmTokenizerBase): 6 | """A convenince class for accessing special token ids of 7 | the StructureTokenEncoder and StructureTokenDecoder.""" 8 | 9 | def __init__(self, codebook_size: int = C.VQVAE_CODEBOOK_SIZE): 10 | self.vq_vae_special_tokens = { 11 | "MASK": codebook_size, 12 | "EOS": codebook_size + 1, 13 | "BOS": codebook_size + 2, 14 | "PAD": codebook_size + 3, 15 | "CHAINBREAK": codebook_size + 4, 16 | } 17 | 18 | def mask_token(self) -> str: 19 | raise NotImplementedError( 20 | "Structure tokens are defined on 3D coordinates, not strings." 21 | ) 22 | 23 | @property 24 | def mask_token_id(self) -> int: 25 | return self.vq_vae_special_tokens["MASK"] 26 | 27 | def bos_token(self) -> str: 28 | raise NotImplementedError( 29 | "Structure tokens are defined on 3D coordinates, not strings." 30 | ) 31 | 32 | @property 33 | def bos_token_id(self) -> int: 34 | return self.vq_vae_special_tokens["BOS"] 35 | 36 | def eos_token(self) -> str: 37 | raise NotImplementedError( 38 | "Structure tokens are defined on 3D coordinates, not strings." 39 | ) 40 | 41 | @property 42 | def eos_token_id(self) -> int: 43 | return self.vq_vae_special_tokens["EOS"] 44 | 45 | def pad_token(self) -> str: 46 | raise NotImplementedError( 47 | "Structure tokens are defined on 3D coordinates, not strings." 48 | ) 49 | 50 | @property 51 | def pad_token_id(self) -> int: 52 | return self.vq_vae_special_tokens["PAD"] 53 | 54 | def chain_break_token(self) -> str: 55 | raise NotImplementedError( 56 | "Structure tokens are defined on 3D coordinates, not strings." 57 | ) 58 | 59 | @property 60 | def chain_break_token_id(self) -> int: 61 | return self.vq_vae_special_tokens["CHAINBREAK"] 62 | 63 | @property 64 | def all_token_ids(self): 65 | return list(range(C.VQVAE_CODEBOOK_SIZE + len(self.vq_vae_special_tokens))) 66 | 67 | @property 68 | def special_token_ids(self): 69 | return self.vq_vae_special_tokens.values() 70 | 71 | def encode(self, *args, **kwargs): 72 | raise NotImplementedError( 73 | "The StructureTokenizer class is provided as a convenience for " 74 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 75 | "Please use them instead." 76 | ) 77 | 78 | def decode(self, *args, **kwargs): 79 | raise NotImplementedError( 80 | "The StructureTokenizer class is provided as a convenience for " 81 | "accessing special token ids of the StructureTokenEncoder and StructureTokenDecoder.\n" 82 | "Please use them instead." 83 | ) 84 | -------------------------------------------------------------------------------- /src/data/esm/tokenization/tokenizer_base.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | 4 | @runtime_checkable 5 | class EsmTokenizerBase(Protocol): 6 | mask_token: str 7 | mask_token_id: int 8 | bos_token: str 9 | bos_token_id: int 10 | eos_token: str 11 | eos_token_id: int 12 | pad_token: str 13 | pad_token_id: int 14 | chain_break_token: str 15 | chain_break_token_id: int 16 | 17 | def encode(self, *args, **kwargs): ... 18 | 19 | def decode(self, *args, **kwargs): ... 20 | 21 | @property 22 | def all_token_ids(self): ... 23 | 24 | @property 25 | def special_token_ids(self): ... 26 | -------------------------------------------------------------------------------- /src/data/esm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/data/esm/utils/__init__.py -------------------------------------------------------------------------------- /src/data/esm/utils/constants/api.py: -------------------------------------------------------------------------------- 1 | MAX_TOPK_SEQUENCE = 32 2 | MAX_TOPK_STRUCTURE = MAX_TOPK_SEQUENCE 3 | MAX_TOPK_SECONDARY_STRUCTURE = MAX_TOPK_SEQUENCE 4 | MAX_TOPK_SASA = MAX_TOPK_SEQUENCE 5 | MAX_TOPK_FUNCTION = MAX_TOPK_SEQUENCE 6 | -------------------------------------------------------------------------------- /src/data/esm/utils/constants/esm3.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import cache 3 | from pathlib import Path 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | SEQUENCE_BOS_TOKEN = 0 8 | SEQUENCE_PAD_TOKEN = 1 9 | SEQUENCE_EOS_TOKEN = 2 10 | SEQUENCE_CHAINBREAK_TOKEN = 31 11 | SEQUENCE_MASK_TOKEN = 32 12 | 13 | VQVAE_CODEBOOK_SIZE = 4096 14 | VQVAE_SPECIAL_TOKENS = { 15 | "MASK": VQVAE_CODEBOOK_SIZE, 16 | "EOS": VQVAE_CODEBOOK_SIZE + 1, 17 | "BOS": VQVAE_CODEBOOK_SIZE + 2, 18 | "PAD": VQVAE_CODEBOOK_SIZE + 3, 19 | "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, 20 | } 21 | VQVAE_DIRECTION_LOSS_BINS = 16 22 | VQVAE_PAE_BINS = 64 23 | VQVAE_MAX_PAE_BIN = 31.0 24 | VQVAE_PLDDT_BINS = 50 25 | 26 | STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"] 27 | STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"] 28 | STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"] 29 | STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"] 30 | STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"] 31 | STRUCTURE_UNDEFINED_TOKEN = 955 32 | 33 | SASA_PAD_TOKEN = 0 34 | 35 | SS8_PAD_TOKEN = 0 36 | 37 | INTERPRO_PAD_TOKEN = 0 38 | 39 | RESIDUE_PAD_TOKEN = 0 40 | 41 | CHAIN_BREAK_STR = "|" 42 | 43 | SEQUENCE_BOS_STR = "" 44 | SEQUENCE_EOS_STR = "" 45 | 46 | MASK_STR_SHORT = "_" 47 | SEQUENCE_MASK_STR = "" 48 | SASA_MASK_STR = "" 49 | SS8_MASK_STR = "" 50 | 51 | # fmt: off 52 | SEQUENCE_VOCAB = [ 53 | "", "", "", "", 54 | "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", 55 | "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", 56 | "O", ".", "-", "|", 57 | "", 58 | ] 59 | # fmt: on 60 | 61 | SSE_8CLASS_VOCAB = "GHITEBSC" 62 | SSE_3CLASS_VOCAB = "HEC" 63 | SSE_8CLASS_TO_3CLASS_MAP = { 64 | "G": "H", 65 | "H": "H", 66 | "I": "H", 67 | "T": "C", 68 | "E": "E", 69 | "B": "E", 70 | "S": "C", 71 | "C": "C", 72 | } 73 | 74 | SASA_DISCRETIZATION_BOUNDARIES = [ 75 | 0.8, 76 | 4.0, 77 | 9.6, 78 | 16.4, 79 | 24.5, 80 | 32.9, 81 | 42.0, 82 | 51.5, 83 | 61.2, 84 | 70.9, 85 | 81.6, 86 | 93.3, 87 | 107.2, 88 | 125.4, 89 | 151.4, 90 | ] 91 | 92 | MAX_RESIDUE_ANNOTATIONS = 16 93 | 94 | 95 | TFIDF_VECTOR_SIZE = 58641 96 | 97 | 98 | @staticmethod 99 | @cache 100 | def data_root(model: str): 101 | if "INFRA_PROVIDER" in os.environ: 102 | return Path("") 103 | # Try to download from hugginface if it doesn't exist 104 | if model.startswith("esm3"): 105 | # path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1")) 106 | path = Path('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/esm3-sm-open-v1') 107 | elif model.startswith("esmc-300"): 108 | path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12")) 109 | elif model.startswith("esmc-600"): 110 | # path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12")) 111 | path = Path('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/esmc-600m-2024-12') 112 | else: 113 | raise ValueError(f"{model=} is an invalid model name.") 114 | return path 115 | 116 | 117 | IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data" 118 | 119 | INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list" 120 | INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt" 121 | INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt" 122 | INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json" 123 | 124 | LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"} 125 | 126 | KEYWORDS_VOCABULARY = ( 127 | IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt" 128 | ) 129 | KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy" 130 | 131 | RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv" 132 | INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv" 133 | -------------------------------------------------------------------------------- /src/data/esm/utils/constants/models.py: -------------------------------------------------------------------------------- 1 | # Model names 2 | ESM3_OPEN_SMALL = "esm3_sm_open_v1" 3 | ESM3_OPEN_SMALL_ALIAS_1 = "esm3-open-2024-03" 4 | ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1" 5 | ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open" 6 | ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0" 7 | ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0" 8 | ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0" 9 | ESMC_600M = "esmc_600m" 10 | ESMC_300M = "esmc_300m" 11 | 12 | 13 | def forge_only_return_single_layer_hidden_states(model_name: str): 14 | return model_name.startswith("esmc-6b") 15 | 16 | 17 | def model_is_locally_supported(x: str): 18 | return x in { 19 | ESM3_OPEN_SMALL, 20 | ESM3_OPEN_SMALL_ALIAS_1, 21 | ESM3_OPEN_SMALL_ALIAS_2, 22 | ESM3_OPEN_SMALL_ALIAS_3, 23 | } 24 | 25 | 26 | def normalize_model_name(x: str): 27 | if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}: 28 | return ESM3_OPEN_SMALL 29 | return x 30 | -------------------------------------------------------------------------------- /src/data/esm/utils/constants/physics.py: -------------------------------------------------------------------------------- 1 | BB_COORDINATES = [ 2 | [0.5256, 1.3612, 0.0000], 3 | [0.0000, 0.0000, 0.0000], 4 | [-1.5251, 0.0000, 0.0000], 5 | ] 6 | -------------------------------------------------------------------------------- /src/data/esm/utils/function/lsh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cloudpathlib import AnyPath 3 | 4 | from src.data.esm.utils.types import PathLike 5 | 6 | 7 | class LSHTable: 8 | def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): 9 | if hyperplanes is None: 10 | hyperplanes = np.random.randn(n_bits, dim) 11 | hyperplanes = hyperplanes / np.linalg.norm( 12 | hyperplanes, axis=-1, keepdims=True 13 | ) 14 | else: 15 | assert hyperplanes.shape == (n_bits, dim), ( 16 | hyperplanes.shape, 17 | (n_bits, dim), 18 | ) 19 | assert hyperplanes is not None 20 | self.hyperplanes: np.ndarray = hyperplanes 21 | self.values = 1 << np.arange(n_bits) 22 | 23 | def __call__(self, array, tokenize: bool = True): 24 | similarity = self.hyperplanes @ array.T 25 | bits = np.where(similarity >= 0, 1, 0) 26 | if tokenize: 27 | tokens = bits.T @ self.values 28 | return tokens 29 | else: 30 | return bits.T 31 | 32 | 33 | class LSHTokenized: 34 | def __init__( 35 | self, 36 | n_bits: int, 37 | dim: int, 38 | num_tables: int = 1, 39 | filepath: PathLike | None = None, 40 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 41 | ): 42 | table_hyperplanes = None 43 | if filepath is not None: 44 | filepath = AnyPath(filepath) 45 | if not filepath.exists(): 46 | raise FileNotFoundError(filepath) 47 | table_hyperplanes = np.load(filepath) # type: ignore 48 | for i in range(num_tables): 49 | assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" 50 | elif not allow_create_hyperplanes: 51 | raise RuntimeError( 52 | "Not allowed to create hyperplanes but no filepath provided" 53 | ) 54 | 55 | self.tables = [ 56 | LSHTable( 57 | n_bits, 58 | dim, 59 | table_hyperplanes[str(i)] if table_hyperplanes is not None else None, 60 | ) 61 | for i in range(num_tables) 62 | ] 63 | 64 | def write_hyperplanes(self, filepath: PathLike): 65 | hyperplanes: dict[str, np.ndarray] = { # type: ignore 66 | str(i): table.hyperplanes for i, table in enumerate(self.tables) 67 | } 68 | np.savez(filepath, **hyperplanes) 69 | 70 | def __call__(self, array): 71 | tokens = np.stack([table(array) for table in self.tables], 1) 72 | return tokens 73 | 74 | 75 | class LSHBitstream: 76 | def __init__( 77 | self, 78 | n_bits: int, 79 | dim: int, 80 | filepath: PathLike | None = None, 81 | allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes 82 | ): 83 | table_hyperplanes = None 84 | if filepath is not None: 85 | filepath = AnyPath(filepath) 86 | if not filepath.exists(): 87 | raise FileNotFoundError(filepath) 88 | table_hyperplanes = np.load(filepath) 89 | elif not allow_create_hyperplanes: 90 | raise RuntimeError( 91 | "Not allowed to create hyperplanes but no filepath provided" 92 | ) 93 | 94 | self.table = LSHTable( 95 | n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None 96 | ) 97 | 98 | def write_hyperplanes(self, filepath: PathLike): 99 | np.save(filepath, self.table.hyperplanes) 100 | 101 | def __call__(self, array): 102 | return self.table(array, tokenize=False) 103 | -------------------------------------------------------------------------------- /src/data/esm/utils/function/tfidf.py: -------------------------------------------------------------------------------- 1 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model.""" 2 | 3 | from collections import Counter 4 | from functools import cached_property 5 | 6 | import numpy as np 7 | from cloudpathlib import AnyPath 8 | from scipy import sparse 9 | 10 | from src.data.esm.utils.types import PathLike 11 | 12 | 13 | class TFIDFModel: 14 | """Term-Frequency / Inverse Document Frequency (TF-IDF) model. 15 | Mimics sklearn.feature_extraction.text.TfidfVectorizer with sublinear_tf=True 16 | """ 17 | 18 | def __init__(self, vocabulary_path: PathLike, idf_path: PathLike): 19 | with AnyPath(vocabulary_path).open("r") as f: 20 | self.vocabulary = f.read().strip().split("\n") 21 | 22 | with AnyPath(idf_path).open("rb") as f: 23 | self.idf_ = np.load(f) 24 | 25 | assert self.idf_.ndim == 1 26 | assert ( 27 | len(self.idf_) == len(self.vocabulary) 28 | ), f"IDF size must match vocabulary size, got {len(self.idf_)} and {len(self.vocabulary)}" 29 | 30 | @cached_property 31 | def vocab_to_index(self) -> dict[str, int]: 32 | return {term: index for index, term in enumerate(self.vocabulary)} 33 | 34 | def encode(self, terms: list[str]) -> sparse.csr_matrix: 35 | """Encodes terms as TF-IDF vectors. 36 | 37 | Args: 38 | terms: list of terms to encode. 39 | 40 | Returns: 41 | TF-IDF vector encoded as sparse matrix of shape (1, num_terms) 42 | """ 43 | counter = Counter(filter(self.vocabulary.__contains__, terms)) 44 | indices = [self.vocab_to_index[term] for term in counter] 45 | 46 | tf = np.array([count for term, count in counter.items()]) 47 | idf = np.take(self.idf_, indices) 48 | 49 | values = (1 + np.log(tf)) * idf 50 | values /= np.linalg.norm(values) 51 | 52 | return sparse.csr_matrix( 53 | (values, (np.zeros_like(indices), indices)), shape=(1, len(self.vocabulary)) 54 | ) 55 | 56 | def decode(self, vec: sparse.csr_matrix) -> list[str]: 57 | """Extract terms from TF-IDF.""" 58 | return [self.vocabulary[i] for i in vec.indices] 59 | -------------------------------------------------------------------------------- /src/data/esm/utils/misc_test.py: -------------------------------------------------------------------------------- 1 | """Tests for misc.py""" 2 | 3 | from src.data.esm.utils.misc import merge_annotations 4 | from src.data.esm.utils.types import FunctionAnnotation 5 | 6 | 7 | def test_merge_annotations(): 8 | merged = merge_annotations( 9 | [ 10 | FunctionAnnotation("a", start=1, end=10), 11 | FunctionAnnotation("b", start=5, end=15), 12 | FunctionAnnotation("a", start=10, end=20), 13 | FunctionAnnotation("b", start=2, end=6), 14 | FunctionAnnotation("c", start=4, end=10), 15 | ] 16 | ) 17 | assert len(merged) == 3 18 | assert FunctionAnnotation("a", start=1, end=20) in merged 19 | assert FunctionAnnotation("b", start=2, end=15) in merged 20 | assert FunctionAnnotation("c", start=4, end=10) in merged 21 | 22 | 23 | def test_merge_annotations_gap(): 24 | merged = merge_annotations( 25 | [ 26 | FunctionAnnotation("a", start=1, end=10), 27 | FunctionAnnotation("a", start=13, end=20), # gap is 2 28 | FunctionAnnotation("a", start=24, end=30), 29 | ], 30 | merge_gap_max=2, 31 | ) 32 | 33 | assert len(merged) == 2 34 | assert FunctionAnnotation("a", 1, 20) in merged 35 | assert FunctionAnnotation("a", 24, 30) in merged 36 | -------------------------------------------------------------------------------- /src/data/esm/utils/noise_schedules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def cosine_schedule(t: torch.Tensor): 7 | # t is a tensor of size (batch_size,) with values between 0 and 1. This is the 8 | # schedule used in the MaskGIT paper 9 | return torch.cos(t * math.pi * 0.5) 10 | 11 | 12 | def cubic_schedule(t): 13 | return 1 - t**3 14 | 15 | 16 | def linear_schedule(t): 17 | return 1 - t 18 | 19 | 20 | def square_root_schedule(t): 21 | return 1 - torch.sqrt(t) 22 | 23 | 24 | def square_schedule(t): 25 | return 1 - t**2 26 | 27 | 28 | NOISE_SCHEDULE_REGISTRY = { 29 | "cosine": cosine_schedule, 30 | "linear": linear_schedule, 31 | "square_root_schedule": square_root_schedule, 32 | "cubic": cubic_schedule, 33 | "square": square_schedule, 34 | } 35 | -------------------------------------------------------------------------------- /src/data/esm/utils/residue_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This mapping is used when we need to store atom data in a format that requires 17 | # fixed atom data size for every residue (e.g. a numpy array). 18 | atom_types = [ 19 | "N", 20 | "CA", 21 | "C", 22 | "CB", 23 | "O", 24 | "CG", 25 | "CG1", 26 | "CG2", 27 | "OG", 28 | "OG1", 29 | "SG", 30 | "CD", 31 | "CD1", 32 | "CD2", 33 | "ND1", 34 | "ND2", 35 | "OD1", 36 | "OD2", 37 | "SD", 38 | "CE", 39 | "CE1", 40 | "CE2", 41 | "CE3", 42 | "NE", 43 | "NE1", 44 | "NE2", 45 | "OE1", 46 | "OE2", 47 | "CH2", 48 | "NH1", 49 | "NH2", 50 | "OH", 51 | "CZ", 52 | "CZ2", 53 | "CZ3", 54 | "NZ", 55 | "OXT", 56 | ] 57 | atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} 58 | atom_type_num = len(atom_types) # := 37. 59 | 60 | restype_1to3 = { 61 | "A": "ALA", 62 | "R": "ARG", 63 | "N": "ASN", 64 | "D": "ASP", 65 | "C": "CYS", 66 | "Q": "GLN", 67 | "E": "GLU", 68 | "G": "GLY", 69 | "H": "HIS", 70 | "I": "ILE", 71 | "L": "LEU", 72 | "K": "LYS", 73 | "M": "MET", 74 | "F": "PHE", 75 | "P": "PRO", 76 | "S": "SER", 77 | "T": "THR", 78 | "W": "TRP", 79 | "Y": "TYR", 80 | "V": "VAL", 81 | } 82 | -------------------------------------------------------------------------------- /src/data/esm/utils/sampling_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.data.esm.utils.sampling import sample_logits 5 | 6 | 7 | def test_sample_logits(): 8 | # batched input. temperature != 0.0. 9 | sampled = sample_logits( 10 | logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096)) 11 | ) 12 | assert sampled.shape == (64, 8) 13 | 14 | # batched input. temperature == 0.0. 15 | sampled = sample_logits( 16 | logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096)) 17 | ) 18 | assert sampled.shape == (64, 8) 19 | 20 | # non-batched input. temperature != 0.0. 21 | sampled = sample_logits( 22 | logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096)) 23 | ) 24 | assert sampled.shape == (8,) 25 | 26 | # non-batched input. temperature == 0.0. 27 | sampled = sample_logits( 28 | logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096)) 29 | ) 30 | assert sampled.shape == (8,) 31 | 32 | with pytest.raises(ValueError): 33 | sampled = sample_logits( 34 | logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[] 35 | ) 36 | 37 | 38 | test_sample_logits() 39 | -------------------------------------------------------------------------------- /src/data/esm/utils/structure/aligner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import Field, replace 4 | from typing import Any, ClassVar, Protocol, TypeVar 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from src.data.esm.utils.structure.protein_structure import ( 10 | compute_affine_and_rmsd, 11 | ) 12 | 13 | 14 | class Alignable(Protocol): 15 | atom37_positions: np.ndarray 16 | atom37_mask: np.ndarray 17 | # Trick to detect whether an object is a dataclass 18 | __dataclass_fields__: ClassVar[dict[str, Field[Any]]] 19 | 20 | def __len__(self) -> int: ... 21 | 22 | 23 | T = TypeVar("T", bound=Alignable) 24 | 25 | 26 | class Aligner: 27 | def __init__( 28 | self, 29 | mobile: Alignable, 30 | target: Alignable, 31 | only_use_backbone: bool = False, 32 | use_reflection: bool = False, 33 | ): 34 | """ 35 | Aligns a mobile protein chain against a target protein chain. 36 | 37 | Args: 38 | mobile (ProteinChain): Protein chain to be aligned. 39 | target (ProteinChain): Protein chain target. 40 | only_use_backbone (bool): Whether to only use backbone atoms. 41 | use_reflection (bool): Whether to align to target reflection. 42 | """ 43 | # Check proteins must have same number of residues 44 | assert len(mobile) == len(target) 45 | 46 | # Determine overlapping atoms 47 | joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype( 48 | bool 49 | ) 50 | 51 | # Backbone atoms are first sites in atom37 representation 52 | if only_use_backbone: 53 | joint_atom37_mask[:, 3:] = False 54 | 55 | # Extract matching atom positions and convert to batched tensors 56 | mobile_atom_tensor = ( 57 | torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0) 58 | ) 59 | target_atom_tensor = ( 60 | torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0) 61 | ) 62 | joint_atom37_mask = ( 63 | torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0) 64 | ) 65 | 66 | # If using reflection flip target 67 | if use_reflection: 68 | target_atom_tensor = -target_atom_tensor 69 | 70 | # Compute alignment and rmsd 71 | affine3D, rmsd = compute_affine_and_rmsd( 72 | mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask 73 | ) 74 | self._affine3D = affine3D 75 | self._rmsd = rmsd.item() 76 | 77 | @property 78 | def rmsd(self): 79 | return self._rmsd 80 | 81 | def apply(self, mobile: T) -> T: 82 | """Apply alignment to a protein chain""" 83 | # Extract atom positions and convert to batched tensors 84 | mobile_atom_tensor = ( 85 | torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask]) 86 | .type(torch.float32) 87 | .unsqueeze(0) 88 | ) 89 | 90 | # Transform atom arrays 91 | aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0) 92 | 93 | # Rebuild atom37 positions 94 | aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan) 95 | aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor 96 | 97 | return replace(mobile, atom37_positions=aligned_atom37_positions) 98 | -------------------------------------------------------------------------------- /src/data/esm/utils/structure/normalize_coordinates.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | from src.data.esm.utils import residue_constants as RC 8 | from src.data.esm.utils.structure.affine3d import Affine3D 9 | 10 | ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) 11 | 12 | 13 | def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: 14 | N, CA, C = bb_positions.unbind(dim=-2) 15 | return Affine3D.from_graham_schmidt(C, CA, N) 16 | 17 | 18 | def index_by_atom_name( 19 | atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 20 | ) -> ArrayOrTensor: 21 | squeeze = False 22 | if isinstance(atom_names, str): 23 | atom_names = [atom_names] 24 | squeeze = True 25 | indices = [RC.atom_order[atom_name] for atom_name in atom_names] 26 | dim = dim % atom37.ndim 27 | index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) 28 | result = atom37[index] # type: ignore 29 | if squeeze: 30 | result = result.squeeze(dim) 31 | return result 32 | 33 | 34 | def get_protein_normalization_frame(coords: Tensor) -> Affine3D: 35 | """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. 36 | Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame 37 | using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. 38 | 39 | Args: 40 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 41 | 42 | Returns: 43 | Affine3D: tensor of Affine3D frame 44 | """ 45 | bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) 46 | coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1) 47 | 48 | average_position_per_n_ca_c = bb_coords.masked_fill( 49 | ~coord_mask[..., None, None], 0 50 | ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) 51 | frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) 52 | 53 | return frame 54 | 55 | 56 | def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: 57 | """Given a set of coordinates and a single frame, apply the frame to the coordinates. 58 | 59 | Args: 60 | coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates 61 | frame (Affine3D): Affine3D frame 62 | 63 | Returns: 64 | torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates 65 | """ 66 | coords_trans_rot = frame[..., None, None].invert().apply(coords) 67 | 68 | # only transform coordinates with frame that have a valid rotation 69 | valid_frame = frame.trans.norm(dim=-1) > 0 70 | 71 | is_inf = torch.isinf(coords) 72 | coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) 73 | coords.masked_fill_(is_inf, torch.inf) 74 | 75 | return coords 76 | 77 | 78 | def normalize_coordinates(coords: Tensor) -> Tensor: 79 | return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) 80 | -------------------------------------------------------------------------------- /src/data/esm/utils/structure/predicted_aligned_error.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from src.data.esm.utils.structure.affine3d import Affine3D 5 | 6 | 7 | def masked_mean( 8 | mask: torch.Tensor, 9 | value: torch.Tensor, 10 | dim: int | None | tuple[int, ...] = None, 11 | eps=1e-10, 12 | ) -> torch.Tensor: 13 | """Compute the mean of `value` where only positions where `mask == true` are 14 | counted. 15 | """ 16 | mask = mask.expand(*value.shape) 17 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 18 | 19 | 20 | def _pae_bins( 21 | max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu") 22 | ): 23 | bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device) 24 | step = max_bin / (num_bins - 2) 25 | bin_centers = bins + step / 2 26 | bin_centers = torch.cat( 27 | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 28 | ) 29 | return bin_centers 30 | 31 | 32 | def _compute_pae_masks(mask: torch.Tensor): 33 | square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool() 34 | return square_mask 35 | 36 | 37 | def compute_predicted_aligned_error( 38 | logits: torch.Tensor, 39 | aa_mask: torch.Tensor, 40 | sequence_id: torch.Tensor | None = None, 41 | max_bin: float = 31, 42 | ) -> torch.Tensor: 43 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 44 | square_mask = _compute_pae_masks(aa_mask) 45 | min_v = torch.finfo(logits.dtype).min 46 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 47 | 48 | return (probs * bins).sum(dim=-1) 49 | 50 | 51 | @torch.no_grad 52 | def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0): 53 | square_mask = _compute_pae_masks(aa_mask) 54 | seqlens = aa_mask.sum(-1, keepdim=True) 55 | bins = _pae_bins(max_bin, logits.shape[-1], logits.device) 56 | d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8 57 | f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2) 58 | 59 | min_v = torch.finfo(logits.dtype).min 60 | probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1) 61 | # This is the sum over bins 62 | ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1) 63 | # This is the mean over residues j 64 | ptm = masked_mean(square_mask, ptm, dim=-1) 65 | # The we do a max over residues i 66 | return ptm.max(dim=-1).values 67 | 68 | 69 | def tm_loss( 70 | logits: torch.Tensor, 71 | pred_affine: torch.Tensor, 72 | targ_affine: torch.Tensor, 73 | targ_mask: torch.Tensor, 74 | tm_mask: torch.Tensor | None = None, 75 | sequence_id: torch.Tensor | None = None, 76 | max_bin: float = 31, 77 | ): 78 | pred = Affine3D.from_tensor(pred_affine) 79 | targ = Affine3D.from_tensor(targ_affine) 80 | 81 | def transform(affine: Affine3D): 82 | pts = affine.trans[..., None, :, :] 83 | return affine.invert()[..., None].apply(pts) 84 | 85 | with torch.no_grad(): 86 | sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1) 87 | 88 | num_bins = logits.shape[-1] 89 | sq_bins = torch.linspace( 90 | 0, max_bin, num_bins - 1, device=logits.device 91 | ).square() 92 | # Gets the bin id by using a sum. 93 | true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long() 94 | 95 | errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none") 96 | square_mask = _compute_pae_masks(targ_mask) 97 | loss = masked_mean(square_mask, errors, dim=(-1, -2)) 98 | 99 | if tm_mask is not None: 100 | loss = masked_mean(tm_mask, loss, dim=None) 101 | else: 102 | loss = loss.mean() 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /src/data/esm/utils/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | from cloudpathlib import CloudPath 9 | 10 | PathLike = Union[str, Path, CloudPath] 11 | PathOrBuffer = Union[PathLike, io.StringIO] 12 | 13 | 14 | @dataclass 15 | class FunctionAnnotation: 16 | """Represents an annotation of a protein's function over a range of residues. 17 | 18 | Fields: 19 | label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs 20 | start (int): Start index of this annotation. 1-indexed, inclusive. 21 | end (int): End index of this annotation. 1-indexed, inclusive. 22 | """ 23 | 24 | label: str 25 | start: int 26 | end: int 27 | 28 | def to_tuple(self) -> tuple[str, int, int]: 29 | return self.label, self.start, self.end 30 | 31 | def __len__(self) -> int: 32 | """Length of the annotation.""" 33 | return self.end - self.start + 1 34 | -------------------------------------------------------------------------------- /src/data/geo_featurizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from src.data.affine3d import build_affine3d_from_coordinates 5 | 6 | def rbf_func(D, num_rbf): 7 | dtype = D.dtype 8 | device = D.device 9 | shape = D.shape 10 | D_min, D_max, D_count = 0., 20., num_rbf 11 | D_mu = torch.linspace(D_min, D_max, D_count, dtype=dtype, device=device) 12 | D_mu = D_mu.view([1]*(len(shape))+[-1]) 13 | D_sigma = (D_max - D_min) / D_count 14 | D_expand = torch.unsqueeze(D, -1) 15 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) 16 | return RBF 17 | 18 | def decouple(U, num_rbf): 19 | norm = U.norm(dim=-1, keepdim=True) 20 | mask = norm<1e-4 21 | direct = U/(norm+1e-6) 22 | direct[mask[...,0]] = 0 23 | rbf = rbf_func(norm[...,0], num_rbf) 24 | return torch.cat([direct, rbf], dim=-1) 25 | 26 | class GeoFeaturizer(nn.Module): 27 | def __init__(self, *args, **kwargs) -> None: 28 | super().__init__(*args, **kwargs) 29 | 30 | @classmethod 31 | @torch.no_grad() 32 | def from_X_to_features(self, X): 33 | ''' 34 | X: [N,L,4,3] 35 | ''' 36 | T, mask = build_affine3d_from_coordinates(X[...,:3,:]) 37 | T_ts = T[:,:,None].invert().compose(T[:,None]) 38 | 39 | V, E = self.get_interact_feats(T, T_ts, X) 40 | 41 | return V, E, mask 42 | 43 | @classmethod 44 | @torch.no_grad() 45 | def get_interact_feats(self, T, T_ts, X): 46 | dtype = X.dtype 47 | device = X.device 48 | B, L, m, d = X.shape 49 | 50 | ## =========== node feature =========== 51 | diffX = F.pad(X.view(B,-1,d).diff(dim=1), (0,0,1,0)).view(B,L,m,d) 52 | diffX_proj = T[...,None].invert().rot.apply(diffX) 53 | V = decouple(diffX_proj, 16).view(B, L, -1) 54 | 55 | ## =========== pair feature =========== 56 | diffE = T[:,:,None,None].invert().apply(X[:,None,...]) 57 | diffE = decouple(diffE, 16).view(B,L,L, -1) 58 | 59 | E_quant = T_ts.invert().rot._rots.reshape(B,L,L,9) 60 | E_trans = T_ts.trans 61 | E_trans = decouple(E_trans, 16).view(B,L,L,-1) 62 | 63 | E = torch.cat([diffE, E_quant, E_trans], dim=-1) 64 | return V.to(X.dtype), E.to(X.dtype) 65 | 66 | 67 | if __name__ == '__main__': 68 | X = torch.rand(2, 10, 4, 3) 69 | geofeaturizer = GeoFeaturizer() 70 | V, E, attn_mask = geofeaturizer.from_X_to_features(X) 71 | print() 72 | 73 | -------------------------------------------------------------------------------- /src/data/protein/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/data/protein/.DS_Store -------------------------------------------------------------------------------- /src/data/protein/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Generate Biomedicines, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This package includes io formats and tools for a few common datatypes, 17 | including antibodies, proteins, sequences, and structures. 18 | """ 19 | from src.data.protein.protein import Protein 20 | -------------------------------------------------------------------------------- /src/data/proteingym_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, glob 3 | import pandas as pd 4 | import numpy as np 5 | 6 | 7 | class ProteinGYMDataset(torch.utils.data.Dataset): 8 | def __init__( 9 | self, 10 | dms_csv_dir: str = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DMS_ProteinGym_substitutions", 11 | dms_pdb_dir: str = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DMS_ProteinGym_substitutions/ProteinGym_AF2_structures", 12 | dms_reference_csv_path: str = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DMS_ProteinGym_substitutions/ProteinGym_AF2_structures/DMS_substitutions.csv", 13 | ): 14 | self.dms_csv_dir = dms_csv_dir 15 | self.pdb_dir = dms_pdb_dir 16 | self.dms_reference_csv_path = dms_reference_csv_path 17 | self.dms_reference_df = pd.read_csv(self.dms_reference_csv_path) 18 | self.dms_ids = self.dms_reference_df["DMS_id"].tolist() 19 | self.target_sequences = self.dms_reference_df["target_seq"].tolist() 20 | self.dms_csv_path = [os.path.join(self.dms_csv_dir, ele) for ele in self.dms_reference_df["DMS_filename"].tolist()] 21 | self.pdb_file_path = [os.path.join(self.pdb_dir, ele) for ele in self.dms_reference_df["pdb_file"].tolist()] 22 | self.pdb_file_ranges = [[int(ele.split("-")[0])-1, int(ele.split("-")[-1])] for ele in self.dms_reference_df["pdb_range"].tolist()] # 0-index 23 | 24 | def __len__(self): 25 | return len(self.dms_reference_df) 26 | 27 | def __getitem__( 28 | self, 29 | idx 30 | ): 31 | dms_id = self.dms_ids[idx] 32 | dms_csv_path = self.dms_csv_path[idx] 33 | target_sequence = self.target_sequences[idx] 34 | pdb_file_path = self.pdb_file_path[idx] 35 | pdb_range = self.pdb_file_ranges[idx] 36 | assert len(pdb_range) == 2, f"invalid pdb range: {pdb_range}" 37 | # target_sequence = target_sequence[pdb_range[0]:pdb_range[1]] 38 | 39 | return { 40 | "dms_id": dms_id, 41 | "dms_csv_path": dms_csv_path, 42 | "target_sequence": target_sequence, 43 | "pdb_file_path": pdb_file_path, 44 | "pdb_range": pdb_range, 45 | "max_length": pdb_range[1] - pdb_range[0] 46 | } 47 | 48 | 49 | if __name__ == "__main__": 50 | proteingym = ProteinGYMDataset() 51 | print(f"length of proteingym dataset: {len(proteingym)}...") 52 | -------------------------------------------------------------------------------- /src/interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/interface/__init__.py -------------------------------------------------------------------------------- /src/interface/data_interface.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import importlib 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class DInterface_base(pl.LightningDataModule): 8 | def __init__(self, num_workers=8, 9 | dataset='', 10 | **kwargs): 11 | super().__init__() 12 | self.save_hyperparameters() 13 | self.num_workers = num_workers 14 | self.dataset = dataset 15 | self.kwargs = kwargs 16 | self.batch_size = kwargs.get('batch_size', 4) 17 | self.task_name = kwargs.get("task_name") 18 | self.finetune_type = kwargs.get("finetune_type") 19 | print("batch_size", self.batch_size) 20 | print("task_name", self.task_name) 21 | 22 | # def setup(self, stage=None): 23 | # # Assign train/val datasets for use in dataloaders 24 | # if stage == 'fit' or stage is None: 25 | # self.trainset = self.instancialize(split = 'train') 26 | # self.valset = self.instancialize(split='valid') 27 | 28 | # # Assign test dataset for use in dataloader(s) 29 | # if stage == 'test' or stage is None: 30 | # self.testset = self.instancialize(split='test') 31 | 32 | def train_dataloader(self): 33 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3) 34 | 35 | def val_dataloader(self): 36 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 37 | 38 | def test_dataloader(self): 39 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 40 | 41 | def load_data_module(self): 42 | name = self.dataset 43 | # Change the `snake_case.py` file name to `CamelCase` class name. 44 | # Please always name your model file name as `snake_case.py` and 45 | # class name corresponding `CamelCase`. 46 | camel_name = ''.join([i.capitalize() for i in name.split('_')]) 47 | try: 48 | self.data_module = getattr(importlib.import_module( 49 | '.'+name, package=__package__), camel_name) 50 | except: 51 | raise ValueError( 52 | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') 53 | 54 | def instancialize(self, **other_args): 55 | """ Instancialize a model using the corresponding parameters 56 | from self.hparams dictionary. You can also input any args 57 | to overwrite the corresponding value in self.kwargs. 58 | """ 59 | if other_args['split'] == 'train': 60 | self.data_module = getattr(importlib.import_module( 61 | '.AF2DB_dataset', package='data'), 'Af2dbDataset') 62 | else: 63 | self.data_module = getattr(importlib.import_module( 64 | '.CASP15_dataset', package='data'), 'CASP15Dataset') 65 | 66 | class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] 67 | inkeys = self.kwargs.keys() 68 | args1 = {} 69 | for arg in class_args: 70 | if arg in inkeys: 71 | args1[arg] = self.kwargs[arg] 72 | args1.update(other_args) 73 | return self.data_module(**args1) 74 | -------------------------------------------------------------------------------- /src/interface/model_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn as nn 4 | import os 5 | import torch.optim.lr_scheduler as lrs 6 | import inspect 7 | 8 | 9 | class MInterface_base(pl.LightningModule): 10 | def __init__(self, model_name=None, loss=None, lr=None, **kargs): 11 | super().__init__() 12 | self.save_hyperparameters() 13 | self.load_model() 14 | self.configure_loss() 15 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) 16 | 17 | def forward(self, input): 18 | pass 19 | 20 | 21 | def training_step(self, batch, batch_idx, **kwargs): 22 | pass 23 | 24 | 25 | def validation_step(self, batch, batch_idx): 26 | pass 27 | 28 | def test_step(self, batch, batch_idx): 29 | # Here we just reuse the validation_step for testing 30 | return self.validation_step(batch, batch_idx) 31 | 32 | def on_validation_epoch_end(self): 33 | # Make the Progress Bar leave there 34 | self.print('') 35 | 36 | def get_schedular(self, optimizer, lr_scheduler='onecycle'): 37 | if lr_scheduler == 'step': 38 | scheduler = lrs.StepLR(optimizer, 39 | step_size=self.hparams.lr_decay_steps, 40 | gamma=self.hparams.lr_decay_rate) 41 | elif lr_scheduler == 'cosine': 42 | scheduler = lrs.CosineAnnealingLR(optimizer, 43 | T_max=self.hparams.lr_decay_steps) 44 | elif lr_scheduler == 'onecycle': 45 | scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, T_max=self.hparams.lr_decay_steps, three_phase=False) 46 | elif lr_scheduler == 'plateau': 47 | scheduler = lrs.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True) 48 | else: 49 | raise ValueError('Invalid lr_scheduler type!') 50 | 51 | return scheduler 52 | 53 | def configure_optimizers(self): 54 | if hasattr(self.hparams, 'weight_decay'): 55 | weight_decay = self.hparams.weight_decay 56 | else: 57 | weight_decay = 0 58 | 59 | optimizer_g = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8) 60 | 61 | 62 | schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler) 63 | if self.hparams.lr_scheduler == 'plateau': 64 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "epoch", "monitor": "val_loss"}] 65 | else: 66 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}] 67 | 68 | def lr_scheduler_step(self, *args, **kwargs): 69 | scheduler = self.lr_schedulers() 70 | if self.hparams.lr_scheduler != 'plateau': 71 | scheduler.step() 72 | 73 | 74 | def configure_devices(self): 75 | self.device = torch.device(self.hparams.device) 76 | 77 | def configure_loss(self): 78 | self.loss_function = nn.CrossEntropyLoss(reduction='none') 79 | 80 | def load_model(self): 81 | self.model = None 82 | 83 | def instancialize(self, Model, **other_args): 84 | """ Instancialize a model using the corresponding parameters 85 | from self.hparams dictionary. You can also input any args 86 | to overwrite the corresponding value in self.hparams. 87 | """ 88 | class_args = inspect.getargspec(Model.__init__).args[1:] 89 | inkeys = self.hparams.keys() 90 | args1 = {} 91 | for arg in class_args: 92 | if arg in inkeys: 93 | args1[arg] = getattr(self.hparams, arg) 94 | args1.update(other_args) 95 | return Model(**args1) 96 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/model/__init__.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/bce.py: -------------------------------------------------------------------------------- 1 | from baidubce.bce_client_configuration import BceClientConfiguration 2 | from baidubce.auth.bce_credentials import BceCredentials 3 | from baidubce.services.bos.bos_client import BosClient 4 | import hashlib 5 | import glob 6 | 7 | def seq_encoder(sequence, method='md5'): 8 | hasher = eval(f'hashlib.{method}') 9 | return hasher(sequence.encode(encoding='utf-8')).hexdigest() 10 | 11 | 12 | config = BceClientConfiguration( 13 | credentials = BceCredentials( 14 | '35420270cb5c46118d6729b692669e2b', 15 | '35474e577b514954b72a128a53304cab' 16 | ), 17 | endpoint = 'https://bj.bcebos.com' 18 | ) 19 | 20 | bos_client = BosClient(config) 21 | 22 | # response = bos_client.list_buckets() 23 | # for bucket in response.buckets: 24 | # print(bucket.name) 25 | 26 | 27 | if __name__ == "__main__": 28 | import pandas as pd 29 | all_csvs = list( 30 | glob.iglob( 31 | "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/*/*.csv", 32 | recursive=True 33 | ) 34 | ) 35 | all_csvs.sort() 36 | 37 | seqs = [] 38 | pdb_paths = [] 39 | 40 | for _csv in all_csvs: 41 | df = pd.read_csv(_csv) 42 | if "pdb_path" not in df or "aa_seq" not in df: 43 | print(_csv) 44 | continue 45 | _seqs = df["aa_seq"].tolist() 46 | _pdb_paths = df["pdb_path"].tolist() 47 | seqs.extend(_seqs) 48 | pdb_paths.extend(_pdb_paths) 49 | 50 | print(len(pdb_paths)) 51 | print(len(seqs)) 52 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | import os 3 | import shutil 4 | from omegaconf import OmegaConf 5 | 6 | class SetupCallback(Callback): 7 | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None): 8 | super().__init__() 9 | self.now = now 10 | self.logdir = logdir 11 | self.ckptdir = ckptdir 12 | self.cfgdir = cfgdir 13 | self.config = config 14 | 15 | self.argv_content = argv_content 16 | 17 | # 在pretrain例程开始时调用。 18 | def on_fit_start(self, trainer, pl_module): 19 | # Create logdirs and save configs 20 | os.makedirs(self.logdir, exist_ok=True) 21 | os.makedirs(self.ckptdir, exist_ok=True) 22 | os.makedirs(self.cfgdir, exist_ok=True) 23 | 24 | print("Project config") 25 | print(OmegaConf.to_yaml(self.config)) 26 | OmegaConf.save(self.config, 27 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 28 | 29 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: 30 | f.write(str(self.argv_content)) 31 | 32 | class BackupCodeCallback(Callback): 33 | def __init__(self, source_dir, backup_dir, ignore_patterns=None): 34 | super().__init__() 35 | self.source_dir = source_dir 36 | self.backup_dir = backup_dir 37 | self.ignore_patterns = ignore_patterns 38 | 39 | def on_train_start(self, trainer, pl_module): 40 | try: 41 | os.makedirs(self.backup_dir, exist_ok=True) 42 | if os.path.exists(self.backup_dir+'/code'): 43 | shutil.rmtree(self.backup_dir+'/code') 44 | shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns) 45 | 46 | print(f"Code file backed up to {self.backup_dir}") 47 | except: 48 | print(f"Fail in copying file backed up to {self.backup_dir}") -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def spearman_correlation(y_true, y_pred): 4 | # 排序并获取秩 5 | rank_true = torch.argsort(torch.argsort(y_true)) 6 | rank_pred = torch.argsort(torch.argsort(y_pred)) 7 | 8 | # 计算秩差的平方 9 | d_squared = torch.pow(rank_true - rank_pred, 2).float() 10 | 11 | # 样本数量 12 | n = y_true.size(0) 13 | 14 | # 计算 Spearman 相关系数 15 | spearman_corr = 1 - (6 * torch.sum(d_squared)) / (n * (n**2 - 1)) 16 | return spearman_corr.item() 17 | 18 | 19 | def f1_score_max(pred, target): 20 | """ 21 | F1 score with the optimal threshold. 22 | 23 | This function first enumerates all possible thresholds for deciding positive and negative 24 | samples, and then pick the threshold with the maximal F1 score. 25 | 26 | Parameters: 27 | pred (Tensor): predictions of shape :math:`(B, N)` 28 | target (Tensor): binary targets of shape :math:`(B, N)` 29 | """ 30 | order = pred.argsort(descending=True, dim=1) 31 | target = target.gather(1, order) 32 | precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) 33 | recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) 34 | is_start = torch.zeros_like(target).bool() 35 | is_start[:, 0] = 1 36 | is_start = torch.scatter(is_start, 1, order, is_start) 37 | 38 | all_order = pred.flatten().argsort(descending=True) 39 | order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] 40 | order = order.flatten() 41 | inv_order = torch.zeros_like(order) 42 | inv_order[order] = torch.arange(order.shape[0], device=order.device) 43 | is_start = is_start.flatten()[all_order] 44 | all_order = inv_order[all_order] 45 | precision = precision.flatten() 46 | recall = recall.flatten() 47 | all_precision = precision[all_order] - \ 48 | torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) 49 | all_precision = all_precision.cumsum(0) / is_start.cumsum(0) 50 | all_recall = recall[all_order] - \ 51 | torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) 52 | all_recall = all_recall.cumsum(0) / pred.shape[0] 53 | all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) 54 | return all_f1.max() 55 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biomap-research/PFMBench/53758ffcbdf1d79b5d125383e4dd52d6fd59d2a1/tasks/__init__.py -------------------------------------------------------------------------------- /tasks/configs/antibiotic_resistance.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/antibiotic_resistance" 3 | ex_name: "antibiotic_resistance_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/antibiotic_resistance/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/antibiotic_resistance/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/antibiotic_resistance/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm3_1.4b' 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "antibiotic_resistance" 31 | task_type: "classification" 32 | num_classes: 19 33 | # metric: "valid_acc" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/beta_lactamase_activity_prediction.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/beta_lactamase_activity_prediction" 3 | ex_name: "beta_lactamase_activity_prediction_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/beta-lactamase-activity-prediction/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/beta-lactamase-activity-prediction/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/beta-lactamase-activity-prediction/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "beta_lactamase_activity_prediction" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/binding_db.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/binding_db" 3 | ex_name: "binding_db_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | # train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/mmseq_outdir/train.csv" 8 | # val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/mmseq_outdir/val.csv" 9 | # test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/mmseq_outdir/test.csv" 10 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/binding_db_train.csv" 11 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/binding_db_valid.csv" 12 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/BindingDB/binding_db_test.csv" 13 | 14 | Training: 15 | epoch: 50 # end epoch 16 | lr: 1e-4 # Learning rate 17 | lr_scheduler: "cosine" 18 | check_val_every_n_epoch: 1 19 | seed: 2024 20 | batch_size: 8 21 | num_workers: 4 22 | seq_len: 1024 23 | 24 | Model: 25 | pretrain_model_name: 'esm2_650m' 26 | finetune_type: 'adapter' 27 | peft_type: 'lora' 28 | lora_r: 8 # 低秩矩阵的秩 29 | lora_alpha: 16 # LoRA 的 alpha 参数 30 | lora_dropout: 0.1 # Dropout 防止过拟合 31 | 32 | Task: 33 | task_name: "binding_db" 34 | task_type: "regression" 35 | num_classes: 1 36 | # metric: "valid_spearman" 37 | # direction: "max" 38 | 39 | hydra: 40 | output_subdir: null 41 | -------------------------------------------------------------------------------- /tasks/configs/cloning_clf.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/cloning_clf" 3 | ex_name: "cloning_clf_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/cloning_clf/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/cloning_clf/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/cloning_clf/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "cloning_clf" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/contact_map.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/contact_map" 3 | ex_name: "contact_map_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/contact_prediction_binary/contact_prediction_binary_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/contact_prediction_binary/contact_prediction_binary_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/contact_prediction_binary/contact_prediction_binary_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "contact_map" 30 | task_type: "contact" 31 | num_classes: 1 32 | # metric: "valid_Top(L/5)" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/deepet_topt.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deepet_topt" 3 | ex_name: "deepet_topt_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepET_Topt/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepET_Topt/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepET_Topt/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'protrek' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "deepet_topt" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/deeploc2_multi.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deep_loc2_multi" 3 | ex_name: "deep_loc2_multi_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLoc2Multi/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLoc2Multi/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLoc2Multi/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "deep_loc2_multi" 31 | task_type: "multi_labels_classification" 32 | num_classes: 10 33 | # metric: "valid_f1_max" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/deeploc_binary.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deep_loc_binary" 3 | ex_name: "deep_loc_binary_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocBinary/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocBinary/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocBinary/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "deep_loc_binary" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/deeploc_multi.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deeploc_multi" 3 | ex_name: "deeploc_multi_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocMulti/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocMulti/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepLocMulti/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "deeploc_multi" 30 | task_type: "classification" 31 | num_classes: 10 32 | # metric: "valid_acc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/deepsol.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deepsol" 3 | ex_name: "deepsol_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSol/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSol/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSol/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "deepsol" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/deepsolue.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/deepsolue" 3 | ex_name: "deepsolue_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSoluE/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSoluE/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/DeepSoluE/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "deepsolue" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/enzyme_catalytic_efficiency.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/enzyme_catalytic_efficiency" 3 | ex_name: "enzyme_catalytic_efficiency_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_catalytic_efficiency/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_catalytic_efficiency/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_catalytic_efficiency/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'protrek' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "enzyme_catalytic_efficiency" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/enzyme_commission_number.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/enzyme_commission_number" 3 | ex_name: "enzyme_commission_number_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_commission_number/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_commission_number/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/enzyme_commission_number/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "enzyme_commission_number" 31 | task_type: "multi_labels_classification" 32 | num_classes: 585 33 | # metric: "valid_f1_max" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/esol.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/esol" 3 | ex_name: "esol_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/eSOL/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/eSOL/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/eSOL/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "eSOL" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/fitness_prediction.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/fitness_prediction" 3 | ex_name: "fitness_prediction_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fitness_prediction/fitness_prediction_with_glmfold_structure_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fitness_prediction/fitness_prediction_with_glmfold_structure_val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fitness_prediction/fitness_prediction_with_glmfold_structure_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'protrek' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "fitness_prediction" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/flip_aav_sampled.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/flip_aav_sampled" 3 | ex_name: "flip_aav_sampled_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_AAV_sampled/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_AAV_sampled/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_AAV_sampled/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "flip_aav_sampled" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/flip_gb1_sampled.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/flip_gb1_sampled" 3 | ex_name: "flip_gb1_sampled_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_GB1_sampled/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_GB1_sampled/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/FLIP_GB1_sampled/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "flip_gb1_sampled" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/fold_prediction.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/fold_prediction" 3 | ex_name: "fold_prediction_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fold_prediction/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fold_prediction/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/fold_prediction/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "fold_prediction" 30 | task_type: "classification" 31 | num_classes: 1195 32 | # metric: "valid_acc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/go_biological_process.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/go_biological_process" 3 | ex_name: "go_biological_process_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_biological_process/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_biological_process/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_biological_process/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "go_biological_process" 30 | task_type: "multi_labels_classification" 31 | num_classes: 1943 32 | # metric: "valid_f1_max" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/go_cellular_component.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/go_cellular_component" 3 | ex_name: "go_cellular_component_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_cellular_component/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_cellular_component/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_cellular_component/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "go_cellular_component" 30 | task_type: "multi_labels_classification" 31 | num_classes: 320 32 | # metric: "valid_f1_max" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/go_molecular_function.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/go_molecular_function" 3 | ex_name: "go_molecular_function_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_molecular_function/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_molecular_function/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/GO_molecular_function/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'prostt5' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "go_molecular_function" 30 | task_type: "multi_labels_classification" 31 | num_classes: 489 32 | # metric: "valid_f1_max" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/human_ppi.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/human_ppi" 3 | ex_name: "human_ppi_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/human_ppi/human_ppi_with_pdb_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/human_ppi/human_ppi_with_pdb_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/human_ppi/human_ppi_with_pdb_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'venusplm' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "human_ppi" 30 | task_type: "pair_binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/material_production.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/material_production" 3 | ex_name: "material_production_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/material_production/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/material_production/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/material_production/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: "esm3_1.4b" 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "material_production" 31 | task_type: "binary_classification" 32 | num_classes: 1 33 | # metric: "valid_acc" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/metal_ion_binding.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/metal_ion_binding" 3 | ex_name: "metal_ion_binding_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/metal_ion_binding/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/metal_ion_binding/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/metal_ion_binding/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "metal_ion_binding" 31 | task_type: "binary_classification" 32 | num_classes: 1 33 | # metric: "valid_acc" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/optimal_ph.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/optimal_ph" 3 | ex_name: "optimal_ph_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/optimal_ph/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/optimal_ph/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/optimal_ph/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'protrek' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "optimal_ph" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/pdbind.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/pdbind" 3 | ex_name: "pdbind_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/pdbbind_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/pdbbind_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/pdbbind_test.csv" 10 | # train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/mmseq_outdir/train.csv" 11 | # val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/mmseq_outdir/val.csv" 12 | # test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/pdbind/mmseq_outdir/test.csv" 13 | 14 | Training: 15 | epoch: 50 # end epoch 16 | lr: 1e-4 # Learning rate 17 | lr_scheduler: "cosine" 18 | check_val_every_n_epoch: 1 19 | seed: 2024 20 | batch_size: 8 21 | num_workers: 4 22 | seq_len: 1024 23 | 24 | Model: 25 | pretrain_model_name: 'esm2_650m' 26 | finetune_type: 'adapter' 27 | peft_type: 'lora' 28 | lora_r: 8 # 低秩矩阵的秩 29 | lora_alpha: 16 # LoRA 的 alpha 参数 30 | lora_dropout: 0.1 # Dropout 防止过拟合 31 | 32 | Task: 33 | task_name: "pdbind" 34 | task_type: "regression" 35 | num_classes: 1 36 | # metric: "valid_spearman" 37 | # direction: "max" 38 | 39 | hydra: 40 | output_subdir: null 41 | -------------------------------------------------------------------------------- /tasks/configs/peptide_HLA_MHC_affinity.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/peptide_HLA_MHC_affinity" 3 | ex_name: "peptide_HLA_MHC_affinity_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/peptide_HLA_MHC_affinity/peptide_HLA_MHC_affinity_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/peptide_HLA_MHC_affinity/peptide_HLA_MHC_affinity_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/peptide_HLA_MHC_affinity/peptide_HLA_MHC_affinity_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "peptide_HLA_MHC_affinity" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/peta_chs_sol.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/peta_chs_sol" 3 | ex_name: "peta_chs_sol_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_CHS_Sol/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_CHS_Sol/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_CHS_Sol/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "peta_chs_sol" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/peta_lgk_sol.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/peta_lgk_sol" 3 | ex_name: "peta_lgk_sol_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_LGK_Sol/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_LGK_Sol/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_LGK_Sol/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "peta_lgk_sol" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/peta_tem_sol.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/peta_tem_sol" 3 | ex_name: "peta_tem_sol_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_TEM_Sol/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_TEM_Sol/valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/PETA_TEM_Sol/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "peta_tem_sol" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/ppi_affinity.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/ppi_affinity" 3 | ex_name: "ppi_affinity_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ppi_affinity/ppi_affinity_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ppi_affinity/ppi_affinity_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ppi_affinity/ppi_affinity_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'prost' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "ppi_affinity" 30 | task_type: "pair_regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/prot_solm.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/protsolm" 3 | ex_name: "protsolm_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ProtSolM/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ProtSolM/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ProtSolM/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "protsolm" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/sorting_signal.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/sorting_signal" 3 | ex_name: "sorting_signal_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/SortingSignal/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/SortingSignal/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/SortingSignal/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'saport' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "sorting_signal" 30 | task_type: "multi_labels_classification" 31 | num_classes: 9 32 | # metric: "valid_f1_max" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/ssp_q8.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/ssp_q8" 3 | ex_name: "ssp_q8_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ssp_q8/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ssp_q8/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/ssp_q8/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | peft_type: 'lora' 25 | lora_r: 8 # 低秩矩阵的秩 26 | lora_alpha: 16 # LoRA 的 alpha 参数 27 | lora_dropout: 0.1 # Dropout 防止过拟合 28 | 29 | Task: 30 | task_name: "ssp_q8" 31 | task_type: "residual_classification" 32 | num_classes: 8 33 | # metric: "valid_acc" 34 | # direction: "max" 35 | 36 | hydra: 37 | output_subdir: null 38 | -------------------------------------------------------------------------------- /tasks/configs/tape_fluorescence.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/tape_fluorescence" 3 | ex_name: "tape_fluorescence_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Fluorescence/fluorescence_prediction_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Fluorescence/fluorescence_prediction_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Fluorescence/fluorescence_prediction_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 8 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'protrek' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "tape_fluorescence" 30 | task_type: "regression" 31 | num_classes: 1 32 | # metric: "valid_spearman" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/tape_stability.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/tape_stability" 3 | ex_name: "tape_stability_adapter" # ${oc.env:EXP_NAME, 'default'} 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | # TAPE_Stability 8 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Stability/mmseq_outdir/train.csv" 9 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Stability/mmseq_outdir/val.csv" 10 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/TAPE_Stability/mmseq_outdir/test.csv" 11 | # train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/saprot_tape_stab/train.csv" 12 | # val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/saprot_tape_stab/valid.csv" 13 | # test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/saprot_tape_stab/test.csv" 14 | 15 | Training: 16 | epoch: 50 # end epoch 17 | lr: 1e-4 # Learning rate 18 | lr_scheduler: "cosine" 19 | check_val_every_n_epoch: 1 20 | seed: 2024 21 | batch_size: 8 22 | num_workers: 4 23 | seq_len: 1024 24 | 25 | Model: 26 | pretrain_model_name: 'esm_650m' 27 | finetune_type: 'adapter' 28 | peft_type: 'lora' 29 | lora_r: 8 # 低秩矩阵的秩 30 | lora_alpha: 16 # LoRA 的 alpha 参数 31 | lora_dropout: 0.1 # Dropout 防止过拟合 32 | 33 | Task: 34 | task_name: "tape_stability" 35 | task_type: "regression" 36 | num_classes: 1 37 | # metric: "valid_spearman" 38 | # direction: "max" 39 | 40 | hydra: 41 | output_subdir: null 42 | -------------------------------------------------------------------------------- /tasks/configs/tcr_pmhc_affinity.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/tcr_pmhc_affinity" 3 | ex_name: "tcr_pmhc_affinity_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/tcr_pmhc_affinity/tcr_pmhc_affinity_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/tcr_pmhc_affinity/tcr_pmhc_affinity_test.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/tcr_pmhc_affinity/tcr_pmhc_affinity_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "tcr_pmhc_affinity" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/temperature_stability.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/cloning_clf" 3 | ex_name: "cloning_clf_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/temperature_stability/mmseq_outdir/train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/temperature_stability/mmseq_outdir/val.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/temperature_stability/mmseq_outdir/test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'esm2_650m' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "temperature_stability" 30 | task_type: "binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/configs/yeast_ppi.yaml: -------------------------------------------------------------------------------- 1 | Output: 2 | res_dir: "tasks/results/yeast_ppi" 3 | ex_name: "yeast_ppi_adapter" 4 | offline: ${eval:"int(${oc.env:OFFLINE, '1'})"} 5 | 6 | Data: 7 | train_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/yeast_ppi/yeast_ppi_with_pdb_train.csv" 8 | val_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/yeast_ppi/yeast_ppi_with_pdb_valid.csv" 9 | test_data_path: "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/datasets/yeast_ppi/yeast_ppi_with_pdb_test.csv" 10 | 11 | Training: 12 | epoch: 50 # end epoch 13 | lr: 1e-4 # Learning rate 14 | lr_scheduler: "cosine" 15 | check_val_every_n_epoch: 1 16 | seed: 2024 17 | batch_size: 64 18 | num_workers: 4 19 | seq_len: 1024 20 | 21 | Model: 22 | pretrain_model_name: 'prost' 23 | finetune_type: 'adapter' 24 | lora_r: 8 # 低秩矩阵的秩 25 | lora_alpha: 32 # LoRA 的 alpha 参数 26 | lora_dropout: 0.1 # Dropout 防止过拟合 27 | 28 | Task: 29 | task_name: "yeast_ppi" 30 | task_type: "pair_binary_classification" 31 | num_classes: 1 32 | # metric: "valid_auroc" 33 | # direction: "max" 34 | 35 | hydra: 36 | output_subdir: null 37 | -------------------------------------------------------------------------------- /tasks/data_interface.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | from torch.utils.data import DataLoader, DistributedSampler 4 | from src.interface.data_interface import DInterface_base 5 | from src.data.protein_dataset import ProteinDataset 6 | from src.model.pretrain_model_interface import PretrainModelInterface 7 | 8 | class DInterface(DInterface_base): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | self.save_hyperparameters() 12 | 13 | def setup(self, stage=None): 14 | pass 15 | 16 | def data_setup(self, target="all"): 17 | pretrain_model_interface = None 18 | if self.finetune_type == "adapter": 19 | pretrain_model_interface = PretrainModelInterface(self.hparams.pretrain_model_name, batch_size=self.hparams.pretrain_batch_size, max_length=self.hparams.seq_len, sequence_only=self.hparams.sequence_only, task_type=self.hparams.task_type) 20 | if target == "all": 21 | self.train_set = ProteinDataset(self.hparams.train_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) 22 | self.val_set = ProteinDataset(self.hparams.val_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) 23 | self.test_set = ProteinDataset(self.hparams.test_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) 24 | elif target == "test": 25 | self.test_set = ProteinDataset(self.hparams.test_data_path, self.hparams.pretrain_model_name, self.hparams.seq_len, pretrain_model_interface=pretrain_model_interface, task_name=self.task_name, task_type=self.hparams.task_type, num_classes=self.hparams.num_classes) 26 | 27 | 28 | def train_dataloader(self): 29 | return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) 30 | 31 | def val_dataloader(self): 32 | return DataLoader(self.val_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) 33 | 34 | def test_dataloader(self): 35 | return DataLoader(self.test_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=self.data_process_fn) 36 | 37 | def data_process_fn(self, data_list): 38 | if self.hparams.finetune_type == 'adapter': 39 | name_list = [] 40 | mask_list = [] 41 | label_list = [] 42 | embedding_list = [] 43 | smiles = [] 44 | for data in data_list: 45 | name_list.append(data['name']) 46 | mask_list.append(data['attention_mask']) 47 | label_list.append(data['label']) 48 | embedding_list.append(data['embedding']) 49 | if data.get('smiles') is not None: 50 | smiles.append(data['smiles']) 51 | return {'name': name_list, 52 | 'attention_mask': torch.stack(mask_list, dim=0), 53 | 'label': torch.stack(label_list, dim=0), 54 | 'embedding': torch.stack(embedding_list, dim=0), 55 | 'smiles': torch.stack(smiles, dim=0) if len(smiles) > 0 else None, 56 | } 57 | else: 58 | return data_list 59 | -------------------------------------------------------------------------------- /tasks/tuner.py: -------------------------------------------------------------------------------- 1 | import optuna 2 | from pytorch_lightning import Trainer 3 | from pytorch_lightning.loggers import TensorBoardLogger 4 | 5 | def objective(trial): 6 | # 可选不同的数据集或模型组合 7 | data_module = YourDataModule(dataset_name=trial.suggest_categorical("dataset", ["A", "B", "C"])) 8 | model = YourModel(learning_rate=trial.suggest_loguniform("lr", 1e-5, 1e-1)) 9 | 10 | trainer = Trainer( 11 | max_epochs=10, 12 | logger=TensorBoardLogger("logs/", name="optuna"), 13 | callbacks=[], 14 | ) 15 | 16 | trainer.fit(model, datamodule=data_module) 17 | 18 | return trainer.callback_metrics["val_loss"].item() # 根据你实际的 validation metric 来写 19 | 20 | study = optuna.create_study(direction="minimize") 21 | study.optimize(objective, n_trials=50) 22 | -------------------------------------------------------------------------------- /zeroshot/data_interface.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | from torch.utils.data import DataLoader, DistributedSampler 4 | from src.interface.data_interface import DInterface_base 5 | from src.model.pretrain_model_interface import PretrainModelInterface 6 | from src.data.proteingym_dataset import ProteinGYMDataset 7 | from src.data.msa_dataset import MSADataset 8 | 9 | class DInterface(DInterface_base): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | self.save_hyperparameters() 13 | 14 | def setup(self, stage=None): 15 | pass 16 | 17 | def data_setup(self, type="proteingym"): 18 | if type == "proteingym": 19 | self.mut_dataset = ProteinGYMDataset( 20 | dms_csv_dir = self.hparams.dms_csv_dir, 21 | dms_pdb_dir = self.hparams.dms_pdb_dir, 22 | dms_reference_csv_path = self.hparams.dms_reference_csv_path, 23 | ) 24 | elif type == "msa": 25 | self.msa_dataset = MSADataset( 26 | msa_csv_path = msa_csv_path 27 | ) 28 | 29 | def train_dataloader(self): 30 | return DataLoader(self.mut_dataset, batch_size=1, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True) 31 | 32 | def val_dataloader(self): 33 | return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) 34 | 35 | def test_dataloader(self): 36 | return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) 37 | -------------------------------------------------------------------------------- /zeroshot/saprot/foldseek_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import re 4 | import sys 5 | 6 | sys.path.append(".") 7 | 8 | 9 | # Get structural seqs from pdb file 10 | def get_struc_seq(foldseek, 11 | path, 12 | chains: list = None, 13 | process_id: int = 0, 14 | plddt_mask: bool = False, 15 | plddt_threshold: float = 70.) -> dict: 16 | """ 17 | 18 | Args: 19 | foldseek: Binary executable file of foldseek 20 | path: Path to pdb file 21 | chains: Chains to be extracted from pdb file. If None, all chains will be extracted. 22 | process_id: Process ID for temporary files. This is used for parallel processing. 23 | plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file. 24 | plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked. 25 | 26 | Returns: 27 | seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of 28 | (seq, struc_seq, combined_seq). 29 | """ 30 | assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}" 31 | assert os.path.exists(path), f"Pdb file not found: {path}" 32 | 33 | tmp_save_path = f"get_struc_seq_{process_id}.tsv" 34 | cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}" 35 | os.system(cmd) 36 | 37 | seq_dict = {} 38 | name = os.path.basename(path) 39 | with open(tmp_save_path, "r") as r: 40 | for i, line in enumerate(r): 41 | desc, seq, struc_seq = line.split("\t")[:3] 42 | 43 | # Mask low plddt 44 | if plddt_mask: 45 | plddts = extract_plddt(path) 46 | assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}" 47 | 48 | # Mask regions with plddt < threshold 49 | indices = np.where(plddts < plddt_threshold)[0] 50 | np_seq = np.array(list(struc_seq)) 51 | np_seq[indices] = "#" 52 | struc_seq = "".join(np_seq) 53 | 54 | name_chain = desc.split(" ")[0] 55 | chain = name_chain.replace(name, "").split("_")[-1] 56 | 57 | if chains is None or chain in chains: 58 | if chain not in seq_dict: 59 | combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)]) 60 | seq_dict[chain] = (seq, struc_seq, combined_seq) 61 | 62 | os.remove(tmp_save_path) 63 | os.remove(tmp_save_path + ".dbtype") 64 | return seq_dict 65 | 66 | 67 | def extract_plddt(pdb_path: str) -> np.ndarray: 68 | """ 69 | Extract plddt scores from pdb file. 70 | Args: 71 | pdb_path: Path to pdb file. 72 | 73 | Returns: 74 | plddts: plddt scores. 75 | """ 76 | with open(pdb_path, "r") as r: 77 | plddt_dict = {} 78 | for line in r: 79 | line = re.sub(' +', ' ', line).strip() 80 | splits = line.split(" ") 81 | 82 | if splits[0] == "ATOM": 83 | # If position < 1000 84 | if len(splits[4]) == 1: 85 | pos = int(splits[5]) 86 | 87 | # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000" 88 | # So the length of splits[4] is not 1 89 | else: 90 | pos = int(splits[4][1:]) 91 | 92 | plddt = float(splits[-2]) 93 | 94 | if pos not in plddt_dict: 95 | plddt_dict[pos] = [plddt] 96 | else: 97 | plddt_dict[pos].append(plddt) 98 | 99 | plddts = np.array([np.mean(v) for v in plddt_dict.values()]) 100 | return plddts 101 | --------------------------------------------------------------------------------