├── Constants.py ├── Dataset ├── AdjacencyTransform.py ├── Dataset.py ├── __init__.py ├── distanceTransform.py ├── myKnn.py ├── myRadiusGraph.py └── utils.py ├── README.md ├── Sampler ├── ImbalancedDatasetSampler.py └── Smote.py ├── data └── test │ ├── pdbs │ ├── A0A5S9MMK5.pdb.gz │ ├── A3DD66.pdb.gz │ ├── E9Q0S6.pdb.gz │ └── P97084.pdb.gz │ └── test.fasta ├── environment.yml ├── models ├── egnn_clean │ ├── __init__.py │ └── egnn_clean.py └── gnn.py ├── net_utils.py ├── params.py ├── parser.py ├── plots └── protein_structure.py ├── predict.py ├── prep └── utils.py ├── preprocessing ├── create_go.py ├── extract.py ├── generate_msa.py ├── preprocess.py └── utils.py ├── tools ├── hhblits.py ├── jackhmmer.py ├── msa_identifiers.py ├── parsers.py ├── residue_constants.py └── utils.py └── training.py /Constants.py: -------------------------------------------------------------------------------- 1 | residues = { 2 | "A": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7, "I": 8, "K": 9, "L": 10, "M": 11, 3 | "N": 12, "P": 13, "Q": 14, "R": 15, "S": 16, "T": 17, "V": 18, "W": 19, "Y": 20 4 | } 5 | 6 | INVALID_ACIDS = {"U", "O", "B", "Z", "J", "X", "*"} 7 | 8 | amino_acids = { 9 | "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", "GLN": "Q", "GLU": "E", 10 | "GLY": "G", "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", 11 | "PRO": "P", "PYL": "O", "SER": "S", "SEC": "U", "THR": "T", "TRP": "W", "TYR": "Y", 12 | "VAL": "V", "ASX": "B", "GLX": "Z", "XAA": "X", "XLE": "J" 13 | } 14 | 15 | root_terms = {"GO:0008150", "GO:0003674", "GO:0005575"} 16 | 17 | exp_evidence_codes = {"EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "TAS", "IC"} 18 | exp_evidence_codes = set([ 19 | "EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "TAS", "IC", 20 | "HTP", "HDA", "HMP", "HGI", "HEP"]) 21 | 22 | ROOT = "/home/fbqc9/PycharmProjects/TransFunData/data/" 23 | # ROOT = "D:/Workspace/python-3/transfunData/data_bp/" 24 | # ROOT = "/data_bp/pycharm/TransFunData/data_bp/" 25 | 26 | # CAFA4 Targets 27 | CAFA_TARGETS = {"287", "3702", "4577", "6239", "7227", "7955", "9606", "9823", "10090", "10116", "44689", "83333", 28 | "99287", "226900", "243273", "284812", "559292"} 29 | 30 | NAMESPACES = { 31 | "cc": "cellular_component", 32 | "mf": "molecular_function", 33 | "bp": "biological_process" 34 | } 35 | 36 | FUNC_DICT = { 37 | 'cc': 'GO:0005575', 38 | 'mf': 'GO:0003674', 39 | 'bp': 'GO:0008150'} 40 | 41 | BENCH_DICT = { 42 | 'cc': "CCO", 43 | 'mf': 'MFO', 44 | 'bp': 'BPO' 45 | } 46 | 47 | NAMES = { 48 | "cc": "Cellular Component", 49 | "mf": "Molecular Function", 50 | "bp": "Biological Process" 51 | } 52 | 53 | TEST_GROUPS = ["LK_bpo", "LK_mfo", "LK_cco", "NK_bpo", "NK_mfo", "NK_cco"] 54 | 55 | Final_thresholds = { 56 | "cellular_component": 0.50, 57 | "molecular_function": 0.90, 58 | "biological_process": 0.50 59 | } 60 | 61 | TFun_Plus_thresholds = { 62 | "cellular_component": (0.13, 0.87), 63 | "molecular_function": (0.36, 0.61), 64 | "biological_process": (0.38, 0.62) 65 | } 66 | -------------------------------------------------------------------------------- /Dataset/AdjacencyTransform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.transforms import BaseTransform 4 | 5 | 6 | class AdjacencyFeatures(BaseTransform): 7 | r"""Saves the Euclidean distance of linked nodes in its edge attributes. 8 | 9 | Args: 10 | norm (bool, optional): If set to :obj:`False`, the output will not be 11 | normalized to the interval :math:`[0, 1]`. (default: :obj:`True`) 12 | max_value (float, optional): If set and :obj:`norm=True`, normalization 13 | will be performed based on this value instead of the maximum value 14 | found in the data. (default: :obj:`None`) 15 | cat (bool, optional): If set to :obj:`False`, all existing edge 16 | attributes will be replaced. (default: :obj:`True`) 17 | """ 18 | 19 | def __init__(self, edge_types, cat=True): 20 | self.cat = cat 21 | self.edge_types = edge_types 22 | 23 | def __call__(self, data): 24 | for edge_type in self.edge_types: 25 | adjacent_edges = [] 26 | (row, col), pseudo = data['atoms', edge_type[1], 'atoms'].edge_index, \ 27 | data['atoms', edge_type[1], 'atoms'].get('edge_attr', None) 28 | 29 | for i, j in zip(row, col): 30 | assert i != j 31 | if abs(i - j) == 1: 32 | adjacent_edges.append(1) 33 | else: 34 | adjacent_edges.append(0) 35 | 36 | adjacent_edges = torch.FloatTensor(adjacent_edges).view(-1, 1) 37 | 38 | if pseudo is not None and self.cat: 39 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 40 | data['atoms', edge_type[1], 'atoms'].edge_attr = torch.cat([pseudo, adjacent_edges.type_as(pseudo)], 41 | dim=-1) 42 | else: 43 | data['atoms', edge_type[1], 'atoms'].edge_attr = adjacent_edges 44 | 45 | return data 46 | 47 | def __repr__(self) -> str: 48 | return f'{self.__class__.__name__} ' 49 | -------------------------------------------------------------------------------- /Dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import subprocess 5 | import torch 6 | import os.path as osp 7 | from torch_geometric.data import Dataset, download_url, HeteroData 8 | 9 | import Constants 10 | from Dataset.distanceTransform import myDistanceTransform 11 | from Dataset.myKnn import myKNNGraph 12 | from Dataset.myRadiusGraph import myRadiusGraph 13 | from Dataset.utils import find_files, process_pdbpandas, get_knn, generate_Identity_Matrix 14 | import torch_geometric.transforms as T 15 | from torch_geometric.data import Data 16 | from Dataset.AdjacencyTransform import AdjacencyFeatures 17 | from preprocessing.utils import pickle_load, pickle_save, get_sequence_from_pdb, fasta_to_dictionary, collect_test, \ 18 | read_test_set, read_test, cafa_fasta_to_dictionary 19 | import pandas as pd 20 | import random 21 | 22 | 23 | class PDBDataset(Dataset): 24 | """ 25 | Creates a dataset from a list of PDB files. 26 | :param file_list: path to LMDB file containing dataset 27 | :type file_list: list[Union[str, Path]] 28 | :param transform: transformation function for data_bp augmentation, defaults to None 29 | :type transform: function, optional 30 | """ 31 | 32 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, **kwargs): 33 | 34 | self.root = root 35 | self.seq_id = kwargs.get('seq_id', None) 36 | self.ont = kwargs.get('ont', None) 37 | self.session = kwargs.get('session', None) 38 | self.prot_ids = kwargs.get('prot_ids', []) 39 | self.test_file = kwargs.get('test_file', None) 40 | self.pdb_pth = kwargs.get('pdb_path', self.root + "/alphafold/") 41 | 42 | self.raw_file_list = [] 43 | self.processed_file_list = [] 44 | 45 | if self.session == "selected": 46 | self.data = self.prot_ids 47 | for i in self.data: 48 | self.raw_file_list.append('{}'.format(i)) 49 | self.processed_file_list.append('{}.pt'.format(i)) 50 | else: 51 | if self.session == "train": 52 | self.data = pickle_load(self.root + "/{}/{}/{}".format(self.seq_id, self.ont, self.session)) 53 | for i in self.data: 54 | for j in self.data[i]: 55 | self.raw_file_list.append('AF-{}-F1-model_v2.pdb.gz'.format(j)) 56 | self.processed_file_list.append('{}.pt'.format(j)) 57 | elif self.session == "validation": 58 | self.data = list(pickle_load(self.root + "{}/{}".format(self.seq_id, self.session))) 59 | for i in self.data: 60 | self.raw_file_list.append('AF-{}-F1-model_v2.pdb.gz'.format(i)) 61 | self.processed_file_list.append('{}.pt'.format(i)) 62 | elif self.session == "test": 63 | self.data = self.get_test(self.test_file) 64 | for i in self.data: 65 | self.raw_file_list.append('AF-{}-F1-model_v2.pdb.gz'.format(i)) 66 | self.processed_file_list.append('{}.pt'.format(i)) 67 | 68 | super().__init__(self.root, transform, pre_transform, pre_filter) 69 | 70 | @property 71 | def raw_dir(self) -> str: 72 | return self.pdb_pth 73 | 74 | @property 75 | def processed_dir(self) -> str: 76 | return self.root + "/processed/" 77 | 78 | @property 79 | def raw_file_names(self): 80 | return self.raw_file_list 81 | 82 | @property 83 | def processed_file_names(self): 84 | return self.processed_file_list 85 | 86 | def download(self): 87 | rem_files = set(self.raw_file_list) - set(find_files(self.raw_dir, suffix="pdb.gz", type="Name")) 88 | for file in rem_files: 89 | src = "/data_bp/pycharm/TransFunData/data_bp/alphafold/AF-{}-F1-model_v2.pdb.gz" 90 | des = self.root + "/raw/{}".format(file) 91 | if os.path.isfile(src.format(file)): 92 | pass 93 | # subprocess.call('cp {} {}'.format(src.format("pdb", file), des), shell=True) 94 | else: 95 | pass 96 | # download 97 | 98 | def process(self): 99 | rem_files = set(self.processed_file_list) - set(find_files(self.processed_dir, suffix="pt", type="Name")) 100 | print("{} unprocessed proteins out of {}".format(len(rem_files), len(self.processed_file_list))) 101 | chain_id = 'A' 102 | 103 | for file in rem_files: 104 | protein = file.split(".")[0] 105 | print("Processing protein {}".format(protein)) 106 | 107 | raw_path = self.raw_dir + '{}.pdb.gz'.format(protein) 108 | 109 | labels = { 110 | 'molecular_function': [], 111 | 'biological_process': [], 112 | 'cellular_component': [] 113 | } 114 | 115 | emb = torch.load(self.root + "/esm/{}.pt".format(protein)) 116 | embedding_features_per_residue = emb['representations'][33] 117 | embedding_features_per_sequence = emb['mean_representations'][33].view(1, -1) 118 | 119 | if raw_path: 120 | node_coords, sequence_features, sequence_letters = process_pdbpandas(raw_path, chain_id) 121 | # else: node_coords, sequence_features, sequence_letters = generate_Identity_Matrix( 122 | # embedding_features_per_residue.shape, self.fasta[protein]) 123 | 124 | assert embedding_features_per_residue.shape[0] == node_coords.shape[0] 125 | assert embedding_features_per_residue.shape[1] == embedding_features_per_sequence.shape[1] 126 | 127 | node_size = node_coords.shape[0] 128 | names = torch.arange(0, node_size, dtype=torch.int8) 129 | 130 | data = HeteroData() 131 | data['atoms'].pos = node_coords 132 | 133 | data['atoms'].molecular_function = torch.IntTensor(labels['molecular_function']) 134 | data['atoms'].biological_process = torch.IntTensor(labels['biological_process']) 135 | data['atoms'].cellular_component = torch.IntTensor(labels['cellular_component']) 136 | 137 | data['atoms'].sequence_features = sequence_features 138 | data['atoms'].embedding_features_per_residue = embedding_features_per_residue 139 | data['atoms'].names = names 140 | data['atoms'].sequence_letters = sequence_letters 141 | data['atoms'].embedding_features_per_sequence = embedding_features_per_sequence 142 | data['atoms'].protein = protein 143 | 144 | if self.pre_filter is not None and not self.pre_filter(data): 145 | continue 146 | 147 | if self.pre_transform is not None: 148 | _transforms = [] 149 | for i in self.pre_transform: 150 | if i[0] == "KNN": 151 | kwargs = {'mode': i[1], 'sequence_length': node_size} 152 | knn = get_knn(**kwargs) 153 | _transforms.append(myKNNGraph(i[1], k=knn, force_undirected=True, )) 154 | if i[0] == "DIST": 155 | _transforms.append(myRadiusGraph(i[1], r=i[2], loop=False)) 156 | _transforms.append(myDistanceTransform(edge_types=self.pre_transform, norm=True)) 157 | _transforms.append(AdjacencyFeatures(edge_types=self.pre_transform)) 158 | 159 | pre_transform = T.Compose(_transforms) 160 | data = pre_transform(data) 161 | 162 | torch.save(data, osp.join(self.root + "/processed/", f'{protein}.pt')) 163 | 164 | def len(self): 165 | return len(self.data) 166 | 167 | def get(self, idx): 168 | if self.session == "train": 169 | rep = random.sample(self.data[idx], 1)[0] 170 | return torch.load(osp.join(self.processed_dir, f'{rep}.pt')) 171 | elif self.session == "validation" or self.session == "selected" or self.session == "test": 172 | rep = self.data[idx] 173 | return torch.load(osp.join(self.processed_dir, f'{rep}.pt')) 174 | 175 | def get_test(self, test_file): 176 | # all_test = set(self.all_test.keys()) 177 | # all_test = set(self.all_test) 178 | # x = list(set(read_test_set("{}supplementary_data/cafa3/benchmark20171115/groundtruth/{}".format(self.root, test_file)))) 179 | # onlystructs_filter = pickle_load("/home/fbqc9/PycharmProjects/TransFun/evaluation/available_structures") 180 | # onlystructs_filter = set([i[0].split(".")[0] for i in onlystructs_filter if i[1] == True]) 181 | # x = [i for i in x if i in onlystructs_filter] 182 | 183 | data = pd.read_csv(Constants.ROOT + "timebased/test_data", sep="\t") 184 | data = data.loc[data['ONTOLOGY'] == self.ont] 185 | missing = set(pickle_load(Constants.ROOT + "timebased/missing_proteins")) 186 | data = list(set(data['ACC'].to_list()).difference(missing)) 187 | 188 | # x = list(pickle_load(test_file)[self.ont]) 189 | return data 190 | 191 | 192 | def load_dataset(root=None, **kwargs): 193 | """ 194 | Load files in file_list into corresponding dataset object. All files should be of type filetype. 195 | 196 | :param root: path to root 197 | :type file_list: list[Union[str, Path]] 198 | :param raw_path: path to raw path 199 | :type file_list: list[Union[str, Path]] 200 | 201 | :return: Pytorch Dataset containing data_bp 202 | :rtype: torch.utils.data_bp.Dataset 203 | """ 204 | 205 | if root == None: 206 | raise ValueError('Root path is empty, specify root directory') 207 | 208 | # Group; name; operation/cutoff; Description 209 | pre_transform = [("KNN", "sqrt", "sqrt", "K nearest neighbour with sqrt for neighbours"), 210 | ("KNN", "cbrt", "cbrt", "K nearest neighbour with sqrt for neighbours"), 211 | ("DIST", "dist_3", 3, "Distance of 2angs"), 212 | ("DIST", "dist_4", 4, "Distance of 2angs"), 213 | ("DIST", "dist_6", 6, "Distance of 2angs"), 214 | ("DIST", "dist_10", 10, "Distance of 2angs"), 215 | ("DIST", "dist_12", 12, "Distance of 2angs")] 216 | # PDB URL has 1 attached to it 217 | dataset = PDBDataset(root, pre_transform=pre_transform, **kwargs) 218 | return dataset 219 | 220 | 221 | # create raw and processed list. 222 | def generate_dataset(_group="molecular_function"): 223 | # load sequences as dictionary 224 | if _group == "molecular_function": 225 | x = pickle_load('/data_bp/pycharm/TransFunData/data_bp/molecular_function/{}'.format(_group)) 226 | raw = list(set([i for i in x.keys()])) 227 | elif _group == "cellular_component": 228 | pass 229 | elif _group == "biological_process": 230 | pass 231 | if raw: 232 | pickle_save(raw, '/data_bp/pycharm/TransFunData/data_bp/molecular_function/{}' 233 | .format("molecular_function_raw_list")) 234 | -------------------------------------------------------------------------------- /Dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/Dataset/__init__.py -------------------------------------------------------------------------------- /Dataset/distanceTransform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.transforms import Distance 4 | 5 | 6 | class myDistanceTransform(Distance): 7 | r""" 8 | 9 | """ 10 | 11 | def __init__(self, edge_types, norm=True, max_value=None, cat=True): 12 | super().__init__(norm, max_value, cat) 13 | self.edge_types = edge_types 14 | 15 | def __call__(self, data): 16 | for i in self.edge_types: 17 | (row, col), pos, pseudo = data['atoms', i[1], 'atoms'].edge_index, \ 18 | data['atoms'].pos, \ 19 | data['atoms', i[1], 'atoms'].get('edge_attr', None) 20 | 21 | dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1) 22 | 23 | if self.norm and dist.numel() > 0: 24 | dist = dist / (dist.max() if self.max is None else self.max) 25 | 26 | if pseudo is not None and self.cat: 27 | pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo 28 | data['atoms', i[1], 'atoms'].edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1) 29 | else: 30 | data['atoms', i[1], 'atoms'].edge_attr = dist 31 | 32 | return data 33 | 34 | def __repr__(self) -> str: 35 | return (f'{self.__class__.__name__}(norm={self.norm}, ' 36 | f'max_value={self.max})') 37 | -------------------------------------------------------------------------------- /Dataset/myKnn.py: -------------------------------------------------------------------------------- 1 | import torch_geometric 2 | from torch_geometric.transforms import KNNGraph 3 | from torch_geometric.utils import to_undirected 4 | 5 | 6 | class myKNNGraph(KNNGraph): 7 | r"""Creates a k-NN graph based on node positions :obj:`pos`. 8 | """ 9 | def __init__(self, name: str, k=6, loop=False, force_undirected=False, 10 | flow='source_to_target'): 11 | 12 | super().__init__(k, loop, force_undirected, flow) 13 | self.name = name 14 | 15 | def __call__(self, data): 16 | data['atoms', self.name, 'atoms'].edge_attr = None 17 | batch = data.batch if 'batch' in data else None 18 | edge_index = torch_geometric.nn.knn_graph(data['atoms'].pos, 19 | self.k, batch, 20 | loop=self.loop, 21 | flow=self.flow) 22 | 23 | if self.force_undirected: 24 | edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) 25 | 26 | data['atoms', self.name, 'atoms'].edge_index = edge_index 27 | return data 28 | 29 | def __repr__(self) -> str: 30 | return f'{self.__class__.__name__}(k={self.k})' 31 | -------------------------------------------------------------------------------- /Dataset/myRadiusGraph.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch_geometric 4 | from torch_geometric.transforms import RadiusGraph 5 | 6 | 7 | class myRadiusGraph(RadiusGraph): 8 | r"""Creates edges based on node positions :obj:`pos` to all points within a 9 | given distance. 10 | """ 11 | def __init__( 12 | self, 13 | name: str, 14 | r: float, 15 | loop: bool = False, 16 | max_num_neighbors: int = 32, 17 | flow: str = 'source_to_target', 18 | ): 19 | super().__init__(r, loop, max_num_neighbors, flow) 20 | self.name = name 21 | 22 | def __call__(self, data): 23 | data['atoms', self.name, 'atoms'].edge_attr = None 24 | batch = data.batch if 'batch' in data else None 25 | data['atoms', self.name, 'atoms'].edge_index = torch_geometric.nn.radius_graph( 26 | data['atoms'].pos, 27 | self.r, batch, self.loop, 28 | self.max_num_neighbors, 29 | self.flow) 30 | return data 31 | 32 | def __repr__(self) -> str: 33 | return f'{self.__class__.__name__}(r={self.r})' 34 | -------------------------------------------------------------------------------- /Dataset/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import re 4 | import subprocess 5 | from pathlib import Path 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | from biopandas.pdb import PandasPdb 11 | import torch.nn.functional as F 12 | from keras.utils import to_categorical 13 | from keras_preprocessing.sequence import pad_sequences 14 | 15 | import Constants 16 | from Constants import residues, amino_acids 17 | 18 | 19 | def get_input_data(): 20 | # test data 21 | input = ["1a0b", "1a0c", "1a0d", "1a0e", "1a0f", "1a0g", "1a0h", "1a0i", "1a0j", "1a0l"] 22 | raw = [s + ".pdb" for s in input] 23 | processed = [s + ".pt" for s in input] 24 | 25 | with open('../Dataset/raw.pickle', 'wb') as handle: 26 | pickle.dump(raw, handle, protocol=pickle.HIGHEST_PROTOCOL) 27 | 28 | with open('../Dataset/proceesed.pickle', 'wb') as handle: 29 | pickle.dump(processed, handle, protocol=pickle.HIGHEST_PROTOCOL) 30 | 31 | 32 | patterns = { 33 | 'pdb': r'pdb[0-9]*$', 34 | 'pdb.gz': r'pdb[0-9]*\.gz$', 35 | 'mmcif': r'(mm)?cif$', 36 | 'sdf': r'sdf[0-9]*$', 37 | 'xyz': r'xyz[0-9]*$', 38 | 'xyz-gdb': r'xyz[0-9]*$', 39 | 'silent': r'out$', 40 | 'sharded': r'@[0-9]+', 41 | } 42 | 43 | _regexes = {k: re.compile(v) for k, v in patterns.items()} 44 | 45 | 46 | def is_type(f, filetype): 47 | if filetype in _regexes: 48 | return _regexes[filetype].search(str(f)) 49 | else: 50 | return re.compile(filetype + r'$').search(str(f)) 51 | 52 | 53 | def find_files(path, suffix, relative=None, type="Path"): 54 | """ 55 | Find all files in path with given suffix. = 56 | 57 | :param path: Directory in which to find files. 58 | :type path: Union[str, Path] 59 | :param suffix: Suffix determining file type to search for. 60 | :type suffix: str 61 | :param relative: Flag to indicate whether to return absolute or relative path. 62 | 63 | :return: list of paths to all files with suffix sorted by their names. 64 | :rtype: list[str] 65 | """ 66 | if not relative: 67 | find_cmd = r"find {:} -regex '.*\.{:}' | sort".format(path, suffix) 68 | else: 69 | find_cmd = r"cd {:}; find . -regex '.*\.{:}' | cut -d '/' -f 2- | sort" \ 70 | .format(path, suffix) 71 | out = subprocess.Popen( 72 | find_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, 73 | cwd=os.getcwd(), shell=True) 74 | (stdout, stderr) = out.communicate() 75 | name_list = stdout.decode().split() 76 | name_list.sort() 77 | if type == "Path": 78 | return sorted([Path(x) for x in name_list]) 79 | elif type == "Name": 80 | return sorted([Path(x).name for x in name_list]) 81 | 82 | 83 | def process_pdbpandas(raw_path, chain_id): 84 | pdb_to_pandas = PandasPdb().read_pdb(raw_path) 85 | 86 | pdb_df = pdb_to_pandas.df['ATOM'] 87 | assert (len(set(pdb_df['chain_id'])) == 1) & (list(set(pdb_df['chain_id']))[0] == chain_id) 88 | 89 | pdb_df = pdb_df[(pdb_df['atom_name'] == 'CA') & (pdb_df['chain_id'] == chain_id)] 90 | pdb_df = pdb_df.drop_duplicates() 91 | 92 | _residues = pdb_df['residue_name'].to_list() 93 | _residues = [amino_acids[i] for i in _residues if i != "UNK"] 94 | 95 | sequence_features = [[residues[residue] for residue in _residues]] 96 | 97 | sequence_features = pad_sequences(sequence_features, maxlen=1024, truncating='post', padding='post') 98 | 99 | # sequences + padding 100 | sequence_features = torch.tensor(to_categorical(sequence_features, num_classes=len(residues) + 1)) 101 | # sequence_features = F.one_hot(sequence_features, num_classes=len(residues) + 1).to(dtype=torch.int64) 102 | 103 | node_coords = torch.tensor(pdb_df[['x_coord', 'y_coord', 'z_coord']].values, dtype=torch.float32) 104 | 105 | return node_coords, sequence_features, ''.join(_residues) 106 | 107 | return residues 108 | 109 | 110 | def generate_Identity_Matrix(shape, sequence): 111 | 112 | node_coords = torch.from_numpy(np.zeros(shape=(shape[0], 3))) 113 | _residues = sequence[3] 114 | 115 | # _residues = [amino_acids[i] for i in _residues if i != "UNK"] 116 | 117 | sequence_features = [[residues[residue] for residue in list(_residues) if residue not in Constants.INVALID_ACIDS]] 118 | sequence_features = pad_sequences(sequence_features, maxlen=1024, truncating='post', padding='post') 119 | # sequences + padding 120 | sequence_features = torch.tensor(to_categorical(sequence_features, num_classes=len(residues) + 1)) 121 | # sequence_features = F.one_hot(sequence_features, num_classes=len(residues) + 1).to(dtype=torch.int64) 122 | return node_coords, sequence_features, str(_residues) 123 | 124 | 125 | def get_cbrt(a): 126 | return a**(1./3.) 127 | 128 | 129 | def get_knn(**kwargs): 130 | mode = kwargs["mode"] 131 | seq_length = kwargs["sequence_length"] 132 | if mode == "sqrt": 133 | x = int(math.sqrt(seq_length)) 134 | if x % 2 == 0: 135 | return x + 1 136 | return x 137 | elif mode == "cbrt": 138 | x = int(get_cbrt(seq_length)) 139 | if x % 2 == 0: 140 | return x + 1 141 | return x 142 | else: 143 | return seq_length -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransFun for Protein Function Prediction 2 | TransFun is a method using a transformer-based protein language model and 3D-equivariant graph neural networks (EGNN) to distill information from both protein sequences and structures to predict protein function in terms of Gene Ontology (GO) terms. It extracts feature embeddings from protein sequences using a pre-trained protein language model (ESM) via transfer learning and combines them with 3D structures of proteins predicted by AlphaFold2 through EGNN to predict function. It achieved the state-of-the-art performance on the CAFA3 test dataset and a new test dataset. 3 | 4 | 5 | 6 | ## Installation 7 | ``` 8 | # clone project 9 | git clone https://github.com/jianlin-cheng/TransFun.git 10 | cd TransFun/ 11 | 12 | # download trained models and test sample 13 | curl https://calla.rnet.missouri.edu/rnaminer/transfun/data --output data.zip 14 | unzip data 15 | 16 | # create conda environment 17 | conda env create -f environment.yml 18 | conda activate transfun 19 | ``` 20 | 21 | 22 | ## Training dataset 23 | ``` 24 | https://calla.rnet.missouri.edu/rnaminer/transfun_all_data/ 25 | ``` 26 | 27 | 28 | ## Prediction 29 | 1. To predict protein function with protein structures in the PDB format as input (note: protein sequences are automatically extracted from the PDB files in the input pdb path). 30 | ``` 31 | python predict.py --data-path path_to_store_intermediate_files --ontology GO_function_category --input-type pdb --pdb-path data/alphafold --output output_file --cut-off probability_threshold 32 | ``` 33 | 34 | 2. To predict protein function with protein sequences in the fasta format and protein structures in the PDB format as input: 35 | ``` 36 | python predict.py --data-path path_to_store_intermediate_files --ontology GO_function_category --input-type fasta --pdb-path data/alphafold --fasta-path path_to_a_fasta_file --output result.txt --cut-off probability_threshold 37 | ``` 38 | 39 | 3. Full prediction command: 40 | ``` 41 | Predict protein functions with TransFun 42 | 43 | optional arguments: 44 | -h, --help Help message 45 | --data-path DATA_PATH 46 | Path to store intermediate data files 47 | --ontology ONTOLOGY GO function category: cellular_component, molecular_function, biological_process 48 | --no-cuda NO_CUDA Disables CUDA training 49 | --batch-size BATCH_SIZE 50 | Batch size 51 | --input-type {fasta,pdb} 52 | Input data type: fasta file or PDB files 53 | --fasta-path FASTA_PATH 54 | Path to a fasta containing one or more protein sequences 55 | --pdb-path PDB_PATH Path to the directory of one or more protein structure files in the PDB format 56 | --cut-off CUT_OFF Cut-off probability threshold to report function 57 | --output OUTPUT A file to save output. All the predictions are stored in this file 58 | 59 | ``` 60 | 61 | 4. An example of predicting cellular component of some proteins: 62 | ``` 63 | python predict.py --data-path data --ontology cellular_component --input-type pdb --pdb-path test/pdbs/ --output result.txt 64 | ``` 65 | 66 | 5. An example of predicting molecular function of some proteins: 67 | ``` 68 | python predict.py --data-path data --ontology molecular_function --input-type pdb --pdb-path test/pdbs/ --output result.txt 69 | ``` 70 | 71 | ## Reference 72 | ``` 73 | @article{10.1093/bioinformatics/btad208, 74 | author = {Boadu, Frimpong and Cao, Hongyuan and Cheng, Jianlin}, 75 | title = "{Combining protein sequences and structures with transformers and equivariant graph neural networks to predict protein function}", 76 | journal = {Bioinformatics}, 77 | volume = {39}, 78 | number = {Supplement_1}, 79 | pages = {i318-i325}, 80 | year = {2023}, 81 | month = {06}, 82 | issn = {1367-4811}, 83 | doi = {10.1093/bioinformatics/btad208}, 84 | url = {https://doi.org/10.1093/bioinformatics/btad208}, 85 | eprint = {https://academic.oup.com/bioinformatics/article-pdf/39/Supplement\_1/i318/50741489/btad208.pdf}, 86 | } 87 | 88 | ``` 89 | -------------------------------------------------------------------------------- /Sampler/ImbalancedDatasetSampler.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | 4 | import Constants 5 | from preprocessing.utils import class_distribution_counter, pickle_load 6 | 7 | 8 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 9 | """Samples elements randomly from a given list of indices for imbalanced dataset 10 | Arguments: 11 | indices: a list of indices 12 | num_samples: number of samples to draw 13 | callback_get_label: a callback-like function which takes two arguments - dataset and index 14 | """ 15 | 16 | def __init__( 17 | self, 18 | dataset, 19 | labels: list = None, 20 | indices: list = None, 21 | num_samples: int = None, 22 | callback_get_label: Callable = None, 23 | device: str = 'cpu', 24 | **kwargs 25 | ): 26 | 27 | # if indices is not provided, all elements in the dataset will be considered 28 | self.indices = list(range(len(dataset))) if indices is None else indices 29 | 30 | # define custom callback 31 | self.callback_get_label = callback_get_label 32 | 33 | # if num_samples is not provided, draw `len(indices)` samples in each iteration 34 | self.num_samples = len(self.indices) if num_samples is None else num_samples 35 | 36 | # distribution of classes in the dataset 37 | # df["label"] = self._get_labels(dataset) if labels is None else labels 38 | label_to_count = class_distribution_counter(**kwargs) 39 | 40 | go_terms = pickle_load(Constants.ROOT + "/go_terms") 41 | terms = go_terms['GO-terms-{}'.format(kwargs['ont'])] 42 | 43 | class_weights = [label_to_count[i] for i in terms] 44 | total = sum(class_weights) 45 | self.weights = torch.tensor([1.0 / label_to_count[i] for i in terms], 46 | dtype=torch.float).to(device) 47 | 48 | # def _get_labels(self, dataset): 49 | # if self.callback_get_label: 50 | # return self.callback_get_label(dataset) 51 | # elif isinstance(dataset, torch.utils.data_bp.TensorDataset): 52 | # return dataset.tensors[1] 53 | # elif isinstance(dataset, torchvision.datasets.MNIST): 54 | # return dataset.train_labels.tolist() 55 | # elif isinstance(dataset, torchvision.datasets.ImageFolder): 56 | # return [x[1] for x in dataset.imgs] 57 | # elif isinstance(dataset, torchvision.datasets.DatasetFolder): 58 | # return dataset.samples[:][1] 59 | # elif isinstance(dataset, torch.utils.data_bp.Subset): 60 | # return dataset.dataset.imgs[:][1] 61 | # elif isinstance(dataset, torch.utils.data_bp.Dataset): 62 | # return dataset.get_labels() 63 | # else: 64 | # raise NotImplementedError 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True)) 68 | 69 | def __len__(self): 70 | return self.num_samples 71 | 72 | # 73 | # from torch_geometric.loader import DataLoader 74 | # from Dataset.Dataset import load_dataset 75 | # 76 | # kwargs = { 77 | # 'seq_id': 0.95, 78 | # 'ont': 'cellular_component', 79 | # 'session': 'train' 80 | # } 81 | # 82 | # dataset = load_dataset(root=Constants.ROOT, **kwargs) 83 | # train_dataloader = DataLoader(dataset, 84 | # batch_size=30, 85 | # drop_last=False, 86 | # sampler=ImbalancedDatasetSampler(dataset, **kwargs)) 87 | # 88 | # 89 | # for i in train_dataloader: 90 | # print(i) -------------------------------------------------------------------------------- /Sampler/Smote.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from random import randint 3 | import random 4 | 5 | 6 | class SMOTE(object): 7 | def __init__(self, distance='euclidian', dims=512, k=5): 8 | super(SMOTE, self).__init__() 9 | self.newindex = 0 10 | self.k = k 11 | self.dims = dims 12 | self.distance_measure = distance 13 | 14 | def populate(self, N, i, nnarray, min_samples, k): 15 | while N: 16 | nn = randint(0, k - 2) 17 | 18 | diff = min_samples[nnarray[nn]] - min_samples[i] 19 | gap = random.uniform(0, 1) 20 | 21 | self.synthetic_arr[self.newindex, :] = min_samples[i] + gap * diff 22 | 23 | self.newindex += 1 24 | 25 | N -= 1 26 | 27 | def k_neighbors(self, euclid_distance, k): 28 | nearest_idx = torch.zeros((euclid_distance.shape[0], euclid_distance.shape[0]), dtype=torch.int64) 29 | 30 | idxs = torch.argsort(euclid_distance, dim=1) 31 | nearest_idx[:, :] = idxs 32 | 33 | return nearest_idx[:, 1:k] 34 | 35 | def find_k(self, X, k): 36 | euclid_distance = torch.zeros((X.shape[0], X.shape[0]), dtype=torch.float32) 37 | 38 | for i in range(len(X)): 39 | dif = (X - X[i]) ** 2 40 | dist = torch.sqrt(dif.sum(axis=1)) 41 | euclid_distance[i] = dist 42 | 43 | return self.k_neighbors(euclid_distance, k) 44 | 45 | def generate(self, min_samples, N, k): 46 | """ 47 | Returns (N/100) * n_minority_samples synthetic minority samples. 48 | Parameters 49 | ---------- 50 | min_samples : Numpy_array-like, shape = [n_minority_samples, n_features] 51 | Holds the minority samples 52 | N : percetange of new synthetic samples: 53 | n_synthetic_samples = N/100 * n_minority_samples. Can be < 100. 54 | k : int. Number of nearest neighbours. 55 | Returns 56 | ------- 57 | S : Synthetic samples. array, 58 | shape = [(N/100) * n_minority_samples, n_features]. 59 | """ 60 | T = min_samples.shape[0] 61 | self.synthetic_arr = torch.zeros(int(N / 100) * T, self.dims) 62 | N = int(N / 100) 63 | if self.distance_measure == 'euclidian': 64 | indices = self.find_k(min_samples, k) 65 | for i in range(indices.shape[0]): 66 | self.populate(N, i, indices[i], min_samples, k) 67 | self.newindex = 0 68 | return self.synthetic_arr 69 | 70 | def fit_generate(self, X, y): 71 | # get occurence of each class 72 | occ = torch.eye(int(y.max() + 1), int(y.max() + 1))[y].sum(axis=0) 73 | # get the dominant class 74 | dominant_class = torch.argmax(occ) 75 | # get occurence of the dominant class 76 | n_occ = int(occ[dominant_class].item()) 77 | for i in range(len(occ)): 78 | if i != dominant_class: 79 | # calculate the amount of synthetic data_bp to generate 80 | N = (n_occ - occ[i]) * 100 / occ[i] 81 | candidates = X[y == i] 82 | xs = self.generate(candidates, N, self.k) 83 | X = torch.cat((X, xs)) 84 | ys = torch.ones(xs.shape[0]) * i 85 | y = torch.cat((y, ys)) 86 | return X, y -------------------------------------------------------------------------------- /data/test/pdbs/A0A5S9MMK5.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/data/test/pdbs/A0A5S9MMK5.pdb.gz -------------------------------------------------------------------------------- /data/test/pdbs/A3DD66.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/data/test/pdbs/A3DD66.pdb.gz -------------------------------------------------------------------------------- /data/test/pdbs/E9Q0S6.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/data/test/pdbs/E9Q0S6.pdb.gz -------------------------------------------------------------------------------- /data/test/pdbs/P97084.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/data/test/pdbs/P97084.pdb.gz -------------------------------------------------------------------------------- /data/test/test.fasta: -------------------------------------------------------------------------------- 1 | >A0A5S9MMK5 2 | MMTTTVQKNCWRLDQTMLGLEKPGSSDISSSSTDTSAISPISVSSMPLSPDKEKKKISFV 3 | RYNPDIPQIVTSFKGYQKLMYQGYRYNIYQIAPERNFKSWRCVCAKKMHDGQWCKCRAET 4 | TMDNKNACTKGSHNHPPRHHVAEIEFIKSQLYSAALENPDHDAGDLVNQASMYLSDGVMF 5 | DNKESIKKSLVVARNKDGKPKKPRSKRMMKFEVDDDDENEYKMPKLETDISCFLPFINNM 6 | VKVEPPFSHTPTIQIPQPIPTPIQHQQQEQSNLLQPATLNGMNNPWMGMEDHLAMIWAAN 7 | AMLNPGLDVLSTIAALSKHQQHVQGPSPQQAATAPTTASLSSNLSVSSFTPQMPKEASIA 8 | IPAPLQVLNLKDLKPLPPLANIQTSPVIQAANLLLPVAALKKDSSTQTTEEIKVSQCLTS 9 | GCGCRVIRICCCDEGVCRRTAAC 10 | >A3DD66 11 | MPPGAKVPQAEIYKTSNLQGAVPTNSWESSILWNQYSLPIYAHPLTFKFKAEGIEVGKPA 12 | LGGSGIAYFGAHKNDFTVGHSSVYTFPDARADKISDFAVDAVMASGSGSIKATLMKGSPY 13 | AYFVFTGGNPRIDFSGTPTVFYGDSGSQCLGVTINGVNYGLFAPSGSKWQGIGTGTITCI 14 | LPAGKNYFSIAVLPDNTVSTLTYYKDYAYCFVTDTKVEWSYNETESTLTTTFTAEVSVKE 15 | GTNKGTILALYPHQWRNNPHILPLPYTYSTLRGIMKTIQGTSFKTVYRYHGILPNLPDKG 16 | TYDREALNRYINELALQADAPVAVDTYWFGKHLGKLSCALPIAEQLGNISAKDRFISFMK 17 | SSLEDWFTAKEGETAKLFYYDSNWGTLIGYPSSYGSDEELNDHHFHYGYFLHAAAQIALR 18 | DPQWASRDNWGAMVELLIKDIANWDRNDTRFPFLRNFDPYEGHSWASGHAGFADGNNQES 19 | SSEAINAWQAIILWGEATGNKTIRDLGIYLYTTEVEAVCNYWFDLYKDIFSPSYGHNYAS 20 | MVWGGKYCHEIWWNGTNSEKHGINFLPITAASLYLGKDPNYIKQNYEEMLRECGTSQPPN 21 | WKDIQYMYYALYDPAAAKNMWNESIVPEDGESKAHTYHWICNLDSLGLPDFSVTADTPLY 22 | SVFNKNNIRTYVVYNASSSAKKVTFSDGKVMTVGPHSMAVSTGSESEVLAGDLNGDGKIN 23 | STDISLMKRYLLKQIVDLPVEDDIKAADINKDGKVNSTDMSILKRVILRNYPL 24 | >P97084 25 | MALFNSAHGGNIREAATVLGISPDQLLDFSANINPLGMPVSVKRALIDNLDCIERYPDAD 26 | YFHLHQALARHHQVPASWILAGNGETESIFTVASGLKPRRAMIVTPGFAEYGRALAQSGC 27 | EIRRWSLREADGWQLTDAILEALTPDLDCLFLCTPNNPTGLLPERPLLQAIADRCKSLNI 28 | NLILDEAFIDFIPHETGFIPALKDNPHIWVLRSLTKFYAIPGLRLGYLVNSDDAAMARMR 29 | RQQMPWSVNALAALAGEVALQDSAWQQATWHWLREEGARFYQALCQLPLLTVYPGRANYL 30 | LLRCEREDIDLQRRLLTQRILIRSCANYPGLDSRYYRVAIRSAAQNERLLAALRNVLTGI 31 | APAD 32 | >E9Q0S6 33 | MGCTVSLVCCEALEPLPSCGPQPPGTPPGPARPERCEPGGAAPDPRRRLLLQPEDLEAPK 34 | THHFKVKAFKKVKPCGICRQAITREGCVCKVCSFSCHRKCQAKVAAPCVPPSSHELVPIT 35 | TETVPKNVVDVGEGDCRVGSSPKNLEEGGSMRVSPSIQPQPQSQPTSLSRNTSVSRAMED 36 | SCELDLVYVTERIIAVSFPSTANEENFRSNLREVAQMLKSKHGGNYLLFNLSEQRPDITK 37 | LHAKVLEFGWPDLHTPALEKICSVCKAMDTWLNADPHNVVVLHNKGNRGRIGVVIAAYLH 38 | YSNISASADQALDRFAMKRFYEDKIVPIGQPSQRRYVHYFSGLLSGSIKMNNKPLFLHHV 39 | IMHGIPNFESKGGCRPFLRIYQAMQPVYTSGIYNIPGDSQASICITIEPGLLLKGDILLK 40 | CYHKKFRSPARDVIFRVQFHTCAIHDLGVVFGKEDLDEAFKDDRFPDYGKVEFVFSYGPE 41 | KIQGMEHLENGPSVSVDYNTSDPLIRWDSYDNFSGHREDGMEEVVGHTQGPLDGSLYAKV 42 | KKKDSLNGSSGPVTTARPALSATPNHVEHTLSVSSDSGNSTASTKTDKTDEPVSGATTAP 43 | AALSPQEKKELDRLLSGFGVDREKQGAMYRAQQLRSHPGGGPTVPSPGRHIVPAQVHVNG 44 | GALASERETDILDDELPIQDGQSGGSMGTLSSLDGVTNTSESGYPETLSPLTNGLDKPYS 45 | TEPVLNGGGYPYEAANRVIPVHSSHSAPIRPSYSAQEGLAGYQREGPHPAWSQQVTSAHC 46 | GCDPSGLFRSQSFPDVEPQLPQAPTRGGSSREAVQRGLNSWQQQQPHPPPRQQERSPLQS 47 | LARSKPSPQLSAETPVAALPEFPRAASQQEIEQSIETLNMLMLDLEPASAAAPLHKSQSV 48 | PGAWPGASPLSSQPLLGSSRQSHPLTQSRSGYIPSGHSLGTPELVSSGRPYSPYDYQLHP 49 | AGSNQSFHPKSPASSTFLPSPHSSAGPQEPPASLPGLIAQPQLPPKETTSDPSRTPEEEP 50 | LNLEGLVAHRVAGVQARERQPAEPPGPLRRRAASDGQYENQSPEATSPRSPGVRSPVQCV 51 | SPELALTIALNPGGRPKEPHLHSYKEAFEEMEGTSPSSPPHSVARSPPGLAKTPLSALGL 52 | KPHNPADILLHPTGVARRLIQPEEDEGEEVTKPPEEPRSYVESVARTAVAGPRAQDVEPK 53 | SFSAPAAHAYGHETPLRNGTPGGSFVSPSPLSTSSPILSADSTSVGSFPSVVSSDQGPRT 54 | PFQPMLDSSIRSGSLGQPSPAALSYQSSSPVPVGGSSYNSPDYSLQPFSSSPESQGQPQY 55 | SAASVHMVPGSPQARHRTVGTNTPPSPGFGRRAVNPTMAAPGSPSLSHRQVMGPSGPGFH 56 | GNVVSGHPASAATTPGSPSLGRHPVGSHQVPGLHSSVVTTPGSPSLGRHPGAHQGNLASS 57 | LHSNAVISPGSPSLGRHLGGSGSVVPGSPSLDRHAAYGGYSTPEDRRPTLSRQSSASGYQ 58 | APSTPSFPVSPAYYPGLSSPATSPSPDSAAFRQGSPTPALPEKRRMSVGDRAGSLPNYAT 59 | INGKVSSSPVANGMASGSSTVSFSHTLPDFSKYSMPDNSPETRAKVKFVQDTSKYWYKPE 60 | ISREQAIALLKDQEPGAFIIRDSHSFRGAYGLAMKVSSPPPTITQQGKKGDMTHELVRHF 61 | LIETGPRGVKLKGCPNEPNFGSLSALVYQHSVIPLALPCKLVIPSRDPTDESKDSSGPAN 62 | STTDLLKQGAACNVLFVNSVDMESLTGPQAISKATSETLAADPTPAATIVHFKVSAQGIT 63 | LTDNQRKLFFRRHYPLNTVTFCDLDPQERKWMKTEGGAPAKLFGFVARKQGSTTDNACHL 64 | FAELDPNQPASAIVNFVSKVMLSAGQKR -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: transfun 2 | channels: 3 | - pytorch 4 | - pyg 5 | - bioconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - diamond=0.9.14 10 | - networkx=2.8.4 11 | - numpy=1.23.3 12 | - pandas=1.4.4 13 | - pillow=9.2.0 14 | - pip=22.3.* 15 | - pyg=2.0.4 16 | - python=3.8.* 17 | - pytorch=1.10.2 18 | - scikit-learn=1.1.2 19 | - scipy=1.9.1 20 | - setuptools=65.5.0 21 | - torchaudio=0.10.2 22 | - torchvision=0.11.3 23 | - pip: 24 | - biopandas==0.4.1 25 | - biopython==1.80 26 | - fair-esm==2.0.0 27 | - h5py==3.7.0 28 | - keras==2.10.0 29 | - keras-preprocessing==1.1.2 30 | - matplotlib==3.6.2 31 | - obonet==0.3.0 32 | - pip==22.3.1 33 | - tensorflow==2.10.0 34 | - torchviz==0.0.2 35 | - wandb==0.13.4 36 | -------------------------------------------------------------------------------- /models/egnn_clean/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianlin-cheng/TransFun/ab801d838eb8d691831fe177c1925e27cf3eace6/models/egnn_clean/__init__.py -------------------------------------------------------------------------------- /models/egnn_clean/egnn_clean.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import Sigmoid, Linear 4 | from torch_geometric.nn import global_mean_pool 5 | 6 | import net_utils 7 | 8 | 9 | class E_GCL(nn.Module): 10 | """ 11 | E(n) Equivariant Convolutional Layer 12 | re 13 | """ 14 | 15 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False): 16 | super(E_GCL, self).__init__() 17 | input_edge = input_nf * 2 18 | self.residual = residual 19 | self.attention = attention 20 | self.normalize = normalize 21 | self.coords_agg = coords_agg 22 | self.tanh = tanh 23 | self.epsilon = 1e-8 24 | edge_coords_nf = 1 25 | 26 | self.edge_mlp = nn.Sequential( 27 | nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), 28 | act_fn, 29 | nn.Linear(hidden_nf, hidden_nf), 30 | act_fn) 31 | 32 | self.node_mlp = nn.Sequential( 33 | nn.Linear(hidden_nf + input_nf, hidden_nf), 34 | act_fn, 35 | nn.Linear(hidden_nf, output_nf)) 36 | 37 | layer = nn.Linear(hidden_nf, 1, bias=False) 38 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 39 | 40 | coord_mlp = [] 41 | coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) 42 | coord_mlp.append(act_fn) 43 | coord_mlp.append(layer) 44 | if self.tanh: 45 | coord_mlp.append(nn.Tanh()) 46 | self.coord_mlp = nn.Sequential(*coord_mlp) 47 | 48 | if self.attention: 49 | self.att_mlp = nn.Sequential( 50 | nn.Linear(hidden_nf, 1), 51 | nn.Sigmoid()) 52 | 53 | def edge_model(self, source, target, radial, edge_attr): 54 | if edge_attr is None: # Unused. 55 | out = torch.cat([source, target, radial], dim=1) 56 | else: 57 | out = torch.cat([source, target, radial, edge_attr], dim=1) 58 | out = self.edge_mlp(out) 59 | if self.attention: 60 | att_val = self.att_mlp(out) 61 | out = out * att_val 62 | return out 63 | 64 | def node_model(self, x, edge_index, edge_attr, node_attr): 65 | row, col = edge_index 66 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) 67 | if node_attr is not None: 68 | agg = torch.cat([x, agg, node_attr], dim=1) 69 | else: 70 | agg = torch.cat([x, agg], dim=1) 71 | out = self.node_mlp(agg) 72 | if self.residual: 73 | out = x + out 74 | return out, agg 75 | 76 | def coord_model(self, coord, edge_index, coord_diff, edge_feat): 77 | row, col = edge_index 78 | trans = coord_diff * self.coord_mlp(edge_feat) 79 | if self.coords_agg == 'sum': 80 | agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) 81 | elif self.coords_agg == 'mean': 82 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 83 | else: 84 | raise Exception('Wrong coords_agg parameter' % self.coords_agg) 85 | coord = coord + agg 86 | return coord 87 | 88 | def coord2radial(self, edge_index, coord): 89 | row, col = edge_index 90 | coord_diff = coord[row] - coord[col] 91 | radial = torch.sum(coord_diff**2, 1).unsqueeze(1) 92 | 93 | if self.normalize: 94 | norm = torch.sqrt(radial).detach() + self.epsilon 95 | coord_diff = coord_diff / norm 96 | 97 | return radial, coord_diff 98 | 99 | def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None): 100 | row, col = edge_index 101 | radial, coord_diff = self.coord2radial(edge_index, coord) 102 | 103 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 104 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) 105 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 106 | 107 | return h, coord, edge_attr 108 | 109 | 110 | class EGNN(nn.Module): 111 | def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False): 112 | ''' 113 | 114 | :param in_node_nf: Number of features for 'h' at the input 115 | :param hidden_nf: Number of hidden features 116 | :param out_node_nf: Number of features for 'h' at the output 117 | :param in_edge_nf: Number of features for the edge features 118 | :param device: Device (e.g. 'cpu', 'cuda:0',...) 119 | :param act_fn: Non-linearity 120 | :param n_layers: Number of layer for the EGNN 121 | :param residual: Use residual connections, we recommend not changing this one 122 | :param attention: Whether using attention or not 123 | :param normalize: Normalizes the coordinates messages such that: 124 | instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij) 125 | we get: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j|| 126 | We noticed it may help in the stability or generalization in some future works. 127 | We didn't use it in our paper. 128 | :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of 129 | phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy. 130 | We didn't use it in our paper. 131 | ''' 132 | 133 | super(EGNN, self).__init__() 134 | self.hidden_nf = hidden_nf 135 | self.device = device 136 | self.n_layers = n_layers 137 | self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf) 138 | self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) 139 | for i in range(0, n_layers): 140 | self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, 141 | act_fn=act_fn, residual=residual, attention=attention, 142 | normalize=normalize, tanh=tanh)) 143 | self.to(self.device) 144 | 145 | def forward(self, h, x, edges, edge_attr): 146 | h = self.embedding_in(h) 147 | for i in range(0, self.n_layers): 148 | h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr) 149 | h = self.embedding_out(h) 150 | 151 | return h, x 152 | 153 | 154 | def unsorted_segment_sum(data, segment_ids, num_segments): 155 | result_shape = (num_segments, data.size(1)) 156 | result = data.new_full(result_shape, 0) # Init empty result tensor. 157 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 158 | result.scatter_add_(0, segment_ids, data) 159 | return result 160 | 161 | 162 | def unsorted_segment_mean(data, segment_ids, num_segments): 163 | result_shape = (num_segments, data.size(1)) 164 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 165 | result = data.new_full(result_shape, 0) # Init empty result tensor. 166 | count = data.new_full(result_shape, 0) 167 | result.scatter_add_(0, segment_ids, data) 168 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 169 | return result / count.clamp(min=1) 170 | 171 | 172 | def get_edges(n_nodes): 173 | rows, cols = [], [] 174 | for i in range(n_nodes): 175 | for j in range(n_nodes): 176 | if i != j: 177 | rows.append(i) 178 | cols.append(j) 179 | 180 | edges = [rows, cols] 181 | return edges 182 | 183 | 184 | def get_edges_batch(n_nodes, batch_size): 185 | edges = get_edges(n_nodes) 186 | edge_attr = torch.ones(len(edges[0]) * batch_size, 1) 187 | edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])] 188 | if batch_size == 1: 189 | return edges, edge_attr 190 | elif batch_size > 1: 191 | rows, cols = [], [] 192 | for i in range(batch_size): 193 | rows.append(edges[0] + n_nodes * i) 194 | cols.append(edges[1] + n_nodes * i) 195 | edges = [torch.cat(rows), torch.cat(cols)] 196 | return edges, edge_attr 197 | 198 | 199 | if __name__ == "__main__": 200 | # Dummy parameters 201 | batch_size = 8 202 | n_nodes = 4 203 | n_feat = 1 204 | x_dim = 3 205 | 206 | # Dummy variables h, x and fully connected edges 207 | h = torch.ones(batch_size * n_nodes, n_feat) 208 | x = torch.ones(batch_size * n_nodes, x_dim) 209 | edges, edge_attr = get_edges_batch(n_nodes, batch_size) 210 | 211 | # Initialize EGNN 212 | egnn = EGNN(in_node_nf=n_feat, hidden_nf=32, out_node_nf=1, in_edge_nf=1) 213 | 214 | # Run EGNN 215 | h, x = egnn(h, x, edges, edge_attr) 216 | 217 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch.nn import Sigmoid 5 | from models.egnn_clean import egnn_clean as eg 6 | import net_utils 7 | 8 | 9 | class GCN(torch.nn.Module): 10 | def __init__(self, **kwargs): 11 | super(GCN, self).__init__() 12 | 13 | input_features_size = kwargs['input_features_size'] 14 | hidden_channels = kwargs['hidden'] 15 | edge_features = kwargs['edge_features'] 16 | num_classes = kwargs['num_classes'] 17 | num_egnn_layers = kwargs['egnn_layers'] 18 | 19 | self.edge_type = kwargs['edge_type'] 20 | self.num_layers = kwargs['layers'] 21 | self.device = kwargs['device'] 22 | 23 | self.egnn_1 = eg.EGNN(in_node_nf=input_features_size, 24 | hidden_nf=hidden_channels, 25 | n_layers=num_egnn_layers, 26 | out_node_nf=num_classes, 27 | in_edge_nf=edge_features, 28 | attention=True, 29 | normalize=False, 30 | tanh=True) 31 | 32 | self.egnn_2 = eg.EGNN(in_node_nf=num_classes, 33 | hidden_nf=hidden_channels, 34 | n_layers=num_egnn_layers, 35 | out_node_nf=int(num_classes / 2), 36 | in_edge_nf=edge_features, 37 | attention=True, 38 | normalize=False, 39 | tanh=True) 40 | 41 | self.egnn_3 = eg.EGNN(in_node_nf=input_features_size, 42 | hidden_nf=hidden_channels, 43 | n_layers=num_egnn_layers, 44 | out_node_nf=int(num_classes / 2), 45 | in_edge_nf=edge_features, 46 | attention=True, 47 | normalize=False, 48 | tanh=True) 49 | 50 | self.egnn_4 = eg.EGNN(in_node_nf=int(num_classes / 2), 51 | hidden_nf=hidden_channels, 52 | n_layers=num_egnn_layers, 53 | out_node_nf=int(num_classes / 4), 54 | in_edge_nf=edge_features, 55 | attention=True, 56 | normalize=False, 57 | tanh=True) 58 | 59 | self.fc1 = net_utils.FC(num_classes + int(num_classes / 2) * 2 + int(num_classes / 4), 60 | num_classes + 50, relu=False, bnorm=True) 61 | self.final = net_utils.FC(num_classes + 50, num_classes, relu=False, bnorm=False) 62 | 63 | self.bnrelu1 = net_utils.BNormRelu(num_classes) 64 | self.bnrelu2 = net_utils.BNormRelu(int(num_classes / 2)) 65 | self.bnrelu3 = net_utils.BNormRelu(int(num_classes / 4)) 66 | self.sig = Sigmoid() 67 | 68 | def forward_once(self, data): 69 | x_res, x_emb_seq, edge_index, x_batch, x_pos = data['atoms'].embedding_features_per_residue, \ 70 | data['atoms'].embedding_features_per_sequence, \ 71 | data[self.edge_type].edge_index, \ 72 | data['atoms'].batch, \ 73 | data['atoms'].pos 74 | 75 | ppi_shape = x_emb_seq.shape[0] 76 | 77 | if ppi_shape > 1: 78 | edge_index_2 = list(zip(*list(itertools.combinations(range(ppi_shape), 2)))) 79 | edge_index_2 = [torch.LongTensor(edge_index_2[0]).to(self.device), 80 | torch.LongTensor(edge_index_2[1]).to(self.device)] 81 | else: 82 | edge_index_2 = tuple(range(ppi_shape)) 83 | edge_index_2 = [torch.LongTensor(edge_index_2).to(self.device), 84 | torch.LongTensor(edge_index_2).to(self.device)] 85 | 86 | output_res, pre_pos_res = self.egnn_1(h=x_res, 87 | x=x_pos.float(), 88 | edges=edge_index, 89 | edge_attr=None) 90 | 91 | output_res_2, pre_pos_res_2 = self.egnn_2(h=output_res, 92 | x=pre_pos_res.float(), 93 | edges=edge_index, 94 | edge_attr=None) 95 | 96 | output_seq, pre_pos_seq = self.egnn_3(h=x_emb_seq, 97 | x=net_utils.get_pool(pool_type='mean')(x_pos.float(), x_batch), 98 | edges=edge_index_2, 99 | edge_attr=None) 100 | 101 | output_res_4, pre_pos_seq_4 = self.egnn_4(h=output_res_2, 102 | x=pre_pos_res_2.float(), 103 | edges=edge_index, 104 | edge_attr=None) 105 | 106 | output_res = net_utils.get_pool(pool_type='mean')(output_res, x_batch) 107 | output_res = self.bnrelu1(output_res) 108 | 109 | output_res_2 = net_utils.get_pool(pool_type='mean')(output_res_2, x_batch) 110 | output_res_2 = self.bnrelu2(output_res_2) 111 | 112 | output_seq = self.bnrelu2(output_seq) 113 | 114 | output_res_4 = net_utils.get_pool(pool_type='mean')(output_res_4, x_batch) 115 | output_res_4 = self.bnrelu3(output_res_4) 116 | 117 | output = torch.cat([output_res, output_res_2, output_seq, output_res_4], 1) 118 | 119 | return output 120 | 121 | def forward(self, data): 122 | passes = [] 123 | 124 | for i in range(self.num_layers): 125 | passes.append(self.forward_once(data)) 126 | 127 | x = torch.cat(passes, 1) 128 | 129 | x = self.fc1(x) 130 | x = self.final(x) 131 | x = self.sig(x) 132 | 133 | return x 134 | -------------------------------------------------------------------------------- /net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import GCNConv, BatchNorm, global_add_pool, global_mean_pool, global_max_pool 4 | 5 | 6 | class GCN(nn.Module): 7 | def __init__(self, input_features, out_channels, relu=True): 8 | super(GCN, self).__init__() 9 | self.conv = GCNConv(input_features, out_channels) 10 | self.relu = nn.LeakyReLU(0.1, inplace=True) if relu else None 11 | 12 | def forward(self, x): 13 | edge_index = x[1] 14 | x = self.conv(x[0], edge_index) 15 | if self.relu is not None: 16 | x = self.relu(x) 17 | return (x, edge_index) 18 | 19 | 20 | class GCN_BatchNorm(nn.Module): 21 | def __init__(self, in_channels, out_channels, relu=True): 22 | super(GCN_BatchNorm, self).__init__() 23 | 24 | self.conv = GCNConv(in_channels, out_channels, bias=False) 25 | self.bn = BatchNorm(out_channels, momentum=0.1) 26 | self.relu = nn.LeakyReLU(0.1, inplace=True) if relu else None 27 | 28 | def forward(self, x): 29 | edge_index = x[1] 30 | x = self.conv(x[0], edge_index) 31 | if self.relu is not None: 32 | x = self.relu(x) 33 | x = self.bn(x) 34 | return x 35 | 36 | 37 | class FC(nn.Module): 38 | def __init__(self, in_features, out_features, relu=True, bnorm=True): 39 | super(FC, self).__init__() 40 | _bias = False if bnorm else True 41 | self.fc = nn.Linear(in_features, out_features, bias=_bias) 42 | self.relu = nn.ReLU(inplace=True) if relu else None 43 | #self.bn = BatchNorm(out_features, momentum=0.1) if bnorm else None 44 | self.bn = nn.BatchNorm1d(out_features, momentum=0.1) if bnorm else None 45 | 46 | def forward(self, x): 47 | x = self.fc(x) 48 | if self.bn is not None: 49 | x = self.bn(x) 50 | if self.relu is not None: 51 | x = self.relu(x) 52 | return x 53 | 54 | 55 | class BNormRelu(nn.Module): 56 | def __init__(self, in_features, relu=True, bnorm=True): 57 | super(BNormRelu, self).__init__() 58 | self.relu = nn.ReLU(inplace=True) if relu else None 59 | self.bn = BatchNorm(in_features, momentum=0.1) if bnorm else None 60 | # self.bn = nn.BatchNorm1d(out_features, momentum=0.1) if bnorm else None 61 | 62 | def forward(self, x): 63 | if self.bn is not None: 64 | x = self.bn(x) 65 | if self.relu is not None: 66 | x = self.relu(x) 67 | return x 68 | 69 | def get_pool(pool_type='max'): 70 | if pool_type == 'mean': 71 | return global_mean_pool 72 | elif pool_type == 'add': 73 | return global_add_pool 74 | elif pool_type == 'max': 75 | return global_max_pool 76 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | bio_kwargs = { 2 | 'hidden': 16, 3 | 'input_features_size': 1280, 4 | 'num_classes': 3774, 5 | 'edge_type': 'cbrt', 6 | 'edge_features': 0, 7 | 'egnn_layers': 12, 8 | 'layers': 1, 9 | 'device': 'cuda', 10 | 'wd': 5e-4 11 | } 12 | 13 | mol_kwargs = { 14 | 'hidden': 16, 15 | 'input_features_size': 1280, 16 | 'num_classes': 600, 17 | 'edge_type': 'cbrt', 18 | 'edge_features': 0, 19 | 'egnn_layers': 12, 20 | 'layers': 1, 21 | 'device': 'cuda', 22 | 'wd': 0.001 23 | } 24 | 25 | cc_kwargs = { 26 | 'hidden': 16, 27 | 'input_features_size': 1280, 28 | 'num_classes': 547, 29 | 'edge_type': 'cbrt', 30 | 'edge_features': 0, 31 | 'egnn_layers': 12, 32 | 'layers': 1, 33 | 'device': 'cuda', 34 | 'wd': 0.001 #5e-4 35 | } 36 | 37 | edge_types = set(['sqrt', 'cbrt', 'dist_3', 'dist_4', 'dist_6', 'dist_10', 'dist_12', 38 | 'molecular_function', 'biological_process', 'cellular_component', 39 | 'all', 'names', 'sequence_letters', 'ptr', 'sequence_features']) 40 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import argparse 3 | import os 4 | import torch 5 | 6 | warnings.filterwarnings("ignore", category=UserWarning) 7 | 8 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 12 | parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.') 13 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 14 | parser.add_argument('--epochs', type=int, default=50, help='Number of epochs to train.') 15 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.') 16 | parser.add_argument('--weight_decay', type=float, default=1e-16, help='Weight decay (L2 loss on parameters).') 17 | parser.add_argument('--train_batch', type=int, default=32, help='Training batch size.') 18 | parser.add_argument('--valid_batch', type=int, default=32, help='Validation batch size.') 19 | parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).') 20 | parser.add_argument('--seq', type=float, default=0.5, help='Sequence Identity (Sequence Identity).') 21 | parser.add_argument("--ont", default='biological_process', type=str, help='Ontology under consideration') 22 | 23 | args = parser.parse_args() 24 | args.cuda = not args.no_cuda and torch.cuda.is_available() 25 | 26 | 27 | def get_parser(): 28 | return args 29 | -------------------------------------------------------------------------------- /plots/protein_structure.py: -------------------------------------------------------------------------------- 1 | # import networkx as nx 2 | # import numpy as np 3 | # import matplotlib.pyplot as plt 4 | # from mpl_toolkits.mplot3d import Axes3D 5 | # from biopandas.pdb import PandasPdb 6 | import random 7 | 8 | import networkx as nx 9 | 10 | import Constants 11 | from Dataset.Dataset import load_dataset 12 | import matplotlib.pyplot as plt 13 | 14 | kwargs = { 15 | 'prot_ids': ['P83847', ], 16 | 'session': 'selected' 17 | } 18 | 19 | dataset = load_dataset(root=Constants.ROOT, **kwargs) 20 | protein = dataset[0] 21 | print(protein) 22 | node_coords = protein.pos 23 | edges = protein.edge_index 24 | 25 | 26 | exit() 27 | 28 | # x, y, z = node_coords[:, 0], node_coords[:, 1], node_coords[:, 2] 29 | # # print(x.shape) 30 | # # print(y.shape) 31 | # # print(z.shape) 32 | # 33 | # fig = plt.figure() 34 | # ax = plt.axes(projection='3d') 35 | # 36 | # ax.scatter3D(x, y, z, 'gray') 37 | # # ax.plot3D(x, y, z, 'gray') 38 | # 39 | # fig.tight_layout() 40 | # plt.show() 41 | 42 | 43 | 44 | 45 | # import networkx as nx 46 | # import numpy as np 47 | # import matplotlib.pyplot as plt 48 | # from mpl_toolkits.mplot3d import Axes3D 49 | # from biopandas.pdb import PandasPdb 50 | import networkx as nx 51 | import numpy as np 52 | 53 | import Constants 54 | from Dataset.Dataset import load_dataset 55 | import matplotlib.pyplot as plt 56 | 57 | kwargs = { 58 | 'prot_ids': ['A0A023FBW4', ], 59 | 'session': 'selected' 60 | } 61 | 62 | dataset = load_dataset(root=Constants.ROOT, **kwargs) 63 | protein = dataset[0] 64 | 65 | print(protein) 66 | 67 | 68 | 69 | exit() 70 | 71 | 72 | def plot_residues(protein, add_edges=False, limit=25): 73 | node_coords = protein.pos.numpy() 74 | 75 | limit = len(node_coords) 76 | 77 | if add_edges: 78 | indicies = random.sample(range(0, limit), 25) 79 | else: 80 | indicies = random.sample(range(0, limit), limit) 81 | 82 | edges = protein.edge_index.numpy() 83 | some_edges = [] 84 | edges = [i for i in zip(edges[0], edges[1])] 85 | for i, j in edges: 86 | if i in indicies and j in indicies: 87 | some_edges.append(([node_coords[i][0], node_coords[j][0]], 88 | [node_coords[i][1], node_coords[j][1]], 89 | [node_coords[i][2], node_coords[j][2]])) 90 | 91 | node_coords = np.array([node_coords[i] for i in indicies]) 92 | 93 | x, y, z = node_coords[:, 0], node_coords[:, 1], node_coords[:, 2] 94 | # # print(x.shape) 95 | # # print(y.shape) 96 | # # print(z.shape) 97 | 98 | fig = plt.figure() 99 | ax = plt.axes(projection='3d') 100 | 101 | ax.scatter3D(x, y, z) 102 | 103 | if add_edges: 104 | for x in some_edges: 105 | ax.plot3D(x[0], x[1], x[2])# , 'gray') 106 | 107 | ax.set_xlabel('X axis') 108 | ax.set_ylabel('Y axis') 109 | ax.set_zlabel('Z axis') 110 | 111 | # fig.tight_layout() 112 | plt.title("Some protein") 113 | 114 | plt.show() 115 | 116 | 117 | plot_residues(protein) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import networkx as nx 5 | import obonet 6 | import torch 7 | from Bio import SeqIO 8 | from torch import optim 9 | from torch_geometric.loader import DataLoader 10 | import Constants 11 | import params 12 | from Dataset.Dataset import load_dataset 13 | from models.gnn import GCN 14 | from preprocessing.utils import load_ckp, get_sequence_from_pdb, create_seqrecord, get_proteins_from_fasta, \ 15 | generate_bulk_embedding, pickle_load, fasta_to_dictionary 16 | 17 | parser = argparse.ArgumentParser(description=" Predict protein functions with TransFun ", epilog=" Thank you !!!") 18 | parser.add_argument('--data-path', type=str, default="data", help="Path to data files") 19 | parser.add_argument('--ontology', type=str, default="cellular_component", help="Path to data files") 20 | parser.add_argument('--no-cuda', default=False, help='Disables CUDA training.') 21 | parser.add_argument('--batch-size', default=10, help='Batch size.') 22 | parser.add_argument('--input-type', choices=['fasta', 'pdb'], default="fasta", 23 | help='Input Data: fasta file or PDB files') 24 | parser.add_argument('--fasta-path', default="sequence.fasta", help='Path to Fasta') 25 | parser.add_argument('--pdb-path', default="alphafold", help='Path to directory of PDBs') 26 | parser.add_argument('--cut-off', type=float, default=0.0, help="Cut of to report function") 27 | parser.add_argument('--output', type=str, default="output", help="File to save output") 28 | # parser.add_argument('--add-ancestors', default=False, help="Add ancestor terms to prediction") 29 | 30 | args = parser.parse_args() 31 | args.cuda = not args.no_cuda and torch.cuda.is_available() 32 | 33 | if args.cuda: 34 | device = 'cuda' 35 | else: 36 | device = 'cpu' 37 | 38 | if args.ontology == 'molecular_function': 39 | ont_kwargs = params.mol_kwargs 40 | elif args.ontology == 'cellular_component': 41 | ont_kwargs = params.cc_kwargs 42 | elif args.ontology == 'biological_process': 43 | ont_kwargs = params.bio_kwargs 44 | ont_kwargs['device'] = device 45 | 46 | FUNC_DICT = { 47 | 'cellular_component': 'GO:0005575', 48 | 'molecular_function': 'GO:0003674', 49 | 'biological_process': 'GO:0008150' 50 | } 51 | 52 | print("Predicting proteins") 53 | 54 | def create_fasta(proteins): 55 | fasta = [] 56 | for protein in proteins: 57 | alpha_fold_seq = get_sequence_from_pdb("{}/{}/{}.pdb.gz".format(args.data_path, args.pdb_path, protein), "A") 58 | fasta.append(create_seqrecord(id=protein, seq=alpha_fold_seq)) 59 | SeqIO.write(fasta, "{}/sequence.fasta".format(args.data_path), "fasta") 60 | args.fasta_path = "{}/sequence.fasta".format(args.data_path) 61 | 62 | 63 | def write_to_file(data, output): 64 | with open('{}'.format(output), 'w') as fp: 65 | for protein, go_terms in data.items(): 66 | for go_term, score in go_terms.items(): 67 | fp.write('%s %s %s\n' % (protein, go_term, score)) 68 | 69 | 70 | def generate_embeddings(fasta_path): 71 | def merge_pts(keys, fasta): 72 | embeddings = [0, 32, 33] 73 | for protein in keys: 74 | fasta_dic = fasta_to_dictionary(fasta) 75 | tmp = [] 76 | for level in range(keys[protein]): 77 | os_path = "{}/esm/{}_{}.pt".format(args.data_path, protein, level) 78 | tmp.append(torch.load(os_path)) 79 | 80 | data = {'representations': {}, 'mean_representations': {}} 81 | for index in tmp: 82 | for rep in embeddings: 83 | assert torch.equal(index['mean_representations'][rep], torch.mean(index['representations'][rep], dim=0)) 84 | 85 | if rep in data['representations']: 86 | data['representations'][rep] = torch.cat((data['representations'][rep], index['representations'][rep])) 87 | else: 88 | data['representations'][rep] = index['representations'][rep] 89 | 90 | for emb in embeddings: 91 | assert len(fasta_dic[protein][3]) == data['representations'][emb].shape[0] 92 | 93 | for rep in embeddings: 94 | data['mean_representations'][rep] = torch.mean(data['representations'][rep], dim=0) 95 | 96 | # print("saving {}".format(protein)) 97 | torch.save(data, "{}/esm/{}.pt".format(args.data_path, protein)) 98 | 99 | def crop_fasta(record): 100 | splits = [] 101 | keys = {} 102 | main_id = record.id 103 | keys[main_id] = int(len(record.seq) / 1021) + 1 104 | for pos in range(int(len(record.seq) / 1021) + 1): 105 | id = "{}_{}".format(main_id, pos) 106 | seq = str(record.seq[pos * 1021:(pos * 1021) + 1021]) 107 | splits.append(create_seqrecord(id=id, name=id, description="", seq=seq)) 108 | return splits, keys 109 | 110 | keys = {} 111 | sequences = [] 112 | input_seq_iterator = SeqIO.parse(fasta_path, "fasta") 113 | for record in input_seq_iterator: 114 | if len(record.seq) > 1021: 115 | _seqs, _keys = crop_fasta(record) 116 | sequences.extend(_seqs) 117 | keys.update(_keys) 118 | else: 119 | sequences.append(record) 120 | 121 | cropped_fasta = "{}/sequence_cropped.fasta".format(args.data_path) 122 | SeqIO.write(sequences, cropped_fasta, "fasta") 123 | 124 | generate_bulk_embedding("./preprocessing/extract.py", "{}".format(cropped_fasta), 125 | "{}/esm".format(args.data_path)) 126 | 127 | # merge 128 | if len(keys) > 0: 129 | print("Found {} protein with length > 1021".format(len(keys))) 130 | merge_pts(keys, fasta_path) 131 | 132 | 133 | if args.input_type == 'fasta': 134 | if not args.fasta_path is None: 135 | proteins = set(get_proteins_from_fasta("{}/{}".format(args.data_path, args.fasta_path))) 136 | pdbs = set([i.split(".")[0] for i in os.listdir("{}/{}".format(args.data_path, args.pdb_path))]) 137 | proteins = list(pdbs.intersection(proteins)) 138 | elif args.input_type == 'pdb': 139 | if not args.pdb_path is None: 140 | pdb_path = "{}/{}".format(args.data_path, args.pdb_path) 141 | if os.path.exists(pdb_path): 142 | proteins = os.listdir(pdb_path) 143 | proteins = [protein.split('.')[0] for protein in proteins if protein.endswith(".pdb.gz")] 144 | if len(proteins) == 0: 145 | print(print("No proteins found.".format(pdb_path))) 146 | exit() 147 | create_fasta(proteins) 148 | else: 149 | print("PDB directory not found -- {}".format(pdb_path)) 150 | exit() 151 | 152 | 153 | if len(proteins) > 0: 154 | print("Predicting for {} proteins".format(len(proteins))) 155 | 156 | print("Generating Embeddings from {}".format(args.fasta_path)) 157 | os.makedirs("{}/esm".format(args.data_path), exist_ok=True) 158 | generate_embeddings(args.fasta_path) 159 | 160 | kwargs = { 161 | 'seq_id': Constants.Final_thresholds[args.ontology], 162 | 'ont': args.ontology, 163 | 'session': 'selected', 164 | 'prot_ids': proteins, 165 | 'pdb_path': "{}/{}".format(args.data_path, args.pdb_path) 166 | } 167 | 168 | dataset = load_dataset(root=args.data_path, **kwargs) 169 | 170 | test_dataloader = DataLoader(dataset, 171 | batch_size=args.batch_size, 172 | drop_last=False, 173 | shuffle=False) 174 | 175 | # model 176 | model = GCN(**ont_kwargs) 177 | model.to(device) 178 | 179 | optimizer = optim.Adam(model.parameters()) 180 | 181 | ckp_pth = "{}/{}.pt".format(args.data_path, args.ontology) 182 | print(ckp_pth) 183 | # load the saved checkpoint 184 | if os.path.exists(ckp_pth): 185 | model, optimizer, current_epoch, min_val_loss = load_ckp(ckp_pth, model, optimizer, device) 186 | else: 187 | print("Model not found. Skipping...") 188 | exit() 189 | 190 | model.eval() 191 | 192 | scores = [] 193 | proteins = [] 194 | 195 | for data in test_dataloader: 196 | with torch.no_grad(): 197 | proteins.extend(data['atoms'].protein) 198 | scores.extend(model(data.to(device)).tolist()) 199 | 200 | 201 | assert len(proteins) == len(scores) 202 | 203 | goterms = pickle_load('{}/go_terms'.format(args.data_path))[f'GO-terms-{args.ontology}'] 204 | go_graph = obonet.read_obo(open("{}/go-basic.obo".format(args.data_path), 'r')) 205 | go_set = nx.ancestors(go_graph, FUNC_DICT[args.ontology]) 206 | 207 | 208 | results = {} 209 | for protein, score in zip(proteins, scores): 210 | protein_scores = {} 211 | 212 | for go_term, _score in zip(goterms, score): 213 | if _score > args.cut_off: 214 | protein_scores[go_term] = max(protein_scores.get(go_term, 0), _score) 215 | 216 | 217 | for go_term, max_score in list(protein_scores.items()): 218 | descendants = nx.descendants(go_graph, go_term).intersection(go_set) 219 | for descendant in descendants: 220 | protein_scores[descendant] = max(protein_scores.get(descendant, 0), max_score) 221 | 222 | results[protein] = protein_scores 223 | 224 | 225 | print("Writing output to {}".format(args.output)) 226 | write_to_file(results, "{}/{}".format(args.data_path, args.output)) 227 | -------------------------------------------------------------------------------- /prep/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque, Counter 2 | import warnings 3 | import pandas as pd 4 | import numpy as np 5 | import math 6 | 7 | BIOLOGICAL_PROCESS = 'GO:0008150' 8 | MOLECULAR_FUNCTION = 'GO:0003674' 9 | CELLULAR_COMPONENT = 'GO:0005575' 10 | FUNC_DICT = { 11 | 'cc': CELLULAR_COMPONENT, 12 | 'mf': MOLECULAR_FUNCTION, 13 | 'bp': BIOLOGICAL_PROCESS 14 | } 15 | 16 | NAMESPACES = { 17 | 'cc': 'cellular_component', 18 | 'mf': 'molecular_function', 19 | 'bp': 'biological_process' 20 | } 21 | 22 | EXP_CODES = {'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC', 'HTP', 'HDA', 'HMP', 'HGI', 'HEP'} 23 | 24 | def read_fasta(filename): 25 | seqs = list() 26 | info = list() 27 | seq = '' 28 | inf = '' 29 | with open(filename, 'r') as f: 30 | for line in f: 31 | line = line.strip() 32 | if line.startswith('>'): 33 | if seq != '': 34 | seqs.append(seq) 35 | info.append(inf) 36 | seq = '' 37 | inf = line[1:] 38 | else: 39 | seq += line 40 | seqs.append(seq) 41 | info.append(inf) 42 | return info, seqs 43 | 44 | 45 | class DataGenerator(object): 46 | 47 | def __init__(self, batch_size, is_sparse=False): 48 | self.batch_size = batch_size 49 | self.is_sparse = is_sparse 50 | 51 | def fit(self, inputs, targets=None): 52 | self.start = 0 53 | self.inputs = inputs 54 | self.targets = targets 55 | if isinstance(self.inputs, tuple) or isinstance(self.inputs, list): 56 | self.size = self.inputs[0].shape[0] 57 | else: 58 | self.size = self.inputs.shape[0] 59 | self.has_targets = targets is not None 60 | 61 | def __next__(self): 62 | return self.next() 63 | 64 | def reset(self): 65 | self.start = 0 66 | 67 | def next(self): 68 | if self.start < self.size: 69 | batch_index = np.arange( 70 | self.start, min(self.size, self.start + self.batch_size)) 71 | if isinstance(self.inputs, tuple) or isinstance(self.inputs, list): 72 | res_inputs = [] 73 | for inp in self.inputs: 74 | if self.is_sparse: 75 | res_inputs.append( 76 | inp[batch_index, :].toarray()) 77 | else: 78 | res_inputs.append(inp[batch_index, :]) 79 | else: 80 | if self.is_sparse: 81 | res_inputs = self.inputs[batch_index, :].toarray() 82 | else: 83 | res_inputs = self.inputs[batch_index, :] 84 | self.start += self.batch_size 85 | if self.has_targets: 86 | if self.is_sparse: 87 | labels = self.targets[batch_index, :].toarray() 88 | else: 89 | # 90 | # x = pd.read_csv("/data/pycharm/TransFun/nrPDB-GO_2019.06.18_annot.tsv", sep='\t', skiprows=12) 91 | # go_terms = set() 92 | # mf = x['GO-terms (cellular_component)'].to_list() 93 | # for i in mf: 94 | # if isinstance(i, str): 95 | # go_terms.update(i.split(',')) 96 | # print(len(go_terms)) 97 | # 98 | # xx = set(pickle_load(Constants.ROOT + "cellular_component/train_stats")) 99 | # print(len(xx)) 100 | # 101 | # print(len(xx.intersection(go_terms))) 102 | 103 | # bp = x['GO-terms (biological_process)'] 104 | # for i in mf: 105 | # go_terms.update(i.split(',')) 106 | # cc = x['GO-terms (cellular_component)'] 107 | # for i in mf: 108 | # go_terms.update(i.split(',')) 109 | 110 | labels = self.targets[batch_index, :] 111 | return res_inputs, labels 112 | return res_inputs 113 | else: 114 | self.reset() 115 | return self.next() 116 | -------------------------------------------------------------------------------- /preprocessing/create_go.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import subprocess 4 | import networkx as nx 5 | import numpy as np 6 | import obonet 7 | import pandas as pd 8 | from Bio.Seq import Seq 9 | from Bio import SeqIO, SwissProt 10 | from Bio.SeqRecord import SeqRecord 11 | 12 | import Constants 13 | from preprocessing.utils import pickle_save, pickle_load, get_sequence_from_pdb, fasta_for_msas, \ 14 | count_proteins_biopython, count_proteins, fasta_for_esm, fasta_to_dictionary, read_dictionary, \ 15 | get_proteins_from_fasta, create_seqrecord, read_test_set, alpha_seq_fasta_to_dictionary, collect_test, is_ok, \ 16 | test_annotation, get_test_classes 17 | 18 | 19 | def extract_id(header): 20 | return header.split('|')[1] 21 | 22 | 23 | def compare_sequence(uniprot_fasta_file, save=False): 24 | """ 25 | Script is used to compare sequences of uniprot & alpha fold. 26 | :param uniprot_fasta_file: input uniprot fasta file. 27 | :param save: whether to save the proteins that are similar & different. 28 | :return: None 29 | """ 30 | identical = [] 31 | unidentical = [] 32 | detected = 0 33 | for seq_record in SeqIO.parse(uniprot_fasta_file, "fasta"): 34 | uniprot_id = extract_id(seq_record.id) 35 | # print("Uniprot ID is {} and sequence is {}".format(uniprot_id, str(seq_record.seq))) 36 | # Check if alpha fold predicted structure 37 | src = os.path.join(Constants.ROOT, "alphafold/AF-{}-F1-model_v2.pdb.gz".format(uniprot_id)) 38 | if os.path.isfile(src): 39 | detected = detected + 1 40 | # compare sequence 41 | alpha_fold_seq = get_sequence_from_pdb(src, "A") 42 | uniprot_sequence = str(seq_record.seq) 43 | if alpha_fold_seq == uniprot_sequence: 44 | identical.append(uniprot_id) 45 | else: 46 | unidentical.append(uniprot_id) 47 | print("{} number of sequence structures detected, {} identical to uniprot sequence & {} " 48 | "not identical to uniprot sequence".format(detected, len(identical), len(unidentical))) 49 | if save: 50 | pickle_save(identical, Constants.ROOT + "uniprot/identical") 51 | pickle_save(unidentical, Constants.ROOT + "uniprot/unidentical") 52 | 53 | 54 | def filtered_sequences(uniprot_fasta_file): 55 | """ 56 | Script is used to create fasta files based on alphafold sequence, by replacing sequences that are different. 57 | :param uniprot_fasta_file: input uniprot fasta file. 58 | :return: None 59 | """ 60 | identified = set(pickle_load(Constants.ROOT + "uniprot/identical")) 61 | unidentified = set(pickle_load(Constants.ROOT + "uniprot/unidentical")) 62 | 63 | input_seq_iterator = SeqIO.parse(uniprot_fasta_file, "fasta") 64 | identified_seqs = [record for record in input_seq_iterator if extract_id(record.id) in identified] 65 | 66 | input_seq_iterator = SeqIO.parse(uniprot_fasta_file, "fasta") 67 | unidentified_seqs = [] 68 | for record in input_seq_iterator: 69 | uniprot_id = extract_id(record.id) 70 | if uniprot_id in unidentified: 71 | src = os.path.join(Constants.ROOT, "alphafold/AF-{}-F1-model_v2.pdb.gz".format(uniprot_id)) 72 | record.seq = Seq(get_sequence_from_pdb(src)) 73 | unidentified_seqs.append(record) 74 | 75 | new_seq = identified_seqs + unidentified_seqs 76 | print(len(identified_seqs), len(unidentified_seqs), len(new_seq)) 77 | SeqIO.write(new_seq, Constants.ROOT + "uniprot/{}.fasta".format("cleaned"), "fasta") 78 | 79 | 80 | def get_protein_go(uniprot_sprot_dat=None, save_path=None): 81 | """ 82 | Get all the GO terms associated with a protein in a clean format 83 | Creates file with structure: ACCESSION, ID, DESCRIPTION, WITH_STRING, EVIDENCE, GO_ID 84 | :param uniprot_sprot_dat: 85 | :param save_path: 86 | :return: None 87 | """ 88 | handle = open(uniprot_sprot_dat) 89 | all = [["ACC", "ID", "DESCRIPTION", "WITH_STRING", "EVIDENCE", "GO_ID", "ORGANISM", "TAXONOMY"]] 90 | for record in SwissProt.parse(handle): 91 | primary_accession = record.accessions[0] 92 | entry_name = record.entry_name 93 | cross_refs = record.cross_references 94 | organism = record.organism 95 | taxonomy = record.taxonomy_id 96 | for ref in cross_refs: 97 | if ref[0] == "GO": 98 | assert len(ref) == 4 99 | go_id = ref[1] 100 | description = ref[2] 101 | evidence = ref[3].split(":") 102 | with_string = evidence[1] 103 | evidence = evidence[0] 104 | all.append( 105 | [primary_accession, entry_name, description, with_string, 106 | evidence, go_id, organism, taxonomy]) 107 | with open(save_path, "w") as f: 108 | wr = csv.writer(f, delimiter='\t') 109 | wr.writerows(all) 110 | 111 | 112 | def generate_go_counts(fname="", go_graph="", cleaned_proteins =None): 113 | """ 114 | Get only sequences that meet the criteria of sequence length and sequences. 115 | :param cleaned_proteins: proteins filtered for alphafold sequence 116 | :param go_graph: obo-basic graph 117 | :param fname: accession2go file 118 | :param chains: the proteins we are considering 119 | :return: None 120 | """ 121 | 122 | df = pd.read_csv(fname, delimiter="\t") 123 | 124 | df = df[df['EVIDENCE'].isin(Constants.exp_evidence_codes)] 125 | df = df[df['ACC'].isin(cleaned_proteins)] 126 | 127 | protein2go = {} 128 | go2info = {} 129 | # for index, row in df.iterrows(): 130 | for line_number, (index, row) in enumerate(df.iterrows()): 131 | acc = row[0] 132 | evidence = row[4] 133 | go_id = row[5] 134 | 135 | # if (acc in chains) and (go_id in go_graph) and (go_id not in Constants.root_terms): 136 | if go_id in go_graph: 137 | if acc not in protein2go: 138 | protein2go[acc] = {'goterms': [go_id], 'evidence': [evidence]} 139 | namespace = go_graph.nodes[go_id]['namespace'] 140 | go_ids = nx.descendants(go_graph, go_id) 141 | go_ids.add(go_id) 142 | go_ids = go_ids.difference(Constants.root_terms) 143 | for go in go_ids: 144 | protein2go[acc]['goterms'].append(go) 145 | protein2go[acc]['evidence'].append(evidence) 146 | name = go_graph.nodes[go]['name'] 147 | if go not in go2info: 148 | go2info[go] = {'ont': namespace, 'goname': name, 'accessions': set([acc])} 149 | else: 150 | go2info[go]['accessions'].add(acc) 151 | return protein2go, go2info 152 | 153 | 154 | def one_line_format(input_file, dir): 155 | """ 156 | Script takes the mm2seq cluster output and converts to representative seq1, seq2, seq3 .... 157 | :param input_file: The clusters as csv file 158 | :return: None 159 | """ 160 | data = {} 161 | with open(input_file) as file: 162 | lines = file.read().splitlines() 163 | for line in lines: 164 | x = line.split("\t") 165 | if x[0] in data: 166 | data[x[0]].append(x[1]) 167 | else: 168 | data[x[0]] = list([x[1]]) 169 | result = [data[i] for i in data] 170 | with open(dir + "/final_clusters.csv", "w") as f: 171 | wr = csv.writer(f, delimiter='\t') 172 | wr.writerows(result) 173 | 174 | 175 | def get_prot_id_and_prot_name(cafa_proteins): 176 | print('Mapping CAFA PROTEINS') 177 | cafa_id_mapping = dict() 178 | with open(Constants.ROOT + 'uniprot/idmapping_selected.tab') as file: 179 | for line in file: 180 | _tmp = line.split("\t")[:2] 181 | if _tmp[1] in cafa_proteins: 182 | cafa_id_mapping[_tmp[1]] = _tmp[0] 183 | if len(cafa_id_mapping) == 97105: 184 | break 185 | return cafa_id_mapping 186 | 187 | 188 | def target_in_swissprot_trembl_no_alpha(): 189 | gps = ["LK_bpo", "LK_mfo", "LK_cco", "NK_bpo", "NK_mfo", "NK_cco"] 190 | targets = set() 191 | for ts in gps: 192 | ts_func_old = read_test_set("/data_bp/pycharm/TransFunData/data_bp/195-200/{}".format(ts)) 193 | targets.update(set([i[0] for i in ts_func_old])) 194 | 195 | ts_func_new = read_test_set("/data_bp/pycharm/TransFunData/data_bp/205-now/{}".format(ts)) 196 | targets.update(set([i[0] for i in ts_func_new])) 197 | 198 | target = [] 199 | for seq_record in SeqIO.parse(Constants.ROOT + "uniprot/uniprot_trembl.fasta", "fasta"): 200 | if extract_id(seq_record.id) in targets: 201 | target.append(seq_record) 202 | print(len(target)) 203 | if len(target) == len(targets): 204 | break 205 | for seq_record in SeqIO.parse(Constants.ROOT + "uniprot/uniprot_sprot.fasta", "fasta"): 206 | if extract_id(seq_record.id) in targets: 207 | target.append(seq_record) 208 | print(len(target)) 209 | if len(target) == len(targets): 210 | break 211 | SeqIO.write(target, Constants.ROOT + "uniprot/{}.fasta".format("target_and_sequence"), "fasta") 212 | 213 | 214 | # target_in_swissprot_trembl_no_alpha() 215 | 216 | 217 | def cluster_sequence(seq_id, proteins=None, add_target=False): 218 | """ 219 | Script is used to cluster the proteins with mmseq2. 220 | :param threshold: 221 | :param proteins: 222 | :param add_target: Add CAFA targets 223 | :param input_fasta: input uniprot fasta file. 224 | :return: None 225 | 226 | 1. sequence to cluster is cleaned sequence. 227 | 2. Filter for only selected proteins 228 | 3. Add proteins in target not in the filtered list: 229 | 3.1 230 | 3.2 231 | """ 232 | input_fasta = Constants.ROOT + "uniprot/cleaned.fasta" 233 | print("Number of proteins in raw cleaned is {}".format(count_proteins(input_fasta))) 234 | print("Number of selected proteins in raw cleaned is {}".format(len(proteins))) 235 | wd = Constants.ROOT + "{}/mmseq".format(seq_id) 236 | if not os.path.exists(wd): 237 | os.mkdir(wd) 238 | if proteins: 239 | fasta_path = wd + "/fasta_{}".format(seq_id) 240 | if os.path.exists(fasta_path): 241 | input_fasta = fasta_path 242 | else: 243 | input_seq_iterator = SeqIO.parse(input_fasta, "fasta") 244 | cleaned_fasta = [record for record in input_seq_iterator if extract_id(record.id) in proteins] 245 | SeqIO.write(cleaned_fasta, fasta_path, "fasta") 246 | assert len(cleaned_fasta) == len(proteins) 247 | input_fasta = fasta_path 248 | # Add sequence for target not in the uniprotKB 249 | if add_target: 250 | cleaned_missing_target_sequence = Constants.ROOT + "uniprot/cleaned_missing_target_sequence.fasta" 251 | if os.path.exists(cleaned_missing_target_sequence): 252 | input_fasta = cleaned_missing_target_sequence 253 | else: 254 | missing_targets_205_now = [] 255 | missing_targets_195_200 = [] 256 | # Adding missing target sequence 257 | all_list = set([extract_id(i.id) for i in (SeqIO.parse(input_fasta, "fasta"))]) 258 | extra_alpha_fold = alpha_seq_fasta_to_dictionary(Constants.ROOT + "uniprot/alphafold_sequences.fasta") 259 | extra_trembl = fasta_to_dictionary(Constants.ROOT + "uniprot/target_and_sequence.fasta", 260 | identifier='protein_id') 261 | for ts in Constants.TEST_GROUPS: 262 | ts_func_old = read_test_set("/data_bp/pycharm/TransFunData/data_bp/195-200/{}".format(ts)) 263 | ts_func_old = set([i[0] for i in ts_func_old]) 264 | 265 | ts_func_new = read_test_set("/data_bp/pycharm/TransFunData/data_bp/205-now/{}".format(ts)) 266 | ts_func_new = set([i[0] for i in ts_func_new]) 267 | 268 | print("Adding 195-200 {}".format(ts)) 269 | for _id in ts_func_old: 270 | # Alphafold sequence always takes precedence 271 | if _id not in all_list: 272 | if _id in extra_alpha_fold: 273 | _mp = extra_alpha_fold[_id] 274 | missing_targets_195_200.append(SeqRecord(id=_mp[0].replace("AFDB:", ""). 275 | replace("AF-", ""). 276 | replace("-F1", ""), 277 | name=_mp[1], 278 | description=_mp[2], 279 | seq=_mp[3])) 280 | # print("found {} in alphafold".format(_id)) 281 | elif _id in extra_trembl: 282 | _mp = extra_trembl[_id] 283 | missing_targets_195_200.append(SeqRecord(id=_mp[0], 284 | name=_mp[1], 285 | description=_mp[2], 286 | seq=_mp[3])) 287 | # print("found {} in trembl".format(_id)) 288 | else: 289 | print("Found in none for {}".format(_id)) 290 | 291 | print("Adding 205-now {}".format(ts)) 292 | for _id in ts_func_new: 293 | # Alphafold sequence always takes precedence 294 | if _id not in all_list: 295 | if _id in extra_alpha_fold: 296 | _mp = extra_alpha_fold[_id] 297 | missing_targets_205_now.append(SeqRecord(id=_mp[0].replace("AFDB:", ""). 298 | replace("AF-", ""). 299 | replace("-F1", ""), 300 | name=_mp[1], 301 | description=_mp[2], 302 | seq=_mp[3])) 303 | # print("found {} in alphafold".format(_id)) 304 | elif _id in extra_trembl: 305 | _mp = extra_trembl[_id] 306 | missing_targets_205_now.append(SeqRecord(id=_mp[0], 307 | name=_mp[1], 308 | description=_mp[2], 309 | seq=_mp[3])) 310 | # print("found {} in trembl".format(_id)) 311 | else: 312 | print("Found in none for {}".format(_id)) 313 | 314 | # save missing sequence 315 | SeqIO.write(missing_targets_195_200, Constants.ROOT + "uniprot/{}.fasta".format("missing_targets_195_200"), 316 | "fasta") 317 | SeqIO.write(missing_targets_205_now, Constants.ROOT + "uniprot/{}.fasta".format("missing_targets_205_now"), 318 | "fasta") 319 | 320 | input_seq_iterator = list(SeqIO.parse(input_fasta, "fasta")) 321 | SeqIO.write(input_seq_iterator + missing_targets_195_200 + missing_targets_205_now, Constants.ROOT + 322 | "uniprot/{}.fasta".format("cleaned_missing_target_sequence"), "fasta") 323 | 324 | input_fasta = cleaned_missing_target_sequence 325 | 326 | # input_seq_iterator = SeqIO.parse(Constants.ROOT + 327 | # "uniprot/{}.fasta".format("cleaned_missing_target_sequence"), "fasta") 328 | # 329 | # cleaned_fasta = set() 330 | # for record in input_seq_iterator: 331 | # if record.id.startswith("AFDB"): 332 | # cleaned_fasta.add(record.id.split(':')[1].split('-')[1]) 333 | # else: 334 | # cleaned_fasta.add(extract_id(record.id)) 335 | # 336 | # print(len(collect_test() - cleaned_fasta), len(cleaned_fasta)) 337 | 338 | print("Number of proteins in cleaned_missing_target_sequence is {}".format(count_proteins(input_fasta))) 339 | 340 | command = "mmseqs createdb {} {} ; " \ 341 | "mmseqs cluster {} {} tmp --min-seq-id {};" \ 342 | "mmseqs createtsv {} {} {} {}.tsv" \ 343 | "".format(input_fasta, "targetDB", "targetDB", "outputClu", seq_id, "targetDB", "targetDB", 344 | "outputClu", "outputClu") 345 | subprocess.call(command, shell=True, cwd="{}".format(wd)) 346 | one_line_format(wd + "/outputClu.tsv", wd) 347 | 348 | 349 | def accession2sequence(fasta_file=""): 350 | """ 351 | Extract sequnce for each accession into dictionary. 352 | :param fasta_file: 353 | :return: None 354 | """ 355 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 356 | acc2seq = {extract_id(record.id): str(record.seq) for record in input_seq_iterator} 357 | pickle_save(acc2seq, Constants.ROOT + "uniprot/acc2seq") 358 | 359 | 360 | def collect_test_clusters(cluster_path): 361 | # collect test and clusters 362 | total_test = collect_test() 363 | 364 | computed = pd.read_csv(cluster_path, names=['cluster'], header=None).to_dict()['cluster'] 365 | computed = {i: set(computed[i].split('\t')) for i in computed} 366 | 367 | cafa3_cluster = set() 368 | new_cluster = set() 369 | train_cluster_indicies = [] 370 | for i in computed: 371 | # cafa3 372 | if total_test[0].intersection(computed[i]): 373 | cafa3_cluster.update(computed[i]) 374 | # new set 375 | elif total_test[1].intersection(computed[i]): 376 | new_cluster.update(computed[i]) 377 | else: 378 | train_cluster_indicies.append(i) 379 | 380 | print(len(cafa3_cluster)) 381 | print(len(new_cluster)) 382 | exit() 383 | return test_cluster, train_cluster_indicies 384 | 385 | 386 | def write_output_files(protein2go, go2info, seq_id): 387 | onts = ['molecular_function', 'biological_process', 'cellular_component'] 388 | 389 | selected_goterms = {ont: set() for ont in onts} 390 | selected_proteins = set() 391 | 392 | print("Number of GO terms is {} proteins is {}".format(len(go2info), len(protein2go))) 393 | 394 | # for each go term count related proteins; if they are from 50 to 5000 395 | # then we can add them to our data_bp. 396 | for goterm in go2info: 397 | prots = go2info[goterm]['accessions'] 398 | num = len(prots) 399 | namespace = go2info[goterm]['ont'] 400 | if num >= 60: 401 | selected_goterms[namespace].add(goterm) 402 | selected_proteins = selected_proteins.union(prots) 403 | 404 | # Convert the accepted go terms into list, so they have a fixed order 405 | # Add the names of corresponding go terms. 406 | selected_goterms_list = {ont: list(selected_goterms[ont]) for ont in onts} 407 | selected_gonames_list = {ont: [go2info[goterm]['goname'] for goterm in selected_goterms_list[ont]] for ont in onts} 408 | 409 | # print the count of each go term 410 | for ont in onts: 411 | print("###", ont, ":", len(selected_goterms_list[ont])) 412 | 413 | terms = {} 414 | for ont in onts: 415 | terms['GO-terms-' + ont] = selected_goterms_list[ont] 416 | terms['GO-names-' + ont] = selected_gonames_list[ont] 417 | 418 | terms['GO-terms-all'] = selected_goterms_list['molecular_function'] + \ 419 | selected_goterms_list['biological_process'] + \ 420 | selected_goterms_list['cellular_component'] 421 | 422 | terms['GO-names-all'] = selected_gonames_list['molecular_function'] + \ 423 | selected_goterms_list['biological_process'] + \ 424 | selected_goterms_list['cellular_component'] 425 | 426 | pickle_save(terms, Constants.ROOT + 'go_terms') 427 | fasta_dic = fasta_to_dictionary(Constants.ROOT + "uniprot/cleaned.fasta") 428 | 429 | protein_list = set() 430 | terms_count = {'mf': set(), 'bp': set(), 'cc': set()} 431 | with open(Constants.ROOT + 'annot.tsv', 'wt') as out_file: 432 | tsv_writer = csv.writer(out_file, delimiter='\t') 433 | tsv_writer.writerow(["Protein", "molecular_function", "biological_process", "cellular_component", "all"]) 434 | 435 | for chain in selected_proteins: 436 | goterms = set(protein2go[chain]['goterms']) 437 | if len(goterms) > 2 and is_ok(str(fasta_dic[chain][3])): 438 | # selected goterms 439 | mf_goterms = goterms.intersection(set(selected_goterms_list[onts[0]])) 440 | bp_goterms = goterms.intersection(set(selected_goterms_list[onts[1]])) 441 | cc_goterms = goterms.intersection(set(selected_goterms_list[onts[2]])) 442 | if len(mf_goterms) > 0 or len(bp_goterms) > 0 or len(cc_goterms) > 0: 443 | terms_count['mf'].update(mf_goterms) 444 | terms_count['bp'].update(bp_goterms) 445 | terms_count['cc'].update(cc_goterms) 446 | protein_list.add(chain) 447 | tsv_writer.writerow([chain, ','.join(mf_goterms), ','.join(bp_goterms), ','.join(cc_goterms), 448 | ','.join(mf_goterms.union(bp_goterms).union(cc_goterms))]) 449 | 450 | assert len(terms_count['mf']) == len(selected_goterms_list['molecular_function']) \ 451 | and len(terms_count['mf']) == len(selected_goterms_list['molecular_function']) \ 452 | and len(terms_count['mf']) == len(selected_goterms_list['molecular_function']) 453 | 454 | 455 | print("Creating Clusters") 456 | cluster_path = Constants.ROOT + "{}/mmseq/final_clusters.csv".format(seq_id) 457 | if not os.path.exists(cluster_path): 458 | cluster_sequence(seq_id, protein_list, add_target=True) 459 | 460 | # Remove test proteins & their cluster 461 | # Decided to remove irrespective of mf, bp | cc 462 | # It should be fine. 463 | # test_cluster, train_cluster_indicies = collect_test_clusters(cluster_path) 464 | # train_list = protein_list - test_cluster 465 | # assert len(protein_list.intersection(test_cluster)) == len(protein_list.intersection(collect_test())) == 0 466 | # print(len(protein_list), len(protein_list.intersection(cafa3)), len(protein_list.intersection(new_test))) 467 | 468 | print("Getting test cluster") 469 | cafa3, new_test = collect_test() 470 | 471 | train_list = protein_list - (cafa3.union(new_test)) 472 | assert len(train_list.intersection(cafa3)) == len(train_list.intersection(new_test)) == 0 473 | 474 | validation_len = 6000 #int(0.2 * len(protein_list)) 475 | validation_list = set() 476 | 477 | for chain in train_list: 478 | goterms = set(protein2go[chain]['goterms']) 479 | mf_goterms = set(goterms).intersection(set(selected_goterms_list[onts[0]])) 480 | bp_goterms = set(goterms).intersection(set(selected_goterms_list[onts[1]])) 481 | cc_goterms = set(goterms).intersection(set(selected_goterms_list[onts[2]])) 482 | 483 | if len(mf_goterms) > 0 and len(bp_goterms) > 0 and len(cc_goterms) > 0: 484 | validation_list.add(chain) 485 | 486 | if len(validation_list) >= validation_len: 487 | break 488 | 489 | pickle_save(validation_list, Constants.ROOT + '/{}/valid'.format(seq_id)) 490 | train_list = train_list - validation_list 491 | 492 | print("Total number of train nrPDB=%d" % (len(train_list))) 493 | 494 | annot = pd.read_csv(Constants.ROOT + 'annot.tsv', delimiter='\t') 495 | for ont in onts + ['all']: 496 | _pth = Constants.ROOT + '{}/{}'.format(seq_id, ont) 497 | if not os.path.exists(_pth): 498 | os.mkdir(_pth) 499 | 500 | tmp = annot[annot[ont].notnull()][['Protein', ont]] 501 | tmp_prot_list = set(tmp['Protein'].to_list()) 502 | tmp_prot_list = tmp_prot_list.intersection(train_list) 503 | 504 | computed = pd.read_csv(cluster_path, names=['cluster'], header=None) 505 | 506 | # train_indicies = computed.index.isin(train_cluster_indicies) 507 | 508 | # computed = computed.loc[train_indicies].to_dict()['cluster'] 509 | computed = computed.to_dict()['cluster'] 510 | computed = {ont: set(computed[ont].split('\t')) for ont in computed} 511 | 512 | new_computed = {} 513 | index = 0 514 | for i in computed: 515 | _tmp = tmp_prot_list.intersection(computed[i]) 516 | if len(_tmp) > 0: 517 | new_computed[index] = _tmp 518 | index += 1 519 | 520 | _train = set.union(*new_computed.values()) 521 | print("Total proteins for {} is {} in {} clusters".format(ont, len(_train), len(new_computed))) 522 | assert len(cafa3.intersection(_train)) == 0 and len(validation_list.intersection(_train)) == 0 523 | 524 | pickle_save(new_computed, _pth + '/train') 525 | 526 | 527 | def pipeline(compare=False, curate_protein_goterms=False, generate_go_count=False, 528 | generate_msa=False, generate_esm=False, seq_id=0.3): 529 | """ 530 | section 1 531 | 1. First compare the sequence in uniprot and alpha fold and retrieve same sequence and different sequences. 532 | 2. Replace mismatched sequences with alpha fold sequence & create the fasta from only alphafold sequences 533 | 3. Just another comparison to be sure, we have only alphafold sequences. 534 | 535 | section 2 536 | GO terms associated with a protein 537 | 538 | section 3 539 | 1. Convert Fasta to dictionary 540 | 2. Read OBO graph 541 | 3. Get proteins and related go terms & go terms and associated proteins 542 | 543 | :param generate_msa: 544 | :param generate_esm: 545 | :param generate_go_count: 546 | :param curate_protein_goterms: 547 | :param compare: Compare sequence between uniprot and alphafold 548 | :return: 549 | """ 550 | 551 | # section 1 552 | if compare: 553 | compare_sequence(Constants.ROOT + "uniprot/uniprot_sprot.fasta", save=True) # 1 554 | filtered_sequences(Constants.ROOT + "uniprot/uniprot_sprot.fasta") # 2 create cleaned.fasta 555 | compare_sequence(Constants.ROOT + "uniprot/cleaned.fasta", save=False) # 3 556 | 557 | # section 2 558 | if curate_protein_goterms: 559 | get_protein_go(uniprot_sprot_dat=Constants.ROOT + "uniprot/uniprot_sprot.dat", 560 | save_path=Constants.ROOT + "protein2go.csv") # 4 contains proteins and go terms. 561 | 562 | # section 3 563 | if generate_go_count: 564 | cleaned_proteins = fasta_to_dictionary(Constants.ROOT + "uniprot/cleaned.fasta") 565 | go_graph = obonet.read_obo(open(Constants.ROOT + "obo/go-basic.obo", 'r')) # 5 566 | protein2go, go2info = generate_go_counts(fname=Constants.ROOT + "protein2go.csv", go_graph=go_graph, 567 | cleaned_proteins=list(cleaned_proteins.keys())) 568 | pickle_save(protein2go, Constants.ROOT + "protein2go") 569 | pickle_save(go2info, Constants.ROOT + "go2info") 570 | 571 | protein2go = pickle_load(Constants.ROOT + "protein2go") 572 | go2info = pickle_load(Constants.ROOT + "go2info") 573 | 574 | print("Writing output for sequence identity {}".format(seq_id)) 575 | write_output_files(protein2go, go2info, seq_id=seq_id) 576 | 577 | if generate_msa: 578 | fasta_file = Constants.ROOT + "cleaned.fasta" 579 | protein2go_primary = set(protein2go) 580 | fasta_for_msas(protein2go_primary, fasta_file) 581 | 582 | if generate_esm: 583 | fasta_file = Constants.ROOT + "cleaned.fasta" 584 | fasta_for_esm(protein2go, fasta_file) 585 | 586 | # print(count_proteins(Constants.ROOT + "uniprot/{}.fasta".format("target_and_sequence"))) 587 | # 588 | 589 | 590 | seq = [0.3, 0.5, 0.9, 0.95] 591 | for i in seq: 592 | pipeline(compare=False, 593 | curate_protein_goterms=False, 594 | generate_go_count=False, 595 | generate_msa=False, 596 | generate_esm=False, 597 | seq_id=i) 598 | 599 | 600 | exit() 601 | groups = ['molecular_function', 'cellular_component', 'biological_process'] 602 | 603 | for i in seq: 604 | for j in groups: 605 | 606 | train = pd.read_pickle(Constants.ROOT + "{}/{}/train.pickle".format(i, j)) 607 | valid = set(pd.read_pickle(Constants.ROOT + "{}/{}/valid.pickle".format(i, j))) 608 | test_cluster,_ = collect_test_clusters(Constants.ROOT + "{}/mmseq/final_clusters.csv".format(i)) 609 | test = collect_test() 610 | print(i, j, len(test_cluster), len(test), len(test_cluster - test), len(test - test_cluster)) 611 | 612 | # assert len(train.intersection(test_cluster)) == 0 613 | # assert len(train.intersection(test)) == 0 614 | 615 | assert len(valid.intersection(test_cluster)) == 0 616 | assert len(valid.intersection(test)) == 0 617 | -------------------------------------------------------------------------------- /preprocessing/extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import pathlib 9 | 10 | import torch 11 | 12 | from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained 13 | 14 | 15 | def create_parser(): 16 | parser = argparse.ArgumentParser( 17 | description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa 18 | ) 19 | 20 | parser.add_argument( 21 | "model_location", 22 | type=str, 23 | help="PyTorch model file OR name of pretrained model to download (see README for models)", 24 | ) 25 | parser.add_argument( 26 | "fasta_file", 27 | type=pathlib.Path, 28 | help="FASTA file on which to extract representations", 29 | ) 30 | parser.add_argument( 31 | "output_dir", 32 | type=pathlib.Path, 33 | help="output directory for extracted representations", 34 | ) 35 | 36 | parser.add_argument("--toks_per_batch", type=int, default=4096, help="maximum batch size") 37 | parser.add_argument( 38 | "--repr_layers", 39 | type=int, 40 | default=[-1], 41 | nargs="+", 42 | help="layers indices from which to extract representations (0 to num_layers, inclusive)", 43 | ) 44 | parser.add_argument( 45 | "--include", 46 | type=str, 47 | nargs="+", 48 | choices=["mean", "per_tok", "bos", "contacts"], 49 | help="specify which representations to return", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--truncate", 54 | action="store_true", 55 | help="Truncate sequences longer than 1024 to match the training setup", 56 | ) 57 | 58 | parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") 59 | return parser 60 | 61 | 62 | def main(args): 63 | model, alphabet = pretrained.load_model_and_alphabet(args.model_location) 64 | model.eval() 65 | if torch.cuda.is_available() and not args.nogpu: 66 | model = model.cuda() 67 | print("Transferred model to GPU") 68 | 69 | dataset = FastaBatchedDataset.from_file(args.fasta_file) 70 | batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) 71 | data_loader = torch.utils.data.DataLoader( 72 | dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches 73 | ) 74 | print(f"Read {args.fasta_file} with {len(dataset)} sequences") 75 | 76 | args.output_dir.mkdir(parents=True, exist_ok=True) 77 | return_contacts = "contacts" in args.include 78 | 79 | assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) 80 | repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers] 81 | 82 | with torch.no_grad(): 83 | for batch_idx, (labels, strs, toks) in enumerate(data_loader): 84 | print( 85 | f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" 86 | ) 87 | if torch.cuda.is_available() and not args.nogpu: 88 | toks = toks.to(device="cuda", non_blocking=True) 89 | 90 | # The model is trained on truncated sequences and passing longer ones in at 91 | # infernce will cause an error. See https://github.com/facebookresearch/esm/issues/21 92 | if args.truncate: 93 | toks = toks[:, :1022] 94 | 95 | out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts) 96 | 97 | logits = out["logits"].to(device="cpu") 98 | representations = { 99 | layer: t.to(device="cpu") for layer, t in out["representations"].items() 100 | } 101 | if return_contacts: 102 | contacts = out["contacts"].to(device="cpu") 103 | 104 | for i, label in enumerate(labels): 105 | args.output_file = args.output_dir / f"{label}.pt" 106 | args.output_file.parent.mkdir(parents=True, exist_ok=True) 107 | result = {"label": label} 108 | # Call clone on tensors to ensure tensors are not views into a larger representation 109 | # See https://github.com/pytorch/pytorch/issues/1995 110 | if "per_tok" in args.include: 111 | result["representations"] = { 112 | layer: t[i, 1 : len(strs[i]) + 1].clone() 113 | for layer, t in representations.items() 114 | } 115 | if "mean" in args.include: 116 | result["mean_representations"] = { 117 | layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone() 118 | for layer, t in representations.items() 119 | } 120 | if "bos" in args.include: 121 | result["bos_representations"] = { 122 | layer: t[i, 0].clone() for layer, t in representations.items() 123 | } 124 | if return_contacts: 125 | result["contacts"] = contacts[i, : len(strs[i]), : len(strs[i])].clone() 126 | 127 | torch.save( 128 | result, 129 | args.output_file, 130 | ) 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = create_parser() 135 | args = parser.parse_args() 136 | main(args) 137 | -------------------------------------------------------------------------------- /preprocessing/generate_msa.py: -------------------------------------------------------------------------------- 1 | import os, subprocess 2 | from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union 3 | from absl import logging 4 | from tools import jackhmmer, parsers, residue_constants, msa_identifiers, hhblits 5 | import numpy as np 6 | import shutil 7 | from absl import app 8 | from absl import logging 9 | import multiprocessing 10 | from glob import glob 11 | import sys 12 | 13 | jackhmmer_binary_path = shutil.which('jackhmmer') 14 | uniref90_database_path = "/data_bp/pycharm/Genetic_Databases/uniref90/uniref90.fasta" 15 | mgnify_database_path = "/data_bp/pycharm/Genetic_Databases/mgnify/mgy_clusters_2018_12.fa" 16 | small_bfd_database_path = "/data_bp/pycharm/Genetic_Databases/small_bfd/bfd-first_non_consensus_sequences.fasta" 17 | hhblits_binary_path = "/data_bp/pycharm/Genetic_Databases/small_bfd/bfd-first_non_consensus_sequences.fasta" 18 | uniclust30_database_path = "/data_bp/pycharm/Genetic_Databases/small_bfd/bfd-first_non_consensus_sequences.fasta" 19 | 20 | 21 | FeatureDict = MutableMapping[str, np.ndarray] 22 | 23 | 24 | def make_msa_features(msas: Sequence[parsers.Msa], combined_out_path: str) -> FeatureDict: 25 | """Constructs a feature dict of MSA features.""" 26 | if not msas: 27 | raise ValueError('At least one MSA must be provided.') 28 | 29 | int_msa = [] 30 | deletion_matrix = [] 31 | uniprot_accession_ids = [] 32 | species_ids = [] 33 | seen_sequences = [] 34 | name_identifiers = [] 35 | for msa_index, msa in enumerate(msas): 36 | if not msa: 37 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.') 38 | for sequence_index, sequence in enumerate(msa.sequences): 39 | if sequence in seen_sequences: 40 | continue 41 | seen_sequences.append(sequence) 42 | name_identifiers.append(msa.descriptions[sequence_index]) 43 | int_msa.append([residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) 44 | deletion_matrix.append(msa.deletion_matrix[sequence_index]) 45 | identifiers = msa_identifiers.get_identifiers(msa.descriptions[sequence_index]) 46 | uniprot_accession_ids.append(identifiers.uniprot_accession_id.encode('utf-8')) 47 | species_ids.append(identifiers.species_id.encode('utf-8')) 48 | 49 | num_res = len(msas[0].sequences[0]) 50 | num_alignments = len(int_msa) 51 | features = {} 52 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) 53 | features['msa'] = np.array(int_msa, dtype=np.int32) 54 | features['num_alignments'] = np.array([num_alignments] * num_res, dtype=np.int32) 55 | features['msa_uniprot_accession_identifiers'] = np.array(uniprot_accession_ids, dtype=np.object_) 56 | features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_) 57 | 58 | with open(combined_out_path, 'w') as f: 59 | for item in zip(seen_sequences, name_identifiers): 60 | f.write(">%s\n" % item[1]) 61 | f.write("%s\n" % item[0]) 62 | 63 | 64 | 65 | def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str, msa_format: str, use_precomputed_msas: bool, base_path: str) -> Mapping[str, Any]: 66 | """Runs an MSA tool, checking if output already exists first.""" 67 | if not use_precomputed_msas or not os.path.exists(msa_out_path): 68 | result = msa_runner.query(input_fasta_path, base_path=base_path)[0] 69 | with open(msa_out_path, 'w') as f: 70 | f.write(result[msa_format]) 71 | else: 72 | logging.error('Reading MSA from file %s', msa_out_path) 73 | with open(msa_out_path, 'r') as f: 74 | result = {msa_format: f.read()} 75 | return result 76 | 77 | class DataPipeline: 78 | """Runs the alignment tools and assembles the input features.""" 79 | 80 | def __init__(self, jackhmmer_binary_path: str, hhblits_binary_path: str, uniref90_database_path: str, mgnify_database_path: str, small_bfd_database_path: Optional[str], uniclust30_database_path: str, bfd_database_path: Optional[str]): 81 | """Initializes the data_bp pipeline.""" 82 | 83 | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(binary_path=jackhmmer_binary_path, database_path=uniref90_database_path) 84 | 85 | #self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(binary_path=jackhmmer_binary_path, database_path=small_bfd_database_path) 86 | 87 | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( binary_path=hhblits_binary_path, databases=[bfd_database_path, uniclust30_database_path]) 88 | 89 | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(binary_path=jackhmmer_binary_path, database_path=mgnify_database_path) 90 | 91 | self.use_precomputed_msas = False 92 | 93 | self.mgnify_max_hits = 501 94 | 95 | self.uniref_max_hits = 10000 96 | 97 | def process(self, input_fasta_path: str, msa_output_dir: str, base_path: str, protein: str, combine: bool, make_diverse: bool) -> FeatureDict: 98 | """Runs alignment tools on the input sequence and creates features.""" 99 | 100 | uniref90_msa = "None" 101 | bfd_msa = "None" 102 | mgnify_msa = "None" 103 | 104 | combined_out_path = os.path.join(msa_output_dir, 'combined.a3m') 105 | diverse_out_path = os.path.join(msa_output_dir, 'diverse_{}.a3m') 106 | 107 | with open(input_fasta_path) as f: 108 | input_fasta_str = f.read() 109 | input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) 110 | if len(input_seqs) != 1: 111 | raise ValueError(f'More than one input sequence found in {input_fasta_path}.') 112 | 113 | if os.path.isfile(combined_out_path): 114 | logging.error("Combined already generated for {}".format(input_fasta_path)) 115 | else: 116 | uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') 117 | if not os.path.isfile(uniref90_out_path): 118 | logging.error("Generating msa for {} from {}".format(protein, "uniref90")) 119 | jackhmmer_uniref90_result = run_msa_tool(self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, 'sto', self.use_precomputed_msas, base_path=base_path) 120 | uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) 121 | uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) 122 | else: 123 | if combine and not os.path.isfile(combined_out_path): 124 | logging.error("Loading msa for {} from {} @ {}".format(protein, "uniref90", uniref90_out_path)) 125 | with open(uniref90_out_path, 'r') as f: 126 | jackhmmer_uniref90_result = {'sto': f.read()} 127 | uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) 128 | uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) 129 | 130 | mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') 131 | if not os.path.isfile(mgnify_out_path): 132 | logging.error("Generating msa for {} from {}".format(protein, "mgnify")) 133 | jackhmmer_mgnify_result = run_msa_tool(self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', self.use_precomputed_msas, base_path=base_path) 134 | mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) 135 | mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) 136 | else: 137 | if combine and not os.path.isfile(combined_out_path): 138 | logging.error("Loading msa for {} from {} @ {}".format(protein, "mgnify", mgnify_out_path)) 139 | with open(mgnify_out_path, 'r') as f: 140 | jackhmmer_mgnify_result = {'sto': f.read()} 141 | mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) 142 | mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) 143 | 144 | bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') 145 | if not os.path.isfile(bfd_out_path): 146 | logging.error("Generating msa for {} from {}".format(protein, "Bfd")) 147 | hhblits_bfd_uniclust_result = run_msa_tool(self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path, 'a3m', self.use_precomputed_msas, base_path=base_path) 148 | bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) 149 | # jackhmmer_small_bfd_result = run_msa_tool(self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path, 'sto', self.use_precomputed_msas, base_path=base_path) 150 | # bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) 151 | else: 152 | if combine and not os.path.isfile(combined_out_path): 153 | logging.error("Loading msa for {} from {} @ {}".format(protein, "small_bfd", bfd_out_path)) 154 | with open(bfd_out_path, 'r') as f: 155 | hhblits_small_bfd_result = {'a3m': f.read()} 156 | bfd_msa = parsers.parse_stockholm(hhblits_small_bfd_result['a3m']) 157 | # jackhmmer_small_bfd_result = {'sto': f.read()} 158 | # bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) 159 | 160 | 161 | msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa), combined_out_path=combined_out_path) 162 | 163 | if make_diverse: 164 | if not os.path.isfile(diverse_out_path.format(64)): 165 | subprocess.call( 166 | 'hhfilter -i {} -o {} -diff {}'.format(combined_out_path, diverse_out_path.format(64), 64), shell=True) 167 | 168 | if not os.path.isfile(diverse_out_path.format(128)): 169 | subprocess.call( 170 | 'hhfilter -i {} -o {} -diff {}'.format(combined_out_path, diverse_out_path.format(128), 128), shell=True) 171 | 172 | if not os.path.isfile(diverse_out_path.format(256)): 173 | subprocess.call( 174 | 'hhfilter -i {} -o {} -diff {}'.format(combined_out_path, diverse_out_path.format(256), 256), shell=True) 175 | 176 | if not os.path.isfile(diverse_out_path.format(512)): 177 | subprocess.call( 178 | 'hhfilter -i {} -o {} -diff {}'.format(combined_out_path, diverse_out_path.format(512), 512), shell=True) 179 | 180 | # logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) 181 | # logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) 182 | # logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) 183 | 184 | 185 | 186 | 187 | def run_main(directory): 188 | pipeline = DataPipeline(jackhmmer_binary_path, hhblits_binary_path, uniref90_database_path, mgnify_database_path, small_bfd_database_path, uniclust30_database_path, bfd_database_path) 189 | base_path = base+"{}".format(directory) 190 | logging.info("Generating for protein {}".format(directory)) 191 | input_path = base_path+"/{}.fasta".format(directory) 192 | output_path = base_path+"/msas" 193 | if not os.path.exists(output_path): 194 | os.makedirs(output_path) 195 | pipeline.process(input_fasta_path=input_path, msa_output_dir=output_path, base_path=base_path, protein=directory, combine=True, make_diverse=True) 196 | 197 | 198 | base = "/storage/htc/bdm/Frimpong/TransFun/msa_files/two/{}/".format(sys.argv[1]) 199 | directories = [x for x in os.listdir(base)] 200 | 201 | logging.info("Started") 202 | pool = multiprocessing.Pool(4) 203 | pool.map(run_main, directories) 204 | pool.close() 205 | -------------------------------------------------------------------------------- /preprocessing/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import pandas as pd 4 | import torch 5 | import esm 6 | import torch.nn.functional as F 7 | 8 | import Constants 9 | from preprocessing.utils import pickle_save, pickle_load, count_proteins_biopython 10 | 11 | 12 | # Script to test esm 13 | def test_esm(): 14 | # Load ESM-1b model 15 | model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() 16 | batch_converter = alphabet.get_batch_converter() 17 | 18 | # Prepare data_bp (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) 19 | data = [ 20 | ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), 21 | ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), 22 | ("protein2 with mask", "KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), 23 | ("protein3", "K A I S Q"), 24 | ] 25 | 26 | for i, j in data: 27 | print(len(j)) 28 | 29 | batch_labels, batch_strs, batch_tokens = batch_converter(data) 30 | 31 | # Extract per-residue representations (on CPU) 32 | with torch.no_grad(): 33 | results = model(batch_tokens, repr_layers=[33], return_contacts=True) 34 | token_representations = results["representations"][33] 35 | 36 | print(token_representations.shape) 37 | 38 | # Generate per-sequence representations via averaging 39 | # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. 40 | sequence_representations = [] 41 | for i, (_, seq) in enumerate(data): 42 | sequence_representations.append(token_representations[i, 1: len(seq) + 1].mean(0)) 43 | 44 | for i in sequence_representations: 45 | print(len(i)) 46 | 47 | # # Look at the unsupervised self-attention map contact predictions 48 | # import matplotlib.pyplot as plt 49 | # for (_, seq), attention_contacts in zip(data_bp, results["contacts"]): 50 | # plt.matshow(attention_contacts[: len(seq), : len(seq)]) 51 | # plt.title(seq) 52 | # plt.show() 53 | 54 | 55 | # Generate ESM embeddings in bulk 56 | # In this function, I create embedding for each fasta sequence in the fasta file 57 | # Extract file is taken from the github directory 58 | def generate_bulk_embedding(fasta_file, output_dir, path_to_extract_file): 59 | subprocess.call('python extract.py esm1b_t33_650M_UR50S {} {} --repr_layers 0 32 33 ' 60 | '--include mean per_tok --truncate'.format("{}".format(fasta_file), 61 | "{}".format(output_dir)), 62 | shell=True, cwd="{}".format(path_to_extract_file)) 63 | 64 | 65 | # print(count_proteins_biopython(Constants.ROOT + "eval/{}_1.fasta".format("test"))) 66 | # exit() 67 | # generate_bulk_embedding(Constants.ROOT + "eval/{}.fasta".format("cropped"), 68 | # "/data_bp/pycharm/TransFunData/data_bp/bnm", 69 | # "/data_bp/pycharm/TransFun/preprocessing") 70 | 71 | # generate_bulk_embedding(Constants.ROOT + "eval/{}.fasta".format("shorter"), 72 | # "/data_bp/pycharm/TransFunData/data_bp/shorter", 73 | # "/data_bp/pycharm/TransFun/preprocessing") 74 | 75 | exit() 76 | 77 | 78 | # Generate data_bp for each group 79 | def generate_data(): 80 | def get_stats(data): 81 | go_terms = {} 82 | for i in data: 83 | for j in i.split(","): 84 | if j in go_terms: 85 | go_terms[j] = go_terms[j] + 1 86 | else: 87 | go_terms[j] = 1 88 | return go_terms 89 | 90 | categories = [('molecular_function', 'GO-terms (molecular_function)'), 91 | ('biological_process', 'GO-terms (biological_process)'), 92 | ('cellular_component', 'GO-terms (cellular_component)')] 93 | x_id = '### PDB-chain' 94 | 95 | train_set = pickle_load(Constants.ROOT + "final_train") 96 | valid_set = pickle_load(Constants.ROOT + "final_valid") 97 | test_set = pickle_load(Constants.ROOT + "final_test") 98 | 99 | for i in categories: 100 | print("Generating for {}".format(i[0])) 101 | 102 | if not os.path.isdir(Constants.ROOT + i[0]): 103 | os.mkdir(Constants.ROOT + i[0]) 104 | 105 | df = pd.read_csv("/data_bp/pycharm/TransFunData/data_bp/final_annot.tsv", skiprows=12, delimiter="\t") 106 | df = df[df[i[1]].notna()][[x_id, i[1]]] 107 | 108 | train_df = df[df[x_id].isin(train_set)] 109 | train_df.to_pickle(Constants.ROOT + i[0] + "/train.pickle") 110 | stats = get_stats(train_df[i[1]].to_list()) 111 | pickle_save(stats, Constants.ROOT + i[0] + "/train_stats") 112 | print(len(stats)) 113 | 114 | valid_df = df[df[x_id].isin(valid_set)] 115 | valid_df.to_pickle(Constants.ROOT + i[0] + "/valid.pickle") 116 | stats = get_stats(valid_df[i[1]].to_list()) 117 | pickle_save(stats, Constants.ROOT + i[0] + "/valid_stats") 118 | print(len(stats)) 119 | 120 | test_df = df[df[x_id].isin(test_set)] 121 | test_df.to_pickle(Constants.ROOT + i[0] + "/test.pickle") 122 | stats = get_stats(test_df[i[1]].to_list()) 123 | pickle_save(stats, Constants.ROOT + i[0] + "/test_stats") 124 | print(len(stats)) 125 | 126 | 127 | # generate_data() 128 | 129 | 130 | # Generate labels for data_bp 131 | def generate_labels(_type='GO-terms (molecular_function)', _name='molecular_function'): 132 | # ['GO-terms (molecular_function)', 'GO-terms (biological_process)', 'GO-terms (cellular_component)'] 133 | 134 | # if not os.path.isfile('/data_bp/pycharm/TransFunData/data_bp/{}.pickle'.format(_name)): 135 | 136 | file = '/data_bp/pycharm/TransFunData/data_bp/nrPDB-GO_2021.01.23_annot.tsv' 137 | data = pd.read_csv(file, sep='\t', skiprows=12) 138 | data = data[["### PDB-chain", _type]] 139 | data = data[data[_type].notna()] 140 | 141 | classes = data[_type].to_list() 142 | classes = set([one_word for class_list in classes for one_word in class_list.split(',')]) 143 | class_keys = list(range(0, len(classes))) 144 | 145 | classes = dict(zip(classes, class_keys)) 146 | 147 | data_to_one_hot = {} 148 | for index, row in data.iterrows(): 149 | tmp = row[_type].split(',') 150 | x = torch.tensor([classes[i] for i in tmp]) 151 | x = F.one_hot(x, num_classes=len(classes)) 152 | x = x.sum(dim=0).float() 153 | assert len(tmp) == x.sum(dim=0).float() 154 | data_to_one_hot[row['### PDB-chain']] = x.to(dtype=torch.int) 155 | pickle_save(data_to_one_hot, '/data_bp/pycharm/TransFunData/data_bp/{}'.format(_name)) 156 | 157 | # generate_labels() 158 | -------------------------------------------------------------------------------- /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os, subprocess 3 | import shutil 4 | 5 | import pandas as pd 6 | import torch 7 | from Bio import SeqIO 8 | import pickle 9 | 10 | from Bio.Seq import Seq 11 | from Bio.SeqRecord import SeqRecord 12 | from biopandas.pdb import PandasPdb 13 | from collections import deque, Counter 14 | import csv 15 | 16 | from sklearn.metrics import roc_curve, auc 17 | from torchviz import make_dot 18 | 19 | import Constants 20 | from Constants import INVALID_ACIDS, amino_acids 21 | 22 | 23 | def extract_id(header): 24 | return header.split('|')[1] 25 | 26 | 27 | def count_proteins(fasta_file): 28 | num = len([1 for line in open(fasta_file) if line.startswith(">")]) 29 | return num 30 | 31 | 32 | def read_dictionary(file): 33 | reader = csv.reader(open(file, 'r'), delimiter='\t') 34 | d = {} 35 | for row in reader: 36 | k, v = row[0], row[1] 37 | d[k] = v 38 | return d 39 | 40 | 41 | def create_seqrecord(id="", name="", description="", seq=""): 42 | record = SeqRecord(Seq(seq), id=id, name=name, description=description) 43 | return record 44 | 45 | 46 | # Count the number of protein sequences in a fasta file with biopython -- slower. 47 | def count_proteins_biopython(fasta_file): 48 | num = len(list(SeqIO.parse(fasta_file, "fasta"))) 49 | return num 50 | 51 | 52 | def get_proteins_from_fasta(fasta_file): 53 | proteins = list(SeqIO.parse(fasta_file, "fasta")) 54 | # proteins = [i.id.split("|")[1] for i in proteins] 55 | proteins = [i.id for i in proteins] 56 | return proteins 57 | 58 | 59 | def fasta_to_dictionary(fasta_file, identifier='protein_id'): 60 | if identifier == 'protein_id': 61 | loc = 1 62 | elif identifier == 'protein_name': 63 | loc = 2 64 | data = {} 65 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 66 | if "|" in seq_record.id: 67 | data[seq_record.id.split("|")[loc]] = ( 68 | seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 69 | else: 70 | data[seq_record.id] = (seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 71 | return data 72 | 73 | 74 | def cafa_fasta_to_dictionary(fasta_file): 75 | data = {} 76 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 77 | data[seq_record.description.split(" ")[0]] = ( 78 | seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 79 | return data 80 | 81 | 82 | def alpha_seq_fasta_to_dictionary(fasta_file): 83 | data = {} 84 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 85 | _protein = seq_record.id.split(":")[1].split("-")[1] 86 | data[_protein] = (seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 87 | return data 88 | 89 | 90 | def pickle_save(data, filename): 91 | with open('{}.pickle'.format(filename), 'wb') as handle: 92 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 93 | 94 | 95 | def pickle_load(filename): 96 | with open('{}.pickle'.format(filename), 'rb') as handle: 97 | return pickle.load(handle) 98 | 99 | 100 | def download_msa_database(url, name): 101 | database_path = "./msa/hh_suite_database/{}".format(name) 102 | if not os.path.isdir(database_path): 103 | os.mkdir(database_path) 104 | # download database, note downloading a ~1Gb file can take a minute 105 | database_file = "{}/{}.tar.gz".format(database_path, name) 106 | subprocess.call('wget -O {} {}'.format(database_file, url), shell=True) 107 | # unzip the database 108 | subprocess.call('tar xzvf {}.tar.gz'.format(name), shell=True, cwd="{}".format(database_path)) 109 | 110 | 111 | def search_database(file, database): 112 | base_path = "./msa/{}" 113 | output_path = base_path.format("outputs/{}.hhr".format(file)) 114 | input_path = base_path.format("inputs/{}.fasta".format(file)) 115 | oa3m_path = base_path.format("oa3ms/{}.03m".format(file)) 116 | database_path = base_path.format("hh_suite_database/{}/{}".format(database, database)) 117 | if not os.path.isfile(oa3m_path): 118 | subprocess.call( 119 | 'hhblits -i {} -o {} -oa3m {} -d {} -cpu 4 -n 1'.format(input_path, output_path, oa3m_path, database_path), 120 | shell=True) 121 | 122 | 123 | # Just used to group msas to keep track of generation 124 | def partition_files(group): 125 | from glob import glob 126 | dirs = glob("/data_bp/fasta_files/{}/*/".format(group), recursive=False) 127 | for i in enumerate(dirs): 128 | prt = i[1].split('/')[4] 129 | if int(i[0]) % 100 == 0: 130 | current = "/data_bp/fasta_files/{}/{}".format(group, (int(i[0]) // 100)) 131 | if not os.path.isdir(current): 132 | os.mkdir(current) 133 | old = "/data_bp/fasta_files/{}/{}".format(group, prt) 134 | new = current + "/{}".format(prt) 135 | if old != current: 136 | os.rename(old, new) 137 | 138 | 139 | # Just used to group msas to keep track of generation 140 | def fasta_for_msas(proteins, fasta_file): 141 | root_dir = '/data_bp/uniprot/' 142 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 143 | num_protein = 0 144 | for record in input_seq_iterator: 145 | if num_protein % 200 == 0: 146 | parent_dir = root_dir + str(int(num_protein / 200)) 147 | print(parent_dir) 148 | if not os.path.exists(parent_dir): 149 | os.mkdir(parent_dir) 150 | protein = extract_id(record.id) 151 | if protein in proteins: 152 | protein_dir = parent_dir + '/' + protein 153 | if not os.path.exists(protein_dir): 154 | os.mkdir(protein_dir) 155 | SeqIO.write(record, protein_dir + "/{}.fasta".format(protein), "fasta") 156 | 157 | 158 | # Files to generate esm embedding for. 159 | def fasta_for_esm(proteins, fasta_file): 160 | protein_path = Constants.ROOT + "uniprot/{}.fasta".format("filtered") 161 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 162 | 163 | filtered_seqs = [record for record in input_seq_iterator if extract_id(record.id) in proteins] 164 | 165 | if not os.path.exists(protein_path): 166 | SeqIO.write(filtered_seqs, protein_path, "fasta") 167 | 168 | 169 | def get_sequence_from_pdb(pdb_file, chain_id): 170 | pdb_to_pandas = PandasPdb().read_pdb(pdb_file) 171 | 172 | pdb_df = pdb_to_pandas.df['ATOM'] 173 | 174 | assert (len(set(pdb_df['chain_id'])) == 1) & (list(set(pdb_df['chain_id']))[0] == chain_id) 175 | 176 | pdb_df = pdb_df[(pdb_df['atom_name'] == 'CA') & ((pdb_df['chain_id'])[0] == chain_id)] 177 | pdb_df = pdb_df.drop_duplicates() 178 | 179 | residues = pdb_df['residue_name'].to_list() 180 | residues = ''.join([amino_acids[i] for i in residues if i != "UNK"]) 181 | return residues 182 | 183 | 184 | def is_ok(seq, MINLEN=49, MAXLEN=1022): 185 | """ 186 | Checks if sequence is of good quality 187 | :param MAXLEN: 188 | :param MINLEN: 189 | :param seq: 190 | :return: None 191 | """ 192 | if len(seq) < MINLEN or len(seq) >= MAXLEN: 193 | return False 194 | for c in seq: 195 | if c in INVALID_ACIDS: 196 | return False 197 | return True 198 | 199 | 200 | def is_cafa_target(org): 201 | return org in Constants.CAFA_TARGETS 202 | 203 | 204 | def is_exp_code(code): 205 | return code in Constants.exp_evidence_codes 206 | 207 | 208 | def read_test_set(file_name): 209 | with open(file_name) as file: 210 | lines = file.readlines() 211 | lines = [line.rstrip('\n').split("\t")[0] for line in lines] 212 | return lines 213 | 214 | 215 | def read_test_set_x(file_name): 216 | with open(file_name) as file: 217 | lines = file.readlines() 218 | lines = [line.rstrip('\n').split("\t") for line in lines] 219 | return lines 220 | 221 | 222 | def read_test(file_name): 223 | with open(file_name) as file: 224 | lines = file.readlines() 225 | lines = [line.rstrip('\n') for line in lines] 226 | return lines 227 | 228 | 229 | def collect_test(): 230 | cafa3 = pickle_load(Constants.ROOT + "test/test_proteins_list") 231 | cafa3 = set([i[0] for i in cafa3]) 232 | 233 | new_test = set() 234 | for ts in Constants.TEST_GROUPS: 235 | # tmp = read_test_set(Constants.ROOT + "test/195-200/{}".format(ts)) 236 | # total_test.update(set([i[0] for i in tmp])) 237 | tmp = read_test_set(Constants.ROOT + "test/205-now/{}".format(ts)) 238 | new_test.update(set([i[0] for i in tmp])) 239 | 240 | return cafa3, new_test 241 | 242 | 243 | def test_annotation(): 244 | # Add annotations for test set 245 | data = {} 246 | for ts in Constants.TEST_GROUPS: 247 | tmp = read_test_set("/data_bp/pycharm/TransFunData/data_bp/195-200/{}".format(ts)) 248 | for i in tmp: 249 | if i[0] in data: 250 | data[i[0]][ts].add(i[1]) 251 | else: 252 | data[i[0]] = {'LK_bpo': set(), 'LK_mfo': set(), 'LK_cco': set(), 'NK_bpo': set(), 'NK_mfo': set(), 253 | 'NK_cco': set()} 254 | data[i[0]][ts].add(i[1]) 255 | 256 | tmp = read_test_set("/data_bp/pycharm/TransFunData/data_bp/205-now/{}".format(ts)) 257 | for i in tmp: 258 | if i[0] in data: 259 | data[i[0]][ts].add(i[1]) 260 | else: 261 | data[i[0]] = {'LK_bpo': set(), 'LK_mfo': set(), 'LK_cco': set(), 'NK_bpo': set(), 'NK_mfo': set(), 262 | 'NK_cco': set()} 263 | data[i[0]][ts].add(i[1]) 264 | 265 | return data 266 | 267 | 268 | # GO terms for test set. 269 | def get_test_classes(): 270 | data = set() 271 | for ts in Constants.TEST_GROUPS: 272 | tmp = read_test_set("/data_bp/pycharm/TransFunData/data_bp/195-200/{}".format(ts)) 273 | for i in tmp: 274 | data.add(i[1]) 275 | 276 | tmp = read_test_set("/data_bp/pycharm/TransFunData/data_bp/205-now/{}".format(ts)) 277 | for i in tmp: 278 | data.add(i[1]) 279 | 280 | return data 281 | 282 | 283 | def create_cluster(seq_identity=None): 284 | def get_position(row, pos, column, split): 285 | primary = row[column].split(split)[pos] 286 | return primary 287 | 288 | computed = pd.read_pickle(Constants.ROOT + 'uniprot/set1/swissprot.pkl') 289 | computed['primary_accession'] = computed.apply(lambda row: get_position(row, 0, 'accessions', ';'), axis=1) 290 | annotated = pickle_load(Constants.ROOT + "uniprot/anotated") 291 | 292 | def max_go_terms(row): 293 | members = row['cluster'].split('\t') 294 | largest = 0 295 | max = 0 296 | for index, value in enumerate(members): 297 | x = computed.loc[computed['primary_accession'] == value]['prop_annotations'].values # .tolist() 298 | if len(x) > 0: 299 | if len(x[0]) > largest: 300 | largest = len(x[0]) 301 | max = index 302 | return members[max] 303 | 304 | if seq_identity is not None: 305 | src = "/data_bp/pycharm/TransFunData/data_bp/uniprot/set1/mm2seq_{}/max_term".format(seq_identity) 306 | if os.path.isfile(src): 307 | cluster = pd.read_pickle(src) 308 | else: 309 | cluster = pd.read_csv("/data_bp/pycharm/TransFunData/data_bp/uniprot/set1/mm2seq_{}/final_clusters.tsv" 310 | .format(seq_identity), names=['cluster'], header=None) 311 | 312 | cluster['rep'] = cluster.apply(lambda row: get_position(row, 0, 'cluster', '\t'), axis=1) 313 | cluster['max'] = cluster.apply(lambda row: max_go_terms(row), axis=1) 314 | cluster.to_pickle("/data_bp/pycharm/TransFunData/data_bp/uniprot/set1/mm2seq_{}/max_term".format(seq_identity)) 315 | 316 | cluster = cluster['max'].to_list() 317 | computed = computed[computed['primary_accession'].isin(cluster)] 318 | 319 | return computed 320 | 321 | 322 | def class_distribution_counter(**kwargs): 323 | """ 324 | Count the number of proteins for each GO term in training set. 325 | """ 326 | data = pickle_load(Constants.ROOT + "{}/{}/{}".format(kwargs['seq_id'], kwargs['ont'], kwargs['session'])) 327 | 328 | all_proteins = [] 329 | for i in data: 330 | all_proteins.extend(data[i]) 331 | 332 | annot = pd.read_csv(Constants.ROOT + 'annot.tsv', delimiter='\t') 333 | annot = annot.where(pd.notnull(annot), None) 334 | annot = annot[annot['Protein'].isin(all_proteins)] 335 | annot = pd.Series(annot[kwargs['ont']].values, index=annot['Protein']).to_dict() 336 | 337 | terms = [] 338 | for i in annot: 339 | terms.extend(annot[i].split(",")) 340 | 341 | counter = Counter(terms) 342 | 343 | # for i in counter.most_common(): 344 | # print(i) 345 | # print("# of ontologies is {}".format(len(counter))) 346 | 347 | return counter 348 | 349 | 350 | def save_ckp(state, is_best, checkpoint_path, best_model_path): 351 | """ 352 | state: checkpoint we want to save 353 | is_best: is this the best checkpoint; min validation loss 354 | checkpoint_path: path to save checkpoint 355 | best_model_path: path to save best model 356 | """ 357 | f_path = checkpoint_path 358 | # save checkpoint data_bp to the path given, checkpoint_path 359 | torch.save(state, f_path) 360 | # if it is a best model, min validation loss 361 | if is_best: 362 | best_fpath = best_model_path 363 | # copy that checkpoint file to best path given, best_model_path 364 | shutil.copyfile(f_path, best_fpath) 365 | 366 | 367 | def load_ckp(checkpoint_fpath, model, optimizer, device): 368 | """ 369 | checkpoint_path: path to save checkpoint 370 | model: model that we want to load checkpoint parameters into 371 | optimizer: optimizer we defined in previous training 372 | """ 373 | # load check point 374 | checkpoint = torch.load(checkpoint_fpath, map_location=torch.device(device)) 375 | # initialize state_dict from checkpoint to model 376 | model.load_state_dict(checkpoint['state_dict']) 377 | # initialize optimizer from checkpoint to optimizer 378 | optimizer.load_state_dict(checkpoint['optimizer']) 379 | # initialize valid_loss_min from checkpoint to valid_loss_min 380 | valid_loss_min = checkpoint['valid_loss_min'] 381 | # return model, optimizer, epoch value, min validation loss 382 | return model, optimizer, checkpoint['epoch'], valid_loss_min 383 | 384 | 385 | def draw_architecture(model, data_batch): 386 | ''' 387 | Draw the network architecture. 388 | ''' 389 | output = model(data_batch) 390 | make_dot(output, params=dict(model.named_parameters())).render("rnn_lstm_torchviz", format="png") 391 | 392 | 393 | def compute_roc(labels, preds): 394 | # Compute ROC curve and ROC area for each class 395 | fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten()) 396 | roc_auc = auc(fpr, tpr) 397 | return roc_auc 398 | 399 | 400 | def generate_bulk_embedding(path_to_extract_file, fasta_file, output_dir): 401 | subprocess.call('python {} esm1b_t33_650M_UR50S {} {} --repr_layers 0 32 33 ' 402 | '--include mean per_tok --truncate'.format(path_to_extract_file, 403 | "{}".format(fasta_file), 404 | "{}".format(output_dir)), 405 | shell=True) -------------------------------------------------------------------------------- /tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHblits from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Any, List, Mapping, Optional, Sequence 21 | 22 | from absl import logging 23 | # Internal import (7716). 24 | from tools import utils 25 | 26 | _HHBLITS_DEFAULT_P = 20 27 | _HHBLITS_DEFAULT_Z = 500 28 | 29 | 30 | class HHBlits: 31 | """Python wrapper of the HHblits binary.""" 32 | 33 | def __init__(self, 34 | *, 35 | binary_path: str, 36 | databases: Sequence[str], 37 | n_cpu: int = 4, 38 | n_iter: int = 3, 39 | e_value: float = 0.001, 40 | maxseq: int = 1_000_000, 41 | realign_max: int = 100_000, 42 | maxfilt: int = 100_000, 43 | min_prefilter_hits: int = 1000, 44 | all_seqs: bool = False, 45 | alt: Optional[int] = None, 46 | p: int = _HHBLITS_DEFAULT_P, 47 | z: int = _HHBLITS_DEFAULT_Z): 48 | """Initializes the Python HHblits wrapper. 49 | 50 | Args: 51 | binary_path: The path to the HHblits executable. 52 | databases: A sequence of HHblits database paths. This should be the 53 | common prefix for the database files (i.e. up to but not including 54 | _hhm.ffindex etc.) 55 | n_cpu: The number of CPUs to give HHblits. 56 | n_iter: The number of HHblits iterations. 57 | e_value: The E-value, see HHblits docs for more details. 58 | maxseq: The maximum number of rows in an input alignment. Note that this 59 | parameter is only supported in HHBlits version 3.1 and higher. 60 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 61 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 62 | HHblits default: 20000. 63 | min_prefilter_hits: Min number of hits to pass prefilter. 64 | HHblits default: 100. 65 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 66 | HHblits default: False. 67 | alt: Show up to this many alternative alignments. 68 | p: Minimum Prob for a hit to be included in the output hhr file. 69 | HHblits default: 20. 70 | z: Hard cap on number of hits reported in the hhr file. 71 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 72 | 73 | Raises: 74 | RuntimeError: If HHblits binary not found within the path. 75 | """ 76 | self.binary_path = binary_path 77 | self.databases = databases 78 | 79 | for database_path in self.databases: 80 | if not glob.glob(database_path + '_*'): 81 | logging.error('Could not find HHBlits database %s', database_path) 82 | raise ValueError(f'Could not find HHBlits database {database_path}') 83 | 84 | self.n_cpu = n_cpu 85 | self.n_iter = n_iter 86 | self.e_value = e_value 87 | self.maxseq = maxseq 88 | self.realign_max = realign_max 89 | self.maxfilt = maxfilt 90 | self.min_prefilter_hits = min_prefilter_hits 91 | self.all_seqs = all_seqs 92 | self.alt = alt 93 | self.p = p 94 | self.z = z 95 | 96 | def query(self, input_fasta_path: str, base_path: str) -> List[Mapping[str, Any]]: 97 | """Queries the database using HHblits.""" 98 | with utils.tmpdir_manager(base_dir=base_path) as query_tmp_dir: 99 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 100 | 101 | db_cmd = [] 102 | for db_path in self.databases: 103 | db_cmd.append('-d') 104 | db_cmd.append(db_path) 105 | cmd = [ 106 | self.binary_path, 107 | '-i', input_fasta_path, 108 | '-cpu', str(self.n_cpu), 109 | '-oa3m', a3m_path, 110 | '-o', '/dev/null', 111 | '-n', str(self.n_iter), 112 | '-e', str(self.e_value), 113 | '-maxseq', str(self.maxseq), 114 | '-realign_max', str(self.realign_max), 115 | '-maxfilt', str(self.maxfilt), 116 | '-min_prefilter_hits', str(self.min_prefilter_hits)] 117 | if self.all_seqs: 118 | cmd += ['-all'] 119 | if self.alt: 120 | cmd += ['-alt', str(self.alt)] 121 | if self.p != _HHBLITS_DEFAULT_P: 122 | cmd += ['-p', str(self.p)] 123 | if self.z != _HHBLITS_DEFAULT_Z: 124 | cmd += ['-Z', str(self.z)] 125 | cmd += db_cmd 126 | 127 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 128 | process = subprocess.Popen( 129 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 130 | 131 | with utils.timing('HHblits query'): 132 | stdout, stderr = process.communicate() 133 | retcode = process.wait() 134 | 135 | if retcode: 136 | # Logs have a 15k character limit, so log HHblits error line by line. 137 | logging.error('HHblits failed. HHblits stderr begin:') 138 | for error_line in stderr.decode('utf-8').splitlines(): 139 | if error_line.strip(): 140 | logging.error(error_line.strip()) 141 | logging.error('HHblits stderr end') 142 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( 143 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) 144 | 145 | with open(a3m_path) as f: 146 | a3m = f.read() 147 | 148 | raw_output = dict( 149 | a3m=a3m, 150 | output=stdout, 151 | stderr=stderr, 152 | n_iter=self.n_iter, 153 | e_value=self.e_value) 154 | return [raw_output] 155 | -------------------------------------------------------------------------------- /tools/jackhmmer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run Jackhmmer from Python.""" 16 | 17 | from concurrent import futures 18 | import glob 19 | import os 20 | import subprocess 21 | from typing import Any, Callable, Mapping, Optional, Sequence 22 | from urllib import request 23 | 24 | from absl import logging 25 | 26 | 27 | # Internal import (7716). 28 | from tools import utils 29 | 30 | 31 | class Jackhmmer: 32 | """Python wrapper of the Jackhmmer binary.""" 33 | 34 | def __init__(self, 35 | *, 36 | binary_path: str, 37 | database_path: str, 38 | n_cpu: int = 16, 39 | n_iter: int = 1, 40 | e_value: float = 0.0001, 41 | z_value: Optional[int] = None, 42 | get_tblout: bool = False, 43 | filter_f1: float = 0.0005, 44 | filter_f2: float = 0.00005, 45 | filter_f3: float = 0.0000005, 46 | incdom_e: Optional[float] = None, 47 | dom_e: Optional[float] = None, 48 | num_streamed_chunks: Optional[int] = None, 49 | streaming_callback: Optional[Callable[[int], None]] = None): 50 | """Initializes the Python Jackhmmer wrapper. 51 | 52 | Args: 53 | binary_path: The path to the jackhmmer executable. 54 | database_path: The path to the jackhmmer database (FASTA format). 55 | n_cpu: The number of CPUs to give Jackhmmer. 56 | n_iter: The number of Jackhmmer iterations. 57 | e_value: The E-value, see Jackhmmer docs for more details. 58 | z_value: The Z-value, see Jackhmmer docs for more details. 59 | get_tblout: Whether to save tblout string. 60 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. 61 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off. 62 | filter_f3: Forward pre-filter, set to >1.0 to turn off. 63 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next 64 | round. 65 | dom_e: Domain e-value criteria for inclusion in tblout. 66 | num_streamed_chunks: Number of database chunks to stream over. 67 | streaming_callback: Callback function run after each chunk iteration with 68 | the iteration number as argument. 69 | """ 70 | self.binary_path = binary_path 71 | self.database_path = database_path 72 | self.num_streamed_chunks = num_streamed_chunks 73 | 74 | if not os.path.exists(self.database_path) and num_streamed_chunks is None: 75 | logging.error('Could not find Jackhmmer database %s', database_path) 76 | raise ValueError(f'Could not find Jackhmmer database {database_path}') 77 | 78 | self.n_cpu = n_cpu 79 | self.n_iter = n_iter 80 | self.e_value = e_value 81 | self.z_value = z_value 82 | self.filter_f1 = filter_f1 83 | self.filter_f2 = filter_f2 84 | self.filter_f3 = filter_f3 85 | self.incdom_e = incdom_e 86 | self.dom_e = dom_e 87 | self.get_tblout = get_tblout 88 | self.streaming_callback = streaming_callback 89 | 90 | def _query_chunk(self, input_fasta_path: str, database_path: str, base_path: str) -> Mapping[str, Any]: 91 | """Queries the database chunk using Jackhmmer.""" 92 | with utils.tmpdir_manager(base_dir=base_path) as query_tmp_dir: 93 | sto_path = os.path.join(query_tmp_dir, 'output.sto') 94 | 95 | # The F1/F2/F3 are the expected proportion to pass each of the filtering 96 | # stages (which get progressively more expensive), reducing these 97 | # speeds up the pipeline at the expensive of sensitivity. They are 98 | # currently set very low to make querying Mgnify run in a reasonable 99 | # amount of time. 100 | cmd_flags = [ 101 | # Don't pollute stdout with Jackhmmer output. 102 | '-o', '/dev/null', 103 | '-A', sto_path, 104 | '--noali', 105 | '--F1', str(self.filter_f1), 106 | '--F2', str(self.filter_f2), 107 | '--F3', str(self.filter_f3), 108 | '--incE', str(self.e_value), 109 | # Report only sequences with E-values <= x in per-sequence output. 110 | '-E', str(self.e_value), 111 | '--cpu', str(self.n_cpu), 112 | '-N', str(self.n_iter) 113 | ] 114 | if self.get_tblout: 115 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') 116 | cmd_flags.extend(['--tblout', tblout_path]) 117 | 118 | if self.z_value: 119 | cmd_flags.extend(['-Z', str(self.z_value)]) 120 | 121 | if self.dom_e is not None: 122 | cmd_flags.extend(['--domE', str(self.dom_e)]) 123 | 124 | if self.incdom_e is not None: 125 | cmd_flags.extend(['--incdomE', str(self.incdom_e)]) 126 | 127 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path, 128 | database_path] 129 | 130 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 131 | process = subprocess.Popen( 132 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 133 | with utils.timing( 134 | f'Jackhmmer ({os.path.basename(database_path)}) query'): 135 | _, stderr = process.communicate() 136 | retcode = process.wait() 137 | 138 | if retcode: 139 | raise RuntimeError( 140 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) 141 | 142 | # Get e-values for each target name 143 | tbl = '' 144 | if self.get_tblout: 145 | with open(tblout_path) as f: 146 | tbl = f.read() 147 | 148 | with open(sto_path) as f: 149 | sto = f.read() 150 | 151 | raw_output = dict( 152 | sto=sto, 153 | tbl=tbl, 154 | stderr=stderr, 155 | n_iter=self.n_iter, 156 | e_value=self.e_value) 157 | 158 | return raw_output 159 | 160 | def query(self, input_fasta_path: str, base_path: str) -> Sequence[Mapping[str, Any]]: 161 | """Queries the database using Jackhmmer.""" 162 | if self.num_streamed_chunks is None: 163 | return [self._query_chunk(input_fasta_path, self.database_path, base_path=base_path)] 164 | 165 | db_basename = os.path.basename(self.database_path) 166 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' 167 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' 168 | 169 | # Remove existing files to prevent OOM 170 | for f in glob.glob(db_local_chunk('[0-9]*')): 171 | try: 172 | os.remove(f) 173 | except OSError: 174 | print(f'OSError while deleting {f}') 175 | 176 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk 177 | with futures.ThreadPoolExecutor(max_workers=2) as executor: 178 | chunked_output = [] 179 | for i in range(1, self.num_streamed_chunks + 1): 180 | # Copy the chunk locally 181 | if i == 1: 182 | future = executor.submit( 183 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) 184 | if i < self.num_streamed_chunks: 185 | next_future = executor.submit( 186 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) 187 | 188 | # Run Jackhmmer with the chunk 189 | future.result() 190 | chunked_output.append( 191 | self._query_chunk(input_fasta_path, db_local_chunk(i))) 192 | 193 | # Remove the local copy of the chunk 194 | os.remove(db_local_chunk(i)) 195 | # Do not set next_future for the last chunk so that this works even for 196 | # databases with only 1 chunk. 197 | if i < self.num_streamed_chunks: 198 | future = next_future 199 | if self.streaming_callback: 200 | self.streaming_callback(i) 201 | return chunked_output 202 | -------------------------------------------------------------------------------- /tools/msa_identifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for extracting identifiers from MSA sequence descriptions.""" 16 | 17 | import dataclasses 18 | import re 19 | from typing import Optional 20 | 21 | 22 | # Sequences coming from UniProtKB database come in the 23 | # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` 24 | # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). 25 | _UNIPROT_PATTERN = re.compile( 26 | r""" 27 | ^ 28 | # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot 29 | (?:tr|sp) 30 | \| 31 | # A primary accession number of the UniProtKB entry. 32 | (?P[A-Za-z0-9]{6,10}) 33 | # Occasionally there is a _0 or _1 isoform suffix, which we ignore. 34 | (?:_\d)? 35 | \| 36 | # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic 37 | # protein ID code. 38 | (?:[A-Za-z0-9]+) 39 | _ 40 | # A mnemonic species identification code. 41 | (?P([A-Za-z0-9]){1,5}) 42 | # Small BFD uses a final value after an underscore, which we ignore. 43 | (?:_\d+)? 44 | $ 45 | """, 46 | re.VERBOSE) 47 | 48 | 49 | @dataclasses.dataclass(frozen=True) 50 | class Identifiers: 51 | uniprot_accession_id: str = '' 52 | species_id: str = '' 53 | 54 | 55 | def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: 56 | """Gets accession id and species from an msa sequence identifier. 57 | 58 | The sequence identifier has the format specified by 59 | _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. 60 | An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` 61 | 62 | Args: 63 | msa_sequence_identifier: a sequence identifier. 64 | 65 | Returns: 66 | An `Identifiers` instance with a uniprot_accession_id and species_id. These 67 | can be empty in the case where no identifier was found. 68 | """ 69 | matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) 70 | if matches: 71 | return Identifiers( 72 | uniprot_accession_id=matches.group('AccessionIdentifier'), 73 | species_id=matches.group('SpeciesIdentifier')) 74 | return Identifiers() 75 | 76 | 77 | def _extract_sequence_identifier(description: str) -> Optional[str]: 78 | """Extracts sequence identifier from description. Returns None if no match.""" 79 | split_description = description.split() 80 | if split_description: 81 | return split_description[0].partition('/')[0] 82 | else: 83 | return None 84 | 85 | 86 | def get_identifiers(description: str) -> Identifiers: 87 | """Computes extra MSA features from the description.""" 88 | sequence_identifier = _extract_sequence_identifier(description) 89 | if sequence_identifier is None: 90 | return Identifiers() 91 | else: 92 | return _parse_sequence_identifier(sequence_identifier) 93 | -------------------------------------------------------------------------------- /tools/parsers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for parsing various file formats.""" 16 | import collections 17 | import dataclasses 18 | import itertools 19 | import re 20 | import string 21 | from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set 22 | 23 | DeletionMatrix = Sequence[Sequence[int]] 24 | 25 | 26 | @dataclasses.dataclass(frozen=True) 27 | class Msa: 28 | """Class representing a parsed MSA file.""" 29 | sequences: Sequence[str] 30 | deletion_matrix: DeletionMatrix 31 | descriptions: Sequence[str] 32 | 33 | def __post_init__(self): 34 | if not (len(self.sequences) == 35 | len(self.deletion_matrix) == 36 | len(self.descriptions)): 37 | raise ValueError( 38 | 'All fields for an MSA must have the same length. ' 39 | f'Got {len(self.sequences)} sequences, ' 40 | f'{len(self.deletion_matrix)} rows in the deletion matrix and ' 41 | f'{len(self.descriptions)} descriptions.') 42 | 43 | def __len__(self): 44 | return len(self.sequences) 45 | 46 | def truncate(self, max_seqs: int): 47 | return Msa(sequences=self.sequences[:max_seqs], 48 | deletion_matrix=self.deletion_matrix[:max_seqs], 49 | descriptions=self.descriptions[:max_seqs]) 50 | 51 | 52 | @dataclasses.dataclass(frozen=True) 53 | class TemplateHit: 54 | """Class representing a template hit.""" 55 | index: int 56 | name: str 57 | aligned_cols: int 58 | sum_probs: Optional[float] 59 | query: str 60 | hit_sequence: str 61 | indices_query: List[int] 62 | indices_hit: List[int] 63 | 64 | 65 | def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: 66 | """Parses FASTA string and returns list of strings with amino-acid sequences. 67 | 68 | Arguments: 69 | fasta_string: The string contents of a FASTA file. 70 | 71 | Returns: 72 | A tuple of two lists: 73 | * A list of sequences. 74 | * A list of sequence descriptions taken from the comment lines. In the 75 | same order as the sequences. 76 | """ 77 | sequences = [] 78 | descriptions = [] 79 | index = -1 80 | for line in fasta_string.splitlines(): 81 | line = line.strip() 82 | if line.startswith('>'): 83 | index += 1 84 | descriptions.append(line[1:]) # Remove the '>' at the beginning. 85 | sequences.append('') 86 | continue 87 | elif not line: 88 | continue # Skip blank lines. 89 | sequences[index] += line 90 | 91 | return sequences, descriptions 92 | 93 | 94 | def parse_stockholm(stockholm_string: str) -> Msa: 95 | """Parses sequences and deletion matrix from stockholm format alignment. 96 | 97 | Args: 98 | stockholm_string: The string contents of a stockholm file. The first 99 | sequence in the file should be the query sequence. 100 | 101 | Returns: 102 | A tuple of: 103 | * A list of sequences that have been aligned to the query. These 104 | might contain duplicates. 105 | * The deletion matrix for the alignment as a list of lists. The element 106 | at `deletion_matrix[i][j]` is the number of residues deleted from 107 | the aligned sequence i at residue position j. 108 | * The names of the targets matched, including the jackhmmer subsequence 109 | suffix. 110 | """ 111 | name_to_sequence = collections.OrderedDict() 112 | for line in stockholm_string.splitlines(): 113 | line = line.strip() 114 | if not line or line.startswith(('#', '//')): 115 | continue 116 | name, sequence = line.split() 117 | if name not in name_to_sequence: 118 | name_to_sequence[name] = '' 119 | name_to_sequence[name] += sequence 120 | 121 | msa = [] 122 | deletion_matrix = [] 123 | 124 | query = '' 125 | keep_columns = [] 126 | for seq_index, sequence in enumerate(name_to_sequence.values()): 127 | if seq_index == 0: 128 | # Gather the columns with gaps from the query 129 | query = sequence 130 | keep_columns = [i for i, res in enumerate(query) if res != '-'] 131 | 132 | # Remove the columns with gaps in the query from all sequences. 133 | aligned_sequence = ''.join([sequence[c] for c in keep_columns]) 134 | 135 | msa.append(aligned_sequence) 136 | 137 | # Count the number of deletions w.r.t. query. 138 | deletion_vec = [] 139 | deletion_count = 0 140 | for seq_res, query_res in zip(sequence, query): 141 | if seq_res != '-' or query_res != '-': 142 | if query_res == '-': 143 | deletion_count += 1 144 | else: 145 | deletion_vec.append(deletion_count) 146 | deletion_count = 0 147 | deletion_matrix.append(deletion_vec) 148 | 149 | return Msa(sequences=msa, 150 | deletion_matrix=deletion_matrix, 151 | descriptions=list(name_to_sequence.keys())) 152 | 153 | 154 | def parse_a3m(a3m_string: str) -> Msa: 155 | """Parses sequences and deletion matrix from a3m format alignment. 156 | 157 | Args: 158 | a3m_string: The string contents of a a3m file. The first sequence in the 159 | file should be the query sequence. 160 | 161 | Returns: 162 | A tuple of: 163 | * A list of sequences that have been aligned to the query. These 164 | might contain duplicates. 165 | * The deletion matrix for the alignment as a list of lists. The element 166 | at `deletion_matrix[i][j]` is the number of residues deleted from 167 | the aligned sequence i at residue position j. 168 | * A list of descriptions, one per sequence, from the a3m file. 169 | """ 170 | sequences, descriptions = parse_fasta(a3m_string) 171 | deletion_matrix = [] 172 | for msa_sequence in sequences: 173 | deletion_vec = [] 174 | deletion_count = 0 175 | for j in msa_sequence: 176 | if j.islower(): 177 | deletion_count += 1 178 | else: 179 | deletion_vec.append(deletion_count) 180 | deletion_count = 0 181 | deletion_matrix.append(deletion_vec) 182 | 183 | # Make the MSA matrix out of aligned (deletion-free) sequences. 184 | deletion_table = str.maketrans('', '', string.ascii_lowercase) 185 | aligned_sequences = [s.translate(deletion_table) for s in sequences] 186 | return Msa(sequences=aligned_sequences, 187 | deletion_matrix=deletion_matrix, 188 | descriptions=descriptions) 189 | 190 | 191 | def _convert_sto_seq_to_a3m( 192 | query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: 193 | for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): 194 | if is_query_res_non_gap: 195 | yield sequence_res 196 | elif sequence_res != '-': 197 | yield sequence_res.lower() 198 | 199 | 200 | def convert_stockholm_to_a3m(stockholm_format: str, 201 | max_sequences: Optional[int] = None, 202 | remove_first_row_gaps: bool = True) -> str: 203 | """Converts MSA in Stockholm format to the A3M format.""" 204 | descriptions = {} 205 | sequences = {} 206 | reached_max_sequences = False 207 | 208 | for line in stockholm_format.splitlines(): 209 | reached_max_sequences = max_sequences and len(sequences) >= max_sequences 210 | if line.strip() and not line.startswith(('#', '//')): 211 | # Ignore blank lines, markup and end symbols - remainder are alignment 212 | # sequence parts. 213 | seqname, aligned_seq = line.split(maxsplit=1) 214 | if seqname not in sequences: 215 | if reached_max_sequences: 216 | continue 217 | sequences[seqname] = '' 218 | sequences[seqname] += aligned_seq 219 | 220 | for line in stockholm_format.splitlines(): 221 | if line[:4] == '#=GS': 222 | # Description row - example format is: 223 | # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... 224 | columns = line.split(maxsplit=3) 225 | seqname, feature = columns[1:3] 226 | value = columns[3] if len(columns) == 4 else '' 227 | if feature != 'DE': 228 | continue 229 | if reached_max_sequences and seqname not in sequences: 230 | continue 231 | descriptions[seqname] = value 232 | if len(descriptions) == len(sequences): 233 | break 234 | 235 | # Convert sto format to a3m line by line 236 | a3m_sequences = {} 237 | if remove_first_row_gaps: 238 | # query_sequence is assumed to be the first sequence 239 | query_sequence = next(iter(sequences.values())) 240 | query_non_gaps = [res != '-' for res in query_sequence] 241 | for seqname, sto_sequence in sequences.items(): 242 | # Dots are optional in a3m format and are commonly removed. 243 | out_sequence = sto_sequence.replace('.', '') 244 | if remove_first_row_gaps: 245 | out_sequence = ''.join( 246 | _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) 247 | a3m_sequences[seqname] = out_sequence 248 | 249 | fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" 250 | for k in a3m_sequences) 251 | return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. 252 | 253 | 254 | def _keep_line(line: str, seqnames: Set[str]) -> bool: 255 | """Function to decide which lines to keep.""" 256 | if not line.strip(): 257 | return True 258 | if line.strip() == '//': # End tag 259 | return True 260 | if line.startswith('# STOCKHOLM'): # Start tag 261 | return True 262 | if line.startswith('#=GC RF'): # Reference Annotation Line 263 | return True 264 | if line[:4] == '#=GS': # Description lines - keep if sequence in list. 265 | _, seqname, _ = line.split(maxsplit=2) 266 | return seqname in seqnames 267 | elif line.startswith('#'): # Other markup - filter out 268 | return False 269 | else: # Alignment data_bp - keep if sequence in list. 270 | seqname = line.partition(' ')[0] 271 | return seqname in seqnames 272 | 273 | 274 | def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: 275 | """Truncates a stockholm file to a maximum number of sequences.""" 276 | seqnames = set() 277 | filtered_lines = [] 278 | for line in stockholm_msa.splitlines(): 279 | if line.strip() and not line.startswith(('#', '//')): 280 | # Ignore blank lines, markup and end symbols - remainder are alignment 281 | # sequence parts. 282 | seqname = line.partition(' ')[0] 283 | seqnames.add(seqname) 284 | if len(seqnames) >= max_sequences: 285 | break 286 | 287 | for line in stockholm_msa.splitlines(): 288 | if _keep_line(line, seqnames): 289 | filtered_lines.append(line) 290 | 291 | return '\n'.join(filtered_lines) + '\n' 292 | 293 | 294 | def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: 295 | """Removes empty columns (dashes-only) from a Stockholm MSA.""" 296 | processed_lines = {} 297 | unprocessed_lines = {} 298 | for i, line in enumerate(stockholm_msa.splitlines()): 299 | if line.startswith('#=GC RF'): 300 | reference_annotation_i = i 301 | reference_annotation_line = line 302 | # Reached the end of this chunk of the alignment. Process chunk. 303 | _, _, first_alignment = line.rpartition(' ') 304 | mask = [] 305 | for j in range(len(first_alignment)): 306 | for _, unprocessed_line in unprocessed_lines.items(): 307 | prefix, _, alignment = unprocessed_line.rpartition(' ') 308 | if alignment[j] != '-': 309 | mask.append(True) 310 | break 311 | else: # Every row contained a hyphen - empty column. 312 | mask.append(False) 313 | # Add reference annotation for processing with mask. 314 | unprocessed_lines[reference_annotation_i] = reference_annotation_line 315 | 316 | if not any(mask): # All columns were empty. Output empty lines for chunk. 317 | for line_index in unprocessed_lines: 318 | processed_lines[line_index] = '' 319 | else: 320 | for line_index, unprocessed_line in unprocessed_lines.items(): 321 | prefix, _, alignment = unprocessed_line.rpartition(' ') 322 | masked_alignment = ''.join(itertools.compress(alignment, mask)) 323 | processed_lines[line_index] = f'{prefix} {masked_alignment}' 324 | 325 | # Clear raw_alignments. 326 | unprocessed_lines = {} 327 | elif line.strip() and not line.startswith(('#', '//')): 328 | unprocessed_lines[i] = line 329 | else: 330 | processed_lines[i] = line 331 | return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) 332 | 333 | 334 | def deduplicate_stockholm_msa(stockholm_msa: str) -> str: 335 | """Remove duplicate sequences (ignoring insertions wrt query).""" 336 | sequence_dict = collections.defaultdict(str) 337 | 338 | # First we must extract all sequences from the MSA. 339 | for line in stockholm_msa.splitlines(): 340 | # Only consider the alignments - ignore reference annotation, empty lines, 341 | # descriptions or markup. 342 | if line.strip() and not line.startswith(('#', '//')): 343 | line = line.strip() 344 | seqname, alignment = line.split() 345 | sequence_dict[seqname] += alignment 346 | 347 | seen_sequences = set() 348 | seqnames = set() 349 | # First alignment is the query. 350 | query_align = next(iter(sequence_dict.values())) 351 | mask = [c != '-' for c in query_align] # Mask is False for insertions. 352 | for seqname, alignment in sequence_dict.items(): 353 | # Apply mask to remove all insertions from the string. 354 | masked_alignment = ''.join(itertools.compress(alignment, mask)) 355 | if masked_alignment in seen_sequences: 356 | continue 357 | else: 358 | seen_sequences.add(masked_alignment) 359 | seqnames.add(seqname) 360 | 361 | filtered_lines = [] 362 | for line in stockholm_msa.splitlines(): 363 | if _keep_line(line, seqnames): 364 | filtered_lines.append(line) 365 | 366 | return '\n'.join(filtered_lines) + '\n' 367 | 368 | 369 | def _get_hhr_line_regex_groups( 370 | regex_pattern: str, line: str) -> Sequence[Optional[str]]: 371 | match = re.match(regex_pattern, line) 372 | if match is None: 373 | raise RuntimeError(f'Could not parse query line {line}') 374 | return match.groups() 375 | 376 | 377 | def _update_hhr_residue_indices_list( 378 | sequence: str, start_index: int, indices_list: List[int]): 379 | """Computes the relative indices for each residue with respect to the original sequence.""" 380 | counter = start_index 381 | for symbol in sequence: 382 | if symbol == '-': 383 | indices_list.append(-1) 384 | else: 385 | indices_list.append(counter) 386 | counter += 1 387 | 388 | 389 | def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: 390 | """Parses the detailed HMM HMM comparison section for a single Hit. 391 | 392 | This works on .hhr files generated from both HHBlits and HHSearch. 393 | 394 | Args: 395 | detailed_lines: A list of lines from a single comparison section between 2 396 | sequences (which each have their own HMM's) 397 | 398 | Returns: 399 | A dictionary with the information from that detailed comparison section 400 | 401 | Raises: 402 | RuntimeError: If a certain line cannot be processed 403 | """ 404 | # Parse first 2 lines. 405 | number_of_hit = int(detailed_lines[0].split()[-1]) 406 | name_hit = detailed_lines[1][1:] 407 | 408 | # Parse the summary line. 409 | pattern = ( 410 | 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' 411 | ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' 412 | ']*Template_Neff=(.*)') 413 | match = re.match(pattern, detailed_lines[2]) 414 | if match is None: 415 | raise RuntimeError( 416 | 'Could not parse section: %s. Expected this: \n%s to contain summary.' % 417 | (detailed_lines, detailed_lines[2])) 418 | (_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x) 419 | for x in match.groups()] 420 | 421 | # The next section reads the detailed comparisons. These are in a 'human 422 | # readable' format which has a fixed length. The strategy employed is to 423 | # assume that each block starts with the query sequence line, and to parse 424 | # that with a regexp in order to deduce the fixed length used for that block. 425 | query = '' 426 | hit_sequence = '' 427 | indices_query = [] 428 | indices_hit = [] 429 | length_block = None 430 | 431 | for line in detailed_lines[3:]: 432 | # Parse the query sequence line 433 | if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and 434 | not line.startswith('Q ss_pred') and 435 | not line.startswith('Q Consensus')): 436 | # Thus the first 17 characters must be 'Q ', and we can parse 437 | # everything after that. 438 | # start sequence end total_sequence_length 439 | patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' 440 | groups = _get_hhr_line_regex_groups(patt, line[17:]) 441 | 442 | # Get the length of the parsed block using the start and finish indices, 443 | # and ensure it is the same as the actual block length. 444 | start = int(groups[0]) - 1 # Make index zero based. 445 | delta_query = groups[1] 446 | end = int(groups[2]) 447 | num_insertions = len([x for x in delta_query if x == '-']) 448 | length_block = end - start + num_insertions 449 | assert length_block == len(delta_query) 450 | 451 | # Update the query sequence and indices list. 452 | query += delta_query 453 | _update_hhr_residue_indices_list(delta_query, start, indices_query) 454 | 455 | elif line.startswith('T '): 456 | # Parse the hit sequence. 457 | if (not line.startswith('T ss_dssp') and 458 | not line.startswith('T ss_pred') and 459 | not line.startswith('T Consensus')): 460 | # Thus the first 17 characters must be 'T ', and we can 461 | # parse everything after that. 462 | # start sequence end total_sequence_length 463 | patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' 464 | groups = _get_hhr_line_regex_groups(patt, line[17:]) 465 | start = int(groups[0]) - 1 # Make index zero based. 466 | delta_hit_sequence = groups[1] 467 | assert length_block == len(delta_hit_sequence) 468 | 469 | # Update the hit sequence and indices list. 470 | hit_sequence += delta_hit_sequence 471 | _update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) 472 | 473 | return TemplateHit( 474 | index=number_of_hit, 475 | name=name_hit, 476 | aligned_cols=int(aligned_cols), 477 | sum_probs=sum_probs, 478 | query=query, 479 | hit_sequence=hit_sequence, 480 | indices_query=indices_query, 481 | indices_hit=indices_hit, 482 | ) 483 | 484 | 485 | def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: 486 | """Parses the content of an entire HHR file.""" 487 | lines = hhr_string.splitlines() 488 | 489 | # Each .hhr file starts with a results table, then has a sequence of hit 490 | # "paragraphs", each paragraph starting with a line 'No '. We 491 | # iterate through each paragraph to parse each hit. 492 | 493 | block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] 494 | 495 | hits = [] 496 | if block_starts: 497 | block_starts.append(len(lines)) # Add the end of the final block. 498 | for i in range(len(block_starts) - 1): 499 | hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) 500 | return hits 501 | 502 | 503 | def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: 504 | """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" 505 | e_values = {'query': 0} 506 | lines = [line for line in tblout.splitlines() if line[0] != '#'] 507 | # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are 508 | # space-delimited. Relevant fields are (1) target name: and 509 | # (5) E-value (full sequence) (numbering from 1). 510 | for line in lines: 511 | fields = line.split() 512 | e_value = fields[4] 513 | target_name = fields[0] 514 | e_values[target_name] = float(e_value) 515 | return e_values 516 | 517 | 518 | def _get_indices(sequence: str, start: int) -> List[int]: 519 | """Returns indices for non-gap/insert residues starting at the given index.""" 520 | indices = [] 521 | counter = start 522 | for symbol in sequence: 523 | # Skip gaps but add a placeholder so that the alignment is preserved. 524 | if symbol == '-': 525 | indices.append(-1) 526 | # Skip deleted residues, but increase the counter. 527 | elif symbol.islower(): 528 | counter += 1 529 | # Normal aligned residue. Increase the counter and append to indices. 530 | else: 531 | indices.append(counter) 532 | counter += 1 533 | return indices 534 | 535 | 536 | @dataclasses.dataclass(frozen=True) 537 | class HitMetadata: 538 | pdb_id: str 539 | chain: str 540 | start: int 541 | end: int 542 | length: int 543 | text: str 544 | 545 | 546 | def _parse_hmmsearch_description(description: str) -> HitMetadata: 547 | """Parses the hmmsearch A3M sequence description line.""" 548 | # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text 549 | # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 550 | match = re.match( 551 | r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', 552 | description.strip()) 553 | 554 | if not match: 555 | raise ValueError(f'Could not parse description: "{description}".') 556 | 557 | return HitMetadata( 558 | pdb_id=match[1], 559 | chain=match[2], 560 | start=int(match[3]), 561 | end=int(match[4]), 562 | length=int(match[5]), 563 | text=match[6]) 564 | 565 | 566 | def parse_hmmsearch_a3m(query_sequence: str, 567 | a3m_string: str, 568 | skip_first: bool = True) -> Sequence[TemplateHit]: 569 | """Parses an a3m string produced by hmmsearch. 570 | 571 | Args: 572 | query_sequence: The query sequence. 573 | a3m_string: The a3m string produced by hmmsearch. 574 | skip_first: Whether to skip the first sequence in the a3m string. 575 | 576 | Returns: 577 | A sequence of `TemplateHit` results. 578 | """ 579 | # Zip the descriptions and MSAs together, skip the first query sequence. 580 | parsed_a3m = list(zip(*parse_fasta(a3m_string))) 581 | if skip_first: 582 | parsed_a3m = parsed_a3m[1:] 583 | 584 | indices_query = _get_indices(query_sequence, start=0) 585 | 586 | hits = [] 587 | for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): 588 | if 'mol:protein' not in hit_description: 589 | continue # Skip non-protein chains. 590 | metadata = _parse_hmmsearch_description(hit_description) 591 | # Aligned columns are only the match states. 592 | aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) 593 | indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) 594 | 595 | hit = TemplateHit( 596 | index=i, 597 | name=f'{metadata.pdb_id}_{metadata.chain}', 598 | aligned_cols=aligned_cols, 599 | sum_probs=None, 600 | query=query_sequence, 601 | hit_sequence=hit_sequence.upper(), 602 | indices_query=indices_query, 603 | indices_hit=indices_hit, 604 | ) 605 | hits.append(hit) 606 | 607 | return hits 608 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common utilities for data_bp pipeline tools.""" 15 | import contextlib 16 | import shutil 17 | import tempfile 18 | import time 19 | from typing import Optional 20 | 21 | from absl import logging 22 | 23 | 24 | @contextlib.contextmanager 25 | def tmpdir_manager(base_dir: Optional[str] = None): 26 | """Context manager that deletes a temporary directory on exit.""" 27 | tmpdir = tempfile.mkdtemp(dir=base_dir) 28 | try: 29 | yield tmpdir 30 | finally: 31 | shutil.rmtree(tmpdir, ignore_errors=True) 32 | 33 | 34 | @contextlib.contextmanager 35 | def timing(msg: str): 36 | logging.info('Started %s', msg) 37 | tic = time.time() 38 | yield 39 | toc = time.time() 40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic) 41 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | import torch.optim as optim 5 | from torchsummary import summary 6 | 7 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score 8 | import Constants 9 | import params 10 | 11 | from Dataset.Dataset import load_dataset 12 | from models.gnn import GCN 13 | import argparse 14 | import torch 15 | import time 16 | from torch_geometric.loader import DataLoader 17 | from preprocessing.utils import pickle_save, pickle_load, save_ckp, load_ckp, class_distribution_counter, \ 18 | draw_architecture, compute_roc 19 | 20 | import warnings 21 | 22 | warnings.filterwarnings("ignore", category=UserWarning) 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 27 | parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.') 28 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 29 | parser.add_argument('--epochs', type=int, default=70, help='Number of epochs to train.') 30 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate.') 31 | parser.add_argument('--train_batch', type=int, default=10, help='Training batch size.') 32 | parser.add_argument('--valid_batch', type=int, default=10, help='Validation batch size.') 33 | parser.add_argument('--seq', type=float, default=0.9, help='Sequence Identity (Sequence Identity).') 34 | parser.add_argument("--ont", default='molecular_function', type=str, help='Ontology under consideration') 35 | 36 | args = parser.parse_args() 37 | args.cuda = not args.no_cuda and torch.cuda.is_available() 38 | if args.cuda: 39 | device = 'cuda' 40 | 41 | kwargs = { 42 | 'seq_id': args.seq, 43 | 'ont': args.ont, 44 | 'session': 'train' 45 | } 46 | 47 | if args.ont == 'molecular_function': 48 | ont_kwargs = params.mol_kwargs 49 | elif args.ont == 'cellular_component': 50 | ont_kwargs = params.cc_kwargs 51 | elif args.ont == 'biological_process': 52 | ont_kwargs = params.bio_kwargs 53 | 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | if args.cuda: 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | num_class = len(pickle_load(Constants.ROOT + 'go_terms')[f'GO-terms-{args.ont}']) 60 | 61 | 62 | def create_class_weights(cnter): 63 | class_weight_path = Constants.ROOT + "{}/{}/class_weights".format(kwargs['seq_id'], kwargs['ont']) 64 | if os.path.exists(class_weight_path + ".pickle"): 65 | print("Loading class weights") 66 | class_weights = pickle_load(class_weight_path) 67 | else: 68 | print("Generating class weights") 69 | go_terms = pickle_load(Constants.ROOT + "/go_terms") 70 | terms = go_terms['GO-terms-{}'.format(args.ont)] 71 | class_weights = [cnter[i] for i in terms] 72 | 73 | total = sum(class_weights) 74 | _max = max(class_weights) 75 | class_weights = torch.tensor([total / i for i in class_weights], dtype=torch.float).to(device) 76 | 77 | return class_weights 78 | class_weights = create_class_weights(class_distribution_counter(**kwargs)) 79 | 80 | 81 | dataset = load_dataset(root=Constants.ROOT, **kwargs) 82 | labels = pickle_load(Constants.ROOT + "{}_labels".format(args.ont)) 83 | 84 | edge_types = list(params.edge_types) 85 | 86 | train_dataloader = DataLoader(dataset, 87 | batch_size=args.train_batch, 88 | drop_last=True, 89 | exclude_keys=edge_types, 90 | shuffle=True) 91 | 92 | kwargs['session'] = 'validation' 93 | val_dataset = load_dataset(root=Constants.ROOT, **kwargs) 94 | valid_dataloader = DataLoader(val_dataset, 95 | batch_size=args.valid_batch, 96 | drop_last=False, 97 | shuffle=False, 98 | exclude_keys=edge_types) 99 | 100 | print('========================================') 101 | print(f'# training proteins: {len(dataset)}') 102 | print(f'# validation proteins: {len(val_dataset)}') 103 | print(f'# Number of classes: {num_class}') 104 | # print(f'# Max class weights: {torch.max(class_weights)}') 105 | # print(f'# Min class weights: {torch.min(class_weights)}') 106 | print('========================================') 107 | 108 | current_epoch = 1 109 | min_val_loss = np.Inf 110 | 111 | inpu = next(iter(train_dataloader)) 112 | model = GCN(**ont_kwargs) 113 | 114 | 115 | model.to(device) 116 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=ont_kwargs['wd']) 117 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 118 | criterion = torch.nn.BCELoss(reduction='none') 119 | 120 | labels = pickle_load(Constants.ROOT + "{}_labels".format(args.ont)) 121 | 122 | 123 | def train(start_epoch, min_val_loss, model, optimizer, criterion, data_loader): 124 | min_val_loss = min_val_loss 125 | 126 | for epoch in range(start_epoch, args.epochs): 127 | print(" ---------- Epoch {} ----------".format(epoch)) 128 | # initialize variables to monitor training and validation loss 129 | epoch_loss, epoch_precision, epoch_recall, epoch_accuracy, epoch_f1 = 0.0, 0.0, 0.0, 0.0, 0.0 130 | val_loss, val_precision, val_recall, val_accuracy, val_f1 = 0.0, 0.0, 0.0, 0.0, 0.0 131 | 132 | t = time.time() 133 | 134 | with torch.autograd.set_detect_anomaly(True): 135 | lr_scheduler.step() 136 | ################### 137 | # train the model # 138 | ################### 139 | model.train() 140 | for pos, data in enumerate(data_loader['train']): 141 | labs = [] 142 | for la in data['atoms'].protein: 143 | labs.append(torch.tensor(labels[la], dtype=torch.float32).view(1, -1)) 144 | 145 | labs = torch.cat(labs, dim=0) 146 | cnts = torch.sum(labs, dim=0) 147 | total = torch.sum(cnts) 148 | 149 | optimizer.zero_grad() 150 | 151 | output = model(data.to(device)) 152 | loss = criterion(output, labs.to(device)) 153 | loss = (loss * class_weights).mean() 154 | 155 | 156 | pom = output.detach().cpu().numpy() 157 | bins = np.arange(0.0, 1.1, 0.1) 158 | digitized = np.digitize(pom, bins) - 1 159 | counts = np.bincount(digitized.flatten(), minlength=len(bins)) 160 | 161 | # Display the counts for each range 162 | for i, count in enumerate(counts): 163 | lower_bound = round(bins[i], 1) 164 | upper_bound = round(bins[i+1], 1) if i < len(bins) - 1 else 1.0 165 | print(f"Range {lower_bound:.1f} - {upper_bound:.1f}: {count} elements") 166 | 167 | 168 | exit() 169 | 170 | loss.backward() 171 | optimizer.step() 172 | epoch_loss += loss.data.item() 173 | 174 | out_cpu_5 = output.cpu() > 0.5 175 | epoch_accuracy += accuracy_score(y_true=labs, y_pred=out_cpu_5) 176 | epoch_precision += precision_score(y_true=labs, y_pred=out_cpu_5, average="samples") 177 | epoch_recall += recall_score(y_true=labs, y_pred=out_cpu_5, average="samples") 178 | epoch_f1 += f1_score(y_true=labs, y_pred=out_cpu_5, average="samples") 179 | 180 | 181 | epoch_accuracy = epoch_accuracy / len(loaders['train']) 182 | epoch_precision = epoch_precision / len(loaders['train']) 183 | epoch_recall = epoch_recall / len(loaders['train']) 184 | epoch_f1 = epoch_f1 / len(loaders['train']) 185 | 186 | ################### 187 | # Validate the model # 188 | ################### 189 | with torch.no_grad(): 190 | model.eval() 191 | for data in data_loader['valid']: 192 | 193 | labs = [] 194 | for la in data['atoms'].protein: 195 | labs.append(torch.tensor(labels[la], dtype=torch.float32).view(1, -1)) 196 | labs = torch.cat(labs) 197 | 198 | output = model(data.to(device)) 199 | 200 | _val_loss = criterion(output, labs.to(device)) 201 | _val_loss = (_val_loss * class_weights).mean() 202 | val_loss += _val_loss.data.item() 203 | 204 | val_accuracy += accuracy_score(labs, output.cpu() > 0.5) 205 | val_precision += precision_score(labs, output.cpu() > 0.5, average="samples") 206 | val_recall += recall_score(labs, output.cpu() > 0.5, average="samples") 207 | val_f1 += f1_score(labs, output.cpu() > 0.5, average="samples") 208 | 209 | val_loss = val_loss / len(loaders['valid']) 210 | val_accuracy = val_accuracy / len(loaders['valid']) 211 | val_precision = val_precision / len(loaders['valid']) 212 | val_recall = val_recall / len(loaders['valid']) 213 | val_f1 = val_f1 / len(loaders['valid']) 214 | 215 | print('Epoch: {:04d}'.format(epoch), 216 | 'train_loss: {:.4f}'.format(epoch_loss), 217 | 'train_acc: {:.4f}'.format(epoch_accuracy), 218 | 'precision: {:.4f}'.format(epoch_precision), 219 | 'recall: {:.4f}'.format(epoch_recall), 220 | 'f1: {:.4f}'.format(epoch_f1), 221 | 'val_acc: {:.4f}'.format(val_accuracy), 222 | 'val_loss: {:.4f}'.format(val_loss), 223 | 'val_precision: {:.4f}'.format(val_precision), 224 | 'val_recall: {:.4f}'.format(val_recall), 225 | 'val_f1: {:.4f}'.format(val_f1), 226 | 'time: {:.4f}s'.format(time.time() - t)) 227 | 228 | 229 | 230 | checkpoint = { 231 | 'epoch': epoch, 232 | 'valid_loss_min': val_loss, 233 | 'state_dict': model.state_dict(), 234 | 'optimizer': optimizer.state_dict(), 235 | } 236 | 237 | # save checkpoint 238 | # save_ckp(checkpoint, False, ckp_pth, 239 | # ckp_dir + "best_model.pt") 240 | # 241 | # if val_loss <= min_val_loss: 242 | # print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'. \ 243 | # format(min_val_loss, val_loss)) 244 | # 245 | # # save checkpoint as best model 246 | # save_ckp(checkpoint, True, ckp_pth, 247 | # ckp_dir + "best_model.pt") 248 | # min_val_loss = val_loss 249 | 250 | return model 251 | 252 | 253 | loaders = { 254 | 'train': train_dataloader, 255 | 'valid': valid_dataloader 256 | } 257 | 258 | 259 | ckp_dir = Constants.ROOT + 'checkpoints/model_checkpoint/{}/'.format("cur") 260 | ckp_pth = ckp_dir + "current_checkpoint.pt" 261 | #ckp_pth = "" 262 | 263 | ckp_dir = "/home/fbqc9/PycharmProjects/TFUNClone/TransFun/data/" 264 | ckp_pth = ckp_dir + "molecular_function.pt" 265 | 266 | print(ckp_pth) 267 | 268 | if os.path.exists(ckp_pth): 269 | print("Loading model checkpoint @ {}".format(ckp_pth)) 270 | model, optimizer, current_epoch, min_val_loss = load_ckp(ckp_pth, model, optimizer, device="cuda:0") 271 | else: 272 | if not os.path.exists(ckp_dir): 273 | os.makedirs(ckp_dir) 274 | 275 | print("Training model on epoch {}, with minimum validation loss as {}".format(current_epoch, min_val_loss)) 276 | 277 | config = { 278 | "learning_rate": args.lr, 279 | "epochs": current_epoch, 280 | "batch_size": args.train_batch, 281 | "valid_size": args.valid_batch, 282 | "weight_decay": ont_kwargs['wd'] 283 | } 284 | 285 | 286 | 287 | trained_model = train(current_epoch, min_val_loss, 288 | model=model, optimizer=optimizer, 289 | criterion=criterion, data_loader=loaders) 290 | --------------------------------------------------------------------------------