├── prot2text_model ├── __init__.py ├── tokenization_prot2text.py ├── Encoder.py ├── Model.py └── utils.py ├── prot2text_dataset ├── __init__.py ├── utils_convert.py ├── utils_dataset.py ├── Inememorydataset.py ├── torch_geometric_loader.py ├── pdb2graph.py ├── conversion.py └── graphs.py ├── __init__.py ├── Prot2Text.drawio.png ├── Requirements.txt ├── generate_description.py ├── train_prot2text.slurm ├── prepare_dataset.py ├── evaluate_prot2text.py ├── README.md ├── train.py └── LICENSE /prot2text_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prot2text_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from . import prot2text_dataset, prot2text_model -------------------------------------------------------------------------------- /Prot2Text.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hadi-abdine/Prot2Text/HEAD/Prot2Text.drawio.png -------------------------------------------------------------------------------- /prot2text_model/tokenization_prot2text.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Tokenizer 2 | 3 | class Prot2TextTokenizer(GPT2Tokenizer): 4 | def __init__( 5 | self, 6 | vocab_file, 7 | merges_file, 8 | errors="replace", 9 | unk_token="<|endoftext|>", 10 | bos_token="<|endoftext|>", 11 | eos_token="<|endoftext|>", 12 | pad_token=None, 13 | add_prefix_space=False, 14 | add_bos_token=False, 15 | **kwargs, 16 | ): 17 | super().__init__( 18 | vocab_file=vocab_file, 19 | merges_file=merges_file, 20 | errors=errors, 21 | unk_token=unk_token, 22 | bos_token=bos_token, 23 | eos_token=eos_token, 24 | pad_token=pad_token, 25 | add_prefix_space=add_prefix_space, 26 | add_bos_token=add_bos_token, 27 | **kwargs, 28 | ) 29 | -------------------------------------------------------------------------------- /Requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | accelerate==0.20.3 3 | arrow==1.2.2 4 | asttokens==2.0.5 5 | async-timeout==4.0.1 6 | bert-score==0.3.11 7 | biopandas== 0.5.0.dev0 8 | biopython==1.81 9 | bioservices==1.11.2 10 | bitarray==2.4.1 11 | cachetools 12 | Cython==0.29.28 13 | datasets==2.12.0 14 | datashader==0.13.0 15 | datashape 16 | distributed==2022.2.1 17 | evaluate==0.4.0 18 | huggingface-hub==0.15.1 19 | hvplot==0.7.3 20 | hyperlink==21.0.0 21 | json5==0.9.6 22 | jsonschema==4.4.0 23 | matplotlib==3.5.1 24 | matplotlib-inline==0.1.2 25 | networkx==2.5 26 | nlp 27 | nltk 28 | numba 29 | numpy==1.23.5 30 | numpydoc==1.2 31 | pandas==1.4.2 32 | pandoc==2.3 33 | pandocfilters==1.5.0 34 | plotly==5.6.0 35 | py==1.11.0 36 | pyarrow==9.0.0 37 | PyYAML 38 | rope==0.22.0 39 | rouge-score==0.1.2 40 | rsa==4.7.2 41 | scikit-image 42 | scikit-learn 43 | scikit-learn-intelex 44 | scipy 45 | Scrapy 46 | seaborn 47 | setuptools==61.2.0 48 | tables==3.6.1 49 | tabulate==0.8.9 50 | TBB 51 | tensorboard==2.9.1 52 | tensorboard-data-server==0.6.1 53 | tensorboard-plugin-wit==1.8.1 54 | tensorboardX==2.5.1 55 | text-unidecode==1.3 56 | textdistance==4.2.1 57 | tqdm==4.64.0 58 | transformers==4.40.2 59 | typed-ast==1.4.3 60 | typeguard==4.0.0 61 | typing_extensions==4.5.0 62 | ujson==5.1.0 63 | Unidecode==1.2.0 64 | url-normalize==1.4.3 65 | urllib3==1.26.9 66 | virtualenv==20.24.0 67 | w3lib==1.21.0 68 | watchdog==2.1.6 69 | wget==3.2 70 | wheel==0.37.1 -------------------------------------------------------------------------------- /generate_description.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Tokenizer, Seq2SeqTrainingArguments 2 | from prot2text_dataset.torch_geometric_loader import Prot2TextDataset 3 | from prot2text_model.utils import Prot2TextTrainer 4 | from prot2text_model.Model import Prot2TextModel 5 | from prot2text_model.tokenization_prot2text import Prot2TextTokenizer 6 | import torch 7 | import os 8 | import argparse 9 | 10 | argParser = argparse.ArgumentParser() 11 | argParser.add_argument("--model_path", help="path to the prot2text model") 12 | argParser.add_argument("--protein_alphafold_id", default=None, help="the AlphaFold ID of the protein") 13 | argParser.add_argument("--protein_sequence", default=None, help="the amino-acid seuqence of the protein") 14 | 15 | # usage: 16 | # python generate_description.py \ 17 | # --model_path ./models/prot2text_base \ 18 | # --protein_alphafold_id \ 19 | # --protein_sequence 20 | 21 | 22 | args = argParser.parse_args() 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | tokenizer = Prot2TextTokenizer.from_pretrained(args.model_path) 26 | model = Prot2TextModel.from_pretrained(args.model_path) 27 | 28 | descrpition = model.generate_protein_description(protein_pdbID=args.protein_alphafold_id, 29 | protein_sequence=args.protein_sequence, 30 | tokenizer=tokenizer, 31 | device=device) 32 | print() 33 | print(descrpition) 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /train_prot2text.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=prot2text_base # name of job 3 | #SBATCH --output=prot2text%j.out # output file (%j = job ID) 4 | #SBATCH --error=prot2text%j.err # error file (%j = job ID) 5 | #SBATCH --constraint=v100-32g # reserve GPUs with 32 GB of RAM 6 | #SBATCH --nodes=16 7 | #SBATCH --ntasks-per-node=1 # reserve 4 tasks (or processes) 8 | #SBATCH --gres=gpu:4 # reserve 4 GPUs 9 | #SBATCH --cpus-per-task=10 # reserve 10 CPUs per task (and associated memory) 10 | 11 | set -x 12 | export GPUS_PER_NODE=4 13 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 14 | export MASTER_PORT=9901 15 | 16 | srun --jobid $SLURM_JOBID bash -c 'python -u -m torch.distributed.run \ 17 | --nproc_per_node $GPUS_PER_NODE \ 18 | --nnodes $SLURM_NNODES \ 19 | --node_rank $SLURM_PROCID \ 20 | --master_addr $MASTER_ADDR \ 21 | --master_port $MASTER_PORT \ 22 | train.py \ 23 | --decoder_path gpt2 \ 24 | --esm_model_path facebook/esm2_t12_35M_UR50D \ 25 | --use_plm \ 26 | --use_rgcn \ 27 | --warmup_esm \ 28 | --warmup_gpt \ 29 | --data_path ./data//dataset/ \ 30 | --train_csv_path ./data/train.csv \ 31 | --eval_csv_path ./data/eval.csv \ 32 | --batch_per_device 4 \ 33 | --nb_epochs 25 \ 34 | --nb_gpus \ 35 | --gradient_accumulation 1 \ 36 | --lr 2e-4 \ 37 | --save_model_path ./models/prot2text_base/ \ 38 | --bleu_evaluation' 39 | -------------------------------------------------------------------------------- /prot2text_dataset/utils_convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from biopandas.pdb import PandasPdb 3 | 4 | pdb_order = [ 5 | "record_name", 6 | "atom_number", 7 | "blank_1", 8 | "atom_name", 9 | "alt_loc", 10 | "residue_name", 11 | "blank_2", 12 | "chain_id", 13 | "residue_number", 14 | "insertion", 15 | "blank_3", 16 | "x_coord", 17 | "y_coord", 18 | "z_coord", 19 | "occupancy", 20 | "b_factor", 21 | "blank_4", 22 | "segment_id", 23 | "element_symbol", 24 | "charge", 25 | "line_idx", 26 | ] 27 | mmcif_read = { 28 | "group_PDB": "record_name", 29 | "id": "atom_number", 30 | "auth_atom_id": "atom_name", 31 | "auth_comp_id": "residue_name", 32 | "auth_asym_id": "chain_id", 33 | "auth_seq_id": "residue_number", 34 | "Cartn_x": "x_coord", 35 | "Cartn_y": "y_coord", 36 | "Cartn_z": "z_coord", 37 | "occupancy": "occupancy", 38 | "B_iso_or_equiv": "b_factor", 39 | "type_symbol": "element_symbol", 40 | } 41 | 42 | nonefields = [ 43 | "blank_1", 44 | "alt_loc", 45 | "blank_2", 46 | "insertion", 47 | "blank_3", 48 | "blank_4", 49 | "segment_id", 50 | "charge", 51 | "line_idx", 52 | ] 53 | 54 | 55 | def biopandas_mmcif2pdb(pandasmmcif, model_index = 1): 56 | """ 57 | Converts the ATOM and HETATM dataframes of PandasMmcif() to PandasPdb() format. 58 | """ 59 | pandaspdb = PandasPdb() 60 | for a in ["ATOM", "HETATM"]: 61 | dfa = pandasmmcif.df[a] 62 | dfa = dfa.loc[dfa.pdbx_PDB_model_num == model_index] 63 | if a =='ATOM': 64 | if len(dfa) == 0: 65 | raise ValueError(f"No model found for index: {model_index}") 66 | # keep only those fields found in pdb 67 | dfa = dfa[mmcif_read.keys()] 68 | # rename fields 69 | dfa = dfa.rename(columns=mmcif_read) 70 | # add empty fields 71 | for i in nonefields: 72 | dfa[i] = "" 73 | dfa["charge"] = np.nan 74 | # reorder columns to PandasPdb order 75 | dfa = dfa[pdb_order] 76 | pandaspdb.df[a] = dfa 77 | 78 | # update line_idx 79 | pandaspdb.df["ATOM"]["line_idx"] = pandaspdb.df["ATOM"].index.values 80 | pandaspdb.df["HETATM"]["line_idx"] = pandaspdb.df["HETATM"].index 81 | 82 | return pandaspdb -------------------------------------------------------------------------------- /prot2text_dataset/utils_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | 4 | def load_GO_annot(filename): 5 | # Load GO annotations 6 | onts = ['mf', 'bp', 'cc'] 7 | prot2annot = {} 8 | goterms = {ont: [] for ont in onts} 9 | gonames = {ont: [] for ont in onts} 10 | with open(filename, mode='r') as tsvfile: 11 | reader = csv.reader(tsvfile, delimiter='\t') 12 | 13 | # molecular function 14 | next(reader, None) # skip the headers 15 | goterms[onts[0]] = next(reader) 16 | next(reader, None) # skip the headers 17 | gonames[onts[0]] = next(reader) 18 | 19 | # biological process 20 | next(reader, None) # skip the headers 21 | goterms[onts[1]] = next(reader) 22 | next(reader, None) # skip the headers 23 | gonames[onts[1]] = next(reader) 24 | 25 | # cellular component 26 | next(reader, None) # skip the headers 27 | goterms[onts[2]] = next(reader) 28 | next(reader, None) # skip the headers 29 | gonames[onts[2]] = next(reader) 30 | 31 | next(reader, None) # skip the headers 32 | counts = {ont: np.zeros(len(goterms[ont]), dtype=float) for ont in onts} 33 | for row in reader: 34 | prot, prot_goterms = row[0], row[1:] 35 | prot2annot[prot] = {ont: [] for ont in onts} 36 | for i in range(3): 37 | goterm_indices = [goterms[onts[i]].index(goterm) for goterm in prot_goterms[i].split(',') if goterm != ''] 38 | prot2annot[prot][onts[i]] = np.zeros(len(goterms[onts[i]])) 39 | prot2annot[prot][onts[i]][goterm_indices] = 1.0 40 | counts[onts[i]][goterm_indices] += 1.0 41 | return prot2annot, goterms, gonames, counts 42 | 43 | 44 | def load_EC_annot(filename): 45 | # Load EC annotations """ 46 | prot2annot = {} 47 | with open(filename, mode='r') as tsvfile: 48 | reader = csv.reader(tsvfile, delimiter='\t') 49 | 50 | # molecular function 51 | next(reader, None) # skip the headers 52 | ec_numbers = {'ec': next(reader)} 53 | next(reader, None) # skip the headers 54 | counts = {'ec': np.zeros(len(ec_numbers['ec']), dtype=float)} 55 | for row in reader: 56 | prot, prot_ec_numbers = row[0], row[1] 57 | ec_indices = [ec_numbers['ec'].index(ec_num) for ec_num in prot_ec_numbers.split(',')] 58 | prot2annot[prot] = {'ec': np.zeros(len(ec_numbers['ec']), dtype=np.int64)} 59 | prot2annot[prot]['ec'][ec_indices] = 1.0 60 | counts['ec'][ec_indices] += 1 61 | -------------------------------------------------------------------------------- /prot2text_model/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from torch_geometric.nn import GINConv,GATv2Conv,TAGConv,ARMAConv,APPNP,MFConv, GMMConv,HypergraphConv,LEConv,PNAConv, GCNConv,SAGEConv, RGCNConv 6 | from torch_scatter import scatter_add, scatter_mean 7 | from typing import Optional, Tuple, Union 8 | from torch_geometric.nn import global_add_pool,global_mean_pool 9 | from torch.nn import init 10 | import random 11 | from torch_geometric.nn import MessagePassing 12 | from torch_geometric.nn import aggr 13 | from torch_geometric.utils import sort_edge_index 14 | import torch_geometric 15 | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig 16 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP, GPT2PreTrainedModel, PARALLELIZE_DOCSTRING, DEPARALLELIZE_DOCSTRING 17 | import transformers 18 | from transformers.utils import ( 19 | ModelOutput, 20 | add_code_sample_docstrings, 21 | add_start_docstrings, 22 | add_start_docstrings_to_model_forward, 23 | logging, 24 | replace_return_docstrings, 25 | ) 26 | 27 | class EncoderRGCN(PreTrainedModel): 28 | ''' 29 | This class implement the RGCN encoder to encode the protein structure 30 | ''' 31 | def __init__(self, input_dim, hidden_dim=512, n_layers=6, emb_dim=512, dropout=0.2, num_relation=7, prot2text_version='1.0'): 32 | super(EncoderRGCN, self).__init__(PretrainedConfig(name='RGCN')) 33 | self.n_layers = n_layers 34 | self.output_dim = emb_dim 35 | self.prot2text_version = prot2text_version 36 | 37 | self.fc0 = nn.Linear(input_dim, hidden_dim) 38 | self.batchnorm_final = nn.BatchNorm1d(hidden_dim) 39 | 40 | self.batch_norms = nn.ModuleList() 41 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 42 | lst = list() 43 | 44 | lst.append(RGCNConv(hidden_dim, hidden_dim, num_relations=num_relation)) 45 | 46 | for i in range(n_layers-1): 47 | lst.append(RGCNConv(hidden_dim,hidden_dim, num_relations=num_relation)) 48 | 49 | self.conv = nn.ModuleList(lst) 50 | 51 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 52 | self.fc2 = nn.Linear(hidden_dim, self.output_dim) 53 | 54 | self.dropout = nn.Dropout(p=dropout) 55 | self.relu = nn.LeakyReLU() 56 | self.batchnorm = nn.BatchNorm1d(hidden_dim) 57 | self.main_input_name = 'nothing' 58 | 59 | def forward(self, x:Optional[torch.FloatTensor] = None, 60 | edge_index:Optional[torch.LongTensor] = None, 61 | edge_type:Optional[torch.LongTensor] = None, 62 | batch:Optional[torch.LongTensor] = None, 63 | **kargs): 64 | #construct pyg edge index shape (2, num_edges) from edge_list 65 | x = self.relu(self.fc0(x)) 66 | 67 | for i in range(self.n_layers): 68 | x = self.conv[i](x, edge_index, edge_type) 69 | 70 | out = global_mean_pool(x, batch) 71 | out = self.relu(self.fc1(out)) 72 | out = self.relu(self.fc2(out)) 73 | 74 | return out.unsqueeze(1) -------------------------------------------------------------------------------- /prot2text_dataset/Inememorydataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Dataset, InMemoryDataset, Data 3 | from tqdm import tqdm 4 | from graphein.ml.conversion import convert_nx_to_pyg_data 5 | import json 6 | import numpy as np 7 | from functools import partial 8 | import multiprocessing 9 | import os 10 | 11 | def _mp_graph_constructor(args, features, map_ss): 12 | result = torch.load(args) 13 | g = Data(edge_index = result.edge_index, num_nodes = len(result.node_id), node_id = result.node_id, 14 | edge_attr = torch.cat((torch.FloatTensor(result.distance).reshape(-1,1), torch.tensor(result.ohe_kind)), dim=1), y = result.label, sequence = result.sequence_A, 15 | name = result.name) 16 | x = torch.cat((torch.FloatTensor(result.coords[0]), torch.FloatTensor(result.amino_acid_one_hot)), dim=1) 17 | for feat in features: 18 | if feat == 'ss': 19 | feature = np.zeros((x.shape[0],8)) 20 | for i in range(x.shape[0]): 21 | feature[i][map_ss[result[feat][i]]] = 1 22 | else: 23 | feature = np.array(result[feat]) 24 | if len(feature.shape)==1: 25 | feature = feature.reshape(-1,1) 26 | x = torch.cat((x, torch.FloatTensor(feature)), dim = 1) 27 | g.x = x 28 | return g 29 | 30 | class ARGDataset(InMemoryDataset): 31 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, features=[], edges=[]): 32 | with open(os.path.join(root, "list_elements_coala.json"), "r") as fp: 33 | self.structures = json.load(fp) 34 | self.edges = edges 35 | self.features = features 36 | super().__init__(root, transform, pre_transform, pre_filter) 37 | self.data, self.slices = torch.load(self.processed_paths[0]) 38 | 39 | @property 40 | def raw_file_names(self) : 41 | """Names of raw files in the dataset.""" 42 | return [f"data_{pdb}t10.pt" for pdb in self.structures] 43 | 44 | @property 45 | def processed_file_names(self): 46 | """Names of processed files to look for""" 47 | feat_name = "_".join(self.features) 48 | edge_name = "_".join(self.edges) 49 | return [f'data_{feat_name}_{edge_name}.pt'] 50 | 51 | def download(self): 52 | # Download to `self.raw_dir`. 53 | pass 54 | 55 | 56 | def process(self): 57 | map_ss = {'-':0, 'H':1, 'B':2, 'E':3, 'G':4, 'I':5, 'T':6, 'S':7} 58 | pool = multiprocessing.Pool(16) 59 | graphs = [] 60 | y = [] 61 | constructor = partial( 62 | _mp_graph_constructor, features=self.features, map_ss=map_ss 63 | ) 64 | for result in tqdm(pool.imap_unordered(constructor, self.raw_paths), total=len(self.raw_paths)): 65 | graphs.append(result) 66 | 67 | pool.close() 68 | pool.join() 69 | print("Converting Networkx graphs to PyG...") 70 | 71 | print("Saving Data...") 72 | data, slices = self.collate(graphs) 73 | torch.save((data, slices), self.processed_paths[0]) 74 | print("Done!") 75 | 76 | if __name__=='__main__': 77 | dataset= ARGDataset('../../PDBFiles', features = ['phi', 'psi','rsa', 'asa', 'b_factor', 'ss','hbond_acceptors', 'hbond_donors', 'expasy']) 78 | 79 | -------------------------------------------------------------------------------- /prot2text_dataset/torch_geometric_loader.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | import os 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | import json 9 | import os 10 | import pickle 11 | import random 12 | import time 13 | import warnings 14 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 15 | import scipy.sparse as sp 16 | import numpy as np 17 | import scipy 18 | from tqdm import tqdm 19 | from transformers.tokenization_utils import PreTrainedTokenizer 20 | from transformers.utils import logging 21 | import pandas as pd 22 | from transformers import DataCollatorForLanguageModeling, GPT2Tokenizer 23 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 24 | from transformers.utils import PaddingStrategy 25 | from torch_geometric.data import Dataset, download_url 26 | import torch 27 | from torch_geometric.loader import DataListLoader, DataLoader 28 | from torch_geometric.nn import DataParallel 29 | from tqdm import tqdm 30 | 31 | class Prot2TextDataset(Dataset): 32 | def __init__(self, 33 | root, 34 | tokenizer: PreTrainedTokenizer, 35 | file_path: str, 36 | block_size: int, 37 | split: str = "train", 38 | transform=None, 39 | pre_transform=None, 40 | pre_filter=None, 41 | esmtokenizer: PreTrainedTokenizer = None,): 42 | 43 | 44 | self.split_path = split 45 | self.files_names_folder = os.path.join(root, split, 'raw') 46 | self.files_names = os.listdir(self.files_names_folder) 47 | self.uniprot_csv = pd.read_csv(file_path) 48 | 49 | self.tokenizer = tokenizer 50 | self.esmtokenizer = esmtokenizer 51 | self.block_size = block_size 52 | print('dataset loading:') 53 | 54 | self.length = len(self.files_names) 55 | 56 | super().__init__(root, transform, pre_transform, pre_filter) 57 | 58 | @property 59 | def raw_file_names(self): 60 | return self.files_names 61 | 62 | @property 63 | def processed_file_names(self): 64 | return [f'data_{i}.pt' for i in range(self.length)] 65 | 66 | @property 67 | def raw_dir(self) -> str: 68 | return osp.join(self.root, self.split_path, 'raw') 69 | 70 | @property 71 | def processed_dir(self) -> str: 72 | return osp.join(self.root, self.split_path, 'processed') 73 | 74 | 75 | def process(self): 76 | idx = 0 77 | print("length:",len(self.files_names)) 78 | for i in tqdm(range(len(self.files_names))): 79 | try: 80 | graph = torch.load(os.path.join(self.root, self.split_path, "raw",self.files_names[i])) 81 | function = '<|graph_token|> '+self.uniprot_csv.loc[self.uniprot_csv['accession'] == self.files_names[i].split("-")[1]]["function"].values[0]+' <|stop_token|> ' 82 | sequence = self.uniprot_csv.loc[self.uniprot_csv['accession'] == self.files_names[i].split("-")[1]]["sequence"].values[0] 83 | 84 | text = self.tokenizer([function], add_special_tokens=True, truncation=True, max_length=self.block_size, padding='max_length', return_tensors="pt") 85 | seq = self.esmtokenizer([sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt") 86 | 87 | graph.encoder_input_ids = seq['input_ids'] 88 | graph.attention_mask = seq['attention_mask'] 89 | graph.decoder_input_ids = text['input_ids'] 90 | graph.decoder_attention_mask = text['attention_mask'] 91 | labels = text['input_ids'].clone() 92 | labels[labels == self.tokenizer.pad_token_id] = -100 93 | graph.labels = labels 94 | graph.edge_type = graph.edge_type.transpose(0,1) 95 | 96 | torch.save(graph, osp.join(self.processed_dir ,f'data_{idx}.pt')) 97 | idx = idx + 1 98 | except: 99 | print('error loading ', self.files_names[i]) 100 | print("don't forget to delete it from raw files to avoid error") 101 | 102 | 103 | def len(self): 104 | return len(self.processed_file_names) 105 | 106 | def __len__(self): 107 | return len(self.processed_file_names) 108 | 109 | def get(self, idx): 110 | data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) 111 | return data 112 | 113 | def download(self): 114 | pass 115 | 116 | def __cat_dim__(self, key, value, args, *kwargs): 117 | if 'index' in key or 'face' in key or 'edge_type' in key: 118 | return 1 119 | else: 120 | return 0 -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | from prot2text_dataset.pdb2graph import * 2 | from prot2text_dataset.utils_dataset import * 3 | import prot2text_dataset.graphs 4 | import wget 5 | from tqdm import tqdm 6 | import os 7 | import argparse 8 | from functools import partial 9 | from transformers import AutoTokenizer 10 | from prot2text_dataset.torch_geometric_loader import Prot2TextDataset 11 | from graphein.protein.config import ProteinGraphConfig, DSSPConfig 12 | from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor 13 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure 14 | from graphein.protein.edges.distance import (add_peptide_bonds, 15 | add_hydrogen_bond_interactions, 16 | add_disulfide_interactions, 17 | add_ionic_interactions, 18 | add_delaunay_triangulation, 19 | add_distance_threshold, 20 | add_sequence_distance_edges, 21 | add_k_nn_edges) 22 | 23 | argParser = argparse.ArgumentParser() 24 | argParser.add_argument("--data_save_path", help="folder to save the dataset") 25 | argParser.add_argument("--csv_path", help="csv containing the protein dataset") 26 | argParser.add_argument("--split", help="train, test or eval csv?") 27 | argParser.add_argument("--plm_model", help="protein model to use (from hugging face)") 28 | argParser.add_argument("--decoder_model", help="language model to use (from hugging face)") 29 | 30 | # usage: 31 | # python prepare_dataset.py \ 32 | # --data_save_path ./data/dataset/ \ 33 | # --split test --csv_path ./data/test.csv \ 34 | # --plm_model facebook/esm2_t12_35M_UR50D \ 35 | # --decoder_model gpt2 36 | 37 | args = argParser.parse_args() 38 | 39 | 40 | # step 1: download the PDB files from AlphaFoldDB 41 | isExist = os.path.exists(os.path.join(args.data_save_path, args.split)) 42 | if not isExist: 43 | os.makedirs(os.path.join(args.data_save_path, args.split, 'pdb')) 44 | os.makedirs(os.path.join(args.data_save_path, args.split, 'raw')) 45 | os.makedirs(os.path.join(args.data_save_path, args.split, 'processed')) 46 | 47 | print('downloading the data:\n') 48 | 49 | df = pd.read_csv(args.csv_path) 50 | 51 | pdb_path = os.path.join(args.data_save_path, args.split, 'pdb') 52 | for prot in tqdm(set(df.AlphaFoldDB)): 53 | if os.path.exists(os.path.join(pdb_path, 'AF-'+str(prot)+'-F1-model_v4.pdb')): 54 | continue 55 | download_alphafold_structure(uniprot_id=str(prot), out_dir=pdb_path) 56 | 57 | # step 2: construct graphs from the pdb files 58 | print('constructing the graphs:\n') 59 | if len(os.listdir(os.path.join(args.data_save_path, args.split, 'raw'))) == len(os.listdir(os.path.join(args.data_save_path, args.split, 'pdb'))): 60 | print('graphs already created') 61 | else: 62 | config = {"node_metadata_functions": [amino_acid_one_hot, 63 | expasy_protein_scale, 64 | meiler_embedding, 65 | hydrogen_bond_acceptor, 66 | hydrogen_bond_donor 67 | ], 68 | "edge_construction_functions": [add_peptide_bonds, 69 | add_hydrogen_bond_interactions, 70 | partial(add_distance_threshold, 71 | long_interaction_threshold=3, 72 | threshold=10.),], 73 | "graph_metadata_functions":[asa, phi, psi, secondary_structure, rsa], 74 | "dssp_config": DSSPConfig(),} 75 | config = ProteinGraphConfig(**config) 76 | PDB2Graph(root = pdb_path, 77 | output_folder = os.path.join(args.data_save_path, args.split, 'raw'), 78 | config=config, n_processors=32).process() 79 | 80 | # step 3: process the dataset 81 | esm_tokenizer = AutoTokenizer.from_pretrained(args.plm_model) 82 | tokenizer = AutoTokenizer.from_pretrained(args.decoder_model) 83 | SPECIAL_TOKEN = '<|graph_token|>' 84 | tokenizer.pad_token = tokenizer.eos_token 85 | tokenizer.pad_token = 50256 86 | tokenizer.add_tokens([SPECIAL_TOKEN]) 87 | SPECIAL_TOKEN = '<|stop_token|>' 88 | tokenizer.add_tokens([SPECIAL_TOKEN]) 89 | tokenizer.eos_token = '<|stop_token|>' 90 | tokenizer.eos_token_id = 50258 91 | tokenizer.bos_token_id = 50257 92 | 93 | dataset = Prot2TextDataset(root=args.data_save_path, 94 | tokenizer=tokenizer, 95 | file_path=args.csv_path, 96 | block_size=256, 97 | split=args.split, 98 | esmtokenizer=esm_tokenizer) 99 | -------------------------------------------------------------------------------- /evaluate_prot2text.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Tokenizer, Seq2SeqTrainingArguments 2 | from prot2text_dataset.torch_geometric_loader import Prot2TextDataset 3 | from prot2text_model.utils import Prot2TextTrainer 4 | from prot2text_model.Model import Prot2TextModel 5 | from prot2text_model.tokenization_prot2text import Prot2TextTokenizer 6 | import evaluate 7 | from torch_geometric.loader import DataLoader 8 | import pandas as pd 9 | from transformers.utils import logging 10 | from tqdm import tqdm 11 | import torch 12 | import os 13 | import argparse 14 | 15 | argParser = argparse.ArgumentParser() 16 | argParser.add_argument("--model_path", help="path to the prot2text model") 17 | argParser.add_argument("--data_path", help="root folder of the data") 18 | argParser.add_argument("--csv_path", help="csv containing the protein dataset to evaluate") 19 | argParser.add_argument("--split", help="train, test or eval csv?") 20 | argParser.add_argument("--batch_per_device", help="batch size for each device") 21 | argParser.add_argument("--save_results_path", help="path to save the generated description") 22 | 23 | # usage for single GPU: 24 | # python evaluate_prot2text.py \ 25 | # --model_path ./models/prot2text_base \ 26 | # --data_path ./data/dataset/ \ 27 | # --split test \ 28 | # --csv_path ./data/test.csv \ 29 | # --batch_per_device 4 \ 30 | # --save_results_path ./results/prot2text_base_results.csv 31 | 32 | # usage for multiple GPUs: 33 | # python -u -m torch.distributed.run --nproc_per_node --nnodes --node_rank 0 evaluate_prot2text.py \ 34 | # --model_path ./models/prot2text_base \ 35 | # --data_path ./data/dataset/ \ 36 | # --split test \ 37 | # --csv_path ./data/test.csv \ 38 | # --batch_per_device 4 \ 39 | # --save_results_path ./results/prot2text_base_results.csv 40 | 41 | args = argParser.parse_args() 42 | 43 | tokenizer = Prot2TextTokenizer.from_pretrained(args.model_path) 44 | 45 | model = Prot2TextModel.from_pretrained(args.model_path) 46 | eval_dataset = Prot2TextDataset(root=args.data_path, 47 | tokenizer=tokenizer, 48 | file_path=args.csv_path, 49 | block_size=256, 50 | split=args.split) 51 | print('eval set loaded') 52 | 53 | batch_size = int(args.batch_per_device) 54 | model.eval() 55 | bleu = evaluate.load("bleu") 56 | rouge = evaluate.load("rouge") 57 | bert_score = evaluate.load("bertscore") 58 | 59 | args_seq = Seq2SeqTrainingArguments(output_dir='./', per_device_eval_batch_size=batch_size) 60 | trainer = Prot2TextTrainer(model=model, args=args_seq, eval_dataset=eval_dataset) 61 | 62 | d = trainer.get_eval_dataloader() 63 | 64 | if torch.distributed.is_initialized(): 65 | if torch.distributed.get_rank()==0: 66 | if os.path.exists(args.save_results_path): 67 | os.remove(args.save_results_path) 68 | else: 69 | if os.path.exists(args.save_results_path): 70 | os.remove(args.save_results_path) 71 | 72 | names = list() 73 | generated = list() 74 | functions = list() 75 | 76 | for inputs in tqdm(d): 77 | inputs = inputs.to_dict() 78 | inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0) 79 | inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1) 80 | names += inputs['name'] 81 | functions += tokenizer.batch_decode(inputs['decoder_input_ids'], skip_special_tokens=True) 82 | inputs['decoder_input_ids'] = inputs['decoder_input_ids'][:,0:1] 83 | inputs["decoder_attention_mask"] = torch.ones(inputs['decoder_input_ids'].shape[0], 1) 84 | inputs = {k: v.to(device=torch.cuda.current_device(), non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} 85 | tok_ids = model.generate(inputs=None, **inputs, 86 | num_beams=1, 87 | early_stopping=False, 88 | no_repeat_ngram_size=None, 89 | length_penalty=1.0) 90 | generated += tokenizer.batch_decode(tok_ids, skip_special_tokens=True) 91 | 92 | data= {'name':names, 'generated': generated, 'function':functions} 93 | df = pd.DataFrame(data) 94 | df.to_csv(args.save_results_path, index=False, mode='a') 95 | 96 | if torch.distributed.is_initialized(): 97 | torch.distributed.barrier() 98 | if torch.distributed.get_rank() > 0: 99 | exit(0) 100 | res = pd.read_csv(args.save_results_path).drop_duplicates() 101 | res = res.drop(res[res['name'] == 'name'].index) 102 | 103 | res_bleu = bleu.compute(predictions=res['generated'].tolist(), references=res['function'].tolist()) 104 | res_rouge = rouge.compute(predictions=res['generated'].tolist(), references=res['function'].tolist()) 105 | res_bertscore = bert_score.compute(predictions=res['generated'].tolist(), references=res['function'].tolist(), 106 | model_type="dmis-lab/biobert-large-cased-v1.1", num_layers=24) 107 | print(res_bleu) 108 | print(res_rouge) 109 | def Average(lst): 110 | return sum(lst) / len(lst) 111 | print('Bert Score: ', Average(res_bertscore['f1'])) -------------------------------------------------------------------------------- /prot2text_dataset/pdb2graph.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from tqdm import tqdm 4 | from sklearn.preprocessing import MultiLabelBinarizer 5 | 6 | from torch_geometric.data import Data 7 | import torch 8 | 9 | import numpy as np 10 | 11 | from .conversion import convert_nx_to_pyg_data 12 | from graphein.protein.config import ProteinGraphConfig, DSSPConfig 13 | from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor 14 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure 15 | from graphein.protein.edges.distance import (add_peptide_bonds, 16 | add_hydrogen_bond_interactions, 17 | add_disulfide_interactions, 18 | add_ionic_interactions, 19 | add_delaunay_triangulation, 20 | add_distance_threshold, 21 | add_sequence_distance_edges, 22 | add_k_nn_edges) 23 | 24 | from functools import partial 25 | from .graphs import * 26 | from .utils_dataset import * 27 | import os 28 | import sys 29 | import subprocess 30 | import wget 31 | 32 | 33 | class PDB2Graph(): 34 | def __init__(self, root, output_folder, config, n_processors=int(multiprocessing.cpu_count())): 35 | self.root = root 36 | self.output_folder = output_folder 37 | self.map_secondary_structure = {'-':0, 'H':1, 'B':2, 'E':3, 'G':4, 'I':5, 'T':6, 'S':7} 38 | self.init_ohe_edge_type() 39 | self.config = config 40 | self.features = ['phi', 'psi', 'rsa', 'asa', 'ss', 'expasy'] 41 | self.n_processors = n_processors 42 | self.raw_dir = root 43 | self.processed_dir = self._processed_dir() 44 | self.raw_file_names = self._raw_file_names() 45 | self.processed_file_names = self._processed_file_names() 46 | 47 | 48 | def _processed_dir(self): 49 | #processed_dir = os.path.join(os.path.split(self.root)[0], "processed_new") 50 | if not os.path.exists(self.output_folder): 51 | os.makedirs(self.output_folder) 52 | return self.output_folder 53 | 54 | def _raw_file_names(self): 55 | return os.listdir(self.raw_dir) 56 | 57 | def _processed_file_names(self): 58 | return [self.pdb2pathdata(pdb_path.split(".")[0]) for pdb_path in self.raw_file_names] 59 | 60 | def create_nx_graph(self, path_to_structure): 61 | return construct_graph(self.config, pdb_path = path_to_structure) 62 | 63 | def create_pyg_graph(self, path_to_structure): 64 | pyg_graph = convert_nx_to_pyg_data(self.create_nx_graph(path_to_structure)) 65 | 66 | graph = Data(edge_index = pyg_graph.edge_index, 67 | num_nodes = len(pyg_graph.node_id), 68 | node_id = pyg_graph.node_id, 69 | name = pyg_graph.name[0], 70 | sequence = getattr(pyg_graph, f"sequence_{pyg_graph.chain_id[0]}"), 71 | distance_matrix = pyg_graph.dist_mat, 72 | distance = pyg_graph.distance, 73 | coordinates = torch.FloatTensor(np.array(pyg_graph.coords[0]))) 74 | #create the features 75 | x = np.array([np.argmax(pyg_graph.amino_acid_one_hot, axis=1)]).reshape(-1,1) 76 | for feat in self.features: 77 | if feat == "ss": 78 | feature = np.array([[self.map_secondary_structure.get(feat_node, 0)] \ 79 | for feat_node in pyg_graph[feat]]) 80 | else: 81 | feature = np.array(pyg_graph[feat]) 82 | if len(feature.shape) == 1: 83 | feature = feature.reshape(-1,1) 84 | x = np.concatenate((x, feature), axis = 1) 85 | graph.edge_type = self.mlb.transform(pyg_graph.kind) 86 | graph.x = torch.FloatTensor(x) 87 | # y = self.annotations[graph.name.split("_")[0]] 88 | # if self.task == 'GeneOntology' : 89 | # graph.y_mf = torch.FloatTensor(y["mf"]) 90 | # graph.y_cc = torch.FloatTensor(y["cc"]) 91 | # graph.y_bp = torch.FloatTensor(y["bp"]) 92 | # else: 93 | # graph.y_ec = torch.FloatTensor(y["ec"]) 94 | return graph 95 | 96 | def init_ohe_edge_type(self): 97 | self.mlb = MultiLabelBinarizer(classes = ['peptide_bond', 'sequence_distance_2', 'sequence_distance_3' 98 | , 'distance_threshold', 'delaunay', 'hbond', 'k_nn']) 99 | self.mlb.fit([['peptide_bond', 'sequence_distance_2', 'sequence_distance_3' 100 | , 'distance_threshold', 'delaunay', 'hbond', 'k_nn']]) 101 | 102 | def process(self): 103 | """Convert the PDB files into torch geometric graphs""" 104 | # self.pdb2graph = PDB2Graph(self.config) 105 | to_be_processed = self.get_files_to_process() 106 | 107 | # pool = multiprocessing.Pool(self.n_processors) 108 | # for _ in tqdm(pool.imap_unordered(self.graph_creation, to_be_processed), total=len(to_be_processed)): 109 | # continue 110 | # pool.close() 111 | # pool.join() 112 | 113 | 114 | 115 | processes = [] 116 | for prot in tqdm(to_be_processed): 117 | p = multiprocessing.Process(target=self.graph_creation, args=(prot,)) 118 | processes.append(p) 119 | p.start() 120 | 121 | for process in processes: 122 | process.join() 123 | 124 | 125 | def graph_creation(self, pdb): 126 | """Create a graph from the PDB file""" 127 | 128 | # Define the path_to_structure from the pdb name file 129 | path_to_structure = self.pdb2pathstructure(pdb) 130 | 131 | # Convert the structure into a graph 132 | g = self.create_pyg_graph(path_to_structure) 133 | # Save the graph 134 | torch.save(g, os.path.join(self.output_folder, self.pdb2pathdata(pdb))) 135 | 136 | return None 137 | 138 | def pdb2pathdata(self, pdb): 139 | return pdb+'.pt' 140 | 141 | def pdb2pathstructure(self, pdb): 142 | return os.path.join(self.raw_dir, pdb+'.pdb') 143 | 144 | def get_files_to_process(self): 145 | RAW_FILES = self.processed_file_names 146 | PROCESSED_FILES = os.listdir(self.processed_dir) 147 | to_be_processed = set(RAW_FILES).difference(set(PROCESSED_FILES)) 148 | to_be_processed = [path.split('.')[0] for path in to_be_processed] 149 | return to_be_processed 150 | 151 | def download_alphafold_structure( 152 | uniprot_id: str, 153 | out_dir: str, 154 | version: int = 4 155 | ): 156 | 157 | BASE_URL = "https://alphafold.ebi.ac.uk/files/" 158 | uniprot_id = uniprot_id.upper() 159 | 160 | query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb" 161 | structure_filename = os.path.join(out_dir, f"AF-{uniprot_id}-F1-model_v{version}.pdb") 162 | if os.path.exists(structure_filename): 163 | return structure_filename 164 | try: 165 | structure_filename = wget.download(query_url, out=out_dir) 166 | except: 167 | print('Error.. could not download: ', f"AF-{uniprot_id}-F1-model_v{version}.pdb") 168 | return None 169 | return structure_filename 170 | 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prot2Text: Multimodal Protein’s Function Generation with GNNs and Transformers 2 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 3 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 4 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 5 | 6 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 8 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg 9 | 10 | September 2024: Prot2Text is now on [HuggingFace](https://huggingface.co/collections/habdine/prot2text-suite-66e48fe3596fcff3e41be4e7) 11 | 12 | This repository contains the code to reproduce the results of the paper "Prot2Text: Multimodal Protein's Function Generation with GNNs and Transformers" by Hadi Abdine, Michail Chatzianastasis, Costas Bouyioukos and Michalis Vazirgiannis, accepted at AAAI 2024. [Preprint Link](https://arxiv.org/abs/2307.14367) 13 | 14 | Preliminary versions of the paper were accepted as a spotlight at [DGM4H@NeurIPS 2023](https://sites.google.com/ethz.ch/dgm4h-neurips2023/home?authuser=0) and [AI4Science@NeurIPS 2023](https://ai4sciencecommunity.github.io/neurips23.html). 15 | 16 | A demo web app for protein description generation could also be tested here [nlp.polytechnique.fr/prot2text](http://nlp.polytechnique.fr/prot2text#proteins). 17 | 18 | ![](Prot2Text.drawio.png) 19 | 20 | ## Setup 21 | #### Environment Setup 22 | 23 | The recommended environment is Python >= 3.8 and PyTorch 1.13, although other versions of Python and PyTorch may also work. 24 | 25 | To prepare the environment we need to do the following steps: 26 | 1- Install pytorch 1.13.*, pytorch-geometric and its optional dependencies according to your cuda version using the following links (newer versions work): 27 | - pytorch: [https://pytorch.org/get-started/previous-versions/](https://pytorch.org/get-started/previous-versions/) 28 | - pytorch-geometric and its optional dependencies: [https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) 29 | 30 | 2- Install the DSSP library (v3.0.0): [https://anaconda.org/salilab/dssp](https://anaconda.org/salilab/dssp) or newer [https://ssbio.readthedocs.io/en/latest/instructions/dssp.html](https://ssbio.readthedocs.io/en/latest/instructions/dssp.html) (the experiments in this study were done using dssp v3.0.0) 31 | 32 | 3- Install the rest of the requirements: 33 | ```bash 34 | pip install -r Requirements.txt 35 | ``` 36 | The models are trained and tested using transformers v4.30.1. However, the repo is updated lately to work with transformers v4.40.2 as well. 37 | 38 | 4- Install graphein library from source or version >1.7.6 (if released): 39 | ```bash 40 | git clone https://github.com/a-r-j/graphein.git 41 | pip install -e graphein/ 42 | ``` 43 | 44 | #### Datasets Preparation 45 | | Dataset | Size | Link | 46 | |:-----------:|:-------:|:------------:| 47 | | Train | 248 315 | [Download](https://nuage.lix.polytechnique.fr/index.php/s/QbPoM9fcAsQPRez) | 48 | | Validation | 4 172 | [Download](https://nuage.lix.polytechnique.fr/index.php/s/CwmKBe3gF6sY7XE) | 49 | | Test | 4 203 | [Download](https://nuage.lix.polytechnique.fr/index.php/s/64s7PM9aPFZENCT) | 50 | 51 | After downloading the CSV file for each split, the PDB files will be downloaded from AlphaFoldDB and then preprocessed to extract the graph, the tokenized amino-acid sequence and the tokenized protein description for each PDB file. 52 | Example for the test split: 53 | 54 | ```bash 55 | python prepare_dataset.py \ 56 | --data_save_path ./data/dataset/ \ 57 | --split test \ 58 | --csv_path ./data/test.csv \ 59 | --plm_model facebook/esm2_t12_35M_UR50D \ 60 | --decoder_model gpt2 61 | ``` 62 | where: 63 | * `--data_save_path`: the path for the folder where the PDB files will be downloaded from AlphaFold and the graphs alongside the tokenized text sequence and the amino-acid sequence will be stored. 64 | * `--split`: specify which split to preprocess (`train`, `eval`, or `test`). 65 | * `--csv_path`: the path to a CSV file downloaded from the previous table. 66 | * `--plm_model`: the HuggingFace protein language model path (used for amino-acid sequence tokenization). 67 | * `--decoder_model`: the HuggingFace decoder model path (GPT-like model) (used for natural text tokenization). 68 | 69 | ## Models 70 | 71 | | Model | #params | BLEU Score | BERT Score | Link | 72 | |:--------------------------:|:--------:|:-----------:|:-----------:|:------------:| 73 | | Prot2TextSMALL | 256M | 30.01 | 82.60 | [Download v1.0](https://1drv.ms/u/s!AhcBGHWGY2muke8KlDdP__DNxHhB1g?e=5HHhtn) [Download v1.1](https://1drv.ms/u/s!AhcBGHWGY2muke5r13_Ew0mP_XkmCw?e=Daud3y) | 74 | | Prot2TextBASE | 283M | 35.11 | 84.30 | [Download v1.0](https://1drv.ms/u/s!AhcBGHWGY2muke8JOhbc2e4zGCNCiA?e=XGsl7R) [Download v1.1](https://1drv.ms/u/s!AhcBGHWGY2muke5pg1mohZpURZNgSQ?e=53BjOq) | 75 | | Prot2TextMEDIUM| 398M | 36.51 | 84.83 | [Download v1.0](https://1drv.ms/u/s!AhcBGHWGY2muke8LWGEjPfeIb0B-iA?e=4swFgn) [Download v1.1](https://1drv.ms/u/s!AhcBGHWGY2muke5s7RQUW45fK8BAjQ?e=zn7cvn) | 76 | | Prot2TextLARGE | 898M | 36.29 | 85.20 | [Download v1.0](https://1drv.ms/u/s!AhcBGHWGY2muke8MfBS7eC-NOHoKQQ?e=jHU9PT) [Download v1.1](https://1drv.ms/u/s!AhcBGHWGY2muke5tU-jgrZmNpOAXlg?e=p61QAr) | 77 | | Esm2TextBASE | 225M | 32.11 | 83.21 | [Download v1.0](https://1drv.ms/u/s!AhcBGHWGY2muke8Icx929RSAmsLRNw?e=jxxW3g) [Download v1.1](https://1drv.ms/u/s!AhcBGHWGY2muke5qine4uqzRO_nDAQ?e=TSwrCv) | 78 | 79 | The reported results are computed using v1.0 (from the original paper). v1.1 uses the same architecture, it is only trained with some fixed bugs in the code. It has a similar performance to v1.0. 80 | 81 | #### Protein Description Generation 82 | To generate the description of a protein using any Prot2Text model (using `--model_path`) you need to specify the protein AlphaFoldDB ID (using `--protein_alphafold_id`) and have an internet connection in order to download the structure: 83 | ``` 84 | python generate_description.py \ 85 | --model_path ./models/prot2text_base \ 86 | --protein_alphafold_id P36108 87 | ``` 88 | 89 | You can also use the Esm2Text model to generate protein description based only on the amino-acid sequence (using `--protein_sequence`): 90 | ``` 91 | python generate_description.py \ 92 | --model_path ./models/esm2text_base \ 93 | --protein_sequence AEQAERYEEMVEFMEKL 94 | ``` 95 | 96 | #### Training Prot2Text 97 | To train Prot2Text model on a single GPU: 98 | ``` 99 | python train.py \ 100 | --decoder_path gpt2 \ 101 | --esm_model_path facebook/esm2_t12_35M_UR50D \ 102 | --use_plm \ 103 | --use_rgcn \ 104 | --warmup_esm \ 105 | --warmup_gpt \ 106 | --data_path ./data//dataset/ \ 107 | --train_csv_path ./data/train.csv \ 108 | --eval_csv_path ./data/eval.csv \ 109 | --batch_per_device 4 \ 110 | --nb_epochs 25 \ 111 | --nb_gpus 1 \ 112 | --gradient_accumulation 64 \ 113 | --lr 2e-4 \ 114 | --save_model_path ./models/prot2text_base/ \ 115 | --bleu_evaluation 116 | ``` 117 | where: 118 | * `--decoder_path`: the HuggingFace text decoder model path (GPT-like model, i.e. `gpt2`, `gpt2-medium`) from which the decoder architecture will be used. 119 | * `--esm_model_path`: the HuggingFace protein language model path (i.e. `facebook/esm2_t12_35M_UR50D`, `facebook/esm2_t30_150M_UR50D`) from which the PLM architecture will be used. 120 | * `--use_plm`: whether to use PLM model in the encoder or not. if set, you need to pass the PLM path using `--esm_model_path`. 121 | * `--use_rgcn`: whether to use RGCN to encode the structure of the protein or not. At least one of `--use_rgcn` and `--use_plm` must be used. 122 | * `--warmup_esm`: if set, the PLM model weights will be initialized using the HuggingFace checkpoint, otherwise the weights will be initialized randomly. 123 | * `--warmup_gpt`: if set, the decoder model weights will be initialized using the HuggingFace checkpoint, otherwise the weights will be initialized randomly. 124 | * `--data_path`: the path used to prepare the dataset (PDB files downloading and preprocessing). This will require running `prepare_dataset.py` on both `train` and `eval` splits. 125 | * `--train_csv_path`: the path to the training CSV file downloaded from the first table. 126 | * `--eval_csv_path`: the path to the validation CSV file downloaded from the first table. 127 | * `--batch_per_device`: the batch size to be used on each GPU. 128 | * `--nb_epochs`: the number of training epochs. 129 | * `--nb_gpus`: the total number of GPUs used during the training (to compute the warming up steps). 130 | * `--gradient_accumulation`: the gradient accumulation steps required to perform the optimization step. 131 | * `--lr`: the learning rate to be used. 132 | * `--save_model_path`: the path to save the model. 133 | * `--bleu_evaluation`: if used, the model selection will be based on the best BLEU score on the validation dataset, otherwise the CLM (causal language modeling) loss will be used for model selection. 134 | 135 | To train Prot2Text model on multiple GPUs: 136 | ``` 137 | python -u -m torch.distributed.run --nproc_per_node --nnodes --node_rank 0 train.py \ 138 | --decoder_path gpt2 \ 139 | --esm_model_path facebook/esm2_t12_35M_UR50D \ 140 | --use_plm \ 141 | --use_rgcn \ 142 | --warmup_esm \ 143 | --warmup_gpt \ 144 | --data_path ./data//dataset/ \ 145 | --train_csv_path ./data/train.csv \ 146 | --eval_csv_path ./data/eval.csv \ 147 | --batch_per_device 4 \ 148 | --nb_epochs 25 \ 149 | --nb_gpus \ 150 | --gradient_accumulation 1 \ 151 | --lr 2e-4 \ 152 | --save_model_path ./models/prot2text_base/ \ 153 | --bleu_evaluation 154 | ``` 155 | An example script for distributed training using SLURM can be also found in this repository. 156 | 157 | 158 | #### Evaluation 159 | 160 | To evaluate Prot2Text model (using `--model_path`) on the test set using a single GPU: 161 | ``` 162 | python evaluate_prot2text.py \ 163 | --model_path ./models/prot2text_base \ 164 | --data_path ./data/dataset/ \ 165 | --split test \ 166 | --csv_path ./data/test.csv \ 167 | --batch_per_device 4 \ 168 | --save_results_path ./results/prot2text_base_results.csv 169 | ``` 170 | 171 | 172 | To evaluate Prot2Text model on multiple GPUs: 173 | ``` 174 | python -u -m torch.distributed.run --nproc_per_node --nnodes --node_rank 0 evaluate_prot2text.py \ 175 | --model_path ./models/prot2text_base \ 176 | --data_path ./data/dataset/ \ 177 | --split test \ 178 | --csv_path ./data/test.csv \ 179 | --batch_per_device 4 \ 180 | --save_results_path ./results/prot2text_base_results.csv 181 | ``` 182 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from transformers import DataCollatorForLanguageModeling 2 | from transformers import GPT2Tokenizer 3 | from typing import Optional, Tuple, Union, TYPE_CHECKING, Any, Callable, Dict, List 4 | import torch 5 | from transformers import GPT2Config, AutoConfig, AutoTokenizer 6 | from transformers import GPT2LMHeadModel, GPT2Model, PretrainedConfig 7 | import transformers 8 | from prot2text_model.Encoder import EncoderRGCN 9 | from prot2text_dataset.torch_geometric_loader import Prot2TextDataset 10 | from prot2text_model.utils import Prot2TextTrainer, CABlock, _GPT2LMHeadModel 11 | from prot2text_model.Model import Prot2TextModel 12 | from prot2text_model.tokenization_prot2text import Prot2TextTokenizer 13 | import torch.nn as nn 14 | from transformers import EvalPrediction, Seq2SeqTrainingArguments 15 | from transformers.trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType 16 | import evaluate 17 | from torch.utils.data.distributed import DistributedSampler 18 | from torch_geometric.data import Dataset, download_url 19 | from torch_geometric.loader import DataListLoader, DataLoader 20 | from torch_geometric.nn import DataParallel 21 | from torch.nn.parallel import DistributedDataParallel 22 | from torch.utils.data.distributed import DistributedSampler 23 | import pandas as pd 24 | from transformers.trainer_utils import PredictionOutput 25 | from transformers.utils import logging 26 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP 27 | import os 28 | import argparse 29 | 30 | argParser = argparse.ArgumentParser() 31 | argParser.add_argument("--decoder_path", type=str, help="path to the gpt2 model to use (hugging face). options: gpt2, gpt2-medium, gpt2-large..") 32 | argParser.add_argument("--esm_model_path", type=str, help="path to esm model to use. example: facebook/esm2_t12_35M_UR50D") 33 | argParser.add_argument("--use_plm", action='store_true', help="True or False. (use or not protein language model in the encoder)") 34 | argParser.add_argument("--use_rgcn", action='store_true', help="True or False. (use or not RGCN in the encoder)") 35 | argParser.add_argument("--warmup_esm", action='store_true', help="True or False.") 36 | argParser.add_argument("--warmup_gpt", action='store_true', help="True or False.") 37 | argParser.add_argument("--data_path", type=str, default='./data//dataset/', help="root folder of the data") 38 | argParser.add_argument("--train_csv_path", type=str, default='./data/train.csv', help="csv containing the protein dataset for training") 39 | argParser.add_argument("--eval_csv_path", type=str, default='./data/eval.csv', help="csv containing the protein dataset for evaluation") 40 | argParser.add_argument("--batch_per_device", type=int, default=4, help="batch size for each device") 41 | argParser.add_argument("--nb_epochs", type=int, default=1, help="number of epochs") 42 | argParser.add_argument("--nb_gpus", type=int, default=1, help="number of GPUs") 43 | argParser.add_argument("--gradient_accumulation", default=1, help="gradient accumuluation") 44 | argParser.add_argument("--lr", type=float, default=2e-4, help="learning rate") 45 | argParser.add_argument("--save_model_path", type=str, default='./models/model_test/', help="path to save the model and the checkpoints") 46 | argParser.add_argument("--bleu_evaluation", action='store_true', help="True or False") 47 | 48 | # usage for single GPU: 49 | # python train.py \ 50 | # --decoder_path gpt2 \ 51 | # --esm_model_path facebook/esm2_t12_35M_UR50D \ 52 | # --use_plm \ 53 | # --use_rgcn \ 54 | # --warmup_esm \ 55 | # --warmup_gpt \ 56 | # --data_path ./data/dataset/ \ 57 | # --train_csv_path ./data/train.csv \ 58 | # --eval_csv_path ./data/eval.csv \ 59 | # --batch_per_device 4 \ 60 | # --nb_epochs 25 \ 61 | # --nb_gpus 1 \ 62 | # --gradient_accumulation 64 \ 63 | # --lr 2e-4 \ 64 | # --save_model_path ./models/prot2text_base/ \ 65 | # --bleu_evaluation \ 66 | 67 | 68 | # usage for multiple GPUs: 69 | # python -u -m torch.distributed.run --nproc_per_node --nnodes --node_rank 0 train.py \ 70 | # --decoder_path gpt2 \ 71 | # --esm_model_path facebook/esm2_t12_35M_UR50D \ 72 | # --use_plm \ 73 | # --use_rgcn \ 74 | # --warmup_esm \ 75 | # --warmup_gpt \ 76 | # --data_path ./data/dataset/ \ 77 | # --train_csv_path ./data/train.csv \ 78 | # --eval_csv_path ./data/eval.csv \ 79 | # --batch_per_device 4 \ 80 | # --nb_epochs 25 \ 81 | # --nb_gpus \ 82 | # --gradient_accumulation 1 \ 83 | # --lr 2e-4 \ 84 | # --save_model_path ./models/prot2text_base/ \ 85 | # --bleu_evaluation \ 86 | 87 | args = argParser.parse_args() 88 | 89 | if args.decoder_path is None: 90 | raise ValueError( 91 | "You need to specify a GPT like model path that is compatible with Hugging Face. Please pass the path of a Hugging Face decoder model using --decoder_path." 92 | ) 93 | if args.use_plm and args.esm_model_path is None: 94 | raise ValueError( 95 | "You want to use protein language model in the encoder, however you did not specify any PLM path. Please pass the path of a Hugging Face PLM using --esm_model_path." 96 | ) 97 | if not args.use_plm and not args.use_rgcn: 98 | raise ValueError( 99 | "You did not choose which type of encoder to use. Please set --use_plm to train a PLM encoder, --use_rgcn for an RGCN encoder or both for Prot2Text architecture." 100 | ) 101 | if not args.use_plm and args.warmup_esm: 102 | raise ValueError( 103 | "You chose to warmup the protein language model however you chose not to use a PLM in the encoder. Please remove --warmup_esm or use --use_plm and specify a PLM using --esm_model_path." 104 | ) 105 | 106 | model_name = args.decoder_path 107 | tokenizer = Prot2TextTokenizer.from_pretrained(model_name) 108 | SPECIAL_TOKEN = '<|graph_token|>' 109 | tokenizer.pad_token = tokenizer.eos_token 110 | tokenizer.pad_token = "<|endoftext|>" 111 | tokenizer.add_tokens([SPECIAL_TOKEN]) 112 | SPECIAL_TOKEN = '<|stop_token|>' 113 | tokenizer.add_tokens([SPECIAL_TOKEN]) 114 | tokenizer.eos_token = '<|stop_token|>' 115 | tokenizer.eos_token_id = 50258 116 | tokenizer.bos_token_id = 50257 117 | 118 | esm_tokenizer = AutoTokenizer.from_pretrained(args.esm_model_path) 119 | 120 | config_model = PretrainedConfig( 121 | _name_or_path='prot2text', 122 | prot2text_version="1.1", 123 | cross_esm_graph=args.use_plm & args.use_rgcn, 124 | esm=args.use_plm, 125 | esm_model_name=args.esm_model_path, 126 | gpt_model_name=model_name, 127 | rgcn=args.use_rgcn, 128 | rgcn_input_dim = 67, 129 | rgcn_n_layers = 6, 130 | decoder_start_token_id = 50257, 131 | eos_token_id = 50258, 132 | max_new_tokens = 256, 133 | no_repeat_ngram_size = 3, 134 | early_stopping = True, 135 | length_penalty = 2.0, 136 | num_beams = 1, 137 | pad_token_id = 50256, 138 | bos_token_id = 50257 139 | ) 140 | esm_config = AutoConfig.from_pretrained(config_model.esm_model_name).to_dict() 141 | config_model.esm_config = esm_config 142 | gpt_config = GPT2Config.from_pretrained(config_model.gpt_model_name, 143 | _name_or_path= config_model.gpt_model_name, 144 | is_encoder_decoder=True, 145 | use_cache=False, 146 | add_cross_attention=True, 147 | bos_token_id=config_model.bos_token_id, 148 | decoder_start_token_id=config_model.decoder_start_token_id, 149 | eos_token_id=config_model.eos_token_id, 150 | max_new_tokens=config_model.max_new_tokens, 151 | pad_token_id=50256, 152 | vocab_size=50259, 153 | num_beams=1, 154 | max_length=256, 155 | min_length=1) 156 | gpt_config.max_new_tokens = 256 157 | gpt_config.prot2text_version = config_model.prot2text_version 158 | config_model.gpt_config = gpt_config.to_dict() 159 | 160 | model = Prot2TextModel(config=config_model) 161 | if args.warmup_esm and args.warmup_gpt: 162 | model.warm_up(gpt_model=args.decoder_path, esm_model=args.esm_model_path) 163 | elif args.warmup_esm: 164 | model.warm_up(esm_model=args.esm_model_path) 165 | elif args.warmup_gpt: 166 | model.warm_up(gpt_model=args.decoder_path) 167 | 168 | train_dataset = Prot2TextDataset(root=args.data_path, 169 | tokenizer=tokenizer, 170 | file_path=args.train_csv_path, 171 | block_size=256, 172 | split='train', 173 | esmtokenizer=esm_tokenizer) 174 | print('train set loaded') 175 | eval_dataset = Prot2TextDataset(root=args.data_path, 176 | tokenizer=tokenizer, 177 | file_path=args.eval_csv_path, 178 | block_size=256, 179 | split='eval', 180 | esmtokenizer=esm_tokenizer) 181 | print('eval set loaded') 182 | 183 | num_gpus = int(args.nb_gpus) 184 | train_size = len(train_dataset) 185 | num_epochs = int(args.nb_epochs) 186 | grad_accumulation = int(args.gradient_accumulation) 187 | batch_size = int(args.batch_per_device) 188 | warmup = 0.06 * num_epochs * train_size / (num_gpus * batch_size * grad_accumulation) 189 | model_save_name = args.save_model_path 190 | lr = args.lr 191 | 192 | def compute_metrics(pred): 193 | labels_ids = pred.label_ids 194 | pred_ids = pred.predictions 195 | 196 | pred_ids[pred_ids == -100] = tokenizer.eos_token_id 197 | pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) 198 | labels_ids[labels_ids == -100] = tokenizer.eos_token_id 199 | label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) 200 | 201 | try: 202 | res = bleu.compute(predictions=pred_str, references=label_str) 203 | return {'eval_bleu': res['bleu']} 204 | except: 205 | return {'eval_bleu': 0.0} 206 | 207 | if args.bleu_evaluation: 208 | prediction_loss_only = False 209 | metric_for_best_model = 'eval_bleu' 210 | greater_is_better = True 211 | predict_with_generate = True 212 | load_best_model_at_end = True 213 | bleu = evaluate.load("bleu") 214 | do_predict = True 215 | else: 216 | compute_metrics = None 217 | prediction_loss_only = True 218 | metric_for_best_model = 'loss' 219 | greater_is_better = False 220 | predict_with_generate = False 221 | load_best_model_at_end = False 222 | do_predict = False 223 | 224 | training_args = Seq2SeqTrainingArguments( 225 | output_dir=model_save_name, 226 | overwrite_output_dir=True, 227 | num_train_epochs=num_epochs, 228 | per_device_train_batch_size=batch_size, 229 | per_device_eval_batch_size=batch_size, 230 | gradient_accumulation_steps=grad_accumulation, 231 | eval_accumulation_steps=None, 232 | evaluation_strategy=IntervalStrategy.STEPS, 233 | save_steps=500, 234 | eval_steps=500, 235 | logging_steps=500, 236 | save_total_limit=15, 237 | weight_decay=0.1, 238 | warmup_steps=warmup, 239 | lr_scheduler_type="cosine", 240 | learning_rate=lr, 241 | do_train=True, 242 | do_eval=True, 243 | do_predict=do_predict, 244 | prediction_loss_only=prediction_loss_only, 245 | metric_for_best_model=metric_for_best_model, 246 | greater_is_better=greater_is_better, 247 | predict_with_generate=predict_with_generate, 248 | load_best_model_at_end=load_best_model_at_end, 249 | ) 250 | 251 | trainer = Prot2TextTrainer( 252 | model=model, 253 | args=training_args, 254 | data_collator=None, 255 | train_dataset=train_dataset, 256 | eval_dataset=eval_dataset, 257 | tokenizer=tokenizer, 258 | compute_metrics=compute_metrics, 259 | ) 260 | 261 | trainer.train() 262 | 263 | if torch.distributed.is_initialized(): 264 | if torch.distributed.get_rank()==0: 265 | model.save_pretrained(os.path.join(model_save_name,'model/')) 266 | tokenizer.save_pretrained(os.path.join(model_save_name,'model/')) 267 | else: 268 | model.save_pretrained(os.path.join(model_save_name,'model/')) 269 | tokenizer.save_pretrained(os.path.join(model_save_name,'model/')) -------------------------------------------------------------------------------- /prot2text_dataset/conversion.py: -------------------------------------------------------------------------------- 1 | """Utilities for converting Graphein Networks to Geometric Deep Learning formats. 2 | """ 3 | # %% 4 | # Graphein 5 | # Author: Kexin Huang, Arian Jamasb 6 | # License: MIT 7 | # Project Website: https://github.com/a-r-j/graphein 8 | # Code Repository: https://github.com/a-r-j/graphein 9 | from __future__ import annotations 10 | 11 | from typing import List, Optional 12 | 13 | import networkx as nx 14 | import numpy as np 15 | import torch 16 | 17 | from graphein.utils.dependencies import import_message 18 | 19 | try: 20 | import torch_geometric 21 | from torch_geometric.data import Data 22 | except ImportError: 23 | import_message( 24 | submodule="graphein.ml.conversion", 25 | package="torch_geometric", 26 | pip_install=True, 27 | conda_channel="rusty1s", 28 | ) 29 | 30 | try: 31 | import dgl 32 | except ImportError: 33 | import_message( 34 | submodule="graphein.ml.conversion", 35 | package="dgl", 36 | pip_install=True, 37 | conda_channel="dglteam", 38 | ) 39 | 40 | try: 41 | import jax.numpy as jnp 42 | except ImportError: 43 | import_message( 44 | submodule="graphein.ml.conversion", 45 | package="jax", 46 | pip_install=True, 47 | conda_channel="conda-forge", 48 | ) 49 | try: 50 | import jraph 51 | except ImportError: 52 | import_message( 53 | submodule="graphein.ml.conversion", 54 | package="jraph", 55 | pip_install=True, 56 | conda_channel="conda-forge", 57 | ) 58 | 59 | 60 | SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"] 61 | """Supported conversion formats. 62 | 63 | ``"nx"``: NetworkX graph 64 | 65 | ``"pyg"``: PyTorch Geometric Data object 66 | 67 | ``"dgl"``: DGL graph 68 | 69 | ``"Jraph"``: Jraph GraphsTuple 70 | """ 71 | 72 | SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"] 73 | """Supported verbosity levels for preserving graph features in conversion.""" 74 | 75 | 76 | class GraphFormatConvertor: 77 | """ 78 | Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats. 79 | Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion 80 | formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`. 81 | 82 | :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS` 83 | :type src_format: Literal["nx", "pyg", "dgl", "jraph"] 84 | :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in: 85 | ``graphein.ml.conversion.SUPPORTED_FORMATS`` 86 | :type dst_format: Literal["nx", "pyg", "dgl", "jraph"] 87 | :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features) 88 | as some are unsupported by various downstream frameworks 89 | :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY 90 | :param columns: List of columns in the node features to retain 91 | :type columns: List[str], optional 92 | """ 93 | 94 | def __init__( 95 | self, 96 | src_format: str, 97 | dst_format: str, 98 | verbose: SUPPORTED_VERBOSITY = "gnn", 99 | columns: Optional[List[str]] = None, 100 | ): 101 | if (src_format not in SUPPORTED_FORMATS) or ( 102 | dst_format not in SUPPORTED_FORMATS 103 | ): 104 | raise ValueError( 105 | "Please specify from supported format, " 106 | + "/".join(SUPPORTED_FORMATS) 107 | ) 108 | self.src_format = src_format 109 | self.dst_format = dst_format 110 | 111 | # supported_verbose_format = ["gnn", "default", "all_info"] 112 | if (columns is None) and (verbose not in SUPPORTED_VERBOSITY): 113 | raise ValueError( 114 | "Please specify the supported verbose mode (" 115 | + "/".join(SUPPORTED_VERBOSITY) 116 | + ") or specify column names!" 117 | ) 118 | 119 | if columns is None: 120 | if verbose == "gnn": 121 | columns = [ 122 | "edge_index", 123 | "coords", 124 | "dist_mat", 125 | "name", 126 | "node_id", 127 | ] 128 | elif verbose == "default": 129 | columns = [ 130 | "b_factor", 131 | "chain_id", 132 | "coords", 133 | "dist_mat", 134 | "edge_index", 135 | "kind", 136 | "name", 137 | "node_id", 138 | "residue_name", 139 | ] 140 | elif verbose == "all_info": 141 | columns = [ 142 | "atom_type", 143 | "b_factor", 144 | "chain_id", 145 | "chain_ids", 146 | "config", 147 | "coords", 148 | "dist_mat", 149 | "edge_index", 150 | "element_symbol", 151 | "kind", 152 | "name", 153 | "node_id", 154 | "node_type", 155 | "pdb_df", 156 | "raw_pdb_df", 157 | "residue_name", 158 | "residue_number", 159 | "rgroup_df", 160 | "sequence_A", 161 | "sequence_B", 162 | ] 163 | self.columns = columns 164 | 165 | self.type2form = { 166 | "atom_type": "str", 167 | "b_factor": "float", 168 | "chain_id": "str", 169 | "coords": "np.array", 170 | "dist_mat": "np.array", 171 | "element_symbol": "str", 172 | "node_id": "str", 173 | "residue_name": "str", 174 | "residue_number": "int", 175 | "edge_index": "torch.tensor", 176 | "kind": "str", 177 | } 178 | 179 | def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph: 180 | """ 181 | Converts ``NetworkX`` graph to ``DGL`` 182 | 183 | :param G: ``nx.Graph`` to convert to ``DGLGraph`` 184 | :type G: nx.Graph 185 | :return: ``DGLGraph`` object version of input ``NetworkX`` graph 186 | :rtype: dgl.DGLGraph 187 | """ 188 | g = dgl.DGLGraph() 189 | node_id = list(G.nodes()) 190 | G = nx.convert_node_labels_to_integers(G) 191 | 192 | ## add node level feat 193 | 194 | node_dict = {} 195 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 196 | for key, value in feat_dict.items(): 197 | if str(key) in self.columns: 198 | node_dict[str(key)] = ( 199 | [value] if i == 0 else node_dict[str(key)] + [value] 200 | ) 201 | 202 | string_dict = {} 203 | node_dict_transformed = {} 204 | for i, j in node_dict.items(): 205 | if i == "coords": 206 | node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type( 207 | "torch.FloatTensor" 208 | ) 209 | elif i == "dist_mat": 210 | node_dict_transformed[i] = torch.Tensor( 211 | np.asarray(j[0].values) 212 | ).type("torch.FloatTensor") 213 | elif self.type2form[i] == "str": 214 | string_dict[i] = j 215 | elif self.type2form[i] in ["float", "int"]: 216 | node_dict_transformed[i] = torch.Tensor(np.array(j)) 217 | g.add_nodes( 218 | len(node_id), 219 | node_dict_transformed, 220 | ) 221 | 222 | edge_dict = {} 223 | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() 224 | 225 | # add edge level features 226 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 227 | for key, value in feat_dict.items(): 228 | if str(key) in self.columns: 229 | edge_dict[str(key)] = ( 230 | list(value) 231 | if i == 0 232 | else edge_dict[str(key)] + list(value) 233 | ) 234 | 235 | edge_transform_dict = {} 236 | for i, j in node_dict.items(): 237 | if self.type2form[i] == "str": 238 | string_dict[i] = j 239 | elif self.type2form[i] in ["float", "int"]: 240 | edge_transform_dict[i] = torch.Tensor(np.array(j)) 241 | g.add_edges(edge_index[0], edge_index[1], edge_transform_dict) 242 | 243 | # add graph level features 244 | graph_dict = { 245 | str(feat_name): [G.graph[feat_name]] 246 | for feat_name in G.graph 247 | if str(feat_name) in self.columns 248 | } 249 | 250 | return g 251 | 252 | def convert_nx_to_pyg(self, G: nx.Graph) -> Data: 253 | """ 254 | Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed. 255 | 256 | :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object 257 | :type G: nx.Graph 258 | :return: ``Data`` object containing networkx graph data 259 | :rtype: pytorch_geometric.data.Data 260 | """ 261 | 262 | # Initialise dict used to construct Data object & Assign node ids as a feature 263 | data = {"node_id": list(G.nodes())} 264 | G = nx.convert_node_labels_to_integers(G) 265 | 266 | # Construct Edge Index 267 | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() 268 | 269 | # Add node features 270 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 271 | for key, value in feat_dict.items(): 272 | if str(key) in self.columns: 273 | data[str(key)] = ( 274 | [value] if i == 0 else data[str(key)] + [value] 275 | ) 276 | 277 | # Add edge features 278 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 279 | for key, value in feat_dict.items(): 280 | if str(key) in self.columns: 281 | data[str(key)] = ( 282 | list(value) if i == 0 else data[str(key)] + list(value) 283 | ) 284 | 285 | # Add graph-level features 286 | for feat_name in G.graph: 287 | if str(feat_name) in self.columns: 288 | data[str(feat_name)] = [G.graph[feat_name]] 289 | 290 | if "edge_index" in self.columns: 291 | data["edge_index"] = edge_index.view(2, -1) 292 | 293 | data = Data.from_dict(data) 294 | data.num_nodes = G.number_of_nodes() 295 | return data 296 | 297 | @staticmethod 298 | def convert_nx_to_nx(G: nx.Graph) -> nx.Graph: 299 | """ 300 | Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself. 301 | 302 | :param G: NetworkX Graph 303 | :type G: nx.Graph 304 | :return: NetworkX Graph 305 | :rtype: nx.Graph 306 | """ 307 | return G 308 | 309 | @staticmethod 310 | def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph: 311 | """ 312 | Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes. 313 | 314 | :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph. 315 | :type G: dgl.DGLGraph 316 | :return: NetworkX graph object. 317 | :rtype: nx.Graph 318 | """ 319 | node_attrs = G.node_attr_schemes().keys() 320 | edge_attrs = G.edge_attr_schemes().keys() 321 | return dgl.to_networkx(G, node_attrs, edge_attrs) 322 | 323 | @staticmethod 324 | def convert_pyg_to_nx(G: Data) -> nx.Graph: 325 | """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``). 326 | 327 | :param G: Pytorch Geometric Data. 328 | :type G: torch_geometric.data.Data 329 | :returns: NetworkX graph. 330 | :rtype: nx.Graph 331 | """ 332 | return torch_geometric.utils.to_networkx(G) 333 | 334 | def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple: 335 | """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``. 336 | 337 | :param G: Networkx graph to convert. 338 | :type G: nx.Graph 339 | :return: Jraph GraphsTuple graph. 340 | :rtype: jraph.GraphsTuple 341 | """ 342 | G = nx.convert_node_labels_to_integers(G) 343 | 344 | n_node = len(G) 345 | n_edge = G.number_of_edges() 346 | edge_list = list(G.edges()) 347 | senders, receivers = zip(*edge_list) 348 | senders, receivers = jnp.array(senders), jnp.array(receivers) 349 | 350 | # Add node features 351 | node_features = {} 352 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 353 | for key, value in feat_dict.items(): 354 | if str(key) in self.columns: 355 | # node_features[str(key)] = ( 356 | # [value] 357 | # if i == 0 358 | # else node_features[str(key)] + [value] 359 | # ) 360 | feat = ( 361 | [value] 362 | if i == 0 363 | else node_features[str(key)] + [value] 364 | ) 365 | try: 366 | feat = torch.tensor(feat) 367 | node_features[str(key)] = feat 368 | except TypeError: 369 | node_features[str(key)] = feat 370 | 371 | # Add edge features 372 | edge_features = {} 373 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 374 | for key, value in feat_dict.items(): 375 | if str(key) in self.columns: 376 | edge_features[str(key)] = ( 377 | list(value) 378 | if i == 0 379 | else edge_features[str(key)] + list(value) 380 | ) 381 | 382 | # Add graph features 383 | global_context = { 384 | str(feat_name): [G.graph[feat_name]] 385 | for feat_name in G.graph 386 | if str(feat_name) in self.columns 387 | } 388 | 389 | return jraph.GraphsTuple( 390 | nodes=node_features, 391 | senders=senders, 392 | receivers=receivers, 393 | edges=edge_features, 394 | n_node=n_node, 395 | n_edge=n_edge, 396 | globals=global_context, 397 | ) 398 | 399 | def __call__(self, G: nx.Graph): 400 | nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)") 401 | dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)") 402 | return dst_g 403 | 404 | 405 | # def convert_nx_to_pyg_data(G: nx.Graph) -> Data: 406 | # # Initialise dict used to construct Data object 407 | # data = {"node_id": list(G.nodes())} 408 | 409 | # G = nx.convert_node_labels_to_integers(G) 410 | 411 | # # Construct Edge Index 412 | # edge_index = torch.LongTensor(list(G.edges)).t().contiguous() 413 | 414 | # # Add node features 415 | # for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 416 | # for key, value in feat_dict.items(): 417 | # data[str(key)] = [value] if i == 0 else data[str(key)] + [value] 418 | 419 | # # Add edge features 420 | # for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 421 | # for key, value in feat_dict.items(): 422 | # data[str(key)] = ( 423 | # list(value) if i == 0 else data[str(key)] + list(value) 424 | # ) 425 | 426 | # # Add graph-level features 427 | # for feat_name in G.graph: 428 | # data[str(feat_name)] = [G.graph[feat_name]] 429 | 430 | # data["edge_index"] = edge_index.view(2, -1) 431 | # data = Data.from_dict(data) 432 | # data.num_nodes = G.number_of_nodes() 433 | 434 | # return data 435 | def convert_nx_to_pyg_data(G: nx.Graph) -> Data: 436 | # Initialise dict used to construct Data object 437 | data = {"node_id": list(G.nodes())} 438 | 439 | G = nx.convert_node_labels_to_integers(G) 440 | 441 | # Construct Edge Index 442 | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() 443 | 444 | # Add node features 445 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 446 | for key, value in feat_dict.items(): 447 | data[str(key)] = [value] if i == 0 else data[str(key)] + [value] 448 | 449 | 450 | # Add edge features 451 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 452 | for key, value in feat_dict.items(): 453 | if key == 'distance': 454 | data[str(key)] = ( 455 | [value] if i == 0 else data[str(key)] + [value] 456 | ) 457 | else: 458 | data[str(key)] = ( 459 | [list(value)] if i == 0 else data[str(key)] + [list(value)] 460 | ) 461 | 462 | # Add graph-level features 463 | for feat_name in G.graph: 464 | data[str(feat_name)] = [G.graph[feat_name]] 465 | 466 | data["edge_index"] = edge_index.view(2, -1) 467 | data = Data.from_dict(data) 468 | data.num_nodes = G.number_of_nodes() 469 | 470 | return data 471 | -------------------------------------------------------------------------------- /prot2text_model/Model.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Config, AutoConfig, AutoTokenizer, GPT2Config 2 | from transformers import GPT2LMHeadModel, GPT2Model, PretrainedConfig, PreTrainedModel 3 | import transformers 4 | from .Encoder import EncoderRGCN 5 | from typing import Optional, Tuple, Union, Callable 6 | import torch 7 | import torch.nn as nn 8 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions 9 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP, GPT2PreTrainedModel 10 | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig 11 | from .utils import CABlock, _GPT2LMHeadModel 12 | import os 13 | import sys 14 | import numpy as np 15 | from transformers.generation.configuration_utils import GenerationConfig 16 | from transformers.generation.logits_process import LogitsProcessorList 17 | from transformers.generation.stopping_criteria import StoppingCriteriaList 18 | sys.path.append('../prot2text_dataset') 19 | from prot2text_dataset.pdb2graph import PDB2Graph, download_alphafold_structure 20 | from prot2text_dataset.graphs import * 21 | from prot2text_dataset.utils_dataset import * 22 | from graphein.protein.config import ProteinGraphConfig, DSSPConfig 23 | from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale, hydrogen_bond_acceptor, hydrogen_bond_donor 24 | from graphein.protein.features.nodes.dssp import phi, psi, asa, rsa, secondary_structure 25 | from graphein.protein.edges.distance import (add_peptide_bonds, 26 | add_hydrogen_bond_interactions, 27 | add_disulfide_interactions, 28 | add_ionic_interactions, 29 | add_delaunay_triangulation, 30 | add_distance_threshold, 31 | add_sequence_distance_edges, 32 | add_k_nn_edges) 33 | 34 | class Prot2TextModel(PreTrainedModel): 35 | config_class = PretrainedConfig 36 | _keys_to_ignore_on_load_missing = [r"transformer"] 37 | base_model_prefix = "decoder" 38 | def __init__(self, config): 39 | super().__init__(config) 40 | 41 | self.gpt_config = GPT2Config.from_dict(config.gpt_config) 42 | 43 | # if we are using RGCN to encode the protein's structure, define the RGCN encoder 44 | if config.rgcn: 45 | self.encoder = EncoderRGCN(input_dim=config.rgcn_input_dim, hidden_dim=self.gpt_config.n_embd, n_layers=config.rgcn_n_layers, emb_dim=self.gpt_config.n_embd, prot2text_version=self.config.prot2text_version) 46 | 47 | # define the GPT2 decoder 48 | self.decoder = _GPT2LMHeadModel(self.gpt_config) 49 | 50 | # if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer 51 | if config.esm: 52 | self.esm_config = PretrainedConfig.from_dict(config.esm_config) 53 | self.esm = transformers.EsmModel(self.esm_config) 54 | self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd) 55 | if config.cross_esm_graph and config.rgcn: 56 | self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)]) 57 | self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon) 58 | 59 | self.config = config 60 | 61 | 62 | def get_encoder(self): 63 | return self.encoder 64 | 65 | def get_decoder(self): 66 | return self.decoder 67 | 68 | def get_input_embeddings(self): 69 | if hasattr(self, "transformer"): 70 | return self.transformer.wte 71 | return self.decoder.transformer.wte 72 | 73 | def warm_up(self, gpt_model=None, esm_model=None): 74 | if esm_model is not None: 75 | self.esm = transformers.EsmModel.from_pretrained(esm_model) 76 | if gpt_model is not None: 77 | self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False) 78 | self.decoder.resize_token_embeddings(self.gpt_config.vocab_size) 79 | self.decoder.config = self.gpt_config 80 | 81 | 82 | def forward(self, 83 | encoder_input_ids: Optional[torch.LongTensor] = None, 84 | edge_index: Optional[torch.LongTensor] = None, 85 | batch: Optional[torch.LongTensor] = None, 86 | x: Optional[torch.FloatTensor] = None, 87 | edge_type: Optional[torch.LongTensor] = None, 88 | decoder_input_ids: Optional[torch.LongTensor] = None, 89 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 90 | past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None, 91 | decoder_attention_mask: Optional[torch.FloatTensor] = None, 92 | attention_mask: Optional[torch.FloatTensor] = None, 93 | token_type_ids: Optional[torch.LongTensor] = None, 94 | position_ids: Optional[torch.LongTensor] = None, 95 | head_mask: Optional[torch.FloatTensor] = None, 96 | inputs_embeds: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.Tensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | labels: Optional[torch.LongTensor] = None, 100 | use_cache: Optional[bool] = None, 101 | output_attentions: Optional[bool] = None, 102 | output_hidden_states: Optional[bool] = None, 103 | return_dict: Optional[bool] = None, 104 | get_graph_emb: Optional[bool] = False, 105 | **delete_args, 106 | ): 107 | use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache 108 | return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict 109 | 110 | 111 | if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3: 112 | decoder_input_ids = decoder_input_ids.squeeze(0) 113 | 114 | if x is not None and self.config.rgcn: 115 | graph_emb = self.encoder(x, edge_index, edge_type, batch) 116 | graph_mask = None 117 | 118 | if self.config.esm: 119 | if self.config.prot2text_version=='1.0': 120 | if encoder_input_ids.size()[1] != 1021: 121 | raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021") 122 | 123 | esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state 124 | esm_emb = self.to_embedding(esm_emb) 125 | if not self.config.cross_esm_graph and self.config.rgcn: 126 | graph_emb = torch.cat((graph_emb, esm_emb), dim=1) 127 | t_add = torch.ones((attention_mask.size(0), 1)).to(attention_mask.get_device()) 128 | attention_mask = torch.cat((t_add, attention_mask), dim=1) 129 | elif self.config.cross_esm_graph and self.config.rgcn: 130 | if past_key_values_graph_esm is None: 131 | past_length = 0 132 | past_key_values_graph_esm = tuple([None] * len(self.h)) 133 | else: 134 | past_length = past_key_values_graph_esm[0][0].size(-2) 135 | output_shape = esm_emb.size() 136 | 137 | all_self_attentions = () if output_attentions else None 138 | all_cross_attentions = () if output_attentions and self.gpt_config.add_cross_attention else None 139 | all_hidden_states = () if output_hidden_states else None 140 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values_graph_esm)): 141 | outputs = block( 142 | esm_emb, 143 | layer_past=layer_past, 144 | attention_mask=attention_mask, 145 | encoder_hidden_states=graph_emb, 146 | encoder_attention_mask=graph_mask, 147 | use_cache=use_cache, 148 | output_attentions=False, 149 | ) 150 | esm_emb = outputs[0] 151 | 152 | esm_emb = self.ln_f(esm_emb) 153 | esm_emb = esm_emb.view(output_shape) 154 | graph_emb = esm_emb 155 | else: 156 | graph_emb = esm_emb 157 | else: 158 | attention_mask = None 159 | if self.config.prot2text_version=='1.0': 160 | attention_mask = None 161 | if get_graph_emb: 162 | return graph_emb 163 | 164 | transformer_outputs = self.decoder(input_ids=decoder_input_ids, 165 | past_key_values=past_key_values, 166 | attention_mask=decoder_attention_mask, 167 | token_type_ids=token_type_ids, 168 | position_ids=position_ids, 169 | head_mask=head_mask, 170 | inputs_embeds=inputs_embeds, 171 | encoder_hidden_states=graph_emb, 172 | encoder_attention_mask=attention_mask, 173 | labels=labels, 174 | use_cache=use_cache, 175 | output_attentions=output_attentions, 176 | output_hidden_states=output_hidden_states, 177 | return_dict=return_dict, 178 | ) 179 | 180 | return transformer_outputs 181 | 182 | @torch.no_grad() 183 | def generate_protein_description(self, 184 | protein_pdbID=None, 185 | protein_sequence=None, 186 | edge_index: Optional[torch.LongTensor] = None, 187 | x: Optional[torch.FloatTensor] = None, 188 | edge_type: Optional[torch.LongTensor] = None, 189 | tokenizer=None, 190 | device='cpu' 191 | ): 192 | 193 | if self.config.esm and not self.config.rgcn and protein_sequence==None: 194 | raise ValueError( 195 | "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence" 196 | ) 197 | if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None): 198 | raise ValueError( 199 | "The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type" 200 | ) 201 | if self.config.esm: 202 | esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name) 203 | 204 | if protein_pdbID==None and protein_sequence==None: 205 | raise ValueError( 206 | "you need to provide either a protein AlphaFold Id or an amino-acid sequence" 207 | ) 208 | 209 | if protein_pdbID!=None: 210 | config = {"node_metadata_functions": [amino_acid_one_hot, 211 | expasy_protein_scale, 212 | meiler_embedding, 213 | hydrogen_bond_acceptor, hydrogen_bond_donor 214 | ], 215 | "edge_construction_functions": [add_peptide_bonds, 216 | add_hydrogen_bond_interactions, 217 | partial(add_distance_threshold, long_interaction_threshold=3, threshold=10.),], 218 | "graph_metadata_functions":[asa,phi, psi, secondary_structure, rsa], 219 | "dssp_config": DSSPConfig()} 220 | config = ProteinGraphConfig(**config) 221 | 222 | PATH_TO_DATA = f"./.tmp/pdb/pdb" 223 | OUTPUT_FOLDER = f"./.tmp/pdb/raw" 224 | save_dir = f"./.tmp/pdb/" 225 | isExist = os.path.exists(PATH_TO_DATA) 226 | if not isExist: 227 | os.makedirs(PATH_TO_DATA) 228 | isExist = os.path.exists(OUTPUT_FOLDER) 229 | if not isExist: 230 | os.makedirs(OUTPUT_FOLDER) 231 | isExist = os.path.exists(save_dir+'processed') 232 | if not isExist: 233 | os.makedirs(save_dir+'processed') 234 | 235 | structure_filename = download_alphafold_structure(uniprot_id=protein_pdbID, out_dir=PATH_TO_DATA) 236 | if structure_filename is None: 237 | raise ValueError("Error! the ID does not exist in AlphaFoldDB or you do not have internet connection") 238 | graph_filename = structure_filename.split('/') 239 | graph_filename[-2] = 'raw' 240 | graph_filename[-1] = graph_filename[-1].replace('.pdb', '.pt') 241 | graph_filename = '/'.join(graph_filename) 242 | process_filename = structure_filename.split('/') 243 | process_filename[-2] = 'processed' 244 | process_filename[-1] = process_filename[-1].replace('.pdb', '.pt') 245 | process_filename = '/'.join(process_filename) 246 | try: 247 | gpdb = PDB2Graph(root = PATH_TO_DATA, output_folder = OUTPUT_FOLDER, config=config, n_processors=1).create_pyg_graph(structure_filename) 248 | seq = esmtokenizer(gpdb.sequence, add_special_tokens=True, truncation=True, max_length=1021, padding='max_length',return_tensors="pt") # 249 | torch.save(gpdb, graph_filename) 250 | gpdb.edge_type = [np.array(gpdb.edge_type.transpose(0,1))] 251 | gpdb.encoder_input_ids = seq['input_ids'] 252 | gpdb.attention_mask = seq['attention_mask'] 253 | torch.save(gpdb, process_filename) 254 | except: 255 | os.remove(structure_filename) 256 | raise ValueError('creating graphs did not work, probably the pdb file of alphaFold is damaged') 257 | 258 | self.eval() 259 | inputs = gpdb 260 | inputs = inputs.to_dict() 261 | 262 | inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0) 263 | inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1) 264 | for key in ['num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates']: 265 | inputs.pop(key) 266 | inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone() 267 | inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id 268 | inputs["decoder_attention_mask"] = torch.ones(inputs['decoder_input_ids'].shape[0], 1) 269 | self.to(device) 270 | inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} 271 | encoder_state = dict() 272 | encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True) 273 | encoder_state['attentions'] = inputs['attention_mask'] 274 | for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']: 275 | inputs.pop(key) 276 | tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'], 277 | encoder_outputs=encoder_state, 278 | use_cache=True, 279 | output_attentions=True, 280 | output_scores=True, 281 | return_dict_in_generate=True, 282 | encoder_attention_mask=inputs['attention_mask'], 283 | length_penalty=1.0, 284 | no_repeat_ngram_size=None, 285 | early_stopping=False, 286 | num_beams=1) 287 | 288 | generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True) 289 | print(tok_ids.get('scores')[0].size()) 290 | m = torch.nn.Softmax() 291 | att_w = [] 292 | print(len(gpdb.sequence[0])) 293 | score = 0 294 | for i in range(len(tok_ids.get('cross_attentions'))): 295 | att_w.append(torch.mul(tok_ids.get('cross_attentions')[i][-1].squeeze().mean(dim=0), inputs['attention_mask'][-1].squeeze())[:len(gpdb.sequence[0])].tolist()) 296 | score += np.log(torch.max(m(tok_ids.get('scores')[i]).squeeze()).item()) 297 | score = score / len(tok_ids.get('cross_attentions')) 298 | # print(str(score)) 299 | 300 | # import seaborn as sns 301 | # import matplotlib.pylab as plt 302 | # plt.figure().set_figwidth(150) 303 | # ax = sns.heatmap(att_w, cmap="YlGnBu", robust=True, xticklabels=gpdb.sequence[0])#, yticklabels=generated[0]) 304 | # plt.savefig("seaborn_plot.png") 305 | 306 | os.remove(structure_filename) 307 | os.remove(graph_filename) 308 | os.remove(process_filename) 309 | 310 | return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '') 311 | 312 | else: 313 | seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt") 314 | inputs={} 315 | inputs['encoder_input_ids'] = seq['input_ids'] 316 | inputs['attention_mask'] = seq['attention_mask'] 317 | inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone() 318 | inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id 319 | 320 | self.to(device) 321 | inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} 322 | encoder_state = dict() 323 | encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True) 324 | generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True) 325 | 326 | return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '') 327 | 328 | @torch.no_grad() 329 | def generate(self, 330 | inputs: Optional[torch.Tensor] = None, 331 | generation_config: Optional[GenerationConfig] = None, 332 | logits_processor: Optional[LogitsProcessorList] = None, 333 | stopping_criteria: Optional[StoppingCriteriaList] = None, 334 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 335 | synced_gpus: Optional[bool] = None, 336 | assistant_model: Optional["PreTrainedModel"] = None, 337 | streamer: Optional["BaseStreamer"] = None, 338 | **kwargs, 339 | ): 340 | encoder_state = self(**kwargs, get_graph_emb=True) 341 | input_ids = kwargs['decoder_input_ids'] 342 | attention_mask = kwargs['decoder_attention_mask'] 343 | kwargs['encoder_attention_mask'] = kwargs['attention_mask'] 344 | if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm: 345 | t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device()) 346 | kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1) 347 | for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length', 348 | '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]: 349 | if key in kwargs.keys(): 350 | kwargs.pop(key) 351 | return self.decoder.generate(input_ids=input_ids, 352 | generation_config=generation_config, 353 | logits_processor=logits_processor, 354 | stopping_criteria=stopping_criteria, 355 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 356 | synced_gpus=synced_gpus, 357 | assistant_model=assistant_model, 358 | streamer=streamer, 359 | encoder_outputs={'hidden_states': encoder_state, 'attentions':0}, 360 | **kwargs 361 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /prot2text_model/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP 3 | from typing import Optional, Tuple, Union, TYPE_CHECKING, Any, Callable, Dict, List 4 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, GPT2LMHeadModel 5 | from torch_geometric.loader import DataListLoader, DataLoader 6 | from torch_geometric.nn import DataParallel 7 | from torch.nn.parallel import DistributedDataParallel 8 | from torch.utils.data.distributed import DistributedSampler 9 | import torch 10 | from torch_geometric.data import Dataset 11 | from transformers.deepspeed import is_deepspeed_zero3_enabled 12 | from transformers.generation.logits_process import LogitsProcessorList 13 | from transformers.generation.stopping_criteria import StoppingCriteriaList 14 | # from transformers.generation.utils import validate_stopping_criteria 15 | from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput 16 | from transformers.generation.beam_search import BeamScorer 17 | 18 | class _GPT2LMHeadModel(GPT2LMHeadModel): 19 | def _init_(self, config): 20 | super(GPT2LMHeadModel, self).init_(config) 21 | self.config = config 22 | 23 | 24 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, encoder_outputs=None, **kwargs): 25 | ''' 26 | This function is an edited version of the prepare_inputs_for_generation function from HuggingFace's transformers 27 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 28 | ''' 29 | token_type_ids = kwargs.get("token_type_ids", None) 30 | # only last token for inputs_ids if past is defined in kwargs 31 | if past_key_values: 32 | input_ids = input_ids[:, -1].unsqueeze(-1) 33 | if token_type_ids is not None: 34 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 35 | 36 | attention_mask = kwargs.get("attention_mask", None) 37 | position_ids = kwargs.get("position_ids", None) 38 | if self.config.prot2text_version=="1.1" or self.config.prot2text_version=="1.2": 39 | encoder_attention_mask = kwargs.get("encoder_attention_mask", None) 40 | elif self.config.prot2text_version=="1.0": 41 | encoder_attention_mask = None 42 | 43 | if attention_mask is not None and position_ids is None: 44 | position_ids = attention_mask.long().cumsum(-1) - 1 45 | position_ids.masked_fill_(attention_mask == 0, 1) 46 | if past_key_values: 47 | position_ids = position_ids[:, -1].unsqueeze(-1) 48 | else: 49 | position_ids = None 50 | 51 | model_specific_kwargs = { 52 | "encoder_hidden_states": encoder_outputs['hidden_states'], 53 | } 54 | 55 | return { 56 | "input_ids": input_ids, 57 | "past_key_values": past_key_values, 58 | "use_cache": kwargs.get("use_cache"), 59 | "position_ids": position_ids, 60 | "attention_mask": attention_mask, 61 | "token_type_ids": token_type_ids, 62 | "encoder_attention_mask": encoder_attention_mask, 63 | **model_specific_kwargs 64 | } 65 | 66 | 67 | def greedy_search( 68 | self, 69 | input_ids: torch.LongTensor, 70 | logits_processor: Optional[LogitsProcessorList] = None, 71 | stopping_criteria: Optional[StoppingCriteriaList] = None, 72 | max_length: Optional[int] = None, 73 | pad_token_id: Optional[int] = None, 74 | eos_token_id: Optional[Union[int, List[int]]] = None, 75 | output_attentions: Optional[bool] = None, 76 | output_hidden_states: Optional[bool] = None, 77 | output_scores: Optional[bool] = None, 78 | return_dict_in_generate: Optional[bool] = None, 79 | synced_gpus: bool = False, 80 | streamer: Optional["BaseStreamer"] = None, 81 | **model_kwargs, 82 | ) -> Union[GreedySearchOutput, torch.LongTensor]: 83 | ''' 84 | This function is an edited version of the greedy_search function from HuggingFace's transformers 85 | https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py 86 | ''' 87 | 88 | # init values 89 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 90 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 91 | if max_length is not None: 92 | warnings.warn( 93 | "`max_length` is deprecated in this function, use" 94 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 95 | UserWarning, 96 | ) 97 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 98 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 99 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 100 | if isinstance(eos_token_id, int): 101 | eos_token_id = [eos_token_id] 102 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 103 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 104 | output_attentions = ( 105 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 106 | ) 107 | output_hidden_states = ( 108 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 109 | ) 110 | return_dict_in_generate = ( 111 | return_dict_in_generate 112 | if return_dict_in_generate is not None 113 | else self.generation_config.return_dict_in_generate 114 | ) 115 | 116 | # init attention / hidden states / scores tuples 117 | scores = () if (return_dict_in_generate and output_scores) else None 118 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 119 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 120 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 121 | 122 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 123 | if return_dict_in_generate and self.config.is_encoder_decoder: 124 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 125 | encoder_hidden_states = ( 126 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 127 | ) 128 | 129 | # keep track of which sequences are already finished 130 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 131 | 132 | this_peer_finished = False # used by synced_gpus only 133 | while True: 134 | if synced_gpus: 135 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 136 | # The following logic allows an early break if all peers finished generating their sequence 137 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 138 | # send 0.0 if we finished, 1.0 otherwise 139 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 140 | # did all peers finish? the reduced sum will be 0.0 then 141 | if this_peer_finished_flag.item() == 0.0: 142 | break 143 | 144 | # prepare model inputs 145 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 146 | 147 | # forward pass to get next token 148 | outputs = self( 149 | **model_inputs, 150 | return_dict=True, 151 | output_attentions=output_attentions, 152 | output_hidden_states=output_hidden_states, 153 | ) 154 | 155 | if synced_gpus and this_peer_finished: 156 | continue # don't waste resources running the code we don't need 157 | 158 | next_token_logits = outputs.logits[:, -1, :] 159 | 160 | # pre-process distribution 161 | next_tokens_scores = logits_processor(input_ids, next_token_logits) 162 | 163 | # Store scores, attentions and hidden_states when required 164 | if return_dict_in_generate: 165 | if output_scores: 166 | scores += (next_tokens_scores,) 167 | if output_attentions: 168 | decoder_attentions += ( 169 | (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) 170 | ) 171 | if self.config.is_encoder_decoder: 172 | cross_attentions += (outputs.cross_attentions,) 173 | 174 | if output_hidden_states: 175 | decoder_hidden_states += ( 176 | (outputs.decoder_hidden_states,) 177 | if self.config.is_encoder_decoder 178 | else (outputs.hidden_states,) 179 | ) 180 | 181 | # argmax 182 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 183 | 184 | # finished sentences should have their next token be a padding token 185 | if eos_token_id is not None: 186 | if pad_token_id is None: 187 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 188 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 189 | 190 | # update generated ids, model inputs, and length for next step 191 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 192 | if streamer is not None: 193 | streamer.put(next_tokens.cpu()) 194 | model_kwargs = self._update_model_kwargs_for_generation( 195 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 196 | ) 197 | 198 | # if eos_token was found in one sentence, set sentence to finished 199 | if eos_token_id_tensor is not None: 200 | unfinished_sequences = unfinished_sequences.mul( 201 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 202 | ) 203 | 204 | # stop when each sentence is finished 205 | if unfinished_sequences.max() == 0: 206 | this_peer_finished = True 207 | 208 | # stop if we exceed the maximum length 209 | try: 210 | if stopping_criteria(input_ids, scores): 211 | this_peer_finished = True 212 | except: 213 | if all(stopping_criteria(input_ids, scores)): 214 | this_peer_finished = True 215 | 216 | if this_peer_finished and not synced_gpus: 217 | break 218 | 219 | if streamer is not None: 220 | streamer.end() 221 | 222 | if return_dict_in_generate: 223 | if self.config.is_encoder_decoder: 224 | return GreedySearchEncoderDecoderOutput( 225 | sequences=input_ids, 226 | scores=scores, 227 | encoder_attentions=encoder_attentions, 228 | encoder_hidden_states=encoder_hidden_states, 229 | decoder_attentions=decoder_attentions, 230 | cross_attentions=cross_attentions, 231 | decoder_hidden_states=decoder_hidden_states, 232 | ) 233 | else: 234 | return GreedySearchDecoderOnlyOutput( 235 | sequences=input_ids, 236 | scores=scores, 237 | attentions=decoder_attentions, 238 | hidden_states=decoder_hidden_states, 239 | ) 240 | else: 241 | return input_ids 242 | 243 | def _greedy_search( 244 | self, 245 | input_ids: torch.LongTensor, 246 | logits_processor: Optional[LogitsProcessorList] = None, 247 | stopping_criteria: Optional[StoppingCriteriaList] = None, 248 | max_length: Optional[int] = None, 249 | pad_token_id: Optional[int] = None, 250 | eos_token_id: Optional[Union[int, List[int]]] = None, 251 | output_attentions: Optional[bool] = None, 252 | output_hidden_states: Optional[bool] = None, 253 | output_scores: Optional[bool] = None, 254 | return_dict_in_generate: Optional[bool] = None, 255 | synced_gpus: bool = False, 256 | streamer: Optional["BaseStreamer"] = None, 257 | **model_kwargs, 258 | ) -> Union[GreedySearchOutput, torch.LongTensor]: 259 | 260 | return self.greedy_search( 261 | input_ids, 262 | logits_processor, 263 | stopping_criteria, 264 | max_length, 265 | pad_token_id, 266 | eos_token_id, 267 | output_attentions, 268 | output_hidden_states, 269 | output_scores, 270 | return_dict_in_generate, 271 | synced_gpus, 272 | streamer, 273 | **model_kwargs, 274 | ) 275 | def _beam_search( 276 | self, 277 | input_ids: torch.LongTensor, 278 | beam_scorer: BeamScorer, 279 | logits_processor: Optional[LogitsProcessorList] = None, 280 | stopping_criteria: Optional[StoppingCriteriaList] = None, 281 | max_length: Optional[int] = None, 282 | pad_token_id: Optional[int] = None, 283 | eos_token_id: Optional[Union[int, List[int]]] = None, 284 | output_attentions: Optional[bool] = None, 285 | output_hidden_states: Optional[bool] = None, 286 | output_scores: Optional[bool] = None, 287 | return_dict_in_generate: Optional[bool] = None, 288 | synced_gpus: bool = False, 289 | **model_kwargs, 290 | ) -> Union[BeamSearchOutput, torch.LongTensor]: 291 | 292 | return self.beam_search( 293 | input_ids, 294 | beam_scorer, 295 | logits_processor, 296 | stopping_criteria, 297 | max_length, 298 | pad_token_id, 299 | eos_token_id, 300 | output_attentions, 301 | output_hidden_states, 302 | output_scores, 303 | return_dict_in_generate, 304 | synced_gpus, 305 | **model_kwargs, 306 | ) 307 | 308 | def beam_search( 309 | self, 310 | input_ids: torch.LongTensor, 311 | beam_scorer: BeamScorer, 312 | logits_processor: Optional[LogitsProcessorList] = None, 313 | stopping_criteria: Optional[StoppingCriteriaList] = None, 314 | max_length: Optional[int] = None, 315 | pad_token_id: Optional[int] = None, 316 | eos_token_id: Optional[Union[int, List[int]]] = None, 317 | output_attentions: Optional[bool] = None, 318 | output_hidden_states: Optional[bool] = None, 319 | output_scores: Optional[bool] = None, 320 | return_dict_in_generate: Optional[bool] = None, 321 | synced_gpus: bool = False, 322 | **model_kwargs, 323 | ) -> Union[BeamSearchOutput, torch.LongTensor]: 324 | ''' 325 | This function is an edited version of the beam_search function from HuggingFace's transformers 326 | https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py 327 | ''' 328 | # init values 329 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 330 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 331 | if max_length is not None: 332 | warnings.warn( 333 | "`max_length` is deprecated in this function, use" 334 | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", 335 | UserWarning, 336 | ) 337 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 338 | if len(stopping_criteria) == 0: 339 | warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) 340 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 341 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 342 | if isinstance(eos_token_id, int): 343 | eos_token_id = [eos_token_id] 344 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 345 | output_attentions = ( 346 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 347 | ) 348 | output_hidden_states = ( 349 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 350 | ) 351 | return_dict_in_generate = ( 352 | return_dict_in_generate 353 | if return_dict_in_generate is not None 354 | else self.generation_config.return_dict_in_generate 355 | ) 356 | 357 | batch_size = len(beam_scorer._beam_hyps) 358 | num_beams = beam_scorer.num_beams 359 | 360 | batch_beam_size, cur_len = input_ids.shape 361 | 362 | if num_beams * batch_size != batch_beam_size: 363 | raise ValueError( 364 | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." 365 | ) 366 | 367 | # init attention / hidden states / scores tuples 368 | scores = () if (return_dict_in_generate and output_scores) else None 369 | beam_indices = ( 370 | tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None 371 | ) 372 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 373 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 374 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 375 | 376 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 377 | if return_dict_in_generate and self.config.is_encoder_decoder: 378 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 379 | encoder_hidden_states = ( 380 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 381 | ) 382 | 383 | # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens 384 | # of the first beam are considered to avoid sampling the exact same tokens across all beams. 385 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) 386 | beam_scores[:, 1:] = -1e9 387 | beam_scores = beam_scores.view((batch_size * num_beams,)) 388 | 389 | this_peer_finished = False # used by synced_gpus only 390 | while True: 391 | if synced_gpus: 392 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 393 | # The following logic allows an early break if all peers finished generating their sequence 394 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 395 | # send 0.0 if we finished, 1.0 otherwise 396 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 397 | # did all peers finish? the reduced sum will be 0.0 then 398 | if this_peer_finished_flag.item() == 0.0: 399 | break 400 | 401 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 402 | 403 | outputs = self( 404 | **model_inputs, 405 | return_dict=True, 406 | output_attentions=output_attentions, 407 | output_hidden_states=output_hidden_states, 408 | ) 409 | 410 | if synced_gpus and this_peer_finished: 411 | cur_len = cur_len + 1 412 | continue # don't waste resources running the code we don't need 413 | 414 | next_token_logits = outputs.logits[:, -1, :] 415 | # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` 416 | # cannot be generated both before and after the `nn.functional.log_softmax` operation. 417 | # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) 418 | next_token_scores = nn.functional.log_softmax( 419 | next_token_logits, dim=-1 420 | ) # (batch_size * num_beams, vocab_size) 421 | 422 | next_token_scores_processed = logits_processor(input_ids, next_token_scores) 423 | # next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) 424 | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( 425 | next_token_scores_processed 426 | ) 427 | 428 | # Store scores, attentions and hidden_states when required 429 | if return_dict_in_generate: 430 | if output_scores: 431 | scores += (next_token_scores_processed,) 432 | if output_attentions: 433 | decoder_attentions += ( 434 | (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) 435 | ) 436 | if self.config.is_encoder_decoder: 437 | cross_attentions += (outputs.cross_attentions,) 438 | 439 | if output_hidden_states: 440 | decoder_hidden_states += ( 441 | (outputs.decoder_hidden_states,) 442 | if self.config.is_encoder_decoder 443 | else (outputs.hidden_states,) 444 | ) 445 | 446 | # reshape for beam search 447 | vocab_size = next_token_scores.shape[-1] 448 | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) 449 | 450 | 451 | 452 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) 453 | next_token_scores, next_tokens = torch.topk( 454 | next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True 455 | ) 456 | 457 | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") 458 | next_tokens = next_tokens % vocab_size 459 | 460 | # stateless 461 | beam_outputs = beam_scorer.process( 462 | input_ids, 463 | next_token_scores, 464 | next_tokens, 465 | next_indices, 466 | pad_token_id=pad_token_id, 467 | eos_token_id=eos_token_id, 468 | beam_indices=beam_indices, 469 | ) 470 | 471 | beam_scores = beam_outputs["next_beam_scores"] 472 | beam_next_tokens = beam_outputs["next_beam_tokens"] 473 | beam_idx = beam_outputs["next_beam_indices"] 474 | 475 | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) 476 | 477 | model_kwargs = self._update_model_kwargs_for_generation( 478 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 479 | ) 480 | if model_kwargs["past_key_values"] is not None: 481 | model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) 482 | 483 | if return_dict_in_generate and output_scores: 484 | beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) 485 | 486 | # increase cur_len 487 | cur_len = cur_len + 1 488 | 489 | try: 490 | if beam_scorer.is_done or stopping_criteria(input_ids, scores): 491 | if not synced_gpus: 492 | break 493 | else: 494 | this_peer_finished = True 495 | except: 496 | if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): 497 | if not synced_gpus: 498 | break 499 | else: 500 | this_peer_finished = True 501 | 502 | 503 | sequence_outputs = beam_scorer.finalize( 504 | input_ids, 505 | beam_scores, 506 | next_tokens, 507 | next_indices, 508 | pad_token_id=pad_token_id, 509 | eos_token_id=eos_token_id, 510 | max_length=stopping_criteria.max_length, 511 | beam_indices=beam_indices, 512 | ) 513 | 514 | if return_dict_in_generate: 515 | if not output_scores: 516 | sequence_outputs["sequence_scores"] = None 517 | 518 | if self.config.is_encoder_decoder: 519 | return BeamSearchEncoderDecoderOutput( 520 | sequences=sequence_outputs["sequences"], 521 | sequences_scores=sequence_outputs["sequence_scores"], 522 | scores=scores, 523 | beam_indices=sequence_outputs["beam_indices"], 524 | encoder_attentions=encoder_attentions, 525 | encoder_hidden_states=encoder_hidden_states, 526 | decoder_attentions=decoder_attentions, 527 | cross_attentions=cross_attentions, 528 | decoder_hidden_states=decoder_hidden_states, 529 | ) 530 | else: 531 | return BeamSearchDecoderOnlyOutput( 532 | sequences=sequence_outputs["sequences"], 533 | sequences_scores=sequence_outputs["sequence_scores"], 534 | scores=scores, 535 | beam_indices=sequence_outputs["beam_indices"], 536 | attentions=decoder_attentions, 537 | hidden_states=decoder_hidden_states, 538 | ) 539 | else: 540 | return sequence_outputs["sequences"] 541 | 542 | 543 | class CABlock(nn.Module): 544 | ''' 545 | This function is an edited version of the gpt2 decoder block function from HuggingFace's transformers 546 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 547 | ''' 548 | def __init__(self, config, layer_idx=None): 549 | super().__init__() 550 | hidden_size = config.hidden_size 551 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 552 | 553 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 554 | 555 | self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) 556 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 557 | 558 | self.mlp = GPT2MLP(inner_dim, config) 559 | 560 | def forward( 561 | self, 562 | hidden_states: Optional[Tuple[torch.FloatTensor]], 563 | layer_past: Optional[Tuple[torch.Tensor]] = None, 564 | attention_mask: Optional[torch.FloatTensor] = None, 565 | head_mask: Optional[torch.FloatTensor] = None, 566 | encoder_hidden_states: Optional[torch.Tensor] = None, 567 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 568 | use_cache: Optional[bool] = False, 569 | output_attentions: Optional[bool] = False, 570 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 571 | 572 | 573 | residual = hidden_states 574 | hidden_states = self.ln_cross_attn(hidden_states) 575 | cross_attn_outputs = self.crossattention( 576 | hidden_states, 577 | attention_mask=attention_mask, 578 | head_mask=head_mask, 579 | encoder_hidden_states=encoder_hidden_states, 580 | encoder_attention_mask=encoder_attention_mask, 581 | output_attentions=output_attentions, 582 | ) 583 | attn_output = cross_attn_outputs[0] 584 | # residual connection 585 | hidden_states = residual + attn_output 586 | 587 | residual = hidden_states 588 | hidden_states = self.ln_2(hidden_states) 589 | feed_forward_hidden_states = self.mlp(hidden_states) 590 | # residual connection 591 | hidden_states = residual + feed_forward_hidden_states 592 | 593 | return (hidden_states,) 594 | 595 | class Prot2TextTrainer(Seq2SeqTrainer): 596 | ''' 597 | This function is an edited version of the Seq2SeqTrainer from HuggingFace's transformers 598 | ''' 599 | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: 600 | if self.args.world_size > 1: 601 | eval_sampler = DistributedSampler(self.eval_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) 602 | else: 603 | eval_sampler = None 604 | return DataLoader( 605 | self.eval_dataset, 606 | batch_size=self.args.eval_batch_size, 607 | collate_fn=None, 608 | num_workers=self.args.dataloader_num_workers, 609 | pin_memory=self.args.dataloader_pin_memory, 610 | sampler=eval_sampler, 611 | ) 612 | def get_train_dataloader(self) -> DataLoader: 613 | if self.args.world_size > 1: 614 | train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) 615 | else: 616 | train_sampler = None 617 | return DataLoader( 618 | self.train_dataset, 619 | batch_size=self.args.per_device_train_batch_size, 620 | collate_fn=None, 621 | num_workers=self.args.dataloader_num_workers, 622 | pin_memory=self.args.dataloader_pin_memory, 623 | sampler=train_sampler, 624 | ) 625 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: 626 | """ 627 | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and 628 | handling potential state. 629 | """ 630 | inputs = self._prepare_input(inputs) 631 | if len(inputs) == 0: 632 | raise ValueError( 633 | "The batch received was empty, your model won't be able to train on it. Double-check that your " 634 | f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." 635 | ) 636 | if self.args.past_index >= 0 and self._past is not None: 637 | inputs["mems"] = self._past 638 | 639 | inputs = inputs.to_dict() 640 | inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0) 641 | inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1) 642 | inputs = {k: v.to(device=self.args.device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} 643 | return inputs 644 | 645 | def prediction_step( 646 | self, 647 | model: nn.Module, 648 | inputs: Dict[str, Union[torch.Tensor, Any]], 649 | prediction_loss_only: bool, 650 | ignore_keys: Optional[List[str]] = None, 651 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 652 | """ 653 | Perform an evaluation step on `model` using `inputs`. 654 | 655 | Subclass and override to inject custom behavior. 656 | 657 | Args: 658 | model (`nn.Module`): 659 | The model to evaluate. 660 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 661 | The inputs and targets of the model. 662 | 663 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 664 | argument `labels`. Check your model's documentation for all accepted arguments. 665 | prediction_loss_only (`bool`): 666 | Whether or not to return the loss only. 667 | 668 | Return: 669 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 670 | labels (each being optional). 671 | """ 672 | 673 | if not self.args.predict_with_generate or prediction_loss_only: 674 | return super().prediction_step( 675 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 676 | ) 677 | 678 | has_labels = "labels" in inputs 679 | inputs = self._prepare_inputs(inputs) 680 | 681 | # XXX: adapt synced_gpus for fairscale as well 682 | gen_kwargs = self._gen_kwargs.copy() 683 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 684 | gen_kwargs["max_length"] = self.model.config.max_length 685 | gen_kwargs["num_beams"] = ( 686 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 687 | ) 688 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 689 | gen_kwargs["synced_gpus"] = ( 690 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 691 | ) 692 | 693 | if "attention_mask" in inputs: 694 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 695 | if "global_attention_mask" in inputs: 696 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 697 | 698 | generation_inputs = None 699 | gen_kwargs['x'] = inputs.get('x', None) 700 | gen_kwargs['edge_index'] = inputs.get('edge_index', None) 701 | gen_kwargs['edge_type'] = inputs.get('edge_type', None) 702 | gen_kwargs['batch'] = inputs.get('batch', None) 703 | gen_kwargs['encoder_input_ids'] = inputs.get('encoder_input_ids', None) 704 | gen_kwargs['decoder_input_ids'] = inputs.get('decoder_input_ids', None)[:,0:1] 705 | gen_kwargs["decoder_attention_mask"] = torch.ones(gen_kwargs['decoder_input_ids'].shape[0], 1).to(self.args.device) 706 | 707 | generated_tokens = self.model.generate( 708 | generation_inputs, 709 | **gen_kwargs, 710 | ) 711 | # in case the batch is shorter than max length, the output should be padded 712 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 713 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 714 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 715 | gen_kwargs["max_new_tokens"] + 1 716 | ): 717 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 718 | 719 | with torch.no_grad(): 720 | if has_labels: 721 | with self.compute_loss_context_manager(): 722 | outputs = model(**inputs) 723 | if self.label_smoother is not None: 724 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 725 | else: 726 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 727 | else: 728 | loss = None 729 | 730 | if self.args.prediction_loss_only: 731 | return (loss, None, None) 732 | 733 | if has_labels: 734 | labels = inputs["labels"] 735 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 736 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 737 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 738 | gen_kwargs["max_new_tokens"] + 1 739 | ): 740 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 741 | else: 742 | labels = None 743 | 744 | return (loss, generated_tokens, labels) -------------------------------------------------------------------------------- /prot2text_dataset/graphs.py: -------------------------------------------------------------------------------- 1 | """Functions for working with Protein Structure Graphs.""" 2 | # %% 3 | # Graphein 4 | # Author: Arian Jamasb , Eric Ma, Charlie Harris 5 | # License: MIT 6 | # Project Website: https://github.com/a-r-j/graphein 7 | # Code Repository: https://github.com/a-r-j/graphein 8 | from __future__ import annotations 9 | 10 | import logging 11 | import traceback 12 | from functools import partial 13 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 14 | 15 | import networkx as nx 16 | import numpy as np 17 | import pandas as pd 18 | # from Bio.PDB.Polypeptide import three_to_one 19 | from biopandas.pdb import PandasPdb 20 | from biopandas.mmcif import PandasMmcif 21 | from rich.progress import Progress 22 | from tqdm.contrib.concurrent import process_map 23 | 24 | from graphein.protein.config import ( 25 | DSSPConfig, 26 | GetContactsConfig, 27 | ProteinGraphConfig, 28 | ) 29 | from graphein.protein.edges.distance import ( 30 | add_distance_to_edges, 31 | compute_distmat, 32 | ) 33 | from graphein.protein.resi_atoms import BACKBONE_ATOMS, RESI_THREE_TO_1 34 | from graphein.protein.subgraphs import extract_subgraph_from_chains 35 | from graphein.protein.utils import ( 36 | ProteinGraphConfigurationError, 37 | compute_rgroup_dataframe, 38 | filter_dataframe, 39 | get_protein_name_from_filename, 40 | three_to_one_with_mods, 41 | ) 42 | from graphein.rna.constants import RNA_ATOMS 43 | from graphein.utils.utils import ( 44 | annotate_edge_metadata, 45 | annotate_graph_metadata, 46 | annotate_node_metadata, 47 | compute_edges, 48 | ) 49 | 50 | from .utils_convert import biopandas_mmcif2pdb 51 | 52 | # logging.basicConfig(level="DEBUG") 53 | log = logging.getLogger(__name__) 54 | 55 | 56 | 57 | def subset_structure_to_rna( 58 | df: pd.DataFrame, 59 | ) -> pd.DataFrame: 60 | """ 61 | Return a subset of atomic dataframe that contains only certain atom names relevant for RNA structures. 62 | 63 | :param df: Protein Structure dataframe to subset 64 | :type df: pd.DataFrame 65 | :returns: Subsetted protein structure dataframe 66 | :rtype: pd.DataFrame 67 | """ 68 | return filter_dataframe( 69 | df, by_column="atom_name", list_of_values=RNA_ATOMS, boolean=True 70 | ) 71 | 72 | 73 | def read_pdb_to_dataframe( 74 | pdb_path: Optional[str] = None, 75 | pdb_code: Optional[str] = None, 76 | uniprot_id: Optional[str] = None, 77 | model_index: int = 1, 78 | ) -> pd.DataFrame: 79 | """ 80 | Reads PDB file to ``PandasPDB`` object. 81 | 82 | Returns ``atomic_df``, which is a dataframe enumerating all atoms and their cartesian coordinates in 3D space. Also 83 | contains associated metadata from the PDB file. 84 | 85 | :param pdb_path: path to PDB file. Defaults to ``None``. 86 | :type pdb_path: str, optional 87 | :param pdb_code: 4-character PDB accession. Defaults to ``None``. 88 | :type pdb_code: str, optional 89 | :param uniprot_id: UniProt ID to build graph from AlphaFoldDB. Defaults to ``None``. 90 | :type uniprot_id: str, optional 91 | :param model_index: Index of model to read. Only relevant for structures containing ensembles. Defaults to ``1``. 92 | :type model_index: int, optional 93 | :param verbose: print dataframe? 94 | :type verbose: bool 95 | :param granularity: Specifies granularity of dataframe. See :class:`~graphein.protein.config.ProteinGraphConfig` for further 96 | details. 97 | :type granularity: str 98 | :returns: ``pd.DataFrame`` containing protein structure 99 | :rtype: pd.DataFrame 100 | """ 101 | if pdb_code is None and pdb_path is None and uniprot_id is None: 102 | raise NameError( 103 | "One of pdb_code, pdb_path or uniprot_id must be specified!" 104 | ) 105 | 106 | if pdb_path is not None: 107 | if pdb_path.endswith('cif'): 108 | atomic_df = PandasMmcif().read_mmcif(pdb_path) 109 | atomic_df = biopandas_mmcif2pdb(atomic_df, model_index) 110 | else: 111 | atomic_df = PandasPdb().read_pdb(pdb_path) 112 | else: 113 | if uniprot_id is not None: 114 | atomic_df = PandasPdb().fetch_pdb( 115 | uniprot_id=uniprot_id, source="alphafold2-v2" 116 | ) 117 | else: 118 | atomic_df = PandasPdb().fetch_pdb(pdb_code) 119 | 120 | atomic_df = atomic_df.get_model(model_index) 121 | if len(atomic_df.df["ATOM"]) == 0: 122 | raise ValueError(f"No model found for index: {model_index}") 123 | 124 | return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]]) 125 | 126 | 127 | def label_node_id(df: pd.DataFrame, granularity: str) -> pd.DataFrame: 128 | df["node_id"] = ( 129 | df["chain_id"].apply(str) 130 | + ":" 131 | + df["residue_name"] 132 | + ":" 133 | + df["residue_number"].apply(str) 134 | ) 135 | df["residue_id"] = df["node_id"] 136 | if granularity == "atom": 137 | df["node_id"] = df["node_id"] + ":" + df["atom_name"] 138 | elif granularity in {"rna_atom", "rna_centroid"}: 139 | df["node_id"] = ( 140 | df["node_id"] 141 | + ":" 142 | + df["atom_number"].apply(str) 143 | + ":" 144 | + df["atom_name"] 145 | ) 146 | return df 147 | 148 | 149 | def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame: 150 | """Remove protons from PDB dataframe. 151 | 152 | :param df: Atomic dataframe. 153 | :type df: pd.DataFrame 154 | :returns: Atomic dataframe with all ``atom_name == "H"`` removed. 155 | :rtype: pd.DataFrame 156 | """ 157 | log.debug( 158 | "Deprotonating protein. This removes H atoms from the pdb_df dataframe" 159 | ) 160 | return filter_dataframe( 161 | df, by_column="element_symbol", list_of_values=["H"], boolean=False 162 | ) 163 | 164 | 165 | def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame: 166 | """Overwrite existing ``(x, y, z)`` coordinates with centroids of the amino acids. 167 | 168 | :param df: Pandas Dataframe protein structure to convert into a dataframe of centroid positions. 169 | :type df: pd.DataFrame 170 | :return: pd.DataFrame with atoms/residues positions converted into centroid positions. 171 | :rtype: pd.DataFrame 172 | """ 173 | log.debug( 174 | "Converting dataframe to centroids. This averages XYZ coords of the atoms in a residue" 175 | ) 176 | 177 | centroids = calculate_centroid_positions(df) 178 | df = df.loc[df["atom_name"] == "CA"].reset_index(drop=True) 179 | df["x_coord"] = centroids["x_coord"] 180 | df["y_coord"] = centroids["y_coord"] 181 | df["z_coord"] = centroids["z_coord"] 182 | 183 | return df 184 | 185 | 186 | def subset_structure_to_atom_type( 187 | df: pd.DataFrame, granularity: str 188 | ) -> pd.DataFrame: 189 | """ 190 | Return a subset of atomic dataframe that contains only certain atom names. 191 | 192 | :param df: Protein Structure dataframe to subset. 193 | :type df: pd.DataFrame 194 | :returns: Subsetted protein structure dataframe. 195 | :rtype: pd.DataFrame 196 | """ 197 | return filter_dataframe( 198 | df, by_column="atom_name", list_of_values=[granularity], boolean=True 199 | ) 200 | 201 | 202 | def remove_insertions(df: pd.DataFrame, keep: str = "first") -> pd.DataFrame: 203 | """ 204 | This function removes insertions from PDB dataframes. 205 | 206 | :param df: Protein Structure dataframe to remove insertions from. 207 | :type df: pd.DataFrame 208 | :param keep: Specifies which insertion to keep. Options are ``"first"`` or ``"last"``. 209 | Default is ``"first"`` 210 | :type keep: str 211 | :return: Protein structure dataframe with insertions removed 212 | :rtype: pd.DataFrame 213 | """ 214 | # Catches unnamed insertions 215 | duplicates = df.duplicated( 216 | subset=["chain_id", "residue_number", "atom_name"], keep=keep 217 | ) 218 | df = df[~duplicates] 219 | 220 | # Catches explicit insertions 221 | df = filter_dataframe( 222 | df, by_column="insertion", list_of_values=[""], boolean=True 223 | ) 224 | 225 | # Remove alt_locs 226 | df = filter_dataframe( 227 | df, by_column="alt_loc", list_of_values=["", "A"], boolean=True 228 | ) 229 | 230 | return df 231 | 232 | 233 | def filter_hetatms( 234 | df: pd.DataFrame, keep_hets: List[str] 235 | ) -> List[pd.DataFrame]: 236 | """Return hetatms of interest. 237 | 238 | :param df: Protein Structure dataframe to filter hetatoms from. 239 | :type df: pd.DataFrame 240 | :param keep_hets: List of hetero atom names to keep. 241 | :returns: Protein structure dataframe with heteroatoms removed 242 | :rtype: pd.DataFrame 243 | """ 244 | return [df.loc[df["residue_name"] == hetatm] for hetatm in keep_hets] 245 | 246 | 247 | def process_dataframe( 248 | protein_df: pd.DataFrame, 249 | atom_df_processing_funcs: Optional[List[Callable]] = None, 250 | hetatom_df_processing_funcs: Optional[List[Callable]] = None, 251 | granularity: str = "centroids", 252 | chain_selection: str = "all", 253 | insertions: bool = False, 254 | deprotonate: bool = True, 255 | keep_hets: List[str] = [], 256 | verbose: bool = False, 257 | ) -> pd.DataFrame: 258 | """ 259 | Process ATOM and HETATM dataframes to produce singular dataframe used for graph construction. 260 | 261 | :param protein_df: Dataframe to process. 262 | Should be the object returned from :func:`~graphein.protein.graphs.read_pdb_to_dataframe`. 263 | :type protein_df: pd.DataFrame 264 | :param atom_df_processing_funcs: List of functions to process dataframe. These must take in a dataframe and return a 265 | dataframe. Defaults to None. 266 | :type atom_df_processing_funcs: List[Callable], optional 267 | :param hetatom_df_processing_funcs: List of functions to process the hetatom dataframe. These must take in a dataframe and return a dataframe 268 | :type hetatom_df_processing_funcs: List[Callable], optional 269 | :param granularity: The level of granularity for the graph. This determines the node definition. 270 | Acceptable values include: ``"centroids"``, ``"atoms"``, 271 | any of the atom_names in the PDB file (e.g. ``"CA"``, ``"CB"``, ``"OG"``, etc.). 272 | See: :const:`~graphein.protein.config.GRAPH_ATOMS` and :const:`~graphein.protein.config.GRANULARITY_OPTS`. 273 | :type granularity: str 274 | :param insertions: Whether or not to keep insertions. 275 | :param insertions: bool 276 | :param deprotonate: Whether or not to remove hydrogen atoms (i.e. deprotonation). 277 | :type deprotonate: bool 278 | :param keep_hets: Hetatoms to keep. Defaults to an empty list. 279 | To keep a hetatom, pass it inside a list of hetatom names to keep. 280 | :type keep_hets: List[str] 281 | :param verbose: Verbosity level. 282 | :type verbose: bool 283 | :param chain_selection: Which protein chain to select. Defaults to ``"all"``. Eg can use ``"ACF"`` 284 | to select 3 chains (``A``, ``C`` & ``F``) 285 | :type chain_selection: str 286 | :return: A protein dataframe that can be consumed by 287 | other graph construction functions. 288 | :rtype: pd.DataFrame 289 | """ 290 | protein_df = label_node_id(protein_df, granularity=granularity) 291 | # TODO: Need to properly define what "granularity" is supposed to do. 292 | atoms = filter_dataframe( 293 | protein_df, 294 | by_column="record_name", 295 | list_of_values=["ATOM"], 296 | boolean=True, 297 | ) 298 | hetatms = filter_dataframe( 299 | protein_df, 300 | by_column="record_name", 301 | list_of_values=["HETATM"], 302 | boolean=True, 303 | ) 304 | 305 | # This block enables processing via a list of supplied functions operating on the atom and hetatom dataframes 306 | # If these are provided, the dataframe returned will be computed only from these and the default workflow 307 | # below this block will not execute. 308 | if atom_df_processing_funcs is not None: 309 | for func in atom_df_processing_funcs: 310 | atoms = func(atoms) 311 | if hetatom_df_processing_funcs is None: 312 | return atoms 313 | 314 | if hetatom_df_processing_funcs is not None: 315 | for func in hetatom_df_processing_funcs: 316 | hetatms = func(hetatms) 317 | return pd.concat([atoms, hetatms]) 318 | 319 | if keep_hets: 320 | hetatms_to_keep = filter_hetatms(hetatms, keep_hets) 321 | atoms = pd.concat([atoms] + hetatms_to_keep) 322 | 323 | # Deprotonate structure by removing H atoms 324 | if deprotonate: 325 | atoms = deprotonate_structure(atoms) 326 | 327 | # Restrict DF to desired granularity 328 | if granularity == "atom": 329 | pass 330 | elif granularity in {"centroids", "rna_centroid"}: 331 | atoms = convert_structure_to_centroids(atoms) 332 | elif granularity == "rna_atom": 333 | atoms = subset_structure_to_rna(atoms) 334 | else: 335 | atoms = subset_structure_to_atom_type(atoms, granularity) 336 | 337 | protein_df = atoms 338 | 339 | # Remove alt_loc residues 340 | if not insertions: 341 | protein_df = remove_insertions(protein_df) 342 | 343 | # perform chain selection 344 | protein_df = select_chains( 345 | protein_df, chain_selection=chain_selection, verbose=verbose 346 | ) 347 | 348 | log.debug(f"Detected {len(protein_df)} total nodes") 349 | 350 | # Sort dataframe to place HETATMs 351 | protein_df = sort_dataframe(protein_df) 352 | 353 | return protein_df 354 | 355 | 356 | def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame: 357 | """Sorts a protein dataframe by chain->residue number->atom number 358 | 359 | This is useful for distributing hetatms/modified residues through the DF. 360 | 361 | :param df: Protein dataframe to sort. 362 | :type df: pd.DataFrame 363 | :return: Sorted protein dataframe. 364 | :rtype: pd.DataFrame 365 | """ 366 | return df.sort_values(by=["chain_id", "residue_number", "atom_number"]) 367 | 368 | 369 | def assign_node_id_to_dataframe( 370 | protein_df: pd.DataFrame, granularity: str 371 | ) -> pd.DataFrame: 372 | """ 373 | Assigns the node ID back to the ``pdb_df`` dataframe 374 | 375 | :param protein_df: Structure Dataframe 376 | :type protein_df: pd.DataFrame 377 | :param granularity: Granularity of graph. Atom-level, 378 | residue (e.g. ``CA``) or ``centroids``. 379 | See: :const:`~graphein.protein.config.GRAPH_ATOMS` 380 | and :const:`~graphein.protein.config.GRANULARITY_OPTS`. 381 | :type granularity: str 382 | :return: Returns dataframe with added ``node_ids`` 383 | :rtype: pd.DataFrame 384 | """ 385 | protein_df["node_id"] = ( 386 | protein_df["chain_id"].apply(str) 387 | + ":" 388 | + protein_df["residue_name"] 389 | + ":" 390 | + protein_df["residue_number"].apply(str) 391 | ) 392 | if granularity in {"atom", "rna_atom"}: 393 | protein_df[ 394 | "node_id" 395 | ] = f'{protein_df["node_id"]}:{protein_df["atom_name"]}' 396 | 397 | 398 | def select_chains( 399 | protein_df: pd.DataFrame, chain_selection: str, verbose: bool = False 400 | ) -> pd.DataFrame: 401 | """ 402 | Extracts relevant chains from ``protein_df``. 403 | 404 | :param protein_df: pandas dataframe of PDB subsetted to relevant atoms 405 | (``CA``, ``CB``). 406 | :type protein_df: pd.DataFrame 407 | :param chain_selection: Specifies chains that should be extracted from 408 | the larger complexed structure. 409 | :type chain_selection: str 410 | :param verbose: Print dataframe? 411 | :type verbose: bool 412 | :return: Protein structure dataframe containing only entries in the 413 | chain selection. 414 | :rtype: pd.DataFrame 415 | """ 416 | if chain_selection != "all": 417 | protein_df = filter_dataframe( 418 | protein_df, 419 | by_column="chain_id", 420 | list_of_values=list(chain_selection), 421 | boolean=True, 422 | ) 423 | 424 | return protein_df 425 | 426 | 427 | def initialise_graph_with_metadata( 428 | protein_df: pd.DataFrame, 429 | raw_pdb_df: pd.DataFrame, 430 | granularity: str, 431 | name: Optional[str] = None, 432 | pdb_code: Optional[str] = None, 433 | pdb_path: Optional[str] = None, 434 | ) -> nx.Graph: 435 | """ 436 | Initializes the nx Graph object with initial metadata. 437 | 438 | :param protein_df: Processed Dataframe of protein structure. 439 | :type protein_df: pd.DataFrame 440 | :param raw_pdb_df: Unprocessed dataframe of protein structure for comparison and traceability downstream. 441 | :type raw_pdb_df: pd.DataFrame 442 | :param granularity: Granularity of the graph (eg ``"atom"``, ``"CA"``, ``"CB"`` etc or ``"centroid"``). 443 | See: :const:`~graphein.protein.config.GRAPH_ATOMS` and :const:`~graphein.protein.config.GRANULARITY_OPTS`. 444 | :type granularity: str 445 | :param name: specified given name for the graph. If None, the PDB code or the file name will be used to name the graph. 446 | :type name: Optional[str], defaults to ``None`` 447 | :param pdb_code: PDB ID / Accession code, if the PDB is available on the PDB database. 448 | :type pdb_code: Optional[str], defaults to ``None`` 449 | :param pdb_path: path to local PDB file, if constructing a graph from a local file. 450 | :type pdb_path: Optional[str], defaults to ``None`` 451 | :return: Returns initial protein structure graph with metadata. 452 | :rtype: nx.Graph 453 | """ 454 | 455 | # Get name for graph if no name was provided 456 | if name is None: 457 | if pdb_path is not None: 458 | name = get_protein_name_from_filename(pdb_path) 459 | else: 460 | name = pdb_code 461 | 462 | G = nx.Graph( 463 | name=name, 464 | pdb_code=pdb_code, 465 | pdb_path=pdb_path, 466 | chain_ids=list(protein_df["chain_id"].unique()), 467 | pdb_df=protein_df, 468 | raw_pdb_df=raw_pdb_df, 469 | rgroup_df=compute_rgroup_dataframe(remove_insertions(raw_pdb_df)), 470 | coords=np.asarray(protein_df[["x_coord", "y_coord", "z_coord"]]), 471 | ) 472 | 473 | # Create graph and assign intrinsic graph-level metadata 474 | G.graph["node_type"] = granularity 475 | 476 | # Add Sequences to graph metadata 477 | for c in G.graph["chain_ids"]: 478 | if granularity == "rna_atom": 479 | sequence = protein_df.loc[protein_df["chain_id"] == c][ 480 | "residue_name" 481 | ].str.cat() 482 | else: 483 | sequence = ( 484 | protein_df.loc[protein_df["chain_id"] == c]["residue_name"] 485 | .apply(three_to_one_with_mods) 486 | .str.cat() 487 | ) 488 | G.graph[f"sequence_{c}"] = sequence 489 | return G 490 | 491 | 492 | def add_nodes_to_graph( 493 | G: nx.Graph, 494 | protein_df: Optional[pd.DataFrame] = None, 495 | verbose: bool = False, 496 | ) -> nx.Graph: 497 | """Add nodes into protein graph. 498 | 499 | :param G: ``nx.Graph`` with metadata to populate with nodes. 500 | :type G: nx.Graph 501 | :protein_df: DataFrame of protein structure containing nodes & initial node metadata to add to the graph. 502 | :type protein_df: pd.DataFrame, optional 503 | :param verbose: Controls verbosity of this step. 504 | :type verbose: bool 505 | :returns: nx.Graph with nodes added. 506 | :rtype: nx.Graph 507 | """ 508 | 509 | # If no protein dataframe is supplied, use the one stored in the Graph object 510 | if protein_df is None: 511 | protein_df = G.graph["pdb_df"] 512 | # Assign intrinsic node attributes 513 | chain_id = protein_df["chain_id"].apply(str) 514 | residue_name = protein_df["residue_name"] 515 | residue_number = protein_df["residue_number"] # .apply(str) 516 | coords = np.asarray(protein_df[["x_coord", "y_coord", "z_coord"]]) 517 | b_factor = protein_df["b_factor"] 518 | atom_type = protein_df["atom_name"] 519 | nodes = protein_df["node_id"] 520 | element_symbol = protein_df["element_symbol"] 521 | G.add_nodes_from(nodes) 522 | 523 | # Set intrinsic node attributes 524 | nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id") 525 | nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name") 526 | nx.set_node_attributes( 527 | G, dict(zip(nodes, residue_number)), "residue_number" 528 | ) 529 | nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type") 530 | nx.set_node_attributes( 531 | G, dict(zip(nodes, element_symbol)), "element_symbol" 532 | ) 533 | nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords") 534 | nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor") 535 | 536 | # TODO: include charge, line_idx for traceability? 537 | if verbose: 538 | print(nx.info(G)) 539 | print(G.nodes()) 540 | 541 | return G 542 | 543 | 544 | def calculate_centroid_positions( 545 | atoms: pd.DataFrame, verbose: bool = False 546 | ) -> pd.DataFrame: 547 | """ 548 | Calculates position of sidechain centroids. 549 | 550 | :param atoms: ATOM df of protein structure. 551 | :type atoms: pd.DataFrame 552 | :param verbose: bool controlling verbosity. 553 | :type verbose: bool 554 | :return: centroids (df). 555 | :rtype: pd.DataFrame 556 | """ 557 | centroids = ( 558 | atoms.groupby("residue_number") 559 | .mean()[["x_coord", "y_coord", "z_coord"]] 560 | .reset_index() 561 | ) 562 | if verbose: 563 | print(f"Calculated {len(centroids)} centroid nodes") 564 | log.debug(f"Calculated {len(centroids)} centroid nodes") 565 | return centroids 566 | 567 | 568 | def compute_edges( 569 | G: nx.Graph, 570 | funcs: List[Callable], 571 | get_contacts_config: Optional[GetContactsConfig] = None, 572 | ) -> nx.Graph: 573 | """ 574 | Computes edges for the protein structure graph. Will compute a pairwise 575 | distance matrix between nodes which is 576 | added to the graph metadata to facilitate some edge computations. 577 | 578 | :param G: nx.Graph with nodes to add edges to. 579 | :type G: nx.Graph 580 | :param funcs: List of edge construction functions. 581 | :type funcs: List[Callable] 582 | :param get_contacts_config: Config object for ``GetContacts`` if 583 | intramolecular edges are being used. 584 | :type get_contacts_config: graphein.protein.config.GetContactsConfig 585 | :return: Graph with added edges. 586 | :rtype: nx.Graph 587 | """ 588 | # This control flow prevents unnecessary computation of the distance matrices 589 | if "config" in G.graph: 590 | if G.graph["config"].granularity == "atom": 591 | G.graph["atomic_dist_mat"] = compute_distmat(G.graph["pdb_df"]) 592 | else: 593 | G.graph["dist_mat"] = compute_distmat(G.graph["pdb_df"]) 594 | 595 | for func in funcs: 596 | func(G) 597 | 598 | return add_distance_to_edges(G) 599 | 600 | 601 | def construct_graph( 602 | config: Optional[ProteinGraphConfig] = None, 603 | name: Optional[str] = None, 604 | pdb_path: Optional[str] = None, 605 | uniprot_id: Optional[str] = None, 606 | pdb_code: Optional[str] = None, 607 | chain_selection: str = "all", 608 | model_index: int = 1, 609 | df_processing_funcs: Optional[List[Callable]] = None, 610 | edge_construction_funcs: Optional[List[Callable]] = None, 611 | edge_annotation_funcs: Optional[List[Callable]] = None, 612 | node_annotation_funcs: Optional[List[Callable]] = None, 613 | graph_annotation_funcs: Optional[List[Callable]] = None, 614 | ) -> nx.Graph: 615 | """ 616 | Constructs protein structure graph from a ``pdb_code`` or ``pdb_path``. 617 | 618 | Users can provide a :class:`~graphein.protein.config.ProteinGraphConfig` 619 | object to specify construction parameters. 620 | 621 | However, config parameters can be overridden by passing arguments directly to the function. 622 | 623 | :param config: :class:`~graphein.protein.config.ProteinGraphConfig` object. If None, defaults to config in ``graphein.protein.config``. 624 | :type config: graphein.protein.config.ProteinGraphConfig, optional 625 | :param name: an optional given name for the graph. the PDB ID or PDB file name will be used if not specified. 626 | :type name: str, optional 627 | :param pdb_path: Path to ``pdb_file`` when constructing a graph from a local pdb file. Default is ``None``. 628 | :type pdb_path: Optional[str], defaults to ``None`` 629 | :param pdb_code: A 4-character PDB ID / accession to be used to construct the graph, if available. Default is ``None``. 630 | :type pdb_code: Optional[str], defaults to ``None`` 631 | :param uniprot_id: UniProt accession ID to build graph from AlphaFold2DB. Default is ``None``. 632 | :type uniprot_id: str, optional 633 | :param chain_selection: String of polypeptide chains to include in graph. E.g ``"ABDF"`` or ``"all"``. Default is ``"all"``. 634 | :type chain_selection: str 635 | :param model_index: Index of model to use in the case of structural ensembles. Default is ``1``. 636 | :type model_index: int 637 | :param df_processing_funcs: List of dataframe processing functions. Default is ``None``. 638 | :type df_processing_funcs: List[Callable], optional 639 | :param edge_construction_funcs: List of edge construction functions. Default is ``None``. 640 | :type edge_construction_funcs: List[Callable], optional 641 | :param edge_annotation_funcs: List of edge annotation functions. Default is ``None``. 642 | :type edge_annotation_funcs: List[Callable], optional 643 | :param node_annotation_funcs: List of node annotation functions. Default is ``None``. 644 | :type node_annotation_funcs: List[Callable], optional 645 | :param graph_annotation_funcs: List of graph annotation function. Default is ``None``. 646 | :type graph_annotation_funcs: List[Callable] 647 | :return: Protein Structure Graph 648 | :rtype: nx.Graph 649 | """ 650 | 651 | if pdb_code is None and pdb_path is None and uniprot_id is None: 652 | raise ValueError( 653 | "Either a PDB ID, UniProt ID or a path to a local PDB file" 654 | " must be specified to construct a graph" 655 | ) 656 | 657 | # If no config is provided, use default 658 | if config is None: 659 | config = ProteinGraphConfig() 660 | with Progress(transient=True) as progress: 661 | task1 = progress.add_task("Reading PDB file...", total=1) 662 | # Get name from pdb_file is no pdb_code is provided 663 | # if pdb_path and (pdb_code is None and uniprot_id is None): 664 | # pdb_code = get_protein_name_from_filename(pdb_path) 665 | # pdb_code = pdb_code if len(pdb_code) == 4 else None 666 | progress.advance(task1) 667 | 668 | # If config params are provided, overwrite them 669 | config.protein_df_processing_functions = ( 670 | df_processing_funcs 671 | if config.protein_df_processing_functions is None 672 | else config.protein_df_processing_functions 673 | ) 674 | config.edge_construction_functions = ( 675 | edge_construction_funcs 676 | if config.edge_construction_functions is None 677 | else config.edge_construction_functions 678 | ) 679 | config.node_metadata_functions = ( 680 | node_annotation_funcs 681 | if config.node_metadata_functions is None 682 | else config.node_metadata_functions 683 | ) 684 | config.graph_metadata_functions = ( 685 | graph_annotation_funcs 686 | if config.graph_metadata_functions is None 687 | else config.graph_metadata_functions 688 | ) 689 | config.edge_metadata_functions = ( 690 | edge_annotation_funcs 691 | if config.edge_metadata_functions is None 692 | else config.edge_metadata_functions 693 | ) 694 | 695 | raw_df = read_pdb_to_dataframe( 696 | pdb_path, 697 | pdb_code, 698 | uniprot_id, 699 | model_index=model_index, 700 | ) 701 | 702 | 703 | task2 = progress.add_task("Processing PDB dataframe...", total=1) 704 | # raw_df = label_node_id(raw_df, granularity=config.granularity) 705 | # raw_df.df["ATOM"] = label_node_id( 706 | # raw_df.df["ATOM"], granularity=config.granularity 707 | # ) 708 | # raw_df.df["HETATM"] = label_node_id( 709 | # raw_df.df["HETATM"], granularity=config.granularity 710 | # ) 711 | raw_df = sort_dataframe(raw_df) 712 | protein_df = process_dataframe( 713 | raw_df, 714 | chain_selection=chain_selection, 715 | granularity=config.granularity, 716 | insertions=config.insertions, 717 | keep_hets=config.keep_hets, 718 | ) 719 | progress.advance(task2) 720 | 721 | task3 = progress.add_task("Initializing graph...", total=1) 722 | # Initialise graph with metadata 723 | g = initialise_graph_with_metadata( 724 | protein_df=protein_df, 725 | raw_pdb_df=raw_df, 726 | name=name, 727 | pdb_code=pdb_code, 728 | pdb_path=pdb_path, 729 | granularity=config.granularity, 730 | ) 731 | # Add nodes to graph 732 | g = add_nodes_to_graph(g) 733 | # Add config to graph 734 | g.graph["config"] = config 735 | g.graph["path"] = g.graph["pdb_path"] 736 | 737 | # Annotate additional node metadata 738 | if config.node_metadata_functions is not None: 739 | g = annotate_node_metadata(g, config.node_metadata_functions) 740 | progress.advance(task3) 741 | task4 = progress.add_task("Constructing edges...", total=1) 742 | # Compute graph edges 743 | g = compute_edges( 744 | g, 745 | funcs=config.edge_construction_functions, 746 | get_contacts_config=None, 747 | ) 748 | progress.advance(task4) 749 | 750 | # Annotate additional graph metadata 751 | # print(g.graph['dssp_df']) 752 | if config.graph_metadata_functions is not None: 753 | g = annotate_graph_metadata(g, config.graph_metadata_functions) 754 | 755 | # Annotate additional edge metadata 756 | if config.edge_metadata_functions is not None: 757 | g = annotate_edge_metadata(g, config.edge_metadata_functions) 758 | 759 | return g 760 | 761 | 762 | def _mp_graph_constructor( 763 | args: Tuple[str, str, int], source: str, config: ProteinGraphConfig 764 | ) -> Union[nx.Graph, None]: 765 | """ 766 | Protein graph constructor for use in multiprocessing several protein structure graphs. 767 | 768 | :param args: Tuple of pdb code/path and the chain selection for that PDB. 769 | :type args: Tuple[str, str] 770 | :param use_pdb_code: Whether we are using ``"pdb_code"``s, ``pdb_path``s or ``"uniprot_id"``s. 771 | :type use_pdb_code: bool 772 | :param config: Protein structure graph construction config (see: :class:`graphein.protein.config.ProteinGraphConfig`). 773 | :type config: ProteinGraphConfig 774 | :return: Protein structure graph or ``None`` if an error is encountered. 775 | :rtype: Union[nx.Graph, None] 776 | """ 777 | log.info( 778 | f"Constructing graph for: {args[0]}. Chain selection: {args[1]}. Model index: {args[2]}" 779 | ) 780 | func = partial(construct_graph, config=config) 781 | try: 782 | if source == "pdb_code": 783 | return func( 784 | pdb_code=args[0], chain_selection=args[1], model_index=args[2] 785 | ) 786 | elif source == "pdb_path": 787 | return func( 788 | pdb_path=args[0], chain_selection=args[1], model_index=args[2] 789 | ) 790 | elif source == "uniprot_id": 791 | return func( 792 | uniprot_id=args[0], 793 | chain_selection=args[1], 794 | model_index=args[2], 795 | ) 796 | 797 | except Exception as ex: 798 | log.info( 799 | f"Graph construction error (PDB={args[0]})! {traceback.format_exc()}" 800 | ) 801 | log.info(ex) 802 | return None 803 | 804 | 805 | def construct_graphs_mp( 806 | pdb_code_it: Optional[List[str]] = None, 807 | pdb_path_it: Optional[List[str]] = None, 808 | uniprot_id_it: Optional[List[str]] = None, 809 | chain_selections: Optional[List[str]] = None, 810 | model_indices: Optional[List[str]] = None, 811 | config: ProteinGraphConfig = ProteinGraphConfig(), 812 | num_cores: int = 16, 813 | return_dict: bool = True, 814 | out_path: Optional[str] = None, 815 | ) -> Union[List[nx.Graph], Dict[str, nx.Graph]]: 816 | """ 817 | Constructs protein graphs for a list of pdb codes or pdb paths using multiprocessing. 818 | 819 | :param pdb_code_it: List of pdb codes to use for protein graph construction 820 | :type pdb_code_it: Optional[List[str]], defaults to ``None`` 821 | :param pdb_path_it: List of paths to PDB files to use for protein graph construction 822 | :type pdb_path_it: Optional[List[str]], defaults to ``None`` 823 | :param chain_selections: List of chains to select from the protein structures (e.g. ``["ABC", "A", "L", "CD"...]``) 824 | :type chain_selections: Optional[List[str]], defaults to ``None`` 825 | :param model_indices: List of model indices to use for protein graph construction. Only relevant for structures containing ensembles of models. 826 | :type model_indices: Optional[List[str]], defaults to ``None`` 827 | :param config: ProteinGraphConfig to use. 828 | :type config: graphein.protein.config.ProteinGraphConfig, defaults to default config params 829 | :param num_cores: Number of cores to use for multiprocessing. The more the merrier 830 | :type num_cores: int, defaults to ``16`` 831 | :param return_dict: Whether or not to return a dictionary (indexed by pdb codes/paths) or a list of graphs. 832 | :type return_dict: bool, default to ``True`` 833 | :param out_path: Path to save the graphs to. If None, graphs are not saved. 834 | :type out_path: Optional[str], defaults to ``None`` 835 | :return: Iterable of protein graphs. None values indicate there was a problem in constructing the graph for this particular pdb 836 | :rtype: Union[List[nx.Graph], Dict[str, nx.Graph]] 837 | """ 838 | assert ( 839 | pdb_code_it is not None or pdb_path_it is not None 840 | ), "Iterable of pdb codes, pdb paths or uniprot IDs required." 841 | 842 | if pdb_code_it is not None: 843 | pdbs = pdb_code_it 844 | source = "pdb_code" 845 | 846 | if pdb_path_it is not None: 847 | pdbs = pdb_path_it 848 | source = "pdb_path" 849 | 850 | if uniprot_id_it is not None: 851 | pdbs = uniprot_id_it 852 | source = "uniprot_id" 853 | 854 | if chain_selections is None: 855 | chain_selections = ["all"] * len(pdbs) 856 | 857 | if model_indices is None: 858 | model_indices = [1] * len(pdbs) 859 | 860 | constructor = partial(_mp_graph_constructor, source=source, config=config) 861 | 862 | graphs = list( 863 | process_map( 864 | constructor, 865 | [ 866 | (pdb, chain_selections[i], model_indices[i]) 867 | for i, pdb in enumerate(pdbs) 868 | ], 869 | max_workers=num_cores, 870 | ) 871 | ) 872 | if out_path is not None: 873 | [ 874 | nx.write_gpickle( 875 | g, str(f"{out_path}/" + f"{g.graph['name']}.pickle") 876 | ) 877 | for g in graphs 878 | ] 879 | 880 | if return_dict: 881 | graphs = {pdb: graphs[i] for i, pdb in enumerate(pdbs)} 882 | 883 | return graphs 884 | 885 | 886 | def compute_chain_graph( 887 | g: nx.Graph, 888 | chain_list: Optional[List[str]] = None, 889 | remove_self_loops: bool = False, 890 | return_weighted_graph: bool = False, 891 | ) -> Union[nx.Graph, nx.MultiGraph]: 892 | """Computes a chain-level graph from a protein structure graph. 893 | 894 | This graph features nodes as individual chains in a complex and edges as 895 | the interactions between constituent nodes in each chain. You have the 896 | option of returning an unweighted graph (multigraph, 897 | ``return_weighted_graph=False``) or a weighted graph 898 | (``return_weighted_graph=True``). The difference between these is the 899 | unweighted graph features and edge for each interaction between chains 900 | (ie the number of edges will be equal to the number of edges in the input 901 | protein structure graph), while the weighted graph sums these interactions 902 | to a single edge between chains with the counts stored as features. 903 | 904 | :param g: A protein structure graph to compute the chain graph of. 905 | :type g: nx.Graph 906 | :param chain_list: A list of chains to extract from the input graph. 907 | If ``None``, all chains will be used. This is provided as input to 908 | ``extract_subgraph_from_chains``. Default is ``None``. 909 | :type chain_list: Optional[List[str]] 910 | :param remove_self_loops: Whether to remove self-loops from the graph. 911 | Default is False. 912 | :type remove_self_loops: bool 913 | :return: A chain-level graph. 914 | :rtype: Union[nx.Graph, nx.MultiGraph] 915 | """ 916 | # If we are extracting specific chains, do it here. 917 | if chain_list is not None: 918 | g = extract_subgraph_from_chains(g, chain_list) 919 | 920 | # Initialise new graph with Metadata 921 | h = nx.MultiGraph() 922 | h.graph = g.graph 923 | h.graph["node_type"] = "chain" 924 | 925 | # Set nodes 926 | nodes_per_chain = {chain: 0 for chain in g.graph["chain_ids"]} 927 | sequences = {chain: "" for chain in g.graph["chain_ids"]} 928 | for n, d in g.nodes(data=True): 929 | nodes_per_chain[d["chain_id"]] += 1 930 | sequences[d["chain_id"]] += RESI_THREE_TO_1[d["residue_name"]] 931 | 932 | h.add_nodes_from(g.graph["chain_ids"]) 933 | 934 | for n, d in h.nodes(data=True): 935 | d["num_residues"] = nodes_per_chain[n] 936 | d["sequence"] = sequences[n] 937 | 938 | # Add edges 939 | for u, v, d in g.edges(data=True): 940 | h.add_edge( 941 | g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"] 942 | ) 943 | # Remove self-loops if necessary. Checks for equality between nodes in a given edge. 944 | if remove_self_loops: 945 | edges_to_remove: List[Tuple[str]] = [ 946 | (u, v) for u, v in h.edges() if u == v 947 | ] 948 | h.remove_edges_from(edges_to_remove) 949 | 950 | # Compute a weighted graph if required. 951 | if return_weighted_graph: 952 | return compute_weighted_graph_from_multigraph(h) 953 | return h 954 | 955 | 956 | def compute_weighted_graph_from_multigraph(g: nx.MultiGraph) -> nx.Graph: 957 | """Computes a weighted graph from a multigraph. 958 | 959 | This function is used to convert a multigraph to a weighted graph. The 960 | weights of the edges are the number of interactions between the nodes. 961 | 962 | :param g: A multigraph. 963 | :type g: nx.MultiGraph 964 | :return: A weighted graph. 965 | :rtype: nx.Graph 966 | """ 967 | H = nx.Graph() 968 | H.graph = g.graph 969 | H.add_nodes_from(g.nodes(data=True)) 970 | for u, v, d in g.edges(data=True): 971 | if H.has_edge(u, v): 972 | H[u][v]["weight"] += len(d["kind"]) 973 | H[u][v]["kind"].update(d["kind"]) 974 | for kind in list(d["kind"]): 975 | try: 976 | H[u][v][kind] += 1 977 | except KeyError: 978 | H[u][v][kind] = 1 979 | else: 980 | H.add_edge(u, v, weight=len(d["kind"]), kind=d["kind"]) 981 | for kind in list(d["kind"]): 982 | H[u][v][kind] = 1 983 | return H 984 | 985 | 986 | def number_groups_of_runs(list_of_values: List[Any]) -> List[str]: 987 | """Numbers groups of runs in a list of values. 988 | 989 | E.g. ``["A", "A", "B", "A", "A", "A", "B", "B"] -> 990 | ["A1", "A1", "B1", "A2", "A2", "A2", "B2", "B2"]`` 991 | 992 | :param list_of_values: List of values to number. 993 | :type list_of_values: List[Any] 994 | :return: List of numbered values. 995 | :rtype: List[str] 996 | """ 997 | df = pd.DataFrame({"val": list_of_values}) 998 | df["idx"] = df["val"].shift() != df["val"] 999 | df["sum"] = df.groupby("val")["idx"].cumsum() 1000 | return list(df["val"].astype(str) + df["sum"].astype(str)) 1001 | 1002 | 1003 | def compute_secondary_structure_graph( 1004 | g: nx.Graph, 1005 | allowable_ss_elements: Optional[List[str]] = None, 1006 | remove_non_ss: bool = True, 1007 | remove_self_loops: bool = False, 1008 | return_weighted_graph: bool = False, 1009 | ) -> Union[nx.Graph, nx.MultiGraph]: 1010 | """Computes a secondary structure graph from a protein structure graph. 1011 | 1012 | :param g: A protein structure graph to compute the secondary structure 1013 | graph of. 1014 | :type g: nx.Graph 1015 | :param remove_non_ss: Whether to remove non-secondary structure nodes from 1016 | the graph. These are denoted as ``"-"`` by DSSP. Default is True. 1017 | :type remove_non_ss: bool 1018 | :param remove_self_loops: Whether to remove self-loops from the graph. 1019 | Default is ``False``. 1020 | :type remove_self_loops: bool 1021 | :param return_weighted_graph: Whether to return a weighted graph. 1022 | Default is False. 1023 | :type return_weighted_graph: bool 1024 | :raises ProteinGraphConfigurationError: If the protein structure graph is 1025 | not configured correctly with secondary structure assignments on all 1026 | nodes. 1027 | :return: A secondary structure graph. 1028 | :rtype: Union[nx.Graph, nx.MultiGraph] 1029 | """ 1030 | # Initialise list of secondary structure elements we use to build the graph 1031 | ss_list: List[str] = [] 1032 | 1033 | # Check nodes have secondary structure assignment & store them in list 1034 | for _, d in g.nodes(data=True): 1035 | if "ss" not in d.keys(): 1036 | raise ProteinGraphConfigurationError( 1037 | "Secondary structure not defined for all nodes." 1038 | ) 1039 | ss_list.append(d["ss"]) 1040 | 1041 | # Number SS elements 1042 | ss_list = pd.Series(number_groups_of_runs(ss_list)) 1043 | ss_list.index = list(g.nodes()) 1044 | 1045 | # Remove unstructured elements if necessary 1046 | if remove_non_ss: 1047 | ss_list = ss_list[~ss_list.str.contains("-")] 1048 | # Subset to only allowable SS elements if necessary 1049 | if allowable_ss_elements: 1050 | ss_list = ss_list[ 1051 | ss_list.str.contains("|".join(allowable_ss_elements)) 1052 | ] 1053 | 1054 | constituent_residues: Dict[str, List[str]] = ss_list.index.groupby( 1055 | ss_list.values 1056 | ) 1057 | constituent_residues = { 1058 | k: list(v) for k, v in constituent_residues.items() 1059 | } 1060 | residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict() 1061 | 1062 | # Add Nodes from secondary structure list 1063 | h = nx.MultiGraph() 1064 | h.add_nodes_from(ss_list) 1065 | nx.set_node_attributes(h, residue_counts, "residue_counts") 1066 | nx.set_node_attributes(h, constituent_residues, "constituent_residues") 1067 | # Assign ss 1068 | for n, d in h.nodes(data=True): 1069 | d["ss"] = n[0] 1070 | 1071 | # Add graph-level metadata 1072 | h.graph = g.graph 1073 | h.graph["node_type"] = "secondary_structure" 1074 | 1075 | # Iterate over edges in source graph and add SS-SS edges to new graph. 1076 | for u, v, d in g.edges(data=True): 1077 | try: 1078 | h.add_edge( 1079 | ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}" 1080 | ) 1081 | except KeyError as e: 1082 | log.debug( 1083 | f"Edge {u}-{v} not added to secondary structure graph. \ 1084 | Reason: {e} not in graph" 1085 | ) 1086 | 1087 | # Remove self-loops if necessary. 1088 | # Checks for equality between nodes in a given edge. 1089 | if remove_self_loops: 1090 | edges_to_remove: List[Tuple[str]] = [ 1091 | (u, v) for u, v in h.edges() if u == v 1092 | ] 1093 | h.remove_edges_from(edges_to_remove) 1094 | 1095 | # Create weighted graph from h 1096 | if return_weighted_graph: 1097 | return compute_weighted_graph_from_multigraph(h) 1098 | return h 1099 | 1100 | 1101 | def compute_line_graph(g: nx.Graph, repopulate_data: bool = True) -> nx.Graph: 1102 | """Computes the line graph of a graph. 1103 | 1104 | The line graph of a graph G has a node for each edge in G and an edge 1105 | joining those nodes if the two edges in G share a common node. For directed 1106 | graphs, nodes are adjacent exactly when the edges they represent form a 1107 | directed path of length two. 1108 | 1109 | The nodes of the line graph are 2-tuples of nodes in the original graph (or 1110 | 3-tuples for multigraphs, with the key of the edge as the third element). 1111 | 1112 | :param g: Graph to compute the line graph of. 1113 | :type g: nx.Graph 1114 | :param repopulate_data: Whether or not to map node and edge data to edges 1115 | and nodes of the line graph, defaults to True 1116 | :type repopulate_data: bool, optional 1117 | :return: Line graph of g. 1118 | :rtype: nx.Graph 1119 | """ 1120 | l_g = nx.generators.line_graph(g) 1121 | l_g.graph = g.graph 1122 | 1123 | if repopulate_data: 1124 | source_edge_data = {(u, v): d for u, v, d in g.edges(data=True)} 1125 | nx.set_node_attributes(l_g, source_edge_data) 1126 | 1127 | node_list = {} 1128 | for u, v, d in l_g.edges(data=True): 1129 | node_union = u + v 1130 | for n in node_union: 1131 | if node_union.count(n) > 1: 1132 | node_list[(u, v)] = n 1133 | break 1134 | 1135 | source_node_data = {k: g.nodes[v] for k, v in node_list.items()} 1136 | nx.set_edge_attributes(l_g, source_node_data) 1137 | return l_g 1138 | --------------------------------------------------------------------------------