├── scripts ├── __init__.py ├── utils_argparse.py ├── benchmark.py ├── generate_legacy.py ├── generate_instruct.py ├── generate_instruct_light.py ├── train_legacy.py └── train_instruct.py ├── figures ├── model.png └── training.png ├── .gitignore ├── dataset ├── __init__.py ├── utils_dataset.py ├── utils_pdb2nx.py ├── nx2pyg.py ├── dataloader_light.py ├── dataloader.py ├── dataloader_derived.py └── dataset.py ├── models ├── __init__.py ├── configuration_esm2rgcn2llama_instruct.py ├── configuration_esm2llama_instruct.py ├── configuration_esm2llama_legacy.py ├── modeling_esm2llama_instruct.py └── modeling_esm2rgcn2llama_instruct.py ├── LICENSE └── README.md /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ColinFX/Prot2Text-V2/HEAD/figures/model.png -------------------------------------------------------------------------------- /figures/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ColinFX/Prot2Text-V2/HEAD/figures/training.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.csv 3 | *.err 4 | *.log 5 | *.ipynb 6 | *.pt 7 | *.pth 8 | *tmp* 9 | .vscode 10 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Prot2TextInstructDataset 2 | from .dataloader import Prot2TextInstructDataLoader 3 | from .dataloader_derived import Prot2TextDerivedDataLoader 4 | from .dataloader_light import Prot2TextLightDataset, Prot2TextLightCollater 5 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_esm2llama_instruct import ModalityAdapterConfig, Esm2LlamaInstructConfig 2 | from .configuration_esm2llama_legacy import EsmEncoderConfig, Esm2LlamaConfig 3 | from .configuration_esm2rgcn2llama_instruct import RgcnAdapterConfig, Esm2Rgcn2LlamaInstructConfig 4 | from .modeling_esm2llama_instruct import ModalityAdapter, Esm2LlamaInstructForCausalLM 5 | from .modeling_esm2llama_legacy import EsmEncoderModel, Esm2LlamaForCausalLM 6 | from .modeling_esm2rgcn2llama_instruct import RgcnAdapter, Esm2Rgcn2LlamaInstructForCausalLM 7 | from .modeling_reward import RewardModel 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xiao FEI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/utils_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from graphein.protein.config import DSSPConfig, ProteinGraphConfig 4 | from graphein.protein.edges.distance import ( 5 | add_peptide_bonds, 6 | add_hydrogen_bond_interactions, 7 | add_distance_threshold 8 | ) 9 | from graphein.protein.features.nodes.amino_acid import ( 10 | amino_acid_one_hot, 11 | meiler_embedding, 12 | expasy_protein_scale, 13 | hydrogen_bond_acceptor, 14 | hydrogen_bond_donor 15 | ) 16 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure 17 | 18 | 19 | default_graph_process_config = ProteinGraphConfig( 20 | **{ 21 | "node_metadata_functions": [ 22 | amino_acid_one_hot, 23 | expasy_protein_scale, 24 | meiler_embedding, 25 | hydrogen_bond_acceptor, 26 | hydrogen_bond_donor 27 | ], 28 | "edge_construction_functions": [ 29 | add_peptide_bonds, 30 | add_hydrogen_bond_interactions, 31 | partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.), 32 | ], 33 | "graph_metadata_functions": [asa, phi, psi, secondary_structure, rsa], 34 | "dssp_config": DSSPConfig(), 35 | } 36 | ) 37 | -------------------------------------------------------------------------------- /scripts/utils_argparse.py: -------------------------------------------------------------------------------- 1 | """Utilities for argparse configuration.""" 2 | 3 | import torch 4 | 5 | 6 | def str2bool(string: str = "") -> bool: 7 | """ 8 | Override default build-in bool() function to facilitate argparse configuration. 9 | 10 | Example of usage: 11 | >>> import argparse 12 | >>> from utils import str2bool 13 | >>> argParser = argparse.ArgumentParser() 14 | >>> argParser.add_argument( 15 | "--debug_model", 16 | default=False, 17 | type=str2bool, 18 | help="boolean flag to debug model" 19 | ) 20 | """ 21 | negative_keywords = ["false", "no", "none", "negative", "off", "disable", "f", "0"] 22 | if not string or any([string.lower() == keyword for keyword in negative_keywords]): 23 | return False 24 | return True 25 | 26 | 27 | def str2dtype(string: str = "") -> torch.dtype: 28 | """ 29 | Convert string to corresponding torch datatype to facilitate argparse 30 | configuration for autocast dtype. 31 | """ 32 | if not string: 33 | return torch.float32 34 | string = string.lower() 35 | bf16_keywords = ["torch.bfloat16", "bfloat16", "bf16"] 36 | fp16_keywords = ["torch.float16", "float16", "fp16", "16", "half"] 37 | int8_keywords = ["torch.int8", "int8", "8"] 38 | int4_keywords = ["torch.int4", "int4", "4"] 39 | if any([string == keyword for keyword in bf16_keywords]): 40 | return torch.bfloat16 41 | elif any([string == keyword for keyword in fp16_keywords]): 42 | return torch.float16 43 | elif any([string == keyword for keyword in int8_keywords]): 44 | return torch.int8 45 | elif any([string == keyword for keyword in int4_keywords]): 46 | return torch.int4 47 | else: 48 | return torch.float32 49 | -------------------------------------------------------------------------------- /models/configuration_esm2rgcn2llama_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration class for the assembled Esm2Rgcn2LlamaInstructForCausalLM model. 3 | 4 | Esm2LlamaInstructConfig = EsmConfig + RgcnAdapterConfig + LlamaConfig 5 | """ 6 | 7 | 8 | from transformers import EsmConfig, LlamaConfig, PretrainedConfig 9 | 10 | 11 | class RgcnAdapterConfig(PretrainedConfig): 12 | """Configuration class of the Relational Graph Convolutional Network adapter.""" 13 | model_type = "rgcn_adapter" # unique identifier of the model 14 | 15 | def __init__( 16 | self, 17 | input_dim: int, 18 | intermediate_dim: int, 19 | output_dim: int, 20 | n_relations: int = 7, 21 | n_layers: int = 6, 22 | dropout_rate: float = 0.2, 23 | **kwargs 24 | ): 25 | super().__init__(**kwargs) 26 | self.input_dim = input_dim 27 | self.intermediate_dim = intermediate_dim 28 | self.output_dim = output_dim 29 | self.n_relations = n_relations 30 | self.n_layers = n_layers 31 | self.dropout_rate = dropout_rate 32 | 33 | 34 | class Esm2Rgcn2LlamaInstructConfig(PretrainedConfig): 35 | """ 36 | Configuration class of Esm2Rgcn2LlamaInstructForCausalLM model. 37 | placeholder_id: Token id in chat template to be replaced by ESM embeddings. 38 | """ 39 | model_type = "esm2rgcn2llama_instruct" # unique identifier of the model 40 | 41 | def __init__( 42 | self, 43 | # model components 44 | esm_config: EsmConfig, 45 | adapter_config: RgcnAdapterConfig, 46 | llama_config: LlamaConfig, 47 | # standalone attributes 48 | placeholder_id: int = 128003, 49 | **kwargs 50 | ): 51 | super().__init__(**kwargs) 52 | self.esm_config = esm_config 53 | self.adapter_config = adapter_config 54 | self.llama_config = llama_config 55 | self.placeholder_id = placeholder_id 56 | -------------------------------------------------------------------------------- /dataset/utils_pdb2nx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for `model/pdb2nx.py` script. 3 | """ 4 | import numpy as np 5 | from biopandas.pdb import PandasPdb 6 | 7 | 8 | pdb_order = [ 9 | "record_name", 10 | "atom_number", 11 | "blank_1", 12 | "atom_name", 13 | "alt_loc", 14 | "residue_name", 15 | "blank_2", 16 | "chain_id", 17 | "residue_number", 18 | "insertion", 19 | "blank_3", 20 | "x_coord", 21 | "y_coord", 22 | "z_coord", 23 | "occupancy", 24 | "b_factor", 25 | "blank_4", 26 | "segment_id", 27 | "element_symbol", 28 | "charge", 29 | "line_idx", 30 | ] 31 | mmcif_read = { 32 | "group_PDB": "record_name", 33 | "id": "atom_number", 34 | "auth_atom_id": "atom_name", 35 | "auth_comp_id": "residue_name", 36 | "auth_asym_id": "chain_id", 37 | "auth_seq_id": "residue_number", 38 | "Cartn_x": "x_coord", 39 | "Cartn_y": "y_coord", 40 | "Cartn_z": "z_coord", 41 | "occupancy": "occupancy", 42 | "B_iso_or_equiv": "b_factor", 43 | "type_symbol": "element_symbol", 44 | } 45 | 46 | nonefields = [ 47 | "blank_1", 48 | "alt_loc", 49 | "blank_2", 50 | "insertion", 51 | "blank_3", 52 | "blank_4", 53 | "segment_id", 54 | "charge", 55 | "line_idx", 56 | ] 57 | 58 | 59 | def biopandas_mmcif2pdb(pandasmmcif, model_index=1): 60 | """Converts the ATOM and HETATM dataframes of PandasMmcif() to PandasPdb() format.""" 61 | pandaspdb = PandasPdb() 62 | for a in ["ATOM", "HETATM"]: 63 | dfa = pandasmmcif.df[a] 64 | dfa = dfa.loc[dfa.pdbx_PDB_model_num == model_index] 65 | if a == 'ATOM': 66 | if len(dfa) == 0: 67 | raise ValueError(f"No model found for index: {model_index}") 68 | # keep only those fields found in pdb 69 | dfa = dfa[mmcif_read.keys()] 70 | # rename fields 71 | dfa = dfa.rename(columns=mmcif_read) 72 | # add empty fields 73 | for i in nonefields: 74 | dfa[i] = "" 75 | dfa["charge"] = np.nan 76 | # reorder columns to PandasPdb order 77 | dfa = dfa[pdb_order] 78 | pandaspdb.df[a] = dfa 79 | 80 | # update line_idx 81 | pandaspdb.df["ATOM"]["line_idx"] = pandaspdb.df["ATOM"].index.values 82 | pandaspdb.df["HETATM"]["line_idx"] = pandaspdb.df["HETATM"].index 83 | 84 | return pandaspdb 85 | -------------------------------------------------------------------------------- /models/configuration_esm2llama_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model. 3 | 4 | Esm2LlamaInstructConfig = EsmConfig + ModalityAdapterConfig + LlamaConfig 5 | """ 6 | 7 | from typing import Dict, Optional, Union 8 | from transformers import EsmConfig, LlamaConfig, PretrainedConfig 9 | 10 | 11 | class ModalityAdapterConfig(PretrainedConfig): 12 | """Configuration class of the 2-layer non-linear adapter.""" 13 | model_type = "modality_adapter" # unique identifier of the model 14 | 15 | def __init__( 16 | self, 17 | input_dim: int, 18 | intermediate_dim: int, 19 | output_dim: int, 20 | dropout_rate: float = 0.3, 21 | **kwargs 22 | ): 23 | super().__init__(**kwargs) 24 | self.input_dim = input_dim 25 | self.intermediate_dim = intermediate_dim 26 | self.output_dim = output_dim 27 | self.dropout_rate = dropout_rate 28 | 29 | 30 | class Esm2LlamaInstructConfig(PretrainedConfig): 31 | """ 32 | Configuration class of Esm2LlamaInstructForCausalLM model. 33 | placeholder_id: Token id in chat template to be replaced by ESM embeddings. 34 | """ 35 | model_type = "esm2llama_instruct" # unique identifier of the model 36 | 37 | def __init__( 38 | self, 39 | # model components 40 | esm_config: Optional[Union[EsmConfig, Dict]] = None, 41 | adapter_config: Optional[Union[ModalityAdapterConfig, Dict]] = None, 42 | llama_config: Optional[Union[LlamaConfig, Dict]] = None, 43 | # standalone attributes 44 | placeholder_id: int = 128003, 45 | **kwargs 46 | ): 47 | super().__init__(**kwargs) 48 | 49 | if isinstance(esm_config, dict): 50 | self.esm_config = EsmConfig(**esm_config) 51 | else: 52 | self.esm_config = esm_config 53 | 54 | if isinstance(llama_config, dict): 55 | self.llama_config = LlamaConfig(**llama_config) 56 | else: 57 | self.llama_config = llama_config 58 | 59 | if isinstance(adapter_config, dict): 60 | self.adapter_config = ModalityAdapterConfig(**adapter_config) 61 | else: 62 | self.adapter_config = adapter_config 63 | 64 | self.placeholder_id = placeholder_id 65 | -------------------------------------------------------------------------------- /dataset/nx2pyg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for converting Protein Structure Graphs to standard Data object. 3 | """ 4 | import networkx as nx 5 | import numpy as np 6 | import torch 7 | import torch_geometric 8 | 9 | 10 | graph_features = ['phi', 'psi', 'rsa', 'asa', 'ss', 'expasy'] 11 | map_secondary_structure = {'-': 0, 'H': 1, 'B': 2, 'E': 3, 'G': 4, 'I': 5, 'T': 6, 'S': 7} 12 | map_edge_types = { 13 | 'peptide_bond': 0, 14 | 'sequence_distance_2': 1, 15 | 'sequence_distance_3': 2, 16 | 'distance_threshold': 3, 17 | 'delaunay': 4, 18 | 'hbond': 5, 19 | 'k_nn': 6 20 | } 21 | 22 | 23 | def convert_nx_to_pyg(nx_graph: nx.Graph) -> torch_geometric.data.Data: 24 | """ 25 | Converting Graphein Networks from `networkx.Graph` (nx) format to `torch_geometric.data.Data` (pytorch geometric, 26 | PyG) format. 27 | """ 28 | # Initialise dict used to construct Data object 29 | data_dict = {"node_id": list(nx_graph.nodes())} 30 | nx_graph = nx.convert_node_labels_to_integers(nx_graph) 31 | 32 | # Construct Edge Index 33 | edge_index = torch.LongTensor(list(nx_graph.edges)).t().contiguous() 34 | 35 | # Add node features 36 | for i, (_, feat_dict) in enumerate(nx_graph.nodes(data=True)): 37 | for key, value in feat_dict.items(): 38 | data_dict[str(key)] = [value] if i == 0 else data_dict[str(key)] + [value] 39 | 40 | # Add edge features 41 | for i, (_, _, feat_dict) in enumerate(nx_graph.edges(data=True)): 42 | for key, value in feat_dict.items(): 43 | if key == 'distance': 44 | data_dict[str(key)] = [value] if i == 0 else data_dict[str(key)] + [value] 45 | else: 46 | data_dict[str(key)] = [list(value)] if i == 0 else data_dict[str(key)] + [list(value)] 47 | 48 | # Add graph-level features 49 | for feat_name in nx_graph.graph: 50 | data_dict[str(feat_name)] = [nx_graph.graph[feat_name]] 51 | 52 | data_dict["edge_index"] = edge_index.view(2, -1) 53 | data = torch_geometric.data.Data.from_dict(data_dict) 54 | data.num_nodes = nx_graph.number_of_nodes() 55 | 56 | # remove useless intermediate data and add features for deep learning models 57 | reformat_data = torch_geometric.data.Data( 58 | edge_index=data.edge_index, 59 | num_nodes=len(data.node_id), 60 | node_id=data.node_id, 61 | name=data.name[0], 62 | sequence=getattr(data, f"sequence_{data.chain_id[0]}"), 63 | distance_matrix=data.dist_mat, 64 | distance=data.distance, 65 | coordinates=torch.FloatTensor(np.array(data.coords[0])) 66 | ) 67 | 68 | x = np.array([np.argmax(data.amino_acid_one_hot, axis=1)]).reshape(-1, 1) 69 | for feat in graph_features: 70 | if feat == "ss": 71 | feature = np.array([[map_secondary_structure.get(feat_node, 0)] for feat_node in data[feat]]) 72 | else: 73 | feature = np.array(data[feat]) 74 | if len(feature.shape) == 1: 75 | feature = feature.reshape(-1, 1) 76 | x = np.concatenate((x, feature), axis=1) 77 | reformat_data.x = torch.FloatTensor(x) 78 | reformat_data.edge_type = torch.LongTensor([map_edge_types[kind[0]] for kind in data.kind]) 79 | 80 | return reformat_data 81 | -------------------------------------------------------------------------------- /models/configuration_esm2llama_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Legacy configuration classes for Esm2LlamaModel. Migrated from previous projects 3 | on protein function description under a decoder-only base language decoder 4 | structure. 5 | 6 | Esm2LlamaConfig = LlamaConfig (+ EsmEncoderConfig as additional attribute) 7 | """ 8 | 9 | 10 | import os 11 | from typing import Any, Dict, Optional, Union 12 | 13 | from transformers import EsmConfig, LlamaConfig 14 | 15 | 16 | class EsmEncoderConfig(EsmConfig): 17 | """Configuration class of EsmEncoderModel model.""" 18 | 19 | def __init__( 20 | self, 21 | *args, 22 | decoder_hidden_size: Optional[int] = None, 23 | **kwargs 24 | ): 25 | super().__init__(*args, **kwargs) 26 | self.decoder_hidden_size: int = decoder_hidden_size 27 | 28 | 29 | class Esm2LlamaConfig(LlamaConfig): 30 | """ 31 | Configuration class of Esm2LlamaModel model. 32 | 33 | Args: 34 | esm_config: 35 | Configuration to be used with the EsmEncoderConfig encoder 36 | configuration. The value is either an instance of EsmEncoderConfig 37 | or a dict of parameters to be passed to initialize the 38 | EsmEncoderConfig (read the documentation from `EsmEncoderConfig` 39 | for more information in this case). If not given, a default 40 | configuration is used. 41 | kwargs: 42 | Keyword arguments for initialization of LlamaConfig, read the 43 | documentation from `LlamaConfig` for more information. Parameters 44 | controlling the model outputs can also be passed, read the 45 | documentation from `PretrainedConfig` for more information. 46 | """ 47 | def __init__( 48 | self, 49 | *args, 50 | esm_config: Optional[Union[EsmEncoderConfig, Dict[str, Any]]] = None, 51 | **kwargs 52 | ): 53 | # normal initialization of LlamaConfig with keyword arguments 54 | super().__init__(*args, **kwargs) 55 | 56 | # add self.esm_config: EsmEncoderConfig as extra attribute to LlamaConfig 57 | if esm_config is None or isinstance(esm_config, dict): 58 | self.esm_config = EsmEncoderConfig(**esm_config if esm_config else {}) 59 | elif isinstance(esm_config, EsmEncoderConfig): 60 | self.esm_config = esm_config 61 | else: 62 | raise ValueError( 63 | "esm_config must be a EsmEncoderConfig, or a dict of " 64 | "initialization parameters. Use from_pretrained method instead " 65 | "if the esm_config shall be loaded from a pretrained model " 66 | "name or path. " 67 | ) 68 | 69 | @classmethod 70 | def from_pretrained( 71 | cls, 72 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, 73 | pretrained_esm_model_name_or_path: Optional[Union[str, os.PathLike]] = None, 74 | pretrained_llama_model_name_or_path: Optional[Union[str, os.PathLike]] = None, 75 | esm_kwargs: Optional[Dict[str, Any]] = None, 76 | **kwargs, 77 | ) -> "Esm2LlamaConfig": 78 | """ 79 | Instantiates a Esm2LlamaConfig from either (1) a pretrained 80 | Esm-to-Llama model, or (2) a pretrained LlamaForCausalLM and/or a 81 | pretrained EsmModel. The configuration of any unspecified parts of 82 | the model will be initialized with default values. 83 | 84 | return_unused_kwargs is currently not supported in the loading 85 | behavior. # TODO complete this case. 86 | 87 | Args: 88 | pretrained_model_name_or_path: 89 | Esm-to-Llama model name or path to load predefined Esm-to-Llama 90 | model configuration. If given, pretrained EsmModel and 91 | LlamaForCausalLM name or path will be ignored. 92 | pretrained_esm_model_name_or_path: 93 | Esm model name or path to load predefined EsmModel configuration 94 | as encoder part of the whole model. 95 | pretrained_llama_model_name_or_path: 96 | Llama model name or path to load predefined LlamaForCausalLM 97 | model configuration as decoder part of the whole model. 98 | esm_kwargs: 99 | Configuration attributes to override values in EsmEncoderConfig 100 | which is either loaded from pretrained or initialized with 101 | default values. Behavior concerning key/value pairs whose keys 102 | are not configuration attributes is controlled by the 103 | return_unused_kwargs keyword parameter. Parameters controlling 104 | the loading behaviors of Esm configuration such as `cache_dir` 105 | and `force_download` can also be passed if 106 | `pretrained_esm_model_name_or_path` is given, read the 107 | documentation from `PretrainedConfig.from_pretrained` for more 108 | information. 109 | kwargs: 110 | Configuration attributes to override the loaded values in 111 | Esm2LlamaConfig or LlamaConfig. Parameters controlling the 112 | loading behaviors of Esm-to-Llama or Llama configuration such 113 | as `cache_dir` and `force_download` can also be passed if 114 | `pretrained_model_name_or_path` or 115 | `pretrained_llama_model_name_or_path` is given. 116 | """ 117 | # case (1): instantiate from a pretrained Esm-to-Llama model 118 | if pretrained_model_name_or_path is not None: 119 | config = super().from_pretrained( 120 | pretrained_model_name_or_path=pretrained_model_name_or_path, 121 | **kwargs 122 | ) 123 | if esm_kwargs: 124 | config.esm_config.update(esm_kwargs) 125 | 126 | # case (2-1): instantiate from pretrained LlamaModel and EsmModel 127 | elif ( 128 | pretrained_esm_model_name_or_path is not None 129 | and pretrained_llama_model_name_or_path is not None 130 | ): 131 | config = super().from_pretrained( 132 | pretrained_model_name_or_path=pretrained_llama_model_name_or_path, 133 | **kwargs 134 | ) 135 | config.esm_config = EsmEncoderConfig.from_pretrained( 136 | pretrained_model_name_or_path=pretrained_esm_model_name_or_path, 137 | **esm_kwargs if esm_kwargs else {} 138 | ) 139 | 140 | # case (2-2): instantiate from a pretrained EsmModel 141 | elif pretrained_esm_model_name_or_path is not None: 142 | esm_config = EsmEncoderConfig.from_pretrained( 143 | pretrained_model_name_or_path=pretrained_esm_model_name_or_path, 144 | **esm_kwargs if esm_kwargs else {} 145 | ) 146 | config = cls(esm_config=esm_config, **kwargs) 147 | 148 | # case (2-3): instantiate from a pretrained LlamaModel 149 | elif pretrained_llama_model_name_or_path is not None: 150 | config = super().from_pretrained( 151 | pretrained_model_name_or_path=pretrained_llama_model_name_or_path, 152 | **kwargs 153 | ) 154 | config.esm_config = EsmEncoderConfig(**esm_kwargs if esm_kwargs else {}) 155 | 156 | else: 157 | raise ValueError( 158 | "Either pretrained name or path of Esm-to-Llama model, EsmModel " 159 | "or LlamaForCausalLM should be passed. Use initialization " 160 | "method instead if none of the above three can be provided. " 161 | ) 162 | return config 163 | -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Single-thread script to compute metrics including BLEU, ROUGE and BERT scores. 3 | Metrics will be computed on JSON files generated by `generate_instruct.py`. 4 | The script is designed for single-GPU computation. 5 | """ 6 | 7 | import argparse 8 | import json 9 | import os 10 | import re 11 | from typing import Any, Dict, List 12 | 13 | import evaluate 14 | from transformers import BertTokenizer, RobertaTokenizer 15 | import scripts.utils_argparse as utils_argparse 16 | 17 | 18 | argParser = argparse.ArgumentParser() 19 | 20 | argParser.add_argument("--read_generation_dir", type=str) 21 | argParser.add_argument("--read_file_identifier", type=str, help="Postfix identifier or timestamp to filter files.") 22 | 23 | argParser.add_argument("--evaluate_exact_match", type=utils_argparse.str2bool) 24 | argParser.add_argument("--evaluate_bleu", type=utils_argparse.str2bool) 25 | argParser.add_argument("--evaluate_rouge", type=utils_argparse.str2bool) 26 | argParser.add_argument("--evaluate_bert_score", type=utils_argparse.str2bool) 27 | argParser.add_argument("--verbose", type=utils_argparse.str2bool) 28 | 29 | 30 | def compute_exact_match(predictions: List[str], references: List[str]) -> float: 31 | """Compute exact match ratio allowing for case and punctuation differences.""" 32 | def normalize(text: str) -> str: 33 | """Normalize text by lowercasing and removing punctuation.""" 34 | text = text.lower() 35 | text = re.sub(r'[^\w]', '', text) 36 | return text 37 | 38 | exact_match = 0 39 | for pred, ref in zip(predictions, references): 40 | if normalize(pred) == normalize(ref): 41 | exact_match += 1 42 | return exact_match / len(predictions) 43 | 44 | 45 | def compute_bleu2(predictions: List[str], references: List[str]) -> Dict[str, Any]: 46 | bleu = evaluate.load("bleu") 47 | return bleu.compute(predictions=predictions, references=references, max_order=2) 48 | 49 | 50 | def compute_bleu4(predictions: List[str], references: List[str]) -> Dict[str, Any]: 51 | bleu = evaluate.load("bleu") 52 | return bleu.compute(predictions=predictions, references=references) 53 | 54 | 55 | def compute_rouge(predictions: List[str], references: List[str]) -> Dict[str, Any]: 56 | rouge = evaluate.load("rouge") 57 | return rouge.compute(predictions=predictions, references=references) 58 | 59 | 60 | def compute_bert_score(predictions: List[str], references: List[str]) -> Dict[str, Dict[str, Any]]: 61 | """Compute BERT score on roberta-large and biobert-large respectively.""" 62 | results: Dict[str, Dict[str, Any]] = {} 63 | 64 | tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-large") 65 | retokenized_predictions = tokenizer( 66 | predictions, padding="max_length", truncation=True, max_length=495, return_tensors="pt" 67 | )["input_ids"] 68 | truncated_predictions = tokenizer.batch_decode(retokenized_predictions, skip_special_tokens=True) 69 | retokenized_labels = tokenizer( 70 | references, padding="max_length", truncation=True, max_length=495, return_tensors="pt" 71 | )["input_ids"] 72 | truncated_labels = tokenizer.batch_decode(retokenized_labels, skip_special_tokens=True) 73 | 74 | bert = evaluate.load("bertscore") 75 | roberta_results = bert.compute(predictions=truncated_predictions, references=truncated_labels, lang="en") 76 | results["roberta-large"] = { 77 | "precision": sum(roberta_results["precision"]) / len(roberta_results["precision"]), 78 | "recall": sum(roberta_results["recall"]) / len(roberta_results["recall"]), 79 | "f1": sum(roberta_results["f1"]) / len(roberta_results["f1"]) 80 | } 81 | 82 | # truncate sentences to fit max_position_embeddings=512 of biobert 83 | tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-large-cased-v1.1") 84 | retokenized_predictions = tokenizer( 85 | predictions, padding="max_length", truncation=True, max_length=495, return_tensors="pt" 86 | )["input_ids"] 87 | truncated_predictions = tokenizer.batch_decode(retokenized_predictions, skip_special_tokens=True) 88 | retokenized_labels = tokenizer( 89 | references, padding="max_length", truncation=True, max_length=495, return_tensors="pt" 90 | )["input_ids"] 91 | truncated_labels = tokenizer.batch_decode(retokenized_labels, skip_special_tokens=True) 92 | 93 | biobert_results = bert.compute( 94 | predictions=truncated_predictions, 95 | references=truncated_labels, 96 | model_type="dmis-lab/biobert-large-cased-v1.1", 97 | num_layers=24, 98 | ) 99 | results["biobert-large"] = { 100 | "precision": sum(biobert_results["precision"]) / len(biobert_results["precision"]), 101 | "recall": sum(biobert_results["recall"]) / len(biobert_results["recall"]), 102 | "f1": sum(biobert_results["f1"]) / len(biobert_results["f1"]) 103 | } 104 | 105 | return results 106 | 107 | 108 | def compute_metrics(predictions: List[str], references: List[str], args: Dict[str, Any]) -> Dict[str, Any]: 109 | """Compute BLEU, ROUGE, BERT scores and exact match ratio on given texts.""" 110 | gathered_results: Dict[str, Dict[str, Any]] = {} 111 | 112 | if args["evaluate_exact_match"]: 113 | exact_match = compute_exact_match(predictions=predictions, references=references) 114 | gathered_results["exact_match"] = exact_match 115 | if args["verbose"]: 116 | print(f"EXACT match ratio: {exact_match}") 117 | 118 | if args["evaluate_bleu"]: 119 | bleu_results = compute_bleu2(predictions=predictions, references=references) 120 | gathered_results["bleu2"] = bleu_results 121 | if args["verbose"]: 122 | print(f"BLEU-2 score: {bleu_results}") 123 | bleu_results = compute_bleu4(predictions=predictions, references=references) 124 | gathered_results["bleu4"] = bleu_results 125 | if args["verbose"]: 126 | print(f"BLEU-4 score: {bleu_results}") 127 | 128 | if args["evaluate_rouge"]: 129 | rouge_results = compute_rouge(predictions=predictions, references=references) 130 | gathered_results["rouge"] = rouge_results 131 | if args["verbose"]: 132 | print(f"ROUGE score: {rouge_results}") 133 | 134 | if args["evaluate_bert_score"]: 135 | bert_results = compute_bert_score(predictions=predictions, references=references) 136 | gathered_results["bert"] = bert_results 137 | if args["verbose"]: 138 | for model_name, model_results in bert_results.items(): 139 | print(f"BERT score with {model_name}: {model_results}") 140 | 141 | return gathered_results 142 | 143 | 144 | def benchmark(args: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: 145 | """ 146 | Evaluate generation results on JSON files produced by `generate_ddp.py`. 147 | 1) Gather predictions and labels from JSON files. 148 | 2) Compute metrics including BLEU, ROUGE and BERT scores and print results. 149 | """ 150 | read_generation_paths = [] 151 | for file_name in os.listdir(args["read_generation_dir"]): 152 | full_path = os.path.join(args["read_generation_dir"], file_name) 153 | if os.path.isfile(full_path) and args["read_file_identifier"] in full_path: 154 | read_generation_paths.append(full_path) 155 | 156 | gathered_predictions = [] 157 | gathered_labels = [] 158 | for read_path in read_generation_paths: 159 | with open(read_path, "r") as file: 160 | results = json.load(file) 161 | local_predictions = [results[name]["pred"] for name in results.keys()] 162 | local_labels = [results[name]["true"] for name in results.keys()] 163 | gathered_predictions.extend(local_predictions) 164 | gathered_labels.extend(local_labels) 165 | print(f"Reading {read_path}") 166 | 167 | return compute_metrics(predictions=gathered_predictions, references=gathered_labels, args=args) 168 | 169 | 170 | if __name__ == "__main__": 171 | parsed_args = argParser.parse_args() 172 | 173 | print("####################") 174 | for key, value in parsed_args.__dict__.items(): 175 | print(f"{key}: {value}") 176 | print("####################") 177 | 178 | benchmark(parsed_args.__dict__) 179 | -------------------------------------------------------------------------------- /scripts/generate_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | DistributedDataParallel generation script implemented from scratch. 3 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`. 4 | The script is designed for multi-GPU parallelism on single node. 5 | """ 6 | 7 | import argparse 8 | from datetime import datetime 9 | import json 10 | import os 11 | from typing import Any, Dict, List 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.utils.data.distributed import DistributedSampler 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer, PreTrainedTokenizer 19 | 20 | from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader 21 | from .train_legacy import load_model, setup, cleanup 22 | 23 | 24 | argParser = argparse.ArgumentParser() 25 | 26 | argParser.add_argument("--esm_path", type=str) 27 | argParser.add_argument("--llama_path", type=str) 28 | argParser.add_argument("--root_dataset_dir", type=str) 29 | argParser.add_argument("--root_csv_dir", type=str) 30 | argParser.add_argument("--save_generation_dir", type=str) 31 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None) 32 | argParser.add_argument("--load_general_checkpoint_path", type=str, default="") 33 | 34 | argParser.add_argument("--batch_size_per_device", type=int) 35 | argParser.add_argument("--random_seed", type=int) 36 | argParser.add_argument("--generate_split", type=str) 37 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None) 38 | argParser.add_argument("--max_sequence_length", type=int, default=None) 39 | argParser.add_argument("--max_generation_length", type=int) 40 | argParser.add_argument("--num_beams", type=int, default=1) 41 | argParser.add_argument("--length_penalty", type=float, default=1.0) 42 | 43 | 44 | def iterative_generation_loop( 45 | rank: int, 46 | model: torch.nn.Module, 47 | data_batch: Dict[str, Any], 48 | max_generation_length: int, 49 | num_beams: int, 50 | length_penalty: float 51 | ) -> torch.Tensor: 52 | """ 53 | Standard API for different models. Used in `inference_epoch`. 54 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader. 55 | 2) Execute the generation cycle and return the direct output. 56 | Returned output is a `torch.Tensor` of the generated tokens. 57 | """ 58 | if isinstance(model, DistributedDataParallel): 59 | model = model.module # for wrapper models, get the inner model for generation 60 | 61 | return model.generate( 62 | inputs=data_batch["input_ids"].to(rank), 63 | attention_mask=data_batch["attention_mask"].to(rank), 64 | protein_input_ids=data_batch["protein_input_ids"].to(rank), 65 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank), 66 | max_new_tokens=max_generation_length, 67 | eos_token_id=128001, 68 | pad_token_id=128002, 69 | return_dict_in_generate=False, 70 | num_beams=num_beams, 71 | length_penalty=length_penalty, 72 | ) 73 | 74 | 75 | def inference_epoch( 76 | rank: int, 77 | model: DistributedDataParallel, 78 | dataloader: Prot2TextDerivedDataLoader, 79 | llama_tokenizer: PreTrainedTokenizer, 80 | args: Dict[str, Any] 81 | ): 82 | """ 83 | Iterate over all batches for inference with iterative loop. 84 | Generation results will be saved to JSON files. 85 | """ 86 | model.eval() 87 | local_names: List[str] = [] 88 | local_predictions: List[str] = [] 89 | local_labels: List[str] = [] 90 | 91 | # core loop for batches 92 | t = tqdm(iter(dataloader)) 93 | for data_batch in t: 94 | with torch.no_grad(): 95 | output = iterative_generation_loop( 96 | rank=rank, 97 | model=model, 98 | data_batch=data_batch, 99 | max_generation_length=args["max_generation_length"], 100 | num_beams=args["num_beams"], 101 | length_penalty=args["length_penalty"] 102 | ) 103 | local_names.extend(data_batch["name"]) 104 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True) 105 | local_predictions.extend(predicted_texts) 106 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True) 107 | local_labels.extend(label_texts) 108 | t.set_postfix({ 109 | "mode": "inference", 110 | "batch_maxlen_gen": output.shape[1], 111 | "device": f"rank:{rank}" 112 | }) 113 | 114 | local_json_path = os.path.join( 115 | args["save_generation_dir"], 116 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json" 117 | ) 118 | with open(local_json_path, "w") as file: 119 | json_dict = { 120 | name: {"true": label, "pred": prediction} 121 | for name, label, prediction in zip(local_names, local_labels, local_predictions) 122 | } 123 | json.dump(json_dict, file, indent=4) 124 | print(f"Saving {local_json_path}") 125 | 126 | 127 | 128 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]): 129 | """Core generation process for every device with batches over the whole dataset""" 130 | setup(rank, world_size) 131 | 132 | # prepare dataset and dataloader 133 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"]) 134 | llama_tokenizer = AutoTokenizer.from_pretrained(args["llama_path"], pad_token='<|reserved_special_token_0|>') 135 | generate_dataset = Prot2TextInstructDataset( 136 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"), 137 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"), 138 | sequence_tokenizer=esm_tokenizer, 139 | description_tokenizer=llama_tokenizer, 140 | skip_reload=True, 141 | skip_download=True, 142 | ignore_graph_features=True, 143 | max_sequence_length=args["max_sequence_length"], 144 | max_description_length=None, # fetch complete labels for metric computation 145 | ) 146 | if args["debug_trim_generate_split"]: 147 | generate_dataset.usable_file_names = generate_dataset.usable_file_names[:args["debug_trim_generate_split"]] 148 | generate_sampler = DistributedSampler(generate_dataset, rank=rank, num_replicas=world_size, shuffle=False) 149 | generate_loader = Prot2TextDerivedDataLoader( 150 | generate_dataset, 151 | mode="inference", 152 | batch_size=args["batch_size_per_device"], 153 | sampler=generate_sampler, 154 | num_workers=2, 155 | pin_memory=True, 156 | shuffle=False, 157 | drop_last=True, 158 | ) 159 | 160 | # load base model and then the checkpoint 161 | model = load_model( 162 | esm_path=args["esm_path"], 163 | llama_path=args["llama_path"], 164 | load_general_checkpoint_path=args["load_general_checkpoint_path"] 165 | ) 166 | model = model.to(rank) 167 | model = DistributedDataParallel(model) 168 | print(f"Model loaded on rank{rank}") 169 | 170 | inference_epoch( 171 | rank=rank, 172 | model=model, 173 | dataloader=generate_loader, 174 | llama_tokenizer=llama_tokenizer, 175 | args=args 176 | ) 177 | # use a barrier to make sure that all processes have finished writing their JSON files 178 | dist.barrier() 179 | 180 | cleanup() 181 | 182 | 183 | def inference_distributed(args: Dict[str, Any]): 184 | """Core generation process across multiple devices with batches over the whole dataset""" 185 | torch.multiprocessing.spawn( 186 | inference_on_device, 187 | args=(args["world_size"], args), 188 | nprocs=args["world_size"], 189 | join=True 190 | ) 191 | 192 | 193 | if __name__ == "__main__": 194 | # suppress messages from AutoTokenizer parallelism and Graphein respectively 195 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 196 | os.environ["LOGURU_LEVEL"] = "INFO" 197 | 198 | parsed_args = argParser.parse_args() 199 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs 200 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes 201 | 202 | torch.manual_seed(parsed_args.random_seed) 203 | torch.cuda.manual_seed(parsed_args.random_seed) 204 | 205 | # prepare for saving path 206 | if not os.path.exists(parsed_args.save_generation_dir): 207 | os.makedirs(parsed_args.save_generation_dir) 208 | 209 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S") 210 | if parsed_args.save_generation_postfix_identifier: 211 | parsed_args.save_generation_postfix_identifier = f""" 212 | {start_timestamp}_[{parsed_args.save_generation_postfix_identifier}] 213 | """ 214 | else: 215 | parsed_args.save_generation_postfix_identifier = start_timestamp 216 | 217 | print("####################") 218 | for key, value in parsed_args.__dict__.items(): 219 | print(f"{key}: {value}") 220 | print("####################") 221 | 222 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices 223 | inference_distributed(parsed_args.__dict__) 224 | -------------------------------------------------------------------------------- /scripts/generate_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | DistributedDataParallel generation script implemented from scratch. 3 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`. 4 | The script is designed for multi-GPU parallelism on single node. 5 | """ 6 | 7 | import argparse 8 | from datetime import datetime 9 | import json 10 | import os 11 | from typing import Any, Dict, List 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.utils.data.distributed import DistributedSampler 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer, PreTrainedTokenizer 19 | 20 | from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader 21 | from .train_instruct import load_model, setup, cleanup 22 | import scripts.utils_argparse as utils_argparse 23 | 24 | 25 | argParser = argparse.ArgumentParser() 26 | 27 | argParser.add_argument("--esm_path", type=str) 28 | argParser.add_argument("--llama_path", type=str) 29 | argParser.add_argument("--root_dataset_dir", type=str) 30 | argParser.add_argument("--root_csv_dir", type=str) 31 | argParser.add_argument("--save_generation_dir", type=str) 32 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None) 33 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="") 34 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="") 35 | 36 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype) 37 | argParser.add_argument("--batch_size_per_device", type=int) 38 | argParser.add_argument("--random_seed", type=int) 39 | argParser.add_argument("--generate_split", type=str) 40 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None) 41 | argParser.add_argument("--max_sequence_length", type=int, default=None) 42 | argParser.add_argument("--max_generation_length", type=int) 43 | argParser.add_argument("--num_beams", type=int, default=1) 44 | argParser.add_argument("--length_penalty", type=float, default=1.0) 45 | argParser.add_argument("--temperature", type=float, default=1.0) 46 | argParser.add_argument("--do_sample", type=utils_argparse.str2bool, default=False) 47 | argParser.add_argument("--top_p", type=float, default=1.0) 48 | argParser.add_argument("--top_k", type=int, default=50) 49 | 50 | 51 | def iterative_generation_loop( 52 | rank: int, 53 | model: torch.nn.Module, 54 | data_batch: Dict[str, Any], 55 | max_generation_length: int, 56 | num_beams: int, 57 | length_penalty: float, 58 | temperature: float, 59 | do_sample: bool, 60 | top_p: float, 61 | top_k: int 62 | ) -> torch.Tensor: 63 | """ 64 | Standard API for different models. Used in `inference_epoch`. 65 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader. 66 | 2) Execute the generation cycle and return the direct output. 67 | Returned output is a `torch.Tensor` of the generated tokens. 68 | """ 69 | if isinstance(model, DistributedDataParallel): 70 | model = model.module # for wrapper models, get the inner model for generation 71 | 72 | return model.generate( 73 | inputs=data_batch["input_ids"].to(rank), 74 | attention_mask=data_batch["attention_mask"].to(rank), 75 | protein_input_ids=data_batch["protein_input_ids"].to(rank), 76 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank), 77 | max_new_tokens=max_generation_length, 78 | eos_token_id=128009, 79 | pad_token_id=128002, 80 | return_dict_in_generate=False, 81 | num_beams=num_beams, 82 | length_penalty=length_penalty, 83 | temperature=temperature, 84 | do_sample=do_sample, 85 | top_p=top_p, 86 | top_k=top_k 87 | ) 88 | 89 | 90 | def inference_epoch( 91 | rank: int, 92 | model: DistributedDataParallel, 93 | dataloader: Prot2TextInstructDataLoader, 94 | llama_tokenizer: PreTrainedTokenizer, 95 | args: Dict[str, Any] 96 | ): 97 | """ 98 | Iterate over all batches for inference with iterative loop. 99 | Generation results will be saved to JSON files. 100 | """ 101 | model.eval() 102 | local_names: List[str] = [] 103 | local_predictions: List[str] = [] 104 | local_labels: List[str] = [] 105 | 106 | # core loop for batches 107 | t = tqdm(iter(dataloader)) 108 | for data_batch in t: 109 | with torch.no_grad(): 110 | output = iterative_generation_loop( 111 | rank=rank, 112 | model=model, 113 | data_batch=data_batch, 114 | max_generation_length=args["max_generation_length"], 115 | num_beams=args["num_beams"], 116 | length_penalty=args["length_penalty"], 117 | temperature=args["temperature"], 118 | do_sample=args["do_sample"], 119 | top_p=args["top_p"], 120 | top_k=args["top_k"] 121 | ) 122 | local_names.extend(data_batch["name"]) 123 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True) 124 | local_predictions.extend(predicted_texts) 125 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True) 126 | local_labels.extend(label_texts) 127 | t.set_postfix({ 128 | "mode": "inference", 129 | "batch_maxlen_gen": output.shape[1], 130 | "device": f"rank:{rank}" 131 | }) 132 | 133 | local_json_path = os.path.join( 134 | args["save_generation_dir"], 135 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json" 136 | ) 137 | with open(local_json_path, "w") as file: 138 | json_dict = { 139 | name: {"true": label, "pred": prediction} 140 | for name, label, prediction in zip(local_names, local_labels, local_predictions) 141 | } 142 | json.dump(json_dict, file, indent=4) 143 | print(f"Saving {local_json_path}") 144 | 145 | 146 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]): 147 | """Core generation process for every device with batches over the whole dataset""" 148 | setup(rank, world_size) 149 | 150 | # prepare dataset and dataloader 151 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"]) 152 | llama_tokenizer = AutoTokenizer.from_pretrained( 153 | args["llama_path"], 154 | pad_token='<|reserved_special_token_0|>' 155 | ) 156 | 157 | generate_dataset = Prot2TextInstructDataset( 158 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"), 159 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"), 160 | sequence_tokenizer=esm_tokenizer, 161 | description_tokenizer=llama_tokenizer, 162 | skip_reload=True, 163 | skip_download=True, 164 | ignore_graph_features=True, 165 | ) 166 | if args["debug_trim_generate_split"]: 167 | generate_dataset.usable_file_names = generate_dataset.usable_file_names[ 168 | :args["debug_trim_generate_split"] 169 | ] 170 | generate_sampler = DistributedSampler( 171 | generate_dataset, 172 | rank=rank, 173 | num_replicas=world_size, 174 | shuffle=False 175 | ) 176 | generate_loader = Prot2TextInstructDataLoader( 177 | generate_dataset, 178 | mode="inference", 179 | batch_size=args["batch_size_per_device"], 180 | sampler=generate_sampler, 181 | num_workers=2, 182 | pin_memory=True, 183 | shuffle=False, 184 | drop_last=True, 185 | ) 186 | 187 | # load base model and then the checkpoint and the adapter 188 | model = load_model(args=args) 189 | 190 | # merge peft adapter for inference 191 | model = model.merge_and_unload() 192 | 193 | model = model.to(rank) 194 | model = DistributedDataParallel(model) 195 | print(f"DDP model loaded on rank{rank}") 196 | 197 | inference_epoch( 198 | rank=rank, 199 | model=model, 200 | dataloader=generate_loader, 201 | llama_tokenizer=llama_tokenizer, 202 | args=args 203 | ) 204 | # use a barrier to make sure that all processes have finished writing their JSON files 205 | dist.barrier() 206 | 207 | cleanup() 208 | 209 | 210 | def inference_distributed(args: Dict[str, Any]): 211 | """Core generation process across multiple devices with batches over the whole dataset""" 212 | torch.multiprocessing.spawn( 213 | inference_on_device, 214 | args=(args["world_size"], args), 215 | nprocs=args["world_size"], 216 | join=True 217 | ) 218 | 219 | 220 | if __name__ == "__main__": 221 | # suppress messages from AutoTokenizer parallelism and Graphein respectively 222 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 223 | os.environ["LOGURU_LEVEL"] = "INFO" 224 | 225 | parsed_args = argParser.parse_args() 226 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs 227 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes 228 | 229 | torch.manual_seed(parsed_args.random_seed) 230 | torch.cuda.manual_seed(parsed_args.random_seed) 231 | 232 | # prepare for saving path 233 | if not os.path.exists(parsed_args.save_generation_dir): 234 | os.makedirs(parsed_args.save_generation_dir) 235 | 236 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S") 237 | if parsed_args.save_generation_postfix_identifier: 238 | parsed_args.save_generation_postfix_identifier = ( 239 | f"{start_timestamp}_[{parsed_args.save_generation_postfix_identifier}]" 240 | ) 241 | else: 242 | parsed_args.save_generation_postfix_identifier = start_timestamp 243 | 244 | print("####################") 245 | for key, value in parsed_args.__dict__.items(): 246 | print(f"{key}: {value}") 247 | print("####################") 248 | 249 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices 250 | inference_distributed(parsed_args.__dict__) 251 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment 2 | 3 |
4 | transformers 5 | based-on-esm 6 | based-on-llama 7 | ai4biology 8 | license-mit 9 |
10 | 11 |
12 | | 📃 Preprint | 13 | 📜 NeurIPS | 14 | 🤗 Server | 15 | 🤗 Model | 16 | 🤗 Dataset | 17 |
18 | 19 |
20 | 21 | This is the official repository for the paper "**Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment**" by Xiao Fei, Michail Chatzianastasis, Sarah Almeida Carneiro, Hadi Abdine, Lawrence P. Petalidis, and Michalis Vazirgiannis. 22 | 23 | We're excited to share that our paper has been accepted to 🎉 **NeurIPS 2025** ! An online server, the trained model weights and the dataset are now publicly available on Hugging Face. 24 | 25 | ## About the Project 26 | 27 | Proteins are written in a code made of amino acids, but what if we could actually read that code like a language? 28 | 29 | **Prot2Text-V2** treats a protein sequence as if it were another language, and then translate it into English. The model takes the raw amino acid sequence as input and generates a clear, human-readable paragraph describing what the protein does. 30 | 31 |
32 | Model Architecture 33 |
34 | 35 | The instruction-based Prot2Text-V2 model is an innovative fusion of three key components: 36 | 37 | * Protein language model as sequence encoder: `facebook/esm2_t36_3B_UR50D` 38 | * Modality adapter as a unique and lightweight component that bridges the gap between protein embeddings and the language model. 39 | * Natural language decoder for generating articulate textual descriptions utilizing the sequence embeddings: `meta-llama/Llama-3.1-8B-Instruct` 40 | 41 |
42 | Training Stages 43 |
44 | 45 | A clever alignment step first captures the semantic meaning of the sequence, after which supervised fine-tuning trains the decoder to generate articulate descriptions. 46 | 47 | For backward compatibility, the repository also includes our legacy base model, `Esm2LlamaForCausalLM`, along with its specialized dataloader. 48 | 49 | ## Getting Started 50 | 51 | ✅ Verified on Ubuntu-22.04-LTS with 2 x NVIDIA RTX A6000 52 | 53 | ✅ Verified on RHEL-9.4 with 8 x NVIDIA A100 54 | 55 | * Install NVIDIA `cuda-toolkit=12.1.1`, see official website for detailed information. 56 | 57 | * Install `dssp=4.0.4` for protein dataset preprocessing: 58 | 59 | ```shell 60 | sudo apt-get install dssp=4.0.4 61 | ``` 62 | 63 | * Create environment with `conda` then install packages with `pip`: 64 | 65 | ```shell 66 | conda create -n prot2text-pip python=3.8 67 | 68 | pip3 install torch torchvision torchaudio # torch==2.3.0 69 | pip3 install torch_geometric 70 | pip3 install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html 71 | 72 | pip3 install graphein==1.7.7 73 | 74 | pip3 install transformers==4.40.2 tokenizers==0.19.1 accelerate==0.29.3 sentencepiece==0.2.0 75 | pip3 install peft==0.10.0 76 | pip3 install biopython==1.81 77 | pip3 install networkx==2.5 78 | pip3 install chardet==5.2.0 charset-normalizer==2.0.4 79 | pip3 install multiprocess==0.70.16 80 | pip3 install tensorboard==2.14.0 81 | pip3 install evaluate==0.4.2 82 | pip3 install mpi4py==3.1.6 83 | 84 | sudo apt install libaio-dev 85 | DS_BUILD_FUSED_ADAM pip3 install deepspeed==0.14.2 86 | 87 | pip3 install nltk==3.8.1 rouge_score==0.1.2 jiwer==3.0.4 88 | ``` 89 | 90 | ## Dataset Preparation 91 | 92 | * Download CSV files from [HuggingFace](https://huggingface.co/datasets/habdine/Prot2Text-Data) and place under `./data`. 93 | 94 | * Download PDB files from AlphaFoldDB (for RGCN only) then preprocess graph and text features: 95 | 96 | ```python 97 | from transformers import AutoTokenizer 98 | from dataset import Prot2TextInstructDataset 99 | 100 | SPLIT = "train" # run script for "eval" and "test" as well 101 | CSV_DIR = "./data" 102 | DATA_ROOT_DIR = "/data/Prot2Text-Llama3-Data" 103 | LLAMA_DIR = "meta-llama/Meta-Llama-3.1-8B-Instruct-hf" 104 | ESM_DIR = "facebook/esm2_t36_3B_UR50D" 105 | 106 | split_dataset = Prot2TextInstructDataset( 107 | root_dir=os.path.join(DATA_ROOT_DIR, SPLIT), 108 | csv_path=os.path.join(CSV_DIR, f"{SPLIT}.csv"), 109 | sequence_tokenizer=AutoTokenizer.from_pretrained(ESM_DIR), 110 | description_tokenizer=AutoTokenizer.from_pretrained(LLAMA_DIR, pad_token='<|reserved_special_token_0|>'), 111 | skip_download=False, 112 | skip_reload=False, 113 | ) 114 | ``` 115 | 116 | * [Optional] In case of applying new language tokenizer to a preprocessed dataset, run the following to avoid processing graphs again: 117 | 118 | ```python 119 | NEW_LLAMA_DIR = "/data/Llama-3.2-1B" 120 | 121 | split_dataset = Prot2TextInstructDataset( 122 | root_dir=os.path.join(DATA_ROOT_DIR, SPLIT), 123 | csv_path=os.path.join(CSV_DIR, f"{SPLIT}.csv"), 124 | sequence_tokenizer=AutoTokenizer.from_pretrained(ESM_DIR), 125 | description_tokenizer=AutoTokenizer.from_pretrained(NEW_LLAMA_DIR, pad_token='<|reserved_special_token_0|>'), 126 | skip_download=True, 127 | skip_reload=True, 128 | ) 129 | split_dataset.process_text() 130 | ``` 131 | 132 | ## Model Training Pipeline 133 | 134 | ### 1. Contrastive Learning Stage 135 | `./scripts/train_contrast.py` performs contrastive learning to align protein representations with textual descriptions. This stage helps the model learn meaningful cross-modal embeddings. 136 | 137 | **Arguments:** 138 | - **Model Paths:** 139 | - `--esm_path`: Path to pretrained ESM protein language model 140 | - `--llama_path`: Path to pretrained LLaMA language model 141 | - **Data Directories:** 142 | - `--root_dataset_dir`: Root directory containing protein datasets 143 | - `--root_csv_dir`: Directory containing CSV metadata files 144 | - **Checkpoint Handling:** 145 | - `--save_checkpoint_dir`: Directory to save model checkpoints 146 | - `--load_model_checkpoint_path`: Path to load full model checkpoint (optional) 147 | - `--load_optimizer_scheduler_checkpoint_path`: Path to load optimizer/scheduler state (optional) 148 | - **Training Parameters:** 149 | - `--torch_dtype`: PyTorch data type for training (e.g., float16, float32) 150 | - `--batch_size_per_device`: Batch size per GPU/device 151 | - `--num_epochs`: Total number of training epochs 152 | - `--save_every_epochs`: Frequency of checkpoint saving (in epochs) 153 | - `--gradient_accumulation_steps`: Number of steps for gradient accumulation 154 | - `--learning_rate`: Initial learning rate 155 | - `--gradient_clipping`: Gradient clipping value (optional) 156 | - `--scheduler_gamma`: Learning rate scheduler gamma value 157 | - `--random_seed`: Random seed for reproducibility 158 | - `--contrastive_num_segments`: Number of segments for contrastive learning 159 | - **Data Splits:** 160 | - `--train_split`: Name of training split 161 | - `--eval_split`: Name of evaluation split 162 | - `--debug_trim_train_split`: Trim training set for sanity check (optional) 163 | - `--debug_trim_eval_split`: Trim evaluation set for sanity check (optional) 164 | 165 | ### 2. Supervised Fine-Tuning Stage 166 | After contrastive learning, run `./scripts/train_instruct.py` for instruction fine-tuning on the training set. 167 | 168 | **Additional/Modified Arguments:** 169 | - **Adapter Configuration:** 170 | - `--load_adapter_checkpoint_dir`: Directory to load adapter checkpoints 171 | - `--fix_modality_adapter`: Whether to freeze modality adapter weights 172 | - `--lora_rank`: Rank for LoRA adapter layers 173 | - **Text Field Handling:** 174 | - `--include_text_fields`: Whether to include text fields in input 175 | - `--name_dropout`: Dropout rate for protein names 176 | - `--taxonomy_dropout`: Dropout rate for taxonomy information 177 | 178 | ## Performance Evaluation 179 | 180 | ### 1. Generation (`generate_instruct.py`) 181 | Generates answers for proteins in the test set using a trained model. 182 | 183 | **Key Arguments:** 184 | - **Generation Parameters:** 185 | - `--max_generation_length`: Maximum length of generated text 186 | - `--num_beams`: Number of beams for beam search 187 | - `--temperature`: Sampling temperature 188 | - `--do_sample`: Whether to use sampling 189 | - `--top_p`: Nucleus sampling probability 190 | - `--top_k`: Top-k sampling value 191 | - **Output Control:** 192 | - `--save_generation_postfix_identifier`: Identifier for output files 193 | - `--max_sequence_length`: Maximum input sequence length 194 | 195 | ### 2. Benchmarking (`benchmark.py`) 196 | Evaluates generated outputs using various metrics. 197 | 198 | **Evaluation Options:** 199 | - `--evaluate_exact_match`: Compute exact match accuracy 200 | - `--evaluate_bleu`: Compute BLEU scores 201 | - `--evaluate_rouge`: Compute ROUGE scores 202 | - `--evaluate_bert_score`: Compute BERTScore 203 | - `--read_file_identifier`: Filter generated files by this identifier 204 | - `--verbose`: Print detailed evaluation results 205 | 206 | ## Usage Notes: 207 | 1. For the full training pipeline, first run `train_contrast.py`, then `train_instruct.py` 208 | 2. Generation should use the same data splits used during evaluation 209 | 3. Benchmarking can be customized to compute only relevant metrics 210 | 4. Debug arguments allow for faster iteration during development 211 | 212 | The pipeline supports both full fine-tuning and parameter-efficient approaches (LoRA, adapter layers) through the various adapter-related arguments. 213 | 214 | ## Ⓒ Citation 215 | 216 | If you find our research helpful, feel free to 🖋️ cite our work or ⭐️ star the repository: 217 | 218 | ```bibtex 219 | @misc{prot2textv2, 220 | title={Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment}, 221 | author={Xiao Fei and Michail Chatzianastasis and Sarah Almeida Carneiro and Hadi Abdine and Lawrence P. Petalidis and Michalis Vazirgiannis}, 222 | year={2025}, 223 | eprint={2505.11194}, 224 | archivePrefix={arXiv}, 225 | primaryClass={cs.CE}, 226 | url={https://arxiv.org/abs/2505.11194}, 227 | } 228 | ``` 229 | -------------------------------------------------------------------------------- /scripts/generate_instruct_light.py: -------------------------------------------------------------------------------- 1 | """ 2 | DistributedDataParallel generation script implemented from scratch. 3 | Using Prot2TextLightDataset instead of Prot2TextInstructDataset. 4 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`. 5 | The script is designed for multi-GPU parallelism on single node. 6 | """ 7 | 8 | import argparse 9 | from datetime import datetime 10 | import json 11 | import os 12 | from typing import Any, Dict, List 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from torch.nn.parallel import DistributedDataParallel 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.distributed import DistributedSampler 19 | from tqdm import tqdm 20 | from transformers import AutoTokenizer, PreTrainedTokenizer 21 | 22 | # from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader 23 | from dataset import Prot2TextLightDataset, Prot2TextLightCollater 24 | from .train_instruct import load_model, setup, cleanup 25 | import scripts.utils_argparse as utils_argparse 26 | 27 | 28 | argParser = argparse.ArgumentParser() 29 | 30 | argParser.add_argument("--esm_path", type=str) 31 | argParser.add_argument("--llama_path", type=str) 32 | argParser.add_argument("--root_dataset_dir", type=str) 33 | argParser.add_argument("--root_csv_dir", type=str) 34 | argParser.add_argument("--save_generation_dir", type=str) 35 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None) 36 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="") 37 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="") 38 | 39 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype) 40 | argParser.add_argument("--batch_size_per_device", type=int) 41 | argParser.add_argument("--random_seed", type=int) 42 | argParser.add_argument("--generate_split", type=str) 43 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None) 44 | argParser.add_argument("--max_description_length", type=int, default=1021) # NEW 45 | argParser.add_argument("--max_sequence_length", type=int, default=512) 46 | argParser.add_argument("--max_generation_length", type=int) 47 | argParser.add_argument("--num_beams", type=int, default=1) 48 | argParser.add_argument("--length_penalty", type=float, default=1.0) 49 | argParser.add_argument("--temperature", type=float, default=1.0) 50 | argParser.add_argument("--do_sample", type=utils_argparse.str2bool, default=False) 51 | argParser.add_argument("--top_p", type=float, default=1.0) 52 | argParser.add_argument("--top_k", type=int, default=50) 53 | 54 | 55 | def iterative_generation_loop( 56 | rank: int, 57 | model: torch.nn.Module, 58 | data_batch: Dict[str, Any], 59 | max_generation_length: int, 60 | num_beams: int, 61 | length_penalty: float, 62 | temperature: float, 63 | do_sample: bool, 64 | top_p: float, 65 | top_k: int 66 | ) -> torch.Tensor: 67 | """ 68 | Standard API for different models. Used in `inference_epoch`. 69 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader. 70 | 2) Execute the generation cycle and return the direct output. 71 | Returned output is a `torch.Tensor` of the generated tokens. 72 | """ 73 | if isinstance(model, DistributedDataParallel): 74 | model = model.module # for wrapper models, get the inner model for generation 75 | 76 | return model.generate( 77 | inputs=data_batch["input_ids"].to(rank), 78 | attention_mask=data_batch["attention_mask"].to(rank), 79 | protein_input_ids=data_batch["protein_input_ids"].to(rank), 80 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank), 81 | max_new_tokens=max_generation_length, 82 | eos_token_id=128009, 83 | pad_token_id=128002, 84 | return_dict_in_generate=False, 85 | num_beams=num_beams, 86 | length_penalty=length_penalty, 87 | temperature=temperature, 88 | do_sample=do_sample, 89 | top_p=top_p, 90 | top_k=top_k 91 | ) 92 | 93 | 94 | def inference_epoch( 95 | rank: int, 96 | model: DistributedDataParallel, 97 | # dataloader: Prot2TeqxtInstructDataLoader, 98 | dataloader: DataLoader, 99 | llama_tokenizer: PreTrainedTokenizer, 100 | args: Dict[str, Any] 101 | ): 102 | """ 103 | Iterate over all batches for inference with iterative loop. 104 | Generation results will be saved to JSON files. 105 | """ 106 | model.eval() 107 | local_names: List[str] = [] 108 | local_predictions: List[str] = [] 109 | local_labels: List[str] = [] 110 | 111 | # core loop for batches 112 | t = tqdm(iter(dataloader)) 113 | for data_batch in t: 114 | with torch.no_grad(): 115 | output = iterative_generation_loop( 116 | rank=rank, 117 | model=model, 118 | data_batch=data_batch, 119 | max_generation_length=args["max_generation_length"], 120 | num_beams=args["num_beams"], 121 | length_penalty=args["length_penalty"], 122 | temperature=args["temperature"], 123 | do_sample=args["do_sample"], 124 | top_p=args["top_p"], 125 | top_k=args["top_k"] 126 | ) 127 | local_names.extend(data_batch["name"]) 128 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True) 129 | local_predictions.extend(predicted_texts) 130 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True) 131 | local_labels.extend(label_texts) 132 | t.set_postfix({ 133 | "mode": "inference", 134 | "batch_maxlen_gen": output.shape[1], 135 | "device": f"rank:{rank}" 136 | }) 137 | 138 | local_json_path = os.path.join( 139 | args["save_generation_dir"], 140 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json" 141 | ) 142 | with open(local_json_path, "w") as file: 143 | json_dict = { 144 | name: {"true": label, "pred": prediction} 145 | for name, label, prediction in zip(local_names, local_labels, local_predictions) 146 | } 147 | json.dump(json_dict, file, indent=4) 148 | print(f"Saving {local_json_path}") 149 | 150 | 151 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]): 152 | """Core generation process for every device with batches over the whole dataset""" 153 | setup(rank, world_size) 154 | 155 | # prepare dataset and dataloader 156 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"]) 157 | llama_tokenizer = AutoTokenizer.from_pretrained( 158 | args["llama_path"], 159 | pad_token='<|reserved_special_token_0|>' 160 | ) 161 | 162 | generate_dataset = Prot2TextLightDataset( 163 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv") 164 | ) 165 | 166 | generate_collater = Prot2TextLightCollater( 167 | sequence_tokenizer=esm_tokenizer, 168 | description_tokenizer=llama_tokenizer, 169 | mode="inference", 170 | include_text_fields=True, 171 | name_dropout=0.0, 172 | taxonomy_dropout=0.0, 173 | ) 174 | 175 | #generate_dataset = Prot2TextInstructDataset( 176 | # root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"), 177 | # csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"), 178 | # sequence_tokenizer=esm_tokenizer, 179 | # description_tokenizer=llama_tokenizer, 180 | # skip_reload=True, 181 | # skip_download=True, 182 | # ignore_graph_features=True, 183 | #) 184 | # if args["debug_trim_generate_split"]: 185 | # generate_dataset.usable_file_names = generate_dataset.usable_file_names[ 186 | # :args["debug_trim_generate_split"] 187 | # ] 188 | generate_sampler = DistributedSampler( 189 | generate_dataset, 190 | rank=rank, 191 | num_replicas=world_size, 192 | shuffle=False 193 | ) 194 | generate_loader = DataLoader( 195 | generate_dataset, 196 | batch_size=args["batch_size_per_device"], 197 | sampler=generate_sampler, 198 | num_workers=2, 199 | pin_memory=True, 200 | shuffle=False, 201 | drop_last=True, 202 | collate_fn=generate_collater 203 | ) 204 | #generate_loader = Prot2TextInstructDataLoader( 205 | # generate_dataset, 206 | # mode="inference", 207 | # batch_size=args["batch_size_per_device"], 208 | # sampler=generate_sampler, 209 | # num_workers=2, 210 | # pin_memory=True, 211 | # shuffle=False, 212 | # drop_last=True, 213 | #) 214 | 215 | # load base model and then the checkpoint and the adapter 216 | model = load_model(args=args) 217 | 218 | # if rank == 0: 219 | # for name, param in model.named_parameters(): 220 | # print(name, param.requires_grad) 221 | 222 | # merge peft adapter for inference 223 | model = model.merge_and_unload() 224 | model.train() 225 | model.adapter.fc1.requires_grad = True 226 | 227 | model = model.to(rank) 228 | model = DistributedDataParallel(model) 229 | print(f"DDP model loaded on rank{rank}") 230 | 231 | inference_epoch( 232 | rank=rank, 233 | model=model, 234 | dataloader=generate_loader, 235 | llama_tokenizer=llama_tokenizer, 236 | args=args 237 | ) 238 | # use a barrier to make sure that all processes have finished writing their JSON files 239 | dist.barrier() 240 | 241 | cleanup() 242 | 243 | 244 | def inference_distributed(args: Dict[str, Any]): 245 | """Core generation process across multiple devices with batches over the whole dataset""" 246 | torch.multiprocessing.spawn( 247 | inference_on_device, 248 | args=(args["world_size"], args), 249 | nprocs=args["world_size"], 250 | join=True 251 | ) 252 | 253 | 254 | if __name__ == "__main__": 255 | # suppress messages from AutoTokenizer parallelism and Graphein respectively 256 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 257 | os.environ["LOGURU_LEVEL"] = "INFO" 258 | 259 | parsed_args = argParser.parse_args() 260 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs 261 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes 262 | 263 | torch.manual_seed(parsed_args.random_seed) 264 | torch.cuda.manual_seed(parsed_args.random_seed) 265 | 266 | # prepare for saving path 267 | if not os.path.exists(parsed_args.save_generation_dir): 268 | os.makedirs(parsed_args.save_generation_dir) 269 | 270 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S") 271 | if parsed_args.save_generation_postfix_identifier: 272 | parsed_args.save_generation_postfix_identifier = ( 273 | f"{start_timestamp}_[{parsed_args.save_generation_postfix_identifier}]" 274 | ) 275 | else: 276 | parsed_args.save_generation_postfix_identifier = start_timestamp 277 | 278 | print("####################") 279 | for key, value in parsed_args.__dict__.items(): 280 | print(f"{key}: {value}") 281 | print("####################") 282 | 283 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices 284 | inference_distributed(parsed_args.__dict__) 285 | -------------------------------------------------------------------------------- /dataset/dataloader_light.py: -------------------------------------------------------------------------------- 1 | """ 2 | Light-weight dataset and data collater class for protein function prediction 3 | instruction tuning. To be used with Esm2LlamaInstructForCausalLM. 4 | 5 | Such flexible implementation is designed to fetch raw text data from a CSV file 6 | and perform tokenization and padding on-the-fly. This is useful when the default 7 | user message and chat template is not suitable for the task at hand. 8 | 9 | Can only be used if the model is not requiring graph-related data. 10 | 11 | Every batch from DataLoader will contain following attributes: 12 | * Training mode (train-eval with teacher-forcing): 13 | - graph related features: 14 | None 15 | - amino-acid sequence: 16 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 17 | - protein_attention_mask (bsz, max_seq_len+2) # right padding 18 | - concatenated chat: 19 | - input_ids (bsz, max_prompt_len+max_text_len+1) 20 | - attention_mask (bsz, max_prompt_len+max_text_len+1) 21 | - labels (bsz, max_prompt_len+max_text_len+1) 22 | - standalone description for contrastive learning: 23 | - description_input_ids (bsz, max_text_len+1) # eos token only 24 | - description_attention_mask (bsz, max_text_len+1) # right padding 25 | 26 | ids = [left-pad + bos + prompt & description + eot + right-pad] 27 | mask = [0s + 1 + 1s & 1s + 1 + 0s ] 28 | labels = [-100s + -100 + -100s & description + eot + -100s ] 29 | desc_ids = [ & description + eot + right-pad] 30 | desc_mask = [ & 1s + 1 + 0s ] 31 | 32 | * Inference mode (iterative generation): 33 | - graph related features: 34 | None 35 | - amino-acid sequence: 36 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 37 | - protein_attention_mask (bsz, max_seq_len+2) # right padding 38 | - prompt chat: 39 | - input_ids (bsz, max_prompt_len) 40 | - attention_mask (bsz, max_prompt_len) 41 | - description_input_ids (bsz, max_text_len+1) # for evaluation 42 | 43 | ids = [left-pad + bos + prompt & ] 44 | mask = [0s + 1 + 1s & ] 45 | desc_ids = [ & description + eot + right-pad] 46 | 47 | Example of usage: 48 | >>> from torch.utils.data import DataLoader 49 | >>> from transformers import AutoTokenizer 50 | >>> from dataset import Prot2TextLightDataset, Prot2TextLightCollater 51 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D") 52 | >>> llama_tokenizer = AutoTokenizer.from_pretrained( 53 | "/data/Meta-Llama-3.1-8B-Instruct-hf", 54 | pad_token='<|reserved_special_token_0|>' 55 | ) 56 | >>> train_dataset = Prot2TextLightDataset("./data/train.csv") 57 | >>> train_collater = Prot2TextLightCollater( 58 | sequence_tokenizer=esm_tokenizer, 59 | description_tokenizer=llama_tokenizer, 60 | mode="train" 61 | ) 62 | >>> train_dataloader = DataLoader( 63 | train_dataset, 64 | batch_size=4, 65 | shuffle=True, 66 | num_workers=4, 67 | collate_fn=train_collater, 68 | pin_memory=True, 69 | drop_last=True 70 | ) 71 | """ 72 | 73 | import random 74 | from typing import Dict, List, Literal, Optional 75 | 76 | import pandas as pd 77 | import torch 78 | import torch.utils.data 79 | from transformers import PreTrainedTokenizer 80 | 81 | 82 | class Prot2TextLightDataset(torch.utils.data.Dataset): 83 | """Dataset class loading directly from single CSV file.""" 84 | def __init__(self, csv_path: str): 85 | super().__init__() 86 | self.data: pd.DataFrame = pd.read_csv(csv_path) 87 | 88 | def __len__(self) -> int: 89 | return len(self.data) 90 | 91 | def __getitem__(self, idx: int) -> Dict[str, str]: 92 | return { 93 | column_name: self.data.iloc[idx][column_name] 94 | for column_name in self.data.columns 95 | } 96 | 97 | 98 | class Prot2TextLightCollater: 99 | def __init__( 100 | self, 101 | sequence_tokenizer: PreTrainedTokenizer, 102 | description_tokenizer: PreTrainedTokenizer, 103 | mode: Literal["train", "inference"] = "train", 104 | include_text_fields: bool = True, 105 | name_dropout: float = 0.8, 106 | taxonomy_dropout: float = 0.8, 107 | max_sequence_length: Optional[int] = 1021, 108 | max_description_length: Optional[int] = 512, 109 | system_message: str = ( 110 | "You are a scientific assistant specialized in protein function " 111 | "predictions. Given the sequence embeddings and other information " 112 | "of a protein, describe its function clearly and concisely in " 113 | "professional language. " 114 | ), 115 | placeholder_token: str = '<|reserved_special_token_1|>', 116 | ): 117 | self.sequence_tokenizer = sequence_tokenizer 118 | self.description_tokenizer = description_tokenizer 119 | self.mode = mode 120 | 121 | self.include_text_fields = include_text_fields 122 | self.name_dropout = name_dropout 123 | self.taxonomy_dropout = taxonomy_dropout 124 | 125 | self.max_sequence_length = max_sequence_length 126 | self.max_description_length = max_description_length 127 | self.system_message = system_message 128 | self.placeholder_token = placeholder_token 129 | 130 | def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]: 131 | # group data across batch 132 | accessions = [item["AlphaFoldDB"] for item in batch] 133 | fullnames = [item["Full Name"] for item in batch] 134 | taxons = [item["taxon"] for item in batch] 135 | sequences = [item["sequence"] for item in batch] 136 | descriptions = [item["function"] for item in batch] 137 | 138 | # replace nan in name and taxon with unknown 139 | fullnames = [ 140 | fullname 141 | if isinstance(fullname, str) and random.random() > self.name_dropout 142 | else "unknown" 143 | for fullname in fullnames 144 | ] 145 | taxons = [ 146 | taxon 147 | if isinstance(taxon, str) and random.random() > self.taxonomy_dropout 148 | else "unknown" 149 | for taxon in taxons 150 | ] 151 | 152 | # for each sequence in sequences 153 | # if the sequence is origianlly longer than max_sequence_length, take a segment of that length randomly 154 | # else do nothing 155 | for i in range(len(sequences)): 156 | if len(sequences[i]) > self.max_sequence_length: 157 | start = random.randint(0, len(sequences[i]) - self.max_sequence_length) 158 | sequences[i] = sequences[i][start:start + self.max_sequence_length] 159 | 160 | # truncate and tokenize sequences 161 | self.sequence_tokenizer.padding_side = "right" 162 | tokenized_sequences = self.sequence_tokenizer( 163 | sequences, 164 | truncation=True, 165 | padding="longest", 166 | max_length=self.max_sequence_length + 2, # including bos and eos tokens of esm tokenizer 167 | return_tensors="pt" 168 | ) 169 | sequence_input_ids = tokenized_sequences["input_ids"] 170 | sequence_attention_mask = tokenized_sequences["attention_mask"] 171 | 172 | # apply chat template 173 | sequence_lens = sequence_attention_mask.sum(dim=1).tolist() 174 | 175 | if self.include_text_fields: 176 | user_messages = [ 177 | # ( 178 | # (f"Protein name: {fullname}; " if fullname != "unknown" else "") 179 | # + (f"Taxon: {taxon}; " if taxon != "unknown" else "") 180 | # + "Sequence embeddings: " + self.placeholder_token * sequence_len 181 | # ) 182 | ( 183 | f"Protein name: {fullname}; Taxon: {taxon}; " 184 | + "Sequence embeddings: " + self.placeholder_token * sequence_len 185 | ) 186 | for fullname, taxon, sequence_len in zip(fullnames, taxons, sequence_lens) 187 | ] 188 | else: 189 | user_messages = [ 190 | "Sequence embeddings: " + self.placeholder_token * sequence_lens 191 | for sequence_lens in sequence_lens 192 | ] 193 | 194 | prompt_conversations = [ 195 | [ 196 | {"role": "system", "content": self.system_message}, 197 | {"role": "user", "content": user_message} 198 | ] 199 | for user_message in user_messages 200 | ] 201 | 202 | # tokenize prompts 203 | self.description_tokenizer.padding_side = "left" 204 | tokenized_prompts = self.description_tokenizer.apply_chat_template( 205 | prompt_conversations, 206 | add_generation_prompt=True, 207 | tokenize=True, 208 | padding="longest", 209 | return_tensors="pt", 210 | return_dict=True 211 | ) 212 | prompt_input_ids = tokenized_prompts["input_ids"] 213 | prompt_attention_mask = tokenized_prompts["attention_mask"] 214 | 215 | # tokenize descriptions 216 | self.description_tokenizer.padding_side = "right" 217 | tokenized_descriptions = self.description_tokenizer( 218 | [description + self.description_tokenizer.eos_token for description in descriptions], 219 | add_special_tokens=False, # do not add bos token to the beginning 220 | truncation=True, 221 | padding="longest", 222 | max_length=self.max_description_length, 223 | return_tensors="pt" 224 | ) 225 | description_input_ids = tokenized_descriptions["input_ids"] 226 | description_attention_mask = tokenized_descriptions["attention_mask"] 227 | 228 | # truncate descriptions 229 | if description_input_ids.size(1) > self.max_description_length: 230 | description_input_ids = description_input_ids[:, :self.max_description_length] 231 | description_attention_mask = description_attention_mask[:, :self.max_description_length] 232 | 233 | # prepare labels 234 | labels = description_input_ids.clone() 235 | labels[description_attention_mask == 0] = -100 236 | 237 | # assemble 238 | if self.mode == "train": 239 | return { 240 | "name": accessions, 241 | "protein_input_ids": sequence_input_ids, 242 | "protein_attention_mask": sequence_attention_mask, 243 | "input_ids": torch.cat([ 244 | prompt_input_ids, 245 | description_input_ids, 246 | ], dim=1), 247 | "attention_mask": torch.cat([ 248 | prompt_attention_mask, 249 | description_attention_mask, 250 | ], dim=1), 251 | "labels": torch.cat([ 252 | torch.full_like( 253 | prompt_input_ids, 254 | fill_value=-100, 255 | ), 256 | labels, 257 | ], dim=1), 258 | "description_input_ids": description_input_ids, 259 | "description_attention_mask": description_attention_mask 260 | } 261 | 262 | elif self.mode == "inference": 263 | return { 264 | "name": accessions, 265 | "protein_input_ids": sequence_input_ids, 266 | "protein_attention_mask": sequence_attention_mask, 267 | "input_ids": prompt_input_ids, 268 | "attention_mask": prompt_attention_mask, 269 | "description_input_ids": description_input_ids, 270 | } 271 | 272 | else: 273 | raise ValueError(f"Invalid mode: {self.mode}") 274 | -------------------------------------------------------------------------------- /models/modeling_esm2llama_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model. 3 | 4 | Esm2LlamaInstructForCausalLM = EsmModel + ModalityAdapter + LlamaForCausalLM 5 | 6 | For training/evaluation under teacher-forcing scenario, the model `forward` 7 | function shall take following arguments: 8 | * input_ids: (bsz, prompt_len+description_len) # whole chat template 9 | * attention_mask: (bsz, prompt_len+description_len) # left & right padding 10 | * position_ids: (bsz, prompt_len+description_len) # optional 11 | * past_key_values: None 12 | * labels: (bsz, prompt_len+description_len) # -100 for padding & prompt 13 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds 14 | * protein_attention_mask: (bsz, prot_seq_len) # right padding 15 | * protein_position_ids: (bsz, prot_seq_len) # optional 16 | * protein_head_mask: (num_heads,) or (num_layers, num_heads) # optional 17 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional 18 | * use_cache: False 19 | * return_decoder_inputs: False 20 | 21 | For inference, the model `generate` function shall take following arguments: 22 | * inputs: (bsz, prompt_len) # prompt part of chat template 23 | * attention_mask: (bsz, prompt_len) # left padding 24 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds 25 | * protein_attention_mask: (bsz, prot_seq_len) # right padding 26 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional 27 | """ 28 | 29 | 30 | from typing import Optional, Tuple, Union 31 | 32 | import torch 33 | from transformers import Cache, PreTrainedModel 34 | from transformers.generation.utils import GenerateOutput 35 | from transformers.modeling_outputs import CausalLMOutputWithPast 36 | from transformers.models.esm.modeling_esm import EsmModel 37 | from transformers.models.llama import LlamaForCausalLM 38 | 39 | from .configuration_esm2llama_instruct import ( 40 | ModalityAdapterConfig, 41 | Esm2LlamaInstructConfig 42 | ) 43 | 44 | 45 | class ModalityAdapter(PreTrainedModel): 46 | """2-layer adapter to match the hidden size of different modalities.""" 47 | config_class = ModalityAdapterConfig # configuration class for this model 48 | 49 | def __init__(self, config: ModalityAdapterConfig): 50 | super().__init__(config) 51 | self.config = config 52 | self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim) 53 | self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim) 54 | self.activation = torch.nn.GELU() 55 | self.dropout = torch.nn.Dropout(p=config.dropout_rate) 56 | self.ln1 = torch.nn.LayerNorm(normalized_shape=config.intermediate_dim) # DEPRECATED 57 | self.ln2 = torch.nn.LayerNorm(normalized_shape=config.output_dim) # DEPRECATED 58 | self.post_init() # initialize weights and apply final processing 59 | 60 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: 61 | # input: (bsz, seq_len, input_dim) 62 | hidden_states = self.activation(self.fc1(hidden_states)) 63 | hidden_states = self.dropout(hidden_states) 64 | # interm: (bsz, seq_len, interm_dim) 65 | hidden_states = self.activation(self.fc2(hidden_states)) 66 | hidden_states = self.dropout(hidden_states) 67 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1) 68 | return hidden_states # (bsz, seq_len, output_dim) 69 | 70 | 71 | class Esm2LlamaInstructForCausalLM(PreTrainedModel): 72 | """ 73 | Esm2LlamaInstructForCausalLM model for protein function prediction. 74 | Similar to `EncoderDecoderModel` but with more complicated architecture. 75 | Initialize with either a configuration OR all three components. 76 | `kwargs` can override standalone attributes in `Esm2LlamaInstructConfig`. 77 | """ 78 | config_class = Esm2LlamaInstructConfig # configuration class for this model 79 | 80 | def __init__( 81 | self, 82 | config: Optional[Esm2LlamaInstructConfig] = None, 83 | esm_encoder: Optional[EsmModel] = None, 84 | adapter: Optional[ModalityAdapter] = None, 85 | llama_decoder: Optional[LlamaForCausalLM] = None, 86 | **kwargs 87 | ): 88 | if config is not None: # components ignored if config is provided 89 | super().__init__(config) 90 | self.esm_encoder = EsmModel( 91 | config.esm_config, 92 | add_pooling_layer=False 93 | ) 94 | self.adapter = ModalityAdapter(config.adapter_config) 95 | self.llama_decoder = LlamaForCausalLM(config.llama_config) 96 | else: 97 | config = Esm2LlamaInstructConfig( 98 | esm_config=esm_encoder.config, 99 | adapter_config=adapter.config, 100 | llama_config=llama_decoder.config, 101 | **kwargs # override standalone attributes 102 | ) 103 | super().__init__(config) 104 | self.esm_encoder = esm_encoder 105 | self.adapter = adapter 106 | self.llama_decoder = llama_decoder 107 | 108 | def prepare_decoder_inputs( 109 | self, 110 | input_ids: torch.LongTensor, 111 | encoder_hidden_states: torch.FloatTensor, 112 | attention_mask: Optional[torch.LongTensor] = None, 113 | encoder_attention_mask: Optional[torch.LongTensor] = None, 114 | ): 115 | """ 116 | Embed and replace placeholder in `input_ids` by encoder hidden states. 117 | `input_ids` must be passed to locate placeholder for replacement. 118 | """ 119 | # preparation 120 | batch_size, seq_len = input_ids.size() 121 | _, encoder_seq_len, _ = encoder_hidden_states.size() 122 | if attention_mask is None: 123 | attention_mask = torch.ones( 124 | (batch_size, seq_len), 125 | dtype=torch.long, 126 | device=input_ids.device 127 | ) 128 | if encoder_attention_mask is None: 129 | encoder_attention_mask = torch.ones( 130 | (batch_size, encoder_seq_len), 131 | dtype=torch.long, 132 | device=encoder_hidden_states.device 133 | ) 134 | inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids) 135 | # replacement 136 | placeholder_mask = input_ids == self.config.placeholder_id 137 | encoder_mask = encoder_attention_mask.bool() 138 | inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask] 139 | return inputs_embeds, attention_mask 140 | 141 | def forward( 142 | self, 143 | # chat template text inputs 144 | input_ids: Optional[torch.LongTensor] = None, 145 | attention_mask: Optional[torch.LongTensor] = None, 146 | position_ids: Optional[torch.LongTensor] = None, 147 | past_key_values: Optional[Cache] = None, 148 | labels: Optional[torch.LongTensor] = None, 149 | # protein amino-acid sequence inputs 150 | protein_input_ids: Optional[torch.LongTensor] = None, 151 | protein_attention_mask: Optional[torch.LongTensor] = None, 152 | protein_position_ids: Optional[torch.LongTensor] = None, 153 | protein_head_mask: Optional[torch.LongTensor] = None, 154 | protein_inputs_embeds: Optional[torch.FloatTensor] = None, 155 | # behavior control arguments 156 | use_cache: Optional[bool] = None, 157 | output_attentions: Optional[bool] = None, 158 | output_hidden_states: Optional[bool] = None, 159 | return_dict: Optional[bool] = None, 160 | return_encoder_outputs: bool = False, 161 | return_adapter_outputs: bool = False, 162 | return_decoder_inputs: bool = False, 163 | cache_position: Optional[torch.LongTensor] = None 164 | ) -> Union[Tuple, CausalLMOutputWithPast]: 165 | """ 166 | Compute encoder and adapter outputs, then pass to decoder. 167 | `input_ids` is expected to be [prompt + description] in teacher-forcing 168 | scenario and [prompt] only in first iteration of inference (with 169 | return_decoder_inputs=True). 170 | Attention: possible concatenation of the mask and labels should be 171 | handled before calling this method. 172 | `inputs_embeds` not allowed due to placeholder replacement scheme. 173 | """ 174 | # esm_encoder forward 175 | encoder_output = self.esm_encoder( 176 | input_ids=protein_input_ids, 177 | attention_mask=protein_attention_mask, 178 | position_ids=protein_position_ids, 179 | head_mask=protein_head_mask, 180 | inputs_embeds=protein_inputs_embeds, 181 | use_cache=False, # because config.esm_config.is_decoder=False 182 | output_attentions=output_attentions, 183 | output_hidden_states=output_hidden_states, 184 | return_dict=return_dict 185 | ) 186 | encoder_hidden_states = encoder_output[0] 187 | encoder_attention_mask = protein_attention_mask 188 | if return_encoder_outputs: 189 | return encoder_output 190 | # adapter forward 191 | adapter_output = self.adapter(encoder_hidden_states) 192 | if return_adapter_outputs: 193 | return adapter_output, encoder_attention_mask 194 | # decoder input preparation 195 | inputs_embeds, attention_mask = self.prepare_decoder_inputs( 196 | input_ids=input_ids, 197 | encoder_hidden_states=adapter_output, 198 | attention_mask=attention_mask, 199 | encoder_attention_mask=encoder_attention_mask, 200 | ) 201 | if return_decoder_inputs: 202 | return inputs_embeds, attention_mask 203 | # llama_decoder forward 204 | return self.llama_decoder.forward( 205 | input_ids=None, 206 | attention_mask=attention_mask, 207 | position_ids=position_ids, 208 | past_key_values=past_key_values, 209 | inputs_embeds=inputs_embeds, 210 | labels=labels, 211 | use_cache=use_cache, 212 | output_attentions=output_attentions, 213 | return_dict=return_dict, 214 | cache_position=cache_position 215 | ) 216 | 217 | def generate( 218 | self, 219 | inputs: torch.LongTensor, # alias of `input_ids` 220 | attention_mask: Optional[torch.LongTensor] = None, 221 | protein_input_ids: Optional[torch.LongTensor] = None, 222 | protein_attention_mask: Optional[torch.LongTensor] = None, 223 | protein_inputs_embeds: Optional[torch.FloatTensor] = None, 224 | **kwargs 225 | ) -> Union[GenerateOutput, torch.LongTensor]: 226 | """ 227 | Do inference based on given input prompt. 228 | `inputs` is expected to be [prompt] only. 229 | Output will not keep the input prompt due to input in form of embeds. 230 | Generation behavior can be controlled by `args` and `kwargs`, read 231 | `GenerationMixin.generate` for more info. 232 | """ 233 | # get decoder inputs 234 | prompt_inputs_embeds, prompt_attention_mask = self( 235 | input_ids=inputs, 236 | attention_mask=attention_mask, 237 | protein_input_ids=protein_input_ids, 238 | protein_attention_mask=protein_attention_mask, 239 | protein_inputs_embeds=protein_inputs_embeds, 240 | use_cache=False, 241 | output_attentions=False, 242 | output_hidden_states=False, 243 | return_dict=False, 244 | return_decoder_inputs=True 245 | ) 246 | # do generate on llama_decoder 247 | return self.llama_decoder.generate( 248 | inputs_embeds=prompt_inputs_embeds, 249 | attention_mask=prompt_attention_mask, 250 | **kwargs 251 | ) 252 | 253 | def gradient_checkpointing_enable(self): 254 | """ 255 | Enable gradient checkpointing for all submodules that support it. 256 | Attention! Model need to be in train mode before calling this method. 257 | """ 258 | if hasattr(self.esm_encoder, "gradient_checkpointing_enable"): 259 | self.esm_encoder.gradient_checkpointing_enable() 260 | if hasattr(self.llama_decoder, "gradient_checkpointing_enable"): 261 | self.llama_decoder.gradient_checkpointing_enable() 262 | # simple adapter no need to implement gradient checkpointing 263 | 264 | def gradient_checkpointing_disable(self): 265 | if hasattr(self.esm_encoder, "gradient_checkpointing_disable"): 266 | self.esm_encoder.gradient_checkpointing_disable() 267 | if hasattr(self.llama_decoder, "gradient_checkpointing_disable"): 268 | self.llama_decoder.gradient_checkpointing_disable() 269 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | DataLoader class for protein function prediction instruction tuning. To be used 3 | with Prot2TextInstructDataset for Esm2LlamaInstructForCausalLM. 4 | 5 | Every batch from DataLoader will contain following attributes: 6 | * Training mode (train-eval with teacher-forcing): 7 | - graph related features: 8 | - x: (sum_num_nodes, num_node_features) 9 | - edge_index: (2, sum_num_edges) 10 | - edge_type: (sum_num_edges,) 11 | - batch: (sum_num_nodes,) 12 | - amino-acid sequence: 13 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 14 | - protein_attention_mask (bsz, max_seq_len+2) # right padding 15 | - concatenated chat: 16 | - input_ids (bsz, max_prompt_len+max_text_len+1) 17 | - attention_mask (bsz, max_prompt_len+max_text_len+1) 18 | - labels (bsz, max_prompt_len+max_text_len+1) 19 | - standalone description for contrastive learning: 20 | - description_input_ids (bsz, max_text_len+1) # eos token only 21 | - description_attention_mask (bsz, max_text_len+1) # right padding 22 | 23 | ids = [left-pad + bos + prompt & description + eot + right-pad] 24 | mask = [0s + 1 + 1s & 1s + 1 + 0s ] 25 | labels = [-100s + -100 + -100s & description + eot + -100s ] 26 | desc_ids = [ & description + eot + right-pad] 27 | desc_mask = [ & 1s + 1 + 0s ] 28 | 29 | * Inference mode (iterative generation): 30 | - graph related features: 31 | - x: (sum_num_nodes, num_node_features) 32 | - edge_index: (2, sum_num_edges) 33 | - edge_type: (sum_num_edges,) 34 | - batch: (sum_num_nodes,) 35 | - amino-acid sequence: 36 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 37 | - protein_attention_mask (bsz, max_seq_len+2) # right padding 38 | - prompt chat: 39 | - input_ids (bsz, max_prompt_len) 40 | - attention_mask (bsz, max_prompt_len) 41 | - description_input_ids (bsz, max_text_len+1) # for evaluation 42 | 43 | ids = [left-pad + bos + prompt & ] 44 | mask = [0s + 1 + 1s & ] 45 | desc_ids = [ & description + eot + right-pad] 46 | 47 | Example of usage: 48 | >>> from transformers import AutoTokenizer 49 | >>> from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader 50 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D") 51 | >>> llama_tokenizer = AutoTokenizer.from_pretrained( 52 | "/data/Meta-Llama-3.1-8B-Instruct-hf", 53 | pad_token='<|reserved_special_token_0|>' 54 | ) 55 | >>> train_dataset = Prot2TextInstructDataset( 56 | root_dir="/data/Prot2Text-Llama3-Data/train", 57 | csv_path="./data/train.csv", 58 | sequence_tokenizer=esm_tokenizer, 59 | description_tokenizer=llama_tokenizer, 60 | skip_download=True, # assume data is already downloaded 61 | skip_reload=True, # assume data is already preprocessed 62 | ) 63 | >>> train_dataloader = Prot2TextInstructDataLoader( 64 | dataset=train_dataset, 65 | mode="train", 66 | batch_size=2, 67 | shuffle=True, 68 | ) 69 | """ 70 | 71 | 72 | from typing import Dict, List, Literal, Optional, Union 73 | 74 | import torch 75 | import torch.utils.data 76 | import torch_geometric 77 | import torch_geometric.data 78 | import torch_geometric.loader.dataloader 79 | from transformers import PreTrainedTokenizer 80 | 81 | from .dataset import Prot2TextInstructDataset 82 | 83 | 84 | class Prot2TextInstructCollater(torch_geometric.loader.dataloader.Collater): 85 | def __init__( 86 | self, 87 | dataset: Prot2TextInstructDataset, 88 | tokenizer: PreTrainedTokenizer, 89 | mode: Literal["train", "inference"], 90 | **kwargs, 91 | ): 92 | super().__init__(dataset=dataset, **kwargs) 93 | self.tokenizer = tokenizer 94 | self.mode = mode 95 | self.seq_pad_token_id = self.dataset.sequence_tokenizer.pad_token_id 96 | self.text_pad_token_id = tokenizer.pad_token_id 97 | 98 | def __call__( 99 | self, 100 | batch: List[Dict[str, Union[str, torch.Tensor]]] 101 | ) -> Dict[str, torch.Tensor]: 102 | # prepare graph related features and name 103 | data_batch = torch_geometric.data.Batch.from_data_list( 104 | batch, 105 | exclude_keys=[ 106 | "sequence_input_ids", 107 | "prompt_input_ids", 108 | "description_input_ids", 109 | ] 110 | ) 111 | 112 | # prepare attn mask, right pad and stack sequences 113 | sequence_input_ids = [data["sequence_input_ids"][0] for data in batch] 114 | pad_sequence_input_ids = self._pad_sequence( 115 | sequence_input_ids, 116 | padding_value=self.seq_pad_token_id, 117 | padding_side="right" 118 | ) 119 | pad_sequence_attention_mask = self._pad_sequence( 120 | [torch.ones_like(data["sequence_input_ids"][0]) for data in batch], 121 | padding_value=0, 122 | padding_side="right" 123 | ) 124 | 125 | # prepare attn mask, left pad and stack prompts 126 | prompt_input_ids = [data["prompt_input_ids"][0] for data in batch] 127 | pad_prompt_input_ids = self._pad_sequence( 128 | prompt_input_ids, 129 | padding_value=self.text_pad_token_id, 130 | padding_side="left" 131 | ) 132 | pad_prompt_attention_mask = self._pad_sequence( 133 | [torch.ones_like(data["prompt_input_ids"][0]) for data in batch], 134 | padding_value=0, 135 | padding_side="left" 136 | ) 137 | 138 | # prepare attn mask, right pad and stack descriptions 139 | description_input_ids = [data["description_input_ids"][0] for data in batch] 140 | pad_description_input_ids = self._pad_sequence( 141 | description_input_ids, 142 | padding_value=self.text_pad_token_id, 143 | padding_side="right" 144 | ) 145 | pad_description_attention_mask = self._pad_sequence( 146 | [torch.ones_like(data["description_input_ids"][0]) for data in batch], 147 | padding_value=0, 148 | padding_side="right" 149 | ) 150 | pad_labels = self._pad_sequence( 151 | description_input_ids, 152 | padding_value=-100, 153 | padding_side="right" 154 | ) 155 | 156 | # update text features 157 | if self.mode == "train": 158 | data_batch.update({ 159 | "input_ids": torch.cat([ 160 | pad_prompt_input_ids, 161 | pad_description_input_ids 162 | ], dim=1), 163 | "attention_mask": torch.cat([ 164 | pad_prompt_attention_mask, 165 | pad_description_attention_mask 166 | ], dim=1), 167 | "labels": torch.cat([ 168 | torch.full_like( 169 | pad_prompt_input_ids, 170 | fill_value=-100 171 | ), 172 | pad_labels 173 | ], dim=1), 174 | "protein_input_ids": pad_sequence_input_ids, 175 | "protein_attention_mask": pad_sequence_attention_mask, 176 | "description_input_ids": pad_description_input_ids, 177 | "description_attention_mask": pad_description_attention_mask, 178 | }) 179 | elif self.mode == "inference": 180 | data_batch.update({ 181 | "input_ids": pad_prompt_input_ids, 182 | "attention_mask": pad_prompt_attention_mask, 183 | "description_input_ids": pad_description_input_ids, 184 | "protein_input_ids": pad_sequence_input_ids, 185 | "protein_attention_mask": pad_sequence_attention_mask, 186 | }) 187 | else: 188 | raise ValueError(f"Invalid mode: {self.mode}") 189 | 190 | # remove excluded keys 191 | if self.exclude_keys: 192 | data_batch = { 193 | k: v for k, v in data_batch.items() 194 | if k not in self.exclude_keys 195 | } 196 | 197 | return data_batch 198 | 199 | @staticmethod 200 | def _pad_sequence( 201 | sequences: List[torch.Tensor], 202 | padding_value: Union[float, int], 203 | padding_side: Literal["left", "right"] = "right", 204 | ) -> torch.Tensor: 205 | """ 206 | Modified version of torch.nn.utils.rnn.pad_sequence with optional 207 | padding side. 208 | Such feature is naturally supported by PyTorch 2.6.0+ and it's 209 | recommended to use the built-in version for better efficiency. 210 | 211 | * sequences as input must be a list of 1D tensors. 212 | """ 213 | max_len = max(sequence.shape[-1] for sequence in sequences) 214 | padded_sequences = [] 215 | for sequence in sequences: 216 | padding = torch.full( 217 | size=(max_len - sequence.shape[-1],), 218 | fill_value=padding_value, 219 | dtype=sequence.dtype, 220 | device=sequence.device, 221 | ) 222 | if padding_side == "left": 223 | padded_sequences.append(torch.cat([padding, sequence], dim=-1)) 224 | elif padding_side == "right": 225 | padded_sequences.append(torch.cat([sequence, padding], dim=-1)) 226 | else: 227 | raise ValueError(f"Invalid padding side: {padding_side}") 228 | return torch.stack(padded_sequences, dim=0) 229 | 230 | 231 | class Prot2TextInstructDataLoader(torch.utils.data.DataLoader): 232 | """ 233 | DataLoader class proteins, forming batch inputs. 234 | 235 | (1) Compose graph related features, 236 | (2) dynamically pad sequences, prompts and descriptions, then 237 | (3) stack then concatenate these text features under different modes. 238 | 239 | Args: 240 | dataset: 241 | `Prot2TextInstructDataset` class to load data from. 242 | mode: 243 | - "train": training-evaluation with teacher-forcing. Input ids will 244 | be concatenated with labels (system + user + assistant) for 245 | training. 246 | - "inference": iterative generation. Input ids will only contain 247 | prompt (system + user) for generation. 248 | batch_size: 249 | Number of samples per batch. 250 | shuffle: 251 | Whether to shuffle the data. If a sampler is provided, the shuffling 252 | behavior should be controlled by the sampler and this argument 253 | should be set to False. 254 | follow_batch: 255 | PyG specific feature to be passed to Collater class. When working 256 | with batched graph data, follow_batch is used to indicate which 257 | attributes should an extra batch indices tensor be created for. If 258 | not set, an extra `batch` tensor will be automatically created to 259 | indicate which graph each node belongs to. 260 | exclude_keys: 261 | List of keys to exclude from the batch. The exclusion will be applied 262 | at the very end of the collation process. 263 | kwargs: 264 | Additional arguments to pass to DataLoader class. 265 | """ 266 | def __init__( 267 | self, 268 | dataset: Prot2TextInstructDataset, 269 | mode: Literal["train", "inference"] = "train", 270 | batch_size: int = 1, 271 | shuffle: bool = True, 272 | follow_batch: Optional[List[str]] = None, 273 | exclude_keys: Optional[List[str]] = None, 274 | **kwargs, 275 | ): 276 | # override collate_fn 277 | kwargs.pop("collate_fn", None) 278 | self.follow_batch = follow_batch 279 | self.exclude_keys = exclude_keys 280 | collater = Prot2TextInstructCollater( 281 | dataset=dataset, 282 | tokenizer=dataset.description_tokenizer, 283 | mode=mode, 284 | follow_batch=follow_batch, 285 | exclude_keys=exclude_keys 286 | ) 287 | super().__init__( 288 | dataset, 289 | batch_size=batch_size, 290 | shuffle=shuffle, 291 | collate_fn=collater, 292 | **kwargs 293 | ) 294 | 295 | def __len__(self): 296 | """Mark class as Sized.""" 297 | return super().__len__() 298 | 299 | def __iter__(self): 300 | """Mark class as Iterable.""" 301 | return super().__iter__() 302 | -------------------------------------------------------------------------------- /dataset/dataloader_derived.py: -------------------------------------------------------------------------------- 1 | """ 2 | Derived DataLoader class for protein function prediction supervised fine tuning. 3 | To be used with Prot2TextInstructDataset for Esm2LlamaForCausalLM. 4 | 5 | The `Prot2TextInstructDataset` is designed to be preprocessed with an instruct 6 | model (ex. `Meta-Llama-3.1-8B-Instruct-hf`), but can be adapted to other base 7 | models (ex. `Llama-3.2-1B`) if tokenizers of both models share the same 8 | vocabulary. 9 | 10 | This derived version (`Prot2TextDerivedDataLoader`) is thus designed to adapt 11 | datasets that are already preprocessed for the instruct model. It replaces 12 | special tokens and reorganize tokenized input ids, making them suitable for the 13 | base language model. 14 | 15 | Every batch from DataLoader will contain following attributes: 16 | * Training mode (train-eval with teacher-forcing): 17 | - graph related features: 18 | - x: (sum_num_nodes, num_node_features) 19 | - edge_index: (2, sum_num_edges) 20 | - edge_type: (sum_num_edges,) 21 | - batch: (sum_num_nodes,) 22 | - amino-acid sequence: 23 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 24 | - protein_attention_mask (bsz, max_seq_len+2) # left padding 25 | - concatenated prompt and description: 26 | - input_ids (bsz, prompt_len+1+max_text_len+1) # bos and eos tokens 27 | - attention_mask (bsz, prompt_len+1+max_text_len+1) # right padding 28 | - labels (bsz, prompt_len+1+max_text_len+1) 29 | - standalone description for reward model training: 30 | - description_input_ids (bsz, max_text_len+1) # eos token only 31 | - description_attention_mask (bsz, max_text_len+1) # right padding 32 | 33 | ids = [bos + prompt + bos & description + eos + right-pad] 34 | mask = [1 + 1s + 1 & 1s + 1 + 0s ] 35 | labels = [-100 + -100s + -100 & description + eos + -100s ] 36 | desc_ids = [ & description + eos + right-pad] 37 | desc_mask = [ & 1s + 1 + 0s ] 38 | 39 | * Inference mode (iterative generation): 40 | - graph related features: 41 | - x: (sum_num_nodes, num_node_features) 42 | - edge_index: (2, sum_num_edges) 43 | - edge_type: (sum_num_edges,) 44 | - batch: (sum_num_nodes,) 45 | - amino-acid sequence: 46 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens 47 | - protein_attention_mask (bsz, max_seq_len+2) # left padding 48 | - prompt with inference head: 49 | - input_ids (bsz, prompt_len+1) # bos token at the end 50 | - attention_mask (bsz, prompt_len+1) 51 | - standalone description for evaluation: 52 | - description_input_ids (bsz, max_text_len+1) 53 | - description_attention_mask (bsz, max_text_len+1) 54 | 55 | ids = [bos + prompt + bos & ] 56 | mask = [1 + 1s + 1 & ] 57 | desc_ids = [ & description + eos + right-pad] 58 | desc_mask = [ & 1s + 1 + 0s ] 59 | 60 | Example of usage: 61 | >>> from transformers import AutoTokenizer 62 | >>> from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader 63 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D") 64 | >>> llama_tokenizer = AutoTokenizer.from_pretrained( 65 | "/data/Llama-3.2-1B", 66 | pad_token='<|reserved_special_token_0|>' 67 | ) 68 | >>> train_dataset = Prot2TextInstructDataset( 69 | root_dir="/data/Prot2Text-Llama3-Data/train", 70 | csv_path="./data/train.csv", 71 | sequence_tokenizer=esm_tokenizer, 72 | description_tokenizer=llama_tokenizer, # pass the base model tokenizer 73 | skip_download=True, # assume data is already downloaded 74 | skip_reload=True, # assume data is already preprocessed 75 | ) 76 | >>> train_dataloader = Prot2TextDerivedDataLoader( 77 | dataset=train_dataset, 78 | mode="train", 79 | batch_size=2, 80 | shuffle=True, 81 | ) 82 | """ 83 | 84 | 85 | from typing import Dict, List, Literal, Optional, Union 86 | 87 | import torch 88 | import torch.utils.data 89 | import torch_geometric 90 | import torch_geometric.data 91 | import torch_geometric.loader.dataloader 92 | from transformers import PreTrainedTokenizer 93 | 94 | from .dataset import Prot2TextInstructDataset 95 | 96 | 97 | class Prot2TextDerivedCollater(torch_geometric.loader.dataloader.Collater): 98 | def __init__( 99 | self, 100 | dataset: Prot2TextInstructDataset, 101 | tokenizer: PreTrainedTokenizer, 102 | mode: Literal["train", "inference"], 103 | original_eos_token_id: int, 104 | prompt_sentence: str, 105 | **kwargs, 106 | ): 107 | super().__init__(dataset=dataset, **kwargs) 108 | self.tokenizer = tokenizer 109 | self.mode = mode 110 | self.prompt_sentence = prompt_sentence 111 | 112 | self.prompt_input_ids = tokenizer( 113 | [tokenizer.bos_token + prompt_sentence + tokenizer.bos_token], 114 | add_special_tokens=False, 115 | return_tensors="pt", 116 | return_attention_mask=False, 117 | )["input_ids"] 118 | 119 | self.seq_pad_token_id = dataset.sequence_tokenizer.pad_token_id 120 | self.text_pad_token_id = tokenizer.pad_token_id 121 | self.old_text_eos_token_id = original_eos_token_id 122 | self.new_text_eos_token_id = tokenizer.eos_token_id 123 | 124 | def __call__( 125 | self, 126 | batch: List[Dict[str, Union[str, torch.Tensor]]] 127 | ) -> Dict[str, torch.Tensor]: 128 | # prepare graph related features and name 129 | data_batch = torch_geometric.data.Batch.from_data_list( 130 | batch, 131 | exclude_keys=[ 132 | "sequence_input_ids", 133 | "prompt_input_ids", 134 | "description_input_ids", 135 | ] 136 | ) 137 | 138 | # prepare attn mask, left pad and stack sequences 139 | pad_sequence_input_ids = self._pad_sequence( 140 | [data["sequence_input_ids"][0] for data in batch], 141 | padding_value=self.seq_pad_token_id, 142 | padding_side="left" 143 | ) 144 | pad_sequence_attention_mask = self._pad_sequence( 145 | [torch.ones_like(data["sequence_input_ids"][0]) for data in batch], 146 | padding_value=0, 147 | padding_side="left" 148 | ) 149 | 150 | # prepare attn mask and expand prompts 151 | pad_prompt_input_ids = self.prompt_input_ids.repeat(len(batch), 1).to( 152 | pad_sequence_input_ids.device 153 | ) 154 | pad_prompt_attention_mask = torch.ones_like(pad_prompt_input_ids) 155 | 156 | # prepare attn mask, right pad and stack descriptions 157 | description_input_ids = [data["description_input_ids"][0] for data in batch] 158 | pad_description_input_ids = self._pad_sequence( 159 | description_input_ids, 160 | padding_value=self.text_pad_token_id, 161 | padding_side="right" 162 | ) 163 | pad_description_attention_mask = self._pad_sequence( 164 | [torch.ones_like(data["description_input_ids"][0]) for data in batch], 165 | padding_value=0, 166 | padding_side="right" 167 | ) 168 | pad_labels = self._pad_sequence( 169 | description_input_ids, 170 | padding_value=-100, 171 | padding_side="right" 172 | ) 173 | 174 | # replace special tokens in descriptions 175 | pad_description_input_ids.masked_fill_( 176 | pad_description_input_ids == self.old_text_eos_token_id, 177 | self.new_text_eos_token_id 178 | ) 179 | pad_labels.masked_fill_( 180 | pad_labels == self.old_text_eos_token_id, 181 | self.new_text_eos_token_id 182 | ) 183 | 184 | # decode back the descriptions in text 185 | descriptions = self.tokenizer.batch_decode( 186 | description_input_ids, 187 | skip_special_tokens=True, 188 | ) 189 | 190 | # update text features 191 | if self.mode == "train": 192 | data_batch.update({ 193 | "input_ids": torch.cat([ 194 | pad_prompt_input_ids, 195 | pad_description_input_ids 196 | ], dim=1), 197 | "attention_mask": torch.cat([ 198 | pad_prompt_attention_mask, 199 | pad_description_attention_mask 200 | ], dim=1), 201 | "labels": torch.cat([ 202 | torch.full_like( 203 | pad_prompt_input_ids, 204 | fill_value=-100 205 | ), 206 | pad_labels 207 | ], dim=1), 208 | "protein_input_ids": pad_sequence_input_ids, 209 | "protein_attention_mask": pad_sequence_attention_mask, 210 | "description_input_ids": pad_description_input_ids, 211 | "description_attention_mask": pad_description_attention_mask, 212 | "descriptions": descriptions, 213 | }) 214 | elif self.mode == "inference": 215 | data_batch.update({ 216 | "input_ids": pad_prompt_input_ids, 217 | "attention_mask": pad_prompt_attention_mask, 218 | "description_input_ids": pad_description_input_ids, 219 | "description_attention_mask": pad_description_attention_mask, 220 | "protein_input_ids": pad_sequence_input_ids, 221 | "protein_attention_mask": pad_sequence_attention_mask, 222 | }) 223 | else: 224 | raise ValueError(f"Invalid mode: {self.mode}") 225 | 226 | return data_batch 227 | 228 | @staticmethod 229 | def _pad_sequence( 230 | sequences: List[torch.Tensor], 231 | padding_value: Union[float, int], 232 | padding_side: Literal["left", "right"] = "right", 233 | ) -> torch.Tensor: 234 | """ 235 | Modified version of torch.nn.utils.rnn.pad_sequence with optional 236 | padding side. 237 | Such feature is naturally supported by PyTorch 2.6.0+ and it's 238 | recommended to use the built-in version for better efficiency. 239 | 240 | * sequences as input must be a list of 1D tensors. 241 | """ 242 | max_len = max(sequence.shape[-1] for sequence in sequences) 243 | padded_sequences = [] 244 | for sequence in sequences: 245 | padding = torch.full( 246 | size=(max_len - sequence.shape[-1],), 247 | fill_value=padding_value, 248 | dtype=sequence.dtype, 249 | device=sequence.device, 250 | ) 251 | if padding_side == "left": 252 | padded_sequences.append(torch.cat([padding, sequence], dim=-1)) 253 | elif padding_side == "right": 254 | padded_sequences.append(torch.cat([sequence, padding], dim=-1)) 255 | else: 256 | raise ValueError(f"Invalid padding side: {padding_side}") 257 | return torch.stack(padded_sequences, dim=0) 258 | 259 | 260 | class Prot2TextDerivedDataLoader(torch.utils.data.DataLoader): 261 | """ 262 | DataLoader class proteins, forming batch inputs. 263 | 264 | (1) Compose graph related features with PyG's Batch.from_data_list; 265 | (2) dynamically pad sequences, prompts and descriptions; 266 | (3) replace special tokens and reorganize tokenized input ids; 267 | (4) stack then concatenate these text features under different modes. 268 | 269 | Args: 270 | dataset: 271 | `Prot2TextInstructDataset` class to load data from. Both 272 | `Prot2TextInstructDataLoader` and `Prot2TextDerivedDataLoader` should 273 | use the same dataset, but with different sequence tokenizers. Read 274 | the docstring of `Prot2TextInstructDataset` for more details. 275 | mode: 276 | - "train": training-evaluation with teacher-forcing. Input ids will 277 | be concatenated with labels (prompt + description) for 278 | training. 279 | - "inference": iterative generation. Input ids will only contain 280 | prompt (prompt + bos) for generation. 281 | batch_size: 282 | Number of samples per batch. 283 | shuffle: 284 | Whether to shuffle the data. If a sampler is provided, the shuffling 285 | behavior should be controlled by the sampler and this argument 286 | should be set to False. 287 | follow_batch: 288 | PyG specific feature to be passed to Collater class. When working 289 | with batched graph data, follow_batch is used to indicate which 290 | attributes should an extra batch indices tensor be created for. If 291 | not set, an extra `batch` tensor will be automatically created to 292 | indicate which graph each node belongs to. 293 | exclude_keys: 294 | List of keys to exclude from the batch. The exclusion will be applied 295 | at the very end of the collation process. 296 | original_eos_token_id: 297 | End-of-sequence token id of the instruct model tokenizer that is used 298 | in the preprocessing stage of the dataset. Such token will be 299 | replaced by the eos token id of the base model tokenizer that is 300 | given in the derived scenario. 301 | prompt_sentence: 302 | Prompt sentence to be used in the derived scenario. 303 | kwargs: 304 | Additional arguments to pass to DataLoader class. 305 | """ 306 | def __init__( 307 | self, 308 | dataset: Prot2TextInstructDataset, 309 | mode: Literal["train", "inference"] = "train", 310 | batch_size: int = 1, 311 | shuffle: bool = True, 312 | follow_batch: Optional[List[str]] = None, 313 | exclude_keys: Optional[List[str]] = None, 314 | original_eos_token_id: int = 128009, 315 | prompt_sentence: str = ( 316 | "Predict protein description based on the amino-acid sequence embeddings." 317 | ), 318 | **kwargs, 319 | ): 320 | # override collate_fn 321 | kwargs.pop("collate_fn", None) 322 | self.follow_batch = follow_batch 323 | self.exclude_keys = exclude_keys 324 | collater = Prot2TextDerivedCollater( 325 | dataset=dataset, 326 | tokenizer=dataset.description_tokenizer, 327 | mode=mode, 328 | follow_batch=follow_batch, 329 | exclude_keys=exclude_keys, 330 | original_eos_token_id=original_eos_token_id, 331 | prompt_sentence=prompt_sentence, 332 | ) 333 | super().__init__( 334 | dataset, 335 | batch_size=batch_size, 336 | shuffle=shuffle, 337 | collate_fn=collater, 338 | **kwargs 339 | ) 340 | 341 | def __len__(self): 342 | """Mark class as Sized.""" 343 | return super().__len__() 344 | 345 | def __iter__(self): 346 | """Mark class as Iterable.""" 347 | return super().__iter__() 348 | -------------------------------------------------------------------------------- /scripts/train_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | FullyShardedDataParallel / DistributedDataParallel training script implemented 3 | from scratch. 4 | 5 | The script currently supports gradient accumulation, AutoMixedPrecision, 6 | and inter-epoch evaluation. 7 | 8 | The script currently does not support save/load pretrained or gradient checkpointing. 9 | 10 | reference for FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html 11 | reference for AMP: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html 12 | 13 | * The script is designed for multi-GPU parallelism on single node. 14 | * On the cluster, print(...) will go to stdout and tqdm(...) will go to stderr. 15 | """ 16 | 17 | import argparse 18 | from datetime import datetime 19 | import functools 20 | import os 21 | from typing import Any, Dict, Optional, Tuple, Union 22 | 23 | import torch 24 | from torch.amp import GradScaler 25 | from torch.cuda.amp import autocast 26 | import torch.distributed as dist 27 | from torch.distributed.fsdp import ( 28 | FullyShardedDataParallel, 29 | FullStateDictConfig, 30 | StateDictType 31 | ) 32 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy 33 | from torch.nn.parallel import DistributedDataParallel 34 | from torch.optim import AdamW, Optimizer 35 | from torch.optim.lr_scheduler import StepLR 36 | from torch.utils.data.distributed import DistributedSampler 37 | from tqdm import tqdm 38 | from transformers import AutoTokenizer, PreTrainedModel 39 | 40 | from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader 41 | from models import Esm2LlamaForCausalLM 42 | import scripts.utils_argparse as utils_argparse 43 | 44 | 45 | argParser = argparse.ArgumentParser() 46 | 47 | argParser.add_argument("--esm_path", type=str) 48 | argParser.add_argument("--llama_path", type=str) 49 | argParser.add_argument("--root_dataset_dir", type=str) 50 | argParser.add_argument("--root_csv_dir", type=str) 51 | argParser.add_argument("--save_checkpoint_dir", type=str) 52 | argParser.add_argument("--load_general_checkpoint_path", type=str, default="") 53 | 54 | argParser.add_argument("--wrap_model", type=utils_argparse.str2bool) 55 | argParser.add_argument("--autocast_dtype", type=utils_argparse.str2dtype) 56 | argParser.add_argument("--batch_size_per_device", type=int) 57 | argParser.add_argument("--num_epochs", type=int) 58 | argParser.add_argument("--save_every_epochs", type=int) 59 | argParser.add_argument("--gradient_accumulation_steps", type=int) 60 | argParser.add_argument("--learning_rate", type=float) 61 | argParser.add_argument("--gradient_clipping", type=float, default=None) 62 | argParser.add_argument("--scheduler_gamma", type=float) 63 | argParser.add_argument("--random_seed", type=int) 64 | argParser.add_argument("--train_split", type=str) 65 | argParser.add_argument("--eval_split", type=str) 66 | argParser.add_argument("--debug_trim_train_split", type=int, default=None) 67 | argParser.add_argument("--debug_trim_eval_split", type=int, default=None) 68 | argParser.add_argument("--max_sequence_length", type=int, default=None) 69 | argParser.add_argument("--max_description_length", type=int, default=None) 70 | 71 | 72 | def load_model( 73 | esm_path: Union[str, os.PathLike], 74 | llama_path: Union[str, os.PathLike], 75 | load_general_checkpoint_path: Optional[Union[str, os.PathLike]] = None 76 | ) -> PreTrainedModel: 77 | """ 78 | Standard API for different models. Used in both `train` and `generate`. 79 | Load base model of the given name, and load weights from the checkpoint path if provided. 80 | Returned model is on CPU by default. 81 | `load_general_checkpoint_path` will be ignored if `load_checkpoint_path` is provided. 82 | """ 83 | model = Esm2LlamaForCausalLM.from_pretrained( 84 | pretrained_esm_model_name_or_path=esm_path, 85 | pretrained_llama_model_name_or_path=llama_path, 86 | esm_kwargs={"decoder_hidden_size": 2048} 87 | ) 88 | 89 | # load checkpoint if any 90 | if load_general_checkpoint_path: 91 | print(f"Loading {load_general_checkpoint_path}") 92 | checkpoint_state_dicts = torch.load(load_general_checkpoint_path, weights_only=True) 93 | model_state_dict = checkpoint_state_dicts["model_state_dict"] 94 | model.load_state_dict(model_state_dict) 95 | 96 | return model 97 | 98 | 99 | def teacher_forcing_forward_pass( 100 | rank: int, 101 | model: torch.nn.Module, 102 | data_batch: Dict[str, Any], 103 | ) -> Tuple[torch.Tensor]: 104 | """ 105 | Standard API for different models. Used in both `train_epoch` and `eval_epoch`. 106 | 1) Prepare inputs for the forward pass with teacher forcing using data_batch from dataloader. 107 | 2) Execute the forward pass and return the direct output. 108 | Returned loss is not scaled with gradient accumulation steps. 109 | Returned logits are un-normalized predictions representing the scores for each token in the vocabulary. 110 | """ 111 | return model( 112 | input_ids=data_batch["input_ids"].to(rank), 113 | attention_mask=data_batch["attention_mask"].to(rank), 114 | labels=data_batch["labels"].to(rank), 115 | protein_input_ids=data_batch["protein_input_ids"].to(rank), 116 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank), 117 | use_cache=False, 118 | output_attentions=False, 119 | output_hidden_states=False, 120 | return_dict=False, # force return tuple (loss, logits) 121 | return_encoder_output=False, 122 | ) 123 | 124 | 125 | def setup(rank: int, world_size: int): 126 | """ 127 | Initialize processes for distributed training before first epoch. 128 | Fetch from job script or launcher to set the IP address and the port of the master node. 129 | """ 130 | os.environ['MASTER_ADDR'] = os.getenv('MASTER_ADDR', 'localhost') 131 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '9901') 132 | # initialize the process group 133 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 134 | 135 | 136 | def cleanup(): 137 | """End processes for distributed training after last epoch""" 138 | dist.destroy_process_group() 139 | 140 | 141 | def train_epoch( 142 | rank: int, 143 | current_epoch: int, 144 | model: Union[DistributedDataParallel, FullyShardedDataParallel], 145 | dataloader: Prot2TextDerivedDataLoader, 146 | optimizer: Optimizer, 147 | scaler: GradScaler, 148 | args: Dict[str, Any] 149 | ): 150 | """Iterate over all batches for one epoch in training with teacher forcing""" 151 | model.train() 152 | ddp_loss = torch.zeros(2).to(rank) # [0] for acc. loss and [1] for num. of seen batches 153 | ddp_gradnorm = torch.zeros(2).to(rank) # [0] for acc. gradnorm and [1] for num. of passed steps 154 | optimizer.zero_grad() # erase accumulated gradients from last epoch 155 | 156 | t = tqdm(iter(dataloader)) 157 | for batch_idx, data_batch in enumerate(t): 158 | # with autocast, logits will be in AUTOCAST_DTYPE and loss will be re-casted to torch.float32 159 | with autocast(dtype=args["autocast_dtype"]): 160 | output = teacher_forcing_forward_pass( 161 | rank=rank, 162 | model=model, 163 | data_batch=data_batch, 164 | ) 165 | 166 | # rescale loss for consistency with different gradient accumulation steps 167 | loss = output[0] / args["gradient_accumulation_steps"] 168 | 169 | # summary current batch 170 | t.set_postfix({ 171 | "mode": "train", 172 | "epoch": f"{current_epoch}/{args['num_epochs']}", 173 | "batch_loss": loss.item() * args["gradient_accumulation_steps"], 174 | "device": f"rank:{rank}" 175 | }) 176 | ddp_loss[0] += loss.item() * args["gradient_accumulation_steps"] 177 | ddp_loss[1] += 1 # the loss is the weighted mean of the output of every batch 178 | 179 | # scale the loss up by a large factor to prevent them from becoming too small, then accumulate the scaled grads 180 | scaler.scale(loss).backward() # backward out of autocast, but still uses same dtype as for forward 181 | 182 | # update weights by loss if accumulation step is reached 183 | if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0: # Perform optimizer step after accumulation 184 | scaler.unscale_(optimizer) # unscale gradients for gradient examination and clipping 185 | gradnorm = torch.nn.utils.clip_grad_norm_( 186 | model.parameters(), 187 | max_norm=float("inf") if args["gradient_clipping"] is None else args["gradient_clipping"] 188 | ) 189 | ddp_gradnorm[0] += gradnorm 190 | ddp_gradnorm[1] += 1 191 | 192 | scaler.step(optimizer) # first unscale the gradients, then do step only if no INF or NaN is in grad 193 | scaler.update() 194 | optimizer.zero_grad(set_to_none=True) 195 | 196 | # summary current epoch 197 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) 198 | if rank == 0: 199 | print( 200 | f"[epoch={current_epoch}/{args['num_epochs']}, " 201 | f"train_loss={ddp_loss[0] / ddp_loss[1]}, " 202 | f"epoch_lr={optimizer.param_groups[0]['lr']}, " 203 | f"epoch_gradnorm={ddp_gradnorm[0] / ddp_gradnorm[1]}]" 204 | ) 205 | # NaN detection 206 | if ddp_loss[0] != ddp_loss[0]: 207 | raise ValueError("NaN detected in the training loss of the epoch, training interrupted.") 208 | 209 | 210 | def eval_epoch( 211 | rank: int, 212 | current_epoch: int, 213 | model: Union[DistributedDataParallel, FullyShardedDataParallel], 214 | dataloader: Prot2TextDerivedDataLoader, 215 | args: Dict[str, Any] 216 | ): 217 | """Iterate over all batches in evaluation with teacher forcing""" 218 | model.eval() 219 | ddp_loss = torch.zeros(2).to(rank) # [0] for acc. loss and [1] for num. of seen batches 220 | 221 | t = tqdm(iter(dataloader)) 222 | for data_batch in t: 223 | with torch.no_grad(): 224 | output = teacher_forcing_forward_pass( 225 | rank=rank, 226 | model=model, 227 | data_batch=data_batch, 228 | ) 229 | 230 | loss = output[0] 231 | t.set_postfix({ 232 | "mode": "eval", 233 | "epoch": f"{current_epoch}/{args['num_epochs']}", 234 | "batch_loss": loss.item(), 235 | "device": f"rank:{rank}" 236 | }) 237 | ddp_loss[0] += loss.item() 238 | ddp_loss[1] += 1 239 | 240 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) 241 | if rank == 0: 242 | print(f"[epoch={current_epoch}/{args['num_epochs']}, eval_loss={ddp_loss[0] / ddp_loss[1]}]") 243 | 244 | 245 | def train_on_device( 246 | rank: int, 247 | world_size: int, 248 | args: Dict[str, Any] 249 | ): 250 | """Training and evaluation process for each device, including epochs of training with teacher forcing""" 251 | setup(rank, world_size) 252 | 253 | # prepare datasets and dataloaders 254 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"]) 255 | llama_tokenizer = AutoTokenizer.from_pretrained(args["llama_path"], pad_token='<|reserved_special_token_0|>') 256 | 257 | train_dataset = Prot2TextInstructDataset( 258 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['train_split']}"), 259 | csv_path=os.path.join(args["root_csv_dir"], f"{args['train_split']}.csv"), 260 | sequence_tokenizer=esm_tokenizer, 261 | description_tokenizer=llama_tokenizer, 262 | skip_reload=True, 263 | skip_download=True, 264 | ignore_graph_features=False, 265 | max_sequence_length=args["max_sequence_length"], 266 | max_description_length=args["max_description_length"], 267 | ) 268 | if args["debug_trim_train_split"]: 269 | train_dataset.usable_file_names = train_dataset.usable_file_names[:args["debug_trim_train_split"]] 270 | train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) 271 | train_loader = Prot2TextDerivedDataLoader( 272 | train_dataset, 273 | batch_size=args["batch_size_per_device"], 274 | sampler=train_sampler, 275 | num_workers=4, # parallel CPU cores used for data loading 276 | pin_memory=True, # enable page-locked memory allocation for faster data transfer to GPUs 277 | shuffle=False, 278 | drop_last=True, # avoid incomplete batch at the end 279 | ) 280 | 281 | eval_dataset = Prot2TextInstructDataset( 282 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['eval_split']}"), 283 | csv_path=os.path.join(args["root_csv_dir"], f"{args['eval_split']}.csv"), 284 | sequence_tokenizer=esm_tokenizer, 285 | description_tokenizer=llama_tokenizer, 286 | skip_reload=True, 287 | skip_download=True, 288 | ignore_graph_features=False, 289 | max_sequence_length=args["max_sequence_length"], 290 | max_description_length=args["max_description_length"], 291 | ) 292 | if args["debug_trim_eval_split"]: 293 | eval_dataset.usable_file_names = eval_dataset.usable_file_names[:args["debug_trim_eval_split"]] 294 | eval_sampler = DistributedSampler(eval_dataset, rank=rank, num_replicas=world_size, shuffle=False) 295 | eval_loader = Prot2TextDerivedDataLoader( 296 | eval_dataset, 297 | batch_size=args["batch_size_per_device"], 298 | sampler=eval_sampler, 299 | num_workers=4, 300 | pin_memory=True, 301 | shuffle=False, 302 | drop_last=True, 303 | ) 304 | 305 | torch.cuda.set_device(rank) 306 | 307 | model = load_model( 308 | esm_path=args["esm_path"], 309 | llama_path=args["llama_path"], 310 | load_general_checkpoint_path=args["load_general_checkpoint_path"], 311 | ) 312 | model = model.to(rank) 313 | 314 | if args["wrap_model"]: 315 | # shard all layers with size of parameters greater than min_num_params 316 | my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=10000) 317 | model = FullyShardedDataParallel(model, auto_wrap_policy=my_auto_wrap_policy) 318 | print(f"FSDP model loaded on rank:{rank}") 319 | else: 320 | model = DistributedDataParallel(model, find_unused_parameters=True) 321 | print(f"DDP model loaded on rank:{rank}") 322 | 323 | # initialization of the optimizer after wrapping the model 324 | optimizer = AdamW(model.parameters(), lr=args["learning_rate"]) 325 | scheduler = StepLR(optimizer, step_size=1, gamma=args["scheduler_gamma"]) 326 | if args["load_general_checkpoint_path"]: 327 | checkpoint_state_dicts = torch.load(args["load_general_checkpoint_path"], weights_only=True) 328 | optimizer_state_dict = checkpoint_state_dicts["optimizer_state_dict"] 329 | scheduler_state_dict = checkpoint_state_dicts["scheduler_state_dict"] 330 | optimizer.load_state_dict(optimizer_state_dict) 331 | scheduler.load_state_dict(scheduler_state_dict) 332 | 333 | # initialization of scaler for mixed precision and control of gradient accumulation 334 | grad_scaler = GradScaler("cuda") 335 | 336 | # core loop of epochs 337 | for epoch_idx in range(1, args["num_epochs"] + 1): 338 | # shuffle data differently at each epoch across all processes 339 | train_sampler.set_epoch(epoch=epoch_idx) 340 | 341 | train_epoch( 342 | rank=rank, 343 | current_epoch=epoch_idx, 344 | model=model, 345 | dataloader=train_loader, 346 | optimizer=optimizer, 347 | scaler=grad_scaler, 348 | args=args 349 | ) 350 | scheduler.step() 351 | dist.barrier() # use a barrier to make sure training is done on all ranks 352 | 353 | eval_epoch( 354 | rank=rank, 355 | model=model, 356 | current_epoch=epoch_idx, 357 | dataloader=eval_loader, 358 | args=args 359 | ) 360 | dist.barrier() 361 | 362 | # save model checkpoint with CPU offload to avoid CUDA OOM, save/load_pretrained not available in FSDP 363 | if epoch_idx == 1 or epoch_idx == args["num_epochs"] or epoch_idx % args["save_every_epochs"] == 0: 364 | if args["wrap_model"]: # FSDP 365 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 366 | with FullyShardedDataParallel.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): 367 | model_state_dict = model.state_dict() 368 | else: # DDP 369 | model_state_dict = model.module.state_dict() 370 | if rank == 0: 371 | checkpoint_path = os.path.join(args["save_checkpoint_dir"], f"general_checkpoint_{epoch_idx}.pt") 372 | torch.save( 373 | { 374 | "model_state_dict": model_state_dict, 375 | "optimizer_state_dict": optimizer.state_dict(), 376 | "scheduler_state_dict": scheduler.state_dict(), 377 | }, 378 | checkpoint_path 379 | ) 380 | print(f"Saving {checkpoint_path}") 381 | dist.barrier() 382 | 383 | cleanup() 384 | 385 | 386 | def train_distributed(args: Dict[str, Any]): 387 | """ 388 | Core training process across multiple devices with epochs of training and inter-epoch evaluation. 389 | Use args: Dict[str, Any] instead of **kwargs for compatibility with torch.multiprocessing.spawn. 390 | """ 391 | torch.multiprocessing.spawn( 392 | train_on_device, 393 | args=(args["world_size"], args), 394 | nprocs=args["world_size"], 395 | join=True 396 | ) 397 | 398 | 399 | if __name__ == '__main__': 400 | # suppress messages from AutoTokenizer parallelism and Graphein respectively 401 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 402 | os.environ["LOGURU_LEVEL"] = "INFO" 403 | 404 | parsed_args = argParser.parse_args() 405 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs 406 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes 407 | 408 | torch.manual_seed(parsed_args.random_seed) 409 | torch.cuda.manual_seed(parsed_args.random_seed) 410 | 411 | # initialize checkpoint directory 412 | timestamp = datetime.now().strftime("%y%m%d_%H%M%S") 413 | parsed_args.save_checkpoint_dir = os.path.join(parsed_args.save_checkpoint_dir, f"checkpoints_{timestamp}") 414 | if not os.path.exists(parsed_args.save_checkpoint_dir): 415 | os.mkdir(parsed_args.save_checkpoint_dir) 416 | 417 | print("####################") 418 | for key, value in parsed_args.__dict__.items(): 419 | print(f"{key}: {value}") 420 | print("####################") 421 | 422 | train_distributed(parsed_args.__dict__) 423 | -------------------------------------------------------------------------------- /scripts/train_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stage 2 - instruction tuning training script for ESM-LLAMA protein description 3 | generation on Esm2LlamaInstructForCausalLM model. 4 | 5 | With LoRA. 6 | 7 | DistributedDataParallel training script implemented from scratch. 8 | 9 | The script currently supports gradient accumulation, AutoMixedPrecision, 10 | inter-epoch evaluation. 11 | 12 | The script currently does not support save/load pretrained, gradient checkpointing 13 | or generation under FSDP. 14 | 15 | reference for AMP: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html 16 | 17 | * The script is designed for multi-GPU parallelism on single node. 18 | * On the cluster, print(...) will go to stdout and tqdm(...) will go to stderr. 19 | """ 20 | 21 | import argparse 22 | from datetime import datetime 23 | import os 24 | from typing import Any, Dict, Union 25 | 26 | from peft import get_peft_model, LoraConfig 27 | from peft.peft_model import PeftModel 28 | import torch 29 | import torch.distributed as dist 30 | from torch.distributed.fsdp import FullyShardedDataParallel 31 | from torch.nn.parallel import DistributedDataParallel 32 | from torch.optim import Adam, AdamW, Optimizer 33 | from torch.optim.lr_scheduler import StepLR 34 | from torch.utils.data import DataLoader 35 | from torch.utils.data.distributed import DistributedSampler 36 | from tqdm import tqdm 37 | from transformers import AutoTokenizer 38 | from transformers import EsmModel, LlamaForCausalLM 39 | 40 | from dataset import Prot2TextLightDataset, Prot2TextLightCollater 41 | from models import ( 42 | ModalityAdapter, 43 | ModalityAdapterConfig, 44 | Esm2LlamaInstructForCausalLM 45 | ) 46 | import scripts.utils_argparse as utils_argparse 47 | 48 | 49 | argParser = argparse.ArgumentParser() 50 | 51 | argParser.add_argument("--esm_path", type=str) 52 | argParser.add_argument("--llama_path", type=str) 53 | # argParser.add_argument("--root_dataset_dir", type=str) 54 | argParser.add_argument("--root_csv_dir", type=str) 55 | argParser.add_argument("--save_checkpoint_dir", type=str) 56 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="") 57 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="") 58 | argParser.add_argument("--load_optimizer_scheduler_checkpoint_path", type=str, default="") 59 | 60 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype) 61 | argParser.add_argument("--batch_size_per_device", type=int) 62 | argParser.add_argument("--num_epochs", type=int) 63 | argParser.add_argument("--save_every_epochs", type=int) 64 | argParser.add_argument("--gradient_accumulation_steps", type=int) 65 | argParser.add_argument("--learning_rate", type=float) 66 | argParser.add_argument("--gradient_clipping", type=float, default=None) 67 | argParser.add_argument("--scheduler_gamma", type=float) 68 | argParser.add_argument("--random_seed", type=int) 69 | argParser.add_argument("--fix_modality_adapter", type=utils_argparse.str2bool) 70 | argParser.add_argument("--lora_rank", type=int) 71 | 72 | argParser.add_argument("--include_text_fields", type=utils_argparse.str2bool) 73 | argParser.add_argument("--name_dropout", type=float) 74 | argParser.add_argument("--taxonomy_dropout", type=float) 75 | 76 | argParser.add_argument("--train_split", type=str) 77 | argParser.add_argument("--eval_split", type=str) 78 | argParser.add_argument("--debug_trim_train_split", type=int, default=None) 79 | argParser.add_argument("--debug_trim_eval_split", type=int, default=None) 80 | 81 | 82 | def load_model(args: Dict[str, Any]) -> PeftModel: 83 | """ 84 | Standard API for different models. Used in both `train` and `generate`. 85 | Load base model of the given name, and load weights from the checkpoint path 86 | if provided. 87 | """ 88 | esm_encoder = EsmModel.from_pretrained( 89 | args["esm_path"], 90 | add_pooling_layer=False, 91 | torch_dtype=args["torch_dtype"], 92 | device_map="cpu" 93 | ) 94 | llama_decoder = LlamaForCausalLM.from_pretrained( 95 | args["llama_path"], 96 | torch_dtype=args["torch_dtype"], 97 | device_map="cpu" 98 | ) 99 | 100 | adapter_config = ModalityAdapterConfig( 101 | input_dim=esm_encoder.config.hidden_size, 102 | intermediate_dim=2048, 103 | output_dim=llama_decoder.config.hidden_size, 104 | ) 105 | adapter = ModalityAdapter(adapter_config) 106 | adapter.to(args["torch_dtype"]) 107 | 108 | model = Esm2LlamaInstructForCausalLM( 109 | esm_encoder=esm_encoder, 110 | adapter=adapter, 111 | llama_decoder=llama_decoder, 112 | ) 113 | 114 | # overwrite weights of base model if checkpoint path is provided 115 | if args["load_model_checkpoint_path"]: 116 | print(f"Loading {args['load_model_checkpoint_path']}") 117 | model_state_dict = torch.load( 118 | args["load_model_checkpoint_path"], 119 | weights_only=True, 120 | map_location="cpu" # load to CPU first 121 | # will be loaded to where the weights were saved from if not specified 122 | ) 123 | model.load_state_dict(model_state_dict) 124 | 125 | # wrap by lora either with pretrained adapter or with initialized adapter 126 | if args["load_adapter_checkpoint_dir"]: 127 | print(f"Loading {args['load_adapter_checkpoint_dir']}") 128 | model = PeftModel.from_pretrained( 129 | model, 130 | args["load_adapter_checkpoint_dir"], 131 | is_trainable=True 132 | ) 133 | else: 134 | print("Initializing LoRA adapter") 135 | lora_config = LoraConfig( 136 | r=args["lora_rank"], 137 | lora_alpha=args["lora_rank"] * 2, 138 | lora_dropout=0.1, 139 | bias="none", 140 | init_lora_weights=True, 141 | target_modules=[ 142 | "self_attn.q_proj", 143 | "self_attn.k_proj", 144 | "self_attn.v_proj", 145 | "self_attn.o_proj", 146 | "mlp.gate_proj", 147 | "mlp.up_proj", 148 | "mlp.down_proj" 149 | ], # for llama_decoder 150 | modules_to_save=( 151 | ["adapter.fc1", "adapter.fc2"] 152 | if not args["fix_modality_adapter"] 153 | else None 154 | ) 155 | ) 156 | model = get_peft_model(model, lora_config) 157 | 158 | model.print_trainable_parameters() 159 | 160 | return model 161 | 162 | 163 | def teacher_forcing_forward_pass( 164 | rank: int, 165 | model: Union[DistributedDataParallel, FullyShardedDataParallel], 166 | data_batch: Dict[str, Any], 167 | ) -> torch.Tensor: # loss 168 | """ 169 | Standard API for different models. Used in both `train_epoch` and `eval_epoch`. 170 | Prepare inputs from dataloader, migrate variable to the same device as the model, 171 | and execute the forward pass with teacher forcing. 172 | 173 | Returned loss is not scaled with gradient accumulation steps. 174 | """ 175 | return model( 176 | input_ids=data_batch["input_ids"].to(rank), 177 | attention_mask=data_batch["attention_mask"].to(rank), 178 | labels=data_batch["labels"].to(rank), 179 | protein_input_ids=data_batch["protein_input_ids"].to(rank), 180 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank), 181 | use_cache=False, 182 | output_attentions=False, 183 | output_hidden_states=False, 184 | return_dict=False, 185 | )[0] 186 | 187 | 188 | def setup(rank: int, world_size: int): 189 | """ 190 | Initialize processes for distributed training before first epoch. 191 | Fetch from job script or launcher to set the IP address and the port of the 192 | master node. 193 | """ 194 | os.environ['MASTER_ADDR'] = os.getenv('MASTER_ADDR', 'localhost') 195 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '9901') 196 | # initialize the process group 197 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 198 | 199 | 200 | def cleanup(): 201 | """End processes for distributed training after last epoch""" 202 | dist.destroy_process_group() 203 | 204 | 205 | def train_epoch( 206 | rank: int, 207 | current_epoch: int, 208 | model: Union[DistributedDataParallel, FullyShardedDataParallel], 209 | dataloader: DataLoader, 210 | optimizer: Optimizer, 211 | args: Dict[str, Any] 212 | ): 213 | """Iterate over all batches for one epoch in training with teacher forcing""" 214 | model.train() 215 | ddp_loss = torch.zeros(2).to(rank) 216 | # [0] for acc. loss and [1] for num. of seen batches 217 | ddp_gradnorm = torch.zeros(2).to(rank) 218 | # [0] for acc. gradnorm and [1] for num. of passed steps 219 | optimizer.zero_grad() # erase accumulated gradients from last epoch 220 | 221 | t = tqdm(iter(dataloader)) 222 | for batch_idx, data_batch in enumerate(t): 223 | # with autocast, logits will be in AUTOCAST_DTYPE 224 | # but loss will be re-casted to torch.float32 225 | # and model weights will stay in torch.float32 226 | loss = teacher_forcing_forward_pass( 227 | rank=rank, 228 | model=model, 229 | data_batch=data_batch, 230 | ) 231 | 232 | # rescale loss for consistency with different gradient accumulation steps 233 | loss = loss / args["gradient_accumulation_steps"] 234 | 235 | # summary current batch 236 | t.set_postfix({ 237 | "mode": "train", 238 | "epoch": f"{current_epoch}/{args['num_epochs']}", 239 | "batch_loss": loss.item() * args["gradient_accumulation_steps"], 240 | "device": f"rank:{rank}" 241 | }) 242 | ddp_loss[0] += loss.item() * args["gradient_accumulation_steps"] 243 | ddp_loss[1] += 1 # the loss is the weighted mean of the output of every batch 244 | 245 | # scale the loss up by a large factor to prevent them from becoming too small 246 | # then accumulate the scaled grads 247 | loss.backward() 248 | # backward out of autocast, but still uses same dtype as for forward 249 | 250 | # update weights by loss if accumulation step is reached 251 | if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0: 252 | gradnorm = torch.nn.utils.clip_grad_norm_( 253 | model.parameters(), 254 | max_norm=( 255 | float("inf") 256 | if args["gradient_clipping"] is None 257 | else args["gradient_clipping"] 258 | ) 259 | ) 260 | ddp_gradnorm[0] += gradnorm 261 | ddp_gradnorm[1] += 1 262 | 263 | optimizer.step() 264 | optimizer.zero_grad(set_to_none=True) 265 | 266 | # summary current epoch 267 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) 268 | if rank == 0: 269 | print( 270 | f"[epoch={current_epoch}/{args['num_epochs']}, " 271 | f"train_loss={ddp_loss[0] / ddp_loss[1]}, " 272 | f"epoch_lr={optimizer.param_groups[0]['lr']}, " 273 | f"epoch_gradnorm={ddp_gradnorm[0] / ddp_gradnorm[1]}]" 274 | ) 275 | # NaN detection 276 | if ddp_loss[0] != ddp_loss[0]: 277 | raise ValueError( 278 | "NaN detected in the training loss of the epoch, training interrupted." 279 | ) 280 | 281 | 282 | def eval_epoch( 283 | rank: int, 284 | current_epoch: int, 285 | model: Union[DistributedDataParallel, FullyShardedDataParallel], 286 | dataloader: DataLoader, 287 | args: Dict[str, Any] 288 | ): 289 | """Iterate over all batches in evaluation with teacher forcing""" 290 | model.eval() 291 | ddp_loss = torch.zeros(2).to(rank) 292 | # [0] for acc. loss and [1] for num. of seen batches 293 | 294 | t = tqdm(iter(dataloader)) 295 | for data_batch in t: 296 | with torch.no_grad(): 297 | loss = teacher_forcing_forward_pass( 298 | rank=rank, 299 | model=model, 300 | data_batch=data_batch, 301 | ) 302 | 303 | t.set_postfix({ 304 | "mode": "eval", 305 | "epoch": f"{current_epoch}/{args['num_epochs']}", 306 | "batch_loss": loss.item(), 307 | "device": f"rank:{rank}" 308 | }) 309 | ddp_loss[0] += loss.item() 310 | ddp_loss[1] += 1 311 | 312 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) 313 | if rank == 0: 314 | print( 315 | f"[epoch={current_epoch}/{args['num_epochs']}, " 316 | f"eval_loss={ddp_loss[0] / ddp_loss[1]}]" 317 | ) 318 | 319 | 320 | def train_on_device( 321 | rank: int, 322 | world_size: int, 323 | args: Dict[str, Any] 324 | ): 325 | """ 326 | Training and evaluation process for each device, including epochs of training 327 | with teacher forcing. 328 | """ 329 | setup(rank, world_size) 330 | 331 | # prepare datasets and dataloaders 332 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"]) 333 | llama_tokenizer = AutoTokenizer.from_pretrained( 334 | args["llama_path"], 335 | pad_token='<|reserved_special_token_0|>' 336 | ) 337 | 338 | train_dataset = Prot2TextLightDataset( 339 | csv_path=os.path.join(args["root_csv_dir"], f"{args['train_split']}.csv"), 340 | ) 341 | if args["debug_trim_train_split"]: 342 | train_dataset.data = train_dataset.data[:args["debug_trim_train_split"]] 343 | train_sampler = DistributedSampler( 344 | train_dataset, 345 | rank=rank, 346 | num_replicas=world_size, 347 | shuffle=True 348 | ) 349 | 350 | train_collater = Prot2TextLightCollater( 351 | sequence_tokenizer=esm_tokenizer, 352 | description_tokenizer=llama_tokenizer, 353 | mode="train", 354 | include_text_fields=args["include_text_fields"], 355 | name_dropout=args["name_dropout"], 356 | taxonomy_dropout=args["taxonomy_dropout"], 357 | ) 358 | train_loader = DataLoader( 359 | train_dataset, 360 | batch_size=args["batch_size_per_device"], 361 | sampler=train_sampler, 362 | collate_fn=train_collater, 363 | num_workers=4, # parallel CPU cores used for data loading 364 | pin_memory=True, # enable page-locked memory allocation for faster data transfer to GPUs 365 | shuffle=False, # avoid shuffling twice with DistributedSampler 366 | drop_last=True, # avoid incomplete batch at the end 367 | ) 368 | print(f"Train dataset loaded on rank:{rank}") 369 | 370 | eval_dataset = Prot2TextLightDataset( 371 | csv_path=os.path.join(args["root_csv_dir"], f"{args['eval_split']}.csv"), 372 | ) 373 | if args["debug_trim_eval_split"]: 374 | eval_dataset.data = eval_dataset.data[:args["debug_trim_eval_split"]] 375 | eval_sampler = DistributedSampler( 376 | eval_dataset, 377 | rank=rank, 378 | num_replicas=world_size, 379 | shuffle=False 380 | ) 381 | eval_loader = DataLoader( 382 | eval_dataset, 383 | batch_size=args["batch_size_per_device"], 384 | sampler=eval_sampler, 385 | collate_fn=train_collater, 386 | num_workers=4, 387 | pin_memory=True, 388 | shuffle=False, 389 | drop_last=True, 390 | ) 391 | print(f"Eval dataset loaded on rank:{rank}") 392 | 393 | torch.cuda.set_device(rank) 394 | 395 | model = load_model(args=args) 396 | model = model.to(rank) 397 | 398 | model = DistributedDataParallel( 399 | model, 400 | # find_unused_parameters=True # suppress error for unused parameters in wrapped model 401 | ) 402 | print(f"DDP model loaded on rank:{rank}") 403 | 404 | # initialization of the optimizer after wrapping the model 405 | optimizer = Adam(model.parameters(), lr=args["learning_rate"]) 406 | scheduler = StepLR(optimizer, step_size=1, gamma=args["scheduler_gamma"]) 407 | if args["load_optimizer_scheduler_checkpoint_path"]: 408 | print(f"Loading {args['load_optimizer_scheduler_checkpoint_path']}") 409 | checkpoint_state_dicts = torch.load( 410 | args["load_optimizer_scheduler_checkpoint_path"], 411 | weights_only=True 412 | ) 413 | optimizer_state_dict = checkpoint_state_dicts["optimizer_state_dict"] 414 | scheduler_state_dict = checkpoint_state_dicts["scheduler_state_dict"] 415 | optimizer.load_state_dict(optimizer_state_dict) 416 | scheduler.load_state_dict(scheduler_state_dict) 417 | 418 | # core loop of epochs 419 | for epoch_idx in range(1, args["num_epochs"] + 1): 420 | # shuffle data differently at each epoch across all processes 421 | train_sampler.set_epoch(epoch=epoch_idx) 422 | 423 | train_epoch( 424 | rank=rank, 425 | current_epoch=epoch_idx, 426 | model=model, 427 | dataloader=train_loader, 428 | optimizer=optimizer, 429 | args=args 430 | ) 431 | scheduler.step() 432 | dist.barrier() # use a barrier to make sure training is done on all ranks 433 | 434 | eval_epoch( 435 | rank=rank, 436 | model=model, 437 | current_epoch=epoch_idx, 438 | dataloader=eval_loader, 439 | args=args 440 | ) 441 | dist.barrier() 442 | 443 | if ( 444 | epoch_idx == 1 445 | or epoch_idx == args["num_epochs"] 446 | or epoch_idx % args["save_every_epochs"] == 0 447 | ): 448 | if rank == 0: 449 | adapter_checkpoint_dir = os.path.join( 450 | args["save_checkpoint_dir"], 451 | f"adapter_checkpoint_{epoch_idx}" 452 | ) 453 | model.module.save_pretrained(adapter_checkpoint_dir) 454 | print(f"Saving {adapter_checkpoint_dir}") 455 | 456 | optimizer_scheduler_checkpoint_path = os.path.join( 457 | args["save_checkpoint_dir"], 458 | f"optimizer_scheduler_checkpoint_{epoch_idx}.pt" 459 | ) 460 | torch.save( 461 | { 462 | "optimizer_state_dict": optimizer.state_dict(), 463 | "scheduler_state_dict": scheduler.state_dict(), 464 | }, 465 | optimizer_scheduler_checkpoint_path 466 | ) 467 | print(f"Saving {optimizer_scheduler_checkpoint_path}") 468 | 469 | dist.barrier() 470 | 471 | cleanup() 472 | 473 | 474 | def train_distributed( 475 | args: Dict[str, Any] # replace **kwargs for compatibility with spawn 476 | ): 477 | """ 478 | Core training process across multiple devices with epochs of training and 479 | inter-epoch evaluation. 480 | """ 481 | torch.multiprocessing.spawn( 482 | train_on_device, 483 | args=(args["world_size"], args), 484 | nprocs=args["world_size"], 485 | join=True 486 | ) 487 | 488 | 489 | if __name__ == '__main__': 490 | # suppress messages from AutoTokenizer parallelism and Graphein respectively 491 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 492 | os.environ["LOGURU_LEVEL"] = "INFO" 493 | 494 | parsed_args = argParser.parse_args() 495 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict GPU visibility 496 | parsed_args.world_size = torch.cuda.device_count() # use up all visible GPUs 497 | 498 | torch.manual_seed(parsed_args.random_seed) 499 | torch.cuda.manual_seed(parsed_args.random_seed) 500 | 501 | # initialize checkpoint directory 502 | timestamp = datetime.now().strftime("%y%m%d_%H%M%S") 503 | parsed_args.save_checkpoint_dir = os.path.join( 504 | parsed_args.save_checkpoint_dir, 505 | f"checkpoints_{timestamp}" 506 | ) 507 | if not os.path.exists(parsed_args.save_checkpoint_dir): 508 | os.mkdir(parsed_args.save_checkpoint_dir) 509 | 510 | print("####################") 511 | for key, value in parsed_args.__dict__.items(): 512 | print(f"{key}: {value}") 513 | print("####################") 514 | 515 | train_distributed(parsed_args.__dict__) 516 | -------------------------------------------------------------------------------- /models/modeling_esm2rgcn2llama_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model. 3 | 4 | Esm2Rgcn2LlamaInstructForCausalLM = EsmModel + RgcnAdapter + LlamaForCausalLM 5 | 6 | For training/evaluation under teacher-forcing scenario, the model `forward` 7 | function shall take following arguments: 8 | * input_ids: (bsz, prompt_len+description_len) # whole chat template 9 | * attention_mask: (bsz, prompt_len+description_len) # left & right padding 10 | * position_ids: (bsz, prompt_len+description_len) # optional 11 | * past_key_values: None 12 | * labels: (bsz, prompt_len+description_len) # -100 for padding & prompt 13 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds 14 | * protein_attention_mask: (bsz, prot_seq_len) # left padding 15 | * protein_position_ids: (bsz, prot_seq_len) # optional 16 | * protein_head_mask: (num_heads,) or (num_layers, num_heads) # optional 17 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional 18 | * graph_edge_index: (2, sum_num_edges) 19 | * graph_edge_type: (sum_num_edges,) 20 | * graph_batch: (sum_num_nodes,) # optional 21 | * use_cache: False 22 | * return_decoder_inputs: False 23 | 24 | For inference, the model `generate` function shall take following arguments: 25 | * inputs: (bsz, prompt_len) # prompt part of chat template 26 | * attention_mask: (bsz, prompt_len) # left padding 27 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds 28 | * protein_attention_mask: (bsz, prot_seq_len) # left padding 29 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional 30 | * graph_edge_index: (2, sum_num_edges) 31 | * graph_edge_type: (sum_num_edges,) 32 | * graph_batch: (sum_num_nodes,) # optional 33 | """ 34 | 35 | 36 | from typing import Optional, Tuple, Union 37 | 38 | import torch 39 | import torch_geometric 40 | import torch_geometric.backend 41 | import torch_geometric.nn 42 | from torch_geometric.nn.conv.rgcn_conv import masked_edge_index 43 | from torch_geometric.typing import Adj, OptTensor, SparseTensor 44 | from torch_geometric.utils import index_sort, scatter 45 | from torch_geometric.utils.sparse import index2ptr 46 | from transformers import Cache, PreTrainedModel 47 | from transformers.generation.utils import GenerateOutput 48 | from transformers.modeling_outputs import CausalLMOutputWithPast 49 | from transformers.models.esm.modeling_esm import EsmModel 50 | from transformers.models.llama import LlamaForCausalLM 51 | 52 | from .configuration_esm2rgcn2llama_instruct import ( 53 | RgcnAdapterConfig, 54 | Esm2Rgcn2LlamaInstructConfig 55 | ) 56 | 57 | 58 | class RgcnConvLayer(torch_geometric.nn.RGCNConv): 59 | """Modified `torch_geometric.nn.RGCNConv` Layer for Flexible Precision.""" 60 | def forward( 61 | self, 62 | x: Union[OptTensor, Tuple[OptTensor, torch.Tensor]], 63 | edge_index: Adj, 64 | edge_type: OptTensor = None 65 | ): 66 | """Forward pass of the RGCN layer.""" 67 | # handle input node features 68 | x_l = x[0] if isinstance(x, tuple) else x # left node features 69 | if x_l is None: # default to indices if no features provided 70 | x_l = torch.arange(self.in_channels_l, device=self.weight.device) 71 | x_r = x[1] if isinstance(x, tuple) else x_l # right node features 72 | 73 | # define the size of the input graph 74 | size = (x_l.size(0), x_r.size(0)) 75 | # extract edge types for SparseTensor input 76 | if isinstance(edge_index, SparseTensor): 77 | edge_type = edge_index.storage.value() 78 | assert edge_type is not None 79 | 80 | # initialize the output tensor, specify additional dtype for flexible precision 81 | out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device, dtype=x.dtype) 82 | 83 | # handle weight decomposition (basis or block diagonal) 84 | weight = self.weight 85 | if self.num_bases is not None: # Basis-decomposition 86 | weight = (self.comp @ weight.view(self.num_bases, -1)).view( 87 | self.num_relations, self.in_channels_l, self.out_channels 88 | ) 89 | 90 | if self.num_blocks is not None: # Block-diagonal-decomposition 91 | if not torch.is_floating_point(x_r) and self.num_blocks is not None: 92 | raise ValueError( 93 | 'Block-diagonal decomposition not supported for non-continuous input features.' 94 | ) 95 | for i in range(self.num_relations): 96 | tmp = masked_edge_index(edge_index, i == edge_type) 97 | h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) 98 | h = h.view(-1, weight.size(1), weight.size(2)) 99 | h = torch.einsum('abc,bcd->abd', h, weight[i]) 100 | out = out + h.contiguous().view(-1, self.out_channels) 101 | 102 | else: # no decomposition (standard weight handling) 103 | use_segment_matmul = torch_geometric.backend.use_segment_matmul 104 | 105 | # heuristic for enabling `segment_matmul` optimization 106 | if use_segment_matmul is None: 107 | segment_count = scatter( 108 | torch.ones_like(edge_type), 109 | edge_type, 110 | dim_size=self.num_relations 111 | ) 112 | self._use_segment_matmul_heuristic_output = ( 113 | torch_geometric.backend.use_segment_matmul_heuristic( 114 | num_segments=self.num_relations, 115 | max_segment_size=int(segment_count.max()), 116 | in_channels=self.weight.size(1), 117 | out_channels=self.weight.size(2), 118 | ) 119 | ) 120 | assert self._use_segment_matmul_heuristic_output is not None 121 | use_segment_matmul = self._use_segment_matmul_heuristic_output 122 | 123 | # segment matmul optimization 124 | if ( 125 | use_segment_matmul 126 | and torch_geometric.typing.WITH_SEGMM 127 | and not torch_geometric.is_compiling() 128 | and self.num_bases is None 129 | and x_l.is_floating_point() 130 | and isinstance(edge_index, torch.Tensor) 131 | ): 132 | if not self.is_sorted: 133 | if (edge_type[1:] < edge_type[:-1]).any(): 134 | edge_type, perm = index_sort(edge_type, max_value=self.num_relations) 135 | edge_index = edge_index[:, perm] 136 | edge_type_ptr = index2ptr(edge_type, self.num_relations) 137 | out = self.propagate(edge_index, x=x_l, edge_type_ptr=edge_type_ptr, size=size) 138 | 139 | else: # loop through relations without optimization 140 | for i in range(self.num_relations): 141 | tmp = masked_edge_index(edge_index, i == edge_type) 142 | 143 | if not torch.is_floating_point(x_r): 144 | out = out + self.propagate( 145 | tmp, 146 | x=weight[i, x_l], 147 | edge_type_ptr=None, 148 | size=size 149 | ) 150 | else: 151 | h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size) 152 | out = out + (h @ weight[i]) 153 | 154 | # incorporate root embeddings 155 | root = self.root 156 | if root is not None: 157 | if not torch.is_floating_point(x_r): 158 | out = out + root[x_r] 159 | else: 160 | out = out + x_r @ root 161 | 162 | # add bias if applicable 163 | if self.bias is not None: 164 | out = out + self.bias 165 | 166 | return out 167 | 168 | def edge_update(self) -> torch.Tensor: 169 | """Placeholder for edge feature updates (if applicable).""" 170 | pass 171 | 172 | 173 | class RgcnAdapter(PreTrainedModel): 174 | """ 175 | Relational Graph Convolutional Network adapter to match the hidden size of 176 | different modalities. 177 | """ 178 | config_class = RgcnAdapterConfig # configuration class for this model 179 | 180 | def __init__(self, config: RgcnAdapterConfig): 181 | super().__init__(config) 182 | self.config = config 183 | self.activation = torch.nn.GELU() 184 | self.dropout = torch.nn.Dropout(p=config.dropout_rate) 185 | self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim) 186 | self.rgcn_layers = torch.nn.ModuleList([ 187 | RgcnConvLayer( 188 | in_channels=config.intermediate_dim, 189 | out_channels=config.intermediate_dim, 190 | num_relations=config.n_relations, 191 | ) 192 | for _ in range(config.n_layers) 193 | ]) 194 | self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim) 195 | self.post_init() # initialize weights and apply final processing 196 | 197 | def forward( 198 | self, 199 | hidden_states: torch.FloatTensor, # (bsz, seq_len, input_dim) 200 | attention_mask: torch.LongTensor, # (bsz, seq_len) 201 | edge_index: torch.LongTensor, # (2, sum_num_edges) 202 | edge_type: torch.LongTensor, # (sum_num_edges,) 203 | batch: Optional[torch.LongTensor] = None # (sum_num_nodes,) 204 | ) -> torch.FloatTensor: 205 | hidden_states = self.activation(self.fc1(hidden_states)) 206 | hidden_states = self.dropout(hidden_states) # (bsz, seq_len, interm_dim) 207 | 208 | # create mask for node embeddings to be updated by RGCN layers 209 | nodes_mask = attention_mask.clone().bool() # (bsz, seq_len) 210 | nodes_mask[:, 0] = False # exclude bos token 211 | batch_size = hidden_states.size(0) 212 | batch_indices = torch.arange(0, batch_size) # (bsz,) 213 | eos_indices = attention_mask.sum(dim=1) - 1 # (bsz,) 214 | nodes_mask[batch_indices, eos_indices] = False # exclude eos token 215 | 216 | # compute RGCN layers on node embeddings only 217 | nodes_hidden_states = hidden_states[nodes_mask] # (sum_num_nodes, interm_dim) 218 | for layer in self.rgcn_layers: 219 | nodes_hidden_states = layer(nodes_hidden_states, edge_index, edge_type) 220 | nodes_hidden_states = self.activation(nodes_hidden_states) 221 | nodes_hidden_states = self.dropout(nodes_hidden_states) 222 | 223 | # update node hidden states with RGCN outputs 224 | hidden_states[nodes_mask] = nodes_hidden_states 225 | 226 | hidden_states = self.activation(self.fc2(hidden_states)) 227 | hidden_states = self.dropout(hidden_states) 228 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1) 229 | return hidden_states # (bsz, seq_len, output_dim) 230 | 231 | 232 | class Esm2Rgcn2LlamaInstructForCausalLM(PreTrainedModel): 233 | """ 234 | Esm2Rgcn2LlamaInstructForCausalLM model for protein function prediction. 235 | Similar to `EncoderDecoderModel` but with more complicated architecture. 236 | Initialize with either a configuration OR all three components. 237 | `kwargs` can override standalone attributes in `Esm2Rgcn2LlamaInstructConfig`. 238 | """ 239 | config_class = Esm2Rgcn2LlamaInstructConfig # configuration class for this model 240 | 241 | def __init__( 242 | self, 243 | config: Optional[Esm2Rgcn2LlamaInstructConfig] = None, 244 | esm_encoder: Optional[EsmModel] = None, 245 | adapter: Optional[RgcnAdapter] = None, 246 | llama_decoder: Optional[LlamaForCausalLM] = None, 247 | **kwargs 248 | ): 249 | if config is not None: # components ignored if config is provided 250 | super().__init__(config) 251 | self.esm_encoder = EsmModel( 252 | config.esm_config, 253 | add_pooling_layer=False 254 | ) 255 | self.adapter = RgcnAdapter(config.adapter_config) 256 | self.llama_decoder = LlamaForCausalLM(config.llama_config) 257 | else: 258 | config = Esm2Rgcn2LlamaInstructConfig( 259 | esm_config=esm_encoder.config, 260 | adapter_config=adapter.config, 261 | llama_config=llama_decoder.config, 262 | **kwargs # override standalone attributes 263 | ) 264 | super().__init__(config) 265 | self.esm_encoder = esm_encoder 266 | self.adapter = adapter 267 | self.llama_decoder = llama_decoder 268 | 269 | def prepare_decoder_inputs( 270 | self, 271 | input_ids: torch.LongTensor, 272 | encoder_hidden_states: torch.FloatTensor, 273 | attention_mask: Optional[torch.LongTensor] = None, 274 | encoder_attention_mask: Optional[torch.LongTensor] = None, 275 | ): 276 | """ 277 | Embed and replace placeholder in `input_ids` by encoder hidden states. 278 | `input_ids` must be passed to locate placeholder for replacement. 279 | """ 280 | # preparation 281 | batch_size, seq_len = input_ids.size() 282 | _, encoder_seq_len, _ = encoder_hidden_states.size() 283 | if attention_mask is None: 284 | attention_mask = torch.ones( 285 | (batch_size, seq_len), 286 | dtype=torch.long, 287 | device=input_ids.device 288 | ) 289 | if encoder_attention_mask is None: 290 | encoder_attention_mask = torch.ones( 291 | (batch_size, encoder_seq_len), 292 | dtype=torch.long, 293 | device=encoder_hidden_states.device 294 | ) 295 | inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids) 296 | # replacement 297 | placeholder_mask = input_ids == self.config.placeholder_id 298 | encoder_mask = encoder_attention_mask.bool() 299 | inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask] 300 | return inputs_embeds, attention_mask 301 | 302 | def forward( 303 | self, 304 | # chat template text inputs 305 | input_ids: Optional[torch.LongTensor] = None, 306 | attention_mask: Optional[torch.LongTensor] = None, 307 | position_ids: Optional[torch.LongTensor] = None, 308 | past_key_values: Optional[Cache] = None, 309 | labels: Optional[torch.LongTensor] = None, 310 | # protein amino-acid sequence inputs 311 | protein_input_ids: Optional[torch.LongTensor] = None, 312 | protein_attention_mask: Optional[torch.LongTensor] = None, 313 | protein_position_ids: Optional[torch.LongTensor] = None, 314 | protein_head_mask: Optional[torch.LongTensor] = None, 315 | protein_inputs_embeds: Optional[torch.FloatTensor] = None, 316 | # graph-related inputs 317 | graph_edge_index: Optional[torch.LongTensor] = None, 318 | graph_edge_type: Optional[torch.LongTensor] = None, 319 | graph_batch: Optional[torch.LongTensor] = None, 320 | # behavior control arguments 321 | use_cache: Optional[bool] = None, 322 | output_attentions: Optional[bool] = None, 323 | output_hidden_states: Optional[bool] = None, 324 | return_dict: Optional[bool] = None, 325 | return_encoder_outputs: bool = False, 326 | return_adapter_outputs: bool = False, 327 | return_decoder_inputs: bool = False, 328 | cache_position: Optional[torch.LongTensor] = None 329 | ) -> Union[Tuple, CausalLMOutputWithPast]: 330 | """ 331 | Compute encoder and adapter outputs, then pass to decoder. 332 | `input_ids` is expected to be [prompt + description] in teacher-forcing 333 | scenario and [prompt] only in first iteration of inference (with 334 | return_decoder_inputs=True). 335 | Attention: possible concatenation of the mask and labels should be 336 | handled before calling this method. 337 | `inputs_embeds` not allowed due to placeholder replacement scheme. 338 | """ 339 | # esm_encoder forward 340 | encoder_output = self.esm_encoder( 341 | input_ids=protein_input_ids, 342 | attention_mask=protein_attention_mask, 343 | position_ids=protein_position_ids, 344 | head_mask=protein_head_mask, 345 | inputs_embeds=protein_inputs_embeds, 346 | use_cache=False, # because config.esm_config.is_decoder=False 347 | output_attentions=output_attentions, 348 | output_hidden_states=output_hidden_states, 349 | return_dict=return_dict 350 | ) 351 | encoder_hidden_states = encoder_output[0] 352 | encoder_attention_mask = protein_attention_mask 353 | if return_encoder_outputs: 354 | return encoder_output 355 | # adapter forward 356 | adapter_output = self.adapter( 357 | hidden_states=encoder_hidden_states, 358 | attention_mask=encoder_attention_mask, 359 | edge_index=graph_edge_index, 360 | edge_type=graph_edge_type, 361 | batch=graph_batch 362 | ) 363 | if return_adapter_outputs: 364 | return adapter_output, encoder_attention_mask 365 | # decoder input preparation 366 | inputs_embeds, attention_mask = self.prepare_decoder_inputs( 367 | input_ids=input_ids, 368 | encoder_hidden_states=adapter_output, 369 | attention_mask=attention_mask, 370 | encoder_attention_mask=encoder_attention_mask, 371 | ) 372 | if return_decoder_inputs: 373 | return inputs_embeds, attention_mask 374 | # llama_decoder forward 375 | return self.llama_decoder.forward( 376 | input_ids=None, 377 | attention_mask=attention_mask, 378 | position_ids=position_ids, 379 | past_key_values=past_key_values, 380 | inputs_embeds=inputs_embeds, 381 | labels=labels, 382 | use_cache=use_cache, 383 | output_attentions=output_attentions, 384 | return_dict=return_dict, 385 | cache_position=cache_position 386 | ) 387 | 388 | def generate( 389 | self, 390 | inputs: torch.LongTensor, # alias of `input_ids` 391 | attention_mask: Optional[torch.LongTensor] = None, 392 | protein_input_ids: Optional[torch.LongTensor] = None, 393 | protein_attention_mask: Optional[torch.LongTensor] = None, 394 | protein_inputs_embeds: Optional[torch.FloatTensor] = None, 395 | graph_edge_index: Optional[torch.LongTensor] = None, 396 | graph_edge_type: Optional[torch.LongTensor] = None, 397 | graph_batch: Optional[torch.LongTensor] = None, 398 | **kwargs 399 | ) -> Union[GenerateOutput, torch.LongTensor]: 400 | """ 401 | Do inference based on given input prompt. 402 | `inputs` is expected to be [prompt] only. 403 | Output will not keep the input prompt due to input in form of embeds. 404 | Generation behavior can be controlled by `args` and `kwargs`, read 405 | `GenerationMixin.generate` for more info. 406 | """ 407 | # get decoder inputs 408 | prompt_inputs_embeds, prompt_attention_mask = self( 409 | input_ids=inputs, 410 | attention_mask=attention_mask, 411 | protein_input_ids=protein_input_ids, 412 | protein_attention_mask=protein_attention_mask, 413 | protein_inputs_embeds=protein_inputs_embeds, 414 | graph_edge_index=graph_edge_index, 415 | graph_edge_type=graph_edge_type, 416 | graph_batch=graph_batch, 417 | use_cache=False, 418 | output_attentions=False, 419 | output_hidden_states=False, 420 | return_dict=False, 421 | return_decoder_inputs=True 422 | ) 423 | # do generate on llama_decoder 424 | return self.llama_decoder.generate( 425 | inputs_embeds=prompt_inputs_embeds, 426 | attention_mask=prompt_attention_mask, 427 | **kwargs 428 | ) 429 | 430 | def gradient_checkpointing_enable(self): 431 | """ 432 | Enable gradient checkpointing for all submodules that support it. 433 | Attention! Model need to be in train mode before calling this method. 434 | """ 435 | if hasattr(self.esm_encoder, "gradient_checkpointing_enable"): 436 | self.esm_encoder.gradient_checkpointing_enable() 437 | if hasattr(self.llama_decoder, "gradient_checkpointing_enable"): 438 | self.llama_decoder.gradient_checkpointing_enable() 439 | # simple adapter no need to implement gradient checkpointing 440 | 441 | def gradient_checkpointing_disable(self): 442 | if hasattr(self.esm_encoder, "gradient_checkpointing_disable"): 443 | self.esm_encoder.gradient_checkpointing_disable() 444 | if hasattr(self.llama_decoder, "gradient_checkpointing_disable"): 445 | self.llama_decoder.gradient_checkpointing_disable() 446 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class for protein function prediction instruction training. 3 | 4 | Every sample from the dataset is of following attributes: 5 | - graph related features if not ignored: 6 | - x: (num_nodes, num_node_features) 7 | - edge_index: (2, num_edges) 8 | - edge_type: (num_edges, ) 9 | - input ids: 10 | - sequence_input_ids: (1, sequence_length+2) # bos and eos tokens 11 | - prompt_input_ids: (1, prompt_length) # with generation prompt head 12 | - description_input_ids: (1, description_length+1) # eos token only 13 | - other attributes: 14 | - name: str 15 | 16 | * No attention mask will be produced at this stage and a dynamic padding will be 17 | applied in the collate function of the dataloader. 18 | 19 | A chat template is applied in such process to assemble full name and taxon of 20 | every protein and leave space for sequence embeddings with placeholders. 21 | The assembled prompt will be of following structure: 22 | ( 23 | <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n 24 | You are a scientific assistant specialized in ... 25 | <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n 26 | Name: ; Taxon: ; Sequence embeddings: ... 27 | <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n 28 | ) 29 | And the description of the protein will be of following structure: 30 | ( 31 | Involved in ... Required for ... <|eot_id|> 32 | ) 33 | The template is designed for Meta-Llama-3.1-8B-Instruct-hf and can be adapted 34 | to other models by calling `process_text` function with the new tokenizer. 35 | 36 | Example of usage: 37 | >>> from transformers import AutoTokenizer 38 | >>> from dataset import Prot2TextInstructDataset 39 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D") 40 | >>> llama_tokenizer = AutoTokenizer.from_pretrained( 41 | "/data/Meta-Llama-3.1-8B-Instruct-hf", 42 | pad_token='<|reserved_special_token_0|>' 43 | ) 44 | >>> train_dataset = Prot2TextInstructDataset( 45 | root_dir="/data/Prot2Text-Llama3-Data/train", 46 | csv_path="./data/train.csv", 47 | sequence_tokenizer=esm_tokenizer, 48 | description_tokenizer=llama_tokenizer, 49 | skip_download=False, # download PDB files 50 | skip_reload=False, # construct graph and tokenize text 51 | ) 52 | """ 53 | 54 | import multiprocessing as mp 55 | import os 56 | import sys 57 | from typing import Any, Dict, List, Optional, Union 58 | 59 | from graphein.protein.config import ProteinGraphConfig 60 | import pandas as pd 61 | import torch 62 | import torch.utils.data 63 | import torch_geometric 64 | import torch_geometric.data 65 | from transformers import PreTrainedTokenizer 66 | from tqdm import tqdm 67 | import wget 68 | 69 | from .pdb2nx import construct_nx_graph 70 | from .nx2pyg import convert_nx_to_pyg 71 | from .utils_dataset import default_graph_process_config 72 | 73 | 74 | class Prot2TextInstructDataset(torch_geometric.data.Dataset): 75 | """ 76 | Dataset class for proteins. 77 | 78 | (1) Download PDB files from AlphaFoldDB if not skipped, then 79 | (2) preprocess graph and textual features if not skipped to prepare for the 80 | formation of :class:`Prot2TextDataLoader`. 81 | 82 | * Use multiple :class:`Prot2TextDataset` instead of one for multiple 83 | splits of the dataset. 84 | * root_dir should be the root directory of each split. 85 | 86 | Args: 87 | root_dir: 88 | Root directory of the dataset where root_dir/raw and 89 | root_dir/processed will be prepared. If the dataset is 90 | of a particular split of the whole dataset, the directory 91 | of the split should be passed instead. 92 | csv_path: 93 | Path of the csv file containing the uniprot ids, amino-acid 94 | sequence and the description. 95 | sequence_tokenizer: 96 | Tokenizer for protein amino-acid sequences. Tokenization 97 | will be done in preprocessing without pad token, and padding 98 | will be added in dataloader when forming data batch. The 99 | tokenizer should have :attr:`pad_token` and :attr:`pad_token_id` 100 | attributes which will be used in collate function of the 101 | dataloader for dynamic padding. 102 | description_tokenizer: 103 | Tokenizer for protein functionality descriptions. The 104 | tokenizer should add bos token but not add eos token by 105 | itself for the consistency with prompt tokenization. However, 106 | the tokenizer should have :attr:`eos_token` and 107 | :attr:`eos_token_id` attributes which will be used in 108 | preprocessing to add eos token at the end of the description 109 | and the label. The tokenizer should also have :attr:`pad_token` 110 | and :attr:`pad_token_id` attributes which will be used in 111 | collate function of the dataloader for dynamic padding. 112 | alphafold_base_url: 113 | Base download URL of AlphaFoldDB to download PDB files from. 114 | The full download link will be :url:`{alphafold_base_url}/AF- 115 | {uniprot_id}-F1-model_v{alphafold_version}.pdb`. 116 | alphafold_version: 117 | Version of the AlphaFoldDB to download PDB files from. 118 | graph_process_config: 119 | Configuration specifying features of nodes, edges and the 120 | whole graph to be used for graph construction from PDB files. 121 | If not passed, Default configuration will be used. 122 | skip_download: 123 | Force preprocessing to skip downloading procedure and to use 124 | existing PDB files only. Otherwise, downloading procedure 125 | will be skipped only if every uniprot id from CSV file has 126 | corresponding PDB file under `root_dir/raw`. 127 | skip_reload: 128 | Force preprocessing to skip formation of graphs from PDB files 129 | and to use existing PyG tensor files only. Otherwise, processing 130 | procedure will be skipped only if every PDB file under 131 | `root_dir/raw` has corresponding PyG tensor file under 132 | `root_dir/processed`. 133 | num_processes: 134 | Number of parallel processes in formation of the graph. If 135 | not specified, all logical CPU threads will be used by default. 136 | ignore_graph_features: 137 | Ignore graph related features while loading data. This does 138 | not affect the preprocessing behavior and graph related features 139 | will be processed then saved. 140 | max_sequence_length: 141 | Maximum length of protein amino-acid sequence. Samples with 142 | longer sequences will be trimmed if such value passed. Such 143 | behavior will lose part of the information but can avoid 144 | Out-Of-Memory error in training. This does not affect the 145 | preprocessing behavior and whole sequences will be processed 146 | then saved. 147 | max_description_length: 148 | Maximum length of protein description. Samples with longer 149 | descriptions will be trimmed if such value passed. Such 150 | behavior will lose part of the information but can avoid 151 | Out-Of-Memory error in training. This does not affect the 152 | preprocessing behavior and whole descriptions will be processed 153 | then saved. 154 | system_message: 155 | The system message to be used in the chat template. 156 | placeholder_token: 157 | The placeholder token to be used in the chat template that will be 158 | replaced by the output of encoder as sequence embeddings. 159 | kwargs: 160 | Additional arguments controlling the behavior of the PyG 161 | dataset inherited from the base dataset class. Read the 162 | documentation from :class:`torch_geometric.data.Dataset` 163 | for more information. 164 | """ 165 | def __init__( 166 | self, 167 | root_dir: Union[str, os.PathLike], 168 | csv_path: Union[str, os.PathLike], 169 | sequence_tokenizer: PreTrainedTokenizer, 170 | description_tokenizer: PreTrainedTokenizer, 171 | alphafold_base_url: str = "https://alphafold.ebi.ac.uk/files/", 172 | alphafold_version: int = 4, 173 | graph_process_config: ProteinGraphConfig = default_graph_process_config, 174 | skip_download: bool = False, 175 | skip_reload: bool = False, 176 | num_processes: Optional[int] = None, 177 | ignore_graph_features: bool = False, 178 | max_sequence_length: Optional[int] = 1021, 179 | max_description_length: Optional[int] = 512, 180 | system_message: str = ( 181 | "You are a scientific assistant specialized in protein function " 182 | "predictions. Given the sequence embeddings and other information " 183 | "of a protein, describe its function clearly and concisely in " 184 | "professional language. " 185 | ), 186 | placeholder_token: str = '<|reserved_special_token_1|>', 187 | **kwargs, 188 | ): 189 | self.root_dir = root_dir 190 | self.uniprot_df = pd.read_csv(csv_path) 191 | self.sequence_tokenizer = sequence_tokenizer 192 | self.description_tokenizer = description_tokenizer 193 | self.alphafold_base_url = alphafold_base_url 194 | self.alphafold_version = alphafold_version 195 | self.graph_process_config = graph_process_config 196 | self.skip_download = skip_download 197 | self.skip_reload = skip_reload 198 | self.num_processes = num_processes 199 | self.ignore_graph_features = ignore_graph_features 200 | self.max_sequence_length = max_sequence_length 201 | self.max_description_length = max_description_length 202 | self.system_message = system_message 203 | self.placeholder_token = placeholder_token 204 | 205 | self.usable_file_names: List[Union[str, os.PathLike]] = [] 206 | super().__init__(root=root_dir, **kwargs) # first download then process 207 | self.update_usable_file_names() 208 | 209 | def download(self, overwrite_existing: bool = False): 210 | """ 211 | Downloads PDB files from AlphaFoldDB to :attr:`self.raw_dir` folder. If 212 | such file already exists, the download will be skipped. Unsuccessful 213 | attempt of downloading will not crash the script but exceptions will be 214 | noted. 215 | """ 216 | if self.skip_download: 217 | return 218 | assert self.alphafold_base_url is not None, ( 219 | "Downloading requested but base URL of AlphaFoldDB is not set. " 220 | ) 221 | assert self.alphafold_version is not None, ( 222 | "Downloading requested but version of AlphaFoldDB is not set. " 223 | ) 224 | for raw_file_name in tqdm(self.raw_file_names): 225 | raw_file_name = raw_file_name 226 | raw_file_path = os.path.join(self.raw_dir, raw_file_name) 227 | full_url = self.alphafold_base_url + raw_file_name 228 | if overwrite_existing or not os.path.exists(raw_file_path): 229 | try: 230 | wget.download(full_url, out=raw_file_path) 231 | except Exception as exception: 232 | if self.log: 233 | print( 234 | f"Download of {raw_file_name} failed due to exception: " 235 | f"{exception.__class__.__name__}: {exception}", 236 | file=sys.stderr 237 | ) 238 | 239 | def process(self, overwrite_existing: bool = False): 240 | """ 241 | Preprocesses and converts PDB files to pytorch tensor files. Parallel 242 | multiprocessing approach applied for graph construction. Unsuccessful 243 | attempt of graph processing will not crash the script but exceptions will 244 | be noted. Textual features will be encoded and added to tensor files in 245 | the end without parallelism. 246 | """ 247 | if self.skip_reload: 248 | return 249 | assert self.sequence_tokenizer is not None, ( 250 | "Processing requested but sequence tokenizer is not set." 251 | ) 252 | assert self.description_tokenizer is not None, ( 253 | "Processing requested but description tokenizer is not set." 254 | ) 255 | 256 | # graph construction: convert PDB files to tensor files 257 | with mp.Pool(processes=self.num_processes) as pool: 258 | for raw_file_name in self.raw_file_names: 259 | raw_file_path = os.path.join(self.raw_dir, raw_file_name) 260 | processed_file_path = os.path.join( 261 | self.processed_dir, 262 | os.path.splitext(raw_file_name)[0] + ".pt" 263 | ) 264 | if overwrite_existing or not os.path.exists(processed_file_path): 265 | pool.apply_async( 266 | self.process_graph, 267 | args=(raw_file_path, processed_file_path) 268 | ) 269 | pool.close() 270 | pool.join() 271 | self.update_usable_file_names() 272 | 273 | # textual feature tokenization: add tokenized features to tensor files 274 | self.process_text() 275 | self.update_usable_file_names() 276 | 277 | def process_graph( 278 | self, 279 | raw_file_path: Union[str, os.PathLike], 280 | processed_file_path: Union[str, os.PathLike] 281 | ): 282 | """ 283 | Preprocesses and converts single PDB file to pytorch tensor file. Unit 284 | function for multiprocessing. 285 | """ 286 | uniprot_id = os.path.split(raw_file_path)[1].split("-")[1] 287 | try: 288 | nx_graph = construct_nx_graph( 289 | config=self.graph_process_config, 290 | pdb_path=raw_file_path 291 | ) 292 | data = convert_nx_to_pyg(nx_graph) 293 | torch.save(data, processed_file_path) 294 | except Exception as exception: 295 | if self.log: 296 | print( 297 | f"Graph processing of {uniprot_id} failed due to exception: " 298 | f"{exception.__class__.__name__}: {exception}", 299 | file=sys.stderr 300 | ) 301 | 302 | def process_text( 303 | self, 304 | new_sequence_tokenizer: Optional[PreTrainedTokenizer] = None, 305 | new_description_tokenizer: Optional[PreTrainedTokenizer] = None 306 | ): 307 | """ 308 | Apply tokenizers to add tokenized protein amino-acid sequences and 309 | protein functionality descriptions to processed pytorch tensor files. 310 | Such process will not download and process graphs again, but only 311 | add/modify the textual features in processed Data objects. Can be used 312 | to change tokenizers on existing processed dataset. 313 | """ 314 | if new_sequence_tokenizer is not None: 315 | self.sequence_tokenizer = new_sequence_tokenizer 316 | if new_description_tokenizer is not None: 317 | self.description_tokenizer = new_description_tokenizer 318 | for processed_file_name in tqdm(self.usable_file_names): 319 | processed_file_path = os.path.join(self.processed_dir, processed_file_name) 320 | try: 321 | uniprot_id = os.path.split(processed_file_name)[1].split("-")[1] 322 | data = torch.load(processed_file_path) 323 | data.update(self._compose_and_tokenize_chat(uniprot_id)) 324 | torch.save(data, processed_file_path) 325 | except Exception as exception: 326 | if self.log: 327 | print( 328 | f"Text processing of {processed_file_name} failed due to exception: " 329 | f"{exception.__class__.__name__}: {exception}", 330 | file=sys.stderr 331 | ) 332 | 333 | def _compose_and_tokenize_chat(self, uniprot_id: str) -> Dict[str, torch.Tensor]: 334 | """ 335 | (1) Matches corresponding row in CSV DataFrame of the given uniprot id, 336 | (2) trim sequence and description to avoid OOM error, 337 | (3) apply chat template to form the prompt with placeholder, and 338 | (4) tokenizes amino-acid sequence, prompt and description. 339 | 340 | The eos token will be added at the end of the description and of the 341 | label before tokenization. 342 | """ 343 | # (1) match row and extract features 344 | filtered_df = self.uniprot_df.loc[ 345 | self.uniprot_df['AlphaFoldDB'] == uniprot_id 346 | ] 347 | sequence = filtered_df["sequence"].values[0] 348 | description = filtered_df["function"].values[0] 349 | fullname = filtered_df["Full Name"].values[0] 350 | taxon = filtered_df["taxon"].values[0] 351 | fullname = "unknown" if pd.isna(fullname) else fullname 352 | taxon = "unknown" if pd.isna(taxon) else taxon 353 | 354 | # (2) trim sequence and description 355 | if self.max_description_length is not None: 356 | description_ids = self.description_tokenizer( 357 | [description], 358 | add_special_tokens=False, 359 | return_tensors="pt" 360 | )["input_ids"] 361 | if description_ids.size(-1) > self.max_description_length: 362 | description_ids = description_ids[:, :self.max_description_length] 363 | description = self.description_tokenizer.decode(description_ids[0]) 364 | if self.max_sequence_length is not None: 365 | if len(sequence) > self.max_sequence_length: 366 | sequence = sequence[:self.max_sequence_length] 367 | 368 | # (3) apply chat template then tokenize the prompt 369 | user_message = ( 370 | "Protein name: " + fullname 371 | + " ; Taxon: " + taxon 372 | + " ; Sequence embeddings: " 373 | + self.placeholder_token * (len(sequence) + 2) # bos and eos tokens 374 | ) 375 | prompt_conversation = [ 376 | {"role": "system", "content": self.system_message}, 377 | {"role": "user", "content": user_message} 378 | ] 379 | prompt_ids = self.description_tokenizer.apply_chat_template( 380 | prompt_conversation, 381 | add_generation_prompt=True, 382 | tokenize=True, 383 | padding=False, 384 | return_tensors="pt" 385 | ) 386 | 387 | # (4) tokenize sequence and description 388 | sequence_ids = self.sequence_tokenizer( 389 | [sequence], 390 | add_special_tokens=True, 391 | return_attention_mask=False, 392 | return_tensors="pt" 393 | )["input_ids"] 394 | description_ids = self.description_tokenizer( 395 | [description + self.description_tokenizer.eos_token], 396 | add_special_tokens=False, 397 | return_attention_mask=False, 398 | return_tensors="pt" 399 | )["input_ids"] 400 | 401 | # tensors should be kept of shape (1, seq_length) for formation of batch 402 | # in dataloader 403 | return { 404 | "sequence_input_ids": sequence_ids, 405 | "prompt_input_ids": prompt_ids, 406 | "description_input_ids": description_ids, 407 | } 408 | 409 | @property 410 | def raw_file_names(self) -> List[str]: 411 | """ 412 | The name of files in `self.raw_dir` folder that must be present in order 413 | to skip downloading. Required in the parent class. 414 | """ 415 | uniprot_ids = set(self.uniprot_df.AlphaFoldDB) 416 | return [ 417 | f"AF-{uniprot_id}-F1-model_v{self.alphafold_version}.pdb" 418 | for uniprot_id in uniprot_ids 419 | ] 420 | 421 | @property 422 | def processed_file_names(self) -> List[str]: 423 | """ 424 | The name of files in `self.processed_dir` folder that must be present in 425 | order to skip processing. Required in the parent class. 426 | """ 427 | return [ 428 | os.path.splitext(raw_file_name)[0]+".pt" 429 | for raw_file_name in os.listdir(self.raw_dir) 430 | ] 431 | 432 | def update_usable_file_names(self): 433 | """ 434 | Updates stored name of files in `self.processed_dir` folder that are 435 | ready to be used for the dataset. Usable file names shall be updated in 436 | initialization and preprocessing but not recommended in actual usage to 437 | maintain consistency of the order of file names. 438 | """ 439 | existing_file_names = os.listdir(self.processed_dir) 440 | for special_file_name in ['pre_transform.pt', 'pre_filter.pt']: 441 | if special_file_name in existing_file_names: 442 | existing_file_names.remove(special_file_name) 443 | self.usable_file_names = sorted(existing_file_names) 444 | 445 | def len(self) -> int: 446 | """Required in the parent class.""" 447 | return len(self.usable_file_names) 448 | 449 | def __len__(self) -> int: 450 | return self.len() 451 | 452 | def get(self, idx: int, debug_mode: bool = False) -> Any: 453 | data = torch.load(os.path.join(self.processed_dir, self.usable_file_names[idx])) 454 | if debug_mode: # return all items stored in the file 455 | return data 456 | if self.ignore_graph_features: 457 | return torch_geometric.data.Data( 458 | sequence_input_ids=data['sequence_input_ids'], 459 | prompt_input_ids=data["prompt_input_ids"], 460 | description_input_ids=data['description_input_ids'], 461 | name=data["name"], 462 | ) 463 | else: 464 | return torch_geometric.data.Data( 465 | x=data['x'], 466 | edge_index=data['edge_index'], 467 | edge_type=data['edge_type'], 468 | sequence_input_ids=data['sequence_input_ids'], 469 | prompt_input_ids=data["prompt_input_ids"], 470 | description_input_ids=data['description_input_ids'], 471 | name=data["name"], 472 | ) 473 | --------------------------------------------------------------------------------