├── CONSTANTS.py ├── Classes ├── Diamond.py ├── Embeddings.py ├── Fasta.py ├── Interpro.py ├── STRING.py └── Templates.py ├── DATASET.md ├── DataGen ├── Embeddings.py ├── LabelEmbeddings ├── msa.py ├── msa_pred.py └── msa_test_shard.py ├── Dataset ├── Dataset.py ├── Dataset_tofix.py ├── FastDataset.py └── MyDataset.py ├── Graph ├── DiamondDataset.py ├── DiamondDatasets.py ├── PPI.py └── ProteinGraph.py ├── LICENSE ├── Loss └── Loss.py ├── Preprocess.py ├── README.md ├── TODO.py ├── Utils.py ├── create_test.py ├── environment.yml ├── evaluate.py ├── evaluation.py ├── evaluation ├── evaluate.py ├── predictions │ └── deepgose │ │ └── format_output.py └── utils.py ├── evaluation_components.py ├── evaluation_label_embedding.py ├── evaluation_rare_terms.py ├── evaluation_scripts ├── diamondblast.py ├── format.py ├── format_deepgose.py ├── format_netgo3.py ├── format_sprof.py ├── format_tale.py └── naive.py ├── evaluation_seqID.py ├── external └── extract.py ├── hparams.py ├── inference.py ├── inference_combined.py ├── label_embedding.py ├── models ├── config.py ├── model.py ├── model_ablation.py ├── model_struct.py └── net_utils.py ├── notes ├── output_dir └── test_fasta.fasta ├── plot.ipynb ├── predict.py ├── similarity_measure.py ├── test_data.py ├── train_data.py ├── training.py ├── training_ablation.py ├── transfew.yaml └── workbook.py /CONSTANTS.py: -------------------------------------------------------------------------------- 1 | residues = { 2 | "A": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7, "I": 8, "K": 9, "L": 10, "M": 11, 3 | "N": 12, "P": 13, "Q": 14, "R": 15, "S": 16, "T": 17, "V": 18, "W": 19, "Y": 20 4 | } 5 | 6 | INVALID_ACIDS = {"U", "O", "B", "Z", "J", "X", "*"} 7 | 8 | amino_acids = { 9 | "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", "GLN": "Q", "GLU": "E", 10 | "GLY": "G", "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", 11 | "PRO": "P", "PYL": "O", "SER": "S", "SEC": "U", "THR": "T", "TRP": "W", "TYR": "Y", 12 | "VAL": "V", "ASX": "B", "GLX": "Z", "XAA": "X", "XLE": "J" 13 | } 14 | 15 | root_terms = {"GO:0008150", "GO:0003674", "GO:0005575"} 16 | 17 | exp_evidence_codes = {"EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "TAS", "IC", "HTP", "HDA", "HMP", "HGI", "HEP"} 18 | 19 | # ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 20 | ROOT_DIR = "/home/fbqc9/Workspace/TFewData/" 21 | 22 | ROOT = "/home/fbqc9/PycharmProjects/TransFun2/" 23 | 24 | NAMESPACES = { 25 | "cc": "cellular_component", 26 | "mf": "molecular_function", 27 | "bp": "biological_process" 28 | } 29 | 30 | FUNC_DICT = { 31 | 'cc': 'GO:0005575', 32 | 'mf': 'GO:0003674', 33 | 'bp': 'GO:0008150' 34 | } 35 | 36 | BENCH_DICT = { 37 | 'cc': "CCO", 38 | 'mf': 'MFO', 39 | 'bp': 'BPO' 40 | } 41 | 42 | NAMES = { 43 | "cc": "Cellular Component", 44 | "mf": "Molecular Function", 45 | "bp": "Biological Process" 46 | } 47 | 48 | GO_FILTERS = { 49 | 'cc': (25, 4), 50 | 'mf': (30, 4), 51 | 'bp': (30, 8) 52 | } 53 | 54 | go_graph_path = ROOT_DIR + "/obo/go-basic.obo" -------------------------------------------------------------------------------- /Classes/Diamond.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | import subprocess 4 | import networkx as nx 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import CONSTANTS 9 | import os.path as osp 10 | import torch_geometric.data as pygdata 11 | from Utils import is_file, pickle_load, pickle_save 12 | 13 | class Diamond: 14 | """ 15 | Class to handle Diamond data 16 | """ 17 | 18 | def __init__(self, session='train', **kwargs): 19 | self.session = session 20 | 21 | if session == 'train': 22 | self.fasta = kwargs.get('fasta_file', CONSTANTS.ROOT_DIR + "uniprot/uniprot_fasta.fasta") 23 | self.dbase = kwargs.get('dbase', CONSTANTS.ROOT_DIR + "diamond/database") 24 | self.query = kwargs.get('query', CONSTANTS.ROOT_DIR + "diamond/query.tsv") 25 | self.adjlist = kwargs.get('output', CONSTANTS.ROOT_DIR + "diamond/graph.list") 26 | elif session == 'test': 27 | self.fasta = kwargs.get('fasta_file', CONSTANTS.ROOT_DIR + "cafa5/Test_Target/testsuperset.fasta") 28 | self.dbase = kwargs.get('dbase', CONSTANTS.ROOT_DIR + "diamond/database") 29 | self.query = kwargs.get('query', CONSTANTS.ROOT_DIR + "diamond/test_query.tsv") 30 | self.adjlist = kwargs.get('output', CONSTANTS.ROOT_DIR + "diamond/test_graph.list") 31 | 32 | def create_diamond_dbase(self): 33 | print("Creating Diamond Database") 34 | if is_file(self.fasta): 35 | CMD = "diamond makedb --in {} -d {}" \ 36 | .format(self.fasta, self.dbase) 37 | subprocess.call(CMD, shell=True) 38 | else: 39 | print("Fasta file not found.") 40 | exit() 41 | 42 | def query_diamond(self): 43 | print("Querying Diamond Database") 44 | CMD = "diamond blastp -q {} -d {} -o {} --sensitive" \ 45 | .format(self.fasta, self.dbase, self.query) 46 | if is_file(self.dbase + ".dmnd"): 47 | subprocess.call(CMD, shell=True) 48 | else: 49 | print("Database not found. Creating database") 50 | self.create_diamond_dbase() 51 | print("Querying Diamond Database") 52 | subprocess.call(CMD, shell=True) 53 | 54 | def create_diamond_graph(self): 55 | if not is_file(self.query): 56 | self.query_diamond() 57 | self.read_file() 58 | 59 | def read_file(self): 60 | scores = open(self.query, 'r') 61 | res = {} 62 | for line in scores.readlines(): 63 | tmp = line.split("\t") 64 | src, des, wgt = tmp[0], tmp[1], float(tmp[2]) / 100 65 | 66 | if src not in res: 67 | res[src] = {} 68 | res[src][des] = wgt 69 | 70 | if des not in res: 71 | res[des] = {} 72 | res[des][src] = wgt 73 | 74 | pickle_save(res, self.adjlist) 75 | 76 | def get_graph(self): 77 | 78 | if not is_file(self.adjlist + ".pickle"): 79 | self.create_diamond_graph() 80 | res = pickle_load(self.adjlist) 81 | return res 82 | 83 | 84 | def get_graph_addlist(self): 85 | 86 | # load test diamond 87 | 88 | # fasta = "/home/fbqc9/Workspace/DATA/uniprot/test_proteins.fasta" 89 | # query = CONSTANTS.ROOT_DIR + "diamond/eval_test_query.tsv" 90 | 91 | # CMD = "diamond blastp -q {} -d {} -o {} --sensitive" \ 92 | # .format(fasta, self.dbase, query) 93 | 94 | # subprocess.call(CMD, shell=True) 95 | 96 | 97 | query = CONSTANTS.ROOT_DIR + "diamond/eval_test_query.tsv" 98 | scores = open(query, 'r') 99 | res = {} 100 | for line in scores.readlines(): 101 | tmp = line.split("\t") 102 | src, des, wgt = tmp[0], tmp[1], float(tmp[2]) / 100 103 | 104 | if src not in res: 105 | res[src] = [] 106 | res[src].append((des, wgt)) 107 | 108 | if des not in res: 109 | res[des] = [] 110 | res[des].append((src, wgt)) 111 | 112 | return res 113 | 114 | 115 | def create_pytorch_graph(self, ): 116 | 117 | onts = ['cc', 'mf', 'bp'] 118 | 119 | for ont in onts: 120 | 121 | labels = pickle_load(CONSTANTS.ROOT_DIR + "{}/labels".format(ont)) 122 | proteins = pickle_load(CONSTANTS.ROOT_DIR + "{}/all_proteins".format(ont)) 123 | indicies = list(range(0, len(proteins))) 124 | val_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/validation_indicies".format(ont)) 125 | train_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/train_indicies".format(ont)) 126 | 127 | 128 | protein_dic = { prot: pos for pos, prot in enumerate(proteins)} 129 | 130 | embeddings = [] 131 | ys = [] 132 | interactions = self.get_graph() 133 | 134 | rows = [] 135 | columns = [] 136 | weights = [] 137 | for src in interactions: 138 | for des, score in interactions[src].items(): 139 | if src == des: 140 | pass 141 | else: 142 | try: 143 | _row = protein_dic[src] 144 | _col = protein_dic[des] 145 | rows.append(_row) 146 | columns.append(_col) 147 | weights.append(score) 148 | except KeyError: 149 | pass 150 | 151 | assert len(rows) == len(columns) == len(weights) 152 | 153 | rows = np.array(rows, dtype='int64') 154 | columns = np.array(columns, dtype='int64') 155 | edges = torch.tensor(np.array([rows, columns])) 156 | 157 | edges_attr = torch.tensor(np.array(weights, dtype='float32')) 158 | 159 | 160 | for pos, prt in enumerate(proteins): 161 | print(pos, len(proteins)) 162 | tmp = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(prt)) 163 | esm = tmp['esm2_t48'].x#.squeeze(0) 164 | embeddings.append(esm) 165 | 166 | ys.append(torch.tensor(labels[prt], dtype=torch.float32).view(1, -1)) 167 | 168 | embeddings = torch.cat(embeddings, dim=0) 169 | ys = torch.cat(ys, dim=0) 170 | 171 | x_x = [False] * 92912 172 | for i in train_indicies: 173 | x_x[i] = True 174 | train_indicies = torch.tensor(np.array(x_x)) 175 | 176 | x_x = [False] * 92912 177 | for i in val_indicies: 178 | x_x[i] = True 179 | val_indicies = torch.tensor(np.array(x_x)) 180 | 181 | graph = pygdata.Data(num_nodes=len(proteins), edge_index=edges, 182 | x=embeddings, y=ys, edges_attr=edges_attr, 183 | train_mask=torch.tensor(train_indicies), 184 | val_mask=torch.tensor(val_indicies)) 185 | 186 | 187 | torch.save(graph, osp.join(CONSTANTS.ROOT_DIR + "{}_tformer_data.pt".format(ont))) 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /Classes/Embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | import subprocess 4 | from os import listdir 5 | 6 | import pandas as pd 7 | from Bio import SeqIO 8 | import CONSTANTS 9 | from Utils import is_file, create_directory, readlines_cluster, extract_id 10 | from preprocessing.utils import create_seqrecord 11 | 12 | 13 | class Embeddings: 14 | """ 15 | This class is used to handle all embedding generations 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | self.dir = kwargs.get('dir', CONSTANTS.ROOT_DIR) 20 | self.fasta = kwargs.get('fasta', None) 21 | self.session = "training" 22 | self.database = kwargs.get('database', "") 23 | 24 | self.mmseq_root = self.dir + "mmseq/" 25 | self.mmseq_dbase_root = self.mmseq_root + "dbase/" 26 | self.mmseq_cluster_root = self.mmseq_root + "cluster/" 27 | 28 | self.uniclust_dbase = None 29 | 30 | self.run() 31 | 32 | def create_database(self): 33 | mmseq_dbase_path = self.mmseq_dbase_root + "mmseq_dbase" 34 | if not is_file(mmseq_dbase_path): 35 | create_directory(self.mmseq_dbase_root) 36 | CMD = "mmseqs createdb {} {}".format(self.fasta, mmseq_dbase_path) 37 | subprocess.call(CMD, shell=True, cwd="{}".format(self.dir)) 38 | print("mmseq database created") 39 | 40 | def generate_cluster(self): 41 | mmseq_dbase_path = self.mmseq_dbase_root + "mmseq_dbase" 42 | mmseq_cluster_path = self.mmseq_cluster_root + "mmseq_cluster" 43 | final_cluster = mmseq_cluster_path + ".tsv" 44 | output = self.mmseq_cluster_root+ "final" + ".tsv" 45 | if not is_file(final_cluster): 46 | create_directory(self.mmseq_cluster_root) 47 | CMD = "mmseqs cluster {} {} tmp ; " \ 48 | "mmseqs createtsv {} {} {} {}.tsv".format(mmseq_dbase_path, mmseq_cluster_path, mmseq_dbase_path, 49 | mmseq_dbase_path, mmseq_cluster_path, mmseq_cluster_path) 50 | 51 | subprocess.call(CMD, shell=True, cwd="{}".format(self.dir)) 52 | 53 | self.one_line_format(final_cluster, output) 54 | 55 | @staticmethod 56 | def one_line_format(input_file, output): 57 | """ 58 | Script takes the mm2seq cluster output and converts to representative seq1, seq2, seq3 .... 59 | :param output: 60 | :param input_file: The clusters as csv file 61 | :return: None 62 | """ 63 | data = {} 64 | with open(input_file) as file: 65 | lines = file.read().splitlines() 66 | for line in lines: 67 | x = line.split("\t") 68 | if x[0] in data: 69 | data[x[0]].append(x[1]) 70 | else: 71 | data[x[0]] = list([x[1]]) 72 | result = [data[i] for i in data] 73 | with open(output, "w") as f: 74 | wr = csv.writer(f, delimiter='\t') 75 | wr.writerows(result) 76 | 77 | def generate_msas(self): 78 | # this part was run on lotus 79 | all_fastas = listdir(CONSTANTS.ROOT_DIR + "uniprot/single_fasta/") 80 | for fasta in all_fastas: 81 | fasta_name = fasta.split(".")[0] 82 | CMD = "hhblits -i query.fasta -d {} -oa3m msas/{}.a3m -cpu 4 -n 2".format(self.uniclust_dbase, fasta_name) 83 | print(CMD) 84 | #subprocess.call(CMD, shell=True, cwd="{}".format(self.dir)) 85 | 86 | # generate msa from cluster 87 | def msa_from_cluster(self): 88 | 89 | cluster_name = self.dir + "cluster/mmseq_cluster" 90 | dbase_name = self.dir + "database/mmseq_dbase" 91 | msa_name = cluster_name + "_msa" 92 | if not is_file(msa_name): 93 | CMD = "D:/Workspace/python-3/TFUN/mmseqs/mmseqs.bat result2msa {} {} {} {}" \ 94 | .format(dbase_name, dbase_name, cluster_name, cluster_name + "_msa") 95 | subprocess.call(CMD, shell=True, cwd="{}".format(self.dir)) 96 | 97 | def read_fasta_list(self): 98 | training_proteins = set(pd.read_csv("../preprocessing/uniprot.csv", sep="\t")['ACC'].to_list()) 99 | seqs = [] 100 | input_seq_iterator = SeqIO.parse(self.fasta, "fasta") 101 | for record in input_seq_iterator: 102 | uniprot_id = extract_id(record.id) 103 | if uniprot_id in training_proteins: 104 | seqs.append(create_seqrecord(id=uniprot_id, seq=str(record.seq))) 105 | return seqs 106 | 107 | def search(self, sequences): 108 | pass 109 | 110 | # generate embeddings 111 | def generate_embeddings(self): 112 | # name model output dir, embedding layer 1, embedding layer 2, batch 113 | models = (("esm_msa_1b", "esm_msa1b_t12_100M_UR50S", "msa", CONSTANTS.ROOT_DIR + "embedding/esm_msa_1b", 11, 12, 10), 114 | ("esm_2", "esm2_t48_15B_UR50D", self.fasta, CONSTANTS.ROOT_DIR + "embedding/esm2_t48", 47, 48, 10), 115 | ("esm_2", "esm2_t36_3B_UR50D", self.fasta, CONSTANTS.ROOT_DIR + "embedding/esm_t36", 35, 36, 50)) 116 | for model in models[2:]: 117 | if model[0] == "esm_msa_1b": 118 | CMD = "python {} {} {} {} --repr_layers {} {} --include mean per_tok contacts " \ 119 | "--toks_per_batch {} ".format(CONSTANTS.ROOT + "external/extract.py", model[1], model[2], 120 | model[3], model[4], model[5], model[6]) 121 | else: 122 | CMD = "python {} {} {} {} --repr_layers {} {} --include mean per_tok --nogpu " \ 123 | "--toks_per_batch {} ".format(CONSTANTS.ROOT + "external/extract.py", model[1], model[2], 124 | model[3], model[4], model[5], model[6]) 125 | 126 | print(CMD) 127 | subprocess.call(CMD, shell=True, cwd="{}".format(self.dir)) 128 | 129 | def run(self): 130 | if self.session == "training": 131 | # self.create_database() 132 | # self.generate_cluster() 133 | # self.generate_msas() 134 | self.generate_embeddings() 135 | else: 136 | pass 137 | # self.search() 138 | 139 | 140 | kwargs = { 141 | 'fasta': CONSTANTS.ROOT_DIR + "testfasta" 142 | } 143 | embeddings = Embeddings(**kwargs) 144 | -------------------------------------------------------------------------------- /Classes/Fasta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from Bio import SeqIO 4 | from Bio.Seq import Seq 5 | from Bio.SeqRecord import SeqRecord 6 | 7 | import CONSTANTS 8 | from Utils import count_proteins 9 | 10 | 11 | def create_seqrecord(id="", name="", description="", seq=""): 12 | record = SeqRecord(Seq(seq), id=id, name=name, description=description) 13 | return record 14 | 15 | 16 | def extract_id(header): 17 | return header.split('|')[1] 18 | 19 | 20 | class Fasta: 21 | def __init__(self, fasta_path): 22 | self.fasta = fasta_path 23 | 24 | def fasta_to_list(self, _filter=None): 25 | seqs = [] 26 | input_seq_iterator = SeqIO.parse(self.fasta, "fasta") 27 | for record in input_seq_iterator: 28 | seqs.append(create_seqrecord(id=extract_id(record.id), seq=str(record.seq))) 29 | 30 | if _filter is not None: 31 | seqs = [record for record in seqs if record.id in _filter] 32 | return seqs 33 | 34 | def reformat(self, _filter=None, output=""): 35 | seqs = self.fasta_to_list(_filter=_filter) 36 | SeqIO.write(seqs, output, "fasta") 37 | 38 | # Create individual fasta files from a large fasta file for hhblits alignment 39 | def fastas_from_fasta(self, _filter=None, out_dir=""): 40 | seqs = self.fasta_to_list(_filter=_filter) 41 | for seq in seqs: 42 | SeqIO.write(seq, out_dir + "/{}.fasta".format(seq.id), "fasta") 43 | 44 | # Removes unwanted proteins from fasta file 45 | def subset_from_fasta(self, output=None, max_seq_len=1022): 46 | seqs = self.fasta_to_list() 47 | subset = [] 48 | for record in seqs: 49 | if len(record.seq) > max_seq_len: 50 | splits = range(int(len(record.seq)/max_seq_len) + 1) 51 | 52 | for split in splits: 53 | seq_len = str(record.seq[split*max_seq_len: (split*max_seq_len) + max_seq_len]) 54 | 55 | subset.append(create_seqrecord(id='{}_{}'.format(record.id, split), seq=seq_len)) 56 | 57 | SeqIO.write(subset, output, "fasta") 58 | 59 | # Count the number of protein sequences in a fasta file with biopython -- slower. 60 | def count_proteins_biopython(self): 61 | num = len(list(SeqIO.parse(self.fasta, "fasta"))) 62 | return num 63 | -------------------------------------------------------------------------------- /Classes/Interpro.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import networkx as nx 3 | import pandas as pd 4 | import CONSTANTS 5 | from Utils import is_file, pickle_save, pickle_load 6 | 7 | 8 | class Interpro: 9 | ''' 10 | Class to handle interpro data 11 | ''' 12 | 13 | def __init__(self, ont): 14 | 15 | self.lines = None 16 | self.raw_graph_file = CONSTANTS.ROOT_DIR + "interpro/ParentChildTreeFile.txt" 17 | self.graph = nx.DiGraph() 18 | self.remap_keys = {} 19 | self.ont = ont 20 | self.data_path = CONSTANTS.ROOT_DIR + "interpro/interpro_data" 21 | self.ohe_path = CONSTANTS.ROOT_DIR + "interpro/interpro_ohe_{}".format(self.ont) 22 | self.categories_path = CONSTANTS.ROOT_DIR + "interpro/categories_{}".format(self.ont) 23 | self.category_count = CONSTANTS.ROOT_DIR + "interpro/category_count_{}".format(self.ont) 24 | 25 | 26 | def propagate_graph(self): 27 | self.read_file() 28 | 29 | for node in self.lines: 30 | # 8 dashes 31 | if node[0].startswith("--------"): 32 | l5 = node[0].strip("--------") 33 | self.graph.add_edges_from([(l4, l5)]) 34 | # 6 dashes 35 | elif node[0].startswith("------"): 36 | l4 = node[0].strip("------") 37 | self.graph.add_edges_from([(l3, l4)]) 38 | # 4 dashes 39 | elif node[0].startswith("----"): 40 | l3 = node[0].strip("----") 41 | self.graph.add_edges_from([(l2, l3)]) 42 | # 2 dashes 43 | elif node[0].startswith("--"): 44 | l2 = node[0].strip("--") 45 | self.graph.add_edges_from([(l1, l2)]) 46 | else: 47 | l1 = node[0] 48 | if not self.graph.has_node(l1): 49 | self.graph.add_node(l1) 50 | 51 | 52 | def read_file(self): 53 | rels = open(self.raw_graph_file, 'r') 54 | self.lines = [i.rstrip('\n').split("::") for i in rels.readlines()] 55 | 56 | 57 | def get_graph(self): 58 | self.propagate_graph() 59 | return self.graph 60 | 61 | 62 | def generate_features(): 63 | # generate features from interproscan. 64 | files = list(range(10000, 80000, 10000)) + [79220] 65 | # generate from terminal in chunks 66 | for file in files: 67 | CMD = "interproscan-5.61-93.0-64-bit/interproscan-5.61-93.0/interproscan.sh -cpu 10 \ 68 | -i uniprot_fasta_{}.fasta -o ./interpro_out_{} -f TSV --goterms".format(file, file) 69 | subprocess.call(CMD, shell=True) 70 | 71 | @staticmethod 72 | def merge_chunks(num_files=23): 73 | infile = CONSTANTS.ROOT_DIR + "interpro/interpro_out_{}" 74 | # merge chunks 75 | df = pd.DataFrame() 76 | for file in range(1, num_files, 1): 77 | data = pd.read_csv(infile.format(file), sep="\t", 78 | names=["Protein accession", "Sequence MD5", "Sequence length", "Analysis", 79 | "Signature accession", "Signature description", "Start location", 80 | "Stop location", "Score", "Status", "Date", 81 | "InterPro annotations", "InterPro annotations description ", "GO annotations"]) 82 | df = pd.concat([df, data], axis=0) 83 | # passed additional quality checks and is very unlikely to be a false match. 84 | df = df[['Protein accession', 'InterPro annotations']] 85 | df = df[df["InterPro annotations"] != "-"] 86 | df.to_csv(CONSTANTS.ROOT_DIR + "interpro/interpro_filtered.csv", index=False, sep="\t") 87 | 88 | 89 | def get_features(self): 90 | if not is_file(CONSTANTS.ROOT_DIR + "interpro/interpro_filtered.csv"): 91 | self.merge_chunks(num_files=23) 92 | data = pd.read_csv(CONSTANTS.ROOT_DIR + "interpro/interpro_filtered.csv", sep="\t") 93 | return data 94 | 95 | 96 | def create_interpro_data(self): 97 | features = self.get_features() 98 | self.get_graph() 99 | data = {} 100 | 101 | for line_number, (index, row) in enumerate(features.iterrows()): 102 | acc = row[0] 103 | annot = row[1] 104 | try: 105 | tmp = nx.descendants(self.graph, annot) | set([annot]) 106 | except nx.exception.NetworkXError: 107 | tmp = set([annot]) 108 | 109 | if acc in data: 110 | data[acc].update(tmp) 111 | else: 112 | data[acc] = tmp 113 | pickle_save(data, self.data_path) 114 | 115 | 116 | def create_features(self, ont): 117 | 118 | # Convert Interpro to One-Hot 119 | if not is_file(self.data_path + ".pickle"): 120 | print(self.data_path + ".pickle") 121 | self.create_interpro_data() 122 | 123 | categories = set() 124 | data = pickle_load(self.data_path) 125 | 126 | train_proteins = pickle_load(CONSTANTS.ROOT_DIR + "train_validation") 127 | val_proteins = set(train_proteins[self.ont]['validation']) 128 | train_proteins = set(train_proteins[self.ont]['train']) 129 | 130 | test_proteins = set(pickle_load(CONSTANTS.ROOT_DIR + "test_proteins")) 131 | 132 | found_proteins = set() 133 | for protein, category in data.items(): 134 | if protein in train_proteins: 135 | categories.update(category) 136 | found_proteins.add(protein) 137 | 138 | categories = list(categories) 139 | categories.sort() 140 | 141 | all_proteins = train_proteins.union(val_proteins).union(test_proteins) 142 | print("training: {}, validation: {}, test: {}, all: {}".format(len(train_proteins), len(val_proteins), len(test_proteins), len(all_proteins))) 143 | print("training proteins with interpro: {}".format(len(found_proteins))) 144 | 145 | ohe = {} 146 | category_count = {i: 0 for i in categories} 147 | 148 | for protein, annots in data.items(): 149 | if protein in all_proteins: 150 | ohe[protein] = [] 151 | for cat in categories: 152 | if cat in annots: 153 | ohe[protein].append(1) 154 | category_count[cat] = category_count[cat] + 1 155 | else: 156 | ohe[protein].append(0) 157 | 158 | print(len(ohe.keys())) 159 | 160 | pickle_save(categories, self.categories_path) 161 | pickle_save(ohe, self.ohe_path) 162 | pickle_save(category_count, self.category_count) 163 | 164 | 165 | def get_interpro_ohe_data(self): 166 | if not is_file(self.ohe_path + ".pickle") or not is_file(self.categories_path + ".pickle"): 167 | print("creating interpro features") 168 | self.create_features(self.ont) 169 | return pickle_load(self.ohe_path), pickle_load(self.categories_path), pickle_load(self.category_count) 170 | 171 | 172 | 173 | def get_interpro_test(self): 174 | self.get_graph() 175 | # load test interpro 176 | '''data = pd.read_csv(CONSTANTS.ROOT_DIR + "interpro/test_intepro.out", sep="\t", 177 | names=["Protein accession", "Sequence MD5", "Sequence length", "Analysis", 178 | "Signature accession", "Signature description", "Start location", 179 | "Stop location", "Score", "Status", "Date", 180 | "InterPro annotations", "InterPro annotations description ", "GO annotations"])''' 181 | data = pd.read_csv(CONSTANTS.ROOT_DIR + "interpro/test_intepro.out", sep="\t", 182 | names=["Protein accession", "InterPro annotations"]) 183 | 184 | annots = {} 185 | for line_number, (index, row) in enumerate(data.iterrows()): 186 | 187 | acc = row.iloc[0] 188 | try: 189 | tmp = nx.descendants(self.graph, row.iloc[1]) | set([row.iloc[1]]) 190 | except nx.exception.NetworkXError: 191 | tmp = set([row.iloc[1]]) 192 | 193 | if acc in annots: 194 | annots[acc].update(tmp) 195 | else: 196 | annots[acc] = set(tmp) 197 | 198 | 199 | categories = pickle_load(self.categories_path) 200 | ohe = {} 201 | 202 | for protein, annot in annots.items(): 203 | ohe[protein] = [] 204 | for cat in categories: 205 | if cat in annot: 206 | ohe[protein].append(1) 207 | else: 208 | ohe[protein].append(0) 209 | 210 | 211 | 212 | return ohe, categories, 0 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | def create_indicies(): 223 | onts = ['cc', 'mf', 'bp'] 224 | 225 | for ont in onts: 226 | print("Processing {}".format(ont)) 227 | cat_counts = pickle_load(CONSTANTS.ROOT_DIR + "interpro/category_count_{}".format(ont)) 228 | cats = pickle_load(CONSTANTS.ROOT_DIR + "interpro/categories_{}".format(ont)) 229 | 230 | indicies = {3:[], 5:[], 10:[], 50:[], 100:[], 250:[], 500:[]} 231 | 232 | for ind in indicies: 233 | for pos, cat in enumerate(cats): 234 | if cat_counts[cat] > ind: 235 | indicies[ind].append(pos) 236 | 237 | pickle_save(indicies, CONSTANTS.ROOT_DIR + "interpro/indicies_{}".format(ont)) 238 | 239 | # create_indicies() -------------------------------------------------------------------------------- /Classes/STRING.py: -------------------------------------------------------------------------------- 1 | import CONSTANTS 2 | import pickle, os 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | import os.path as osp 7 | import torch_geometric.data as pygdata 8 | from Utils import is_file, pickle_save 9 | from torch_geometric.data import Data 10 | 11 | def pickle_load(filename): 12 | with open('{}.pickle'.format(filename), 'rb') as handle: 13 | return pickle.load(handle) 14 | 15 | class STRING: 16 | ''' 17 | Class to handle STRING data 18 | ''' 19 | 20 | def __init__(self, session='train'): 21 | self.string_dbase = CONSTANTS.ROOT_DIR + "STRING/protein.links.v11.5.txt" 22 | self.mapping_file = CONSTANTS.ROOT_DIR + "STRING/protein.aliases.v11.5.txt" 23 | self.mapping = {} 24 | self.mapping_filtered = CONSTANTS.ROOT_DIR + "STRING/filtered" 25 | 26 | ### 27 | self.string_file = CONSTANTS.ROOT_DIR + "STRING/string.csv" 28 | self.neighbours_file = CONSTANTS.ROOT_DIR + "STRING/neighbours" 29 | 30 | 31 | def extract_mapping(self): 32 | if not is_file(self.mapping_filtered + ".pickle"): 33 | print("extracting mappings") 34 | proteins = set(pickle_load(CONSTANTS.ROOT_DIR + "all_proteins_cafa")) 35 | test_set = set(pickle_load(CONSTANTS.ROOT_DIR + "test_proteins")) 36 | proteins = proteins.union(test_set) 37 | 38 | print(len(proteins)) 39 | 40 | with open(self.mapping_file) as in_file: 41 | next(in_file) 42 | for line in in_file: 43 | x = line.strip().split("\t") 44 | if x[2] == "BLAST_UniProt_AC" and x[1] in proteins: 45 | self.mapping[x[0]] = x[1] 46 | pickle_save(self.mapping, self.mapping_filtered) 47 | else: 48 | print("mapping file exist, loading") 49 | self.extract_mapping = pickle_load(self.mapping_filtered) 50 | print("Mapping finished") 51 | 52 | def extract_uniprot(self): 53 | if not is_file(self.string_file + ".pickle"): 54 | self.mapping = pickle_load(self.mapping_filtered) 55 | interactions = [["String protein 1", "String protein 2", "Combined score", "Uniprot protein 1", "Uniprot protein 2"]] 56 | with open(self.string_dbase) as in_file: 57 | next(in_file) 58 | for line in in_file: 59 | x = line.strip().split(" ") 60 | if x[0] in self.mapping and x[1] in self.mapping: 61 | interactions.append([x[0], x[1], int(x[2]), self.mapping[x[0]], self.mapping[x[1]]]) 62 | 63 | df = pd.DataFrame(interactions[1:], columns=interactions[0]) 64 | df.to_csv(self.string_file, sep='\t', index=False) 65 | 66 | 67 | def get_String(self, confidence=0.7): 68 | 69 | if not is_file(self.string_file): 70 | print("generating interactions") 71 | self.extract_mapping() 72 | self.extract_uniprot() 73 | 74 | data = pd.read_csv(self.string_file, sep='\t') 75 | data = data[["Uniprot protein 1", "Uniprot protein 2", "Combined score"]] 76 | data = data[data["Combined score"] > confidence * 1000] 77 | return data 78 | 79 | 80 | def get_string_neighbours(self, ontology, confidence=0.7, recreate=False): 81 | 82 | proteins = set(pickle_load(CONSTANTS.ROOT_DIR + "/{}/all_proteins".format(ontology))) 83 | 84 | if recreate == True or not is_file(self.neighbours_file + "_" + ontology + ".pickle"): 85 | x = self.get_String(confidence=confidence) 86 | 87 | x = x[x['Uniprot protein 1'].isin(proteins) & x['Uniprot protein 2'].isin(proteins)] 88 | 89 | 90 | data = {} 91 | for index, row in x.iterrows(): 92 | p1, p2, prob = row[0], row[1], row[2] 93 | 94 | if p1 in data: 95 | data[p1].add(p2) 96 | else: 97 | data[p1] = set([p2, ]) 98 | 99 | if p2 in data: 100 | data[p2].add(p1) 101 | else: 102 | data[p2] = set([p1, ]) 103 | 104 | pickle_save(data, self.neighbours_file + "_" + ontology) 105 | else: 106 | data = pickle_load(self.neighbours_file + "_" + ontology) 107 | 108 | return data 109 | 110 | 111 | def create_pytorch_graph(self, ont): 112 | 113 | proteins = pickle_load(CONSTANTS.ROOT_DIR + "{}/all_proteins".format(ont)) 114 | 115 | labels = pickle_load(CONSTANTS.ROOT_DIR + "{}/labels".format(ont)) 116 | 117 | 118 | x = [] 119 | y = [] 120 | for j, i in enumerate(proteins): 121 | tmp = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 122 | string_data = tmp['string_{}'.format(ont)].x 123 | x.append(torch.mean(string_data, dim=0).unsqueeze(0)) 124 | y. append(torch.tensor(labels[i], dtype=torch.long).unsqueeze(0)) 125 | 126 | 127 | y = torch.cat(y, dim=0) 128 | x = torch.cat(x, dim=0) 129 | 130 | protein_dic = { prot: pos for pos, prot in enumerate(proteins)} 131 | 132 | interactions = self.get_string_neighbours(ont) 133 | 134 | 135 | #rows = list(protein_dic.values()) # [] 136 | #cols = list(protein_dic.values()) # [] 137 | rows = [] 138 | cols = [] 139 | for src in interactions: 140 | for des in interactions[src]: 141 | if src == des: 142 | pass 143 | else: 144 | # sort small first 145 | _row = protein_dic[src] 146 | _col = protein_dic[des] 147 | 148 | 149 | rows.append(_row) 150 | cols.append(_col) 151 | 152 | rows.append(_col) 153 | cols.append(_row) 154 | 155 | assert len(rows) == len(cols) 156 | 157 | 158 | nodes = np.unique(rows + cols) 159 | 160 | rows = np.array(rows, dtype='int64') 161 | cols = np.array(cols, dtype='int64') 162 | edges = torch.tensor(np.array([rows, cols])) 163 | 164 | train_size = int(len(nodes)*0.85) 165 | val_size = len(nodes) - train_size 166 | 167 | train_set = nodes[0:train_size] 168 | val_set = nodes[train_size:] 169 | 170 | 171 | # assert len(train_set)+len(val_set) == len(nodes) == len(protein_dic) 172 | 173 | 174 | # train_mask = torch.zeros(len(nodes),dtype=torch.long) 175 | # for i in train_set: 176 | # train_mask[i] = 1. 177 | 178 | # val_mask = torch.zeros(len(nodes),dtype=torch.long) 179 | # for i in val_set: 180 | # val_mask[i] = 1. 181 | 182 | 183 | # data = Data(edge_index=edges, train_mask=train_mask, val_mask=val_mask, 184 | # key_val=protein_dic, x=x, y=y) 185 | 186 | data = Data(edge_index=edges, x=x, y=y, key_val=protein_dic) 187 | 188 | 189 | torch.save(data, osp.join(CONSTANTS.ROOT_DIR + "{}/node2vec.pt".format(ont))) 190 | 191 | def get_embeddings(self, ontology): 192 | path = CONSTANTS.ROOT_DIR + "{}/node2vec.pt".format(ontology) 193 | if is_file(path): 194 | pass 195 | else: 196 | print(" generating node2vec graph") 197 | self.create_pytorch_graph(ontology) 198 | 199 | return torch.load(path) 200 | 201 | 202 | 203 | # ROOT = "/home/fbqc9/Workspace/DATA/" 204 | # String = STRING(ROOT+"STRING/protein.links.v11.5.txt", ROOT+"STRING/protein.aliases.v11.5.txt") 205 | # # String.extract_mapping() 206 | # # String.extract_uniprot() 207 | # data = String.get_String() 208 | 209 | # data = data[["Uniprot protein 1", "Uniprot protein 2", "Combined score"]] 210 | # print(data["Combined score"].describe().apply(lambda x: format(x, 'f'))) -------------------------------------------------------------------------------- /Classes/Templates.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import CONSTANTS 4 | from Utils import is_file 5 | from preprocessing.utils import pickle_load, pickle_save 6 | 7 | 8 | class Templates: 9 | """ 10 | This class is used to handle all template generations 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | self.dir = kwargs.get('directory', CONSTANTS.ROOT_DIR + "datasets/") 15 | self.file_name = kwargs.get('file_name', "template") 16 | self.scaling = kwargs.get('scaling', "scaled") # norm / scaling 17 | self.template = {} 18 | self.diamond = kwargs.get('diamond_path', CONSTANTS.ROOT_DIR + "diamond/output.tsv") 19 | 20 | def generate_templates(self): 21 | if not is_file(self.dir + "template"): 22 | self._generate_template() 23 | else: 24 | self.template = pickle_load(self.dir + "template") 25 | 26 | def _generate_template(self): 27 | ontologies = ("cc", "mf", "bp") 28 | tmp = pickle_load(CONSTANTS.ROOT_DIR + "datasets/training_validation") 29 | for ontology in ontologies: 30 | self.template[ontology] = {} 31 | proteins = tmp[ontology]['train'].union(tmp[ontology]['valid']) 32 | labels = pickle_load(CONSTANTS.ROOT_DIR + "datasets/labels")[ontology] 33 | 34 | # BLAST Similarity 35 | diamond_scores = {} 36 | with open(self.diamond) as f: 37 | for line in f: 38 | it = line.strip().split() 39 | if it[0] not in diamond_scores: 40 | diamond_scores[it[0]] = {} 41 | diamond_scores[it[0]][it[1]] = float(it[2]) 42 | 43 | if self.scaling == 'norm': 44 | pass 45 | elif self.scaling == 'scaled': 46 | for protein in proteins: 47 | if protein in diamond_scores: 48 | sim_prots = diamond_scores[protein] 49 | neighbours = sim_prots.items() 50 | 51 | neighbour_score = [] 52 | go_scores = [] 53 | for neighbour, _score in neighbours: 54 | if neighbour in labels and _score < 100: 55 | go_scores.append(labels[neighbour]) 56 | neighbour_score.append(_score / 100) 57 | go_scores = np.array(go_scores) 58 | neighbour_score = np.array(neighbour_score) 59 | 60 | _score = np.matmul(neighbour_score, go_scores) 61 | 62 | if len(_score.shape) == 0: 63 | _score = np.zeros(len(labels[protein])) 64 | 65 | else: 66 | _score = np.zeros(len(labels[protein])) 67 | 68 | self.template[ontology][protein] = _score 69 | 70 | pickle_save(self.template, self.dir + "template") 71 | 72 | def get(self): 73 | return self.template 74 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | ## Train Dataset 2 | ``` 3 | Downloaded from: http://ftp.ebi.ac.uk/pub/databases/GO/goa/old/UNIPROT/goa_uniprot_all.gaf.212.gz 4 | ``` 5 | 6 | 7 | ## Test Dataset 8 | ``` 9 | Downloaded from: http://ftp.ebi.ac.uk/pub/databases/GO/goa/old/UNIPROT/goa_uniprot_all.gaf.218.gz 10 | date-generated: 2023-12-04 09:42 11 | ``` 12 | 13 | 14 | #### Preprocessed Train/Validation Data Description 15 | ``` 16 | train_validation fasta: TFewData/ont/train_sequences.fasta 17 | test fasta: test_fasta.fasta 18 | 19 | order of ontologies: TFewData/ont/sorted_terms.pickle 20 | Index of ontology in model: term_indicies.pickle 21 | train proteins: TFewData/ont/train_proteins.pickle 22 | validation proteins: TFewData/ont/validation_proteins.pickle 23 | train and validation proteins: TFewData/ont/all_proteins.pickle 24 | label data: TFewData/ont/graph.pt 25 | 26 | processed dataset: 27 | TFewData/ont/train_data.pickle 28 | TFewData/ont/validation_data.pickle 29 | 30 | Format: 31 | Dictionary containing the preprocessed. We explored esm, msa & interpro. 32 | Each index contains the data for same protein 33 | dictionary{ 34 | labels: [], labels 35 | 'esm2_t48': [], :-> esm data 36 | 'msa_1b': [] :-> msa data 37 | 'interpro': [], :-> Interpro data 38 | 'diamond': [], :-> Did not use 39 | 'string': [], :-> Did not use 40 | 'protein: [] :-> protein name 41 | } 42 | 43 | ``` 44 | 45 | 46 | #### Preprocessed Test Data Description 47 | ``` 48 | Created with create_test.py script, adapted from 49 | https://github.com/nguyenngochuy91/CAFA_benchmark(create_benchmark.py) 50 | 51 | Inputs: 52 | t1: goa_uniprot_all.gaf.212.gz 53 | t2: goa_uniprot_all.gaf.218.gz 54 | 55 | 56 | TFewData/test/t2/test_proteins: 57 | (NK or LK) for each ontology (bp,cc,mf). 58 | 59 | TFewData/test/t2/groundtruth: 60 | Test groundtruth 61 | 62 | 63 | Predictions from various models compared are kept in: 64 | TFewData/evaluation 65 | ``` 66 | 67 | 68 | #### Trained models 69 | ``` 70 | TFewData/ont/full_gcn :-> final model 71 | TFewData/ont/models/label/GCN :-> label embedding model 72 | 73 | All other models will be uploaded to zenodo soon. 74 | ``` -------------------------------------------------------------------------------- /DataGen/Embeddings.py: -------------------------------------------------------------------------------- 1 | # generate embedding from esm sequence 2 | import subprocess 3 | import torch 4 | from Bio import SeqIO 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | 8 | 9 | def fasta_to_dictionary(fasta_file): 10 | data = {} 11 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 12 | data[seq_record.id] = (seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 13 | return data 14 | 15 | 16 | def create_seqrecord(id="", name="", description="", seq=""): 17 | record = SeqRecord(Seq(seq), id=id, name=name, description=description) 18 | return record 19 | 20 | 21 | def generate_bulk_embedding(fasta): 22 | # name model output dir, embedding layer 1, embedding layer 2, batch 23 | models = (("esm_msa_1b", "esm_msa1b_t12_100M_UR50S", "msa", "esm_msa_1b", 11, 12, 10), 24 | ("esm_2", "esm2_t48_15B_UR50D", fasta, "esm2_t48", 47, 48, 100), 25 | ("esm_2", "esm2_t36_3B_UR50D", fasta, "esm_t36", 35, 36, 4096)) 26 | model = models[1] 27 | CMD = "python -u {} {} {} /home/fbqc9/{} --repr_layers {} {} --include mean per_tok " \ 28 | "--toks_per_batch {}".format("external/extract.py", model[1], model[2], \ 29 | model[3], model[4], model[5], model[6]) 30 | 31 | print(CMD) 32 | subprocess.call(CMD, shell=True, cwd="./") 33 | 34 | 35 | def generate_embeddings(fasta_path): 36 | def merge_pts(keys, fasta): 37 | embeddings = [47, 48] 38 | for pos, protein in enumerate(keys): 39 | print(pos, protein) 40 | fasta_dic = fasta_to_dictionary(fasta) 41 | 42 | tmp = [] 43 | for level in range(keys[protein]): 44 | os_path = "/home/fbqc9/esm2_t48/{}_{}.pt".format(protein, level) 45 | tmp.append(torch.load(os_path)) 46 | 47 | data = {'representations': {}, 'mean_representations': {}} 48 | for index in tmp: 49 | for rep in embeddings: 50 | # print(index['mean_representations'][rep].shape, torch.mean(index['representations'][rep], dim=0).shape) 51 | assert torch.equal(index['mean_representations'][rep], torch.mean(index['representations'][rep], dim=0)) 52 | 53 | if rep in data['representations']: 54 | data['representations'][rep] = torch.cat((data['representations'][rep], index['representations'][rep])) 55 | else: 56 | data['representations'][rep] = index['representations'][rep] 57 | 58 | for emb in embeddings: 59 | assert len(fasta_dic[protein][3]) == data['representations'][emb].shape[0] 60 | 61 | for rep in embeddings: 62 | data['mean_representations'][rep] = torch.mean(data['representations'][rep], dim=0) 63 | 64 | # print("saving {}".format(protein)) 65 | torch.save(data, "/home/fbqc9/esm2_t48/{}.pt".format(protein)) 66 | 67 | def crop_fasta(record): 68 | splits = [] 69 | keys = {} 70 | main_id = record.id 71 | chnks = len(record.seq) / 1021 72 | remnder = len(record.seq) % 1021 73 | chnks = int(chnks) if remnder == 0 else int(chnks) + 1 74 | keys[main_id] = chnks 75 | for pos in range(chnks): 76 | id = "{}_{}".format(main_id, pos) 77 | seq = str(record.seq[pos * 1021:(pos * 1021) + 1021]) 78 | splits.append(create_seqrecord(id=id, name=id, description="", seq=seq)) 79 | return splits, keys 80 | 81 | keys = {} 82 | sequences = [] 83 | input_seq_iterator = SeqIO.parse(fasta_path, "fasta") 84 | for record in input_seq_iterator: 85 | if len(record.seq) > 1021: 86 | _seqs, _keys = crop_fasta(record) 87 | sequences.extend(_seqs) 88 | keys.update(_keys) 89 | else: 90 | sequences.append(record) 91 | 92 | cropped_fasta = "temp.fasta" 93 | SeqIO.write(sequences, cropped_fasta, "fasta") 94 | 95 | # generate_bulk_embedding(cropped_fasta) 96 | 97 | # merge 98 | if len(keys) > 0: 99 | print("Found {} protein with length > 1021".format(len(keys))) 100 | merge_pts(keys, fasta_path) 101 | 102 | 103 | # fasta_path = "/home/fbqc9/Workspace/DATA/uniprot/test_fasta_rem.fasta" 104 | fasta_path = "/home/fbqc9/Workspace/DATA/uniprot/test_fasta_rem.fasta" 105 | generate_embeddings(fasta_path) 106 | -------------------------------------------------------------------------------- /DataGen/LabelEmbeddings: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import obonet 5 | from sentence_transformers import SentenceTransformer 6 | import torch 7 | from transformers import AutoTokenizer, AutoModel 8 | import networkx as nx 9 | from torch_geometric.utils.convert import from_scipy_sparse_matrix 10 | from torch_geometric.data import Data 11 | 12 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 13 | GO_PATH = ROOT_DIR + "/obo/go-basic.obo" 14 | 15 | 16 | def pickle_save(data, filename): 17 | with open('{}.pickle'.format(filename), 'wb') as handle: 18 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 19 | 20 | 21 | def pickle_load(filename): 22 | with open('{}.pickle'.format(filename), 'rb') as handle: 23 | return pickle.load(handle) 24 | 25 | 26 | def is_file(path): 27 | return os.path.exists(path) 28 | 29 | 30 | def get_embedding_biobert(definitions): 31 | model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO') 32 | embeddings = model.encode(definitions) 33 | return embeddings 34 | 35 | 36 | def get_embedding_bert(definitions): 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased") 40 | model = AutoModel.from_pretrained("bert-large-uncased").to(device) 41 | tokenized_train = tokenizer(definitions, padding = True, truncation = True, return_tensors="pt").to(device) 42 | 43 | 44 | with torch.no_grad(): 45 | hidden_train = model(**tokenized_train) 46 | 47 | cls_token = hidden_train.last_hidden_state[:,0,:] 48 | return cls_token 49 | 50 | 51 | 52 | 53 | ontologies = ['cc', 'mf', 'bp'] 54 | go_graph = obonet.read_obo(open(GO_PATH, 'r')) 55 | node_desc = dict(go_graph.nodes(data="def")) 56 | 57 | accepted_edges = set() 58 | unaccepted_edges = set() 59 | 60 | for edge in go_graph.edges: 61 | if edge[2] == 'is_a' or edge[2] == 'part_of': 62 | accepted_edges.add(edge) 63 | else: 64 | unaccepted_edges.add(edge) 65 | go_graph.remove_edges_from(unaccepted_edges) 66 | 67 | assert nx.is_directed_acyclic_graph(go_graph) 68 | 69 | 70 | 71 | # quality check to extract textual definition only(no refference) 72 | for i in node_desc: 73 | assert node_desc[i].count('"') == 2 74 | 75 | 76 | 77 | for ontology in ontologies: 78 | 79 | biobert_path = ROOT_DIR + '{}/biobert.pt'.format(ontology) 80 | bert_path = ROOT_DIR + '{}/bert.pt'.format(ontology) 81 | hierarchy_path = ROOT_DIR + '{}/hierarchy'.format(ontology) 82 | save_path = ROOT_DIR + '{}/graph.pt'.format(ontology) 83 | 84 | print("Computing for {}".format(ontology)) 85 | 86 | go_terms = pickle_load(ROOT_DIR+"/{}/sorted_terms".format(ontology)) 87 | 88 | definitions = [node_desc[go_term].split('"')[1] for go_term in go_terms] 89 | 90 | assert len(go_terms) == len(definitions) 91 | 92 | if not is_file(biobert_path): 93 | print("Generating Biobert") 94 | biobert_embeddings = torch.from_numpy(get_embedding_biobert(definitions)) 95 | torch.save(biobert_embeddings, biobert_path) 96 | else: 97 | biobert_embeddings = torch.load(biobert_path) 98 | 99 | if not is_file(bert_path): 100 | print("Generating Bert") 101 | tmp = [] 102 | for i in range(0, len(go_terms) + 1, 1000): 103 | bert_embeddings = get_embedding_bert(definitions[i:i + 1000]) 104 | tmp.append(bert_embeddings) 105 | bert_embeddings = torch.concat(tmp, dim=0).cpu() 106 | torch.save(bert_embeddings, bert_path) 107 | else: 108 | bert_embeddings = torch.load(bert_path) 109 | 110 | 111 | if not is_file(hierarchy_path + ".pickle"): 112 | print("Generating Hierarchy") 113 | hierarchy = np.zeros((len(go_terms), len(go_terms))) 114 | for rows in range(len(go_terms)): 115 | for cols in range(len(go_terms)): 116 | row = go_terms[rows] 117 | col = go_terms[cols] 118 | 119 | if col in nx.descendants(go_graph, row).union(set([row, ])): 120 | hierarchy[rows, cols] = 1 121 | 122 | pickle_save(hierarchy, ROOT_DIR + "{}/hierarchy".format(ontology)) 123 | 124 | else: 125 | hierarchy = pickle_load(hierarchy_path) 126 | 127 | 128 | subgraph = go_graph.subgraph(go_terms).copy() 129 | 130 | A = nx.to_scipy_sparse_array(subgraph, nodelist=go_terms) 131 | data = from_scipy_sparse_matrix(A) 132 | 133 | hierarchy = torch.tensor(hierarchy, dtype=torch.float32) 134 | 135 | 136 | print(biobert_embeddings) 137 | print(bert_embeddings) 138 | print(hierarchy) 139 | 140 | 141 | 142 | data = Data(x=hierarchy, edge_index=data[0], \ 143 | biobert=biobert_embeddings, bert=bert_embeddings) 144 | 145 | 146 | torch.save(data, save_path) 147 | 148 | -------------------------------------------------------------------------------- /DataGen/msa.py: -------------------------------------------------------------------------------- 1 | # HHblits msa generation 2 | import sys 3 | import csv 4 | import subprocess 5 | from os import listdir 6 | from Bio import SeqIO 7 | 8 | def generate_msas(): 9 | # this part was run on lotus 10 | gen_a3ms = listdir("/bmlfast/frimpong/shared_function_data/a3ms/") 11 | gen_a3ms = set([i.split(".")[0] for i in gen_a3ms]) 12 | 13 | single_fasta = listdir("/bmlfast/frimpong/shared_function_data/single_fastas/") 14 | single_fasta = set([i.split(".")[0] for i in single_fasta]) 15 | 16 | all_fastas = list(single_fasta.difference(gen_a3ms)) 17 | print(len(single_fasta), len(gen_a3ms), len(all_fastas)) 18 | 19 | 20 | all_fastas.sort() 21 | 22 | for fasta in all_fastas[0:100]: 23 | CMD = "hhblits -i /bmlfast/frimpong/shared_function_data/single_fastas/{}.fasta \ 24 | -d /bmlfast/frimpong/msa_database/uniref30/UniRef30_2023_02 -oa3m \ 25 | /bmlfast/frimpong/shared_function_data/a3ms/{}.a3m -cpu 2 -n 2".format(fasta, fasta) 26 | subprocess.call(CMD, shell=True, cwd="/bmlfast/frimpong/shared_function_data/") 27 | 28 | 29 | 30 | 31 | generate_msas() 32 | 33 | -------------------------------------------------------------------------------- /DataGen/msa_pred.py: -------------------------------------------------------------------------------- 1 | from Bio import SeqIO 2 | import traceback 3 | from scipy.spatial.distance import cdist 4 | from typing import Tuple, List, Set 5 | import numpy as np 6 | import string 7 | from Bio.UniProt import GOA 8 | import os 9 | import pickle 10 | 11 | 12 | def pickle_save(data, filename): 13 | with open('{}.pickle'.format(filename), 'wb') as handle: 14 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 15 | 16 | 17 | def pickle_load(filename): 18 | with open('{}.pickle'.format(filename), 'rb') as handle: 19 | return pickle.load(handle) 20 | 21 | # This is an efficient way to delete lowercase characters and insertion characters from a string 22 | deletekeys = dict.fromkeys(string.ascii_lowercase) 23 | deletekeys["."] = None 24 | deletekeys["*"] = None 25 | translation = str.maketrans(deletekeys) 26 | 27 | def get_id_msa(id: str) -> str: 28 | try: 29 | return id.split("_")[1] 30 | except IndexError: 31 | return id 32 | 33 | def remove_insertions(sequence: str) -> str: 34 | """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ 35 | return sequence.translate(translation) 36 | 37 | 38 | def read_msa(filename: str) -> List[Tuple[str, str]]: 39 | """ Reads the sequences from an MSA file, automatically removes insertions.""" 40 | return [(get_id_msa(record.id), remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] 41 | 42 | 43 | # Subsampling MSA 44 | # Select sequences from the MSA to maximize the hamming distance 45 | # Alternatively, can use hhfilter 46 | def greedy_select(msa: List[Tuple[str, str]], num_seqs: int , valid_prots: Set[str], mode: str = "max") -> List[Tuple[str, str]]: 47 | assert mode in ("max", "min") 48 | if len(msa) <= num_seqs: 49 | return msa 50 | 51 | array = np.array([list(seq) for id, seq in msa if id in valid_proteins], dtype=np.bytes_).view(np.uint8) 52 | 53 | optfunc = np.argmax if mode == "max" else np.argmin 54 | all_indices = np.arange(len(msa)) 55 | indices = [0] 56 | pairwise_distances = np.zeros((0, len(msa))) 57 | 58 | for _ in range(num_seqs - 1): 59 | dist = cdist(array[indices[-1:]], array, "hamming") 60 | pairwise_distances = np.concatenate([pairwise_distances, dist]) 61 | shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) 62 | shifted_index = optfunc(shifted_distance) 63 | index = np.delete(all_indices, indices)[shifted_index] 64 | indices.append(index) 65 | indices = sorted(indices) 66 | return [msa[idx] for idx in indices] 67 | 68 | 69 | def read_gaf(handle): 70 | dic = {} 71 | all_protein_name = set() 72 | # evidence from experimental 73 | Evidence = {'Evidence': set(["EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "TAS", "IC", "HTP", "HDA", "HMP", "HGI", "HEP"])} 74 | with open(handle, 'r') as handle: 75 | for rec in GOA.gafiterator(handle): 76 | if rec['DB'] == 'UniProtKB': 77 | all_protein_name.add(rec['DB_Object_ID']) 78 | if rec['DB_Object_ID'] not in dic: 79 | dic[rec['DB_Object_ID']] = {rec['Aspect']: set([rec['GO_ID']])} 80 | else: 81 | if rec['Aspect'] not in dic[rec['DB_Object_ID']]: 82 | dic[rec['DB_Object_ID']][rec['Aspect']] = set([rec['GO_ID']]) 83 | else: 84 | dic[rec['DB_Object_ID']][rec['Aspect']].add(rec['GO_ID']) 85 | return dic, all_protein_name 86 | 87 | 88 | 89 | # cv = pickle_load("proteins_xxx_proteins") 90 | 91 | 92 | 93 | # print(len(cv)) 94 | 95 | # exit() 96 | 97 | # valid_proteins = os.listdir("/home/fbqc9/a3ms") 98 | # valid_proteins = set([i.split(".")[0] for i in valid_proteins]) 99 | 100 | # print(len(valid_proteins)) 101 | 102 | # exit() 103 | data, proteins = read_gaf("/home/fbqc9/Workspace/DATA/uniprot/goa_uniprot_all.gaf.212") 104 | 105 | pickle_save(data, "data_xxx_data") 106 | 107 | pickle_save(proteins, "proteins_xxx_proteins") 108 | 109 | 110 | 111 | exit() 112 | num_seqs = 10 113 | 114 | inputs = read_msa("/home/fbqc9/a3ms/{}.a3m".format("Q54801")) 115 | 116 | 117 | valid_proteins = os.listdir("/home/fbqc9/a3ms") 118 | print(valid_proteins[:10]) 119 | valid_proteins = set([i.split(".")[0] for i in valid_proteins]) 120 | 121 | inputs = greedy_select(inputs, num_seqs=num_seqs, valid_prots=valid_proteins) -------------------------------------------------------------------------------- /DataGen/msa_test_shard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | from typing import Tuple, List 4 | import esm 5 | import numpy as np 6 | import torch 7 | from Bio import SeqIO 8 | import traceback 9 | from scipy.spatial.distance import cdist 10 | 11 | # This is an efficient way to delete lowercase characters and insertion characters from a string 12 | deletekeys = dict.fromkeys(string.ascii_lowercase) 13 | deletekeys["."] = None 14 | deletekeys["*"] = None 15 | translation = str.maketrans(deletekeys) 16 | 17 | 18 | def remove_insertions(sequence: str) -> str: 19 | """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ 20 | return sequence.translate(translation) 21 | 22 | 23 | def read_msa(filename: str) -> List[Tuple[str, str]]: 24 | """ Reads the sequences from an MSA file, automatically removes insertions.""" 25 | return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] 26 | 27 | 28 | # Subsampling MSA 29 | # Select sequences from the MSA to maximize the hamming distance 30 | # Alternatively, can use hhfilter 31 | def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: 32 | assert mode in ("max", "min") 33 | if len(msa) <= num_seqs: 34 | return msa 35 | array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) 36 | optfunc = np.argmax if mode == "max" else np.argmin 37 | all_indices = np.arange(len(msa)) 38 | indices = [0] 39 | pairwise_distances = np.zeros((0, len(msa))) 40 | for _ in range(num_seqs - 1): 41 | dist = cdist(array[indices[-1:]], array, "hamming") 42 | pairwise_distances = np.concatenate([pairwise_distances, dist]) 43 | shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) 44 | shifted_index = optfunc(shifted_distance) 45 | index = np.delete(all_indices, indices)[shifted_index] 46 | indices.append(index) 47 | indices = sorted(indices) 48 | return [msa[idx] for idx in indices] 49 | 50 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP 51 | from fairscale.nn.wrap import enable_wrap, wrap, auto_wrap 52 | 53 | import warnings 54 | warnings.filterwarnings("ignore", category=DeprecationWarning) 55 | warnings.filterwarnings("ignore") 56 | 57 | def generate_msa_embedding(in_dir, out_dir): 58 | if not os.path.exists(out_dir): 59 | os.makedirs(out_dir) 60 | 61 | a3ms = os.listdir(in_dir) 62 | a3ms = set([i.split(".")[0] for i in a3ms if i.endswith(".a3m")]) 63 | 64 | generated = os.listdir("/bmlfast/frimpong/shared_function_data/esm_msa1b") 65 | generated = set([i.split(".")[0] for i in generated]) 66 | print(len(a3ms), len(generated), len(a3ms.difference(generated))) 67 | 68 | 69 | 70 | #generated = generated.union(generated2) 71 | 72 | a3ms = a3ms.difference(generated) 73 | a3ms = sorted(list(a3ms))#[0:2500] 74 | # PDB_IDS.reverse() 75 | print(len(a3ms)) 76 | 77 | 78 | device = 'cuda:0' 79 | num_seqs = 128 80 | seq_len = 1024 81 | seq_len_min_1 = seq_len - 1 82 | 83 | # init the distributed world with world_size 1 84 | url = "tcp://localhost:23455" 85 | torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0) 86 | 87 | # initialize the model with FSDP wrapper 88 | fsdp_params = dict(wrapper_cls=FSDP, 89 | compute_device=device, 90 | mixed_precision=True, 91 | flatten_parameters=True, 92 | #state_dict_device=torch.device("cpu"), # reduce GPU mem usage 93 | cpu_offload=False 94 | ) # enable cpu offloading 95 | 96 | with enable_wrap(**fsdp_params): 97 | msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 98 | msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter() 99 | msa_transformer.eval() 100 | 101 | # msa_transformer = msa_transformer.eval()#.to(device) 102 | 103 | 104 | for name, child in msa_transformer.named_children(): 105 | 106 | if name == "layers": 107 | for layer_name, layer in child.named_children(): 108 | wrapped_layer = wrap(layer) 109 | setattr(child, layer_name, wrapped_layer) 110 | msa_transformer = wrap(msa_transformer) 111 | 112 | for pos, name in enumerate(a3ms): 113 | 114 | try: 115 | inputs = read_msa(f"{in_dir}/{name}.a3m") 116 | ref = inputs[0][1] 117 | print("# {}/{}: generating {}, lenth {}".format(pos, len(a3ms), name, len(ref))) 118 | 119 | if len(ref) <= seq_len_min_1: 120 | inputs = greedy_select(inputs, num_seqs=num_seqs) 121 | msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = \ 122 | msa_transformer_batch_converter([inputs]) 123 | # msa_transformer = msa_transformer.to(device) 124 | msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).to(device)).long() 125 | #msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(device) 126 | out = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12], return_contacts=False) 127 | 128 | results = { 129 | 'label': name, 130 | 'representations_12': out['representations'][12][:, :, 1: seq_len, :] 131 | } 132 | assert results['representations_12'].shape[2] == len(ref) 133 | results['representations_12'] = results['representations_12'].mean(2) 134 | torch.save(results, f"{out_dir}/{name}.pt") 135 | del out 136 | del results 137 | print("generated") 138 | 139 | else: 140 | _inputs = greedy_select(inputs, num_seqs=num_seqs) 141 | 142 | cuts = range(int(len(ref)/seq_len_min_1) + 1) 143 | 144 | rep_12 = [] 145 | log_ts = [] 146 | 147 | for cut in cuts: 148 | inputs = [('{}_{}'.format(x[0], cut), x[1][cut*seq_len_min_1: (cut*seq_len_min_1) + seq_len_min_1]) for x in _inputs] 149 | msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = \ 150 | msa_transformer_batch_converter([inputs]) 151 | #msa_transformer = msa_transformer.to(device) 152 | msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device).long() 153 | msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(device) 154 | out = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12], return_contacts=False) 155 | 156 | torch.save(out, f"{out_dir}/{name}_{cut}.pt") 157 | 158 | del out 159 | 160 | for cut in cuts: 161 | out = torch.load(f"{out_dir}/{name}_{cut}.pt") 162 | rep_12.append(out['representations'][12][:, :, 1: seq_len, :]) 163 | log_ts.append(out['logits']) 164 | 165 | del out 166 | 167 | results = { 168 | 'label': name, 169 | 'representations_12': torch.cat(rep_12, dim=2) 170 | } 171 | 172 | assert results['representations_12'].shape[2] == len(ref) 173 | results['representations_12'] = results['representations_12'].mean(2) 174 | torch.save(results, f"{out_dir}/{name}.pt") 175 | del results 176 | 177 | for cut in cuts: 178 | os.remove(f"{out_dir}/{name}_{cut}.pt") 179 | print("generated") 180 | except Exception as e: 181 | print(traceback.format_exc()) 182 | 183 | 184 | # add aggragation 185 | 186 | 187 | # /bmlfast/frimpong/shared_function_data/ 188 | # generate_msa_embedding("/bmlfast/frimpong/shared_function_data/a3ms", "/bmlfast/frimpong/shared_function_data/esm_msa1b") 189 | generate_msa_embedding("/bmlfast/frimpong/shared_function_data/a3ms", "/home/fbqc9/msa_new") 190 | # generate_msa_embedding("/home/fbqc9/a3ms", "/home/fbqc9/esm_msa1b") 191 | -------------------------------------------------------------------------------- /Dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from sklearn.utils.class_weight import compute_class_weight 4 | import numpy as np 5 | from Utils import pickle_load 6 | import CONSTANTS 7 | 8 | class TransFewDataset(Dataset): 9 | def __init__(self, data_pth=None, submodel=None): 10 | 11 | self.submodel = submodel 12 | 13 | data = pickle_load(data_pth) 14 | 15 | labels = data['labels'] 16 | labels = torch.cat(labels, dim=0) 17 | 18 | self.labels = labels 19 | self.esm_features = data['esm2_t48'] 20 | self.msa_features = data['msa_1b'] 21 | #self.diamond_features = data['diamond'] 22 | self.interpro_features = data['interpro'] 23 | #self.string_features = data['string'] 24 | 25 | 26 | def __getitem__(self, index): 27 | 28 | esm = self.esm_features[index] 29 | msa = self.msa_features[index] 30 | # diamond = self.diamond_features[index] 31 | interpro = self.interpro_features[index] 32 | # string = self.string_features[index] 33 | label = self.labels[index] 34 | 35 | if self.submodel == 'esm2_t48': 36 | return esm, label 37 | elif self.submodel == 'msa_1b': 38 | return msa, label 39 | elif self.submodel == 'interpro': 40 | return interpro, label 41 | elif self.submodel == 'full': 42 | # return esm, msa, diamond, interpro, string, label 43 | return esm, msa, interpro, label 44 | '''elif self.submodel == 'diamond': 45 | return diamond, label 46 | elif self.submodel == 'string': 47 | return string, label''' 48 | 49 | def __len__(self): 50 | return len(self.labels) 51 | 52 | 53 | class TestDataset(Dataset): 54 | def __init__(self, data_pth=None, submodel=None): 55 | 56 | self.submodel = submodel 57 | 58 | data = pickle_load(data_pth) 59 | 60 | self.proteins = data['protein'] 61 | self.esm_features = data['esm2_t48'] 62 | self.msa_features = data['msa_1b'] 63 | self.diamond_features = data['diamond'] 64 | self.interpro_features = data['interpro'] 65 | 66 | def __getitem__(self, index): 67 | 68 | esm = self.esm_features[index] 69 | msa = self.msa_features[index] 70 | diamond = self.diamond_features[index] 71 | interpro = self.interpro_features[index] 72 | proteins = self.proteins[index] 73 | 74 | 75 | 76 | if self.submodel == 'esm2_t48': 77 | return esm, proteins 78 | elif self.submodel == 'msa_1b': 79 | return msa, proteins 80 | elif self.submodel == 'diamond': 81 | return diamond, proteins 82 | elif self.submodel == 'interpro': 83 | return interpro, proteins 84 | elif self.submodel == 'full': 85 | return esm, msa, diamond, interpro, proteins 86 | 87 | def __len__(self): 88 | return len(self.proteins) 89 | 90 | 91 | 92 | class PredictDataset(Dataset): 93 | def __init__(self, data=None): 94 | 95 | self.esm_features = data['esm2_t48'] 96 | self.proteins = data['protein'] 97 | 98 | 99 | def __getitem__(self, index): 100 | 101 | esm = self.esm_features[index] 102 | proteins = self.proteins[index] 103 | 104 | return esm, proteins 105 | 106 | def __len__(self): 107 | return len(self.proteins) -------------------------------------------------------------------------------- /Dataset/Dataset_tofix.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import subprocess 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import os.path as osp 9 | import CONSTANTS 10 | from torch_geometric.data import Dataset 11 | from Classes.Diamond import Diamond 12 | from torch_geometric.data import Data, HeteroData 13 | from Classes.Interpro import Interpro 14 | from Classes.STRING import STRING 15 | from Utils import pickle_load, readlines_cluster 16 | import random 17 | 18 | 19 | class TransFunDataset(Dataset): 20 | """ 21 | Creates a dataset from a list of PDB files. 22 | :param file_list: path to LMDB file containing dataset 23 | :type file_list: list[Union[str, Path]] 24 | :param transform: transformation function for data augmentation, defaults to None 25 | :type transform: function, optional 26 | """ 27 | 28 | def __init__(self, transform=None, pre_transform=None, pre_filter=None, **kwargs): 29 | 30 | self.ont = kwargs.get('ont', None) 31 | self.split = kwargs.get('split', None) 32 | 33 | if self.split == 'selected': 34 | self.data = kwargs.get('proteins', 'proteins') 35 | else: 36 | self.cluster = readlines_cluster(CONSTANTS.ROOT_DIR + "{}/mmseq_0.6/final_clusters.csv".format(self.ont)) 37 | self.indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/{}_indicies".format(self.ont, self.split)) 38 | 39 | 40 | super().__init__(transform, pre_transform, pre_filter) 41 | 42 | @property 43 | def raw_dir(self) -> str: 44 | return CONSTANTS.ROOT_DIR + "/data/raw" 45 | 46 | @property 47 | def processed_dir(self) -> str: 48 | return CONSTANTS.ROOT_DIR + "/data/processed" 49 | 50 | @property 51 | def raw_file_names(self): 52 | 53 | if self.split == 'selected': 54 | return self.data 55 | else: 56 | x = list(pickle_load(CONSTANTS.ROOT_DIR + "train_validation")[self.ont][self.split]) 57 | self.data = x 58 | return x 59 | 60 | 61 | @property 62 | def processed_file_names(self): 63 | data = os.listdir(self.processed_dir) 64 | return data 65 | 66 | def download(self): 67 | pass 68 | 69 | 70 | def process(self): 71 | 72 | raw = set(self.raw_file_names) 73 | processed = set(self.processed_file_names) 74 | remain = raw - processed 75 | 76 | if len(remain) > 0: 77 | print("Raw Data {} --- Processed Data {} --- Remaining {}".\ 78 | format(len(raw), len(processed), len(remain))) 79 | 80 | # String Data 81 | ppi = STRING() 82 | ppi_data_cc = ppi.get_string_neighbours(ontology='cc', confidence=0.4, recreate=False) 83 | print("Number of nodes in ppi", len(ppi_data_cc)) 84 | 85 | ppi_data_mf = ppi.get_string_neighbours(ontology='mf', confidence=0.4, recreate=False) 86 | print("Number of nodes in ppi", len(ppi_data_mf)) 87 | 88 | ppi_data_bp = ppi.get_string_neighbours(ontology='bp', confidence=0.4, recreate=False) 89 | print("Number of nodes in ppi", len(ppi_data_bp)) 90 | 91 | labels_cc = pickle_load(CONSTANTS.ROOT_DIR + "cc/labels") 92 | 93 | labels_mf = pickle_load(CONSTANTS.ROOT_DIR + "mf/labels") 94 | 95 | labels_bp = pickle_load(CONSTANTS.ROOT_DIR + "bp/labels") 96 | 97 | for num, protein in enumerate(remain): 98 | print(protein, num, len(remain)) 99 | 100 | xx = torch.load("/home/fbqc9/Workspace/DATA/data/processed1/{}.pt".format(protein)) 101 | 102 | # STRING 103 | string_neighbours = ppi_data_cc.get(protein, []) 104 | string_cc = [] 105 | for neighbour in string_neighbours: 106 | if neighbour == protein: 107 | pass 108 | else: 109 | if neighbour in labels_cc: 110 | string_cc.append(np.array(labels_cc[neighbour], dtype=int)) 111 | 112 | string_neighbours = ppi_data_mf.get(protein, []) 113 | string_mf = [] 114 | for neighbour in string_neighbours: 115 | if neighbour == protein: 116 | pass 117 | else: 118 | if neighbour in labels_mf: 119 | string_mf.append(np.array(labels_mf[neighbour], dtype=int)) 120 | 121 | string_neighbours = ppi_data_bp.get(protein, []) 122 | string_bp = [] 123 | for neighbour in string_neighbours: 124 | if neighbour == protein: 125 | pass 126 | else: 127 | if neighbour in labels_bp: 128 | string_bp.append(np.array(labels_bp[neighbour], dtype=int)) 129 | 130 | if len(string_cc) > 0: 131 | string_cc = torch.Tensor(np.vstack(string_cc)) 132 | else: 133 | string_cc = torch.Tensor(np.array([0] * 2957)).unsqueeze(0) 134 | 135 | if len(string_mf) > 0: 136 | string_mf = torch.Tensor(np.vstack(string_mf)) 137 | else: 138 | string_mf = torch.Tensor(np.array([0] * 7224)).unsqueeze(0) 139 | 140 | if len(string_bp) > 0: 141 | string_bp = torch.Tensor(np.vstack(string_bp)) 142 | else: 143 | string_bp = torch.Tensor(np.array([0] * 21285)).unsqueeze(0) 144 | 145 | 146 | 147 | xx['string_cc'].x = string_cc 148 | xx['string_mf'].x = string_mf 149 | xx['string_bp'].x = string_bp 150 | 151 | del xx['interpro'] 152 | 153 | 154 | assert len(xx['esm2_t48'].x.shape) == len(xx['esm_msa1b'].x.shape) \ 155 | == len(xx['diamond_cc'].x.shape) == len(xx['diamond_mf'].x.shape) \ 156 | == len(xx['diamond_bp'].x.shape) == len(xx['string_cc'].x.shape) \ 157 | == len(xx['string_mf'].x.shape) == len(xx['string_bp'].x.shape) \ 158 | == len(xx['interpro_cc'].x.shape) == len(xx['interpro_mf'].x.shape)\ 159 | == len(xx['interpro_bp'].x.shape) == 2 160 | 161 | torch.save(xx, osp.join(self.processed_dir, f'{protein}.pt')) 162 | 163 | 164 | 165 | 166 | def len(self): 167 | if self.split == "train": 168 | return len(self.indicies) 169 | else: 170 | return len(self.raw_file_names) 171 | 172 | 173 | def get(self, idx): 174 | if self.split == "train": 175 | cluster_index = self.indicies[idx] 176 | rep = random.sample(self.cluster[cluster_index], 1)[0] 177 | assert rep in set(self.raw_file_names) 178 | return torch.load(osp.join(self.processed_dir, f'{rep}.pt')) 179 | else: 180 | rep = self.data[idx] 181 | return torch.load(osp.join(self.processed_dir, f'{rep}.pt')) -------------------------------------------------------------------------------- /Dataset/FastDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from sklearn.utils.class_weight import compute_class_weight 4 | import numpy as np 5 | from Utils import pickle_load 6 | import CONSTANTS 7 | 8 | class FastTransFunDataset(Dataset): 9 | def __init__(self, data_pth=None, submodel=None): 10 | 11 | self.submodel = submodel 12 | 13 | data = pickle_load(data_pth) 14 | 15 | labels = data['labels'] 16 | labels = torch.cat(labels, dim=0) 17 | 18 | self.labels = labels 19 | self.esm_features = data['esm2_t48'] 20 | self.msa_features = data['msa_1b'] 21 | #self.diamond_features = data['diamond'] 22 | self.interpro_features = data['interpro'] 23 | #self.string_features = data['string'] 24 | 25 | 26 | def __getitem__(self, index): 27 | 28 | esm = self.esm_features[index] 29 | msa = self.msa_features[index] 30 | # diamond = self.diamond_features[index] 31 | interpro = self.interpro_features[index] 32 | # string = self.string_features[index] 33 | label = self.labels[index] 34 | 35 | if self.submodel == 'esm2_t48': 36 | return esm, label 37 | elif self.submodel == 'msa_1b': 38 | return msa, label 39 | elif self.submodel == 'interpro': 40 | return interpro, label 41 | elif self.submodel == 'full': 42 | # return esm, msa, diamond, interpro, string, label 43 | return esm, msa, interpro, label 44 | '''elif self.submodel == 'diamond': 45 | return diamond, label 46 | elif self.submodel == 'string': 47 | return string, label''' 48 | 49 | def __len__(self): 50 | return len(self.labels) 51 | 52 | 53 | 54 | '''class TestDataset(Dataset): 55 | def __init__(self, data_pth=None, term_indicies=None, submodel=None): 56 | 57 | self.submodel = submodel 58 | 59 | data = pickle_load(data_pth) 60 | 61 | labels = pickle_load("tst_labesl")['cc'] 62 | 63 | self.proteins = data['protein'] 64 | self.esm_features = data['esm2_t48'] 65 | self.msa_features = data['msa_1b'] 66 | self.diamond_features = data['diamond'] 67 | self.interpro_features = data['interpro'] 68 | self.labels = labels 69 | self.term_indicies = term_indicies 70 | 71 | 72 | def __getitem__(self, index): 73 | 74 | esm = self.esm_features[index] 75 | msa = self.msa_features[index] 76 | diamond = self.diamond_features[index] 77 | interpro = self.interpro_features[index] 78 | proteins = self.proteins[index] 79 | 80 | 81 | 82 | if self.submodel == 'esm2_t48': 83 | lab = torch.tensor(self.labels[proteins], dtype=torch.float32).view(1, -1) 84 | lab = torch.index_select(lab, 1, self.term_indicies) 85 | return esm, lab #proteins 86 | elif self.submodel == 'msa_1b': 87 | return msa, proteins 88 | elif self.submodel == 'diamond': 89 | return diamond, proteins 90 | elif self.submodel == 'interpro': 91 | return interpro, proteins 92 | elif self.submodel == 'full': 93 | return esm, msa, diamond, interpro, proteins 94 | 95 | def __len__(self): 96 | return len(self.proteins) 97 | 98 | 99 | ''' 100 | 101 | 102 | 103 | class TestDataset(Dataset): 104 | def __init__(self, data_pth=None, submodel=None): 105 | 106 | self.submodel = submodel 107 | 108 | data = pickle_load(data_pth) 109 | 110 | self.proteins = data['protein'] 111 | self.esm_features = data['esm2_t48'] 112 | self.msa_features = data['msa_1b'] 113 | self.diamond_features = data['diamond'] 114 | self.interpro_features = data['interpro'] 115 | 116 | def __getitem__(self, index): 117 | 118 | esm = self.esm_features[index] 119 | msa = self.msa_features[index] 120 | diamond = self.diamond_features[index] 121 | interpro = self.interpro_features[index] 122 | proteins = self.proteins[index] 123 | 124 | 125 | 126 | if self.submodel == 'esm2_t48': 127 | return esm, proteins 128 | elif self.submodel == 'msa_1b': 129 | return msa, proteins 130 | elif self.submodel == 'diamond': 131 | return diamond, proteins 132 | elif self.submodel == 'interpro': 133 | return interpro, proteins 134 | elif self.submodel == 'full': 135 | return esm, msa, diamond, interpro, proteins 136 | 137 | def __len__(self): 138 | return len(self.proteins) 139 | 140 | 141 | 142 | class PredictDataset(Dataset): 143 | def __init__(self, data=None): 144 | 145 | self.esm_features = data['esm2_t48'] 146 | self.proteins = data['protein'] 147 | 148 | 149 | def __getitem__(self, index): 150 | 151 | esm = self.esm_features[index] 152 | proteins = self.proteins[index] 153 | 154 | return esm, proteins 155 | 156 | def __len__(self): 157 | return len(self.proteins) -------------------------------------------------------------------------------- /Dataset/MyDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from sklearn.utils.class_weight import compute_class_weight 4 | import numpy as np 5 | from Utils import pickle_load 6 | import CONSTANTS 7 | 8 | class TransFewDataset(Dataset): 9 | def __init__(self, data_pth=None, submodel=None): 10 | 11 | self.submodel = submodel 12 | 13 | data = pickle_load(data_pth) 14 | 15 | labels = data['labels'] 16 | labels = torch.cat(labels, dim=0) 17 | 18 | self.labels = labels 19 | self.esm_features = data['esm2_t48'] 20 | self.msa_features = data['msa_1b'] 21 | #self.diamond_features = data['diamond'] 22 | self.interpro_features = data['interpro'] 23 | #self.string_features = data['string'] 24 | 25 | 26 | def __getitem__(self, index): 27 | 28 | esm = self.esm_features[index] 29 | msa = self.msa_features[index] 30 | # diamond = self.diamond_features[index] 31 | interpro = self.interpro_features[index] 32 | # string = self.string_features[index] 33 | label = self.labels[index] 34 | 35 | if self.submodel == 'esm2_t48': 36 | return esm, label 37 | elif self.submodel == 'msa_1b': 38 | return msa, label 39 | elif self.submodel == 'interpro': 40 | return interpro, label 41 | elif self.submodel == 'full': 42 | # return esm, msa, diamond, interpro, string, label 43 | return esm, msa, interpro, label 44 | 45 | def __len__(self): 46 | return len(self.labels) 47 | 48 | 49 | class TestDataset(Dataset): 50 | def __init__(self, data_pth=None, submodel=None): 51 | 52 | self.submodel = submodel 53 | 54 | data = pickle_load(data_pth) 55 | 56 | self.proteins = data['protein'] 57 | self.esm_features = data['esm2_t48'] 58 | self.msa_features = data['msa_1b'] 59 | self.diamond_features = data['diamond'] 60 | self.interpro_features = data['interpro'] 61 | # self.labs = data['labels'] 62 | 63 | def __getitem__(self, index): 64 | 65 | esm = self.esm_features[index] 66 | msa = self.msa_features[index] 67 | diamond = self.diamond_features[index] 68 | interpro = self.interpro_features[index] 69 | proteins = self.proteins[index] 70 | 71 | # pop = self.labs[index] 72 | 73 | 74 | 75 | if self.submodel == 'esm2_t48': 76 | return esm, proteins 77 | elif self.submodel == 'msa_1b': 78 | return msa, proteins 79 | elif self.submodel == 'diamond': 80 | return diamond, proteins 81 | elif self.submodel == 'interpro': 82 | return interpro, proteins 83 | elif self.submodel == 'full': 84 | return esm, msa, diamond, interpro, proteins#, pop 85 | 86 | def __len__(self): 87 | return len(self.proteins) 88 | 89 | 90 | 91 | class PredictDataset(Dataset): 92 | def __init__(self, data=None): 93 | 94 | self.esm_features = data['esm2_t48'] 95 | self.proteins = data['protein'] 96 | 97 | 98 | def __getitem__(self, index): 99 | 100 | esm = self.esm_features[index] 101 | proteins = self.proteins[index] 102 | 103 | return esm, proteins 104 | 105 | def __len__(self): 106 | return len(self.proteins) -------------------------------------------------------------------------------- /Graph/DiamondDataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Dataset, download_url 7 | import os.path as osp 8 | from typing import Callable, List, Optional 9 | 10 | from torch_geometric.utils import from_networkx 11 | from torch_geometric.data import InMemoryDataset 12 | from Utils import is_file 13 | import CONSTANTS 14 | from Classes.Diamond import Diamond 15 | from preprocessing.utils import pickle_load 16 | 17 | 18 | class DiamondDataset(InMemoryDataset): 19 | 20 | def __init__(self, split: str = "random", ont: str = "cc"): 21 | 22 | self.ont = ont 23 | self.name = "diamond_{}".format(ont) 24 | super().__init__() 25 | self.data, self.slices = torch.load(self.processed_paths[0]) 26 | 27 | self.split = split 28 | 29 | if split == 'preset': 30 | data = self.get(0) 31 | data.train_mask.fill_(True) 32 | data.train_mask[data.val_mask | data.test_mask] = False 33 | self.data, self.slices = self.collate([data]) 34 | elif split == 'random': 35 | pass 36 | 37 | @property 38 | def processed_file_names(self) -> str: 39 | return '{}_data.pt'.format(self.name) 40 | 41 | @property 42 | def processed_dir(self) -> str: 43 | return CONSTANTS.ROOT_DIR + 'datasets/diamond' 44 | 45 | def process(self): 46 | kwargs = { 47 | 'fasta_file': CONSTANTS.ROOT_DIR + "uniprot/uniprot_fasta.fasta" 48 | } 49 | diamond = Diamond(CONSTANTS.ROOT_DIR + "diamond", **kwargs) 50 | G = diamond.get_graph() 51 | nodes = set(G.nodes) 52 | 53 | 54 | 55 | tmp = pickle_load(CONSTANTS.ROOT_DIR + "datasets/training_validation") 56 | train_val = tmp[self.ont]['train'].union(tmp[self.ont]['valid']) 57 | 58 | remove = nodes.difference(train_val) 59 | print("Total nodes {}; Removing {}; Retaining {}".format(len(nodes), len(remove), len(train_val))) 60 | G.remove_nodes_from(remove) 61 | nodes = list(G.nodes) 62 | 63 | embeddings = pickle_load(CONSTANTS.ROOT_DIR + "embedding/esm_36_all") 64 | embeddings = {key: embeddings[key] for key in nodes if key in embeddings} 65 | 66 | 67 | node_features = {} 68 | layer = 36 69 | test_nodes = set() 70 | for pos, node in enumerate(nodes): 71 | if node in embeddings: 72 | node_features[node] = {'{}'.format(layer): embeddings[node]} 73 | test_nodes.add(node) 74 | else: 75 | G.remove_node(node) 76 | 77 | print("Diamond Graph with {} nodes".format(len(G.nodes))) 78 | nodes = list(G.nodes) 79 | 80 | y = pickle_load(CONSTANTS.ROOT_DIR + "datasets/labels")[self.ont] 81 | y = np.array([y[node] for node in y if node in test_nodes]) 82 | y = torch.from_numpy(y).float() 83 | 84 | indicies = set(enumerate(nodes)) 85 | train_mask = [] 86 | valid_mask = [] 87 | validation_nodes = tmp[self.ont]['valid'] 88 | for i in indicies: 89 | if i[1] in validation_nodes: 90 | valid_mask.append(True) 91 | train_mask.append(False) 92 | else: 93 | valid_mask.append(False) 94 | train_mask.append(True) 95 | 96 | for i, j in zip(train_mask, valid_mask): 97 | assert i != j and type(i) == type(True) and type(j) == type(True) 98 | 99 | # add node features 100 | nx.set_node_attributes(G, node_features) 101 | 102 | data = from_networkx(G, group_node_attrs=all) 103 | # print(data.nodes) 104 | print(data.generate_ids()) 105 | #print(zip(G.nodes, data.nodes)) 106 | exit() 107 | data.train_mask = torch.BoolTensor( train_mask) 108 | data.valid_mask = torch.BoolTensor(valid_mask) 109 | data.nodes = nodes 110 | data.y = y 111 | self.clusters = [] 112 | torch.save(self.collate([data]), self.processed_paths[0]) 113 | 114 | def __repr__(self) -> str: 115 | return f'{self.name}()' 116 | -------------------------------------------------------------------------------- /Graph/DiamondDatasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Dataset, download_url 7 | import os.path as osp 8 | from typing import Callable, List, Optional 9 | 10 | from torch_geometric.utils import from_networkx 11 | from torch_geometric.data import InMemoryDataset 12 | from Utils import is_file 13 | import CONSTANTS 14 | from Classes.Diamond import Diamond 15 | from preprocessing.utils import pickle_load 16 | 17 | 18 | class DiamondDataset(InMemoryDataset): 19 | 20 | def __init__(self, split: str = "random", ont: str = "cc"): 21 | 22 | self.ont = ont 23 | self.name = "diamond_{}".format(ont) 24 | super().__init__() 25 | self.data, self.slices = torch.load(self.processed_paths[0]) 26 | 27 | self.split = split 28 | 29 | if split == 'preset': 30 | data = self.get(0) 31 | data.train_mask.fill_(True) 32 | data.train_mask[data.val_mask | data.test_mask] = False 33 | self.data, self.slices = self.collate([data]) 34 | elif split == 'random': 35 | pass 36 | 37 | @property 38 | def processed_file_names(self) -> str: 39 | return '{}_data.pt'.format(self.name) 40 | 41 | @property 42 | def processed_dir(self) -> str: 43 | return CONSTANTS.ROOT_DIR + 'datasets/diamond' 44 | 45 | def process(self): 46 | kwargs = { 47 | 'fasta_file': CONSTANTS.ROOT_DIR + "uniprot/uniprot_fasta.fasta" 48 | } 49 | diamond = Diamond(CONSTANTS.ROOT_DIR + "diamond", **kwargs) 50 | G = diamond.get_graph() 51 | nodes = set(G.nodes) 52 | 53 | tmp = pickle_load(CONSTANTS.ROOT_DIR + "datasets/training_validation") 54 | train_val = tmp[self.ont]['train'].union(tmp[self.ont]['valid']) 55 | 56 | remove = nodes.difference(train_val) 57 | print("Total nodes {}; Removing {}; Retaining {}".format(len(nodes), len(remove), len(train_val))) 58 | G.remove_nodes_from(remove) 59 | nodes = list(G.nodes) 60 | 61 | y = pickle_load(CONSTANTS.ROOT_DIR + "datasets/labels")[self.ont] 62 | y = np.array([y[node] for node in y]) 63 | y = torch.from_numpy(y).float() 64 | 65 | indicies = set(enumerate(nodes)) 66 | train_mask = [] 67 | valid_mask = [] 68 | validation_nodes = tmp[self.ont]['valid'] 69 | for i in indicies: 70 | if i[1] in validation_nodes: 71 | valid_mask.append(True) 72 | train_mask.append(False) 73 | else: 74 | valid_mask.append(False) 75 | train_mask.append(True) 76 | 77 | for i, j in zip(train_mask, valid_mask): 78 | assert i != j and type(i) == type(True) and type(j) == type(True) 79 | 80 | node_features = {} 81 | layer = 36 82 | pt = 0 83 | for node in nodes: 84 | if is_file(CONSTANTS.ROOT_DIR + "embedding/esm_t36/{}.pt".format(node)): 85 | _x = torch.load(CONSTANTS.ROOT_DIR + "embedding/esm_t36/{}.pt".format(node)) 86 | node_features[node] = {'{}'.format(layer): _x['mean_representations'][layer].tolist()} 87 | else: 88 | node_features[node] = {'{}'.format(layer): [0] * 2560} 89 | pt = pt + 1 90 | 91 | # add node features 92 | nx.set_node_attributes(G, node_features) 93 | 94 | data = from_networkx(G, group_node_attrs=all, group_edge_attrs=all) 95 | data.train_mask = torch.BoolTensor(train_mask) 96 | data.valid_mask = torch.BoolTensor(valid_mask) 97 | data.y = y 98 | torch.save(self.collate([data]), self.processed_paths[0]) 99 | 100 | def __repr__(self) -> str: 101 | return f'{self.name}()' 102 | -------------------------------------------------------------------------------- /Graph/PPI.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List, Tuple, Optional, Callable 3 | 4 | from torch_geometric.data import Dataset, Data 5 | from torch_geometric.utils import from_networkx 6 | from torch_geometric.loader import DataLoader 7 | from Classes.Diamond import Diamond 8 | 9 | 10 | class PPI(Dataset): 11 | def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, 12 | pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, **kwargs): 13 | self.proteins_names = kwargs.get('proteins_names', []) 14 | super().__init__(root, transform, pre_transform, pre_filter) 15 | 16 | @staticmethod 17 | def find_files_dir(path): 18 | return os.listdir(path) 19 | 20 | @property 21 | def raw_dir(self) -> str: 22 | return "D:/Workspace/python-3/TransFun2/data/raw" 23 | 24 | @property 25 | def processed_dir(self) -> str: 26 | return "D:/Workspace/python-3/TransFun2/data/processed" 27 | 28 | @property 29 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 30 | return ["AF-{}-F1-model_v4.pdb".format(i) for i in self.proteins_names] 31 | 32 | @property 33 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 34 | return ["{}.pt".format(i) for i in self.proteins_names] 35 | 36 | def download(self): 37 | available = [file.split("-")[1] for file in self.find_files_dir(self.raw_dir)] 38 | unavailabe = set(self.proteins_names) - set(available) 39 | print("Downloading {} proteins".format(len(unavailabe))) 40 | 41 | def process(self): 42 | unprocessed = [file.split(".")[0] for file in self.find_files_dir(self.processed_dir)] 43 | unprocessed = set(self.proteins_names) - set(unprocessed) 44 | print("{} unprocessed proteins out of {}".format(len(unprocessed), len(self.raw_file_names))) 45 | 46 | for protein in unprocessed: 47 | print("Processing {}".format(protein)) 48 | 49 | diamond = Diamond("../data/{}".format("Diamond"), **kwargs) 50 | G = diamond.get_graph() 51 | pyg_graph = from_networkx(G) 52 | 53 | print(pyg_graph) 54 | 55 | def len(self) -> int: 56 | return 0 57 | 58 | def get(self, idx: int) -> Data: 59 | pass 60 | 61 | 62 | kwargs = { 63 | "proteins_names": ["bbdfnk", "A0JNW5", "A0JP26", "A2A2Y4", "A5D8V7", "A7MD48", "O14503"] 64 | } 65 | dataset = PPI(**kwargs) 66 | 67 | 68 | train_dataloader = DataLoader(dataset, batch_size=10, drop_last=False, shuffle=True) 69 | 70 | for i in train_dataloader: 71 | print(i) 72 | -------------------------------------------------------------------------------- /Graph/ProteinGraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List, Tuple, Optional, Callable 3 | 4 | from torch_geometric.data import Dataset, Data 5 | 6 | 7 | class GraphDataset(Dataset): 8 | def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, 9 | pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, **kwargs): 10 | 11 | self.proteins_names = kwargs.get('proteins_names', []) 12 | super().__init__(root, transform, pre_transform, pre_filter) 13 | 14 | @staticmethod 15 | def find_files_dir(path): 16 | return os.listdir(path) 17 | 18 | @property 19 | def raw_dir(self) -> str: 20 | return "D:/Workspace/python-3/TransFun2/data/raw" 21 | 22 | @property 23 | def processed_dir(self) -> str: 24 | return "D:/Workspace/python-3/TransFun2/data/processed" 25 | 26 | @property 27 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 28 | return ["AF-{}-F1-model_v4.pdb".format(i) for i in self.proteins_names] 29 | 30 | @property 31 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 32 | return ["{}.pt".format(i) for i in self.proteins_names] 33 | 34 | def download(self): 35 | available = [file.split("-")[1] for file in self.find_files_dir(self.raw_dir)] 36 | unavailabe = set(self.proteins_names) - set(available) 37 | print("Downloading {} proteins".format(len(unavailabe))) 38 | 39 | def process(self): 40 | unprocessed = [file.split(".")[0] for file in self.find_files_dir(self.processed_dir)] 41 | unprocessed = set(self.proteins_names) - set(unprocessed) 42 | print("{} unprocessed proteins out of {}".format(len(unprocessed), len(self.raw_file_names))) 43 | 44 | for protein in unprocessed: 45 | print("Processing {}".format(protein)) 46 | 47 | def len(self) -> int: 48 | return 0 49 | 50 | def get(self, idx: int) -> Data: 51 | pass 52 | 53 | 54 | kwargs = { 55 | "proteins_names": ["bbdfnk", "A0JNW5", "A0JP26", "A2A2Y4", "A5D8V7", "A7MD48", "O14503"] 56 | } 57 | graph = GraphDataset(**kwargs) 58 | print(graph) 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 BioinfoMachineLearning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Loss/Loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch import nn 3 | 4 | 5 | class HierarchicalLoss(nn.Module): 6 | def __init__(self, hierarchy, weight_factor=0.5): 7 | super(HierarchicalLoss, self).__init__() 8 | self.hierarchy = hierarchy 9 | self.weight_factor = weight_factor 10 | 11 | def forward(self, logits, targets): 12 | loss = nn.CrossEntropyLoss()(logits, targets) 13 | 14 | # Add hierarchical penalty 15 | for parent, children in self.hierarchy.items(): 16 | if targets.item() in [torch.tensor(child) for child in children]: 17 | loss += self.weight_factor * nn.CrossEntropyLoss()(logits, torch.tensor(children).to(targets.device)) 18 | 19 | return loss 20 | 21 | 22 | 23 | class DiceLoss(torch.nn.Module): 24 | def __init__(self): 25 | super(DiceLoss, self).__init__() 26 | 27 | def forward(self, y_pred, y_true): 28 | smooth = 1e-7 29 | intersection = torch.sum(y_true * y_pred) 30 | union = torch.sum(y_true) + torch.sum(y_pred) 31 | dice = (2. * intersection + smooth) / (union + smooth) 32 | return 1. - dice 33 | 34 | 35 | '''class HierarchicalLoss(torch.nn.Module): 36 | def __init__(self, alpha=0.5, beta=0.5): 37 | super(HierarchicalLoss, self).__init__() 38 | self.alpha = alpha 39 | self.beta = beta 40 | 41 | def forward(self, y_pred, y_true, hierarchy): 42 | """ 43 | Computes the hierarchical loss. 44 | 45 | Args: 46 | - y_pred (Tensor): Predicted probabilities (batch_size, num_classes). 47 | - y_true (Tensor): True binary labels (batch_size, num_classes). 48 | - hierarchy (Tensor): Hierarchical structure matrix (num_classes, num_classes). 49 | 50 | Returns: 51 | - loss (Tensor): Hierarchical loss. 52 | """ 53 | y_pred = torch.sigmoid(y_pred) # Apply sigmoid to convert logits to probabilities 54 | 55 | # Compute binary cross-entropy loss 56 | bce_loss = F.binary_cross_entropy(y_pred, y_true, reduction='none') 57 | 58 | # Compute hierarchical loss 59 | h_loss = torch.zeros_like(bce_loss) 60 | for i in range(y_pred.size(0)): 61 | for j in range(y_pred.size(1)): 62 | if y_true[i][j] == 1: # Only consider losses for positive classes 63 | ancestors = torch.nonzero(hierarchy[:, j]).squeeze(1) 64 | if len(ancestors) > 0: 65 | ancestor_loss = torch.mean(bce_loss[i, ancestors]) 66 | descendant_loss = torch.mean(bce_loss[i, hierarchy[j]]) 67 | h_loss[i, j] = self.alpha * ancestor_loss + self.beta * descendant_loss 68 | 69 | # Average over samples and classes 70 | loss = torch.mean(h_loss) 71 | return loss''' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransFew 2 | #### Improving protein function prediction by learning and integrating representations of protein sequences and function labels 3 | 4 | TransFew leaverages representations of both protein sequences and 5 | function labels (Gene Ontology (GO) terms) to predict the function of proteins. It improves the accuracy of predicting both common and rare function terms (GO terms). 6 | 7 | 8 | 9 | ## Installation 10 | ``` 11 | # clone project 12 | git clone https://github.com/BioinfoMachineLearning/TransFew.git 13 | cd TransFew/ 14 | 15 | # download trained models and test sample 16 | https://calla.rnet.missouri.edu/rnaminer/tfew/TFewDataset 17 | 18 | # Unzip Dataset 19 | unzip TFewDataset 20 | 21 | 22 | # create conda environment 23 | conda env create -f transfew.yaml 24 | conda activate transfew 25 | ``` 26 | 27 | ## Prediction 28 | ``` 29 | Predict protein functions with TransFew 30 | 31 | options: 32 | -h, --help show this help message and exit 33 | 34 | --data-path DATA_PATH Path to data files (models) 35 | 36 | --working-dir WORKING_DIR Path to generate temporary 37 | files 38 | 39 | --ontology ONTOLOGY Path to data files 40 | 41 | --no-cuda NO_CUDA Disables CUDA training. 42 | 43 | --batch-size BATCH_SIZE Batch size. 44 | 45 | --fasta-path FASTA_PATH Path to Fasta 46 | 47 | --output OUTPUT File to save output 48 | ``` 49 | 50 | 4. An example of predicting cellular component of some proteins: 51 | ``` 52 | 1. Change ROOT_DIR in CONSTANTS.py to path of data directory 53 | 54 | 2. python predict.py --data-path /TFewData/ --fasta-path output_dir/test_fasta.fasta --ontology cc --working-dir output_dir --output result.tsv 55 | ``` 56 | 57 | ##### Output format 58 | ``` 59 | protein GO term score 60 | A0A7I2V2M2 GO:0043227 0.996 61 | A0A7I2V2M2 GO:0043226 0.996 62 | A0A7I2V2M2 GO:0005737 0.926 63 | A0A7I2V2M2 GO:0043233 0.924 64 | A0A7I2V2M2 GO:0031974 0.913 65 | A0A7I2V2M2 GO:0070013 0.912 66 | A0A7I2V2M2 GO:0031981 0.831 67 | A0A7I2V2M2 GO:0005654 0.767 68 | ``` 69 | 70 | ## Dataset 71 | ``` 72 | See DATASET.md (https://github.com/BioinfoMachineLearning/TransFew/blob/main/DATASET.md) for description of data 73 | ``` 74 | 75 | 76 | 77 | ## Training 78 | The training program is available in training.py, to train the model: 79 | ``` 80 | 1. Change ROOT_DIR in CONSTANTS.py to path of data directory 81 | 2. Run: python training.py 82 | ``` 83 | 84 | 85 | 86 | ## Reference 87 | ``` 88 | Boadu, F., & Cheng, J. (2024). Improving protein function prediction by learning and integrating representations of protein sequences and function labels. Bioinformatics Advances. Volume 4, Issue 1, vbae120. 89 | 90 | ``` 91 | 92 | 93 | -------------------------------------------------------------------------------- /TODO.py: -------------------------------------------------------------------------------- 1 | diamondL 2 | batch: 5000 3 | vbn: 100 4 | weight: multiply 5 | 6 | 7 | STATE OF ARTS: 8 | 1. Protranslator 9 | 2. DeepGOZero 10 | 3. Tale 11 | 4. go labeler 12 | -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, shutil 2 | from Bio import SeqIO 3 | from Bio.Seq import Seq 4 | from Bio.SeqRecord import SeqRecord 5 | import obonet 6 | import pandas as pd 7 | import torch 8 | import pickle 9 | from biopandas.pdb import PandasPdb 10 | from collections import Counter 11 | import csv 12 | from sklearn.metrics import roc_curve, auc 13 | # from torchviz import make_dot 14 | from CONSTANTS import INVALID_ACIDS, amino_acids 15 | 16 | 17 | def is_file(path): 18 | return os.path.exists(path) 19 | 20 | 21 | def create_directory(dir): 22 | if not os.path.exists(dir): 23 | os.makedirs(dir) 24 | 25 | 26 | def count_proteins(fasta_file): 27 | num = len([1 for line in open(fasta_file) if line.startswith(">")]) 28 | return num 29 | 30 | 31 | def extract_id(header): 32 | return header.split('|')[1] 33 | 34 | 35 | def create_seqrecord(id="", name="", description="", seq=""): 36 | record = SeqRecord(Seq(seq), id=id, name=name, description=description) 37 | return record 38 | 39 | 40 | def remove_ungenerated_esm2_daisy_script(fasta_file, generated_directory): 41 | import os 42 | # those generated 43 | gen = os.listdir(generated_directory) 44 | gen = set([i.split(".")[0] for i in gen]) 45 | 46 | seq_records = [] 47 | 48 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 49 | for record in input_seq_iterator: 50 | uniprot_id = extract_id(record.id) 51 | seq_records.append(create_seqrecord(id=uniprot_id, seq=str(record.seq))) 52 | 53 | print(len(seq_records), len(gen), len(set(seq_records).difference(gen))) 54 | 55 | def filtered_sequences(fasta_file): 56 | """ 57 | Script is used to create fasta files based on alphafold sequence, by replacing sequences that are different. 58 | :param fasta_file: 59 | :return: None 60 | """ 61 | 62 | seq_records = [] 63 | 64 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 65 | for record in input_seq_iterator: 66 | uniprot_id = extract_id(record.id) 67 | seq_records.append(create_seqrecord(id=uniprot_id, seq=str(record.seq))) 68 | 69 | SeqIO.write(seq_records, "data/Fasta/id2.fasta", "fasta") 70 | 71 | 72 | def readlines_cluster(in_file): 73 | file = open(in_file) 74 | lines = [set(line.strip("\n").split("\t")) for line in file.readlines() if line.strip()] 75 | file.close() 76 | return lines 77 | 78 | 79 | def read_dictionary(file): 80 | reader = csv.reader(open(file, 'r'), delimiter='\t') 81 | d = {} 82 | for row in reader: 83 | k, v = row[0], row[1] 84 | d[k] = v 85 | return d 86 | 87 | 88 | def get_proteins_from_fasta(fasta_file): 89 | proteins = list(SeqIO.parse(fasta_file, "fasta")) 90 | proteins = [i.id for i in proteins] 91 | return proteins 92 | 93 | 94 | def read_cafa5_scores(file_name): 95 | with open(file_name) as file: 96 | lines = file.readlines() 97 | return lines 98 | 99 | 100 | def fasta_to_dictionary(fasta_file, identifier='protein_id'): 101 | if identifier == 'protein_id': 102 | loc = 1 103 | elif identifier == 'protein_name': 104 | loc = 2 105 | data = {} 106 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 107 | if "|" in seq_record.id: 108 | data[seq_record.id.split("|")[loc]] = ( 109 | seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 110 | else: 111 | data[seq_record.id] = (seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 112 | return data 113 | 114 | 115 | def pickle_save(data, filename): 116 | with open('{}.pickle'.format(filename), 'wb') as handle: 117 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 118 | 119 | 120 | def pickle_load(filename): 121 | with open('{}.pickle'.format(filename), 'rb') as handle: 122 | return pickle.load(handle) 123 | 124 | 125 | def get_sequence_from_pdb(pdb_file, chain_id): 126 | pdb_to_pandas = PandasPdb().read_pdb(pdb_file) 127 | 128 | pdb_df = pdb_to_pandas.df['ATOM'] 129 | 130 | assert (len(set(pdb_df['chain_id'])) == 1) & (list(set(pdb_df['chain_id']))[0] == chain_id) 131 | 132 | pdb_df = pdb_df[(pdb_df['atom_name'] == 'CA') & ((pdb_df['chain_id'])[0] == chain_id)] 133 | pdb_df = pdb_df.drop_duplicates() 134 | 135 | residues = pdb_df['residue_name'].to_list() 136 | residues = ''.join([amino_acids[i] for i in residues if i != "UNK"]) 137 | return residues 138 | 139 | 140 | def is_ok(seq, MINLEN=49, MAXLEN=1022): 141 | """ 142 | Checks if sequence is of good quality 143 | :param MAXLEN: 144 | :param MINLEN: 145 | :param seq: 146 | :return: None 147 | """ 148 | if len(seq) < MINLEN or len(seq) >= MAXLEN: 149 | return False 150 | for c in seq: 151 | if c in INVALID_ACIDS: 152 | return False 153 | return True 154 | 155 | 156 | def class_distribution_counter(**kwargs): 157 | """ 158 | Count the number of proteins for each GO term in training set. 159 | """ 160 | data = pickle_load(Constants.ROOT + "{}/{}/{}".format(kwargs['seq_id'], kwargs['ont'], kwargs['session'])) 161 | 162 | all_proteins = [] 163 | for i in data: 164 | all_proteins.extend(data[i]) 165 | 166 | annot = pd.read_csv(Constants.ROOT + 'annot.tsv', delimiter='\t') 167 | annot = annot.where(pd.notnull(annot), None) 168 | annot = annot[annot['Protein'].isin(all_proteins)] 169 | annot = pd.Series(annot[kwargs['ont']].values, index=annot['Protein']).to_dict() 170 | 171 | terms = [] 172 | for i in annot: 173 | terms.extend(annot[i].split(",")) 174 | 175 | counter = Counter(terms) 176 | 177 | # for i in counter.most_common(): 178 | # print(i) 179 | # print("# of ontologies is {}".format(len(counter))) 180 | 181 | return counter 182 | 183 | 184 | def save_ckp(state, is_best, checkpoint_dir): 185 | """ 186 | state: checkpoint we want to save 187 | is_best: is this the best checkpoint; min validation loss 188 | checkpoint_path: path to save checkpoint 189 | best_model_path: path to save best model 190 | """ 191 | if not os.path.exists(checkpoint_dir): 192 | os.makedirs(checkpoint_dir) 193 | 194 | checkpoint_path = checkpoint_dir + "current_checkpoint.pt" 195 | best_model_path = checkpoint_dir + "best_model.pt" 196 | # save checkpoint data_bp to the path given, checkpoint_path 197 | torch.save(state, checkpoint_path) 198 | # if it is a best model, min validation loss 199 | if is_best: 200 | # copy that checkpoint file to best path given, best_model_path 201 | shutil.copyfile(checkpoint_path, best_model_path) 202 | 203 | 204 | # def load_ckp_model_only(checkpoint_dir, model, best_model=False): 205 | # if not os.path.exists(checkpoint_dir): 206 | # os.makedirs(checkpoint_dir) 207 | # if best_model: 208 | # checkpoint_fpath = checkpoint_dir + "best_checkpoint.pt" 209 | # else: 210 | # checkpoint_fpath = checkpoint_dir + "current_checkpoint.pt" 211 | 212 | # if os.path.exists(checkpoint_fpath): 213 | # print("Loading model checkpoint @ {}".format(checkpoint_fpath)) 214 | # checkpoint = torch.load(checkpoint_fpath) 215 | # model.load_state_dict(checkpoint['state_dict']) 216 | # return model 217 | 218 | 219 | def load_ckp(checkpoint_dir, model, optimizer=None, lr_scheduler=None, best_model=False, model_only=False): 220 | """ 221 | checkpoint_path: path to save checkpoint 222 | model: model that we want to load checkpoint parameters into 223 | optimizer: optimizer we defined in previous training 224 | """ 225 | # load check point 226 | if best_model: 227 | checkpoint_fpath = checkpoint_dir + "best_model.pt" 228 | else: 229 | checkpoint_fpath = checkpoint_dir + "current_checkpoint.pt" 230 | 231 | checkpoint = torch.load(checkpoint_fpath, map_location="cpu") 232 | 233 | model.load_state_dict(checkpoint['state_dict']) 234 | 235 | # initialize optimizer from checkpoint to optimizer 236 | if optimizer is not None: 237 | optimizer.load_state_dict(checkpoint['optimizer']) 238 | 239 | # initialize lr scheduler from checkpoint to optimizer 240 | if lr_scheduler is not None: 241 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 242 | 243 | # initialize valid_loss_min from checkpoint to valid_loss_min 244 | valid_loss_min = checkpoint['valid_loss_min'] 245 | # return model, optimizer, epoch value, min validation loss 246 | 247 | if model_only: 248 | return model 249 | 250 | return model, optimizer, lr_scheduler, checkpoint['epoch'], valid_loss_min 251 | 252 | 253 | def draw_architecture(model, data_batch): 254 | ''' 255 | Draw the network architecture. 256 | ''' 257 | output = model(data_batch) 258 | make_dot(output, params=dict(model.named_parameters())).render("rnn_lstm_torchviz", format="png") 259 | 260 | 261 | def compute_roc(labels, preds): 262 | # Compute ROC curve and ROC area for each class 263 | fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten()) 264 | roc_auc = auc(fpr, tpr) 265 | return roc_auc 266 | 267 | 268 | def get_graph(obo_file): 269 | go_graph = obonet.read_obo(open(obo_file, 'r')) 270 | 271 | accepted_edges = set() 272 | unaccepted_edges = set() 273 | 274 | for edge in go_graph.edges: 275 | if edge[2] == 'is_a' or edge[2] == 'part_of': 276 | accepted_edges.add(edge) 277 | else: 278 | unaccepted_edges.add(edge) 279 | 280 | print("Number of nodes: {}, edges: {}".format(len(go_graph.nodes), len(go_graph.edges))) 281 | go_graph.remove_edges_from(unaccepted_edges) 282 | print("Number of nodes: {}, edges: {}".format(len(go_graph.nodes), len(go_graph.edges))) 283 | 284 | return go_graph -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tfun2 2 | channels: 3 | - pytorch 4 | - bioconda 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - appdirs=1.4.4=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py39h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2022.12.7=ha878542_0 16 | - certifi=2022.12.7=pyhd8ed1ab_0 17 | - cffi=1.15.1=py39h5eee18b_3 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - cryptography=39.0.1=py39h9ce1e76_0 20 | - cuda-cudart=11.7.99=0 21 | - cuda-cupti=11.7.101=0 22 | - cuda-libraries=11.7.1=0 23 | - cuda-nvrtc=11.7.99=0 24 | - cuda-nvtx=11.7.91=0 25 | - cuda-runtime=11.7.1=0 26 | - diamond=0.9.14=h2e03b76_4 27 | - ffmpeg=4.2.2=h20bf706_0 28 | - filelock=3.9.0=py39h06a4308_0 29 | - freetype=2.12.1=h4a9f257_0 30 | - gawk=5.1.0=h7b6447c_0 31 | - giflib=5.2.1=h5eee18b_3 32 | - gmp=6.2.1=h295c915_3 33 | - gmpy2=2.1.2=py39heeb90bb_0 34 | - gnutls=3.6.15=he1e5248_0 35 | - hhsuite=3.3.0=py39pl5321h67e14b5_5 36 | - idna=3.4=py39h06a4308_0 37 | - intel-openmp=2021.4.0=h06a4308_3561 38 | - jinja2=3.1.2=py39h06a4308_0 39 | - jpeg=9e=h5eee18b_1 40 | - lame=3.100=h7b6447c_0 41 | - lcms2=2.12=h3be6417_0 42 | - ld_impl_linux-64=2.38=h1181459_1 43 | - lerc=3.0=h295c915_0 44 | - libcublas=11.10.3.66=0 45 | - libcufft=10.7.2.124=h4fbf590_0 46 | - libcufile=1.6.0.25=0 47 | - libcurand=10.3.2.56=0 48 | - libcusolver=11.4.0.1=0 49 | - libcusparse=11.7.4.91=0 50 | - libdeflate=1.17=h5eee18b_0 51 | - libffi=3.4.2=h6a678d5_6 52 | - libgcc-ng=11.2.0=h1234567_1 53 | - libgfortran-ng=11.2.0=h00389a5_1 54 | - libgfortran5=11.2.0=h1234567_1 55 | - libgomp=11.2.0=h1234567_1 56 | - libidn2=2.3.2=h7f8727e_0 57 | - libnpp=11.7.4.75=0 58 | - libnsl=2.0.0=h5eee18b_0 59 | - libnvjpeg=11.8.0.2=0 60 | - libopus=1.3.1=h7b6447c_0 61 | - libpng=1.6.39=h5eee18b_0 62 | - libstdcxx-ng=11.2.0=h1234567_1 63 | - libtasn1=4.19.0=h5eee18b_0 64 | - libtiff=4.5.0=h6a678d5_2 65 | - libunistring=0.9.10=h27cfd23_0 66 | - libvpx=1.7.0=h439df22_0 67 | - libwebp=1.2.4=h11a3e52_1 68 | - libwebp-base=1.2.4=h5eee18b_1 69 | - lz4-c=1.9.4=h6a678d5_0 70 | - markupsafe=2.1.1=py39h7f8727e_0 71 | - mkl=2021.4.0=h06a4308_640 72 | - mkl-service=2.4.0=py39h7f8727e_0 73 | - mkl_fft=1.3.1=py39hd3c417c_0 74 | - mkl_random=1.2.2=py39h51133e4_0 75 | - mmseqs2=13.45111=h95f258a_1 76 | - mpc=1.1.0=h10f8cd9_1 77 | - mpfr=4.0.2=hb69a4c5_1 78 | - mpmath=1.2.1=py39h06a4308_0 79 | - ncurses=6.4=h6a678d5_0 80 | - nettle=3.7.3=hbbd107a_1 81 | - networkx=2.8.4=py39h06a4308_1 82 | - numpy=1.24.3=py39h14f4228_0 83 | - numpy-base=1.24.3=py39h31eccc5_0 84 | - openh264=2.1.1=h4ff587b_0 85 | - openssl=1.1.1t=h7f8727e_0 86 | - perl=5.32.1=0_h5eee18b_perl5 87 | - pillow=9.4.0=py39h6a678d5_0 88 | - pip=23.0.1=py39h06a4308_0 89 | - pooch=1.4.0=pyhd3eb1b0_0 90 | - pycparser=2.21=pyhd3eb1b0_0 91 | - pyopenssl=23.0.0=py39h06a4308_0 92 | - pysocks=1.7.1=py39h06a4308_0 93 | - python=3.9.16=h7a1cb2a_2 94 | - python_abi=3.9=2_cp39 95 | - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0 96 | - pytorch-cuda=11.7=h778d358_3 97 | - pytorch-mutex=1.0=cuda 98 | - readline=8.2=h5eee18b_0 99 | - requests=2.29.0=py39h06a4308_0 100 | - scikit-learn=1.2.2=py39h6a678d5_0 101 | - setuptools=66.0.0=py39h06a4308_0 102 | - six=1.16.0=pyhd3eb1b0_1 103 | - sqlite=3.41.2=h5eee18b_0 104 | - sympy=1.11.1=py39h06a4308_0 105 | - tk=8.6.12=h1ccaba5_0 106 | - torchaudio=2.0.0=py39_cu117 107 | - torchtriton=2.0.0=py39 108 | - torchvision=0.15.0=py39_cu117 109 | - typing_extensions=4.5.0=py39h06a4308_0 110 | - tzdata=2023c=h04d1e81_0 111 | - urllib3=1.26.15=py39h06a4308_0 112 | - wget=1.21.3=h0b77cf5_0 113 | - wheel=0.38.4=py39h06a4308_0 114 | - x264=1!157.20191217=h7b6447c_0 115 | - xz=5.2.10=h5eee18b_1 116 | - zlib=1.2.13=h5eee18b_0 117 | - zstd=1.5.5=hc292b87_0 118 | - pip: 119 | - biopandas==0.4.1 120 | - biopython==1.81 121 | - click==8.1.3 122 | - contourpy==1.0.7 123 | - cycler==0.11.0 124 | - docker-pycreds==0.4.0 125 | - fair-esm==2.0.0 126 | - fairscale==0.4.13 127 | - fonttools==4.39.3 128 | - gitdb==4.0.10 129 | - gitpython==3.1.31 130 | - importlib-resources==5.12.0 131 | - joblib==1.2.0 132 | - kiwisolver==1.4.4 133 | - matplotlib==3.7.1 134 | - obonet==1.0.0 135 | - opencv-python==4.7.0.72 136 | - packaging==23.1 137 | - pandas==1.5.3 138 | - pathtools==0.1.2 139 | - protobuf==4.23.0 140 | - psutil==5.9.5 141 | - pyg-lib==0.2.0+pt20cu118 142 | - pyparsing==3.0.9 143 | - python-dateutil==2.8.2 144 | - pytz==2023.2 145 | - pyyaml==6.0 146 | - scipy==1.10.1 147 | - seaborn==0.12.2 148 | - sentry-sdk==1.19.1 149 | - setproctitle==1.3.2 150 | - smmap==5.0.0 151 | - thop==0.1.1-2209072238 152 | - threadpoolctl==3.1.0 153 | - torch-cluster==1.6.1+pt20cu118 154 | - torch-geometric==2.3.1 155 | - torch-scatter==2.1.1+pt20cu118 156 | - torch-sparse==0.6.17+pt20cu118 157 | - torch-spline-conv==1.2.2+pt20cu118 158 | - torch-summary==1.4.5 159 | - tqdm==4.65.0 160 | - ultralytics==8.0.82 161 | - wandb==0.15.2 162 | - zipp==3.15.0 163 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import CONSTANTS 4 | 5 | 6 | # Evaluate all 7 | def evaluate_group(obo_file, ia_file, prediction_folder, groundtruth, outdir): 8 | # command = "cafaeval go-basic.obo prediction_dir test_terms.tsv -ia IA.txt -prop fill -norm cafa -th_step 0.001 -max_terms 500" 9 | command = "cafaeval {} {} {} -ia {} -norm cafa -th_step 0.01 -out_dir {}".\ 10 | format(obo_file, prediction_folder, groundtruth, ia_file, outdir) 11 | subprocess.call(command, shell=True) 12 | 13 | 14 | obo_file = CONSTANTS.ROOT_DIR + "obo/go-basic.obo" 15 | ia_file = CONSTANTS.ROOT_DIR + "test/output_t1_t2/ia_prop.txt" 16 | 17 | 18 | 19 | db_subset = ["swissprot", "trembl"] 20 | ontologies = ["cc", "mf", "bp"] 21 | 22 | 23 | # full evaluation 24 | def full_evaluation(): 25 | prediction_folder = CONSTANTS.ROOT_DIR + "evaluation/predictions/full/{}_{}" 26 | groundtruth = CONSTANTS.ROOT_DIR + "test/output_t1_t2/groundtruths/full/{}_{}.tsv" 27 | out_dir = "results/full/{}_{}" 28 | for ont in ontologies: 29 | for sptr in db_subset: 30 | print("Evaluating {} {}".format(ont, sptr)) 31 | evaluate_group(obo_file=obo_file, 32 | ia_file=ia_file, 33 | prediction_folder=prediction_folder.format(sptr, ont), 34 | groundtruth=groundtruth.format(ont, sptr), 35 | outdir=out_dir.format(ont, sptr)) 36 | 37 | 38 | # 30% SeqID evaluation 39 | def seq_30_evaluation(): 40 | prediction_folder = CONSTANTS.ROOT_DIR + "evaluation/predictions/seq_ID_30/{}_{}" 41 | groundtruth = CONSTANTS.ROOT_DIR + "test/output_t1_t2/groundtruths/seq_ID_30/{}_{}.tsv" 42 | out_dir = "results/seq_ID_30/{}_{}" 43 | for ont in ontologies: 44 | for sptr in db_subset: 45 | print("Evaluating {} {}".format(ont, sptr)) 46 | evaluate_group(obo_file=obo_file, 47 | ia_file=ia_file, 48 | prediction_folder=prediction_folder.format(sptr, ont), 49 | groundtruth=groundtruth.format(ont, sptr), 50 | outdir=out_dir.format(ont, sptr)) 51 | 52 | 53 | # 30% SeqID evaluation 54 | def components_evaluation(): 55 | prediction_folder = CONSTANTS.ROOT_DIR + "evaluation/predictions/components/{}_{}" 56 | groundtruth = CONSTANTS.ROOT_DIR + "test/output_t1_t2/groundtruths/full/{}_{}.tsv" 57 | out_dir = "results/components/{}_{}" 58 | for ont in ontologies: 59 | for sptr in db_subset: 60 | print("Evaluating {} {}".format(ont, sptr)) 61 | evaluate_group(obo_file=obo_file, 62 | ia_file=ia_file, 63 | prediction_folder=prediction_folder.format(sptr, ont), 64 | groundtruth=groundtruth.format(ont, sptr), 65 | outdir=out_dir.format(ont, sptr)) 66 | 67 | 68 | 69 | # full_evaluation() 70 | # seq_30_evaluation() 71 | components_evaluation() -------------------------------------------------------------------------------- /evaluation/predictions/deepgose/format_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from Bio import SeqIO 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | 8 | 9 | def pickle_save(data, filename): 10 | with open('{}.pickle'.format(filename), 'wb') as handle: 11 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 12 | 13 | 14 | def pickle_load(filename): 15 | with open('{}.pickle'.format(filename), 'rb') as handle: 16 | return pickle.load(handle) 17 | 18 | 19 | def csv_to_dic(filename): 20 | res = {} 21 | with open(filename) as f: 22 | lines = [line.split("\t") for line in f] 23 | 24 | for i in lines: 25 | if i[0] in res: 26 | res[i[0]].append((i[1], float(i[2]))) 27 | else: 28 | res[i[0]] = [(i[1], float(i[2])), ] 29 | return res 30 | 31 | 32 | ontologies = ["cc", "mf", "bp"] 33 | 34 | 35 | 36 | output = {} 37 | for ont in ontologies: 38 | res = csv_to_dic("evaluation/predictions/deepgose/test_fasta_preds_{}.tsv".format(ont)) 39 | output[ont] = res 40 | 41 | pickle_save(output, "evaluation/results/{}_out".format("deepgose")) -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def pickle_save(data, filename): 5 | with open('{}.pickle'.format(filename), 'wb') as handle: 6 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 7 | 8 | 9 | def pickle_load(filename): 10 | with open('{}.pickle'.format(filename), 'rb') as handle: 11 | return pickle.load(handle) -------------------------------------------------------------------------------- /evaluation_scripts/diamondblast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import networkx as nx 4 | import numpy as np 5 | import obonet 6 | import pandas as pd 7 | from Bio import SeqIO 8 | import pickle 9 | 10 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 11 | 12 | 13 | def pickle_save(data, filename): 14 | with open('{}.pickle'.format(filename), 'wb') as handle: 15 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 16 | 17 | 18 | def pickle_load(filename): 19 | with open('{}.pickle'.format(filename), 'rb') as handle: 20 | return pickle.load(handle) 21 | 22 | 23 | def count_proteins(fasta_file): 24 | num = len([1 for line in open(fasta_file) if line.startswith(">")]) 25 | return num 26 | 27 | 28 | def get_graph(go_path=ROOT_DIR + "obo/go-basic.obo"): 29 | go_graph = obonet.read_obo(open(go_path, 'r')) 30 | 31 | accepted_edges = set() 32 | unaccepted_edges = set() 33 | 34 | for edge in go_graph.edges: 35 | if edge[2] == 'is_a' or edge[2] == 'part_of': 36 | accepted_edges.add(edge) 37 | else: 38 | unaccepted_edges.add(edge) 39 | go_graph.remove_edges_from(unaccepted_edges) 40 | 41 | return go_graph 42 | 43 | 44 | def create_db(wd, fasta_seq, dbase_name): 45 | command = "diamond makedb --in {} -d {}".format(fasta_seq, dbase_name) 46 | subprocess.call(command, shell=True, cwd="{}".format(wd)) 47 | 48 | 49 | def diamomd_blast(wd, dbase_name, query, output): 50 | command = "diamond blastp -d {} -q {} --outfmt 6 qseqid sseqid bitscore -o {}". \ 51 | format(dbase_name, query, output) 52 | # command = "diamond blastp -d {} -q {} -o {} --sensitive" \ 53 | # .format(dbase_name, query, output) 54 | 55 | subprocess.call(command, shell=True, cwd="{}".format(wd)) 56 | 57 | 58 | def create_diamond(proteins, groundtruth, diamond_scores_file): 59 | 60 | train_test_proteins = set(ontology_groundtruth.keys()) 61 | proteins = set(proteins) 62 | 63 | # BLAST Similarity (Diamond) 64 | diamond_scores = {} 65 | with open(diamond_scores_file) as f: 66 | for line in f: 67 | it = line.strip().split() 68 | if it[0] in proteins and it[1] in train_test_proteins: 69 | if it[0] not in diamond_scores: 70 | diamond_scores[it[0]] = {} 71 | diamond_scores[it[0]][it[1]] = float(it[2]) 72 | 73 | # BlastKNN 74 | results = {} 75 | for protein in proteins: 76 | tmp_annots = [] 77 | if protein in diamond_scores: 78 | sim_prots = diamond_scores[protein] 79 | allgos = set() 80 | total_score = 0.0 81 | for p_id, score in sim_prots.items(): 82 | allgos |= groundtruth[p_id] 83 | total_score += score 84 | # allgos is all go terms for protein 85 | allgos = list(sorted(allgos)) 86 | sim = np.zeros(len(allgos), dtype=np.float32) 87 | 88 | for j, go_id in enumerate(allgos): 89 | s = 0.0 90 | for p_id, score in sim_prots.items(): 91 | if go_id in groundtruth[p_id]: 92 | s = max(s, score) 93 | sim[j] = s 94 | sim = sim / np.max(sim) 95 | for go_id, score in zip(allgos, sim): 96 | tmp_annots.append((go_id, score)) 97 | results[protein] = tmp_annots 98 | 99 | return results 100 | 101 | 102 | def create_directory(dir): 103 | if not os.path.exists(dir): 104 | os.makedirs(dir) 105 | 106 | 107 | train_fasta = ROOT_DIR + "uniprot/train_sequences.fasta" 108 | wd = ROOT_DIR + "evaluation/raw_predictions/diamond" 109 | dbase_name = "diamond_db" 110 | query = wd + "/test_fasta" 111 | diamond_res = wd + "/diamond_res" 112 | 113 | print(count_proteins(train_fasta)) 114 | # create_db(wd, train_fasta, dbase_name) 115 | # diamomd_blast(wd, dbase_name, query, diamond_res) 116 | 117 | go_graph = get_graph() 118 | test_group = pickle_load("/home/fbqc9/Workspace/DATA/test/output_t1_t2/test_proteins") 119 | groundtruth = pickle_load("/home/fbqc9/Workspace/DATA/groundtruth") 120 | parent_terms = { 121 | 'cc': 'GO:0005575', 122 | 'mf': 'GO:0003674', 123 | 'bp': 'GO:0008150' 124 | } 125 | 126 | for ont in test_group: 127 | parent_term = parent_terms[ont] 128 | 129 | train_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/train_proteins".format(ont))) 130 | valid_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/validation_proteins".format(ont))) 131 | data = train_data + valid_data 132 | 133 | all_go_terms = nx.ancestors(go_graph, parent_term)#.union(set([parent_term])) 134 | ontology_groundtruth = {prot: set(groundtruth[prot]).intersection(all_go_terms) for prot in data} 135 | 136 | 137 | for sptr in test_group[ont]: 138 | 139 | print("Swissprot or Trembl is {}".format(sptr)) 140 | 141 | dir_pth = ROOT_DIR + "evaluation/predictions/{}_{}/".format(sptr, ont) 142 | create_directory(dir_pth) 143 | 144 | proteins = test_group[ont][sptr] 145 | 146 | diamond_scores = create_diamond(proteins, ontology_groundtruth, diamond_res) 147 | 148 | file_out = open(dir_pth+"{}.tsv".format("diamond"), 'w') 149 | for prot in proteins: 150 | for annot in diamond_scores[prot]: 151 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 152 | file_out.close() 153 | -------------------------------------------------------------------------------- /evaluation_scripts/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from Bio import SeqIO 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | 8 | 9 | 10 | def read_tsv(file_name): 11 | results = {} 12 | with open(file_name) as f: 13 | lines = f.readlines() 14 | lines = [line.strip().split("\t") for line in lines] 15 | 16 | for line in lines: 17 | if line[0] in results: 18 | results[line[0]].append((line[1], line[2])) 19 | else: 20 | results[line[0]] = [(line[1], line[2]), ] 21 | return results 22 | 23 | 24 | def pickle_load(filename): 25 | with open('{}.pickle'.format(filename), 'rb') as handle: 26 | return pickle.load(handle) 27 | 28 | 29 | def create_directory(dir): 30 | if not os.path.exists(dir): 31 | os.makedirs(dir) 32 | 33 | 34 | ontologies = ["cc", "bp", "mf"] 35 | db_subset = ['swissprot', 'trembl'] 36 | methods = ["esm2_t48", "msa_1b", "interpro", "split_mlp1", "split_mlp2", "combined", "transfew"] 37 | # "slpit_mlp3" 38 | 39 | 40 | def write_to_file(methods, out_dir): 41 | 42 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 43 | proteins = pickle_load(ROOT_DIR + "test/output_t1_t2/test_proteins") 44 | RAW_PRED_pth = ROOT_DIR + "evaluation/raw_predictions/{}/{}.tsv" 45 | 46 | 47 | for method in methods: 48 | for ont in ontologies[1:2]: 49 | results = read_tsv(RAW_PRED_pth.format(method, ont)) 50 | for sptr in db_subset: 51 | tmp_proteins = proteins[ont][sptr] 52 | 53 | print(ont, sptr, len(results), len(tmp_proteins)) 54 | 55 | dir_pth = ROOT_DIR +"evaluation/predictions/{}/{}_{}/".format(out_dir, sptr, ont) 56 | create_directory(dir_pth) 57 | 58 | file_out = open(dir_pth+"{}.tsv".format(method), 'w') 59 | for prot in results: 60 | if prot in tmp_proteins: 61 | annots = results[prot] 62 | for annot in annots: 63 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 64 | file_out.close() 65 | 66 | 67 | # Write full predictions to file to compare with other methods. 68 | methods = ["transfew"] 69 | # write_to_file(methods=methods, out_dir="full") 70 | 71 | 72 | # write the combined: 73 | methods = ["esm2_t48", "msa_1b", "interpro", "combined", "transfew"] 74 | # write_to_file(methods=methods, out_dir="components") 75 | 76 | 77 | # write random split: 78 | methods = ["split_mlp1", "split_mlp2", "split_mlp3", "transfew"] 79 | # write_to_file(methods=methods, out_dir="random_split") -------------------------------------------------------------------------------- /evaluation_scripts/format_deepgose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from Bio import SeqIO 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | import networkx as nx 8 | 9 | 10 | def pickle_save(data, filename): 11 | with open('{}.pickle'.format(filename), 'wb') as handle: 12 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 13 | 14 | 15 | def pickle_load(filename): 16 | with open('{}.pickle'.format(filename), 'rb') as handle: 17 | return pickle.load(handle) 18 | 19 | 20 | def read_to_dic(filename): 21 | res = {} 22 | with open(filename) as f: 23 | lines = [line.strip().split("\t") for line in f] 24 | 25 | 26 | for i in lines: 27 | prot = i[0].split("|")[1] 28 | if prot in res: 29 | res[prot].append((i[1], float(i[2]))) 30 | else: 31 | res[prot] = [(i[1], float(i[2])), ] 32 | return res 33 | 34 | 35 | def create_directory(dir): 36 | if not os.path.exists(dir): 37 | os.makedirs(dir) 38 | 39 | def to_file(data): 40 | with open('somefile.txt', 'a') as f: 41 | for ont, prot in data.items(): 42 | f.write("{}\n".format(ont)) 43 | f.write("\t".join(prot)) 44 | f.write("\n") 45 | 46 | 47 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 48 | 49 | ontologies = ["cc", "mf", "bp"] 50 | sptr = ['swissprot', 'trembl'] 51 | proteins = pickle_load("/home/fbqc9/Workspace/DATA/test/output_t1_t2/test_proteins") 52 | 53 | for ont in ontologies: 54 | print(ont) 55 | filename = ROOT_DIR + "evaluation/raw_predictions/deepgose/idmapping_2024_04_28_preds_{}.tsv".format(ont) 56 | data = read_to_dic(filename) 57 | 58 | for st in sptr: 59 | print(st) 60 | dir_pth = ROOT_DIR +"evaluation/predictions/{}_{}/".format(st, ont) 61 | create_directory(dir_pth) 62 | 63 | filt_proteins = proteins[ont][st] 64 | 65 | file_out = open(dir_pth+"{}.tsv".format("deepgose"), 'w') 66 | for prot in filt_proteins: 67 | try: 68 | annots = data[prot] 69 | for annot in annots: 70 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 71 | except KeyError: 72 | pass 73 | file_out.close() 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | exit() 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | test_group = pickle_load("/home/fbqc9/Workspace/DATA/test/t3/test_proteins") 96 | 97 | # add limited known and no known 98 | test_group = { 99 | 'bp': test_group['LK_bp'] | test_group['NK_bp'], 100 | 'mf': test_group['LK_mf'] | test_group['NK_mf'], 101 | 'cc': test_group['LK_cc'] | test_group['NK_cc'] 102 | } 103 | 104 | to_remove = {'C0HM98', 'C0HM97', 'C0HMA1', 'C0HM44'} 105 | 106 | 107 | 108 | output = {} 109 | for ont in ontologies: 110 | res = csv_to_dic("evaluation/predictions/deepgose/test_fasta_preds_{}.tsv".format(ont)) 111 | 112 | proteins = set(test_group[ont]).difference(to_remove) 113 | 114 | output[ont] = {key: res[key] for key in proteins} 115 | 116 | # to_file(output) 117 | 118 | # pickle_save(output, "evaluation/results/{}_out".format("deepgose")) 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /evaluation_scripts/format_netgo3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from Bio import SeqIO 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | 8 | 9 | 10 | def pickle_save(data, filename): 11 | with open('{}.pickle'.format(filename), 'wb') as handle: 12 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 13 | 14 | 15 | def pickle_load(filename): 16 | with open('{}.pickle'.format(filename), 'rb') as handle: 17 | return pickle.load(handle) 18 | 19 | 20 | # divide data into chunks of 1000 21 | def divide(fasta_file): 22 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 23 | seqs = [] 24 | for pos, record in enumerate(input_seq_iterator): 25 | seqs.append(record) 26 | if pos % 900 == 0 and pos>0: 27 | SeqIO.write(seqs, "/home/fbqc9/Workspace/TransFun2/evaluation/predictions/netgo/raw/{}".format('test_{}.fasta'.format(pos)), "fasta") 28 | seqs = [] 29 | SeqIO.write(seqs, "/home/fbqc9/Workspace/TransFun2/evaluation/predictions/netgo/raw/{}".format('test_{}.fasta'.format(pos)), "fasta") 30 | 31 | # divide("/home/fbqc9/Workspace/DATA/uniprot/test_fasta_rem.fasta") 32 | 33 | 34 | def combine_predictions(): 35 | path = "evaluation/predictions/netgo/result_{}.txt" 36 | results = [1699845719, 1699845761, 1699845770, 1699845780, 1699897400, 37 | 1705853012, 1705853252, 1705853349, 1705908596, 1705908641, 38 | 1705908662, 1705908663] 39 | 40 | with open("evaluation/predictions/netgo/combined_result.txt", 'w') as outfile: 41 | for fname in results: 42 | with open(path.format(fname)) as infile: 43 | outfile.write(infile.read()) 44 | 45 | 46 | 47 | def read_files(): 48 | 49 | res_dic = {'mf':{}, 'cc':{}, 'bp':{}} 50 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 51 | all_netgo_output = [1699845719, 1699845761, 1699845770, 1699845780, 1699897400, 52 | 1705853012, 1705853252, 1705853349, 1705908596, 1705908641, 53 | 1705908662, 1705908663] 54 | # test_group = pickle_load("/home/fbqc9/Workspace/DATA/test/t3/test_proteins") 55 | 56 | 57 | for netgo_output in all_netgo_output: 58 | with open(ROOT_DIR + "evaluation/raw_predictions/netgo/{}.txt".format("result_{}".format(netgo_output))) as f: 59 | lines = [line.rstrip('\n').split("\t") for line in f] 60 | 61 | for i in lines: 62 | if i[0] == "=====": 63 | continue 64 | if i[0] in res_dic[i[3]]: 65 | res_dic[i[3]][i[0]].append((i[1], float(i[2]))) 66 | else: 67 | res_dic[i[3]][i[0]] = [(i[1], float(i[2])), ] 68 | 69 | return res_dic 70 | 71 | 72 | def create_directory(dir): 73 | if not os.path.exists(dir): 74 | os.makedirs(dir) 75 | 76 | 77 | def write_to_files(): 78 | 79 | all_data = read_files() 80 | 81 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 82 | 83 | ontologies = ["cc", "mf", "bp"] 84 | sptr = ['swissprot', 'trembl'] 85 | proteins = pickle_load(ROOT_DIR + "test/output_t1_t2/test_proteins") 86 | 87 | for ont in ontologies: 88 | print("Ontology is {}".format(ont)) 89 | data = all_data[ont] 90 | 91 | for st in sptr: 92 | print("Catehory is {}".format(st)) 93 | 94 | dir_pth = ROOT_DIR +"evaluation/predictions/{}_{}/".format(st, ont) 95 | create_directory(dir_pth) 96 | 97 | filt_proteins = proteins[ont][st] 98 | 99 | file_out = open(dir_pth+"{}.tsv".format("netgo3"), 'w') 100 | for prot in filt_proteins: 101 | annots = data[prot] 102 | for annot in annots: 103 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 104 | file_out.close() 105 | 106 | 107 | 108 | write_to_files() -------------------------------------------------------------------------------- /evaluation_scripts/format_sprof.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from Bio import SeqIO 4 | from Bio.Seq import Seq 5 | from Bio.SeqRecord import SeqRecord 6 | 7 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 8 | 9 | def count_proteins(fasta_file): 10 | num = len([1 for line in open(fasta_file) if line.startswith(">")]) 11 | return num 12 | 13 | def pickle_save(data, filename): 14 | with open('{}.pickle'.format(filename), 'wb') as handle: 15 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 16 | 17 | 18 | def pickle_load(filename): 19 | with open('{}.pickle'.format(filename), 'rb') as handle: 20 | return pickle.load(handle) 21 | 22 | 23 | def create_directory(dir): 24 | if not os.path.exists(dir): 25 | os.makedirs(dir) 26 | 27 | 28 | def combine_predictions(): 29 | path = ROOT_DIR + "evaluation/raw_predictions/sprof/test_{}_all_preds.txt" 30 | in_files = [900, 1800, 2700, 3600, 4500, 5400, 6300, 7200, 8100, 9000, 9900, 10047] 31 | 32 | results = {"cc": {}, "bp": {}, "mf": {} } 33 | 34 | for in_file in in_files: 35 | print(path.format(in_file)) 36 | with open(path.format(in_file), 'r') as f: 37 | lines = f.readlines() 38 | 39 | mf_terms, bp_terms, cc_terms = lines[2].strip().split("; "), lines[5].strip().split("; "), lines[8].strip().split("; ") 40 | 41 | assert lines[10] == "\n" 42 | 43 | for line in lines[10:]: 44 | if line == "\n": 45 | pass 46 | elif line == "MF:\n": 47 | cur = "_mf" 48 | elif line == "CC:\n": 49 | cur = "_cc" 50 | elif line == "BP:\n": 51 | cur = "_bp" 52 | else: 53 | split_line = line.strip().split(";") 54 | if len(split_line) == 1: 55 | protein = split_line[0] 56 | elif cur == "_mf": 57 | assert len(split_line) == len(mf_terms) 58 | split_line = [float(i) for i in split_line] 59 | results['mf'][protein] = list(zip(mf_terms, split_line)) 60 | elif cur == "_cc": 61 | assert len(split_line) == len(cc_terms) 62 | split_line = [float(i) for i in split_line] 63 | results['cc'][protein] = list(zip(cc_terms, split_line)) 64 | elif cur == "_bp": 65 | assert len(split_line) == len(bp_terms) 66 | split_line = [float(i) for i in split_line] 67 | results['bp'][protein] = list(zip(bp_terms, split_line)) 68 | 69 | return results 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | ontologies = ["cc", "mf", "bp"] 78 | sptr = ['swissprot', 'trembl'] 79 | proteins = pickle_load(ROOT_DIR + "test/output_t1_t2/test_proteins") 80 | 81 | all_data = combine_predictions() 82 | 83 | 84 | for ont in ontologies: 85 | print("Ontology is {}".format(ont)) 86 | 87 | data = all_data[ont] 88 | 89 | for st in sptr: 90 | print("Category is {}".format(st)) 91 | 92 | dir_pth = ROOT_DIR +"evaluation/predictions/full/{}_{}/".format(st, ont) 93 | create_directory(dir_pth) 94 | 95 | filt_proteins = proteins[ont][st] 96 | 97 | 98 | file_out = open(dir_pth+"{}.tsv".format("sprof"), 'w') 99 | for prot in filt_proteins: 100 | annots = data[prot] 101 | for annot in annots: 102 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 103 | file_out.close() 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /evaluation_scripts/format_tale.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from Bio import SeqIO 4 | from Bio.SeqRecord import SeqRecord 5 | 6 | 7 | def pickle_save(data, filename): 8 | with open('{}.pickle'.format(filename), 'wb') as handle: 9 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 10 | 11 | 12 | def pickle_load(filename): 13 | with open('{}.pickle'.format(filename), 'rb') as handle: 14 | return pickle.load(handle) 15 | 16 | 17 | def format_input(fasta_file): 18 | seqs = [] 19 | input_seq_iterator = SeqIO.parse(fasta_file, "fasta") 20 | for pos, record in enumerate(input_seq_iterator): 21 | 22 | if len(record.seq) > 1000: 23 | seqs.append(SeqRecord(id=record.id, seq=record.seq[:1000], description=record.description)) 24 | else: 25 | seqs.append(record) 26 | 27 | SeqIO.write(seqs, "evaluation/predictions/tale/test_fasta.fasta", "fasta") 28 | 29 | 30 | # format_input("/home/fbqc9/Workspace/DATA/uniprot/test_fasta.fasta") 31 | # exit() 32 | 33 | 34 | def read_files(ont): 35 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 36 | res = {} 37 | with open(ROOT_DIR + "evaluation/raw_predictions/tale/{}.out".format(ont)) as f: 38 | lines = [line.rstrip('\n').replace(")", "")\ 39 | .replace("(", "") 40 | .split("\t") for line in f] 41 | 42 | for i in lines: 43 | term = i[1].replace("'", "").replace(" ", "").split(",")[0] 44 | if i[0] in res: 45 | res[i[0]].append((term, float(i[2]))) 46 | else: 47 | res[i[0]] = [(term, float(i[2])), ] 48 | 49 | return res 50 | 51 | 52 | def create_directory(dir): 53 | if not os.path.exists(dir): 54 | os.makedirs(dir) 55 | 56 | 57 | def write_to_files(): 58 | 59 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 60 | 61 | ontologies = ["cc", "mf", "bp"] 62 | sptr = ['swissprot', 'trembl'] 63 | proteins = pickle_load("/home/fbqc9/Workspace/DATA/test/output_t1_t2/test_proteins") 64 | 65 | for ont in ontologies: 66 | print("Ontology is {}".format(ont)) 67 | 68 | data = read_files(ont) 69 | 70 | 71 | for st in sptr: 72 | print("Catehory is {}".format(st)) 73 | 74 | dir_pth = ROOT_DIR +"evaluation/predictions/{}_{}/".format(st, ont) 75 | create_directory(dir_pth) 76 | 77 | filt_proteins = proteins[ont][st] 78 | 79 | file_out = open(dir_pth+"{}.tsv".format("tale"), 'w') 80 | for prot in filt_proteins: 81 | annots = data[prot] 82 | for annot in annots: 83 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 84 | file_out.close() 85 | 86 | 87 | write_to_files() -------------------------------------------------------------------------------- /evaluation_scripts/naive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import obonet 3 | import pandas as pd 4 | from collections import Counter 5 | import pickle 6 | import networkx as nx 7 | 8 | ROOT_DIR = "/home/fbqc9/Workspace/DATA/" 9 | 10 | def pickle_save(data, filename): 11 | with open('{}.pickle'.format(filename), 'wb') as handle: 12 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 13 | 14 | 15 | def pickle_load(filename): 16 | with open('{}.pickle'.format(filename), 'rb') as handle: 17 | return pickle.load(handle) 18 | 19 | 20 | def get_graph(go_path=ROOT_DIR + "obo/go-basic.obo"): 21 | go_graph = obonet.read_obo(open(go_path, 'r')) 22 | 23 | accepted_edges = set() 24 | unaccepted_edges = set() 25 | 26 | for edge in go_graph.edges: 27 | if edge[2] == 'is_a' or edge[2] == 'part_of': 28 | accepted_edges.add(edge) 29 | else: 30 | unaccepted_edges.add(edge) 31 | go_graph.remove_edges_from(unaccepted_edges) 32 | 33 | return go_graph 34 | 35 | 36 | def create_directory(dir): 37 | if not os.path.exists(dir): 38 | os.makedirs(dir) 39 | 40 | 41 | def create_naive(groundtruth): 42 | 43 | frequency = [] 44 | for prot, annot in groundtruth.items(): 45 | frequency.extend(list(annot)) 46 | 47 | cnt = Counter(frequency) 48 | max_n = cnt.most_common(1)[0][1] 49 | 50 | scores = [] 51 | for go_id, n in cnt.items(): 52 | score = n / max_n 53 | scores.append((go_id, score)) 54 | 55 | return scores 56 | 57 | 58 | parent_terms = { 59 | 'cc': 'GO:0005575', 60 | 'mf': 'GO:0003674', 61 | 'bp': 'GO:0008150' 62 | } 63 | 64 | go_graph = get_graph() 65 | test_group = pickle_load("/home/fbqc9/Workspace/DATA/test/output_t1_t2/test_proteins") 66 | groundtruth = pickle_load("/home/fbqc9/Workspace/DATA/groundtruth") 67 | 68 | for ont in test_group: 69 | 70 | parent_term = parent_terms[ont] 71 | train_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/train_proteins".format(ont))) 72 | valid_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/validation_proteins".format(ont))) 73 | 74 | data = train_data + valid_data 75 | 76 | all_go_terms = nx.ancestors(go_graph, parent_term)#.union(set([parent_term])) 77 | 78 | ontology_groundtruth = {prot: set(groundtruth[prot]).intersection(all_go_terms) for prot in data} 79 | naive_scores = create_naive(ontology_groundtruth) 80 | 81 | for sptr in test_group[ont]: 82 | print("Swissprot or Trembl is {}".format(sptr)) 83 | 84 | dir_pth = ROOT_DIR +"evaluation/predictions/{}_{}/".format(sptr, ont) 85 | create_directory(dir_pth) 86 | 87 | filt_proteins = test_group[ont][sptr] 88 | 89 | file_out = open(dir_pth+"{}.tsv".format("naive"), 'w') 90 | for prot in filt_proteins: 91 | for annot in naive_scores: 92 | file_out.write(prot + '\t' + annot[0] + '\t' + str(annot[1]) + '\n') 93 | file_out.close() 94 | -------------------------------------------------------------------------------- /evaluation_seqID.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from matplotlib import pyplot as plt, rcParams 5 | import numpy as np 6 | from Utils import count_proteins, create_directory, pickle_save, pickle_load 7 | from collections import Counter 8 | import math 9 | import CONSTANTS 10 | import obonet 11 | import networkx as nx 12 | from Bio import SeqIO 13 | 14 | 15 | def filter_fasta(proteins, infile, outfile): 16 | seqs = [] 17 | input_seq_iterator = SeqIO.parse(infile, "fasta") 18 | 19 | for pos, record in enumerate(input_seq_iterator): 20 | if record.id in proteins: 21 | seqs.append(record) 22 | SeqIO.write(seqs, outfile, "fasta") 23 | 24 | 25 | def extract_from_results(infile): 26 | file = open(infile) 27 | lines = [] 28 | for _line in file.readlines(): 29 | line = _line.strip("\n").split("\t") 30 | lines.append((line[0], line[1], line[3])) 31 | file.close() 32 | return lines 33 | 34 | 35 | def get_seq_less(ontology, test_proteins, seq_id=0.3): 36 | # mmseqs createdb ... | [options] 37 | 38 | full_train_fasta = "/home/fbqc9/Workspace/DATA/uniprot/train_sequences.fasta" 39 | test_fasta = "/home/fbqc9/Workspace/DATA/uniprot/test_fasta.fasta" 40 | 41 | # Number of training proteins & test proteins in combined fasta 42 | # print("# Training: {}, # Testing {}".format(count_proteins(full_train_fasta), count_proteins(test_fasta))) 43 | 44 | train_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/train_proteins".format(ontology))) 45 | valid_data = list(pickle_load("/home/fbqc9/Workspace/DATA/{}/validation_proteins".format(ontology))) 46 | 47 | 48 | train_data = set(train_data + valid_data) # set for fast lookup 49 | test_proteins = set(test_proteins) # set for fast lookup 50 | 51 | # Number of training proteins & test proteins 52 | print("# Training & Validation: {}, # Testing {}".format(len(train_data), len(test_proteins))) 53 | 54 | # No train data in test data 55 | assert len(train_data.intersection(test_proteins)) == 0 56 | 57 | 58 | # make temporary directory 59 | wkdir = "/home/fbqc9/Workspace/TransFun2/evaluation/seqID/{}".format(seq_id) 60 | create_directory(wkdir) 61 | 62 | # create query and target databases 63 | target_fasta = wkdir+"/target_fasta" 64 | query_fasta = wkdir+"/query_fasta" 65 | filter_fasta(train_data, full_train_fasta, target_fasta) 66 | filter_fasta(test_proteins, test_fasta, query_fasta) 67 | 68 | # All train and test in respective fasta 69 | assert len(train_data) == count_proteins(target_fasta) 70 | assert len(test_proteins) == count_proteins(query_fasta) 71 | 72 | print("Creating target Database") 73 | target_dbase = wkdir+"/target_dbase" 74 | CMD = "mmseqs createdb {} {}".format(target_fasta, target_dbase) 75 | subprocess.call(CMD, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) 76 | 77 | print("Creating query Database") 78 | query_dbase = wkdir+"/query_dbase" 79 | CMD = "mmseqs createdb {} {}".format(query_fasta, query_dbase) 80 | subprocess.call(CMD, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) 81 | 82 | print("Mapping very similar sequences") 83 | result_dbase = wkdir+"/result_dbase" 84 | CMD = "mmseqs map {} {} {} {} --min-seq-id {}".\ 85 | format(query_dbase, target_dbase, result_dbase, wkdir, seq_id) 86 | subprocess.call(CMD, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) 87 | 88 | bestResultDB = wkdir+"/bestResultDB" 89 | CMD = "mmseqs filterdb {} {} --extract-lines 1".format(result_dbase, bestResultDB) 90 | subprocess.call(CMD, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) 91 | 92 | final_res = wkdir+"/final_res.tsv" 93 | CMD = "mmseqs createtsv {} {} {} {}".format(query_dbase, target_dbase, bestResultDB, final_res) 94 | subprocess.call(CMD, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) 95 | 96 | 97 | lines = extract_from_results(final_res) 98 | 99 | shutil.rmtree(wkdir) 100 | 101 | querys, targets, seq_ids = zip(*lines) 102 | 103 | querys = set(querys) 104 | targets = set(targets) 105 | 106 | assert len(train_data.intersection(querys)) == 0 107 | assert len(test_proteins.intersection(targets)) == 0 108 | 109 | assert len(train_data.intersection(targets)) == len(targets) 110 | assert len(test_proteins.intersection(querys)) == len(querys) 111 | 112 | 113 | # the proteins with less than X seq identity to the training set 114 | return test_proteins.difference(querys) 115 | 116 | 117 | def read_filter_write(proteins, in_file, out_file): 118 | 119 | with open(in_file) as f: 120 | lines = [line.strip('\n').split("\t") for line in f] 121 | 122 | lines = [i for i in lines if i[0] in proteins] 123 | 124 | '''file_out = open(out_file, 'w') 125 | file_out.write('\t'.join(lines) + '\n') 126 | file_out.close()''' 127 | 128 | file_out = open(out_file, 'w') 129 | for line in lines: 130 | file_out.write('\t'.join(line) + '\n') 131 | # file_out.write(line[0] + '\t' + line[1] + '\t' + line[2] + '\n') 132 | file_out.close() 133 | 134 | 135 | all_proteins = pickle_load(CONSTANTS.ROOT_DIR + "test/output_t1_t2/test_proteins") 136 | 137 | methods = ['naive', 'diamond', 'tale', 'deepgose', 'netgo3', 'sprof', 'transfew'] 138 | 139 | in_file_pths = CONSTANTS.ROOT_DIR + "evaluation/predictions/full/{}_{}/{}.tsv" 140 | out_file_pths = CONSTANTS.ROOT_DIR + "evaluation/predictions/seq_ID_30/{}_{}/" 141 | 142 | gt_in_file_pths = CONSTANTS.ROOT_DIR + "test/output_t1_t2/groundtruths/full/{}_{}.tsv" 143 | gt_out_file_pths = CONSTANTS.ROOT_DIR + "test/output_t1_t2/groundtruths/seq_ID_30/{}_{}.tsv" 144 | 145 | def main(): 146 | for ont in all_proteins: 147 | for sptr in all_proteins[ont]: 148 | 149 | proteins = all_proteins[ont][sptr] 150 | proteins = get_seq_less(ontology=ont, test_proteins=proteins) 151 | 152 | print("Writing groundtruth {} {}".format(ont, sptr)) 153 | read_filter_write(proteins, gt_in_file_pths.format(ont, sptr), gt_out_file_pths.format(ont, sptr)) 154 | 155 | create_directory(out_file_pths.format(sptr, ont)) 156 | # write output from all_output 157 | for method in methods: 158 | 159 | print("Ontology: {} --- DB subset {} --- Method {}".format(ont, sptr, method)) 160 | read_filter_write(proteins, in_file_pths.format(sptr, ont, method), out_file_pths.format(sptr, ont) + "{}.tsv".format(method)) 161 | 162 | 163 | 164 | if __name__ == '__main__': 165 | 166 | main() 167 | 168 | -------------------------------------------------------------------------------- /external/extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import pathlib 9 | 10 | import torch 11 | 12 | from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer 13 | 14 | 15 | def create_parser(): 16 | parser = argparse.ArgumentParser( 17 | description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa 18 | ) 19 | 20 | parser.add_argument( 21 | "model_location", 22 | type=str, 23 | help="PyTorch model file OR name of pretrained model to download (see README for models)", 24 | ) 25 | parser.add_argument( 26 | "fasta_file", 27 | type=pathlib.Path, 28 | help="FASTA file on which to extract representations", 29 | ) 30 | parser.add_argument( 31 | "output_dir", 32 | type=pathlib.Path, 33 | help="output directory for extracted representations", 34 | ) 35 | 36 | parser.add_argument("--toks_per_batch", type=int, default=4096, help="maximum batch size") 37 | parser.add_argument( 38 | "--repr_layers", 39 | type=int, 40 | default=[-1], 41 | nargs="+", 42 | help="layers indices from which to extract representations (0 to num_layers, inclusive)", 43 | ) 44 | parser.add_argument( 45 | "--include", 46 | type=str, 47 | nargs="+", 48 | choices=["mean", "per_tok", "bos", "contacts"], 49 | help="specify which representations to return", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "--truncation_seq_length", 54 | type=int, 55 | default=1022, 56 | help="truncate sequences longer than the given value", 57 | ) 58 | 59 | parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") 60 | return parser 61 | 62 | 63 | def main(args): 64 | model, alphabet = pretrained.load_model_and_alphabet(args.model_location) 65 | model.eval() 66 | if isinstance(model, MSATransformer): 67 | raise ValueError( 68 | "This script currently does not handle models with MSA input (MSA Transformer)." 69 | ) 70 | if torch.cuda.is_available() and not args.nogpu: 71 | model = model.to(device="cuda:1") 72 | print("Transferred model to GPU") 73 | 74 | dataset = FastaBatchedDataset.from_file(args.fasta_file) 75 | batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) 76 | data_loader = torch.utils.data.DataLoader( 77 | dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches 78 | ) 79 | print(f"Read {args.fasta_file} with {len(dataset)} sequences") 80 | 81 | args.output_dir.mkdir(parents=True, exist_ok=True) 82 | return_contacts = "contacts" in args.include 83 | 84 | assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) 85 | repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers] 86 | 87 | with torch.no_grad(): 88 | for batch_idx, (labels, strs, toks) in enumerate(data_loader): 89 | print( 90 | f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" 91 | ) 92 | if torch.cuda.is_available() and not args.nogpu: 93 | toks = toks.to(device="cuda:1", non_blocking=True) 94 | 95 | out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts) 96 | 97 | logits = out["logits"].to(device="cpu") 98 | representations = { 99 | layer: t.to(device="cpu") for layer, t in out["representations"].items() 100 | } 101 | if return_contacts: 102 | contacts = out["contacts"].to(device="cpu") 103 | 104 | for i, label in enumerate(labels): 105 | args.output_file = args.output_dir / f"{label}.pt" 106 | args.output_file.parent.mkdir(parents=True, exist_ok=True) 107 | result = {"label": label} 108 | truncate_len = min(args.truncation_seq_length, len(strs[i])) 109 | # Call clone on tensors to ensure tensors are not views into a larger representation 110 | # See https://github.com/pytorch/pytorch/issues/1995 111 | if "per_tok" in args.include: 112 | result["representations"] = { 113 | layer: t[i, 1 : truncate_len + 1].clone() 114 | for layer, t in representations.items() 115 | } 116 | if "mean" in args.include: 117 | result["mean_representations"] = { 118 | layer: t[i, 1 : truncate_len + 1].mean(0).clone() 119 | for layer, t in representations.items() 120 | } 121 | if "bos" in args.include: 122 | result["bos_representations"] = { 123 | layer: t[i, 0].clone() for layer, t in representations.items() 124 | } 125 | if return_contacts: 126 | result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() 127 | 128 | torch.save( 129 | result, 130 | args.output_file, 131 | ) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = create_parser() 136 | args = parser.parse_args() 137 | main(args) 138 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | cc = { 2 | 'esm2_t48' : { 3 | 'lr': 0.0001, #0.0001, #0.00001, #0.0001 4 | 'batch_size': 500, 5 | 'epochs': 50, 6 | 'weight_decay': 5e-4, 7 | 'lr_scheduler': 50, 8 | 'weight_factor': 2 9 | }, 10 | 'msa_1b' : { 11 | 'lr': 0.0001, 12 | 'batch_size': 500, 13 | 'epochs': 50, 14 | 'weight_decay': 5e-4, 15 | 'lr_scheduler': 50, 16 | 'weight_factor': 2 17 | }, 18 | 'interpro': { 19 | 'lr': 0.0001, 20 | 'batch_size': 500, 21 | 'epochs': 50, 22 | 'weight_decay': 5e-4, 23 | 'lr_scheduler': 50, 24 | 'weight_factor': 2 25 | }, 26 | 'full' : { 27 | 'lr': 0.0001, #0.00005, #0.0001, #0.00005, #0.00003, #0.0001,#0.00005, 28 | 'batch_size': 256, #256, #500, #64 29 | 'epochs': 150, 30 | 'weight_decay': 0.0001, #5e-4, #1e-2, #1e-2, #5e-4 31 | 'lr_scheduler': 50, 32 | 'weight_factor': 3 33 | }, 34 | 'label_ae' : { 35 | 'lr': 0.001, 36 | 'weight_decay': 1e-2, 37 | # 'lr': 0.0001, for GAT cc 38 | } 39 | } 40 | 41 | # rare: 2, full: 3 42 | 43 | mf = { 44 | 'esm2_t48' : { 45 | 'lr': 0.0001, 46 | 'batch_size': 500, 47 | 'epochs': 100, 48 | 'weight_decay': 5e-4, 49 | 'lr_scheduler': 50, 50 | 'weight_factor': 2 51 | }, 52 | 'msa_1b' : { 53 | 'lr': 0.0001, 54 | 'batch_size': 500, 55 | 'epochs': 50, 56 | 'weight_decay': 5e-4, 57 | 'lr_scheduler': 50, 58 | 'weight_factor': 2 59 | }, 60 | 'interpro': { 61 | 'lr': 0.0001, 62 | 'batch_size': 500, 63 | 'epochs': 50, 64 | 'weight_decay': 5e-4, 65 | 'lr_scheduler': 50, 66 | 'weight_factor': 2 67 | }, 68 | 'full' : { 69 | 'lr': 0.0001, #0.0001, 70 | 'batch_size': 256, 71 | 'epochs': 150, 72 | 'weight_decay': 0.0001, #1e-4, #5e-4 73 | 'lr_scheduler': 50, 74 | 'weight_factor': 3 75 | }, 76 | 'label_ae' : { 77 | 'lr': 0.0001, 78 | 'weight_decay': 1e-2, 79 | } 80 | } 81 | 82 | bp = { 83 | 'esm2_t48' : { 84 | 'lr': 0.0001, #0.00005, #0.0001, #0.00005, 85 | 'batch_size': 500, 86 | 'epochs': 52, 87 | 'weight_decay': 5e-4, 88 | 'lr_scheduler': 50, 89 | 'weight_factor': 2 90 | }, 91 | 'msa_1b' : { 92 | 'lr': 0.0001, 93 | 'batch_size': 250, 94 | 'epochs': 50, 95 | 'weight_decay': 5e-4, 96 | 'lr_scheduler': 50, 97 | 'weight_factor': 2 98 | }, 99 | 'interpro': { 100 | 'lr': 0.0001, 101 | 'batch_size': 250, 102 | 'epochs': 50, 103 | 'weight_decay': 5e-4, 104 | 'lr_scheduler': 50, 105 | 'weight_factor': 2 106 | }, 107 | 'diamond': { 108 | 'lr': 0.0001, 109 | 'batch_size': 500, 110 | 'epochs': 50, 111 | 'weight_decay': 5e-4 112 | }, 113 | 'string' : { 114 | 'lr': 0.0001, 115 | 'batch_size': 500, 116 | 'epochs': 50, 117 | 'weight_decay': 5e-4 118 | }, 119 | 'full' : { 120 | 'lr': 0.0001, 121 | 'batch_size': 256, 122 | 'epochs': 150, 123 | 'weight_decay': 1e-4,#1e-2,#5e-4 124 | 'lr_scheduler': 50, 125 | 'weight_factor': 3 126 | }, 127 | 'label_ae' : { 128 | 'lr': 0.0001, 129 | 'weight_decay': 1e-2, 130 | } 131 | 132 | } -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Utils import create_directory, load_ckp, pickle_load, pickle_save 3 | import CONSTANTS 4 | import math, os, time 5 | import argparse 6 | from models.model import TFun, TFun_submodel 7 | from Dataset.MyDataset import TestDataset 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 12 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 13 | parser.add_argument("--load_weights", default=False, type=bool, help='Load weights from saved model') 14 | args = parser.parse_args() 15 | 16 | torch.manual_seed(args.seed) 17 | args.cuda = not args.no_cuda and torch.cuda.is_available() 18 | 19 | if args.cuda: 20 | device = 'cuda:1' 21 | else: 22 | device = 'cpu' 23 | 24 | # load all test 25 | all_test = pickle_load(CONSTANTS.ROOT_DIR + "test/output_t1_t2/test_proteins") 26 | 27 | 28 | ontologies = ["cc", "mf", "bp"] 29 | models = ['esm2_t48', 'msa_1b', 'interpro', 'full'] 30 | 31 | 32 | 33 | def write_output(results, terms, filepath, cutoff=0.001): 34 | with open(filepath, 'w') as fp: 35 | for prt in results: 36 | assert len(terms) == len(results[prt]) 37 | tmp = list(zip(terms, results[prt])) 38 | tmp.sort(key = lambda x: x[1], reverse=True) 39 | for trm, score in tmp: 40 | if score > cutoff: 41 | fp.write('%s\t%s\t%0.3f\n' % (prt, trm, score)) 42 | 43 | 44 | 45 | def get_term_indicies(ontology): 46 | 47 | _term_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/term_indicies".format(ontology)) 48 | 49 | if ontology == 'bp': 50 | full_term_indicies, mid_term_indicies, freq_term_indicies = _term_indicies[0], _term_indicies[5], _term_indicies[30] 51 | rare_term_indicies_2 = torch.tensor([i for i in full_term_indicies if not i in set(mid_term_indicies)]).to(device) 52 | rare_term_indicies = torch.tensor([i for i in mid_term_indicies if not i in set(freq_term_indicies)]).to(device) 53 | full_term_indicies, freq_term_indicies = torch.tensor(_term_indicies[0]).to(device), torch.tensor(freq_term_indicies).to(device) 54 | else: 55 | full_term_indicies = _term_indicies[0] 56 | freq_term_indicies = _term_indicies[30] 57 | rare_term_indicies = torch.tensor([i for i in full_term_indicies if not i in set(freq_term_indicies)]).to(device) 58 | full_term_indicies = torch.tensor(full_term_indicies).to(device) 59 | freq_term_indicies = torch.tensor(freq_term_indicies).to(device) 60 | rare_term_indicies_2 = None 61 | 62 | return full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 63 | 64 | 65 | 66 | for ontology in ontologies: 67 | 68 | data_pth = CONSTANTS.ROOT_DIR + "test/dataset/{}".format(ontology) 69 | sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology)) 70 | 71 | for sub_model in models: 72 | 73 | tst_dataset = TestDataset(data_pth=data_pth, submodel=sub_model) 74 | tstloader = torch.utils.data.DataLoader(tst_dataset, batch_size=500, shuffle=False) 75 | # terms, term_indicies, sub_indicies = get_term_indicies(ontology=ontology, submodel=sub_model) 76 | full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 = get_term_indicies(ontology=ontology) 77 | 78 | 79 | kwargs = { 80 | 'device': device, 81 | 'ont': ontology, 82 | 'full_indicies': full_term_indicies, 83 | 'freq_indicies': freq_term_indicies, 84 | 'rare_indicies': rare_term_indicies, 85 | 'rare_indicies_2': rare_term_indicies_2, 86 | 'sub_model': sub_model, 87 | 'load_weights': True, 88 | 'group': "" 89 | } 90 | 91 | if sub_model == "full": 92 | print("Generating for {} {}".format(ontology, sub_model)) 93 | 94 | ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}_gcn_old/'.format(ontology, sub_model) 95 | ckp_pth = ckp_dir + "current_checkpoint.pt" 96 | model = TFun(**kwargs) 97 | 98 | # load model 99 | model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True) 100 | 101 | model.to(device) 102 | model.eval() 103 | 104 | results = {} 105 | for data in tstloader: 106 | _features, _proteins = data[:4], data[4] 107 | output = model(_features) 108 | output = torch.index_select(output, 1, full_term_indicies) 109 | output = output.tolist() 110 | 111 | for i, j in zip(_proteins, output): 112 | results[i] = j 113 | 114 | terms = [sorted_terms[i] for i in full_term_indicies] 115 | 116 | 117 | filepath = CONSTANTS.ROOT_DIR + 'evaluation/raw_predictions/transfew/' 118 | create_directory(filepath) 119 | write_output(results, terms, filepath+'{}.tsv'.format(ontology), cutoff=0.01) 120 | 121 | else: 122 | print("Generating for {} {}".format(ontology, sub_model)) 123 | 124 | ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}/'.format(ontology, sub_model) 125 | ckp_pth = ckp_dir + "current_checkpoint.pt" 126 | 127 | model = TFun_submodel(**kwargs) 128 | model.to(device) 129 | 130 | # print("Loading model checkpoint @ {}".format(ckp_pth)) 131 | model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True) 132 | model.eval() 133 | 134 | results = {} 135 | for data in tstloader: 136 | _features, _proteins = data[0], data[1] 137 | 138 | output = model(_features).tolist() 139 | for i, j in zip(_proteins, output): 140 | results[i] = j 141 | 142 | terms = [sorted_terms[i] for i in freq_term_indicies] 143 | 144 | filepath = CONSTANTS.ROOT_DIR + 'evaluation/raw_predictions/{}/'.format(sub_model) 145 | create_directory(filepath) 146 | write_output(results, terms, filepath+'{}.tsv'.format(ontology), cutoff=0.01) 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /inference_combined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Dataset.FastDataset import FastTransFunDataset 3 | from Utils import load_ckp, pickle_load, pickle_save 4 | import CONSTANTS 5 | import math, os, time 6 | import argparse 7 | from models.model_ablation import TFun, TFun_submodel 8 | from Dataset.FastDataset import TestDataset 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 12 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 13 | parser.add_argument("--load_weights", default=False, type=bool, help='Load weights from saved model') 14 | parser.add_argument('--label_features', type=str, default='linear', help='Sub model to train') 15 | args = parser.parse_args() 16 | 17 | torch.manual_seed(args.seed) 18 | args.cuda = not args.no_cuda and torch.cuda.is_available() 19 | 20 | if args.cuda: 21 | device = 'cuda:1' 22 | else: 23 | device = 'cpu' 24 | 25 | # load all test 26 | all_test = pickle_load(CONSTANTS.ROOT_DIR + "test/t3/test_proteins") 27 | 28 | 29 | ontologies = ["cc", "mf", "bp"] 30 | annotation_depths = ["LK", "NK"] 31 | 32 | sub_models = ['full'] 33 | label_features = ['gcn'] 34 | 35 | 36 | 37 | def write_output(results, terms, filepath, cutoff=0.001): 38 | with open(filepath, 'w') as fp: 39 | for prt in results: 40 | assert len(terms) == len(results[prt]) 41 | tmp = list(zip(terms, results[prt])) 42 | tmp.sort(key = lambda x: x[1], reverse=True) 43 | for trm, score in tmp: 44 | if score > cutoff: 45 | fp.write('%s\t%s\t%0.3f\n' % (prt, trm, score)) 46 | 47 | 48 | def get_term_indicies(ontology, submodel="full", label_feature="max"): 49 | 50 | _term_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/term_indicies".format(ontology)) 51 | 52 | if ontology == 'bp': 53 | full_term_indicies, mid_term_indicies, freq_term_indicies = _term_indicies[0], _term_indicies[5], _term_indicies[30] 54 | rare_term_indicies_2 = torch.tensor([i for i in full_term_indicies if not i in set(mid_term_indicies)]).to(device) 55 | rare_term_indicies = torch.tensor([i for i in mid_term_indicies if not i in set(freq_term_indicies)]).to(device) 56 | full_term_indicies, freq_term_indicies = torch.tensor(_term_indicies[0]).to(device), torch.tensor(freq_term_indicies).to(device) 57 | else: 58 | full_term_indicies = _term_indicies[0] 59 | freq_term_indicies = _term_indicies[30] 60 | rare_term_indicies = torch.tensor([i for i in full_term_indicies if not i in set(freq_term_indicies)]).to(device) 61 | full_term_indicies = torch.tensor(full_term_indicies).to(device) 62 | freq_term_indicies = torch.tensor(freq_term_indicies).to(device) 63 | rare_term_indicies_2 = None 64 | 65 | return full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 66 | 67 | 68 | '''if submodel == 'full' and label_feature not in ['max', 'mean']: 69 | term_indicies = torch.tensor(_term_indicies[0]) 70 | sub_indicies = torch.tensor(_term_indicies[threshold[ontology]]) 71 | else: 72 | term_indicies = torch.tensor(_term_indicies[threshold[ontology]]) 73 | sub_indicies = term_indicies 74 | 75 | 76 | sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology)) 77 | 78 | terms = [sorted_terms[i] for i in term_indicies] 79 | 80 | return terms, term_indicies, sub_indicies''' 81 | 82 | 83 | 84 | for annotation_depth in annotation_depths: 85 | for ontology in ontologies: 86 | 87 | data_pth = CONSTANTS.ROOT_DIR + "test/t3/dataset/{}_{}".format(annotation_depth, ontology) 88 | sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology)) 89 | 90 | for sub_model in sub_models: 91 | 92 | tst_dataset = TestDataset(data_pth=data_pth, submodel=sub_model) 93 | tstloader = torch.utils.data.DataLoader(tst_dataset, batch_size=500, shuffle=False) 94 | # terms, term_indicies, sub_indicies = get_term_indicies(ontology=ontology, submodel=sub_model) 95 | full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 = get_term_indicies(ontology=ontology, submodel=sub_model) 96 | 97 | 98 | kwargs = { 99 | 'device': device, 100 | 'ont': ontology, 101 | 'full_indicies': full_term_indicies, 102 | 'freq_indicies': freq_term_indicies, 103 | 'rare_indicies': rare_term_indicies, 104 | 'rare_indicies_2': rare_term_indicies_2, 105 | 'sub_model': sub_model, 106 | 'load_weights': True, 107 | 'label_features': "", 108 | 'group': "" 109 | } 110 | 111 | 112 | for label_feature in label_features: 113 | print("Generating for {} {} {} {}".format(annotation_depth, ontology, sub_model, label_feature)) 114 | 115 | kwargs['label_features'] = label_feature 116 | 117 | ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}_{}_combined/'.format(ontology, sub_model, label_feature) 118 | ckp_pth = ckp_dir + "current_checkpoint.pt" 119 | model = TFun(**kwargs) 120 | 121 | # load model 122 | if label_feature != 'max' and label_feature != 'mean': 123 | model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True) 124 | 125 | model.to(device) 126 | model.eval() 127 | 128 | results = {} 129 | for data in tstloader: 130 | _features, _proteins = data[:4], data[4] 131 | output, _ = model(_features) 132 | output = torch.index_select(output, 1, full_term_indicies) 133 | output = output.tolist() 134 | 135 | for i, j in zip(_proteins, output): 136 | results[i] = j 137 | 138 | terms = [sorted_terms[i] for i in full_term_indicies] 139 | 140 | filepath = 'evaluation/predictions/transfew/{}_{}_{}_combined_{}.tsv'.format(annotation_depth, ontology, sub_model, label_feature) 141 | write_output(results, terms, filepath, cutoff=0.01) 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /label_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import wandb 6 | import CONSTANTS 7 | import torch_geometric.transforms as T 8 | from Utils import load_ckp, save_ckp 9 | from models.model import LabelEncoder 10 | import time 11 | from torch_geometric.nn import GAE 12 | import hparams 13 | from num2words import num2words 14 | 15 | os.environ["WANDB_API_KEY"] = "" 16 | os.environ["WANDB_MODE"] = "online" 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') 20 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 21 | parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.') 22 | parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate.') 23 | parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay (L2 loss on parameters).') 24 | parser.add_argument("--ont", default='bp', type=str, help='Ontology under consideration') 25 | parser.add_argument("--load_weights", default=False, type=bool, help='Load weights from saved model') 26 | parser.add_argument("--save_weights", default=False, type=bool, help='Save model weights') 27 | parser.add_argument("--log_output", default=False, type=bool, help='Log output to weights and bias') 28 | parser.add_argument("--features", default='biobert', type=str, help='Which features to use') 29 | parser.add_argument("--gnn", default='GCN', type=str, help='Which GNN network to use') 30 | parser.add_argument("--out_channels", default=1024, type=int, help='Final Hidden Dimension') 31 | 32 | 33 | args = parser.parse_args() 34 | 35 | torch.manual_seed(args.seed) 36 | 37 | 38 | hyps = getattr(hparams, args.ont) 39 | args.weight_decay = hyps['label_ae']['weight_decay'] 40 | args.lr = hyps['label_ae']['lr'] 41 | 42 | args.cuda = not args.no_cuda and torch.cuda.is_available() 43 | device = 'cpu' 44 | if args.cuda: 45 | device = 'cuda:1' 46 | 47 | print("Ontology: {}, Learning rate: {}, Weight Decay: {}, Device: {}"\ 48 | .format(args.ont, args.lr, args.weight_decay, device)) 49 | 50 | 51 | 52 | graph_path = CONSTANTS.ROOT_DIR + '{}/graph.pt'.format(args.ont) 53 | data = torch.load(graph_path) 54 | 55 | if args.features == 'x': 56 | if args.ont == 'cc': 57 | num_features = 2957 58 | elif args.ont == 'mf': 59 | num_features = 7224 60 | elif args.ont == 'bp': 61 | num_features = 21285 62 | elif args.features == 'biobert': 63 | num_features = 768 64 | elif args.features == 'bert': 65 | num_features = 1024 66 | 67 | variational = False 68 | 69 | 70 | transform = T.Compose([ 71 | T.NormalizeFeatures(), 72 | T.ToDevice(device), 73 | T.RandomLinkSplit(num_val=0.2, num_test=0.0, is_undirected=False, 74 | split_labels=True, add_negative_train_samples=True), 75 | ]) 76 | train_data, val_data, _ = transform(data) 77 | 78 | 79 | def check_data_integrity(data, train_data): 80 | # wasn't sure about the Random split and wanted to confirm whether 81 | # the positive is not in the negative. 82 | 83 | a1, b1 = data.edge_index 84 | t1 = list(zip(a1.tolist(), b1.tolist())) 85 | t1 = sorted(t1, key=lambda element: (element[0], element[1])) 86 | 87 | 88 | a1, b1 = train_data.pos_edge_label_index 89 | t2 = list(zip(a1.tolist(), b1.tolist())) 90 | t2 = sorted(t2, key=lambda element: (element[0], element[1])) 91 | 92 | 93 | a1, b1 = train_data.neg_edge_label_index 94 | t3 = list(zip(a1.tolist(), b1.tolist())) 95 | t3 = sorted(t3, key=lambda element: (element[0], element[1])) 96 | 97 | 98 | assert len(t2) == len(set(t1).intersection(set(t2))) 99 | 100 | assert len(set(t1).intersection(set(t3))) == 0 101 | 102 | 103 | # check_data_integrity(data=data, train_data=train_data) 104 | 105 | 106 | def count_params(model): 107 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 108 | 109 | 110 | def train(data, model, optimizer): 111 | model.train() 112 | optimizer.zero_grad() 113 | z = model.encode(data[args.features], data.edge_index) 114 | loss = model.recon_loss(z, data.pos_edge_label_index) 115 | auc, ap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) 116 | if variational: 117 | print("Using variational loss") 118 | loss = loss + (1 / data.num_nodes) * model.kl_loss() 119 | loss.backward() 120 | optimizer.step() 121 | return float(loss), auc, ap 122 | 123 | 124 | @torch.no_grad() 125 | def validate(data, model): 126 | model.eval() 127 | z = model.encode(data[args.features], data.edge_index) 128 | loss = model.recon_loss(z, data.pos_edge_label_index) 129 | auc, ap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) 130 | return float(loss), auc, ap 131 | 132 | 133 | def train_model(start_epoch, min_val_loss, train_data, 134 | val_data, model, optimizer): 135 | 136 | for epoch in range(start_epoch, args.epochs): 137 | print(" ---------- Epoch {} ----------".format(epoch)) 138 | 139 | start = time.time() 140 | 141 | loss, auc, ap = train(data=train_data, model=model, optimizer=optimizer) 142 | val_loss, val_auc, val_ap = validate(val_data, model=model) 143 | 144 | epoch_time = time.time() - start 145 | 146 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}, Val loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}, time: {epoch_time:.4f}') 147 | 148 | 149 | if args.log_output: 150 | wandb.log({"Epoch": epoch, 151 | "train_loss": loss, 152 | "train_auc": auc, 153 | "train_ap": ap, 154 | "val_loss": val_loss, 155 | "val_auc": val_auc, 156 | "val_ap": val_ap, 157 | "time": epoch_time 158 | }) 159 | 160 | checkpoint = { 161 | 'epoch': epoch, 162 | 'valid_loss_min': val_loss, 163 | 'state_dict': model.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'lr_scheduler': None 166 | } 167 | 168 | if args.save_weights: 169 | # save checkpoint 170 | save_ckp(checkpoint, False, ckp_dir) 171 | 172 | if val_loss <= min_val_loss: 173 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'. \ 174 | format(min_val_loss, val_loss)) 175 | 176 | # save checkpoint as best model 177 | save_ckp(checkpoint, True, ckp_dir) 178 | min_val_loss = val_loss 179 | 180 | 181 | model = GAE(LabelEncoder(num_features, args.out_channels, features=args.features, gnn=args.gnn)) 182 | model = model.to(device) 183 | 184 | print(num2words(count_params(model))) 185 | 186 | 187 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 188 | 189 | ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/label/{}'.format(args.ont, args.gnn) 190 | ckp_pth = ckp_dir + "current_checkpoint.pt" 191 | 192 | if args.load_weights and os.path.exists(ckp_pth): 193 | print("Loading model checkpoint @ {}".format(ckp_pth)) 194 | model, optimizer, lr_scheduler, current_epoch, min_val_loss = load_ckp(checkpoint_dir=ckp_dir, model=model, optimizer=optimizer, lr_scheduler=None, best_model=True) 195 | else: 196 | current_epoch = 0 197 | min_val_loss = np.Inf 198 | 199 | 200 | config = { 201 | "learning_rate": args.lr, 202 | "epochs": current_epoch, # saved previous epoch 203 | "weight_decay": args.weight_decay 204 | } 205 | 206 | if args.log_output: 207 | wandb.init(project="LabelEmbedding", entity='frimpz', config=config, name="{}_label_{}_{}_{}".format(args.ont, args.lr, args.weight_decay, args.gnn)) 208 | 209 | train_model(start_epoch=current_epoch, min_val_loss=min_val_loss, 210 | train_data=train_data, val_data=val_data, 211 | model=model, optimizer=optimizer) -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | biobert = 768 2 | bert = 1024 3 | 4 | ##### cellular component ##### 5 | cc_out_shape = 2957 6 | cc_out_freq_shape = 873 #2957# 873 7 | cc_out_rare_shape = 2083 8 | 9 | esm_layers_freq_cc = [ # layer 0 10 | (5120, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 11 | (1024, 256, True, 'gelu', 'batchnorm', 0.2, (1, )), 12 | (256, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 13 | ] 14 | esm_layers_rare_cc = [ # layer 0 15 | (5120, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 16 | (1024, 256, True, 'gelu', 'batchnorm', 0.2, (1, )), 17 | (256, cc_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 18 | ] 19 | 20 | esm_layers_cc = [ # layer 0 21 | (5120, 2048, True, 'gelu', 'batchnorm', 0.2, (0, )), 22 | (2048, 1024, True, 'gelu', 'batchnorm', 0.2, (1, )), 23 | (1024, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 24 | ] 25 | 26 | msa_layers_cc = [ # 0 input 27 | (768, 1024, True, 'gelu', 'batchnorm', 'none', (0, )), 28 | (1024, 930, True, 'gelu', 'batchnorm', 'none', (1, )), 29 | (930, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 30 | ] 31 | 32 | interpro_layers_cc = [ 33 | (24714, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 34 | (1024, 930, True, 'gelu', 'batchnorm', 0.2, (1, )), 35 | (930, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 36 | ] 37 | 38 | 39 | 40 | ##### molecular function ##### 41 | mf_out_shape = 7224 42 | mf_out_freq_shape = 1183# 7224 #1183 43 | mf_out_rare_shape = 6040 44 | 45 | interpro_layers_mf = [ 46 | (25523, 1400, True, 'gelu', 'batchnorm', 0.2, (0, )), 47 | (1400, 1200, True, 'gelu', 'batchnorm', 0.2, (1, )), 48 | (1200, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 49 | ] 50 | 51 | msa_layers_mf = [ # 0 input 52 | (768, 1400, True, 'gelu', 'batchnorm', 'none', (0, )), 53 | (1400, 1200, True, 'gelu', 'batchnorm', 'none', (1, )), 54 | (1200, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 55 | ] 56 | 57 | esm_layers_mf = [ # layer 0 58 | (5120, 2048, True, 'gelu', 'batchnorm', 0.2, (0, )), 59 | (2048, 1200, True, 'gelu', 'batchnorm', 0.2, (1, )), 60 | (1200, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 61 | ] 62 | 63 | esm_layers_freq_mf = [ # layer 0 64 | (5120, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 65 | (1024, 256, True, 'gelu', 'batchnorm', 0.2, (1, )), 66 | (256, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 67 | ] 68 | 69 | esm_layers_rare_mf = [ # layer 0 70 | (5120, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 71 | (1024, 256, True, 'gelu', 'batchnorm', 0.2, (1, )), 72 | (256, mf_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 73 | ] 74 | 75 | 76 | 77 | ##### biological process ##### 78 | bp_out_shape = 21285 79 | bp_out_freq_shape = 6415 80 | bp_out_rare_shape = 6977 81 | bp_out_rare_2_shape = 7892 82 | 83 | interpro_layers_bp = [ 84 | (24846, 2048, True, 'gelu', 'batchnorm', 0.2, (0, )), 85 | (2048, 1200, True, 'gelu', 'batchnorm', 0.2, (1, )), 86 | (1200, bp_out_freq_shape, True, 'gelu', 'none', 'none', (2, )), 87 | ] 88 | 89 | esm_layers_bp = [ # layer 0 90 | (5120, 2048, True, 'gelu', 'batchnorm', 0.1, (0, )), 91 | (2048, 1200, True, 'gelu', 'batchnorm', 0.1, (1, )), 92 | (1200, bp_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 93 | ] 94 | 95 | msa_layers_bp = [ 96 | (768, 2048, True, 'gelu', 'batchnorm', 'none', (0, )), 97 | (2048, 1200, True, 'gelu', 'batchnorm', 'none', (1, )), 98 | (1200, bp_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 99 | ] 100 | 101 | esm_layers_rare_bp = [ # layer 0 102 | (5120, 1024, True, 'gelu', 'batchnorm', 0.1, (0, )), 103 | (1024, 256, True, 'gelu', 'batchnorm', 0.1, (1, )), 104 | (256, bp_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 105 | ] 106 | 107 | esm_layers_rare_2_bp = [ # layer 0 108 | (5120, 1024, True, 'gelu', 'batchnorm', 0.1, (0, )), 109 | (1024, 256, True, 'gelu', 'batchnorm', 0.1, (1, )), 110 | (256, bp_out_rare_2_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 111 | ] 112 | 113 | esm_layers_freq_bp = [ # layer 0 114 | (5120, 1024, True, 'gelu', 'batchnorm', 0.1, (0, )), 115 | (1024, 256, True, 'gelu', 'batchnorm', 0.1, (1, )), 116 | (256, bp_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 117 | ] 118 | 119 | 120 | 121 | 122 | ####### For ablation studies 123 | 124 | ##### cellular component ##### 125 | interpro_layers_freq_cc = [ 126 | (24714, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 127 | (1024, 930, True, 'gelu', 'batchnorm', 0.2, (1, )), 128 | (930, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 129 | ] 130 | interpro_layers_rare_cc = [ 131 | (24714, 1024, True, 'gelu', 'batchnorm', 0.2, (0, )), 132 | (1024, 930, True, 'gelu', 'batchnorm', 0.2, (1, )), 133 | (930, cc_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 134 | ] 135 | 136 | msa_layers_freq_cc = [ # 0 input 137 | (768, 1024, True, 'gelu', 'batchnorm', 'none', (0, )), 138 | (1024, 930, True, 'gelu', 'batchnorm', 'none', (1, )), 139 | (930, cc_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 140 | ] 141 | 142 | msa_layers_rare_cc = [ # 0 input 143 | (768, 1024, True, 'gelu', 'batchnorm', 'none', (0, )), 144 | (1024, 930, True, 'gelu', 'batchnorm', 'none', (1, )), 145 | (930, cc_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 146 | ] 147 | 148 | 149 | ##### molecular function ##### 150 | msa_layers_freq_mf = [ # layer 0 151 | (768, 1400, True, 'gelu', 'batchnorm', 'none', (0, )), 152 | (1400, 1200, True, 'gelu', 'batchnorm', 'none', (1, )), 153 | (1200, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 154 | ] 155 | msa_layers_rare_mf = [ # layer 0 156 | (768, 1400, True, 'gelu', 'batchnorm', 'none', (0, )), 157 | (1400, 1200, True, 'gelu', 'batchnorm', 'none', (1, )), 158 | (1200, mf_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 159 | ] 160 | 161 | interpro_layers_freq_mf = [ # layer 0 162 | (25523, 1400, True, 'gelu', 'batchnorm', 0.2, (0, )), 163 | (1400, 1200, True, 'gelu', 'batchnorm', 0.2, (1, )), 164 | (1200, mf_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 165 | ] 166 | 167 | interpro_layers_rare_mf = [ # layer 0 168 | (25523, 1400, True, 'gelu', 'batchnorm', 0.2, (0, )), 169 | (1400, 1200, True, 'gelu', 'batchnorm', 0.2, (1, )), 170 | (1200, mf_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )), 171 | ] 172 | 173 | 174 | 175 | ##### biological process ##### 176 | interpro_layers_freq_bp = [ 177 | (24846, bp_out_freq_shape + 512, True, 'gelu', 'batchnorm', 0.2, (0, )), 178 | (bp_out_freq_shape + 512, bp_out_freq_shape + 128, True, 'gelu', 'batchnorm', 0.2, (1, )), 179 | (bp_out_freq_shape + 128, bp_out_freq_shape, True, 'gelu', 'none', 'none', (2, )), 180 | ] 181 | interpro_layers_rare_bp = [ 182 | (24846, bp_out_rare_shape + 512, True, 'gelu', 'batchnorm', 0.2, (0, )), 183 | (bp_out_rare_shape + 512, bp_out_rare_shape + 128, True, 'gelu', 'batchnorm', 0.2, (1, )), 184 | (bp_out_rare_shape + 128, bp_out_rare_shape, True, 'gelu', 'none', 'none', (2, )), 185 | ] 186 | interpro_layers_rare_2_bp = [ 187 | (24846, bp_out_rare_2_shape + 512, True, 'gelu', 'batchnorm', 0.2, (0, )), 188 | (bp_out_rare_2_shape + 512, bp_out_rare_2_shape + 128, True, 'gelu', 'batchnorm', 0.2, (1, )), 189 | (bp_out_rare_2_shape + 128, bp_out_rare_2_shape, True, 'gelu', 'none', 'none', (2, )), 190 | ] 191 | 192 | msa_layers_freq_bp = [ 193 | (768, 3200, True, 'gelu', 'batchnorm', 'none', (0, )), 194 | (3200, 3104, True, 'gelu', 'batchnorm', 'none', (1, )), 195 | (3104, bp_out_freq_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 196 | ] 197 | msa_layers_rare_bp = [ 198 | (768, 3200, True, 'gelu', 'batchnorm', 'none', (0, )), 199 | (3200, 3104, True, 'gelu', 'batchnorm', 'none', (1, )), 200 | (3104, bp_out_rare_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 201 | ] 202 | msa_layers_rare_2_bp = [ 203 | (768, 3200, True, 'gelu', 'batchnorm', 'none', (0, )), 204 | (3200, 3104, True, 'gelu', 'batchnorm', 'none', (1, )), 205 | (3104, bp_out_rare_2_shape, True, 'gelu', 'batchnorm', 'none', (2, )) 206 | ] -------------------------------------------------------------------------------- /models/model_struct.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch import nn 3 | from torch.nn import Linear, Sigmoid 4 | from torch.nn.modules.module import Module 5 | from torchsummary import summary 6 | from Utils import load_ckp, pickle_load 7 | import torch.nn.functional as F 8 | from models import net_utils 9 | import models.config as config 10 | import math 11 | import CONSTANTS 12 | from torch.nn import LayerNorm, BatchNorm1d 13 | from collections import OrderedDict 14 | from torch_geometric.nn import GCNConv, TransformerConv, GATv2Conv 15 | from torch_geometric.nn import GAE 16 | 17 | 18 | def getactfn(actfn = 'relu'): 19 | if actfn == 'relu': 20 | return nn.ReLU() 21 | elif actfn == 'gelu': 22 | return nn.GELU() 23 | elif actfn == 'tanh': 24 | return nn.Tanh() 25 | 26 | def getnorm(norm = 'layernorm', norm_shape =0): 27 | if norm == 'layernorm': 28 | return LayerNorm(norm_shape) 29 | elif norm == 'batchnorm': 30 | return BatchNorm1d(norm_shape) 31 | 32 | 33 | class TFun(nn.Module): 34 | def __init__(self, **kwargs): 35 | super(TFun, self).__init__() 36 | 37 | self.ont = kwargs['ont'] 38 | self.device = kwargs['device'] 39 | self.indicies = kwargs['indicies'] 40 | self.load_weights = True #kwargs['load_weights'] 41 | self.out_shape = getattr(config, "{}_out_shape".format(self.ont)) 42 | self.label_features = kwargs['label_features'] 43 | self.dropout = nn.Dropout(0.1) 44 | 45 | self.layer1 = nn.Linear(1024, 1200) 46 | self.layer2 = nn.Linear(1200, 1500) 47 | self.layer3 = nn.Linear(1500, 2000) 48 | self.final = nn.Linear(2000, self.out_shape) 49 | 50 | 51 | self.gelu = nn.GELU() 52 | self.sigmoid = Sigmoid() 53 | 54 | 55 | def forward(self, x): 56 | x = self.layer1(x.to(self.device)) 57 | x = self.gelu(x) 58 | x = self.layer2(x) 59 | x = self.gelu(x) 60 | x = self.layer3(x) 61 | x = self.gelu(x) 62 | x = self.final(x) 63 | x = torch.index_select(x, 1, self.indicies) 64 | x = self.sigmoid(x) 65 | return x 66 | 67 | 68 | class TFunSequence(nn.Module): 69 | def __init__(self, **kwargs): 70 | super(TFunSequence, self).__init__() 71 | 72 | self.ont = kwargs['ont'] 73 | self.device = kwargs['device'] 74 | self.indicies = kwargs['indicies'] 75 | self.load_weights = True #kwargs['load_weights'] 76 | self.out_shape = getattr(config, "{}_out_shape".format(self.ont)) 77 | self.label_features = kwargs['label_features'] 78 | self.dropout = nn.Dropout(0.1) 79 | 80 | 81 | # Transformer encoder layer 82 | self.transformer_encoder = nn.TransformerEncoder( 83 | nn.TransformerEncoderLayer(d_model=512, nhead=4), 84 | num_layers=6 85 | ) 86 | 87 | # Token embeddings 88 | # self.embedding = nn.Embedding(100, 512) 89 | 90 | # Positional encoding 91 | self.positional_encoding = nn.Embedding(1024, 512) # Assuming sequences of max length 1000 92 | 93 | # Fully connected layer for classification 94 | self.fc = nn.Linear(512, self.out_shape) 95 | 96 | self.sigmoid = Sigmoid() 97 | 98 | 99 | def forward(self, x): 100 | 101 | 102 | # Add positional encoding to the input 103 | seq_length = x.size(1) 104 | positions = torch.arange(0, seq_length).unsqueeze(0).expand(x.size(0), -1).to(self.device) 105 | positions = self.positional_encoding(positions) 106 | #print(self.positional_encoding(positions).shape) 107 | x = x + positions 108 | 109 | print(positions.shape) 110 | 111 | 112 | exit() 113 | 114 | # Transformer encoder 115 | x = self.transformer_encoder(x) 116 | 117 | # Global average pooling 118 | x = F.adaptive_avg_pool1d(x.permute(0, 2, 1), (1,)).view(x.size(0), -1) 119 | 120 | # Classification layer 121 | x = self.fc(x) 122 | x = self.sigmoid(x) 123 | 124 | return x -------------------------------------------------------------------------------- /models/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool 4 | 5 | 6 | class FC(nn.Module): 7 | def __init__(self, in_features, out_features, act_fun='relu', bnorm=True): 8 | super(FC, self).__init__() 9 | bias = False if bnorm else True 10 | self.fc = nn.Linear(in_features, out_features, bias=bias) 11 | self.act_fun = get_act_fun(act_fun) 12 | self.bn = nn.BatchNorm1d(out_features, momentum=0.1) if bnorm else None 13 | 14 | def forward(self, x): 15 | x = self.fc(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.act_fun is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | 23 | class BNormActFun(nn.Module): 24 | def __init__(self, in_features, act_fun='relu', momentum=0.1, bnorm=True): 25 | super(BNormActFun, self).__init__() 26 | 27 | self.act_fun = get_act_fun(act_fun) 28 | self.bn = nn.BatchNorm1d(in_features, momentum=momentum) if bnorm else None 29 | 30 | def forward(self, x): 31 | if self.bn is not None: 32 | x = self.bn(x) 33 | if self.act_fun is not None: 34 | x = self.act_fun(x) 35 | return x 36 | 37 | 38 | class MLP(nn.Module): 39 | # [(in_features, out_features, bnorm, act_fun, dropuout)] 40 | def __init__(self, layers_data: list): 41 | super().__init__() 42 | torch.manual_seed(12345) 43 | self.layers = nn.ModuleList() 44 | for layer in layers_data: 45 | # input_size output_size bias 46 | self.layers.append(nn.Linear(layer[0], layer[1], not layer[2])) 47 | if layer[2]: 48 | self.layers.append(nn.BatchNorm1d(layer[1], momentum=0.1)) 49 | # activation function 50 | if layer[3] is not None: 51 | self.layers.append(get_act_fun(layer[3])) 52 | # dropout 53 | if layer[4] is not None: 54 | self.layers.append(nn.Dropout(p=layer[4])) 55 | 56 | def forward(self, x): 57 | for layer in self.layers: 58 | x = layer(x) 59 | return x 60 | 61 | 62 | def get_pool(pool_type='max'): 63 | if pool_type == 'mean': 64 | return global_mean_pool 65 | elif pool_type == 'add': 66 | return global_add_pool 67 | elif pool_type == 'max': 68 | return global_max_pool 69 | 70 | 71 | def get_act_fun(act_fun): 72 | if act_fun == 'relu': 73 | return nn.ReLU() 74 | elif act_fun == 'tanh': 75 | return nn.Tanh() 76 | elif act_fun == 'leakyrelu': 77 | return nn.LeakyReLU() 78 | else: 79 | return None 80 | -------------------------------------------------------------------------------- /notes: -------------------------------------------------------------------------------- 1 | 2 | 3 | bp.x -> 2325219 4 | bp.biobert -> 5 | bp.linear -> 3017061 6 | bp.gcn -> 7 | 8 | mf.linear -> 9 | mf.gcn -> 10 | mf.biobert -> 11 | mf.x -> 12 | 13 | cc.linear -> 14 | cc.gcn -> 15 | cc.biobert -> 16 | cc.x -> 17 | 18 | 19 | cc.log -> 20 | 21 | 22 | t1 -> goa_uniprot_all.gaf.212 -> !date-generated: 2022-11-17 13:42 23 | 24 | t2 -> !date-generated: 2023-07-12 11:49 25 | 26 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import torch 3 | from Bio import SeqIO 4 | from Bio.Seq import Seq 5 | from Bio.SeqRecord import SeqRecord 6 | import torch 7 | from Utils import is_file, load_ckp, pickle_load 8 | import os 9 | import argparse 10 | from models.model import TFun 11 | from Dataset.MyDataset import PredictDataset 12 | 13 | def write_output(results, terms, filepath, cutoff=0.001): 14 | with open(filepath, 'w') as fp: 15 | for prt in results: 16 | assert len(terms) == len(results[prt]) 17 | tmp = list(zip(terms, results[prt])) 18 | tmp.sort(key = lambda x: x[1], reverse=True) 19 | for trm, score in tmp: 20 | if score > cutoff: 21 | fp.write('%s\t%s\t%0.3f\n' % (prt, trm, score)) 22 | 23 | 24 | def generate_bulk_embedding(wd, fasta_path): 25 | # name model output dir, embedding layer 1, embedding layer 2, batch 26 | model = ("esm_2", "esm2_t48_15B_UR50D", fasta_path, "esm2_t48", 48, 100) 27 | CMD = "python -u {} {} {} {}/{} --repr_layers {} --include mean per_tok " \ 28 | "--toks_per_batch {}".format("external/extract.py", model[1], model[2], \ 29 | wd, model[3], model[4], model[5]) 30 | print(CMD) 31 | subprocess.call(CMD, shell=True, cwd="./") 32 | 33 | 34 | def fasta_to_dictionary(fasta_file): 35 | data = {} 36 | for seq_record in SeqIO.parse(fasta_file, "fasta"): 37 | data[seq_record.id] = (seq_record.id, seq_record.name, seq_record.description, seq_record.seq) 38 | return data 39 | 40 | 41 | def merge_pts(keys, fasta, wd): 42 | for pos, protein in enumerate(keys): 43 | fasta_dic = fasta_to_dictionary(fasta) 44 | 45 | tmp = [] 46 | for level in range(keys[protein]): 47 | os_path = "{}/esm2_t48/{}_{}.pt".format(wd, protein, level) 48 | tmp.append(torch.load(os_path)) 49 | 50 | data = {'representations': {}, 'mean_representations': {}} 51 | for index in tmp: 52 | # print(index['mean_representations'][rep].shape, torch.mean(index['representations'][rep], dim=0).shape) 53 | assert torch.equal(index['mean_representations'][48], torch.mean(index['representations'][48], dim=0)) 54 | 55 | if 48 in data['representations']: 56 | data['representations'][48] = torch.cat((data['representations'][48], index['representations'][48])) 57 | else: 58 | data['representations'][48] = index['representations'][48] 59 | 60 | assert len(fasta_dic[protein][3]) == data['representations'][48].shape[0] 61 | 62 | data['mean_representations'][48] = torch.mean(data['representations'][48], dim=0) 63 | 64 | # print("saving {}".format(protein)) 65 | torch.save(data, "{}/esm2_t48/{}.pt".format(wd, protein)) 66 | 67 | 68 | def create_seqrecord(id="", name="", description="", seq=""): 69 | record = SeqRecord(Seq(seq), id=id, name=name, description=description) 70 | return record 71 | 72 | def crop_fasta(record): 73 | splits = [] 74 | keys = {} 75 | main_id = record.id 76 | chnks = len(record.seq) / 1021 77 | remnder = len(record.seq) % 1021 78 | chnks = int(chnks) if remnder == 0 else int(chnks) + 1 79 | keys[main_id] = chnks 80 | for pos in range(chnks): 81 | id = "{}_{}".format(main_id, pos) 82 | seq = str(record.seq[pos * 1021:(pos * 1021) + 1021]) 83 | splits.append(create_seqrecord(id=id, name=id, description="", seq=seq)) 84 | return splits, keys 85 | 86 | 87 | 88 | def generate_embeddings(in_fasta, wd): 89 | keys = {} 90 | sequences = [] 91 | proteins = [] 92 | input_seq_iterator = SeqIO.parse(in_fasta, "fasta") 93 | for record in input_seq_iterator: 94 | proteins.append(record.id) 95 | if len(record.seq) > 1021: 96 | _seqs, _keys = crop_fasta(record) 97 | sequences.extend(_seqs) 98 | keys.update(_keys) 99 | else: 100 | sequences.append(record) 101 | 102 | # any sequence > 1022 103 | cropped_fasta = wd + "/temp.fasta" 104 | if len(keys) > 0: 105 | SeqIO.write(sequences, cropped_fasta, "fasta") 106 | generate_bulk_embedding(wd, cropped_fasta) 107 | merge_pts(keys, fasta_path, wd) 108 | else: 109 | generate_bulk_embedding(wd, in_fasta) 110 | 111 | return proteins 112 | 113 | 114 | def create_dataset(proteins, wd): 115 | data = {'esm2_t48': [], 'protein': [] } 116 | for _protein in proteins: 117 | tmp = torch.load("{}/esm2_t48/{}.pt".format(wd, _protein)) 118 | tmp = tmp['mean_representations'][48].view(1, -1).squeeze(0).cpu() 119 | 120 | data['esm2_t48'].append(tmp) 121 | data['protein'].append(_protein) 122 | 123 | dataset = PredictDataset(data=data) 124 | return dataset 125 | 126 | 127 | def get_term_indicies(ontology, device, data_path): 128 | 129 | _term_indicies = pickle_load(data_path + "/{}/term_indicies".format(ontology)) 130 | 131 | if ontology == 'bp': 132 | full_term_indicies, mid_term_indicies, freq_term_indicies = _term_indicies[0], _term_indicies[5], _term_indicies[30] 133 | rare_term_indicies_2 = torch.tensor([i for i in full_term_indicies if not i in set(mid_term_indicies)]).to(device) 134 | rare_term_indicies = torch.tensor([i for i in mid_term_indicies if not i in set(freq_term_indicies)]).to(device) 135 | full_term_indicies, freq_term_indicies = torch.tensor(_term_indicies[0]).to(device), torch.tensor(freq_term_indicies).to(device) 136 | else: 137 | full_term_indicies = _term_indicies[0] 138 | freq_term_indicies = _term_indicies[30] 139 | rare_term_indicies = torch.tensor([i for i in full_term_indicies if not i in set(freq_term_indicies)]).to(device) 140 | full_term_indicies = torch.tensor(full_term_indicies).to(device) 141 | freq_term_indicies = torch.tensor(freq_term_indicies).to(device) 142 | rare_term_indicies_2 = None 143 | 144 | return full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 145 | 146 | 147 | parser = argparse.ArgumentParser(description=" Predict protein functions with TransFew ", epilog=" Thank you !!!") 148 | parser.add_argument('--data-path', type=str, default="", help="Path to data files (models)") 149 | parser.add_argument('--working-dir', type=str, default=".", help="Path to generate temporary files") 150 | parser.add_argument('--ontology', type=str, default="cc", help="Path to data files") 151 | parser.add_argument('--no-cuda', default=False, help='Disables CUDA training.') 152 | parser.add_argument('--batch-size', default=10, help='Batch size.') 153 | parser.add_argument('--fasta-path', default="sequence.fasta", help='Path to Fasta') 154 | parser.add_argument('--output', type=str, default="result.tsv", help="File to save output") 155 | 156 | args = parser.parse_args() 157 | args.cuda = not args.no_cuda and torch.cuda.is_available() 158 | 159 | if args.cuda: 160 | device = 'cuda:1' 161 | else: 162 | device = 'cpu' 163 | 164 | 165 | fasta_path = args.fasta_path 166 | wd = args.working_dir 167 | ontology = args.ontology 168 | data_path = args.data_path 169 | 170 | 171 | proteins = generate_embeddings(in_fasta=fasta_path, wd=wd) 172 | 173 | dataset = create_dataset(proteins, wd=wd) 174 | 175 | loader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=False) 176 | 177 | 178 | sorted_terms = pickle_load(data_path+"/{}/sorted_terms".format(ontology)) 179 | full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 = \ 180 | get_term_indicies(ontology=ontology, device=device, data_path=data_path) 181 | 182 | kwargs = { 183 | 'device': device, 184 | 'ont': ontology, 185 | 'full_indicies': full_term_indicies, 186 | 'freq_indicies': freq_term_indicies, 187 | 'rare_indicies': rare_term_indicies, 188 | 'rare_indicies_2': rare_term_indicies_2, 189 | 'sub_model': 'full', 190 | 'load_weights': True, 191 | 'label_features': 'gcn', 192 | 'group': "" 193 | } 194 | 195 | ckp_dir = data_path + '/{}/models/{}_{}/'.format(ontology, kwargs['sub_model'], kwargs['label_features']) 196 | ckp_pth = ckp_dir + "current_checkpoint.pt" 197 | model = TFun(**kwargs) 198 | 199 | # load model 200 | model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True) 201 | 202 | model.to(device) 203 | model.eval() 204 | 205 | results = {} 206 | for data in loader: 207 | _features, _proteins = data[:1], data[1] 208 | output = model(_features) 209 | output = torch.index_select(output, 1, full_term_indicies) 210 | output = output.tolist() 211 | 212 | for i, j in zip(_proteins, output): 213 | results[i] = j 214 | 215 | terms = [sorted_terms[i] for i in full_term_indicies] 216 | write_output(results, terms, "{}/{}".format(wd, args.output), cutoff=0.01) 217 | 218 | -------------------------------------------------------------------------------- /test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import torch_geometric.datasets as datasets 6 | import torch_geometric.data as data 7 | import torch_geometric.transforms as transforms 8 | import networkx as nx 9 | from torch_geometric.utils.convert import to_networkx 10 | import CONSTANTS 11 | from Classes.Diamond import Diamond 12 | from Classes.Fasta import Fasta 13 | from Classes.Interpro import Interpro, create_indicies 14 | from Classes.STRING import STRING 15 | from Dataset.Dataset import TransFunDataset 16 | from Utils import count_proteins, get_proteins_from_fasta, pickle_load, pickle_save 17 | 18 | 19 | 20 | 21 | 22 | '''infiles = ["test_proteins_1800", "test_proteins_3600", "test_proteins_5400", 23 | "test_proteins_900", "test_proteins_2700", "test_proteins_4500", 24 | "test_proteins_5641", "test_proteins.out"] 25 | 26 | # merge chunks 27 | df = pd.DataFrame() 28 | prefix = "/home/fbqc9/testinetrpro/{}" 29 | for infile in infiles: 30 | print(infile) 31 | data = pd.read_csv(prefix.format(infile), sep="\t", 32 | names=["Protein accession", "Sequence MD5", "Sequence length", "Analysis", 33 | "Signature accession", "Signature description", "Start location", 34 | "Stop location", "Score", "Status", "Date", 35 | "InterPro annotations", "InterPro annotations description ", "GO annotations"]) 36 | print(data[['Protein accession', 'InterPro annotations']].head(3)) 37 | df = pd.concat([df, data], axis=0) 38 | # passed additional quality checks and is very unlikely to be a false match. 39 | df = df[['Protein accession', 'InterPro annotations']] 40 | df = df[df["InterPro annotations"] != "-"] 41 | 42 | print(df[['Protein accession', 'InterPro annotations']].head(3)) 43 | 44 | df.to_csv(prefix.format("test_proteins.out"), index=False, sep="\t") 45 | 46 | 47 | exit()''' 48 | 49 | 50 | '''p1 = "/home/fbqc9/Workspace/DATA/interpro/test_proteins.out" 51 | p2 = "/home/fbqc9/testinetrpro/test_proteins.out" 52 | data = pd.read_csv(p1, sep="\t", 53 | names=["Protein accession", "Sequence MD5", "Sequence length", "Analysis", 54 | "Signature accession", "Signature description", "Start location", 55 | "Stop location", "Score", "Status", "Date", 56 | "InterPro annotations", "InterPro annotations description ", "GO annotations"]) 57 | 58 | print(len(set(data["Protein accession"].to_list()))) 59 | exit()''' 60 | 61 | 62 | '''all_test_proteins = set() 63 | dta = pickle_load(CONSTANTS.ROOT_DIR + "test/t3/test_proteins") 64 | for i in dta: 65 | all_test_proteins.update(dta[i]) 66 | all_test_proteins = list(all_test_proteins) 67 | print(len(all_test_proteins)) 68 | 69 | 70 | for i in all_test_proteins: 71 | try: 72 | x = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 73 | print(i, torch.sum(x['interpro_mf'].x)) 74 | break 75 | except FileNotFoundError: 76 | pass 77 | 78 | 79 | 80 | interpro = Interpro(ont='mf') 81 | 82 | mf_interpro_data, mf_interpro_sig, _ = interpro.get_interpro_test() 83 | ct = 0 84 | for i in mf_interpro_data: 85 | print(sum(mf_interpro_data[i])) 86 | for j,k in zip(mf_interpro_sig, mf_interpro_data[i]): 87 | if k == 1: 88 | print(i, j , k) 89 | ct = ct + 1 90 | if ct == 5: 91 | exit() 92 | 93 | 94 | exit()''' 95 | 96 | 97 | '''to_remove = {'C0HM98', 'C0HM97', 'C0HMA1', 'C0HM44'} 98 | all_test_proteins = set() 99 | dta = pickle_load(CONSTANTS.ROOT_DIR + "test/t3/test_proteins") 100 | for i in dta: 101 | all_test_proteins.update(dta[i]) 102 | all_test_proteins = list(all_test_proteins.difference(to_remove)) 103 | print(len(all_test_proteins)) 104 | 105 | kwargs = { 106 | 'split': 'selected', 107 | 'proteins': all_test_proteins 108 | } 109 | train_dataset = TransFunDataset(**kwargs) 110 | 111 | exit() 112 | 113 | 114 | 115 | 116 | x = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format("A0A7I2V2R9")) 117 | 118 | print(x) 119 | 120 | 121 | x = torch.load("/bmlfast/frimpong/shared_function_data/esm_msa1b/{}.pt".format("Q75WF1")) 122 | print(x) 123 | 124 | exit() 125 | 126 | 127 | x = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format("Q75WF1")) 128 | 129 | print(x) 130 | 131 | exit()''' 132 | 133 | to_remove = {'C0HM98', 'C0HM97', 'C0HMA1', 'C0HM44'} 134 | all_test_proteins = set() 135 | dta = pickle_load(CONSTANTS.ROOT_DIR + "test/t3/test_proteins") 136 | for i in dta: 137 | all_test_proteins.update(dta[i]) 138 | dt = list(all_test_proteins.difference(to_remove)) 139 | print(len(dt)) 140 | 141 | 142 | for i in dt: 143 | tmp = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 144 | mas = tmp['esm_msa1b'].x 145 | 146 | if len(mas.shape) == 3: 147 | tmp['esm_msa1b'].x = torch.mean(mas, dim=1) 148 | print(mas.shape, tmp['esm_msa1b'].x.shape) 149 | torch.save(tmp, CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 150 | 151 | 152 | exit() 153 | 154 | onts = ['cc', 'bp', 'mf'] 155 | 156 | for ont in onts: 157 | store = {'labels': [], 158 | 'esm2_t48': [], 159 | 'msa_1b': [], 160 | 'interpro': [], 161 | 'diamond': [], 162 | 'string': [], 163 | 'protein': [] 164 | } 165 | 166 | for po, i in enumerate(dt): 167 | print("{}, {}, {}".format(ont, i, po)) 168 | tmp = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 169 | esm = tmp['esm2_t48'].x 170 | msa = torch.mean(tmp['esm_msa1b'].x, dim=0).unsqueeze(0).cpu() 171 | diamond = tmp['diamond_{}'.format(ont)].x 172 | diamond = torch.mean(diamond, dim=0).unsqueeze(0) 173 | interpro = tmp['interpro_{}'.format(ont)].x 174 | string_data = tmp['string_{}'.format(ont)].x 175 | string_data = torch.mean(string_data, dim=0).unsqueeze(0) 176 | 177 | assert esm.shape == torch.Size([1, 5120]) 178 | assert msa.shape == torch.Size([1, 768]) 179 | 180 | store['esm2_t48'].append(esm) 181 | store['msa_1b'].append(msa) 182 | store['diamond'].append(diamond) 183 | store['interpro'].append(interpro) 184 | store['string'].append(string_data) 185 | store['protein'].append(i) 186 | 187 | 188 | pickle_save(store, "com_data/{}.data_test".format(ont)) 189 | 190 | 191 | 192 | exit() 193 | onts = ['cc', 'bp', 'mf'] 194 | 195 | for ont in onts: 196 | 197 | data = pickle_load("com_data/{}.data_test".format(ont)) 198 | 199 | 200 | msa_data = data['msa_1b'] 201 | 202 | for i in msa_data: 203 | if i.device != torch.device("cpu"): 204 | print(i.device) 205 | 206 | 207 | 208 | 209 | 210 | ''' 211 | def generate_test_data(): 212 | all_test = pickle_load(CONSTANTS.ROOT_DIR + "test/test_proteins") 213 | 214 | lk = set() 215 | for i in all_test: 216 | if i.startswith('LK_'): 217 | for j in all_test[i]: 218 | lk.add(j) 219 | 220 | all_test = set([j for i in all_test for j in all_test[i]]) 221 | 222 | 223 | 224 | 225 | 226 | 227 | for i in all_test: 228 | 229 | x = torch.load("/home/fbqc9/esm_msa1b/{}.pt".format(i)) 230 | print(x['representations_12'].shape) 231 | 232 | 233 | 234 | 235 | exit() 236 | 237 | # check if all data is available 238 | 239 | # esm 240 | esm = os.listdir("/bmlfast/frimpong/shared_function_data/esm2_t48/") 241 | esm = set([i.split(".")[0] for i in esm]) 242 | 243 | 244 | msa = os.listdir("/home/fbqc9/esm_msa1b/") 245 | msa = set([i.split(".")[0] for i in msa]) 246 | 247 | a3ms = os.listdir("/bmlfast/frimpong/shared_function_data/a3ms/") 248 | a3ms = set([i.split(".")[0] for i in a3ms]) 249 | 250 | 251 | print(len(all_test.difference(esm)), len(all_test.difference(msa)), \ 252 | len(all_test.difference(a3ms))) 253 | 254 | exit() 255 | 256 | 257 | 258 | generate_test_data() 259 | 260 | exit() 261 | ''' -------------------------------------------------------------------------------- /train_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import torch_geometric.datasets as datasets 6 | import torch_geometric.data as data 7 | import torch_geometric.transforms as transforms 8 | import networkx as nx 9 | from torch_geometric.utils.convert import to_networkx 10 | import CONSTANTS 11 | from Classes.Diamond import Diamond 12 | from Classes.Fasta import Fasta 13 | from Classes.Interpro import Interpro, create_indicies 14 | from Classes.STRING import STRING 15 | from Utils import count_proteins, get_proteins_from_fasta, pickle_load, pickle_save 16 | import urllib.request 17 | 18 | 19 | 20 | onts = ['cc', 'bp', 'mf'] 21 | sess = ['validation', 'train'] 22 | 23 | 24 | key_val = pickle_load("cc_string_comps").key_val 25 | data = torch.load("cc_string_comp.pt") 26 | data = data.cpu() 27 | 28 | # res = dict((v,k) for k,v in _data.items()) 29 | 30 | # you = {} 31 | # for i, j in enumerate(data): 32 | # you[res[i]] = j 33 | # print(len(you)) 34 | 35 | 36 | for ont in onts[0:1]: 37 | print(ont) 38 | 39 | for s in sess: 40 | print(s) 41 | 42 | dt = list(pickle_load(CONSTANTS.ROOT_DIR + "{}/{}_proteins".format(ont, s))) 43 | 44 | indicies = torch.tensor([key_val[i] for i in dt]) 45 | 46 | 47 | bn = torch.index_select(data, 0, indicies) 48 | bn = bn.tolist() 49 | 50 | labels = pickle_load(CONSTANTS.ROOT_DIR + "{}/labels".format(ont)) 51 | 52 | store = {'labels': [], 53 | 'string': bn 54 | } 55 | 56 | for i in dt: 57 | label = torch.tensor(labels[i], dtype=torch.float32).view(1, -1) 58 | 59 | store['labels'].append(label) 60 | 61 | 62 | pickle_save(store, CONSTANTS.ROOT_DIR + "{}/{}_data_2".format(ont, s)) 63 | 64 | 65 | 66 | 67 | exit() 68 | 69 | onts = ['cc', 'bp', 'mf'] 70 | sess = ['train', 'validation'] 71 | 72 | 73 | for ont in onts: 74 | print(ont) 75 | 76 | for s in sess: 77 | 78 | dt = list(pickle_load(CONSTANTS.ROOT_DIR + "{}/{}_proteins".format(ont, s))) 79 | 80 | labels = pickle_load(CONSTANTS.ROOT_DIR + "{}/labels".format(ont)) 81 | 82 | store = {'labels': [], 83 | 'esm2_t48': [], 84 | 'msa_1b': [], 85 | 'interpro': [], 86 | 'diamond': [], 87 | 'string': [], 88 | 'protein': [] 89 | } 90 | 91 | for i in dt: 92 | print("{}, {}, {}".format(ont, s, i)) 93 | tmp = torch.load(CONSTANTS.ROOT_DIR + "data/processed/{}.pt".format(i)) 94 | esm = tmp['esm2_t48'].x.squeeze(0) 95 | msa = torch.mean(tmp['esm_msa1b'].x, dim=0).detach()#.unsqueeze(0) 96 | diamond = tmp['diamond_{}'.format(ont)].x 97 | diamond = torch.mean(diamond, dim=0)#.unsqueeze(0) 98 | interpro = tmp['interpro_{}'.format(ont)].x.squeeze(0) 99 | string_data = tmp['string_{}'.format(ont)].x 100 | string_data = torch.mean(string_data, dim=0)#.unsqueeze(0) 101 | label = torch.tensor(labels[i], dtype=torch.float32).view(1, -1) 102 | 103 | 104 | store['labels'].append(label) 105 | store['esm2_t48'].append(esm) 106 | store['msa_1b'].append(msa) 107 | store['diamond'].append(diamond) 108 | store['interpro'].append(interpro) 109 | store['string'].append(string_data) 110 | store['protein'].append(i) 111 | 112 | 113 | pickle_save(store, CONSTANTS.ROOT_DIR + "{}/{}_data".format(ont, s)) 114 | 115 | 116 | 117 | exit() 118 | # device = 'cuda:1' 119 | train_data = pickle_load("com_data/{}.data_{}".format('cc', 'train')) 120 | print(train_data['esm2_t48'].shape) 121 | print(train_data['msa_1b'].shape) 122 | print(train_data['diamond'].shape) 123 | print(train_data['interpro'].shape) 124 | print(train_data['string'].shape) 125 | print(train_data['labels'].shape) 126 | 127 | 128 | # # labels = torch.cat(labels, dim=0).to(device) 129 | # # labels = torch.index_select(labels, 1, term_indicies) 130 | 131 | # msa_features = train_data['msa_1b'] 132 | 133 | # for i in msa_features: 134 | # print(i) 135 | # exit() 136 | 137 | # print(type(msa_features)) 138 | # exit() 139 | # msa_features = torch.cat(msa_features, dim=0).to(device) 140 | 141 | 142 | 143 | # esm_features = train_data['esm2_t48'] 144 | # esm_features = torch.cat(esm_features, dim=0).to(device) 145 | 146 | 147 | 148 | 149 | # data = pickle_load("com_data/{}.data_{}".format(self.ont, self.session)) 150 | 151 | # labels = data['labels'] 152 | # labels = torch.cat(labels, dim=0).to(device) 153 | # labels = torch.index_select(labels, 1, term_indicies) 154 | 155 | # esm_features = train_data[args.submodel] 156 | # esm_features = torch.cat(esm_features, dim=0).to(device) 157 | 158 | 159 | # exit() 160 | 161 | # print(esm_features.shape, msa_features.shape) 162 | # print(esm_features.dtype, msa_features.dtype) 163 | # print(esm_features.device, msa_features.device) 164 | # print(esm_features[0].dtype, msa_features[0].dtype) 165 | -------------------------------------------------------------------------------- /transfew.yaml: -------------------------------------------------------------------------------- 1 | name: transfew 2 | channels: 3 | - pyg 4 | - pytorch 5 | - bioconda 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - diamond=0.9.14 11 | - hhsuite=3.3.0 12 | - mmseqs2=13.45111 13 | - networkx=3.1 14 | - numpy=1.23.5 15 | - pillow=9.4.0 16 | - pip=23.2.1 17 | - pyg=2.3.1 18 | - python=3.10.0 19 | - pytorch=2.0.1 20 | - scikit-learn=1.3.0 21 | - scipy=1.10.1 22 | - setuptools=68.0.0 23 | - torchaudio=2.0.2 24 | - torchmetrics=1.2.0 25 | - torchtriton=2.0.0 26 | - torchvision=0.15.2 27 | - tqdm=4.65.0 28 | - pip: 29 | - biopandas==0.5.1.dev0 30 | - biopython==1.81 31 | - fair-esm==2.0.0 32 | - h5py==3.10.0 33 | - huggingface-hub==0.16.4 34 | - keras==2.15.0 35 | - keras-preprocessing==1.1.2 36 | - matplotlib==3.8.0 37 | - num2words==0.5.13 38 | - obonet==1.0.0 39 | - pandas==1.5.3 40 | - seaborn==0.12.2 41 | - tensorflow==2.15.0.post1 42 | - torch-cluster==1.6.1 43 | - torch-scatter==2.1.1 44 | - torch-sparse==0.6.17 45 | - torch-spline-conv==1.2.2 46 | - torchsummary==1.5.1 47 | - torchviz==0.0.2 48 | - wandb==0.15.12 49 | - xlsxwriter==3.1.5 --------------------------------------------------------------------------------