├── scripts
├── __init__.py
├── utils_argparse.py
├── benchmark.py
├── generate_legacy.py
├── generate_instruct.py
├── generate_instruct_light.py
├── train_legacy.py
└── train_instruct.py
├── figures
├── model.png
└── training.png
├── .gitignore
├── dataset
├── __init__.py
├── utils_dataset.py
├── utils_pdb2nx.py
├── nx2pyg.py
├── dataloader_light.py
├── dataloader.py
├── dataloader_derived.py
└── dataset.py
├── models
├── __init__.py
├── configuration_esm2rgcn2llama_instruct.py
├── configuration_esm2llama_instruct.py
├── configuration_esm2llama_legacy.py
├── modeling_esm2llama_instruct.py
└── modeling_esm2rgcn2llama_instruct.py
├── LICENSE
└── README.md
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ColinFX/Prot2Text-V2/HEAD/figures/model.png
--------------------------------------------------------------------------------
/figures/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ColinFX/Prot2Text-V2/HEAD/figures/training.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.csv
3 | *.err
4 | *.log
5 | *.ipynb
6 | *.pt
7 | *.pth
8 | *tmp*
9 | .vscode
10 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import Prot2TextInstructDataset
2 | from .dataloader import Prot2TextInstructDataLoader
3 | from .dataloader_derived import Prot2TextDerivedDataLoader
4 | from .dataloader_light import Prot2TextLightDataset, Prot2TextLightCollater
5 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_esm2llama_instruct import ModalityAdapterConfig, Esm2LlamaInstructConfig
2 | from .configuration_esm2llama_legacy import EsmEncoderConfig, Esm2LlamaConfig
3 | from .configuration_esm2rgcn2llama_instruct import RgcnAdapterConfig, Esm2Rgcn2LlamaInstructConfig
4 | from .modeling_esm2llama_instruct import ModalityAdapter, Esm2LlamaInstructForCausalLM
5 | from .modeling_esm2llama_legacy import EsmEncoderModel, Esm2LlamaForCausalLM
6 | from .modeling_esm2rgcn2llama_instruct import RgcnAdapter, Esm2Rgcn2LlamaInstructForCausalLM
7 | from .modeling_reward import RewardModel
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Xiao FEI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/dataset/utils_dataset.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from graphein.protein.config import DSSPConfig, ProteinGraphConfig
4 | from graphein.protein.edges.distance import (
5 | add_peptide_bonds,
6 | add_hydrogen_bond_interactions,
7 | add_distance_threshold
8 | )
9 | from graphein.protein.features.nodes.amino_acid import (
10 | amino_acid_one_hot,
11 | meiler_embedding,
12 | expasy_protein_scale,
13 | hydrogen_bond_acceptor,
14 | hydrogen_bond_donor
15 | )
16 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure
17 |
18 |
19 | default_graph_process_config = ProteinGraphConfig(
20 | **{
21 | "node_metadata_functions": [
22 | amino_acid_one_hot,
23 | expasy_protein_scale,
24 | meiler_embedding,
25 | hydrogen_bond_acceptor,
26 | hydrogen_bond_donor
27 | ],
28 | "edge_construction_functions": [
29 | add_peptide_bonds,
30 | add_hydrogen_bond_interactions,
31 | partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.),
32 | ],
33 | "graph_metadata_functions": [asa, phi, psi, secondary_structure, rsa],
34 | "dssp_config": DSSPConfig(),
35 | }
36 | )
37 |
--------------------------------------------------------------------------------
/scripts/utils_argparse.py:
--------------------------------------------------------------------------------
1 | """Utilities for argparse configuration."""
2 |
3 | import torch
4 |
5 |
6 | def str2bool(string: str = "") -> bool:
7 | """
8 | Override default build-in bool() function to facilitate argparse configuration.
9 |
10 | Example of usage:
11 | >>> import argparse
12 | >>> from utils import str2bool
13 | >>> argParser = argparse.ArgumentParser()
14 | >>> argParser.add_argument(
15 | "--debug_model",
16 | default=False,
17 | type=str2bool,
18 | help="boolean flag to debug model"
19 | )
20 | """
21 | negative_keywords = ["false", "no", "none", "negative", "off", "disable", "f", "0"]
22 | if not string or any([string.lower() == keyword for keyword in negative_keywords]):
23 | return False
24 | return True
25 |
26 |
27 | def str2dtype(string: str = "") -> torch.dtype:
28 | """
29 | Convert string to corresponding torch datatype to facilitate argparse
30 | configuration for autocast dtype.
31 | """
32 | if not string:
33 | return torch.float32
34 | string = string.lower()
35 | bf16_keywords = ["torch.bfloat16", "bfloat16", "bf16"]
36 | fp16_keywords = ["torch.float16", "float16", "fp16", "16", "half"]
37 | int8_keywords = ["torch.int8", "int8", "8"]
38 | int4_keywords = ["torch.int4", "int4", "4"]
39 | if any([string == keyword for keyword in bf16_keywords]):
40 | return torch.bfloat16
41 | elif any([string == keyword for keyword in fp16_keywords]):
42 | return torch.float16
43 | elif any([string == keyword for keyword in int8_keywords]):
44 | return torch.int8
45 | elif any([string == keyword for keyword in int4_keywords]):
46 | return torch.int4
47 | else:
48 | return torch.float32
49 |
--------------------------------------------------------------------------------
/models/configuration_esm2rgcn2llama_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration class for the assembled Esm2Rgcn2LlamaInstructForCausalLM model.
3 |
4 | Esm2LlamaInstructConfig = EsmConfig + RgcnAdapterConfig + LlamaConfig
5 | """
6 |
7 |
8 | from transformers import EsmConfig, LlamaConfig, PretrainedConfig
9 |
10 |
11 | class RgcnAdapterConfig(PretrainedConfig):
12 | """Configuration class of the Relational Graph Convolutional Network adapter."""
13 | model_type = "rgcn_adapter" # unique identifier of the model
14 |
15 | def __init__(
16 | self,
17 | input_dim: int,
18 | intermediate_dim: int,
19 | output_dim: int,
20 | n_relations: int = 7,
21 | n_layers: int = 6,
22 | dropout_rate: float = 0.2,
23 | **kwargs
24 | ):
25 | super().__init__(**kwargs)
26 | self.input_dim = input_dim
27 | self.intermediate_dim = intermediate_dim
28 | self.output_dim = output_dim
29 | self.n_relations = n_relations
30 | self.n_layers = n_layers
31 | self.dropout_rate = dropout_rate
32 |
33 |
34 | class Esm2Rgcn2LlamaInstructConfig(PretrainedConfig):
35 | """
36 | Configuration class of Esm2Rgcn2LlamaInstructForCausalLM model.
37 | placeholder_id: Token id in chat template to be replaced by ESM embeddings.
38 | """
39 | model_type = "esm2rgcn2llama_instruct" # unique identifier of the model
40 |
41 | def __init__(
42 | self,
43 | # model components
44 | esm_config: EsmConfig,
45 | adapter_config: RgcnAdapterConfig,
46 | llama_config: LlamaConfig,
47 | # standalone attributes
48 | placeholder_id: int = 128003,
49 | **kwargs
50 | ):
51 | super().__init__(**kwargs)
52 | self.esm_config = esm_config
53 | self.adapter_config = adapter_config
54 | self.llama_config = llama_config
55 | self.placeholder_id = placeholder_id
56 |
--------------------------------------------------------------------------------
/dataset/utils_pdb2nx.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for `model/pdb2nx.py` script.
3 | """
4 | import numpy as np
5 | from biopandas.pdb import PandasPdb
6 |
7 |
8 | pdb_order = [
9 | "record_name",
10 | "atom_number",
11 | "blank_1",
12 | "atom_name",
13 | "alt_loc",
14 | "residue_name",
15 | "blank_2",
16 | "chain_id",
17 | "residue_number",
18 | "insertion",
19 | "blank_3",
20 | "x_coord",
21 | "y_coord",
22 | "z_coord",
23 | "occupancy",
24 | "b_factor",
25 | "blank_4",
26 | "segment_id",
27 | "element_symbol",
28 | "charge",
29 | "line_idx",
30 | ]
31 | mmcif_read = {
32 | "group_PDB": "record_name",
33 | "id": "atom_number",
34 | "auth_atom_id": "atom_name",
35 | "auth_comp_id": "residue_name",
36 | "auth_asym_id": "chain_id",
37 | "auth_seq_id": "residue_number",
38 | "Cartn_x": "x_coord",
39 | "Cartn_y": "y_coord",
40 | "Cartn_z": "z_coord",
41 | "occupancy": "occupancy",
42 | "B_iso_or_equiv": "b_factor",
43 | "type_symbol": "element_symbol",
44 | }
45 |
46 | nonefields = [
47 | "blank_1",
48 | "alt_loc",
49 | "blank_2",
50 | "insertion",
51 | "blank_3",
52 | "blank_4",
53 | "segment_id",
54 | "charge",
55 | "line_idx",
56 | ]
57 |
58 |
59 | def biopandas_mmcif2pdb(pandasmmcif, model_index=1):
60 | """Converts the ATOM and HETATM dataframes of PandasMmcif() to PandasPdb() format."""
61 | pandaspdb = PandasPdb()
62 | for a in ["ATOM", "HETATM"]:
63 | dfa = pandasmmcif.df[a]
64 | dfa = dfa.loc[dfa.pdbx_PDB_model_num == model_index]
65 | if a == 'ATOM':
66 | if len(dfa) == 0:
67 | raise ValueError(f"No model found for index: {model_index}")
68 | # keep only those fields found in pdb
69 | dfa = dfa[mmcif_read.keys()]
70 | # rename fields
71 | dfa = dfa.rename(columns=mmcif_read)
72 | # add empty fields
73 | for i in nonefields:
74 | dfa[i] = ""
75 | dfa["charge"] = np.nan
76 | # reorder columns to PandasPdb order
77 | dfa = dfa[pdb_order]
78 | pandaspdb.df[a] = dfa
79 |
80 | # update line_idx
81 | pandaspdb.df["ATOM"]["line_idx"] = pandaspdb.df["ATOM"].index.values
82 | pandaspdb.df["HETATM"]["line_idx"] = pandaspdb.df["HETATM"].index
83 |
84 | return pandaspdb
85 |
--------------------------------------------------------------------------------
/models/configuration_esm2llama_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model.
3 |
4 | Esm2LlamaInstructConfig = EsmConfig + ModalityAdapterConfig + LlamaConfig
5 | """
6 |
7 | from typing import Dict, Optional, Union
8 | from transformers import EsmConfig, LlamaConfig, PretrainedConfig
9 |
10 |
11 | class ModalityAdapterConfig(PretrainedConfig):
12 | """Configuration class of the 2-layer non-linear adapter."""
13 | model_type = "modality_adapter" # unique identifier of the model
14 |
15 | def __init__(
16 | self,
17 | input_dim: int,
18 | intermediate_dim: int,
19 | output_dim: int,
20 | dropout_rate: float = 0.3,
21 | **kwargs
22 | ):
23 | super().__init__(**kwargs)
24 | self.input_dim = input_dim
25 | self.intermediate_dim = intermediate_dim
26 | self.output_dim = output_dim
27 | self.dropout_rate = dropout_rate
28 |
29 |
30 | class Esm2LlamaInstructConfig(PretrainedConfig):
31 | """
32 | Configuration class of Esm2LlamaInstructForCausalLM model.
33 | placeholder_id: Token id in chat template to be replaced by ESM embeddings.
34 | """
35 | model_type = "esm2llama_instruct" # unique identifier of the model
36 |
37 | def __init__(
38 | self,
39 | # model components
40 | esm_config: Optional[Union[EsmConfig, Dict]] = None,
41 | adapter_config: Optional[Union[ModalityAdapterConfig, Dict]] = None,
42 | llama_config: Optional[Union[LlamaConfig, Dict]] = None,
43 | # standalone attributes
44 | placeholder_id: int = 128003,
45 | **kwargs
46 | ):
47 | super().__init__(**kwargs)
48 |
49 | if isinstance(esm_config, dict):
50 | self.esm_config = EsmConfig(**esm_config)
51 | else:
52 | self.esm_config = esm_config
53 |
54 | if isinstance(llama_config, dict):
55 | self.llama_config = LlamaConfig(**llama_config)
56 | else:
57 | self.llama_config = llama_config
58 |
59 | if isinstance(adapter_config, dict):
60 | self.adapter_config = ModalityAdapterConfig(**adapter_config)
61 | else:
62 | self.adapter_config = adapter_config
63 |
64 | self.placeholder_id = placeholder_id
65 |
--------------------------------------------------------------------------------
/dataset/nx2pyg.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for converting Protein Structure Graphs to standard Data object.
3 | """
4 | import networkx as nx
5 | import numpy as np
6 | import torch
7 | import torch_geometric
8 |
9 |
10 | graph_features = ['phi', 'psi', 'rsa', 'asa', 'ss', 'expasy']
11 | map_secondary_structure = {'-': 0, 'H': 1, 'B': 2, 'E': 3, 'G': 4, 'I': 5, 'T': 6, 'S': 7}
12 | map_edge_types = {
13 | 'peptide_bond': 0,
14 | 'sequence_distance_2': 1,
15 | 'sequence_distance_3': 2,
16 | 'distance_threshold': 3,
17 | 'delaunay': 4,
18 | 'hbond': 5,
19 | 'k_nn': 6
20 | }
21 |
22 |
23 | def convert_nx_to_pyg(nx_graph: nx.Graph) -> torch_geometric.data.Data:
24 | """
25 | Converting Graphein Networks from `networkx.Graph` (nx) format to `torch_geometric.data.Data` (pytorch geometric,
26 | PyG) format.
27 | """
28 | # Initialise dict used to construct Data object
29 | data_dict = {"node_id": list(nx_graph.nodes())}
30 | nx_graph = nx.convert_node_labels_to_integers(nx_graph)
31 |
32 | # Construct Edge Index
33 | edge_index = torch.LongTensor(list(nx_graph.edges)).t().contiguous()
34 |
35 | # Add node features
36 | for i, (_, feat_dict) in enumerate(nx_graph.nodes(data=True)):
37 | for key, value in feat_dict.items():
38 | data_dict[str(key)] = [value] if i == 0 else data_dict[str(key)] + [value]
39 |
40 | # Add edge features
41 | for i, (_, _, feat_dict) in enumerate(nx_graph.edges(data=True)):
42 | for key, value in feat_dict.items():
43 | if key == 'distance':
44 | data_dict[str(key)] = [value] if i == 0 else data_dict[str(key)] + [value]
45 | else:
46 | data_dict[str(key)] = [list(value)] if i == 0 else data_dict[str(key)] + [list(value)]
47 |
48 | # Add graph-level features
49 | for feat_name in nx_graph.graph:
50 | data_dict[str(feat_name)] = [nx_graph.graph[feat_name]]
51 |
52 | data_dict["edge_index"] = edge_index.view(2, -1)
53 | data = torch_geometric.data.Data.from_dict(data_dict)
54 | data.num_nodes = nx_graph.number_of_nodes()
55 |
56 | # remove useless intermediate data and add features for deep learning models
57 | reformat_data = torch_geometric.data.Data(
58 | edge_index=data.edge_index,
59 | num_nodes=len(data.node_id),
60 | node_id=data.node_id,
61 | name=data.name[0],
62 | sequence=getattr(data, f"sequence_{data.chain_id[0]}"),
63 | distance_matrix=data.dist_mat,
64 | distance=data.distance,
65 | coordinates=torch.FloatTensor(np.array(data.coords[0]))
66 | )
67 |
68 | x = np.array([np.argmax(data.amino_acid_one_hot, axis=1)]).reshape(-1, 1)
69 | for feat in graph_features:
70 | if feat == "ss":
71 | feature = np.array([[map_secondary_structure.get(feat_node, 0)] for feat_node in data[feat]])
72 | else:
73 | feature = np.array(data[feat])
74 | if len(feature.shape) == 1:
75 | feature = feature.reshape(-1, 1)
76 | x = np.concatenate((x, feature), axis=1)
77 | reformat_data.x = torch.FloatTensor(x)
78 | reformat_data.edge_type = torch.LongTensor([map_edge_types[kind[0]] for kind in data.kind])
79 |
80 | return reformat_data
81 |
--------------------------------------------------------------------------------
/models/configuration_esm2llama_legacy.py:
--------------------------------------------------------------------------------
1 | """
2 | Legacy configuration classes for Esm2LlamaModel. Migrated from previous projects
3 | on protein function description under a decoder-only base language decoder
4 | structure.
5 |
6 | Esm2LlamaConfig = LlamaConfig (+ EsmEncoderConfig as additional attribute)
7 | """
8 |
9 |
10 | import os
11 | from typing import Any, Dict, Optional, Union
12 |
13 | from transformers import EsmConfig, LlamaConfig
14 |
15 |
16 | class EsmEncoderConfig(EsmConfig):
17 | """Configuration class of EsmEncoderModel model."""
18 |
19 | def __init__(
20 | self,
21 | *args,
22 | decoder_hidden_size: Optional[int] = None,
23 | **kwargs
24 | ):
25 | super().__init__(*args, **kwargs)
26 | self.decoder_hidden_size: int = decoder_hidden_size
27 |
28 |
29 | class Esm2LlamaConfig(LlamaConfig):
30 | """
31 | Configuration class of Esm2LlamaModel model.
32 |
33 | Args:
34 | esm_config:
35 | Configuration to be used with the EsmEncoderConfig encoder
36 | configuration. The value is either an instance of EsmEncoderConfig
37 | or a dict of parameters to be passed to initialize the
38 | EsmEncoderConfig (read the documentation from `EsmEncoderConfig`
39 | for more information in this case). If not given, a default
40 | configuration is used.
41 | kwargs:
42 | Keyword arguments for initialization of LlamaConfig, read the
43 | documentation from `LlamaConfig` for more information. Parameters
44 | controlling the model outputs can also be passed, read the
45 | documentation from `PretrainedConfig` for more information.
46 | """
47 | def __init__(
48 | self,
49 | *args,
50 | esm_config: Optional[Union[EsmEncoderConfig, Dict[str, Any]]] = None,
51 | **kwargs
52 | ):
53 | # normal initialization of LlamaConfig with keyword arguments
54 | super().__init__(*args, **kwargs)
55 |
56 | # add self.esm_config: EsmEncoderConfig as extra attribute to LlamaConfig
57 | if esm_config is None or isinstance(esm_config, dict):
58 | self.esm_config = EsmEncoderConfig(**esm_config if esm_config else {})
59 | elif isinstance(esm_config, EsmEncoderConfig):
60 | self.esm_config = esm_config
61 | else:
62 | raise ValueError(
63 | "esm_config must be a EsmEncoderConfig, or a dict of "
64 | "initialization parameters. Use from_pretrained method instead "
65 | "if the esm_config shall be loaded from a pretrained model "
66 | "name or path. "
67 | )
68 |
69 | @classmethod
70 | def from_pretrained(
71 | cls,
72 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
73 | pretrained_esm_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
74 | pretrained_llama_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
75 | esm_kwargs: Optional[Dict[str, Any]] = None,
76 | **kwargs,
77 | ) -> "Esm2LlamaConfig":
78 | """
79 | Instantiates a Esm2LlamaConfig from either (1) a pretrained
80 | Esm-to-Llama model, or (2) a pretrained LlamaForCausalLM and/or a
81 | pretrained EsmModel. The configuration of any unspecified parts of
82 | the model will be initialized with default values.
83 |
84 | return_unused_kwargs is currently not supported in the loading
85 | behavior. # TODO complete this case.
86 |
87 | Args:
88 | pretrained_model_name_or_path:
89 | Esm-to-Llama model name or path to load predefined Esm-to-Llama
90 | model configuration. If given, pretrained EsmModel and
91 | LlamaForCausalLM name or path will be ignored.
92 | pretrained_esm_model_name_or_path:
93 | Esm model name or path to load predefined EsmModel configuration
94 | as encoder part of the whole model.
95 | pretrained_llama_model_name_or_path:
96 | Llama model name or path to load predefined LlamaForCausalLM
97 | model configuration as decoder part of the whole model.
98 | esm_kwargs:
99 | Configuration attributes to override values in EsmEncoderConfig
100 | which is either loaded from pretrained or initialized with
101 | default values. Behavior concerning key/value pairs whose keys
102 | are not configuration attributes is controlled by the
103 | return_unused_kwargs keyword parameter. Parameters controlling
104 | the loading behaviors of Esm configuration such as `cache_dir`
105 | and `force_download` can also be passed if
106 | `pretrained_esm_model_name_or_path` is given, read the
107 | documentation from `PretrainedConfig.from_pretrained` for more
108 | information.
109 | kwargs:
110 | Configuration attributes to override the loaded values in
111 | Esm2LlamaConfig or LlamaConfig. Parameters controlling the
112 | loading behaviors of Esm-to-Llama or Llama configuration such
113 | as `cache_dir` and `force_download` can also be passed if
114 | `pretrained_model_name_or_path` or
115 | `pretrained_llama_model_name_or_path` is given.
116 | """
117 | # case (1): instantiate from a pretrained Esm-to-Llama model
118 | if pretrained_model_name_or_path is not None:
119 | config = super().from_pretrained(
120 | pretrained_model_name_or_path=pretrained_model_name_or_path,
121 | **kwargs
122 | )
123 | if esm_kwargs:
124 | config.esm_config.update(esm_kwargs)
125 |
126 | # case (2-1): instantiate from pretrained LlamaModel and EsmModel
127 | elif (
128 | pretrained_esm_model_name_or_path is not None
129 | and pretrained_llama_model_name_or_path is not None
130 | ):
131 | config = super().from_pretrained(
132 | pretrained_model_name_or_path=pretrained_llama_model_name_or_path,
133 | **kwargs
134 | )
135 | config.esm_config = EsmEncoderConfig.from_pretrained(
136 | pretrained_model_name_or_path=pretrained_esm_model_name_or_path,
137 | **esm_kwargs if esm_kwargs else {}
138 | )
139 |
140 | # case (2-2): instantiate from a pretrained EsmModel
141 | elif pretrained_esm_model_name_or_path is not None:
142 | esm_config = EsmEncoderConfig.from_pretrained(
143 | pretrained_model_name_or_path=pretrained_esm_model_name_or_path,
144 | **esm_kwargs if esm_kwargs else {}
145 | )
146 | config = cls(esm_config=esm_config, **kwargs)
147 |
148 | # case (2-3): instantiate from a pretrained LlamaModel
149 | elif pretrained_llama_model_name_or_path is not None:
150 | config = super().from_pretrained(
151 | pretrained_model_name_or_path=pretrained_llama_model_name_or_path,
152 | **kwargs
153 | )
154 | config.esm_config = EsmEncoderConfig(**esm_kwargs if esm_kwargs else {})
155 |
156 | else:
157 | raise ValueError(
158 | "Either pretrained name or path of Esm-to-Llama model, EsmModel "
159 | "or LlamaForCausalLM should be passed. Use initialization "
160 | "method instead if none of the above three can be provided. "
161 | )
162 | return config
163 |
--------------------------------------------------------------------------------
/scripts/benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | Single-thread script to compute metrics including BLEU, ROUGE and BERT scores.
3 | Metrics will be computed on JSON files generated by `generate_instruct.py`.
4 | The script is designed for single-GPU computation.
5 | """
6 |
7 | import argparse
8 | import json
9 | import os
10 | import re
11 | from typing import Any, Dict, List
12 |
13 | import evaluate
14 | from transformers import BertTokenizer, RobertaTokenizer
15 | import scripts.utils_argparse as utils_argparse
16 |
17 |
18 | argParser = argparse.ArgumentParser()
19 |
20 | argParser.add_argument("--read_generation_dir", type=str)
21 | argParser.add_argument("--read_file_identifier", type=str, help="Postfix identifier or timestamp to filter files.")
22 |
23 | argParser.add_argument("--evaluate_exact_match", type=utils_argparse.str2bool)
24 | argParser.add_argument("--evaluate_bleu", type=utils_argparse.str2bool)
25 | argParser.add_argument("--evaluate_rouge", type=utils_argparse.str2bool)
26 | argParser.add_argument("--evaluate_bert_score", type=utils_argparse.str2bool)
27 | argParser.add_argument("--verbose", type=utils_argparse.str2bool)
28 |
29 |
30 | def compute_exact_match(predictions: List[str], references: List[str]) -> float:
31 | """Compute exact match ratio allowing for case and punctuation differences."""
32 | def normalize(text: str) -> str:
33 | """Normalize text by lowercasing and removing punctuation."""
34 | text = text.lower()
35 | text = re.sub(r'[^\w]', '', text)
36 | return text
37 |
38 | exact_match = 0
39 | for pred, ref in zip(predictions, references):
40 | if normalize(pred) == normalize(ref):
41 | exact_match += 1
42 | return exact_match / len(predictions)
43 |
44 |
45 | def compute_bleu2(predictions: List[str], references: List[str]) -> Dict[str, Any]:
46 | bleu = evaluate.load("bleu")
47 | return bleu.compute(predictions=predictions, references=references, max_order=2)
48 |
49 |
50 | def compute_bleu4(predictions: List[str], references: List[str]) -> Dict[str, Any]:
51 | bleu = evaluate.load("bleu")
52 | return bleu.compute(predictions=predictions, references=references)
53 |
54 |
55 | def compute_rouge(predictions: List[str], references: List[str]) -> Dict[str, Any]:
56 | rouge = evaluate.load("rouge")
57 | return rouge.compute(predictions=predictions, references=references)
58 |
59 |
60 | def compute_bert_score(predictions: List[str], references: List[str]) -> Dict[str, Dict[str, Any]]:
61 | """Compute BERT score on roberta-large and biobert-large respectively."""
62 | results: Dict[str, Dict[str, Any]] = {}
63 |
64 | tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-large")
65 | retokenized_predictions = tokenizer(
66 | predictions, padding="max_length", truncation=True, max_length=495, return_tensors="pt"
67 | )["input_ids"]
68 | truncated_predictions = tokenizer.batch_decode(retokenized_predictions, skip_special_tokens=True)
69 | retokenized_labels = tokenizer(
70 | references, padding="max_length", truncation=True, max_length=495, return_tensors="pt"
71 | )["input_ids"]
72 | truncated_labels = tokenizer.batch_decode(retokenized_labels, skip_special_tokens=True)
73 |
74 | bert = evaluate.load("bertscore")
75 | roberta_results = bert.compute(predictions=truncated_predictions, references=truncated_labels, lang="en")
76 | results["roberta-large"] = {
77 | "precision": sum(roberta_results["precision"]) / len(roberta_results["precision"]),
78 | "recall": sum(roberta_results["recall"]) / len(roberta_results["recall"]),
79 | "f1": sum(roberta_results["f1"]) / len(roberta_results["f1"])
80 | }
81 |
82 | # truncate sentences to fit max_position_embeddings=512 of biobert
83 | tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-large-cased-v1.1")
84 | retokenized_predictions = tokenizer(
85 | predictions, padding="max_length", truncation=True, max_length=495, return_tensors="pt"
86 | )["input_ids"]
87 | truncated_predictions = tokenizer.batch_decode(retokenized_predictions, skip_special_tokens=True)
88 | retokenized_labels = tokenizer(
89 | references, padding="max_length", truncation=True, max_length=495, return_tensors="pt"
90 | )["input_ids"]
91 | truncated_labels = tokenizer.batch_decode(retokenized_labels, skip_special_tokens=True)
92 |
93 | biobert_results = bert.compute(
94 | predictions=truncated_predictions,
95 | references=truncated_labels,
96 | model_type="dmis-lab/biobert-large-cased-v1.1",
97 | num_layers=24,
98 | )
99 | results["biobert-large"] = {
100 | "precision": sum(biobert_results["precision"]) / len(biobert_results["precision"]),
101 | "recall": sum(biobert_results["recall"]) / len(biobert_results["recall"]),
102 | "f1": sum(biobert_results["f1"]) / len(biobert_results["f1"])
103 | }
104 |
105 | return results
106 |
107 |
108 | def compute_metrics(predictions: List[str], references: List[str], args: Dict[str, Any]) -> Dict[str, Any]:
109 | """Compute BLEU, ROUGE, BERT scores and exact match ratio on given texts."""
110 | gathered_results: Dict[str, Dict[str, Any]] = {}
111 |
112 | if args["evaluate_exact_match"]:
113 | exact_match = compute_exact_match(predictions=predictions, references=references)
114 | gathered_results["exact_match"] = exact_match
115 | if args["verbose"]:
116 | print(f"EXACT match ratio: {exact_match}")
117 |
118 | if args["evaluate_bleu"]:
119 | bleu_results = compute_bleu2(predictions=predictions, references=references)
120 | gathered_results["bleu2"] = bleu_results
121 | if args["verbose"]:
122 | print(f"BLEU-2 score: {bleu_results}")
123 | bleu_results = compute_bleu4(predictions=predictions, references=references)
124 | gathered_results["bleu4"] = bleu_results
125 | if args["verbose"]:
126 | print(f"BLEU-4 score: {bleu_results}")
127 |
128 | if args["evaluate_rouge"]:
129 | rouge_results = compute_rouge(predictions=predictions, references=references)
130 | gathered_results["rouge"] = rouge_results
131 | if args["verbose"]:
132 | print(f"ROUGE score: {rouge_results}")
133 |
134 | if args["evaluate_bert_score"]:
135 | bert_results = compute_bert_score(predictions=predictions, references=references)
136 | gathered_results["bert"] = bert_results
137 | if args["verbose"]:
138 | for model_name, model_results in bert_results.items():
139 | print(f"BERT score with {model_name}: {model_results}")
140 |
141 | return gathered_results
142 |
143 |
144 | def benchmark(args: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
145 | """
146 | Evaluate generation results on JSON files produced by `generate_ddp.py`.
147 | 1) Gather predictions and labels from JSON files.
148 | 2) Compute metrics including BLEU, ROUGE and BERT scores and print results.
149 | """
150 | read_generation_paths = []
151 | for file_name in os.listdir(args["read_generation_dir"]):
152 | full_path = os.path.join(args["read_generation_dir"], file_name)
153 | if os.path.isfile(full_path) and args["read_file_identifier"] in full_path:
154 | read_generation_paths.append(full_path)
155 |
156 | gathered_predictions = []
157 | gathered_labels = []
158 | for read_path in read_generation_paths:
159 | with open(read_path, "r") as file:
160 | results = json.load(file)
161 | local_predictions = [results[name]["pred"] for name in results.keys()]
162 | local_labels = [results[name]["true"] for name in results.keys()]
163 | gathered_predictions.extend(local_predictions)
164 | gathered_labels.extend(local_labels)
165 | print(f"Reading {read_path}")
166 |
167 | return compute_metrics(predictions=gathered_predictions, references=gathered_labels, args=args)
168 |
169 |
170 | if __name__ == "__main__":
171 | parsed_args = argParser.parse_args()
172 |
173 | print("####################")
174 | for key, value in parsed_args.__dict__.items():
175 | print(f"{key}: {value}")
176 | print("####################")
177 |
178 | benchmark(parsed_args.__dict__)
179 |
--------------------------------------------------------------------------------
/scripts/generate_legacy.py:
--------------------------------------------------------------------------------
1 | """
2 | DistributedDataParallel generation script implemented from scratch.
3 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`.
4 | The script is designed for multi-GPU parallelism on single node.
5 | """
6 |
7 | import argparse
8 | from datetime import datetime
9 | import json
10 | import os
11 | from typing import Any, Dict, List
12 |
13 | import torch
14 | import torch.distributed as dist
15 | from torch.nn.parallel import DistributedDataParallel
16 | from torch.utils.data.distributed import DistributedSampler
17 | from tqdm import tqdm
18 | from transformers import AutoTokenizer, PreTrainedTokenizer
19 |
20 | from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader
21 | from .train_legacy import load_model, setup, cleanup
22 |
23 |
24 | argParser = argparse.ArgumentParser()
25 |
26 | argParser.add_argument("--esm_path", type=str)
27 | argParser.add_argument("--llama_path", type=str)
28 | argParser.add_argument("--root_dataset_dir", type=str)
29 | argParser.add_argument("--root_csv_dir", type=str)
30 | argParser.add_argument("--save_generation_dir", type=str)
31 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None)
32 | argParser.add_argument("--load_general_checkpoint_path", type=str, default="")
33 |
34 | argParser.add_argument("--batch_size_per_device", type=int)
35 | argParser.add_argument("--random_seed", type=int)
36 | argParser.add_argument("--generate_split", type=str)
37 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None)
38 | argParser.add_argument("--max_sequence_length", type=int, default=None)
39 | argParser.add_argument("--max_generation_length", type=int)
40 | argParser.add_argument("--num_beams", type=int, default=1)
41 | argParser.add_argument("--length_penalty", type=float, default=1.0)
42 |
43 |
44 | def iterative_generation_loop(
45 | rank: int,
46 | model: torch.nn.Module,
47 | data_batch: Dict[str, Any],
48 | max_generation_length: int,
49 | num_beams: int,
50 | length_penalty: float
51 | ) -> torch.Tensor:
52 | """
53 | Standard API for different models. Used in `inference_epoch`.
54 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader.
55 | 2) Execute the generation cycle and return the direct output.
56 | Returned output is a `torch.Tensor` of the generated tokens.
57 | """
58 | if isinstance(model, DistributedDataParallel):
59 | model = model.module # for wrapper models, get the inner model for generation
60 |
61 | return model.generate(
62 | inputs=data_batch["input_ids"].to(rank),
63 | attention_mask=data_batch["attention_mask"].to(rank),
64 | protein_input_ids=data_batch["protein_input_ids"].to(rank),
65 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank),
66 | max_new_tokens=max_generation_length,
67 | eos_token_id=128001,
68 | pad_token_id=128002,
69 | return_dict_in_generate=False,
70 | num_beams=num_beams,
71 | length_penalty=length_penalty,
72 | )
73 |
74 |
75 | def inference_epoch(
76 | rank: int,
77 | model: DistributedDataParallel,
78 | dataloader: Prot2TextDerivedDataLoader,
79 | llama_tokenizer: PreTrainedTokenizer,
80 | args: Dict[str, Any]
81 | ):
82 | """
83 | Iterate over all batches for inference with iterative loop.
84 | Generation results will be saved to JSON files.
85 | """
86 | model.eval()
87 | local_names: List[str] = []
88 | local_predictions: List[str] = []
89 | local_labels: List[str] = []
90 |
91 | # core loop for batches
92 | t = tqdm(iter(dataloader))
93 | for data_batch in t:
94 | with torch.no_grad():
95 | output = iterative_generation_loop(
96 | rank=rank,
97 | model=model,
98 | data_batch=data_batch,
99 | max_generation_length=args["max_generation_length"],
100 | num_beams=args["num_beams"],
101 | length_penalty=args["length_penalty"]
102 | )
103 | local_names.extend(data_batch["name"])
104 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
105 | local_predictions.extend(predicted_texts)
106 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True)
107 | local_labels.extend(label_texts)
108 | t.set_postfix({
109 | "mode": "inference",
110 | "batch_maxlen_gen": output.shape[1],
111 | "device": f"rank:{rank}"
112 | })
113 |
114 | local_json_path = os.path.join(
115 | args["save_generation_dir"],
116 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json"
117 | )
118 | with open(local_json_path, "w") as file:
119 | json_dict = {
120 | name: {"true": label, "pred": prediction}
121 | for name, label, prediction in zip(local_names, local_labels, local_predictions)
122 | }
123 | json.dump(json_dict, file, indent=4)
124 | print(f"Saving {local_json_path}")
125 |
126 |
127 |
128 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]):
129 | """Core generation process for every device with batches over the whole dataset"""
130 | setup(rank, world_size)
131 |
132 | # prepare dataset and dataloader
133 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"])
134 | llama_tokenizer = AutoTokenizer.from_pretrained(args["llama_path"], pad_token='<|reserved_special_token_0|>')
135 | generate_dataset = Prot2TextInstructDataset(
136 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"),
137 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"),
138 | sequence_tokenizer=esm_tokenizer,
139 | description_tokenizer=llama_tokenizer,
140 | skip_reload=True,
141 | skip_download=True,
142 | ignore_graph_features=True,
143 | max_sequence_length=args["max_sequence_length"],
144 | max_description_length=None, # fetch complete labels for metric computation
145 | )
146 | if args["debug_trim_generate_split"]:
147 | generate_dataset.usable_file_names = generate_dataset.usable_file_names[:args["debug_trim_generate_split"]]
148 | generate_sampler = DistributedSampler(generate_dataset, rank=rank, num_replicas=world_size, shuffle=False)
149 | generate_loader = Prot2TextDerivedDataLoader(
150 | generate_dataset,
151 | mode="inference",
152 | batch_size=args["batch_size_per_device"],
153 | sampler=generate_sampler,
154 | num_workers=2,
155 | pin_memory=True,
156 | shuffle=False,
157 | drop_last=True,
158 | )
159 |
160 | # load base model and then the checkpoint
161 | model = load_model(
162 | esm_path=args["esm_path"],
163 | llama_path=args["llama_path"],
164 | load_general_checkpoint_path=args["load_general_checkpoint_path"]
165 | )
166 | model = model.to(rank)
167 | model = DistributedDataParallel(model)
168 | print(f"Model loaded on rank{rank}")
169 |
170 | inference_epoch(
171 | rank=rank,
172 | model=model,
173 | dataloader=generate_loader,
174 | llama_tokenizer=llama_tokenizer,
175 | args=args
176 | )
177 | # use a barrier to make sure that all processes have finished writing their JSON files
178 | dist.barrier()
179 |
180 | cleanup()
181 |
182 |
183 | def inference_distributed(args: Dict[str, Any]):
184 | """Core generation process across multiple devices with batches over the whole dataset"""
185 | torch.multiprocessing.spawn(
186 | inference_on_device,
187 | args=(args["world_size"], args),
188 | nprocs=args["world_size"],
189 | join=True
190 | )
191 |
192 |
193 | if __name__ == "__main__":
194 | # suppress messages from AutoTokenizer parallelism and Graphein respectively
195 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
196 | os.environ["LOGURU_LEVEL"] = "INFO"
197 |
198 | parsed_args = argParser.parse_args()
199 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs
200 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes
201 |
202 | torch.manual_seed(parsed_args.random_seed)
203 | torch.cuda.manual_seed(parsed_args.random_seed)
204 |
205 | # prepare for saving path
206 | if not os.path.exists(parsed_args.save_generation_dir):
207 | os.makedirs(parsed_args.save_generation_dir)
208 |
209 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
210 | if parsed_args.save_generation_postfix_identifier:
211 | parsed_args.save_generation_postfix_identifier = f"""
212 | {start_timestamp}_[{parsed_args.save_generation_postfix_identifier}]
213 | """
214 | else:
215 | parsed_args.save_generation_postfix_identifier = start_timestamp
216 |
217 | print("####################")
218 | for key, value in parsed_args.__dict__.items():
219 | print(f"{key}: {value}")
220 | print("####################")
221 |
222 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices
223 | inference_distributed(parsed_args.__dict__)
224 |
--------------------------------------------------------------------------------
/scripts/generate_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | DistributedDataParallel generation script implemented from scratch.
3 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`.
4 | The script is designed for multi-GPU parallelism on single node.
5 | """
6 |
7 | import argparse
8 | from datetime import datetime
9 | import json
10 | import os
11 | from typing import Any, Dict, List
12 |
13 | import torch
14 | import torch.distributed as dist
15 | from torch.nn.parallel import DistributedDataParallel
16 | from torch.utils.data.distributed import DistributedSampler
17 | from tqdm import tqdm
18 | from transformers import AutoTokenizer, PreTrainedTokenizer
19 |
20 | from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader
21 | from .train_instruct import load_model, setup, cleanup
22 | import scripts.utils_argparse as utils_argparse
23 |
24 |
25 | argParser = argparse.ArgumentParser()
26 |
27 | argParser.add_argument("--esm_path", type=str)
28 | argParser.add_argument("--llama_path", type=str)
29 | argParser.add_argument("--root_dataset_dir", type=str)
30 | argParser.add_argument("--root_csv_dir", type=str)
31 | argParser.add_argument("--save_generation_dir", type=str)
32 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None)
33 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="")
34 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="")
35 |
36 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype)
37 | argParser.add_argument("--batch_size_per_device", type=int)
38 | argParser.add_argument("--random_seed", type=int)
39 | argParser.add_argument("--generate_split", type=str)
40 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None)
41 | argParser.add_argument("--max_sequence_length", type=int, default=None)
42 | argParser.add_argument("--max_generation_length", type=int)
43 | argParser.add_argument("--num_beams", type=int, default=1)
44 | argParser.add_argument("--length_penalty", type=float, default=1.0)
45 | argParser.add_argument("--temperature", type=float, default=1.0)
46 | argParser.add_argument("--do_sample", type=utils_argparse.str2bool, default=False)
47 | argParser.add_argument("--top_p", type=float, default=1.0)
48 | argParser.add_argument("--top_k", type=int, default=50)
49 |
50 |
51 | def iterative_generation_loop(
52 | rank: int,
53 | model: torch.nn.Module,
54 | data_batch: Dict[str, Any],
55 | max_generation_length: int,
56 | num_beams: int,
57 | length_penalty: float,
58 | temperature: float,
59 | do_sample: bool,
60 | top_p: float,
61 | top_k: int
62 | ) -> torch.Tensor:
63 | """
64 | Standard API for different models. Used in `inference_epoch`.
65 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader.
66 | 2) Execute the generation cycle and return the direct output.
67 | Returned output is a `torch.Tensor` of the generated tokens.
68 | """
69 | if isinstance(model, DistributedDataParallel):
70 | model = model.module # for wrapper models, get the inner model for generation
71 |
72 | return model.generate(
73 | inputs=data_batch["input_ids"].to(rank),
74 | attention_mask=data_batch["attention_mask"].to(rank),
75 | protein_input_ids=data_batch["protein_input_ids"].to(rank),
76 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank),
77 | max_new_tokens=max_generation_length,
78 | eos_token_id=128009,
79 | pad_token_id=128002,
80 | return_dict_in_generate=False,
81 | num_beams=num_beams,
82 | length_penalty=length_penalty,
83 | temperature=temperature,
84 | do_sample=do_sample,
85 | top_p=top_p,
86 | top_k=top_k
87 | )
88 |
89 |
90 | def inference_epoch(
91 | rank: int,
92 | model: DistributedDataParallel,
93 | dataloader: Prot2TextInstructDataLoader,
94 | llama_tokenizer: PreTrainedTokenizer,
95 | args: Dict[str, Any]
96 | ):
97 | """
98 | Iterate over all batches for inference with iterative loop.
99 | Generation results will be saved to JSON files.
100 | """
101 | model.eval()
102 | local_names: List[str] = []
103 | local_predictions: List[str] = []
104 | local_labels: List[str] = []
105 |
106 | # core loop for batches
107 | t = tqdm(iter(dataloader))
108 | for data_batch in t:
109 | with torch.no_grad():
110 | output = iterative_generation_loop(
111 | rank=rank,
112 | model=model,
113 | data_batch=data_batch,
114 | max_generation_length=args["max_generation_length"],
115 | num_beams=args["num_beams"],
116 | length_penalty=args["length_penalty"],
117 | temperature=args["temperature"],
118 | do_sample=args["do_sample"],
119 | top_p=args["top_p"],
120 | top_k=args["top_k"]
121 | )
122 | local_names.extend(data_batch["name"])
123 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
124 | local_predictions.extend(predicted_texts)
125 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True)
126 | local_labels.extend(label_texts)
127 | t.set_postfix({
128 | "mode": "inference",
129 | "batch_maxlen_gen": output.shape[1],
130 | "device": f"rank:{rank}"
131 | })
132 |
133 | local_json_path = os.path.join(
134 | args["save_generation_dir"],
135 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json"
136 | )
137 | with open(local_json_path, "w") as file:
138 | json_dict = {
139 | name: {"true": label, "pred": prediction}
140 | for name, label, prediction in zip(local_names, local_labels, local_predictions)
141 | }
142 | json.dump(json_dict, file, indent=4)
143 | print(f"Saving {local_json_path}")
144 |
145 |
146 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]):
147 | """Core generation process for every device with batches over the whole dataset"""
148 | setup(rank, world_size)
149 |
150 | # prepare dataset and dataloader
151 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"])
152 | llama_tokenizer = AutoTokenizer.from_pretrained(
153 | args["llama_path"],
154 | pad_token='<|reserved_special_token_0|>'
155 | )
156 |
157 | generate_dataset = Prot2TextInstructDataset(
158 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"),
159 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"),
160 | sequence_tokenizer=esm_tokenizer,
161 | description_tokenizer=llama_tokenizer,
162 | skip_reload=True,
163 | skip_download=True,
164 | ignore_graph_features=True,
165 | )
166 | if args["debug_trim_generate_split"]:
167 | generate_dataset.usable_file_names = generate_dataset.usable_file_names[
168 | :args["debug_trim_generate_split"]
169 | ]
170 | generate_sampler = DistributedSampler(
171 | generate_dataset,
172 | rank=rank,
173 | num_replicas=world_size,
174 | shuffle=False
175 | )
176 | generate_loader = Prot2TextInstructDataLoader(
177 | generate_dataset,
178 | mode="inference",
179 | batch_size=args["batch_size_per_device"],
180 | sampler=generate_sampler,
181 | num_workers=2,
182 | pin_memory=True,
183 | shuffle=False,
184 | drop_last=True,
185 | )
186 |
187 | # load base model and then the checkpoint and the adapter
188 | model = load_model(args=args)
189 |
190 | # merge peft adapter for inference
191 | model = model.merge_and_unload()
192 |
193 | model = model.to(rank)
194 | model = DistributedDataParallel(model)
195 | print(f"DDP model loaded on rank{rank}")
196 |
197 | inference_epoch(
198 | rank=rank,
199 | model=model,
200 | dataloader=generate_loader,
201 | llama_tokenizer=llama_tokenizer,
202 | args=args
203 | )
204 | # use a barrier to make sure that all processes have finished writing their JSON files
205 | dist.barrier()
206 |
207 | cleanup()
208 |
209 |
210 | def inference_distributed(args: Dict[str, Any]):
211 | """Core generation process across multiple devices with batches over the whole dataset"""
212 | torch.multiprocessing.spawn(
213 | inference_on_device,
214 | args=(args["world_size"], args),
215 | nprocs=args["world_size"],
216 | join=True
217 | )
218 |
219 |
220 | if __name__ == "__main__":
221 | # suppress messages from AutoTokenizer parallelism and Graphein respectively
222 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
223 | os.environ["LOGURU_LEVEL"] = "INFO"
224 |
225 | parsed_args = argParser.parse_args()
226 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs
227 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes
228 |
229 | torch.manual_seed(parsed_args.random_seed)
230 | torch.cuda.manual_seed(parsed_args.random_seed)
231 |
232 | # prepare for saving path
233 | if not os.path.exists(parsed_args.save_generation_dir):
234 | os.makedirs(parsed_args.save_generation_dir)
235 |
236 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
237 | if parsed_args.save_generation_postfix_identifier:
238 | parsed_args.save_generation_postfix_identifier = (
239 | f"{start_timestamp}_[{parsed_args.save_generation_postfix_identifier}]"
240 | )
241 | else:
242 | parsed_args.save_generation_postfix_identifier = start_timestamp
243 |
244 | print("####################")
245 | for key, value in parsed_args.__dict__.items():
246 | print(f"{key}: {value}")
247 | print("####################")
248 |
249 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices
250 | inference_distributed(parsed_args.__dict__)
251 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment
2 |
3 |
10 |
11 |
18 |
19 |
20 |
21 | This is the official repository for the paper "**Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment**" by Xiao Fei, Michail Chatzianastasis, Sarah Almeida Carneiro, Hadi Abdine, Lawrence P. Petalidis, and Michalis Vazirgiannis.
22 |
23 | We're excited to share that our paper has been accepted to 🎉 **NeurIPS 2025** ! An online server, the trained model weights and the dataset are now publicly available on Hugging Face.
24 |
25 | ## About the Project
26 |
27 | Proteins are written in a code made of amino acids, but what if we could actually read that code like a language?
28 |
29 | **Prot2Text-V2** treats a protein sequence as if it were another language, and then translate it into English. The model takes the raw amino acid sequence as input and generates a clear, human-readable paragraph describing what the protein does.
30 |
31 |
32 |

33 |
34 |
35 | The instruction-based Prot2Text-V2 model is an innovative fusion of three key components:
36 |
37 | * Protein language model as sequence encoder: `facebook/esm2_t36_3B_UR50D`
38 | * Modality adapter as a unique and lightweight component that bridges the gap between protein embeddings and the language model.
39 | * Natural language decoder for generating articulate textual descriptions utilizing the sequence embeddings: `meta-llama/Llama-3.1-8B-Instruct`
40 |
41 |
42 |

43 |
44 |
45 | A clever alignment step first captures the semantic meaning of the sequence, after which supervised fine-tuning trains the decoder to generate articulate descriptions.
46 |
47 | For backward compatibility, the repository also includes our legacy base model, `Esm2LlamaForCausalLM`, along with its specialized dataloader.
48 |
49 | ## Getting Started
50 |
51 | ✅ Verified on Ubuntu-22.04-LTS with 2 x NVIDIA RTX A6000
52 |
53 | ✅ Verified on RHEL-9.4 with 8 x NVIDIA A100
54 |
55 | * Install NVIDIA `cuda-toolkit=12.1.1`, see official website for detailed information.
56 |
57 | * Install `dssp=4.0.4` for protein dataset preprocessing:
58 |
59 | ```shell
60 | sudo apt-get install dssp=4.0.4
61 | ```
62 |
63 | * Create environment with `conda` then install packages with `pip`:
64 |
65 | ```shell
66 | conda create -n prot2text-pip python=3.8
67 |
68 | pip3 install torch torchvision torchaudio # torch==2.3.0
69 | pip3 install torch_geometric
70 | pip3 install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
71 |
72 | pip3 install graphein==1.7.7
73 |
74 | pip3 install transformers==4.40.2 tokenizers==0.19.1 accelerate==0.29.3 sentencepiece==0.2.0
75 | pip3 install peft==0.10.0
76 | pip3 install biopython==1.81
77 | pip3 install networkx==2.5
78 | pip3 install chardet==5.2.0 charset-normalizer==2.0.4
79 | pip3 install multiprocess==0.70.16
80 | pip3 install tensorboard==2.14.0
81 | pip3 install evaluate==0.4.2
82 | pip3 install mpi4py==3.1.6
83 |
84 | sudo apt install libaio-dev
85 | DS_BUILD_FUSED_ADAM pip3 install deepspeed==0.14.2
86 |
87 | pip3 install nltk==3.8.1 rouge_score==0.1.2 jiwer==3.0.4
88 | ```
89 |
90 | ## Dataset Preparation
91 |
92 | * Download CSV files from [HuggingFace](https://huggingface.co/datasets/habdine/Prot2Text-Data) and place under `./data`.
93 |
94 | * Download PDB files from AlphaFoldDB (for RGCN only) then preprocess graph and text features:
95 |
96 | ```python
97 | from transformers import AutoTokenizer
98 | from dataset import Prot2TextInstructDataset
99 |
100 | SPLIT = "train" # run script for "eval" and "test" as well
101 | CSV_DIR = "./data"
102 | DATA_ROOT_DIR = "/data/Prot2Text-Llama3-Data"
103 | LLAMA_DIR = "meta-llama/Meta-Llama-3.1-8B-Instruct-hf"
104 | ESM_DIR = "facebook/esm2_t36_3B_UR50D"
105 |
106 | split_dataset = Prot2TextInstructDataset(
107 | root_dir=os.path.join(DATA_ROOT_DIR, SPLIT),
108 | csv_path=os.path.join(CSV_DIR, f"{SPLIT}.csv"),
109 | sequence_tokenizer=AutoTokenizer.from_pretrained(ESM_DIR),
110 | description_tokenizer=AutoTokenizer.from_pretrained(LLAMA_DIR, pad_token='<|reserved_special_token_0|>'),
111 | skip_download=False,
112 | skip_reload=False,
113 | )
114 | ```
115 |
116 | * [Optional] In case of applying new language tokenizer to a preprocessed dataset, run the following to avoid processing graphs again:
117 |
118 | ```python
119 | NEW_LLAMA_DIR = "/data/Llama-3.2-1B"
120 |
121 | split_dataset = Prot2TextInstructDataset(
122 | root_dir=os.path.join(DATA_ROOT_DIR, SPLIT),
123 | csv_path=os.path.join(CSV_DIR, f"{SPLIT}.csv"),
124 | sequence_tokenizer=AutoTokenizer.from_pretrained(ESM_DIR),
125 | description_tokenizer=AutoTokenizer.from_pretrained(NEW_LLAMA_DIR, pad_token='<|reserved_special_token_0|>'),
126 | skip_download=True,
127 | skip_reload=True,
128 | )
129 | split_dataset.process_text()
130 | ```
131 |
132 | ## Model Training Pipeline
133 |
134 | ### 1. Contrastive Learning Stage
135 | `./scripts/train_contrast.py` performs contrastive learning to align protein representations with textual descriptions. This stage helps the model learn meaningful cross-modal embeddings.
136 |
137 | **Arguments:**
138 | - **Model Paths:**
139 | - `--esm_path`: Path to pretrained ESM protein language model
140 | - `--llama_path`: Path to pretrained LLaMA language model
141 | - **Data Directories:**
142 | - `--root_dataset_dir`: Root directory containing protein datasets
143 | - `--root_csv_dir`: Directory containing CSV metadata files
144 | - **Checkpoint Handling:**
145 | - `--save_checkpoint_dir`: Directory to save model checkpoints
146 | - `--load_model_checkpoint_path`: Path to load full model checkpoint (optional)
147 | - `--load_optimizer_scheduler_checkpoint_path`: Path to load optimizer/scheduler state (optional)
148 | - **Training Parameters:**
149 | - `--torch_dtype`: PyTorch data type for training (e.g., float16, float32)
150 | - `--batch_size_per_device`: Batch size per GPU/device
151 | - `--num_epochs`: Total number of training epochs
152 | - `--save_every_epochs`: Frequency of checkpoint saving (in epochs)
153 | - `--gradient_accumulation_steps`: Number of steps for gradient accumulation
154 | - `--learning_rate`: Initial learning rate
155 | - `--gradient_clipping`: Gradient clipping value (optional)
156 | - `--scheduler_gamma`: Learning rate scheduler gamma value
157 | - `--random_seed`: Random seed for reproducibility
158 | - `--contrastive_num_segments`: Number of segments for contrastive learning
159 | - **Data Splits:**
160 | - `--train_split`: Name of training split
161 | - `--eval_split`: Name of evaluation split
162 | - `--debug_trim_train_split`: Trim training set for sanity check (optional)
163 | - `--debug_trim_eval_split`: Trim evaluation set for sanity check (optional)
164 |
165 | ### 2. Supervised Fine-Tuning Stage
166 | After contrastive learning, run `./scripts/train_instruct.py` for instruction fine-tuning on the training set.
167 |
168 | **Additional/Modified Arguments:**
169 | - **Adapter Configuration:**
170 | - `--load_adapter_checkpoint_dir`: Directory to load adapter checkpoints
171 | - `--fix_modality_adapter`: Whether to freeze modality adapter weights
172 | - `--lora_rank`: Rank for LoRA adapter layers
173 | - **Text Field Handling:**
174 | - `--include_text_fields`: Whether to include text fields in input
175 | - `--name_dropout`: Dropout rate for protein names
176 | - `--taxonomy_dropout`: Dropout rate for taxonomy information
177 |
178 | ## Performance Evaluation
179 |
180 | ### 1. Generation (`generate_instruct.py`)
181 | Generates answers for proteins in the test set using a trained model.
182 |
183 | **Key Arguments:**
184 | - **Generation Parameters:**
185 | - `--max_generation_length`: Maximum length of generated text
186 | - `--num_beams`: Number of beams for beam search
187 | - `--temperature`: Sampling temperature
188 | - `--do_sample`: Whether to use sampling
189 | - `--top_p`: Nucleus sampling probability
190 | - `--top_k`: Top-k sampling value
191 | - **Output Control:**
192 | - `--save_generation_postfix_identifier`: Identifier for output files
193 | - `--max_sequence_length`: Maximum input sequence length
194 |
195 | ### 2. Benchmarking (`benchmark.py`)
196 | Evaluates generated outputs using various metrics.
197 |
198 | **Evaluation Options:**
199 | - `--evaluate_exact_match`: Compute exact match accuracy
200 | - `--evaluate_bleu`: Compute BLEU scores
201 | - `--evaluate_rouge`: Compute ROUGE scores
202 | - `--evaluate_bert_score`: Compute BERTScore
203 | - `--read_file_identifier`: Filter generated files by this identifier
204 | - `--verbose`: Print detailed evaluation results
205 |
206 | ## Usage Notes:
207 | 1. For the full training pipeline, first run `train_contrast.py`, then `train_instruct.py`
208 | 2. Generation should use the same data splits used during evaluation
209 | 3. Benchmarking can be customized to compute only relevant metrics
210 | 4. Debug arguments allow for faster iteration during development
211 |
212 | The pipeline supports both full fine-tuning and parameter-efficient approaches (LoRA, adapter layers) through the various adapter-related arguments.
213 |
214 | ## Ⓒ Citation
215 |
216 | If you find our research helpful, feel free to 🖋️ cite our work or ⭐️ star the repository:
217 |
218 | ```bibtex
219 | @misc{prot2textv2,
220 | title={Prot2Text-V2: Protein Function Prediction with Multimodal Contrastive Alignment},
221 | author={Xiao Fei and Michail Chatzianastasis and Sarah Almeida Carneiro and Hadi Abdine and Lawrence P. Petalidis and Michalis Vazirgiannis},
222 | year={2025},
223 | eprint={2505.11194},
224 | archivePrefix={arXiv},
225 | primaryClass={cs.CE},
226 | url={https://arxiv.org/abs/2505.11194},
227 | }
228 | ```
229 |
--------------------------------------------------------------------------------
/scripts/generate_instruct_light.py:
--------------------------------------------------------------------------------
1 | """
2 | DistributedDataParallel generation script implemented from scratch.
3 | Using Prot2TextLightDataset instead of Prot2TextInstructDataset.
4 | Generation results will be saved to separate JSON files, and metrics can be further computed with `benchmark.py`.
5 | The script is designed for multi-GPU parallelism on single node.
6 | """
7 |
8 | import argparse
9 | from datetime import datetime
10 | import json
11 | import os
12 | from typing import Any, Dict, List
13 |
14 | import torch
15 | import torch.distributed as dist
16 | from torch.nn.parallel import DistributedDataParallel
17 | from torch.utils.data import DataLoader
18 | from torch.utils.data.distributed import DistributedSampler
19 | from tqdm import tqdm
20 | from transformers import AutoTokenizer, PreTrainedTokenizer
21 |
22 | # from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader
23 | from dataset import Prot2TextLightDataset, Prot2TextLightCollater
24 | from .train_instruct import load_model, setup, cleanup
25 | import scripts.utils_argparse as utils_argparse
26 |
27 |
28 | argParser = argparse.ArgumentParser()
29 |
30 | argParser.add_argument("--esm_path", type=str)
31 | argParser.add_argument("--llama_path", type=str)
32 | argParser.add_argument("--root_dataset_dir", type=str)
33 | argParser.add_argument("--root_csv_dir", type=str)
34 | argParser.add_argument("--save_generation_dir", type=str)
35 | argParser.add_argument("--save_generation_postfix_identifier", type=str, default=None)
36 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="")
37 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="")
38 |
39 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype)
40 | argParser.add_argument("--batch_size_per_device", type=int)
41 | argParser.add_argument("--random_seed", type=int)
42 | argParser.add_argument("--generate_split", type=str)
43 | argParser.add_argument("--debug_trim_generate_split", type=int, default=None)
44 | argParser.add_argument("--max_description_length", type=int, default=1021) # NEW
45 | argParser.add_argument("--max_sequence_length", type=int, default=512)
46 | argParser.add_argument("--max_generation_length", type=int)
47 | argParser.add_argument("--num_beams", type=int, default=1)
48 | argParser.add_argument("--length_penalty", type=float, default=1.0)
49 | argParser.add_argument("--temperature", type=float, default=1.0)
50 | argParser.add_argument("--do_sample", type=utils_argparse.str2bool, default=False)
51 | argParser.add_argument("--top_p", type=float, default=1.0)
52 | argParser.add_argument("--top_k", type=int, default=50)
53 |
54 |
55 | def iterative_generation_loop(
56 | rank: int,
57 | model: torch.nn.Module,
58 | data_batch: Dict[str, Any],
59 | max_generation_length: int,
60 | num_beams: int,
61 | length_penalty: float,
62 | temperature: float,
63 | do_sample: bool,
64 | top_p: float,
65 | top_k: int
66 | ) -> torch.Tensor:
67 | """
68 | Standard API for different models. Used in `inference_epoch`.
69 | 1) Prepare inputs for the generation cycle with inference using data_batch from dataloader.
70 | 2) Execute the generation cycle and return the direct output.
71 | Returned output is a `torch.Tensor` of the generated tokens.
72 | """
73 | if isinstance(model, DistributedDataParallel):
74 | model = model.module # for wrapper models, get the inner model for generation
75 |
76 | return model.generate(
77 | inputs=data_batch["input_ids"].to(rank),
78 | attention_mask=data_batch["attention_mask"].to(rank),
79 | protein_input_ids=data_batch["protein_input_ids"].to(rank),
80 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank),
81 | max_new_tokens=max_generation_length,
82 | eos_token_id=128009,
83 | pad_token_id=128002,
84 | return_dict_in_generate=False,
85 | num_beams=num_beams,
86 | length_penalty=length_penalty,
87 | temperature=temperature,
88 | do_sample=do_sample,
89 | top_p=top_p,
90 | top_k=top_k
91 | )
92 |
93 |
94 | def inference_epoch(
95 | rank: int,
96 | model: DistributedDataParallel,
97 | # dataloader: Prot2TeqxtInstructDataLoader,
98 | dataloader: DataLoader,
99 | llama_tokenizer: PreTrainedTokenizer,
100 | args: Dict[str, Any]
101 | ):
102 | """
103 | Iterate over all batches for inference with iterative loop.
104 | Generation results will be saved to JSON files.
105 | """
106 | model.eval()
107 | local_names: List[str] = []
108 | local_predictions: List[str] = []
109 | local_labels: List[str] = []
110 |
111 | # core loop for batches
112 | t = tqdm(iter(dataloader))
113 | for data_batch in t:
114 | with torch.no_grad():
115 | output = iterative_generation_loop(
116 | rank=rank,
117 | model=model,
118 | data_batch=data_batch,
119 | max_generation_length=args["max_generation_length"],
120 | num_beams=args["num_beams"],
121 | length_penalty=args["length_penalty"],
122 | temperature=args["temperature"],
123 | do_sample=args["do_sample"],
124 | top_p=args["top_p"],
125 | top_k=args["top_k"]
126 | )
127 | local_names.extend(data_batch["name"])
128 | predicted_texts = llama_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
129 | local_predictions.extend(predicted_texts)
130 | label_texts = llama_tokenizer.batch_decode(data_batch["description_input_ids"], skip_special_tokens=True)
131 | local_labels.extend(label_texts)
132 | t.set_postfix({
133 | "mode": "inference",
134 | "batch_maxlen_gen": output.shape[1],
135 | "device": f"rank:{rank}"
136 | })
137 |
138 | local_json_path = os.path.join(
139 | args["save_generation_dir"],
140 | f"generation_{args['save_generation_postfix_identifier']}_rank{rank}.json"
141 | )
142 | with open(local_json_path, "w") as file:
143 | json_dict = {
144 | name: {"true": label, "pred": prediction}
145 | for name, label, prediction in zip(local_names, local_labels, local_predictions)
146 | }
147 | json.dump(json_dict, file, indent=4)
148 | print(f"Saving {local_json_path}")
149 |
150 |
151 | def inference_on_device(rank: int, world_size: int, args: Dict[str, Any]):
152 | """Core generation process for every device with batches over the whole dataset"""
153 | setup(rank, world_size)
154 |
155 | # prepare dataset and dataloader
156 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"])
157 | llama_tokenizer = AutoTokenizer.from_pretrained(
158 | args["llama_path"],
159 | pad_token='<|reserved_special_token_0|>'
160 | )
161 |
162 | generate_dataset = Prot2TextLightDataset(
163 | csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv")
164 | )
165 |
166 | generate_collater = Prot2TextLightCollater(
167 | sequence_tokenizer=esm_tokenizer,
168 | description_tokenizer=llama_tokenizer,
169 | mode="inference",
170 | include_text_fields=True,
171 | name_dropout=0.0,
172 | taxonomy_dropout=0.0,
173 | )
174 |
175 | #generate_dataset = Prot2TextInstructDataset(
176 | # root_dir=os.path.join(args["root_dataset_dir"], f"{args['generate_split']}"),
177 | # csv_path=os.path.join(args["root_csv_dir"], f"{args['generate_split']}.csv"),
178 | # sequence_tokenizer=esm_tokenizer,
179 | # description_tokenizer=llama_tokenizer,
180 | # skip_reload=True,
181 | # skip_download=True,
182 | # ignore_graph_features=True,
183 | #)
184 | # if args["debug_trim_generate_split"]:
185 | # generate_dataset.usable_file_names = generate_dataset.usable_file_names[
186 | # :args["debug_trim_generate_split"]
187 | # ]
188 | generate_sampler = DistributedSampler(
189 | generate_dataset,
190 | rank=rank,
191 | num_replicas=world_size,
192 | shuffle=False
193 | )
194 | generate_loader = DataLoader(
195 | generate_dataset,
196 | batch_size=args["batch_size_per_device"],
197 | sampler=generate_sampler,
198 | num_workers=2,
199 | pin_memory=True,
200 | shuffle=False,
201 | drop_last=True,
202 | collate_fn=generate_collater
203 | )
204 | #generate_loader = Prot2TextInstructDataLoader(
205 | # generate_dataset,
206 | # mode="inference",
207 | # batch_size=args["batch_size_per_device"],
208 | # sampler=generate_sampler,
209 | # num_workers=2,
210 | # pin_memory=True,
211 | # shuffle=False,
212 | # drop_last=True,
213 | #)
214 |
215 | # load base model and then the checkpoint and the adapter
216 | model = load_model(args=args)
217 |
218 | # if rank == 0:
219 | # for name, param in model.named_parameters():
220 | # print(name, param.requires_grad)
221 |
222 | # merge peft adapter for inference
223 | model = model.merge_and_unload()
224 | model.train()
225 | model.adapter.fc1.requires_grad = True
226 |
227 | model = model.to(rank)
228 | model = DistributedDataParallel(model)
229 | print(f"DDP model loaded on rank{rank}")
230 |
231 | inference_epoch(
232 | rank=rank,
233 | model=model,
234 | dataloader=generate_loader,
235 | llama_tokenizer=llama_tokenizer,
236 | args=args
237 | )
238 | # use a barrier to make sure that all processes have finished writing their JSON files
239 | dist.barrier()
240 |
241 | cleanup()
242 |
243 |
244 | def inference_distributed(args: Dict[str, Any]):
245 | """Core generation process across multiple devices with batches over the whole dataset"""
246 | torch.multiprocessing.spawn(
247 | inference_on_device,
248 | args=(args["world_size"], args),
249 | nprocs=args["world_size"],
250 | join=True
251 | )
252 |
253 |
254 | if __name__ == "__main__":
255 | # suppress messages from AutoTokenizer parallelism and Graphein respectively
256 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
257 | os.environ["LOGURU_LEVEL"] = "INFO"
258 |
259 | parsed_args = argParser.parse_args()
260 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs
261 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes
262 |
263 | torch.manual_seed(parsed_args.random_seed)
264 | torch.cuda.manual_seed(parsed_args.random_seed)
265 |
266 | # prepare for saving path
267 | if not os.path.exists(parsed_args.save_generation_dir):
268 | os.makedirs(parsed_args.save_generation_dir)
269 |
270 | start_timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
271 | if parsed_args.save_generation_postfix_identifier:
272 | parsed_args.save_generation_postfix_identifier = (
273 | f"{start_timestamp}_[{parsed_args.save_generation_postfix_identifier}]"
274 | )
275 | else:
276 | parsed_args.save_generation_postfix_identifier = start_timestamp
277 |
278 | print("####################")
279 | for key, value in parsed_args.__dict__.items():
280 | print(f"{key}: {value}")
281 | print("####################")
282 |
283 | # do inference and save to separate JSON files, rank index always starts from zero regardless cuda indices
284 | inference_distributed(parsed_args.__dict__)
285 |
--------------------------------------------------------------------------------
/dataset/dataloader_light.py:
--------------------------------------------------------------------------------
1 | """
2 | Light-weight dataset and data collater class for protein function prediction
3 | instruction tuning. To be used with Esm2LlamaInstructForCausalLM.
4 |
5 | Such flexible implementation is designed to fetch raw text data from a CSV file
6 | and perform tokenization and padding on-the-fly. This is useful when the default
7 | user message and chat template is not suitable for the task at hand.
8 |
9 | Can only be used if the model is not requiring graph-related data.
10 |
11 | Every batch from DataLoader will contain following attributes:
12 | * Training mode (train-eval with teacher-forcing):
13 | - graph related features:
14 | None
15 | - amino-acid sequence:
16 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
17 | - protein_attention_mask (bsz, max_seq_len+2) # right padding
18 | - concatenated chat:
19 | - input_ids (bsz, max_prompt_len+max_text_len+1)
20 | - attention_mask (bsz, max_prompt_len+max_text_len+1)
21 | - labels (bsz, max_prompt_len+max_text_len+1)
22 | - standalone description for contrastive learning:
23 | - description_input_ids (bsz, max_text_len+1) # eos token only
24 | - description_attention_mask (bsz, max_text_len+1) # right padding
25 |
26 | ids = [left-pad + bos + prompt & description + eot + right-pad]
27 | mask = [0s + 1 + 1s & 1s + 1 + 0s ]
28 | labels = [-100s + -100 + -100s & description + eot + -100s ]
29 | desc_ids = [ & description + eot + right-pad]
30 | desc_mask = [ & 1s + 1 + 0s ]
31 |
32 | * Inference mode (iterative generation):
33 | - graph related features:
34 | None
35 | - amino-acid sequence:
36 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
37 | - protein_attention_mask (bsz, max_seq_len+2) # right padding
38 | - prompt chat:
39 | - input_ids (bsz, max_prompt_len)
40 | - attention_mask (bsz, max_prompt_len)
41 | - description_input_ids (bsz, max_text_len+1) # for evaluation
42 |
43 | ids = [left-pad + bos + prompt & ]
44 | mask = [0s + 1 + 1s & ]
45 | desc_ids = [ & description + eot + right-pad]
46 |
47 | Example of usage:
48 | >>> from torch.utils.data import DataLoader
49 | >>> from transformers import AutoTokenizer
50 | >>> from dataset import Prot2TextLightDataset, Prot2TextLightCollater
51 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D")
52 | >>> llama_tokenizer = AutoTokenizer.from_pretrained(
53 | "/data/Meta-Llama-3.1-8B-Instruct-hf",
54 | pad_token='<|reserved_special_token_0|>'
55 | )
56 | >>> train_dataset = Prot2TextLightDataset("./data/train.csv")
57 | >>> train_collater = Prot2TextLightCollater(
58 | sequence_tokenizer=esm_tokenizer,
59 | description_tokenizer=llama_tokenizer,
60 | mode="train"
61 | )
62 | >>> train_dataloader = DataLoader(
63 | train_dataset,
64 | batch_size=4,
65 | shuffle=True,
66 | num_workers=4,
67 | collate_fn=train_collater,
68 | pin_memory=True,
69 | drop_last=True
70 | )
71 | """
72 |
73 | import random
74 | from typing import Dict, List, Literal, Optional
75 |
76 | import pandas as pd
77 | import torch
78 | import torch.utils.data
79 | from transformers import PreTrainedTokenizer
80 |
81 |
82 | class Prot2TextLightDataset(torch.utils.data.Dataset):
83 | """Dataset class loading directly from single CSV file."""
84 | def __init__(self, csv_path: str):
85 | super().__init__()
86 | self.data: pd.DataFrame = pd.read_csv(csv_path)
87 |
88 | def __len__(self) -> int:
89 | return len(self.data)
90 |
91 | def __getitem__(self, idx: int) -> Dict[str, str]:
92 | return {
93 | column_name: self.data.iloc[idx][column_name]
94 | for column_name in self.data.columns
95 | }
96 |
97 |
98 | class Prot2TextLightCollater:
99 | def __init__(
100 | self,
101 | sequence_tokenizer: PreTrainedTokenizer,
102 | description_tokenizer: PreTrainedTokenizer,
103 | mode: Literal["train", "inference"] = "train",
104 | include_text_fields: bool = True,
105 | name_dropout: float = 0.8,
106 | taxonomy_dropout: float = 0.8,
107 | max_sequence_length: Optional[int] = 1021,
108 | max_description_length: Optional[int] = 512,
109 | system_message: str = (
110 | "You are a scientific assistant specialized in protein function "
111 | "predictions. Given the sequence embeddings and other information "
112 | "of a protein, describe its function clearly and concisely in "
113 | "professional language. "
114 | ),
115 | placeholder_token: str = '<|reserved_special_token_1|>',
116 | ):
117 | self.sequence_tokenizer = sequence_tokenizer
118 | self.description_tokenizer = description_tokenizer
119 | self.mode = mode
120 |
121 | self.include_text_fields = include_text_fields
122 | self.name_dropout = name_dropout
123 | self.taxonomy_dropout = taxonomy_dropout
124 |
125 | self.max_sequence_length = max_sequence_length
126 | self.max_description_length = max_description_length
127 | self.system_message = system_message
128 | self.placeholder_token = placeholder_token
129 |
130 | def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
131 | # group data across batch
132 | accessions = [item["AlphaFoldDB"] for item in batch]
133 | fullnames = [item["Full Name"] for item in batch]
134 | taxons = [item["taxon"] for item in batch]
135 | sequences = [item["sequence"] for item in batch]
136 | descriptions = [item["function"] for item in batch]
137 |
138 | # replace nan in name and taxon with unknown
139 | fullnames = [
140 | fullname
141 | if isinstance(fullname, str) and random.random() > self.name_dropout
142 | else "unknown"
143 | for fullname in fullnames
144 | ]
145 | taxons = [
146 | taxon
147 | if isinstance(taxon, str) and random.random() > self.taxonomy_dropout
148 | else "unknown"
149 | for taxon in taxons
150 | ]
151 |
152 | # for each sequence in sequences
153 | # if the sequence is origianlly longer than max_sequence_length, take a segment of that length randomly
154 | # else do nothing
155 | for i in range(len(sequences)):
156 | if len(sequences[i]) > self.max_sequence_length:
157 | start = random.randint(0, len(sequences[i]) - self.max_sequence_length)
158 | sequences[i] = sequences[i][start:start + self.max_sequence_length]
159 |
160 | # truncate and tokenize sequences
161 | self.sequence_tokenizer.padding_side = "right"
162 | tokenized_sequences = self.sequence_tokenizer(
163 | sequences,
164 | truncation=True,
165 | padding="longest",
166 | max_length=self.max_sequence_length + 2, # including bos and eos tokens of esm tokenizer
167 | return_tensors="pt"
168 | )
169 | sequence_input_ids = tokenized_sequences["input_ids"]
170 | sequence_attention_mask = tokenized_sequences["attention_mask"]
171 |
172 | # apply chat template
173 | sequence_lens = sequence_attention_mask.sum(dim=1).tolist()
174 |
175 | if self.include_text_fields:
176 | user_messages = [
177 | # (
178 | # (f"Protein name: {fullname}; " if fullname != "unknown" else "")
179 | # + (f"Taxon: {taxon}; " if taxon != "unknown" else "")
180 | # + "Sequence embeddings: " + self.placeholder_token * sequence_len
181 | # )
182 | (
183 | f"Protein name: {fullname}; Taxon: {taxon}; "
184 | + "Sequence embeddings: " + self.placeholder_token * sequence_len
185 | )
186 | for fullname, taxon, sequence_len in zip(fullnames, taxons, sequence_lens)
187 | ]
188 | else:
189 | user_messages = [
190 | "Sequence embeddings: " + self.placeholder_token * sequence_lens
191 | for sequence_lens in sequence_lens
192 | ]
193 |
194 | prompt_conversations = [
195 | [
196 | {"role": "system", "content": self.system_message},
197 | {"role": "user", "content": user_message}
198 | ]
199 | for user_message in user_messages
200 | ]
201 |
202 | # tokenize prompts
203 | self.description_tokenizer.padding_side = "left"
204 | tokenized_prompts = self.description_tokenizer.apply_chat_template(
205 | prompt_conversations,
206 | add_generation_prompt=True,
207 | tokenize=True,
208 | padding="longest",
209 | return_tensors="pt",
210 | return_dict=True
211 | )
212 | prompt_input_ids = tokenized_prompts["input_ids"]
213 | prompt_attention_mask = tokenized_prompts["attention_mask"]
214 |
215 | # tokenize descriptions
216 | self.description_tokenizer.padding_side = "right"
217 | tokenized_descriptions = self.description_tokenizer(
218 | [description + self.description_tokenizer.eos_token for description in descriptions],
219 | add_special_tokens=False, # do not add bos token to the beginning
220 | truncation=True,
221 | padding="longest",
222 | max_length=self.max_description_length,
223 | return_tensors="pt"
224 | )
225 | description_input_ids = tokenized_descriptions["input_ids"]
226 | description_attention_mask = tokenized_descriptions["attention_mask"]
227 |
228 | # truncate descriptions
229 | if description_input_ids.size(1) > self.max_description_length:
230 | description_input_ids = description_input_ids[:, :self.max_description_length]
231 | description_attention_mask = description_attention_mask[:, :self.max_description_length]
232 |
233 | # prepare labels
234 | labels = description_input_ids.clone()
235 | labels[description_attention_mask == 0] = -100
236 |
237 | # assemble
238 | if self.mode == "train":
239 | return {
240 | "name": accessions,
241 | "protein_input_ids": sequence_input_ids,
242 | "protein_attention_mask": sequence_attention_mask,
243 | "input_ids": torch.cat([
244 | prompt_input_ids,
245 | description_input_ids,
246 | ], dim=1),
247 | "attention_mask": torch.cat([
248 | prompt_attention_mask,
249 | description_attention_mask,
250 | ], dim=1),
251 | "labels": torch.cat([
252 | torch.full_like(
253 | prompt_input_ids,
254 | fill_value=-100,
255 | ),
256 | labels,
257 | ], dim=1),
258 | "description_input_ids": description_input_ids,
259 | "description_attention_mask": description_attention_mask
260 | }
261 |
262 | elif self.mode == "inference":
263 | return {
264 | "name": accessions,
265 | "protein_input_ids": sequence_input_ids,
266 | "protein_attention_mask": sequence_attention_mask,
267 | "input_ids": prompt_input_ids,
268 | "attention_mask": prompt_attention_mask,
269 | "description_input_ids": description_input_ids,
270 | }
271 |
272 | else:
273 | raise ValueError(f"Invalid mode: {self.mode}")
274 |
--------------------------------------------------------------------------------
/models/modeling_esm2llama_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model.
3 |
4 | Esm2LlamaInstructForCausalLM = EsmModel + ModalityAdapter + LlamaForCausalLM
5 |
6 | For training/evaluation under teacher-forcing scenario, the model `forward`
7 | function shall take following arguments:
8 | * input_ids: (bsz, prompt_len+description_len) # whole chat template
9 | * attention_mask: (bsz, prompt_len+description_len) # left & right padding
10 | * position_ids: (bsz, prompt_len+description_len) # optional
11 | * past_key_values: None
12 | * labels: (bsz, prompt_len+description_len) # -100 for padding & prompt
13 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds
14 | * protein_attention_mask: (bsz, prot_seq_len) # right padding
15 | * protein_position_ids: (bsz, prot_seq_len) # optional
16 | * protein_head_mask: (num_heads,) or (num_layers, num_heads) # optional
17 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional
18 | * use_cache: False
19 | * return_decoder_inputs: False
20 |
21 | For inference, the model `generate` function shall take following arguments:
22 | * inputs: (bsz, prompt_len) # prompt part of chat template
23 | * attention_mask: (bsz, prompt_len) # left padding
24 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds
25 | * protein_attention_mask: (bsz, prot_seq_len) # right padding
26 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional
27 | """
28 |
29 |
30 | from typing import Optional, Tuple, Union
31 |
32 | import torch
33 | from transformers import Cache, PreTrainedModel
34 | from transformers.generation.utils import GenerateOutput
35 | from transformers.modeling_outputs import CausalLMOutputWithPast
36 | from transformers.models.esm.modeling_esm import EsmModel
37 | from transformers.models.llama import LlamaForCausalLM
38 |
39 | from .configuration_esm2llama_instruct import (
40 | ModalityAdapterConfig,
41 | Esm2LlamaInstructConfig
42 | )
43 |
44 |
45 | class ModalityAdapter(PreTrainedModel):
46 | """2-layer adapter to match the hidden size of different modalities."""
47 | config_class = ModalityAdapterConfig # configuration class for this model
48 |
49 | def __init__(self, config: ModalityAdapterConfig):
50 | super().__init__(config)
51 | self.config = config
52 | self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim)
53 | self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim)
54 | self.activation = torch.nn.GELU()
55 | self.dropout = torch.nn.Dropout(p=config.dropout_rate)
56 | self.ln1 = torch.nn.LayerNorm(normalized_shape=config.intermediate_dim) # DEPRECATED
57 | self.ln2 = torch.nn.LayerNorm(normalized_shape=config.output_dim) # DEPRECATED
58 | self.post_init() # initialize weights and apply final processing
59 |
60 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
61 | # input: (bsz, seq_len, input_dim)
62 | hidden_states = self.activation(self.fc1(hidden_states))
63 | hidden_states = self.dropout(hidden_states)
64 | # interm: (bsz, seq_len, interm_dim)
65 | hidden_states = self.activation(self.fc2(hidden_states))
66 | hidden_states = self.dropout(hidden_states)
67 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
68 | return hidden_states # (bsz, seq_len, output_dim)
69 |
70 |
71 | class Esm2LlamaInstructForCausalLM(PreTrainedModel):
72 | """
73 | Esm2LlamaInstructForCausalLM model for protein function prediction.
74 | Similar to `EncoderDecoderModel` but with more complicated architecture.
75 | Initialize with either a configuration OR all three components.
76 | `kwargs` can override standalone attributes in `Esm2LlamaInstructConfig`.
77 | """
78 | config_class = Esm2LlamaInstructConfig # configuration class for this model
79 |
80 | def __init__(
81 | self,
82 | config: Optional[Esm2LlamaInstructConfig] = None,
83 | esm_encoder: Optional[EsmModel] = None,
84 | adapter: Optional[ModalityAdapter] = None,
85 | llama_decoder: Optional[LlamaForCausalLM] = None,
86 | **kwargs
87 | ):
88 | if config is not None: # components ignored if config is provided
89 | super().__init__(config)
90 | self.esm_encoder = EsmModel(
91 | config.esm_config,
92 | add_pooling_layer=False
93 | )
94 | self.adapter = ModalityAdapter(config.adapter_config)
95 | self.llama_decoder = LlamaForCausalLM(config.llama_config)
96 | else:
97 | config = Esm2LlamaInstructConfig(
98 | esm_config=esm_encoder.config,
99 | adapter_config=adapter.config,
100 | llama_config=llama_decoder.config,
101 | **kwargs # override standalone attributes
102 | )
103 | super().__init__(config)
104 | self.esm_encoder = esm_encoder
105 | self.adapter = adapter
106 | self.llama_decoder = llama_decoder
107 |
108 | def prepare_decoder_inputs(
109 | self,
110 | input_ids: torch.LongTensor,
111 | encoder_hidden_states: torch.FloatTensor,
112 | attention_mask: Optional[torch.LongTensor] = None,
113 | encoder_attention_mask: Optional[torch.LongTensor] = None,
114 | ):
115 | """
116 | Embed and replace placeholder in `input_ids` by encoder hidden states.
117 | `input_ids` must be passed to locate placeholder for replacement.
118 | """
119 | # preparation
120 | batch_size, seq_len = input_ids.size()
121 | _, encoder_seq_len, _ = encoder_hidden_states.size()
122 | if attention_mask is None:
123 | attention_mask = torch.ones(
124 | (batch_size, seq_len),
125 | dtype=torch.long,
126 | device=input_ids.device
127 | )
128 | if encoder_attention_mask is None:
129 | encoder_attention_mask = torch.ones(
130 | (batch_size, encoder_seq_len),
131 | dtype=torch.long,
132 | device=encoder_hidden_states.device
133 | )
134 | inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids)
135 | # replacement
136 | placeholder_mask = input_ids == self.config.placeholder_id
137 | encoder_mask = encoder_attention_mask.bool()
138 | inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask]
139 | return inputs_embeds, attention_mask
140 |
141 | def forward(
142 | self,
143 | # chat template text inputs
144 | input_ids: Optional[torch.LongTensor] = None,
145 | attention_mask: Optional[torch.LongTensor] = None,
146 | position_ids: Optional[torch.LongTensor] = None,
147 | past_key_values: Optional[Cache] = None,
148 | labels: Optional[torch.LongTensor] = None,
149 | # protein amino-acid sequence inputs
150 | protein_input_ids: Optional[torch.LongTensor] = None,
151 | protein_attention_mask: Optional[torch.LongTensor] = None,
152 | protein_position_ids: Optional[torch.LongTensor] = None,
153 | protein_head_mask: Optional[torch.LongTensor] = None,
154 | protein_inputs_embeds: Optional[torch.FloatTensor] = None,
155 | # behavior control arguments
156 | use_cache: Optional[bool] = None,
157 | output_attentions: Optional[bool] = None,
158 | output_hidden_states: Optional[bool] = None,
159 | return_dict: Optional[bool] = None,
160 | return_encoder_outputs: bool = False,
161 | return_adapter_outputs: bool = False,
162 | return_decoder_inputs: bool = False,
163 | cache_position: Optional[torch.LongTensor] = None
164 | ) -> Union[Tuple, CausalLMOutputWithPast]:
165 | """
166 | Compute encoder and adapter outputs, then pass to decoder.
167 | `input_ids` is expected to be [prompt + description] in teacher-forcing
168 | scenario and [prompt] only in first iteration of inference (with
169 | return_decoder_inputs=True).
170 | Attention: possible concatenation of the mask and labels should be
171 | handled before calling this method.
172 | `inputs_embeds` not allowed due to placeholder replacement scheme.
173 | """
174 | # esm_encoder forward
175 | encoder_output = self.esm_encoder(
176 | input_ids=protein_input_ids,
177 | attention_mask=protein_attention_mask,
178 | position_ids=protein_position_ids,
179 | head_mask=protein_head_mask,
180 | inputs_embeds=protein_inputs_embeds,
181 | use_cache=False, # because config.esm_config.is_decoder=False
182 | output_attentions=output_attentions,
183 | output_hidden_states=output_hidden_states,
184 | return_dict=return_dict
185 | )
186 | encoder_hidden_states = encoder_output[0]
187 | encoder_attention_mask = protein_attention_mask
188 | if return_encoder_outputs:
189 | return encoder_output
190 | # adapter forward
191 | adapter_output = self.adapter(encoder_hidden_states)
192 | if return_adapter_outputs:
193 | return adapter_output, encoder_attention_mask
194 | # decoder input preparation
195 | inputs_embeds, attention_mask = self.prepare_decoder_inputs(
196 | input_ids=input_ids,
197 | encoder_hidden_states=adapter_output,
198 | attention_mask=attention_mask,
199 | encoder_attention_mask=encoder_attention_mask,
200 | )
201 | if return_decoder_inputs:
202 | return inputs_embeds, attention_mask
203 | # llama_decoder forward
204 | return self.llama_decoder.forward(
205 | input_ids=None,
206 | attention_mask=attention_mask,
207 | position_ids=position_ids,
208 | past_key_values=past_key_values,
209 | inputs_embeds=inputs_embeds,
210 | labels=labels,
211 | use_cache=use_cache,
212 | output_attentions=output_attentions,
213 | return_dict=return_dict,
214 | cache_position=cache_position
215 | )
216 |
217 | def generate(
218 | self,
219 | inputs: torch.LongTensor, # alias of `input_ids`
220 | attention_mask: Optional[torch.LongTensor] = None,
221 | protein_input_ids: Optional[torch.LongTensor] = None,
222 | protein_attention_mask: Optional[torch.LongTensor] = None,
223 | protein_inputs_embeds: Optional[torch.FloatTensor] = None,
224 | **kwargs
225 | ) -> Union[GenerateOutput, torch.LongTensor]:
226 | """
227 | Do inference based on given input prompt.
228 | `inputs` is expected to be [prompt] only.
229 | Output will not keep the input prompt due to input in form of embeds.
230 | Generation behavior can be controlled by `args` and `kwargs`, read
231 | `GenerationMixin.generate` for more info.
232 | """
233 | # get decoder inputs
234 | prompt_inputs_embeds, prompt_attention_mask = self(
235 | input_ids=inputs,
236 | attention_mask=attention_mask,
237 | protein_input_ids=protein_input_ids,
238 | protein_attention_mask=protein_attention_mask,
239 | protein_inputs_embeds=protein_inputs_embeds,
240 | use_cache=False,
241 | output_attentions=False,
242 | output_hidden_states=False,
243 | return_dict=False,
244 | return_decoder_inputs=True
245 | )
246 | # do generate on llama_decoder
247 | return self.llama_decoder.generate(
248 | inputs_embeds=prompt_inputs_embeds,
249 | attention_mask=prompt_attention_mask,
250 | **kwargs
251 | )
252 |
253 | def gradient_checkpointing_enable(self):
254 | """
255 | Enable gradient checkpointing for all submodules that support it.
256 | Attention! Model need to be in train mode before calling this method.
257 | """
258 | if hasattr(self.esm_encoder, "gradient_checkpointing_enable"):
259 | self.esm_encoder.gradient_checkpointing_enable()
260 | if hasattr(self.llama_decoder, "gradient_checkpointing_enable"):
261 | self.llama_decoder.gradient_checkpointing_enable()
262 | # simple adapter no need to implement gradient checkpointing
263 |
264 | def gradient_checkpointing_disable(self):
265 | if hasattr(self.esm_encoder, "gradient_checkpointing_disable"):
266 | self.esm_encoder.gradient_checkpointing_disable()
267 | if hasattr(self.llama_decoder, "gradient_checkpointing_disable"):
268 | self.llama_decoder.gradient_checkpointing_disable()
269 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | DataLoader class for protein function prediction instruction tuning. To be used
3 | with Prot2TextInstructDataset for Esm2LlamaInstructForCausalLM.
4 |
5 | Every batch from DataLoader will contain following attributes:
6 | * Training mode (train-eval with teacher-forcing):
7 | - graph related features:
8 | - x: (sum_num_nodes, num_node_features)
9 | - edge_index: (2, sum_num_edges)
10 | - edge_type: (sum_num_edges,)
11 | - batch: (sum_num_nodes,)
12 | - amino-acid sequence:
13 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
14 | - protein_attention_mask (bsz, max_seq_len+2) # right padding
15 | - concatenated chat:
16 | - input_ids (bsz, max_prompt_len+max_text_len+1)
17 | - attention_mask (bsz, max_prompt_len+max_text_len+1)
18 | - labels (bsz, max_prompt_len+max_text_len+1)
19 | - standalone description for contrastive learning:
20 | - description_input_ids (bsz, max_text_len+1) # eos token only
21 | - description_attention_mask (bsz, max_text_len+1) # right padding
22 |
23 | ids = [left-pad + bos + prompt & description + eot + right-pad]
24 | mask = [0s + 1 + 1s & 1s + 1 + 0s ]
25 | labels = [-100s + -100 + -100s & description + eot + -100s ]
26 | desc_ids = [ & description + eot + right-pad]
27 | desc_mask = [ & 1s + 1 + 0s ]
28 |
29 | * Inference mode (iterative generation):
30 | - graph related features:
31 | - x: (sum_num_nodes, num_node_features)
32 | - edge_index: (2, sum_num_edges)
33 | - edge_type: (sum_num_edges,)
34 | - batch: (sum_num_nodes,)
35 | - amino-acid sequence:
36 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
37 | - protein_attention_mask (bsz, max_seq_len+2) # right padding
38 | - prompt chat:
39 | - input_ids (bsz, max_prompt_len)
40 | - attention_mask (bsz, max_prompt_len)
41 | - description_input_ids (bsz, max_text_len+1) # for evaluation
42 |
43 | ids = [left-pad + bos + prompt & ]
44 | mask = [0s + 1 + 1s & ]
45 | desc_ids = [ & description + eot + right-pad]
46 |
47 | Example of usage:
48 | >>> from transformers import AutoTokenizer
49 | >>> from dataset import Prot2TextInstructDataset, Prot2TextInstructDataLoader
50 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D")
51 | >>> llama_tokenizer = AutoTokenizer.from_pretrained(
52 | "/data/Meta-Llama-3.1-8B-Instruct-hf",
53 | pad_token='<|reserved_special_token_0|>'
54 | )
55 | >>> train_dataset = Prot2TextInstructDataset(
56 | root_dir="/data/Prot2Text-Llama3-Data/train",
57 | csv_path="./data/train.csv",
58 | sequence_tokenizer=esm_tokenizer,
59 | description_tokenizer=llama_tokenizer,
60 | skip_download=True, # assume data is already downloaded
61 | skip_reload=True, # assume data is already preprocessed
62 | )
63 | >>> train_dataloader = Prot2TextInstructDataLoader(
64 | dataset=train_dataset,
65 | mode="train",
66 | batch_size=2,
67 | shuffle=True,
68 | )
69 | """
70 |
71 |
72 | from typing import Dict, List, Literal, Optional, Union
73 |
74 | import torch
75 | import torch.utils.data
76 | import torch_geometric
77 | import torch_geometric.data
78 | import torch_geometric.loader.dataloader
79 | from transformers import PreTrainedTokenizer
80 |
81 | from .dataset import Prot2TextInstructDataset
82 |
83 |
84 | class Prot2TextInstructCollater(torch_geometric.loader.dataloader.Collater):
85 | def __init__(
86 | self,
87 | dataset: Prot2TextInstructDataset,
88 | tokenizer: PreTrainedTokenizer,
89 | mode: Literal["train", "inference"],
90 | **kwargs,
91 | ):
92 | super().__init__(dataset=dataset, **kwargs)
93 | self.tokenizer = tokenizer
94 | self.mode = mode
95 | self.seq_pad_token_id = self.dataset.sequence_tokenizer.pad_token_id
96 | self.text_pad_token_id = tokenizer.pad_token_id
97 |
98 | def __call__(
99 | self,
100 | batch: List[Dict[str, Union[str, torch.Tensor]]]
101 | ) -> Dict[str, torch.Tensor]:
102 | # prepare graph related features and name
103 | data_batch = torch_geometric.data.Batch.from_data_list(
104 | batch,
105 | exclude_keys=[
106 | "sequence_input_ids",
107 | "prompt_input_ids",
108 | "description_input_ids",
109 | ]
110 | )
111 |
112 | # prepare attn mask, right pad and stack sequences
113 | sequence_input_ids = [data["sequence_input_ids"][0] for data in batch]
114 | pad_sequence_input_ids = self._pad_sequence(
115 | sequence_input_ids,
116 | padding_value=self.seq_pad_token_id,
117 | padding_side="right"
118 | )
119 | pad_sequence_attention_mask = self._pad_sequence(
120 | [torch.ones_like(data["sequence_input_ids"][0]) for data in batch],
121 | padding_value=0,
122 | padding_side="right"
123 | )
124 |
125 | # prepare attn mask, left pad and stack prompts
126 | prompt_input_ids = [data["prompt_input_ids"][0] for data in batch]
127 | pad_prompt_input_ids = self._pad_sequence(
128 | prompt_input_ids,
129 | padding_value=self.text_pad_token_id,
130 | padding_side="left"
131 | )
132 | pad_prompt_attention_mask = self._pad_sequence(
133 | [torch.ones_like(data["prompt_input_ids"][0]) for data in batch],
134 | padding_value=0,
135 | padding_side="left"
136 | )
137 |
138 | # prepare attn mask, right pad and stack descriptions
139 | description_input_ids = [data["description_input_ids"][0] for data in batch]
140 | pad_description_input_ids = self._pad_sequence(
141 | description_input_ids,
142 | padding_value=self.text_pad_token_id,
143 | padding_side="right"
144 | )
145 | pad_description_attention_mask = self._pad_sequence(
146 | [torch.ones_like(data["description_input_ids"][0]) for data in batch],
147 | padding_value=0,
148 | padding_side="right"
149 | )
150 | pad_labels = self._pad_sequence(
151 | description_input_ids,
152 | padding_value=-100,
153 | padding_side="right"
154 | )
155 |
156 | # update text features
157 | if self.mode == "train":
158 | data_batch.update({
159 | "input_ids": torch.cat([
160 | pad_prompt_input_ids,
161 | pad_description_input_ids
162 | ], dim=1),
163 | "attention_mask": torch.cat([
164 | pad_prompt_attention_mask,
165 | pad_description_attention_mask
166 | ], dim=1),
167 | "labels": torch.cat([
168 | torch.full_like(
169 | pad_prompt_input_ids,
170 | fill_value=-100
171 | ),
172 | pad_labels
173 | ], dim=1),
174 | "protein_input_ids": pad_sequence_input_ids,
175 | "protein_attention_mask": pad_sequence_attention_mask,
176 | "description_input_ids": pad_description_input_ids,
177 | "description_attention_mask": pad_description_attention_mask,
178 | })
179 | elif self.mode == "inference":
180 | data_batch.update({
181 | "input_ids": pad_prompt_input_ids,
182 | "attention_mask": pad_prompt_attention_mask,
183 | "description_input_ids": pad_description_input_ids,
184 | "protein_input_ids": pad_sequence_input_ids,
185 | "protein_attention_mask": pad_sequence_attention_mask,
186 | })
187 | else:
188 | raise ValueError(f"Invalid mode: {self.mode}")
189 |
190 | # remove excluded keys
191 | if self.exclude_keys:
192 | data_batch = {
193 | k: v for k, v in data_batch.items()
194 | if k not in self.exclude_keys
195 | }
196 |
197 | return data_batch
198 |
199 | @staticmethod
200 | def _pad_sequence(
201 | sequences: List[torch.Tensor],
202 | padding_value: Union[float, int],
203 | padding_side: Literal["left", "right"] = "right",
204 | ) -> torch.Tensor:
205 | """
206 | Modified version of torch.nn.utils.rnn.pad_sequence with optional
207 | padding side.
208 | Such feature is naturally supported by PyTorch 2.6.0+ and it's
209 | recommended to use the built-in version for better efficiency.
210 |
211 | * sequences as input must be a list of 1D tensors.
212 | """
213 | max_len = max(sequence.shape[-1] for sequence in sequences)
214 | padded_sequences = []
215 | for sequence in sequences:
216 | padding = torch.full(
217 | size=(max_len - sequence.shape[-1],),
218 | fill_value=padding_value,
219 | dtype=sequence.dtype,
220 | device=sequence.device,
221 | )
222 | if padding_side == "left":
223 | padded_sequences.append(torch.cat([padding, sequence], dim=-1))
224 | elif padding_side == "right":
225 | padded_sequences.append(torch.cat([sequence, padding], dim=-1))
226 | else:
227 | raise ValueError(f"Invalid padding side: {padding_side}")
228 | return torch.stack(padded_sequences, dim=0)
229 |
230 |
231 | class Prot2TextInstructDataLoader(torch.utils.data.DataLoader):
232 | """
233 | DataLoader class proteins, forming batch inputs.
234 |
235 | (1) Compose graph related features,
236 | (2) dynamically pad sequences, prompts and descriptions, then
237 | (3) stack then concatenate these text features under different modes.
238 |
239 | Args:
240 | dataset:
241 | `Prot2TextInstructDataset` class to load data from.
242 | mode:
243 | - "train": training-evaluation with teacher-forcing. Input ids will
244 | be concatenated with labels (system + user + assistant) for
245 | training.
246 | - "inference": iterative generation. Input ids will only contain
247 | prompt (system + user) for generation.
248 | batch_size:
249 | Number of samples per batch.
250 | shuffle:
251 | Whether to shuffle the data. If a sampler is provided, the shuffling
252 | behavior should be controlled by the sampler and this argument
253 | should be set to False.
254 | follow_batch:
255 | PyG specific feature to be passed to Collater class. When working
256 | with batched graph data, follow_batch is used to indicate which
257 | attributes should an extra batch indices tensor be created for. If
258 | not set, an extra `batch` tensor will be automatically created to
259 | indicate which graph each node belongs to.
260 | exclude_keys:
261 | List of keys to exclude from the batch. The exclusion will be applied
262 | at the very end of the collation process.
263 | kwargs:
264 | Additional arguments to pass to DataLoader class.
265 | """
266 | def __init__(
267 | self,
268 | dataset: Prot2TextInstructDataset,
269 | mode: Literal["train", "inference"] = "train",
270 | batch_size: int = 1,
271 | shuffle: bool = True,
272 | follow_batch: Optional[List[str]] = None,
273 | exclude_keys: Optional[List[str]] = None,
274 | **kwargs,
275 | ):
276 | # override collate_fn
277 | kwargs.pop("collate_fn", None)
278 | self.follow_batch = follow_batch
279 | self.exclude_keys = exclude_keys
280 | collater = Prot2TextInstructCollater(
281 | dataset=dataset,
282 | tokenizer=dataset.description_tokenizer,
283 | mode=mode,
284 | follow_batch=follow_batch,
285 | exclude_keys=exclude_keys
286 | )
287 | super().__init__(
288 | dataset,
289 | batch_size=batch_size,
290 | shuffle=shuffle,
291 | collate_fn=collater,
292 | **kwargs
293 | )
294 |
295 | def __len__(self):
296 | """Mark class as Sized."""
297 | return super().__len__()
298 |
299 | def __iter__(self):
300 | """Mark class as Iterable."""
301 | return super().__iter__()
302 |
--------------------------------------------------------------------------------
/dataset/dataloader_derived.py:
--------------------------------------------------------------------------------
1 | """
2 | Derived DataLoader class for protein function prediction supervised fine tuning.
3 | To be used with Prot2TextInstructDataset for Esm2LlamaForCausalLM.
4 |
5 | The `Prot2TextInstructDataset` is designed to be preprocessed with an instruct
6 | model (ex. `Meta-Llama-3.1-8B-Instruct-hf`), but can be adapted to other base
7 | models (ex. `Llama-3.2-1B`) if tokenizers of both models share the same
8 | vocabulary.
9 |
10 | This derived version (`Prot2TextDerivedDataLoader`) is thus designed to adapt
11 | datasets that are already preprocessed for the instruct model. It replaces
12 | special tokens and reorganize tokenized input ids, making them suitable for the
13 | base language model.
14 |
15 | Every batch from DataLoader will contain following attributes:
16 | * Training mode (train-eval with teacher-forcing):
17 | - graph related features:
18 | - x: (sum_num_nodes, num_node_features)
19 | - edge_index: (2, sum_num_edges)
20 | - edge_type: (sum_num_edges,)
21 | - batch: (sum_num_nodes,)
22 | - amino-acid sequence:
23 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
24 | - protein_attention_mask (bsz, max_seq_len+2) # left padding
25 | - concatenated prompt and description:
26 | - input_ids (bsz, prompt_len+1+max_text_len+1) # bos and eos tokens
27 | - attention_mask (bsz, prompt_len+1+max_text_len+1) # right padding
28 | - labels (bsz, prompt_len+1+max_text_len+1)
29 | - standalone description for reward model training:
30 | - description_input_ids (bsz, max_text_len+1) # eos token only
31 | - description_attention_mask (bsz, max_text_len+1) # right padding
32 |
33 | ids = [bos + prompt + bos & description + eos + right-pad]
34 | mask = [1 + 1s + 1 & 1s + 1 + 0s ]
35 | labels = [-100 + -100s + -100 & description + eos + -100s ]
36 | desc_ids = [ & description + eos + right-pad]
37 | desc_mask = [ & 1s + 1 + 0s ]
38 |
39 | * Inference mode (iterative generation):
40 | - graph related features:
41 | - x: (sum_num_nodes, num_node_features)
42 | - edge_index: (2, sum_num_edges)
43 | - edge_type: (sum_num_edges,)
44 | - batch: (sum_num_nodes,)
45 | - amino-acid sequence:
46 | - protein_input_ids (bsz, max_seq_len+2) # bos and eos tokens
47 | - protein_attention_mask (bsz, max_seq_len+2) # left padding
48 | - prompt with inference head:
49 | - input_ids (bsz, prompt_len+1) # bos token at the end
50 | - attention_mask (bsz, prompt_len+1)
51 | - standalone description for evaluation:
52 | - description_input_ids (bsz, max_text_len+1)
53 | - description_attention_mask (bsz, max_text_len+1)
54 |
55 | ids = [bos + prompt + bos & ]
56 | mask = [1 + 1s + 1 & ]
57 | desc_ids = [ & description + eos + right-pad]
58 | desc_mask = [ & 1s + 1 + 0s ]
59 |
60 | Example of usage:
61 | >>> from transformers import AutoTokenizer
62 | >>> from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader
63 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D")
64 | >>> llama_tokenizer = AutoTokenizer.from_pretrained(
65 | "/data/Llama-3.2-1B",
66 | pad_token='<|reserved_special_token_0|>'
67 | )
68 | >>> train_dataset = Prot2TextInstructDataset(
69 | root_dir="/data/Prot2Text-Llama3-Data/train",
70 | csv_path="./data/train.csv",
71 | sequence_tokenizer=esm_tokenizer,
72 | description_tokenizer=llama_tokenizer, # pass the base model tokenizer
73 | skip_download=True, # assume data is already downloaded
74 | skip_reload=True, # assume data is already preprocessed
75 | )
76 | >>> train_dataloader = Prot2TextDerivedDataLoader(
77 | dataset=train_dataset,
78 | mode="train",
79 | batch_size=2,
80 | shuffle=True,
81 | )
82 | """
83 |
84 |
85 | from typing import Dict, List, Literal, Optional, Union
86 |
87 | import torch
88 | import torch.utils.data
89 | import torch_geometric
90 | import torch_geometric.data
91 | import torch_geometric.loader.dataloader
92 | from transformers import PreTrainedTokenizer
93 |
94 | from .dataset import Prot2TextInstructDataset
95 |
96 |
97 | class Prot2TextDerivedCollater(torch_geometric.loader.dataloader.Collater):
98 | def __init__(
99 | self,
100 | dataset: Prot2TextInstructDataset,
101 | tokenizer: PreTrainedTokenizer,
102 | mode: Literal["train", "inference"],
103 | original_eos_token_id: int,
104 | prompt_sentence: str,
105 | **kwargs,
106 | ):
107 | super().__init__(dataset=dataset, **kwargs)
108 | self.tokenizer = tokenizer
109 | self.mode = mode
110 | self.prompt_sentence = prompt_sentence
111 |
112 | self.prompt_input_ids = tokenizer(
113 | [tokenizer.bos_token + prompt_sentence + tokenizer.bos_token],
114 | add_special_tokens=False,
115 | return_tensors="pt",
116 | return_attention_mask=False,
117 | )["input_ids"]
118 |
119 | self.seq_pad_token_id = dataset.sequence_tokenizer.pad_token_id
120 | self.text_pad_token_id = tokenizer.pad_token_id
121 | self.old_text_eos_token_id = original_eos_token_id
122 | self.new_text_eos_token_id = tokenizer.eos_token_id
123 |
124 | def __call__(
125 | self,
126 | batch: List[Dict[str, Union[str, torch.Tensor]]]
127 | ) -> Dict[str, torch.Tensor]:
128 | # prepare graph related features and name
129 | data_batch = torch_geometric.data.Batch.from_data_list(
130 | batch,
131 | exclude_keys=[
132 | "sequence_input_ids",
133 | "prompt_input_ids",
134 | "description_input_ids",
135 | ]
136 | )
137 |
138 | # prepare attn mask, left pad and stack sequences
139 | pad_sequence_input_ids = self._pad_sequence(
140 | [data["sequence_input_ids"][0] for data in batch],
141 | padding_value=self.seq_pad_token_id,
142 | padding_side="left"
143 | )
144 | pad_sequence_attention_mask = self._pad_sequence(
145 | [torch.ones_like(data["sequence_input_ids"][0]) for data in batch],
146 | padding_value=0,
147 | padding_side="left"
148 | )
149 |
150 | # prepare attn mask and expand prompts
151 | pad_prompt_input_ids = self.prompt_input_ids.repeat(len(batch), 1).to(
152 | pad_sequence_input_ids.device
153 | )
154 | pad_prompt_attention_mask = torch.ones_like(pad_prompt_input_ids)
155 |
156 | # prepare attn mask, right pad and stack descriptions
157 | description_input_ids = [data["description_input_ids"][0] for data in batch]
158 | pad_description_input_ids = self._pad_sequence(
159 | description_input_ids,
160 | padding_value=self.text_pad_token_id,
161 | padding_side="right"
162 | )
163 | pad_description_attention_mask = self._pad_sequence(
164 | [torch.ones_like(data["description_input_ids"][0]) for data in batch],
165 | padding_value=0,
166 | padding_side="right"
167 | )
168 | pad_labels = self._pad_sequence(
169 | description_input_ids,
170 | padding_value=-100,
171 | padding_side="right"
172 | )
173 |
174 | # replace special tokens in descriptions
175 | pad_description_input_ids.masked_fill_(
176 | pad_description_input_ids == self.old_text_eos_token_id,
177 | self.new_text_eos_token_id
178 | )
179 | pad_labels.masked_fill_(
180 | pad_labels == self.old_text_eos_token_id,
181 | self.new_text_eos_token_id
182 | )
183 |
184 | # decode back the descriptions in text
185 | descriptions = self.tokenizer.batch_decode(
186 | description_input_ids,
187 | skip_special_tokens=True,
188 | )
189 |
190 | # update text features
191 | if self.mode == "train":
192 | data_batch.update({
193 | "input_ids": torch.cat([
194 | pad_prompt_input_ids,
195 | pad_description_input_ids
196 | ], dim=1),
197 | "attention_mask": torch.cat([
198 | pad_prompt_attention_mask,
199 | pad_description_attention_mask
200 | ], dim=1),
201 | "labels": torch.cat([
202 | torch.full_like(
203 | pad_prompt_input_ids,
204 | fill_value=-100
205 | ),
206 | pad_labels
207 | ], dim=1),
208 | "protein_input_ids": pad_sequence_input_ids,
209 | "protein_attention_mask": pad_sequence_attention_mask,
210 | "description_input_ids": pad_description_input_ids,
211 | "description_attention_mask": pad_description_attention_mask,
212 | "descriptions": descriptions,
213 | })
214 | elif self.mode == "inference":
215 | data_batch.update({
216 | "input_ids": pad_prompt_input_ids,
217 | "attention_mask": pad_prompt_attention_mask,
218 | "description_input_ids": pad_description_input_ids,
219 | "description_attention_mask": pad_description_attention_mask,
220 | "protein_input_ids": pad_sequence_input_ids,
221 | "protein_attention_mask": pad_sequence_attention_mask,
222 | })
223 | else:
224 | raise ValueError(f"Invalid mode: {self.mode}")
225 |
226 | return data_batch
227 |
228 | @staticmethod
229 | def _pad_sequence(
230 | sequences: List[torch.Tensor],
231 | padding_value: Union[float, int],
232 | padding_side: Literal["left", "right"] = "right",
233 | ) -> torch.Tensor:
234 | """
235 | Modified version of torch.nn.utils.rnn.pad_sequence with optional
236 | padding side.
237 | Such feature is naturally supported by PyTorch 2.6.0+ and it's
238 | recommended to use the built-in version for better efficiency.
239 |
240 | * sequences as input must be a list of 1D tensors.
241 | """
242 | max_len = max(sequence.shape[-1] for sequence in sequences)
243 | padded_sequences = []
244 | for sequence in sequences:
245 | padding = torch.full(
246 | size=(max_len - sequence.shape[-1],),
247 | fill_value=padding_value,
248 | dtype=sequence.dtype,
249 | device=sequence.device,
250 | )
251 | if padding_side == "left":
252 | padded_sequences.append(torch.cat([padding, sequence], dim=-1))
253 | elif padding_side == "right":
254 | padded_sequences.append(torch.cat([sequence, padding], dim=-1))
255 | else:
256 | raise ValueError(f"Invalid padding side: {padding_side}")
257 | return torch.stack(padded_sequences, dim=0)
258 |
259 |
260 | class Prot2TextDerivedDataLoader(torch.utils.data.DataLoader):
261 | """
262 | DataLoader class proteins, forming batch inputs.
263 |
264 | (1) Compose graph related features with PyG's Batch.from_data_list;
265 | (2) dynamically pad sequences, prompts and descriptions;
266 | (3) replace special tokens and reorganize tokenized input ids;
267 | (4) stack then concatenate these text features under different modes.
268 |
269 | Args:
270 | dataset:
271 | `Prot2TextInstructDataset` class to load data from. Both
272 | `Prot2TextInstructDataLoader` and `Prot2TextDerivedDataLoader` should
273 | use the same dataset, but with different sequence tokenizers. Read
274 | the docstring of `Prot2TextInstructDataset` for more details.
275 | mode:
276 | - "train": training-evaluation with teacher-forcing. Input ids will
277 | be concatenated with labels (prompt + description) for
278 | training.
279 | - "inference": iterative generation. Input ids will only contain
280 | prompt (prompt + bos) for generation.
281 | batch_size:
282 | Number of samples per batch.
283 | shuffle:
284 | Whether to shuffle the data. If a sampler is provided, the shuffling
285 | behavior should be controlled by the sampler and this argument
286 | should be set to False.
287 | follow_batch:
288 | PyG specific feature to be passed to Collater class. When working
289 | with batched graph data, follow_batch is used to indicate which
290 | attributes should an extra batch indices tensor be created for. If
291 | not set, an extra `batch` tensor will be automatically created to
292 | indicate which graph each node belongs to.
293 | exclude_keys:
294 | List of keys to exclude from the batch. The exclusion will be applied
295 | at the very end of the collation process.
296 | original_eos_token_id:
297 | End-of-sequence token id of the instruct model tokenizer that is used
298 | in the preprocessing stage of the dataset. Such token will be
299 | replaced by the eos token id of the base model tokenizer that is
300 | given in the derived scenario.
301 | prompt_sentence:
302 | Prompt sentence to be used in the derived scenario.
303 | kwargs:
304 | Additional arguments to pass to DataLoader class.
305 | """
306 | def __init__(
307 | self,
308 | dataset: Prot2TextInstructDataset,
309 | mode: Literal["train", "inference"] = "train",
310 | batch_size: int = 1,
311 | shuffle: bool = True,
312 | follow_batch: Optional[List[str]] = None,
313 | exclude_keys: Optional[List[str]] = None,
314 | original_eos_token_id: int = 128009,
315 | prompt_sentence: str = (
316 | "Predict protein description based on the amino-acid sequence embeddings."
317 | ),
318 | **kwargs,
319 | ):
320 | # override collate_fn
321 | kwargs.pop("collate_fn", None)
322 | self.follow_batch = follow_batch
323 | self.exclude_keys = exclude_keys
324 | collater = Prot2TextDerivedCollater(
325 | dataset=dataset,
326 | tokenizer=dataset.description_tokenizer,
327 | mode=mode,
328 | follow_batch=follow_batch,
329 | exclude_keys=exclude_keys,
330 | original_eos_token_id=original_eos_token_id,
331 | prompt_sentence=prompt_sentence,
332 | )
333 | super().__init__(
334 | dataset,
335 | batch_size=batch_size,
336 | shuffle=shuffle,
337 | collate_fn=collater,
338 | **kwargs
339 | )
340 |
341 | def __len__(self):
342 | """Mark class as Sized."""
343 | return super().__len__()
344 |
345 | def __iter__(self):
346 | """Mark class as Iterable."""
347 | return super().__iter__()
348 |
--------------------------------------------------------------------------------
/scripts/train_legacy.py:
--------------------------------------------------------------------------------
1 | """
2 | FullyShardedDataParallel / DistributedDataParallel training script implemented
3 | from scratch.
4 |
5 | The script currently supports gradient accumulation, AutoMixedPrecision,
6 | and inter-epoch evaluation.
7 |
8 | The script currently does not support save/load pretrained or gradient checkpointing.
9 |
10 | reference for FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
11 | reference for AMP: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
12 |
13 | * The script is designed for multi-GPU parallelism on single node.
14 | * On the cluster, print(...) will go to stdout and tqdm(...) will go to stderr.
15 | """
16 |
17 | import argparse
18 | from datetime import datetime
19 | import functools
20 | import os
21 | from typing import Any, Dict, Optional, Tuple, Union
22 |
23 | import torch
24 | from torch.amp import GradScaler
25 | from torch.cuda.amp import autocast
26 | import torch.distributed as dist
27 | from torch.distributed.fsdp import (
28 | FullyShardedDataParallel,
29 | FullStateDictConfig,
30 | StateDictType
31 | )
32 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
33 | from torch.nn.parallel import DistributedDataParallel
34 | from torch.optim import AdamW, Optimizer
35 | from torch.optim.lr_scheduler import StepLR
36 | from torch.utils.data.distributed import DistributedSampler
37 | from tqdm import tqdm
38 | from transformers import AutoTokenizer, PreTrainedModel
39 |
40 | from dataset import Prot2TextInstructDataset, Prot2TextDerivedDataLoader
41 | from models import Esm2LlamaForCausalLM
42 | import scripts.utils_argparse as utils_argparse
43 |
44 |
45 | argParser = argparse.ArgumentParser()
46 |
47 | argParser.add_argument("--esm_path", type=str)
48 | argParser.add_argument("--llama_path", type=str)
49 | argParser.add_argument("--root_dataset_dir", type=str)
50 | argParser.add_argument("--root_csv_dir", type=str)
51 | argParser.add_argument("--save_checkpoint_dir", type=str)
52 | argParser.add_argument("--load_general_checkpoint_path", type=str, default="")
53 |
54 | argParser.add_argument("--wrap_model", type=utils_argparse.str2bool)
55 | argParser.add_argument("--autocast_dtype", type=utils_argparse.str2dtype)
56 | argParser.add_argument("--batch_size_per_device", type=int)
57 | argParser.add_argument("--num_epochs", type=int)
58 | argParser.add_argument("--save_every_epochs", type=int)
59 | argParser.add_argument("--gradient_accumulation_steps", type=int)
60 | argParser.add_argument("--learning_rate", type=float)
61 | argParser.add_argument("--gradient_clipping", type=float, default=None)
62 | argParser.add_argument("--scheduler_gamma", type=float)
63 | argParser.add_argument("--random_seed", type=int)
64 | argParser.add_argument("--train_split", type=str)
65 | argParser.add_argument("--eval_split", type=str)
66 | argParser.add_argument("--debug_trim_train_split", type=int, default=None)
67 | argParser.add_argument("--debug_trim_eval_split", type=int, default=None)
68 | argParser.add_argument("--max_sequence_length", type=int, default=None)
69 | argParser.add_argument("--max_description_length", type=int, default=None)
70 |
71 |
72 | def load_model(
73 | esm_path: Union[str, os.PathLike],
74 | llama_path: Union[str, os.PathLike],
75 | load_general_checkpoint_path: Optional[Union[str, os.PathLike]] = None
76 | ) -> PreTrainedModel:
77 | """
78 | Standard API for different models. Used in both `train` and `generate`.
79 | Load base model of the given name, and load weights from the checkpoint path if provided.
80 | Returned model is on CPU by default.
81 | `load_general_checkpoint_path` will be ignored if `load_checkpoint_path` is provided.
82 | """
83 | model = Esm2LlamaForCausalLM.from_pretrained(
84 | pretrained_esm_model_name_or_path=esm_path,
85 | pretrained_llama_model_name_or_path=llama_path,
86 | esm_kwargs={"decoder_hidden_size": 2048}
87 | )
88 |
89 | # load checkpoint if any
90 | if load_general_checkpoint_path:
91 | print(f"Loading {load_general_checkpoint_path}")
92 | checkpoint_state_dicts = torch.load(load_general_checkpoint_path, weights_only=True)
93 | model_state_dict = checkpoint_state_dicts["model_state_dict"]
94 | model.load_state_dict(model_state_dict)
95 |
96 | return model
97 |
98 |
99 | def teacher_forcing_forward_pass(
100 | rank: int,
101 | model: torch.nn.Module,
102 | data_batch: Dict[str, Any],
103 | ) -> Tuple[torch.Tensor]:
104 | """
105 | Standard API for different models. Used in both `train_epoch` and `eval_epoch`.
106 | 1) Prepare inputs for the forward pass with teacher forcing using data_batch from dataloader.
107 | 2) Execute the forward pass and return the direct output.
108 | Returned loss is not scaled with gradient accumulation steps.
109 | Returned logits are un-normalized predictions representing the scores for each token in the vocabulary.
110 | """
111 | return model(
112 | input_ids=data_batch["input_ids"].to(rank),
113 | attention_mask=data_batch["attention_mask"].to(rank),
114 | labels=data_batch["labels"].to(rank),
115 | protein_input_ids=data_batch["protein_input_ids"].to(rank),
116 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank),
117 | use_cache=False,
118 | output_attentions=False,
119 | output_hidden_states=False,
120 | return_dict=False, # force return tuple (loss, logits)
121 | return_encoder_output=False,
122 | )
123 |
124 |
125 | def setup(rank: int, world_size: int):
126 | """
127 | Initialize processes for distributed training before first epoch.
128 | Fetch from job script or launcher to set the IP address and the port of the master node.
129 | """
130 | os.environ['MASTER_ADDR'] = os.getenv('MASTER_ADDR', 'localhost')
131 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '9901')
132 | # initialize the process group
133 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
134 |
135 |
136 | def cleanup():
137 | """End processes for distributed training after last epoch"""
138 | dist.destroy_process_group()
139 |
140 |
141 | def train_epoch(
142 | rank: int,
143 | current_epoch: int,
144 | model: Union[DistributedDataParallel, FullyShardedDataParallel],
145 | dataloader: Prot2TextDerivedDataLoader,
146 | optimizer: Optimizer,
147 | scaler: GradScaler,
148 | args: Dict[str, Any]
149 | ):
150 | """Iterate over all batches for one epoch in training with teacher forcing"""
151 | model.train()
152 | ddp_loss = torch.zeros(2).to(rank) # [0] for acc. loss and [1] for num. of seen batches
153 | ddp_gradnorm = torch.zeros(2).to(rank) # [0] for acc. gradnorm and [1] for num. of passed steps
154 | optimizer.zero_grad() # erase accumulated gradients from last epoch
155 |
156 | t = tqdm(iter(dataloader))
157 | for batch_idx, data_batch in enumerate(t):
158 | # with autocast, logits will be in AUTOCAST_DTYPE and loss will be re-casted to torch.float32
159 | with autocast(dtype=args["autocast_dtype"]):
160 | output = teacher_forcing_forward_pass(
161 | rank=rank,
162 | model=model,
163 | data_batch=data_batch,
164 | )
165 |
166 | # rescale loss for consistency with different gradient accumulation steps
167 | loss = output[0] / args["gradient_accumulation_steps"]
168 |
169 | # summary current batch
170 | t.set_postfix({
171 | "mode": "train",
172 | "epoch": f"{current_epoch}/{args['num_epochs']}",
173 | "batch_loss": loss.item() * args["gradient_accumulation_steps"],
174 | "device": f"rank:{rank}"
175 | })
176 | ddp_loss[0] += loss.item() * args["gradient_accumulation_steps"]
177 | ddp_loss[1] += 1 # the loss is the weighted mean of the output of every batch
178 |
179 | # scale the loss up by a large factor to prevent them from becoming too small, then accumulate the scaled grads
180 | scaler.scale(loss).backward() # backward out of autocast, but still uses same dtype as for forward
181 |
182 | # update weights by loss if accumulation step is reached
183 | if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0: # Perform optimizer step after accumulation
184 | scaler.unscale_(optimizer) # unscale gradients for gradient examination and clipping
185 | gradnorm = torch.nn.utils.clip_grad_norm_(
186 | model.parameters(),
187 | max_norm=float("inf") if args["gradient_clipping"] is None else args["gradient_clipping"]
188 | )
189 | ddp_gradnorm[0] += gradnorm
190 | ddp_gradnorm[1] += 1
191 |
192 | scaler.step(optimizer) # first unscale the gradients, then do step only if no INF or NaN is in grad
193 | scaler.update()
194 | optimizer.zero_grad(set_to_none=True)
195 |
196 | # summary current epoch
197 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
198 | if rank == 0:
199 | print(
200 | f"[epoch={current_epoch}/{args['num_epochs']}, "
201 | f"train_loss={ddp_loss[0] / ddp_loss[1]}, "
202 | f"epoch_lr={optimizer.param_groups[0]['lr']}, "
203 | f"epoch_gradnorm={ddp_gradnorm[0] / ddp_gradnorm[1]}]"
204 | )
205 | # NaN detection
206 | if ddp_loss[0] != ddp_loss[0]:
207 | raise ValueError("NaN detected in the training loss of the epoch, training interrupted.")
208 |
209 |
210 | def eval_epoch(
211 | rank: int,
212 | current_epoch: int,
213 | model: Union[DistributedDataParallel, FullyShardedDataParallel],
214 | dataloader: Prot2TextDerivedDataLoader,
215 | args: Dict[str, Any]
216 | ):
217 | """Iterate over all batches in evaluation with teacher forcing"""
218 | model.eval()
219 | ddp_loss = torch.zeros(2).to(rank) # [0] for acc. loss and [1] for num. of seen batches
220 |
221 | t = tqdm(iter(dataloader))
222 | for data_batch in t:
223 | with torch.no_grad():
224 | output = teacher_forcing_forward_pass(
225 | rank=rank,
226 | model=model,
227 | data_batch=data_batch,
228 | )
229 |
230 | loss = output[0]
231 | t.set_postfix({
232 | "mode": "eval",
233 | "epoch": f"{current_epoch}/{args['num_epochs']}",
234 | "batch_loss": loss.item(),
235 | "device": f"rank:{rank}"
236 | })
237 | ddp_loss[0] += loss.item()
238 | ddp_loss[1] += 1
239 |
240 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
241 | if rank == 0:
242 | print(f"[epoch={current_epoch}/{args['num_epochs']}, eval_loss={ddp_loss[0] / ddp_loss[1]}]")
243 |
244 |
245 | def train_on_device(
246 | rank: int,
247 | world_size: int,
248 | args: Dict[str, Any]
249 | ):
250 | """Training and evaluation process for each device, including epochs of training with teacher forcing"""
251 | setup(rank, world_size)
252 |
253 | # prepare datasets and dataloaders
254 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"])
255 | llama_tokenizer = AutoTokenizer.from_pretrained(args["llama_path"], pad_token='<|reserved_special_token_0|>')
256 |
257 | train_dataset = Prot2TextInstructDataset(
258 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['train_split']}"),
259 | csv_path=os.path.join(args["root_csv_dir"], f"{args['train_split']}.csv"),
260 | sequence_tokenizer=esm_tokenizer,
261 | description_tokenizer=llama_tokenizer,
262 | skip_reload=True,
263 | skip_download=True,
264 | ignore_graph_features=False,
265 | max_sequence_length=args["max_sequence_length"],
266 | max_description_length=args["max_description_length"],
267 | )
268 | if args["debug_trim_train_split"]:
269 | train_dataset.usable_file_names = train_dataset.usable_file_names[:args["debug_trim_train_split"]]
270 | train_sampler = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
271 | train_loader = Prot2TextDerivedDataLoader(
272 | train_dataset,
273 | batch_size=args["batch_size_per_device"],
274 | sampler=train_sampler,
275 | num_workers=4, # parallel CPU cores used for data loading
276 | pin_memory=True, # enable page-locked memory allocation for faster data transfer to GPUs
277 | shuffle=False,
278 | drop_last=True, # avoid incomplete batch at the end
279 | )
280 |
281 | eval_dataset = Prot2TextInstructDataset(
282 | root_dir=os.path.join(args["root_dataset_dir"], f"{args['eval_split']}"),
283 | csv_path=os.path.join(args["root_csv_dir"], f"{args['eval_split']}.csv"),
284 | sequence_tokenizer=esm_tokenizer,
285 | description_tokenizer=llama_tokenizer,
286 | skip_reload=True,
287 | skip_download=True,
288 | ignore_graph_features=False,
289 | max_sequence_length=args["max_sequence_length"],
290 | max_description_length=args["max_description_length"],
291 | )
292 | if args["debug_trim_eval_split"]:
293 | eval_dataset.usable_file_names = eval_dataset.usable_file_names[:args["debug_trim_eval_split"]]
294 | eval_sampler = DistributedSampler(eval_dataset, rank=rank, num_replicas=world_size, shuffle=False)
295 | eval_loader = Prot2TextDerivedDataLoader(
296 | eval_dataset,
297 | batch_size=args["batch_size_per_device"],
298 | sampler=eval_sampler,
299 | num_workers=4,
300 | pin_memory=True,
301 | shuffle=False,
302 | drop_last=True,
303 | )
304 |
305 | torch.cuda.set_device(rank)
306 |
307 | model = load_model(
308 | esm_path=args["esm_path"],
309 | llama_path=args["llama_path"],
310 | load_general_checkpoint_path=args["load_general_checkpoint_path"],
311 | )
312 | model = model.to(rank)
313 |
314 | if args["wrap_model"]:
315 | # shard all layers with size of parameters greater than min_num_params
316 | my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=10000)
317 | model = FullyShardedDataParallel(model, auto_wrap_policy=my_auto_wrap_policy)
318 | print(f"FSDP model loaded on rank:{rank}")
319 | else:
320 | model = DistributedDataParallel(model, find_unused_parameters=True)
321 | print(f"DDP model loaded on rank:{rank}")
322 |
323 | # initialization of the optimizer after wrapping the model
324 | optimizer = AdamW(model.parameters(), lr=args["learning_rate"])
325 | scheduler = StepLR(optimizer, step_size=1, gamma=args["scheduler_gamma"])
326 | if args["load_general_checkpoint_path"]:
327 | checkpoint_state_dicts = torch.load(args["load_general_checkpoint_path"], weights_only=True)
328 | optimizer_state_dict = checkpoint_state_dicts["optimizer_state_dict"]
329 | scheduler_state_dict = checkpoint_state_dicts["scheduler_state_dict"]
330 | optimizer.load_state_dict(optimizer_state_dict)
331 | scheduler.load_state_dict(scheduler_state_dict)
332 |
333 | # initialization of scaler for mixed precision and control of gradient accumulation
334 | grad_scaler = GradScaler("cuda")
335 |
336 | # core loop of epochs
337 | for epoch_idx in range(1, args["num_epochs"] + 1):
338 | # shuffle data differently at each epoch across all processes
339 | train_sampler.set_epoch(epoch=epoch_idx)
340 |
341 | train_epoch(
342 | rank=rank,
343 | current_epoch=epoch_idx,
344 | model=model,
345 | dataloader=train_loader,
346 | optimizer=optimizer,
347 | scaler=grad_scaler,
348 | args=args
349 | )
350 | scheduler.step()
351 | dist.barrier() # use a barrier to make sure training is done on all ranks
352 |
353 | eval_epoch(
354 | rank=rank,
355 | model=model,
356 | current_epoch=epoch_idx,
357 | dataloader=eval_loader,
358 | args=args
359 | )
360 | dist.barrier()
361 |
362 | # save model checkpoint with CPU offload to avoid CUDA OOM, save/load_pretrained not available in FSDP
363 | if epoch_idx == 1 or epoch_idx == args["num_epochs"] or epoch_idx % args["save_every_epochs"] == 0:
364 | if args["wrap_model"]: # FSDP
365 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
366 | with FullyShardedDataParallel.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
367 | model_state_dict = model.state_dict()
368 | else: # DDP
369 | model_state_dict = model.module.state_dict()
370 | if rank == 0:
371 | checkpoint_path = os.path.join(args["save_checkpoint_dir"], f"general_checkpoint_{epoch_idx}.pt")
372 | torch.save(
373 | {
374 | "model_state_dict": model_state_dict,
375 | "optimizer_state_dict": optimizer.state_dict(),
376 | "scheduler_state_dict": scheduler.state_dict(),
377 | },
378 | checkpoint_path
379 | )
380 | print(f"Saving {checkpoint_path}")
381 | dist.barrier()
382 |
383 | cleanup()
384 |
385 |
386 | def train_distributed(args: Dict[str, Any]):
387 | """
388 | Core training process across multiple devices with epochs of training and inter-epoch evaluation.
389 | Use args: Dict[str, Any] instead of **kwargs for compatibility with torch.multiprocessing.spawn.
390 | """
391 | torch.multiprocessing.spawn(
392 | train_on_device,
393 | args=(args["world_size"], args),
394 | nprocs=args["world_size"],
395 | join=True
396 | )
397 |
398 |
399 | if __name__ == '__main__':
400 | # suppress messages from AutoTokenizer parallelism and Graphein respectively
401 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
402 | os.environ["LOGURU_LEVEL"] = "INFO"
403 |
404 | parsed_args = argParser.parse_args()
405 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict PyTorch to see only the specified GPUs
406 | parsed_args.world_size = torch.cuda.device_count() # use up all available devices across nodes
407 |
408 | torch.manual_seed(parsed_args.random_seed)
409 | torch.cuda.manual_seed(parsed_args.random_seed)
410 |
411 | # initialize checkpoint directory
412 | timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
413 | parsed_args.save_checkpoint_dir = os.path.join(parsed_args.save_checkpoint_dir, f"checkpoints_{timestamp}")
414 | if not os.path.exists(parsed_args.save_checkpoint_dir):
415 | os.mkdir(parsed_args.save_checkpoint_dir)
416 |
417 | print("####################")
418 | for key, value in parsed_args.__dict__.items():
419 | print(f"{key}: {value}")
420 | print("####################")
421 |
422 | train_distributed(parsed_args.__dict__)
423 |
--------------------------------------------------------------------------------
/scripts/train_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Stage 2 - instruction tuning training script for ESM-LLAMA protein description
3 | generation on Esm2LlamaInstructForCausalLM model.
4 |
5 | With LoRA.
6 |
7 | DistributedDataParallel training script implemented from scratch.
8 |
9 | The script currently supports gradient accumulation, AutoMixedPrecision,
10 | inter-epoch evaluation.
11 |
12 | The script currently does not support save/load pretrained, gradient checkpointing
13 | or generation under FSDP.
14 |
15 | reference for AMP: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
16 |
17 | * The script is designed for multi-GPU parallelism on single node.
18 | * On the cluster, print(...) will go to stdout and tqdm(...) will go to stderr.
19 | """
20 |
21 | import argparse
22 | from datetime import datetime
23 | import os
24 | from typing import Any, Dict, Union
25 |
26 | from peft import get_peft_model, LoraConfig
27 | from peft.peft_model import PeftModel
28 | import torch
29 | import torch.distributed as dist
30 | from torch.distributed.fsdp import FullyShardedDataParallel
31 | from torch.nn.parallel import DistributedDataParallel
32 | from torch.optim import Adam, AdamW, Optimizer
33 | from torch.optim.lr_scheduler import StepLR
34 | from torch.utils.data import DataLoader
35 | from torch.utils.data.distributed import DistributedSampler
36 | from tqdm import tqdm
37 | from transformers import AutoTokenizer
38 | from transformers import EsmModel, LlamaForCausalLM
39 |
40 | from dataset import Prot2TextLightDataset, Prot2TextLightCollater
41 | from models import (
42 | ModalityAdapter,
43 | ModalityAdapterConfig,
44 | Esm2LlamaInstructForCausalLM
45 | )
46 | import scripts.utils_argparse as utils_argparse
47 |
48 |
49 | argParser = argparse.ArgumentParser()
50 |
51 | argParser.add_argument("--esm_path", type=str)
52 | argParser.add_argument("--llama_path", type=str)
53 | # argParser.add_argument("--root_dataset_dir", type=str)
54 | argParser.add_argument("--root_csv_dir", type=str)
55 | argParser.add_argument("--save_checkpoint_dir", type=str)
56 | argParser.add_argument("--load_model_checkpoint_path", type=str, default="")
57 | argParser.add_argument("--load_adapter_checkpoint_dir", type=str, default="")
58 | argParser.add_argument("--load_optimizer_scheduler_checkpoint_path", type=str, default="")
59 |
60 | argParser.add_argument("--torch_dtype", type=utils_argparse.str2dtype)
61 | argParser.add_argument("--batch_size_per_device", type=int)
62 | argParser.add_argument("--num_epochs", type=int)
63 | argParser.add_argument("--save_every_epochs", type=int)
64 | argParser.add_argument("--gradient_accumulation_steps", type=int)
65 | argParser.add_argument("--learning_rate", type=float)
66 | argParser.add_argument("--gradient_clipping", type=float, default=None)
67 | argParser.add_argument("--scheduler_gamma", type=float)
68 | argParser.add_argument("--random_seed", type=int)
69 | argParser.add_argument("--fix_modality_adapter", type=utils_argparse.str2bool)
70 | argParser.add_argument("--lora_rank", type=int)
71 |
72 | argParser.add_argument("--include_text_fields", type=utils_argparse.str2bool)
73 | argParser.add_argument("--name_dropout", type=float)
74 | argParser.add_argument("--taxonomy_dropout", type=float)
75 |
76 | argParser.add_argument("--train_split", type=str)
77 | argParser.add_argument("--eval_split", type=str)
78 | argParser.add_argument("--debug_trim_train_split", type=int, default=None)
79 | argParser.add_argument("--debug_trim_eval_split", type=int, default=None)
80 |
81 |
82 | def load_model(args: Dict[str, Any]) -> PeftModel:
83 | """
84 | Standard API for different models. Used in both `train` and `generate`.
85 | Load base model of the given name, and load weights from the checkpoint path
86 | if provided.
87 | """
88 | esm_encoder = EsmModel.from_pretrained(
89 | args["esm_path"],
90 | add_pooling_layer=False,
91 | torch_dtype=args["torch_dtype"],
92 | device_map="cpu"
93 | )
94 | llama_decoder = LlamaForCausalLM.from_pretrained(
95 | args["llama_path"],
96 | torch_dtype=args["torch_dtype"],
97 | device_map="cpu"
98 | )
99 |
100 | adapter_config = ModalityAdapterConfig(
101 | input_dim=esm_encoder.config.hidden_size,
102 | intermediate_dim=2048,
103 | output_dim=llama_decoder.config.hidden_size,
104 | )
105 | adapter = ModalityAdapter(adapter_config)
106 | adapter.to(args["torch_dtype"])
107 |
108 | model = Esm2LlamaInstructForCausalLM(
109 | esm_encoder=esm_encoder,
110 | adapter=adapter,
111 | llama_decoder=llama_decoder,
112 | )
113 |
114 | # overwrite weights of base model if checkpoint path is provided
115 | if args["load_model_checkpoint_path"]:
116 | print(f"Loading {args['load_model_checkpoint_path']}")
117 | model_state_dict = torch.load(
118 | args["load_model_checkpoint_path"],
119 | weights_only=True,
120 | map_location="cpu" # load to CPU first
121 | # will be loaded to where the weights were saved from if not specified
122 | )
123 | model.load_state_dict(model_state_dict)
124 |
125 | # wrap by lora either with pretrained adapter or with initialized adapter
126 | if args["load_adapter_checkpoint_dir"]:
127 | print(f"Loading {args['load_adapter_checkpoint_dir']}")
128 | model = PeftModel.from_pretrained(
129 | model,
130 | args["load_adapter_checkpoint_dir"],
131 | is_trainable=True
132 | )
133 | else:
134 | print("Initializing LoRA adapter")
135 | lora_config = LoraConfig(
136 | r=args["lora_rank"],
137 | lora_alpha=args["lora_rank"] * 2,
138 | lora_dropout=0.1,
139 | bias="none",
140 | init_lora_weights=True,
141 | target_modules=[
142 | "self_attn.q_proj",
143 | "self_attn.k_proj",
144 | "self_attn.v_proj",
145 | "self_attn.o_proj",
146 | "mlp.gate_proj",
147 | "mlp.up_proj",
148 | "mlp.down_proj"
149 | ], # for llama_decoder
150 | modules_to_save=(
151 | ["adapter.fc1", "adapter.fc2"]
152 | if not args["fix_modality_adapter"]
153 | else None
154 | )
155 | )
156 | model = get_peft_model(model, lora_config)
157 |
158 | model.print_trainable_parameters()
159 |
160 | return model
161 |
162 |
163 | def teacher_forcing_forward_pass(
164 | rank: int,
165 | model: Union[DistributedDataParallel, FullyShardedDataParallel],
166 | data_batch: Dict[str, Any],
167 | ) -> torch.Tensor: # loss
168 | """
169 | Standard API for different models. Used in both `train_epoch` and `eval_epoch`.
170 | Prepare inputs from dataloader, migrate variable to the same device as the model,
171 | and execute the forward pass with teacher forcing.
172 |
173 | Returned loss is not scaled with gradient accumulation steps.
174 | """
175 | return model(
176 | input_ids=data_batch["input_ids"].to(rank),
177 | attention_mask=data_batch["attention_mask"].to(rank),
178 | labels=data_batch["labels"].to(rank),
179 | protein_input_ids=data_batch["protein_input_ids"].to(rank),
180 | protein_attention_mask=data_batch["protein_attention_mask"].to(rank),
181 | use_cache=False,
182 | output_attentions=False,
183 | output_hidden_states=False,
184 | return_dict=False,
185 | )[0]
186 |
187 |
188 | def setup(rank: int, world_size: int):
189 | """
190 | Initialize processes for distributed training before first epoch.
191 | Fetch from job script or launcher to set the IP address and the port of the
192 | master node.
193 | """
194 | os.environ['MASTER_ADDR'] = os.getenv('MASTER_ADDR', 'localhost')
195 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '9901')
196 | # initialize the process group
197 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
198 |
199 |
200 | def cleanup():
201 | """End processes for distributed training after last epoch"""
202 | dist.destroy_process_group()
203 |
204 |
205 | def train_epoch(
206 | rank: int,
207 | current_epoch: int,
208 | model: Union[DistributedDataParallel, FullyShardedDataParallel],
209 | dataloader: DataLoader,
210 | optimizer: Optimizer,
211 | args: Dict[str, Any]
212 | ):
213 | """Iterate over all batches for one epoch in training with teacher forcing"""
214 | model.train()
215 | ddp_loss = torch.zeros(2).to(rank)
216 | # [0] for acc. loss and [1] for num. of seen batches
217 | ddp_gradnorm = torch.zeros(2).to(rank)
218 | # [0] for acc. gradnorm and [1] for num. of passed steps
219 | optimizer.zero_grad() # erase accumulated gradients from last epoch
220 |
221 | t = tqdm(iter(dataloader))
222 | for batch_idx, data_batch in enumerate(t):
223 | # with autocast, logits will be in AUTOCAST_DTYPE
224 | # but loss will be re-casted to torch.float32
225 | # and model weights will stay in torch.float32
226 | loss = teacher_forcing_forward_pass(
227 | rank=rank,
228 | model=model,
229 | data_batch=data_batch,
230 | )
231 |
232 | # rescale loss for consistency with different gradient accumulation steps
233 | loss = loss / args["gradient_accumulation_steps"]
234 |
235 | # summary current batch
236 | t.set_postfix({
237 | "mode": "train",
238 | "epoch": f"{current_epoch}/{args['num_epochs']}",
239 | "batch_loss": loss.item() * args["gradient_accumulation_steps"],
240 | "device": f"rank:{rank}"
241 | })
242 | ddp_loss[0] += loss.item() * args["gradient_accumulation_steps"]
243 | ddp_loss[1] += 1 # the loss is the weighted mean of the output of every batch
244 |
245 | # scale the loss up by a large factor to prevent them from becoming too small
246 | # then accumulate the scaled grads
247 | loss.backward()
248 | # backward out of autocast, but still uses same dtype as for forward
249 |
250 | # update weights by loss if accumulation step is reached
251 | if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0:
252 | gradnorm = torch.nn.utils.clip_grad_norm_(
253 | model.parameters(),
254 | max_norm=(
255 | float("inf")
256 | if args["gradient_clipping"] is None
257 | else args["gradient_clipping"]
258 | )
259 | )
260 | ddp_gradnorm[0] += gradnorm
261 | ddp_gradnorm[1] += 1
262 |
263 | optimizer.step()
264 | optimizer.zero_grad(set_to_none=True)
265 |
266 | # summary current epoch
267 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
268 | if rank == 0:
269 | print(
270 | f"[epoch={current_epoch}/{args['num_epochs']}, "
271 | f"train_loss={ddp_loss[0] / ddp_loss[1]}, "
272 | f"epoch_lr={optimizer.param_groups[0]['lr']}, "
273 | f"epoch_gradnorm={ddp_gradnorm[0] / ddp_gradnorm[1]}]"
274 | )
275 | # NaN detection
276 | if ddp_loss[0] != ddp_loss[0]:
277 | raise ValueError(
278 | "NaN detected in the training loss of the epoch, training interrupted."
279 | )
280 |
281 |
282 | def eval_epoch(
283 | rank: int,
284 | current_epoch: int,
285 | model: Union[DistributedDataParallel, FullyShardedDataParallel],
286 | dataloader: DataLoader,
287 | args: Dict[str, Any]
288 | ):
289 | """Iterate over all batches in evaluation with teacher forcing"""
290 | model.eval()
291 | ddp_loss = torch.zeros(2).to(rank)
292 | # [0] for acc. loss and [1] for num. of seen batches
293 |
294 | t = tqdm(iter(dataloader))
295 | for data_batch in t:
296 | with torch.no_grad():
297 | loss = teacher_forcing_forward_pass(
298 | rank=rank,
299 | model=model,
300 | data_batch=data_batch,
301 | )
302 |
303 | t.set_postfix({
304 | "mode": "eval",
305 | "epoch": f"{current_epoch}/{args['num_epochs']}",
306 | "batch_loss": loss.item(),
307 | "device": f"rank:{rank}"
308 | })
309 | ddp_loss[0] += loss.item()
310 | ddp_loss[1] += 1
311 |
312 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
313 | if rank == 0:
314 | print(
315 | f"[epoch={current_epoch}/{args['num_epochs']}, "
316 | f"eval_loss={ddp_loss[0] / ddp_loss[1]}]"
317 | )
318 |
319 |
320 | def train_on_device(
321 | rank: int,
322 | world_size: int,
323 | args: Dict[str, Any]
324 | ):
325 | """
326 | Training and evaluation process for each device, including epochs of training
327 | with teacher forcing.
328 | """
329 | setup(rank, world_size)
330 |
331 | # prepare datasets and dataloaders
332 | esm_tokenizer = AutoTokenizer.from_pretrained(args["esm_path"])
333 | llama_tokenizer = AutoTokenizer.from_pretrained(
334 | args["llama_path"],
335 | pad_token='<|reserved_special_token_0|>'
336 | )
337 |
338 | train_dataset = Prot2TextLightDataset(
339 | csv_path=os.path.join(args["root_csv_dir"], f"{args['train_split']}.csv"),
340 | )
341 | if args["debug_trim_train_split"]:
342 | train_dataset.data = train_dataset.data[:args["debug_trim_train_split"]]
343 | train_sampler = DistributedSampler(
344 | train_dataset,
345 | rank=rank,
346 | num_replicas=world_size,
347 | shuffle=True
348 | )
349 |
350 | train_collater = Prot2TextLightCollater(
351 | sequence_tokenizer=esm_tokenizer,
352 | description_tokenizer=llama_tokenizer,
353 | mode="train",
354 | include_text_fields=args["include_text_fields"],
355 | name_dropout=args["name_dropout"],
356 | taxonomy_dropout=args["taxonomy_dropout"],
357 | )
358 | train_loader = DataLoader(
359 | train_dataset,
360 | batch_size=args["batch_size_per_device"],
361 | sampler=train_sampler,
362 | collate_fn=train_collater,
363 | num_workers=4, # parallel CPU cores used for data loading
364 | pin_memory=True, # enable page-locked memory allocation for faster data transfer to GPUs
365 | shuffle=False, # avoid shuffling twice with DistributedSampler
366 | drop_last=True, # avoid incomplete batch at the end
367 | )
368 | print(f"Train dataset loaded on rank:{rank}")
369 |
370 | eval_dataset = Prot2TextLightDataset(
371 | csv_path=os.path.join(args["root_csv_dir"], f"{args['eval_split']}.csv"),
372 | )
373 | if args["debug_trim_eval_split"]:
374 | eval_dataset.data = eval_dataset.data[:args["debug_trim_eval_split"]]
375 | eval_sampler = DistributedSampler(
376 | eval_dataset,
377 | rank=rank,
378 | num_replicas=world_size,
379 | shuffle=False
380 | )
381 | eval_loader = DataLoader(
382 | eval_dataset,
383 | batch_size=args["batch_size_per_device"],
384 | sampler=eval_sampler,
385 | collate_fn=train_collater,
386 | num_workers=4,
387 | pin_memory=True,
388 | shuffle=False,
389 | drop_last=True,
390 | )
391 | print(f"Eval dataset loaded on rank:{rank}")
392 |
393 | torch.cuda.set_device(rank)
394 |
395 | model = load_model(args=args)
396 | model = model.to(rank)
397 |
398 | model = DistributedDataParallel(
399 | model,
400 | # find_unused_parameters=True # suppress error for unused parameters in wrapped model
401 | )
402 | print(f"DDP model loaded on rank:{rank}")
403 |
404 | # initialization of the optimizer after wrapping the model
405 | optimizer = Adam(model.parameters(), lr=args["learning_rate"])
406 | scheduler = StepLR(optimizer, step_size=1, gamma=args["scheduler_gamma"])
407 | if args["load_optimizer_scheduler_checkpoint_path"]:
408 | print(f"Loading {args['load_optimizer_scheduler_checkpoint_path']}")
409 | checkpoint_state_dicts = torch.load(
410 | args["load_optimizer_scheduler_checkpoint_path"],
411 | weights_only=True
412 | )
413 | optimizer_state_dict = checkpoint_state_dicts["optimizer_state_dict"]
414 | scheduler_state_dict = checkpoint_state_dicts["scheduler_state_dict"]
415 | optimizer.load_state_dict(optimizer_state_dict)
416 | scheduler.load_state_dict(scheduler_state_dict)
417 |
418 | # core loop of epochs
419 | for epoch_idx in range(1, args["num_epochs"] + 1):
420 | # shuffle data differently at each epoch across all processes
421 | train_sampler.set_epoch(epoch=epoch_idx)
422 |
423 | train_epoch(
424 | rank=rank,
425 | current_epoch=epoch_idx,
426 | model=model,
427 | dataloader=train_loader,
428 | optimizer=optimizer,
429 | args=args
430 | )
431 | scheduler.step()
432 | dist.barrier() # use a barrier to make sure training is done on all ranks
433 |
434 | eval_epoch(
435 | rank=rank,
436 | model=model,
437 | current_epoch=epoch_idx,
438 | dataloader=eval_loader,
439 | args=args
440 | )
441 | dist.barrier()
442 |
443 | if (
444 | epoch_idx == 1
445 | or epoch_idx == args["num_epochs"]
446 | or epoch_idx % args["save_every_epochs"] == 0
447 | ):
448 | if rank == 0:
449 | adapter_checkpoint_dir = os.path.join(
450 | args["save_checkpoint_dir"],
451 | f"adapter_checkpoint_{epoch_idx}"
452 | )
453 | model.module.save_pretrained(adapter_checkpoint_dir)
454 | print(f"Saving {adapter_checkpoint_dir}")
455 |
456 | optimizer_scheduler_checkpoint_path = os.path.join(
457 | args["save_checkpoint_dir"],
458 | f"optimizer_scheduler_checkpoint_{epoch_idx}.pt"
459 | )
460 | torch.save(
461 | {
462 | "optimizer_state_dict": optimizer.state_dict(),
463 | "scheduler_state_dict": scheduler.state_dict(),
464 | },
465 | optimizer_scheduler_checkpoint_path
466 | )
467 | print(f"Saving {optimizer_scheduler_checkpoint_path}")
468 |
469 | dist.barrier()
470 |
471 | cleanup()
472 |
473 |
474 | def train_distributed(
475 | args: Dict[str, Any] # replace **kwargs for compatibility with spawn
476 | ):
477 | """
478 | Core training process across multiple devices with epochs of training and
479 | inter-epoch evaluation.
480 | """
481 | torch.multiprocessing.spawn(
482 | train_on_device,
483 | args=(args["world_size"], args),
484 | nprocs=args["world_size"],
485 | join=True
486 | )
487 |
488 |
489 | if __name__ == '__main__':
490 | # suppress messages from AutoTokenizer parallelism and Graphein respectively
491 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
492 | os.environ["LOGURU_LEVEL"] = "INFO"
493 |
494 | parsed_args = argParser.parse_args()
495 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # restrict GPU visibility
496 | parsed_args.world_size = torch.cuda.device_count() # use up all visible GPUs
497 |
498 | torch.manual_seed(parsed_args.random_seed)
499 | torch.cuda.manual_seed(parsed_args.random_seed)
500 |
501 | # initialize checkpoint directory
502 | timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
503 | parsed_args.save_checkpoint_dir = os.path.join(
504 | parsed_args.save_checkpoint_dir,
505 | f"checkpoints_{timestamp}"
506 | )
507 | if not os.path.exists(parsed_args.save_checkpoint_dir):
508 | os.mkdir(parsed_args.save_checkpoint_dir)
509 |
510 | print("####################")
511 | for key, value in parsed_args.__dict__.items():
512 | print(f"{key}: {value}")
513 | print("####################")
514 |
515 | train_distributed(parsed_args.__dict__)
516 |
--------------------------------------------------------------------------------
/models/modeling_esm2rgcn2llama_instruct.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration class for the assembled Esm2LlamaInstructForCausalLM model.
3 |
4 | Esm2Rgcn2LlamaInstructForCausalLM = EsmModel + RgcnAdapter + LlamaForCausalLM
5 |
6 | For training/evaluation under teacher-forcing scenario, the model `forward`
7 | function shall take following arguments:
8 | * input_ids: (bsz, prompt_len+description_len) # whole chat template
9 | * attention_mask: (bsz, prompt_len+description_len) # left & right padding
10 | * position_ids: (bsz, prompt_len+description_len) # optional
11 | * past_key_values: None
12 | * labels: (bsz, prompt_len+description_len) # -100 for padding & prompt
13 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds
14 | * protein_attention_mask: (bsz, prot_seq_len) # left padding
15 | * protein_position_ids: (bsz, prot_seq_len) # optional
16 | * protein_head_mask: (num_heads,) or (num_layers, num_heads) # optional
17 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional
18 | * graph_edge_index: (2, sum_num_edges)
19 | * graph_edge_type: (sum_num_edges,)
20 | * graph_batch: (sum_num_nodes,) # optional
21 | * use_cache: False
22 | * return_decoder_inputs: False
23 |
24 | For inference, the model `generate` function shall take following arguments:
25 | * inputs: (bsz, prompt_len) # prompt part of chat template
26 | * attention_mask: (bsz, prompt_len) # left padding
27 | * protein_input_ids: (bsz, prot_seq_len) # either ids or embeds
28 | * protein_attention_mask: (bsz, prot_seq_len) # left padding
29 | * protein_inputs_embeds: (bsz, prot_seq_len, hidden_size) # optional
30 | * graph_edge_index: (2, sum_num_edges)
31 | * graph_edge_type: (sum_num_edges,)
32 | * graph_batch: (sum_num_nodes,) # optional
33 | """
34 |
35 |
36 | from typing import Optional, Tuple, Union
37 |
38 | import torch
39 | import torch_geometric
40 | import torch_geometric.backend
41 | import torch_geometric.nn
42 | from torch_geometric.nn.conv.rgcn_conv import masked_edge_index
43 | from torch_geometric.typing import Adj, OptTensor, SparseTensor
44 | from torch_geometric.utils import index_sort, scatter
45 | from torch_geometric.utils.sparse import index2ptr
46 | from transformers import Cache, PreTrainedModel
47 | from transformers.generation.utils import GenerateOutput
48 | from transformers.modeling_outputs import CausalLMOutputWithPast
49 | from transformers.models.esm.modeling_esm import EsmModel
50 | from transformers.models.llama import LlamaForCausalLM
51 |
52 | from .configuration_esm2rgcn2llama_instruct import (
53 | RgcnAdapterConfig,
54 | Esm2Rgcn2LlamaInstructConfig
55 | )
56 |
57 |
58 | class RgcnConvLayer(torch_geometric.nn.RGCNConv):
59 | """Modified `torch_geometric.nn.RGCNConv` Layer for Flexible Precision."""
60 | def forward(
61 | self,
62 | x: Union[OptTensor, Tuple[OptTensor, torch.Tensor]],
63 | edge_index: Adj,
64 | edge_type: OptTensor = None
65 | ):
66 | """Forward pass of the RGCN layer."""
67 | # handle input node features
68 | x_l = x[0] if isinstance(x, tuple) else x # left node features
69 | if x_l is None: # default to indices if no features provided
70 | x_l = torch.arange(self.in_channels_l, device=self.weight.device)
71 | x_r = x[1] if isinstance(x, tuple) else x_l # right node features
72 |
73 | # define the size of the input graph
74 | size = (x_l.size(0), x_r.size(0))
75 | # extract edge types for SparseTensor input
76 | if isinstance(edge_index, SparseTensor):
77 | edge_type = edge_index.storage.value()
78 | assert edge_type is not None
79 |
80 | # initialize the output tensor, specify additional dtype for flexible precision
81 | out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device, dtype=x.dtype)
82 |
83 | # handle weight decomposition (basis or block diagonal)
84 | weight = self.weight
85 | if self.num_bases is not None: # Basis-decomposition
86 | weight = (self.comp @ weight.view(self.num_bases, -1)).view(
87 | self.num_relations, self.in_channels_l, self.out_channels
88 | )
89 |
90 | if self.num_blocks is not None: # Block-diagonal-decomposition
91 | if not torch.is_floating_point(x_r) and self.num_blocks is not None:
92 | raise ValueError(
93 | 'Block-diagonal decomposition not supported for non-continuous input features.'
94 | )
95 | for i in range(self.num_relations):
96 | tmp = masked_edge_index(edge_index, i == edge_type)
97 | h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
98 | h = h.view(-1, weight.size(1), weight.size(2))
99 | h = torch.einsum('abc,bcd->abd', h, weight[i])
100 | out = out + h.contiguous().view(-1, self.out_channels)
101 |
102 | else: # no decomposition (standard weight handling)
103 | use_segment_matmul = torch_geometric.backend.use_segment_matmul
104 |
105 | # heuristic for enabling `segment_matmul` optimization
106 | if use_segment_matmul is None:
107 | segment_count = scatter(
108 | torch.ones_like(edge_type),
109 | edge_type,
110 | dim_size=self.num_relations
111 | )
112 | self._use_segment_matmul_heuristic_output = (
113 | torch_geometric.backend.use_segment_matmul_heuristic(
114 | num_segments=self.num_relations,
115 | max_segment_size=int(segment_count.max()),
116 | in_channels=self.weight.size(1),
117 | out_channels=self.weight.size(2),
118 | )
119 | )
120 | assert self._use_segment_matmul_heuristic_output is not None
121 | use_segment_matmul = self._use_segment_matmul_heuristic_output
122 |
123 | # segment matmul optimization
124 | if (
125 | use_segment_matmul
126 | and torch_geometric.typing.WITH_SEGMM
127 | and not torch_geometric.is_compiling()
128 | and self.num_bases is None
129 | and x_l.is_floating_point()
130 | and isinstance(edge_index, torch.Tensor)
131 | ):
132 | if not self.is_sorted:
133 | if (edge_type[1:] < edge_type[:-1]).any():
134 | edge_type, perm = index_sort(edge_type, max_value=self.num_relations)
135 | edge_index = edge_index[:, perm]
136 | edge_type_ptr = index2ptr(edge_type, self.num_relations)
137 | out = self.propagate(edge_index, x=x_l, edge_type_ptr=edge_type_ptr, size=size)
138 |
139 | else: # loop through relations without optimization
140 | for i in range(self.num_relations):
141 | tmp = masked_edge_index(edge_index, i == edge_type)
142 |
143 | if not torch.is_floating_point(x_r):
144 | out = out + self.propagate(
145 | tmp,
146 | x=weight[i, x_l],
147 | edge_type_ptr=None,
148 | size=size
149 | )
150 | else:
151 | h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
152 | out = out + (h @ weight[i])
153 |
154 | # incorporate root embeddings
155 | root = self.root
156 | if root is not None:
157 | if not torch.is_floating_point(x_r):
158 | out = out + root[x_r]
159 | else:
160 | out = out + x_r @ root
161 |
162 | # add bias if applicable
163 | if self.bias is not None:
164 | out = out + self.bias
165 |
166 | return out
167 |
168 | def edge_update(self) -> torch.Tensor:
169 | """Placeholder for edge feature updates (if applicable)."""
170 | pass
171 |
172 |
173 | class RgcnAdapter(PreTrainedModel):
174 | """
175 | Relational Graph Convolutional Network adapter to match the hidden size of
176 | different modalities.
177 | """
178 | config_class = RgcnAdapterConfig # configuration class for this model
179 |
180 | def __init__(self, config: RgcnAdapterConfig):
181 | super().__init__(config)
182 | self.config = config
183 | self.activation = torch.nn.GELU()
184 | self.dropout = torch.nn.Dropout(p=config.dropout_rate)
185 | self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim)
186 | self.rgcn_layers = torch.nn.ModuleList([
187 | RgcnConvLayer(
188 | in_channels=config.intermediate_dim,
189 | out_channels=config.intermediate_dim,
190 | num_relations=config.n_relations,
191 | )
192 | for _ in range(config.n_layers)
193 | ])
194 | self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim)
195 | self.post_init() # initialize weights and apply final processing
196 |
197 | def forward(
198 | self,
199 | hidden_states: torch.FloatTensor, # (bsz, seq_len, input_dim)
200 | attention_mask: torch.LongTensor, # (bsz, seq_len)
201 | edge_index: torch.LongTensor, # (2, sum_num_edges)
202 | edge_type: torch.LongTensor, # (sum_num_edges,)
203 | batch: Optional[torch.LongTensor] = None # (sum_num_nodes,)
204 | ) -> torch.FloatTensor:
205 | hidden_states = self.activation(self.fc1(hidden_states))
206 | hidden_states = self.dropout(hidden_states) # (bsz, seq_len, interm_dim)
207 |
208 | # create mask for node embeddings to be updated by RGCN layers
209 | nodes_mask = attention_mask.clone().bool() # (bsz, seq_len)
210 | nodes_mask[:, 0] = False # exclude bos token
211 | batch_size = hidden_states.size(0)
212 | batch_indices = torch.arange(0, batch_size) # (bsz,)
213 | eos_indices = attention_mask.sum(dim=1) - 1 # (bsz,)
214 | nodes_mask[batch_indices, eos_indices] = False # exclude eos token
215 |
216 | # compute RGCN layers on node embeddings only
217 | nodes_hidden_states = hidden_states[nodes_mask] # (sum_num_nodes, interm_dim)
218 | for layer in self.rgcn_layers:
219 | nodes_hidden_states = layer(nodes_hidden_states, edge_index, edge_type)
220 | nodes_hidden_states = self.activation(nodes_hidden_states)
221 | nodes_hidden_states = self.dropout(nodes_hidden_states)
222 |
223 | # update node hidden states with RGCN outputs
224 | hidden_states[nodes_mask] = nodes_hidden_states
225 |
226 | hidden_states = self.activation(self.fc2(hidden_states))
227 | hidden_states = self.dropout(hidden_states)
228 | hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
229 | return hidden_states # (bsz, seq_len, output_dim)
230 |
231 |
232 | class Esm2Rgcn2LlamaInstructForCausalLM(PreTrainedModel):
233 | """
234 | Esm2Rgcn2LlamaInstructForCausalLM model for protein function prediction.
235 | Similar to `EncoderDecoderModel` but with more complicated architecture.
236 | Initialize with either a configuration OR all three components.
237 | `kwargs` can override standalone attributes in `Esm2Rgcn2LlamaInstructConfig`.
238 | """
239 | config_class = Esm2Rgcn2LlamaInstructConfig # configuration class for this model
240 |
241 | def __init__(
242 | self,
243 | config: Optional[Esm2Rgcn2LlamaInstructConfig] = None,
244 | esm_encoder: Optional[EsmModel] = None,
245 | adapter: Optional[RgcnAdapter] = None,
246 | llama_decoder: Optional[LlamaForCausalLM] = None,
247 | **kwargs
248 | ):
249 | if config is not None: # components ignored if config is provided
250 | super().__init__(config)
251 | self.esm_encoder = EsmModel(
252 | config.esm_config,
253 | add_pooling_layer=False
254 | )
255 | self.adapter = RgcnAdapter(config.adapter_config)
256 | self.llama_decoder = LlamaForCausalLM(config.llama_config)
257 | else:
258 | config = Esm2Rgcn2LlamaInstructConfig(
259 | esm_config=esm_encoder.config,
260 | adapter_config=adapter.config,
261 | llama_config=llama_decoder.config,
262 | **kwargs # override standalone attributes
263 | )
264 | super().__init__(config)
265 | self.esm_encoder = esm_encoder
266 | self.adapter = adapter
267 | self.llama_decoder = llama_decoder
268 |
269 | def prepare_decoder_inputs(
270 | self,
271 | input_ids: torch.LongTensor,
272 | encoder_hidden_states: torch.FloatTensor,
273 | attention_mask: Optional[torch.LongTensor] = None,
274 | encoder_attention_mask: Optional[torch.LongTensor] = None,
275 | ):
276 | """
277 | Embed and replace placeholder in `input_ids` by encoder hidden states.
278 | `input_ids` must be passed to locate placeholder for replacement.
279 | """
280 | # preparation
281 | batch_size, seq_len = input_ids.size()
282 | _, encoder_seq_len, _ = encoder_hidden_states.size()
283 | if attention_mask is None:
284 | attention_mask = torch.ones(
285 | (batch_size, seq_len),
286 | dtype=torch.long,
287 | device=input_ids.device
288 | )
289 | if encoder_attention_mask is None:
290 | encoder_attention_mask = torch.ones(
291 | (batch_size, encoder_seq_len),
292 | dtype=torch.long,
293 | device=encoder_hidden_states.device
294 | )
295 | inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids)
296 | # replacement
297 | placeholder_mask = input_ids == self.config.placeholder_id
298 | encoder_mask = encoder_attention_mask.bool()
299 | inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask]
300 | return inputs_embeds, attention_mask
301 |
302 | def forward(
303 | self,
304 | # chat template text inputs
305 | input_ids: Optional[torch.LongTensor] = None,
306 | attention_mask: Optional[torch.LongTensor] = None,
307 | position_ids: Optional[torch.LongTensor] = None,
308 | past_key_values: Optional[Cache] = None,
309 | labels: Optional[torch.LongTensor] = None,
310 | # protein amino-acid sequence inputs
311 | protein_input_ids: Optional[torch.LongTensor] = None,
312 | protein_attention_mask: Optional[torch.LongTensor] = None,
313 | protein_position_ids: Optional[torch.LongTensor] = None,
314 | protein_head_mask: Optional[torch.LongTensor] = None,
315 | protein_inputs_embeds: Optional[torch.FloatTensor] = None,
316 | # graph-related inputs
317 | graph_edge_index: Optional[torch.LongTensor] = None,
318 | graph_edge_type: Optional[torch.LongTensor] = None,
319 | graph_batch: Optional[torch.LongTensor] = None,
320 | # behavior control arguments
321 | use_cache: Optional[bool] = None,
322 | output_attentions: Optional[bool] = None,
323 | output_hidden_states: Optional[bool] = None,
324 | return_dict: Optional[bool] = None,
325 | return_encoder_outputs: bool = False,
326 | return_adapter_outputs: bool = False,
327 | return_decoder_inputs: bool = False,
328 | cache_position: Optional[torch.LongTensor] = None
329 | ) -> Union[Tuple, CausalLMOutputWithPast]:
330 | """
331 | Compute encoder and adapter outputs, then pass to decoder.
332 | `input_ids` is expected to be [prompt + description] in teacher-forcing
333 | scenario and [prompt] only in first iteration of inference (with
334 | return_decoder_inputs=True).
335 | Attention: possible concatenation of the mask and labels should be
336 | handled before calling this method.
337 | `inputs_embeds` not allowed due to placeholder replacement scheme.
338 | """
339 | # esm_encoder forward
340 | encoder_output = self.esm_encoder(
341 | input_ids=protein_input_ids,
342 | attention_mask=protein_attention_mask,
343 | position_ids=protein_position_ids,
344 | head_mask=protein_head_mask,
345 | inputs_embeds=protein_inputs_embeds,
346 | use_cache=False, # because config.esm_config.is_decoder=False
347 | output_attentions=output_attentions,
348 | output_hidden_states=output_hidden_states,
349 | return_dict=return_dict
350 | )
351 | encoder_hidden_states = encoder_output[0]
352 | encoder_attention_mask = protein_attention_mask
353 | if return_encoder_outputs:
354 | return encoder_output
355 | # adapter forward
356 | adapter_output = self.adapter(
357 | hidden_states=encoder_hidden_states,
358 | attention_mask=encoder_attention_mask,
359 | edge_index=graph_edge_index,
360 | edge_type=graph_edge_type,
361 | batch=graph_batch
362 | )
363 | if return_adapter_outputs:
364 | return adapter_output, encoder_attention_mask
365 | # decoder input preparation
366 | inputs_embeds, attention_mask = self.prepare_decoder_inputs(
367 | input_ids=input_ids,
368 | encoder_hidden_states=adapter_output,
369 | attention_mask=attention_mask,
370 | encoder_attention_mask=encoder_attention_mask,
371 | )
372 | if return_decoder_inputs:
373 | return inputs_embeds, attention_mask
374 | # llama_decoder forward
375 | return self.llama_decoder.forward(
376 | input_ids=None,
377 | attention_mask=attention_mask,
378 | position_ids=position_ids,
379 | past_key_values=past_key_values,
380 | inputs_embeds=inputs_embeds,
381 | labels=labels,
382 | use_cache=use_cache,
383 | output_attentions=output_attentions,
384 | return_dict=return_dict,
385 | cache_position=cache_position
386 | )
387 |
388 | def generate(
389 | self,
390 | inputs: torch.LongTensor, # alias of `input_ids`
391 | attention_mask: Optional[torch.LongTensor] = None,
392 | protein_input_ids: Optional[torch.LongTensor] = None,
393 | protein_attention_mask: Optional[torch.LongTensor] = None,
394 | protein_inputs_embeds: Optional[torch.FloatTensor] = None,
395 | graph_edge_index: Optional[torch.LongTensor] = None,
396 | graph_edge_type: Optional[torch.LongTensor] = None,
397 | graph_batch: Optional[torch.LongTensor] = None,
398 | **kwargs
399 | ) -> Union[GenerateOutput, torch.LongTensor]:
400 | """
401 | Do inference based on given input prompt.
402 | `inputs` is expected to be [prompt] only.
403 | Output will not keep the input prompt due to input in form of embeds.
404 | Generation behavior can be controlled by `args` and `kwargs`, read
405 | `GenerationMixin.generate` for more info.
406 | """
407 | # get decoder inputs
408 | prompt_inputs_embeds, prompt_attention_mask = self(
409 | input_ids=inputs,
410 | attention_mask=attention_mask,
411 | protein_input_ids=protein_input_ids,
412 | protein_attention_mask=protein_attention_mask,
413 | protein_inputs_embeds=protein_inputs_embeds,
414 | graph_edge_index=graph_edge_index,
415 | graph_edge_type=graph_edge_type,
416 | graph_batch=graph_batch,
417 | use_cache=False,
418 | output_attentions=False,
419 | output_hidden_states=False,
420 | return_dict=False,
421 | return_decoder_inputs=True
422 | )
423 | # do generate on llama_decoder
424 | return self.llama_decoder.generate(
425 | inputs_embeds=prompt_inputs_embeds,
426 | attention_mask=prompt_attention_mask,
427 | **kwargs
428 | )
429 |
430 | def gradient_checkpointing_enable(self):
431 | """
432 | Enable gradient checkpointing for all submodules that support it.
433 | Attention! Model need to be in train mode before calling this method.
434 | """
435 | if hasattr(self.esm_encoder, "gradient_checkpointing_enable"):
436 | self.esm_encoder.gradient_checkpointing_enable()
437 | if hasattr(self.llama_decoder, "gradient_checkpointing_enable"):
438 | self.llama_decoder.gradient_checkpointing_enable()
439 | # simple adapter no need to implement gradient checkpointing
440 |
441 | def gradient_checkpointing_disable(self):
442 | if hasattr(self.esm_encoder, "gradient_checkpointing_disable"):
443 | self.esm_encoder.gradient_checkpointing_disable()
444 | if hasattr(self.llama_decoder, "gradient_checkpointing_disable"):
445 | self.llama_decoder.gradient_checkpointing_disable()
446 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset class for protein function prediction instruction training.
3 |
4 | Every sample from the dataset is of following attributes:
5 | - graph related features if not ignored:
6 | - x: (num_nodes, num_node_features)
7 | - edge_index: (2, num_edges)
8 | - edge_type: (num_edges, )
9 | - input ids:
10 | - sequence_input_ids: (1, sequence_length+2) # bos and eos tokens
11 | - prompt_input_ids: (1, prompt_length) # with generation prompt head
12 | - description_input_ids: (1, description_length+1) # eos token only
13 | - other attributes:
14 | - name: str
15 |
16 | * No attention mask will be produced at this stage and a dynamic padding will be
17 | applied in the collate function of the dataloader.
18 |
19 | A chat template is applied in such process to assemble full name and taxon of
20 | every protein and leave space for sequence embeddings with placeholders.
21 | The assembled prompt will be of following structure:
22 | (
23 | <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n
24 | You are a scientific assistant specialized in ...
25 | <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n
26 | Name: ; Taxon: ; Sequence embeddings: ...
27 | <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
28 | )
29 | And the description of the protein will be of following structure:
30 | (
31 | Involved in ... Required for ... <|eot_id|>
32 | )
33 | The template is designed for Meta-Llama-3.1-8B-Instruct-hf and can be adapted
34 | to other models by calling `process_text` function with the new tokenizer.
35 |
36 | Example of usage:
37 | >>> from transformers import AutoTokenizer
38 | >>> from dataset import Prot2TextInstructDataset
39 | >>> esm_tokenizer = AutoTokenizer.from_pretrained("/data/esm2_t33_650M_UR50D")
40 | >>> llama_tokenizer = AutoTokenizer.from_pretrained(
41 | "/data/Meta-Llama-3.1-8B-Instruct-hf",
42 | pad_token='<|reserved_special_token_0|>'
43 | )
44 | >>> train_dataset = Prot2TextInstructDataset(
45 | root_dir="/data/Prot2Text-Llama3-Data/train",
46 | csv_path="./data/train.csv",
47 | sequence_tokenizer=esm_tokenizer,
48 | description_tokenizer=llama_tokenizer,
49 | skip_download=False, # download PDB files
50 | skip_reload=False, # construct graph and tokenize text
51 | )
52 | """
53 |
54 | import multiprocessing as mp
55 | import os
56 | import sys
57 | from typing import Any, Dict, List, Optional, Union
58 |
59 | from graphein.protein.config import ProteinGraphConfig
60 | import pandas as pd
61 | import torch
62 | import torch.utils.data
63 | import torch_geometric
64 | import torch_geometric.data
65 | from transformers import PreTrainedTokenizer
66 | from tqdm import tqdm
67 | import wget
68 |
69 | from .pdb2nx import construct_nx_graph
70 | from .nx2pyg import convert_nx_to_pyg
71 | from .utils_dataset import default_graph_process_config
72 |
73 |
74 | class Prot2TextInstructDataset(torch_geometric.data.Dataset):
75 | """
76 | Dataset class for proteins.
77 |
78 | (1) Download PDB files from AlphaFoldDB if not skipped, then
79 | (2) preprocess graph and textual features if not skipped to prepare for the
80 | formation of :class:`Prot2TextDataLoader`.
81 |
82 | * Use multiple :class:`Prot2TextDataset` instead of one for multiple
83 | splits of the dataset.
84 | * root_dir should be the root directory of each split.
85 |
86 | Args:
87 | root_dir:
88 | Root directory of the dataset where root_dir/raw and
89 | root_dir/processed will be prepared. If the dataset is
90 | of a particular split of the whole dataset, the directory
91 | of the split should be passed instead.
92 | csv_path:
93 | Path of the csv file containing the uniprot ids, amino-acid
94 | sequence and the description.
95 | sequence_tokenizer:
96 | Tokenizer for protein amino-acid sequences. Tokenization
97 | will be done in preprocessing without pad token, and padding
98 | will be added in dataloader when forming data batch. The
99 | tokenizer should have :attr:`pad_token` and :attr:`pad_token_id`
100 | attributes which will be used in collate function of the
101 | dataloader for dynamic padding.
102 | description_tokenizer:
103 | Tokenizer for protein functionality descriptions. The
104 | tokenizer should add bos token but not add eos token by
105 | itself for the consistency with prompt tokenization. However,
106 | the tokenizer should have :attr:`eos_token` and
107 | :attr:`eos_token_id` attributes which will be used in
108 | preprocessing to add eos token at the end of the description
109 | and the label. The tokenizer should also have :attr:`pad_token`
110 | and :attr:`pad_token_id` attributes which will be used in
111 | collate function of the dataloader for dynamic padding.
112 | alphafold_base_url:
113 | Base download URL of AlphaFoldDB to download PDB files from.
114 | The full download link will be :url:`{alphafold_base_url}/AF-
115 | {uniprot_id}-F1-model_v{alphafold_version}.pdb`.
116 | alphafold_version:
117 | Version of the AlphaFoldDB to download PDB files from.
118 | graph_process_config:
119 | Configuration specifying features of nodes, edges and the
120 | whole graph to be used for graph construction from PDB files.
121 | If not passed, Default configuration will be used.
122 | skip_download:
123 | Force preprocessing to skip downloading procedure and to use
124 | existing PDB files only. Otherwise, downloading procedure
125 | will be skipped only if every uniprot id from CSV file has
126 | corresponding PDB file under `root_dir/raw`.
127 | skip_reload:
128 | Force preprocessing to skip formation of graphs from PDB files
129 | and to use existing PyG tensor files only. Otherwise, processing
130 | procedure will be skipped only if every PDB file under
131 | `root_dir/raw` has corresponding PyG tensor file under
132 | `root_dir/processed`.
133 | num_processes:
134 | Number of parallel processes in formation of the graph. If
135 | not specified, all logical CPU threads will be used by default.
136 | ignore_graph_features:
137 | Ignore graph related features while loading data. This does
138 | not affect the preprocessing behavior and graph related features
139 | will be processed then saved.
140 | max_sequence_length:
141 | Maximum length of protein amino-acid sequence. Samples with
142 | longer sequences will be trimmed if such value passed. Such
143 | behavior will lose part of the information but can avoid
144 | Out-Of-Memory error in training. This does not affect the
145 | preprocessing behavior and whole sequences will be processed
146 | then saved.
147 | max_description_length:
148 | Maximum length of protein description. Samples with longer
149 | descriptions will be trimmed if such value passed. Such
150 | behavior will lose part of the information but can avoid
151 | Out-Of-Memory error in training. This does not affect the
152 | preprocessing behavior and whole descriptions will be processed
153 | then saved.
154 | system_message:
155 | The system message to be used in the chat template.
156 | placeholder_token:
157 | The placeholder token to be used in the chat template that will be
158 | replaced by the output of encoder as sequence embeddings.
159 | kwargs:
160 | Additional arguments controlling the behavior of the PyG
161 | dataset inherited from the base dataset class. Read the
162 | documentation from :class:`torch_geometric.data.Dataset`
163 | for more information.
164 | """
165 | def __init__(
166 | self,
167 | root_dir: Union[str, os.PathLike],
168 | csv_path: Union[str, os.PathLike],
169 | sequence_tokenizer: PreTrainedTokenizer,
170 | description_tokenizer: PreTrainedTokenizer,
171 | alphafold_base_url: str = "https://alphafold.ebi.ac.uk/files/",
172 | alphafold_version: int = 4,
173 | graph_process_config: ProteinGraphConfig = default_graph_process_config,
174 | skip_download: bool = False,
175 | skip_reload: bool = False,
176 | num_processes: Optional[int] = None,
177 | ignore_graph_features: bool = False,
178 | max_sequence_length: Optional[int] = 1021,
179 | max_description_length: Optional[int] = 512,
180 | system_message: str = (
181 | "You are a scientific assistant specialized in protein function "
182 | "predictions. Given the sequence embeddings and other information "
183 | "of a protein, describe its function clearly and concisely in "
184 | "professional language. "
185 | ),
186 | placeholder_token: str = '<|reserved_special_token_1|>',
187 | **kwargs,
188 | ):
189 | self.root_dir = root_dir
190 | self.uniprot_df = pd.read_csv(csv_path)
191 | self.sequence_tokenizer = sequence_tokenizer
192 | self.description_tokenizer = description_tokenizer
193 | self.alphafold_base_url = alphafold_base_url
194 | self.alphafold_version = alphafold_version
195 | self.graph_process_config = graph_process_config
196 | self.skip_download = skip_download
197 | self.skip_reload = skip_reload
198 | self.num_processes = num_processes
199 | self.ignore_graph_features = ignore_graph_features
200 | self.max_sequence_length = max_sequence_length
201 | self.max_description_length = max_description_length
202 | self.system_message = system_message
203 | self.placeholder_token = placeholder_token
204 |
205 | self.usable_file_names: List[Union[str, os.PathLike]] = []
206 | super().__init__(root=root_dir, **kwargs) # first download then process
207 | self.update_usable_file_names()
208 |
209 | def download(self, overwrite_existing: bool = False):
210 | """
211 | Downloads PDB files from AlphaFoldDB to :attr:`self.raw_dir` folder. If
212 | such file already exists, the download will be skipped. Unsuccessful
213 | attempt of downloading will not crash the script but exceptions will be
214 | noted.
215 | """
216 | if self.skip_download:
217 | return
218 | assert self.alphafold_base_url is not None, (
219 | "Downloading requested but base URL of AlphaFoldDB is not set. "
220 | )
221 | assert self.alphafold_version is not None, (
222 | "Downloading requested but version of AlphaFoldDB is not set. "
223 | )
224 | for raw_file_name in tqdm(self.raw_file_names):
225 | raw_file_name = raw_file_name
226 | raw_file_path = os.path.join(self.raw_dir, raw_file_name)
227 | full_url = self.alphafold_base_url + raw_file_name
228 | if overwrite_existing or not os.path.exists(raw_file_path):
229 | try:
230 | wget.download(full_url, out=raw_file_path)
231 | except Exception as exception:
232 | if self.log:
233 | print(
234 | f"Download of {raw_file_name} failed due to exception: "
235 | f"{exception.__class__.__name__}: {exception}",
236 | file=sys.stderr
237 | )
238 |
239 | def process(self, overwrite_existing: bool = False):
240 | """
241 | Preprocesses and converts PDB files to pytorch tensor files. Parallel
242 | multiprocessing approach applied for graph construction. Unsuccessful
243 | attempt of graph processing will not crash the script but exceptions will
244 | be noted. Textual features will be encoded and added to tensor files in
245 | the end without parallelism.
246 | """
247 | if self.skip_reload:
248 | return
249 | assert self.sequence_tokenizer is not None, (
250 | "Processing requested but sequence tokenizer is not set."
251 | )
252 | assert self.description_tokenizer is not None, (
253 | "Processing requested but description tokenizer is not set."
254 | )
255 |
256 | # graph construction: convert PDB files to tensor files
257 | with mp.Pool(processes=self.num_processes) as pool:
258 | for raw_file_name in self.raw_file_names:
259 | raw_file_path = os.path.join(self.raw_dir, raw_file_name)
260 | processed_file_path = os.path.join(
261 | self.processed_dir,
262 | os.path.splitext(raw_file_name)[0] + ".pt"
263 | )
264 | if overwrite_existing or not os.path.exists(processed_file_path):
265 | pool.apply_async(
266 | self.process_graph,
267 | args=(raw_file_path, processed_file_path)
268 | )
269 | pool.close()
270 | pool.join()
271 | self.update_usable_file_names()
272 |
273 | # textual feature tokenization: add tokenized features to tensor files
274 | self.process_text()
275 | self.update_usable_file_names()
276 |
277 | def process_graph(
278 | self,
279 | raw_file_path: Union[str, os.PathLike],
280 | processed_file_path: Union[str, os.PathLike]
281 | ):
282 | """
283 | Preprocesses and converts single PDB file to pytorch tensor file. Unit
284 | function for multiprocessing.
285 | """
286 | uniprot_id = os.path.split(raw_file_path)[1].split("-")[1]
287 | try:
288 | nx_graph = construct_nx_graph(
289 | config=self.graph_process_config,
290 | pdb_path=raw_file_path
291 | )
292 | data = convert_nx_to_pyg(nx_graph)
293 | torch.save(data, processed_file_path)
294 | except Exception as exception:
295 | if self.log:
296 | print(
297 | f"Graph processing of {uniprot_id} failed due to exception: "
298 | f"{exception.__class__.__name__}: {exception}",
299 | file=sys.stderr
300 | )
301 |
302 | def process_text(
303 | self,
304 | new_sequence_tokenizer: Optional[PreTrainedTokenizer] = None,
305 | new_description_tokenizer: Optional[PreTrainedTokenizer] = None
306 | ):
307 | """
308 | Apply tokenizers to add tokenized protein amino-acid sequences and
309 | protein functionality descriptions to processed pytorch tensor files.
310 | Such process will not download and process graphs again, but only
311 | add/modify the textual features in processed Data objects. Can be used
312 | to change tokenizers on existing processed dataset.
313 | """
314 | if new_sequence_tokenizer is not None:
315 | self.sequence_tokenizer = new_sequence_tokenizer
316 | if new_description_tokenizer is not None:
317 | self.description_tokenizer = new_description_tokenizer
318 | for processed_file_name in tqdm(self.usable_file_names):
319 | processed_file_path = os.path.join(self.processed_dir, processed_file_name)
320 | try:
321 | uniprot_id = os.path.split(processed_file_name)[1].split("-")[1]
322 | data = torch.load(processed_file_path)
323 | data.update(self._compose_and_tokenize_chat(uniprot_id))
324 | torch.save(data, processed_file_path)
325 | except Exception as exception:
326 | if self.log:
327 | print(
328 | f"Text processing of {processed_file_name} failed due to exception: "
329 | f"{exception.__class__.__name__}: {exception}",
330 | file=sys.stderr
331 | )
332 |
333 | def _compose_and_tokenize_chat(self, uniprot_id: str) -> Dict[str, torch.Tensor]:
334 | """
335 | (1) Matches corresponding row in CSV DataFrame of the given uniprot id,
336 | (2) trim sequence and description to avoid OOM error,
337 | (3) apply chat template to form the prompt with placeholder, and
338 | (4) tokenizes amino-acid sequence, prompt and description.
339 |
340 | The eos token will be added at the end of the description and of the
341 | label before tokenization.
342 | """
343 | # (1) match row and extract features
344 | filtered_df = self.uniprot_df.loc[
345 | self.uniprot_df['AlphaFoldDB'] == uniprot_id
346 | ]
347 | sequence = filtered_df["sequence"].values[0]
348 | description = filtered_df["function"].values[0]
349 | fullname = filtered_df["Full Name"].values[0]
350 | taxon = filtered_df["taxon"].values[0]
351 | fullname = "unknown" if pd.isna(fullname) else fullname
352 | taxon = "unknown" if pd.isna(taxon) else taxon
353 |
354 | # (2) trim sequence and description
355 | if self.max_description_length is not None:
356 | description_ids = self.description_tokenizer(
357 | [description],
358 | add_special_tokens=False,
359 | return_tensors="pt"
360 | )["input_ids"]
361 | if description_ids.size(-1) > self.max_description_length:
362 | description_ids = description_ids[:, :self.max_description_length]
363 | description = self.description_tokenizer.decode(description_ids[0])
364 | if self.max_sequence_length is not None:
365 | if len(sequence) > self.max_sequence_length:
366 | sequence = sequence[:self.max_sequence_length]
367 |
368 | # (3) apply chat template then tokenize the prompt
369 | user_message = (
370 | "Protein name: " + fullname
371 | + " ; Taxon: " + taxon
372 | + " ; Sequence embeddings: "
373 | + self.placeholder_token * (len(sequence) + 2) # bos and eos tokens
374 | )
375 | prompt_conversation = [
376 | {"role": "system", "content": self.system_message},
377 | {"role": "user", "content": user_message}
378 | ]
379 | prompt_ids = self.description_tokenizer.apply_chat_template(
380 | prompt_conversation,
381 | add_generation_prompt=True,
382 | tokenize=True,
383 | padding=False,
384 | return_tensors="pt"
385 | )
386 |
387 | # (4) tokenize sequence and description
388 | sequence_ids = self.sequence_tokenizer(
389 | [sequence],
390 | add_special_tokens=True,
391 | return_attention_mask=False,
392 | return_tensors="pt"
393 | )["input_ids"]
394 | description_ids = self.description_tokenizer(
395 | [description + self.description_tokenizer.eos_token],
396 | add_special_tokens=False,
397 | return_attention_mask=False,
398 | return_tensors="pt"
399 | )["input_ids"]
400 |
401 | # tensors should be kept of shape (1, seq_length) for formation of batch
402 | # in dataloader
403 | return {
404 | "sequence_input_ids": sequence_ids,
405 | "prompt_input_ids": prompt_ids,
406 | "description_input_ids": description_ids,
407 | }
408 |
409 | @property
410 | def raw_file_names(self) -> List[str]:
411 | """
412 | The name of files in `self.raw_dir` folder that must be present in order
413 | to skip downloading. Required in the parent class.
414 | """
415 | uniprot_ids = set(self.uniprot_df.AlphaFoldDB)
416 | return [
417 | f"AF-{uniprot_id}-F1-model_v{self.alphafold_version}.pdb"
418 | for uniprot_id in uniprot_ids
419 | ]
420 |
421 | @property
422 | def processed_file_names(self) -> List[str]:
423 | """
424 | The name of files in `self.processed_dir` folder that must be present in
425 | order to skip processing. Required in the parent class.
426 | """
427 | return [
428 | os.path.splitext(raw_file_name)[0]+".pt"
429 | for raw_file_name in os.listdir(self.raw_dir)
430 | ]
431 |
432 | def update_usable_file_names(self):
433 | """
434 | Updates stored name of files in `self.processed_dir` folder that are
435 | ready to be used for the dataset. Usable file names shall be updated in
436 | initialization and preprocessing but not recommended in actual usage to
437 | maintain consistency of the order of file names.
438 | """
439 | existing_file_names = os.listdir(self.processed_dir)
440 | for special_file_name in ['pre_transform.pt', 'pre_filter.pt']:
441 | if special_file_name in existing_file_names:
442 | existing_file_names.remove(special_file_name)
443 | self.usable_file_names = sorted(existing_file_names)
444 |
445 | def len(self) -> int:
446 | """Required in the parent class."""
447 | return len(self.usable_file_names)
448 |
449 | def __len__(self) -> int:
450 | return self.len()
451 |
452 | def get(self, idx: int, debug_mode: bool = False) -> Any:
453 | data = torch.load(os.path.join(self.processed_dir, self.usable_file_names[idx]))
454 | if debug_mode: # return all items stored in the file
455 | return data
456 | if self.ignore_graph_features:
457 | return torch_geometric.data.Data(
458 | sequence_input_ids=data['sequence_input_ids'],
459 | prompt_input_ids=data["prompt_input_ids"],
460 | description_input_ids=data['description_input_ids'],
461 | name=data["name"],
462 | )
463 | else:
464 | return torch_geometric.data.Data(
465 | x=data['x'],
466 | edge_index=data['edge_index'],
467 | edge_type=data['edge_type'],
468 | sequence_input_ids=data['sequence_input_ids'],
469 | prompt_input_ids=data["prompt_input_ids"],
470 | description_input_ids=data['description_input_ids'],
471 | name=data["name"],
472 | )
473 |
--------------------------------------------------------------------------------