├── src ├── __init__.py └── space │ ├── __init__.py │ ├── models │ ├── __init__.py │ ├── node2vec.py │ └── fedcoder.py │ └── tools │ ├── __init__.py │ ├── convert_h5_fmt.py │ ├── taxonomy.py │ ├── process_orthologs.py │ └── data.py ├── figures └── space_overview.png ├── .gitignore ├── setup.py ├── LICENSE ├── scripts ├── download.sh ├── umap_species.py ├── node2vec.py ├── align_seeds.py ├── distances │ ├── kegg_stats.py │ ├── deeploc_distance.py │ ├── kegg_single.py │ └── og_sampling.py ├── align_non_seeds.py ├── pr_curves.py ├── prott5_emb.py ├── func_pred.py ├── subloc.py └── add_singleton.py ├── README.md └── reproduce.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/space/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/space/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/space/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/space_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deweihu96/SPACE/HEAD/figures/space_overview.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/space.egg-info 2 | data/ 3 | # macos related 4 | __MACOSX/ 5 | .DS_Store 6 | # local python cache 7 | __pycache__/ 8 | # local jupyter notebook cache 9 | .ipynb_checkpoints/ 10 | local/ 11 | results/ 12 | *.ipynb 13 | logs/ 14 | *.h5 15 | data.tgz 16 | todo.md 17 | .gitignore 18 | README_.md 19 | temp.sh 20 | temp/ 21 | !figures/space_overview.png 22 | figures/ 23 | *.log 24 | temp.sh 25 | scripts/download.sh 26 | data.tar 27 | e6.og2seqs_and_species.tsv 28 | make_your_alignment.md 29 | scripts/cross_training.py 30 | cross_training_species.txt 31 | kegg_benchmarking.CONN_maps_in.v12.tsv 32 | kegg_single.py 33 | kegg_species.txt 34 | resub_/ 35 | !scripts/distances/kegg_single.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='space', 5 | version='1.0', 6 | packages=find_packages(where='src'), 7 | package_dir={'': 'src'}, 8 | install_requires=[ 9 | 'gensim==4.3.2', 10 | 'h5py==3.12.1', 11 | 'loguru==0.7.3', 12 | 'matplotlib==3.10.0', 13 | 'numba==0.60.0', 14 | 'numpy==1.26.4', 15 | 'pandas==2.2.3', 16 | 'pecanpy==2.0.9', 17 | 'scikit_learn==1.6.0', 18 | 'tqdm==4.67.1', 19 | 'transformers==4.47.1', 20 | "scipy==1.11.4", 21 | "tensorboard==2.18.0", 22 | "umap-learn==0.5.7", 23 | "seaborn==0.13.2", 24 | "cafaeval==1.2.1", 25 | "torch==2.7.0", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dewei Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | # download the preprocessed data 2 | curl https://erda.ku.dk/archives/c2a0ba424cf75184c39a3cd37e4fe1a6/space-2024-10-28/extra_data/data/data.tgz 3 | mkdir -p data 4 | tar -xzvf data.tgz -C data 5 | 6 | # download the STRING networks and sequences 7 | mkdir -p data/networks 8 | SPECIES="data/euks.txt" 9 | for species in $(cat $SPECIES); do 10 | wget https://stringdb-downloads.org/download/protein.links.v12.0/$species.protein.links.v12.0.txt.gz -O data/networks/$species.protein.links.v12.0.txt.gz -q 11 | wget https://stringdb-downloads.org/download/protein.sequences.v12.0/$species.protein.sequences.v12.0.fa.gz -O data/networks/$species.protein.sequences.v12.0.fa.gz -q 12 | done 13 | 14 | # download the subcellular localization data from DeepLoc2.0 15 | mkdir -p data/benchmarks/deeploc20 16 | wget https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/Swissprot_Train_Validation_dataset.csv -O data/benchmarks/deeploc20/Swissprot_Train_Validation_dataset.csv -q 17 | wget https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/hpa_testset.csv -O data/benchmarks/deeploc20/hpa_testset.csv -q 18 | 19 | # the protein function prediction benchmark has to be downloaded manually from the following link, 20 | # according the NetGO paper (https://doi.org/10.1093/nar/gkz388): 21 | # https://drive.google.com/drive/folders/1HLH1aCDxlrVpu1zKvgfdQFEFnbT8gChm -------------------------------------------------------------------------------- /src/space/tools/convert_h5_fmt.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from space.tools.data import H5pyData 4 | import os 5 | from multiprocessing import Pool 6 | 7 | def convert_h5_format(paths, precision=16): 8 | input_file_path, output_file_path = paths 9 | with h5py.File(input_file_path, 'r') as f: 10 | proteins = list(f.keys()) 11 | 12 | proteins = np.array(proteins).astype('U').reshape(-1) 13 | 14 | embedding = [f[protein][:] for protein in proteins] 15 | 16 | if precision == 32: 17 | embedding = np.array(embedding).astype(np.float32) 18 | elif precision == 16: 19 | embedding = np.array(embedding).astype(np.float16) 20 | 21 | H5pyData.write(proteins=proteins, embedding=embedding, 22 | save_path=output_file_path, precision=precision) 23 | taxid = input_file_path.split('/')[-1].split('.')[0] 24 | with open(f'logs/fmt/{taxid}.txt', 'w') as f: 25 | f.write(f'Converted {taxid}') 26 | 27 | return None 28 | 29 | 30 | # if __name__ == '__main__': 31 | 32 | input_dir = 'data/aligned_non_seeds' 33 | save_dir = 'data/new_aligned_non_seeds' 34 | 35 | # for f in os.listdir('data/aligned_seeds'): 36 | 37 | # if f.endswith('.h5'): 38 | # input_file_path = os.path.join(input_dir, f) 39 | # output_file_path = os.path.join(save_dir, f) 40 | # convert_h5_format(input_file_path, output_file_path, precision=32) 41 | # print(f'Converted {f}.') 42 | 43 | with Pool(7) as p: 44 | p.map(convert_h5_format, [(f'data/aligned_non_seeds/{f}', f'data/new_aligned_non_seeds/{f}') 45 | for f in os.listdir('data/aligned_non_seeds') if f.endswith('.h5')]) 46 | -------------------------------------------------------------------------------- /src/space/tools/taxonomy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import gzip 4 | 5 | ## for some species, the taxid used in STRING is different from the one used in ncbi 6 | __TAXID_CONVERTER__ = {339724:2059318, 7 | 1745343:2082293, 8 | 1325735:2604345, 9 | 1658172:1955775, 10 | 56484:2754530, 11 | 1266660:1072105, 12 | 1229665:2502994, 13 | 944018:1913371, 14 | 743788:2126942 15 | } 16 | 17 | 18 | class Lineage: 19 | 20 | def __init__(self,node_dmp_zip,group_dir) -> None: 21 | 22 | self.df = pd.read_csv(node_dmp_zip,sep='|',compression='zip',header=None) 23 | self.eggnog_ancestors = {f.split('.')[0] for f in os.listdir(group_dir) } 24 | self.group_dir = group_dir 25 | 26 | def get_lineage(self,taxid): 27 | taxid = int(taxid) 28 | l = [taxid] 29 | while taxid != 1: 30 | taxid = self.df[self.df.iloc[:,0]==taxid].iloc[0,1] 31 | l = [taxid] + l 32 | return l 33 | 34 | def common_ancestor(self,taxid_1,taxid_2): 35 | 36 | 37 | taxid_1 = int(taxid_1) 38 | taxid_2 = int(taxid_2) 39 | 40 | use_taxid_1 = __TAXID_CONVERTER__.get(taxid_1,taxid_1) 41 | use_taxid_2 = __TAXID_CONVERTER__.get(taxid_2,taxid_2) 42 | 43 | l_1 = self.get_lineage(use_taxid_1) 44 | l_2 = self.get_lineage(use_taxid_2) 45 | for idx,taxid in enumerate(l_1): 46 | if taxid in l_2: 47 | common_ancestor = taxid 48 | else: 49 | break 50 | 51 | ## make sure eggNOG has the common ancestor, and the orthologs are not empty 52 | while True: 53 | if self.check_ortholog_group(taxid_1,taxid_2,common_ancestor): 54 | break 55 | idx -= 1 56 | common_ancestor = l_1[idx-1] 57 | 58 | return str(taxid_1),str(taxid_2),int(common_ancestor) 59 | 60 | 61 | def check_ortholog_group(self,taxid_1,taxid_2,ancestor): 62 | 63 | group_file = f'{self.group_dir}/{ancestor}.tsv.gz' 64 | 65 | if not os.path.exists(group_file): 66 | return False 67 | 68 | with gzip.open(group_file,'rt') as f: 69 | 70 | for line in f: 71 | line = line.strip().split('\t') 72 | species_list = line[-2].split(',') 73 | 74 | if str(taxid_1) in species_list and str(taxid_2) in species_list: 75 | return True 76 | 77 | return False 78 | -------------------------------------------------------------------------------- /scripts/umap_species.py: -------------------------------------------------------------------------------- 1 | import umap 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | from space.tools.data import H5pyData 6 | from tqdm import tqdm 7 | import os 8 | import argparse 9 | 10 | argparser = argparse.ArgumentParser(description='UMAP visualization of species embeddings') 11 | argparser.add_argument('--save_dir', type=str, default='./results', 12 | help='Directory to save the results') 13 | argparser.add_argument('--embedding_dir', type=str, default='data/aligned', 14 | help='Directory containing the species embeddings') 15 | args = argparser.parse_args() 16 | 17 | 18 | 19 | if os.path.exists(args.save_dir) == False: 20 | os.makedirs(args.save_dir) 21 | 22 | species = [9606,4932,3702,44689,10116] 23 | 24 | 25 | species_names={4932:'Saccharomyces cerevisiae',3702:'Arabidopsis thaliana', 26 | 9606:'Homo sapiens',44689:'Dictyostelium discoideum',10116:'Rattus norvegicus'} 27 | colors = ["#e60049","#0bb4ff","#50e991","#9b19f5","#ffa300"] 28 | embedding_dir = args.embedding_dir 29 | 30 | if __name__ == '__main__': 31 | 32 | emb = list() 33 | labels = list() 34 | num_points = list() 35 | 36 | for s in tqdm(species): 37 | _, e = H5pyData.read(f'{embedding_dir}/{s}.h5',16) 38 | emb.append(e) 39 | labels.append([species_names[s]] * e.shape[0]) 40 | num_points.append(e.shape[0]) 41 | 42 | umap_emb = umap.UMAP(n_neighbors=100, min_dist=1, metric='cosine').fit_transform(np.concatenate(emb)) 43 | umap_df = pd.DataFrame(umap_emb, columns=['UMAP1', 'UMAP2']) 44 | umap_df['species'] = np.concatenate(labels) 45 | 46 | plt.figure(figsize=(20, 4)) 47 | 48 | for i, s in enumerate(species): 49 | other_species = [species_names[k] for k in species if k != s] 50 | plt.subplot(1, 5, i+1) 51 | 52 | others = umap_df[umap_df['species'].isin(other_species)] 53 | plt.scatter(others['UMAP1'], others['UMAP2'], color='#cccccc', s=3) 54 | 55 | this_species = umap_df[umap_df['species'] == species_names[s]] 56 | plt.scatter(this_species['UMAP1'], this_species['UMAP2'], s=3, color=colors[i]) 57 | 58 | plt.title(species_names[s], fontsize=20, style='italic') 59 | 60 | plt.gca().spines['top'].set_visible(False) 61 | plt.gca().spines['right'].set_visible(False) 62 | if s == 9606: 63 | plt.xlabel('UMAP1',fontsize=20) 64 | plt.ylabel('UMAP2',fontsize=20) 65 | ## change the tick number sizes 66 | plt.xticks(fontsize=15) 67 | plt.yticks(fontsize=15) 68 | # remove the top and right spines 69 | else: 70 | plt.gca().spines['left'].set_visible(False) 71 | plt.gca().spines['bottom'].set_visible(False) 72 | plt.yticks([]) 73 | plt.xticks([]) 74 | 75 | plt.tight_layout() 76 | 77 | plt.savefig(f'{args.save_dir}/figure_1_umap_species.png', dpi=300) -------------------------------------------------------------------------------- /scripts/node2vec.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script is used to run the node2vec algorithm on the STRING functional networks. 3 | ''' 4 | 5 | import argparse 6 | import os 7 | from space.models.node2vec import run_single_embedding 8 | 9 | 10 | def main(args): 11 | 12 | species_name = os.path.basename(args.input_network).split(".")[0] 13 | 14 | # create a temporary file 15 | temp_path = f"{args.node2vec_output}/{species_name}.tsv" 16 | 17 | # run the single embedding 18 | run_single_embedding(args.input_network, temp_path, args.node2vec_output, args.dimensions, 19 | args.p, args.q, args.num_walks, args.walk_length, args.window_size, args.sg, 20 | args.negative, args.epochs, args.workers, args.random_state) 21 | 22 | # remove the temporary file 23 | os.remove(temp_path) 24 | 25 | return None 26 | 27 | 28 | if __name__ == "__main__": 29 | 30 | 31 | parser_single_embedding = argparse.ArgumentParser(description="Run node2vec on STRING functional networks with PecanPy.") 32 | 33 | parser_single_embedding.add_argument('-i','--input_network', type=str, help='File to run the\ 34 | embedding for, e.g. .tsv.gz. During running, the file will be processed into another temporary tsv file,\n\ 35 | proteins will be replaced by integers, and the scores will be converted to float (0~1). Once finished, the temporary file will be deleted.') 36 | 37 | parser_single_embedding.add_argument('-o','--node2vec_output', type=str, help='Path to the output folder to save the embeddings.\n\ 38 | The embeddings will be saved in the format: /.h5') 39 | 40 | 41 | ### model parameters, optional 42 | parser_single_embedding.add_argument('-d', '--dimensions', type=int, default=128,help='The number of dimensions for the embedding.') 43 | parser_single_embedding.add_argument('-p', '--p', type=float, default=0.3, help='The return parameter for the random walk.') 44 | parser_single_embedding.add_argument('-q', '--q', type=float, default=0.7, help='The in-out parameter for the random walk.') 45 | parser_single_embedding.add_argument('--num_walks', type=int, default=10, help='The number of walks to perform.') 46 | parser_single_embedding.add_argument('--walk_length', type=int, default=50,help='The length of the walk.') 47 | parser_single_embedding.add_argument('--window_size', type=int, default=5, help='The window size for the skip-gram model.') 48 | parser_single_embedding.add_argument('--sg', type=int, default=1, help='The type of training to use for the skip-gram model. 0 for cbow, 1 for skip-gram.') 49 | parser_single_embedding.add_argument('--negative', type=int, default=5, help='The number of negative samples to use for training the model.') 50 | parser_single_embedding.add_argument('-e', '--epochs', default=5, type=int, help='The number of epochs to train the model.') 51 | parser_single_embedding.add_argument('--workers', type=int, default=-1, help='The number of workers to use for training the model.') 52 | parser_single_embedding.add_argument('--random_state', type=int, default=1234, help='The random state to use for the random number generator.') 53 | 54 | args = parser_single_embedding.parse_args() 55 | 56 | main(args) 57 | -------------------------------------------------------------------------------- /scripts/align_seeds.py: -------------------------------------------------------------------------------- 1 | from space.models.fedcoder import FedCoder 2 | import argparse 3 | 4 | 5 | def main(args): 6 | args = vars(args) 7 | fedcoder = FedCoder(**args) 8 | fedcoder.fit() 9 | 10 | fedcoder.save_embeddings() 11 | 12 | return None 13 | 14 | if __name__ == '__main__': 15 | 16 | 17 | argparser = argparse.ArgumentParser(description='Align seed species') 18 | 19 | argparser.add_argument('--seed_species', type=str, default='data/seeds.txt', 20 | help='Path to seed species file, including the taxon IDs of the species to align') 21 | 22 | argparser.add_argument('--node2vec_dir', type=str, default='data/node2vec', 23 | help='Path to node2vec files') 24 | 25 | argparser.add_argument('--ortholog_dir', type=str,default='data/orthologs/seeds', 26 | help='Path to eggnog group files') 27 | 28 | argparser.add_argument('--aligned_embedding_save_dir', type=str,default='results/aligned_embeddings', 29 | help='Path to save embeddings') 30 | 31 | argparser.add_argument('--save_top_k', type=int, default=3, 32 | help='Number of top moldels to save') 33 | 34 | argparser.add_argument('--log_dir', type=str, default='logs/seeds', 35 | help='Path to save logs') 36 | 37 | argparser.add_argument('--input_dim', type=int, default=128, 38 | help='Input dimension') 39 | 40 | argparser.add_argument('--latent_dim', type=int, default=512, 41 | help='Latent dimension') 42 | 43 | argparser.add_argument('--hidden_dims', type=int, default=None, 44 | help='Hidden dimension') 45 | 46 | argparser.add_argument('--activation_fn', type=str, default=None, 47 | help='Activation function') 48 | 49 | argparser.add_argument('--batch_norm', type=bool, default=False, 50 | help='Batch normalization') 51 | 52 | argparser.add_argument('--number_iters', type=int, default=10, 53 | help='Number of iterations per epoch') 54 | 55 | argparser.add_argument('--autoencoder_type', type=str, default='naive', 56 | help='Type of autoencoder') 57 | 58 | argparser.add_argument('--gamma', type=float, default=0.1, 59 | help='Margin of the alignment loss') 60 | 61 | argparser.add_argument('--alpha', type=float, default=0.5, 62 | help='Balance between reconstruction and alignment (1-alpha) loss') 63 | 64 | argparser.add_argument('--lr', type=float, default=1e-2, 65 | help='Learning rate') 66 | 67 | argparser.add_argument('--device', type=str, default='cpu', 68 | help='Device to train on') 69 | 70 | argparser.add_argument('--patience', type=int, default=5, 71 | help='Patience for early stopping') 72 | 73 | argparser.add_argument('--delta', type=float, default=1e-4, 74 | help='Delta for early stopping') 75 | 76 | argparser.add_argument('--epochs', type=int, default=500, 77 | help='Number of maximum epochs') 78 | 79 | argparser.add_argument('--from_pretrained', type=str, default=None, 80 | help='Path to pretrained model') 81 | 82 | args = argparser.parse_args() 83 | 84 | main(args) -------------------------------------------------------------------------------- /scripts/distances/kegg_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from scipy import stats 4 | import matplotlib.pyplot as plt 5 | import itertools 6 | import seaborn as sns 7 | 8 | 9 | def sci_round(x, digits=2): 10 | """Round a number to scientific notation with specified digits.""" 11 | if x == 0: 12 | return '0' 13 | else: 14 | return f"{x:.{digits}e}" 15 | 16 | def format_diff(diff:str): 17 | if diff.startswith('-'): 18 | return diff 19 | else: 20 | return f'+{diff}' 21 | 22 | 23 | 24 | def get_stats_report(df,): 25 | 26 | col_a,col_b = df.columns[1], df.columns[2] 27 | 28 | if col_b == 'aligned_auc': 29 | # swap the columns 30 | col_a, col_b = col_b, col_a 31 | 32 | 33 | 34 | sample_size = df.shape[0] 35 | 36 | # map to the new name 37 | name_map = { 38 | 'node2vec_auc': 'node2vec', 39 | 'aligned_auc': 'Aligned', 40 | 't5_auc': 'ProtT5', 41 | } 42 | 43 | 44 | statistic, p_value = stats.wilcoxon(df[col_a], df[col_b],alternative='two-sided') 45 | 46 | differences = df[col_b] - df[col_a] 47 | n_pos = np.sum(differences > 0) 48 | n_neg = np.sum(differences < 0) 49 | effect_size = abs((n_pos - n_neg) / (n_pos + n_neg)) 50 | 51 | ratios = df[col_a] / df[col_b] 52 | ratio_q1, ratio_q3 = np.quantile(ratios, [0.25, 0.75]) 53 | ratio_median = np.median(ratios) 54 | 55 | ratio_q1 = round(ratio_q1, 2) 56 | ratio_q3 = round(ratio_q3, 2) 57 | ratio_median = round(ratio_median, 2) 58 | 59 | p_value = sci_round(p_value, 2) 60 | effect_size = round(effect_size, 2) 61 | 62 | # use a dataframe to store the results 63 | results = pd.DataFrame({ 64 | 'Method A': [name_map.get(col_a, col_a)], 65 | 'Method B': [name_map.get(col_b, col_b)], 66 | 'Sample Size': [sample_size], 67 | 'Ratio A/B Median (IQR)': [f'{ratio_median} ({ratio_q1} - {ratio_q3})'], 68 | 'p-value': [f'{p_value}'], 69 | 'Effect Size': [f'{effect_size}'], 70 | }) 71 | return results 72 | 73 | if __name__ == "__main__": 74 | 75 | df = pd.read_csv('./results/kegg_scores.tsv', sep='\t') 76 | seed_list = open('./data/seeds.txt').read().splitlines() 77 | seed_list = [int(x) for x in seed_list] 78 | df['seed'] = False 79 | for seed in seed_list: 80 | df.loc[df['species'] == seed, 'seed'] = True 81 | df_euk_groups = pd.read_csv('./data/euks_groups.tsv', 82 | sep='\t') 83 | df_euk_groups = df_euk_groups.iloc[:,:-1] 84 | df_euk_groups.columns = ['species', 'kingdom',] 85 | df = df.merge(df_euk_groups, on='species', how='left') 86 | ## if the kingdom is `other` set it to `protists` 87 | df.loc[df['kingdom'] == 'other', 'kingdom'] = 'protists' 88 | 89 | df_all_results = list() 90 | for kingdom in df['kingdom'].unique(): 91 | df_subset = df[(df['kingdom'] == kingdom)] 92 | if len(df_subset) > 0: 93 | for method1,method2 in itertools.combinations(['node2vec_auc', 'aligned_auc','t5_auc'],2): 94 | df_subset_method = df_subset[['species',method1, method2]] 95 | df_result = get_stats_report(df_subset_method) 96 | # uppercase the first letter of the kingdom 97 | df_result['Kingdom'] = kingdom.capitalize() 98 | df_all_results.append(df_result) 99 | 100 | df_all_results = pd.concat(df_all_results, ignore_index=True) 101 | 102 | ## change the order of the columns 103 | df_all_results = df_all_results[['Method A', 'Method B','Kingdom', 'Sample Size', 104 | 'Ratio A/B Median (IQR)', 105 | 'p-value', 'Effect Size']] 106 | # order by Method A, then Method B 107 | df_all_results = df_all_results.sort_values(by=['Method A', 'Method B']) 108 | 109 | # print(df_all_results) 110 | df_all_results.to_csv('./results/kegg_stats.csv', 111 | sep=',', index=False) -------------------------------------------------------------------------------- /scripts/align_non_seeds.py: -------------------------------------------------------------------------------- 1 | from space.models.fedcoder import FedCoderNonSeed 2 | import argparse 3 | import os 4 | 5 | def main(args): 6 | args = vars(args) 7 | fedcoder = FedCoderNonSeed(**args) 8 | fedcoder.fit() 9 | 10 | fedcoder.save_embeddings() 11 | 12 | os.system(f'rm -r {fedcoder.log_dir}') 13 | 14 | return None 15 | 16 | if __name__ == '__main__': 17 | 18 | 19 | argparser = argparse.ArgumentParser(description='Align non-seed species') 20 | 21 | argparser.add_argument('--seed_groups', type=str, default='data/euk_seed_groups.json', 22 | help='Path to seed groups file') 23 | 24 | argparser.add_argument('--node2vec_dir', type=str, default='data/node2vec', 25 | help='Folder of node2vec embeddings') 26 | 27 | argparser.add_argument('--tax_group', type=str, default='data/euks_groups.tsv', 28 | help='Path to taxonomic group file') 29 | 30 | argparser.add_argument('--non_seed_species', type=int,required=True, 31 | help='Taxonomy id of non seed species') 32 | 33 | argparser.add_argument('--aligned_dir', type=str, 34 | default='data/aligned', 35 | help='Path to save aligned embeddings') 36 | 37 | argparser.add_argument('--ortholog_dir', type=str, 38 | default='data/orthologs/non_seeds', 39 | help='Path to eggnog group files') 40 | 41 | argparser.add_argument('--aligned_embedding_save_dir', type=str, 42 | default='results/non_seed_embeddings', 43 | help='Path to save embeddings') 44 | 45 | argparser.add_argument('--save_top_k', type=int, default=3, 46 | help='Number of top moldels to save') 47 | 48 | argparser.add_argument('--log_dir', type=str, default='logs/non_seeds', 49 | help='Path to save logs') 50 | 51 | argparser.add_argument('--input_dim', type=int, default=128, 52 | help='Input dimension') 53 | 54 | argparser.add_argument('--latent_dim', type=int, default=512, 55 | help='Latent dimension') 56 | 57 | argparser.add_argument('--hidden_dims', type=int, default=None, 58 | help='Hidden dimension') 59 | 60 | argparser.add_argument('--activation_fn', type=str, default=None, 61 | help='Activation function') 62 | 63 | argparser.add_argument('--batch_norm', type=bool, default=False, 64 | help='Batch normalization') 65 | 66 | argparser.add_argument('--number_iters', type=int, default=10, 67 | help='Number of iterations per epoch') 68 | 69 | argparser.add_argument('--autoencoder_type', type=str, default='naive', 70 | help='Type of autoencoder') 71 | 72 | argparser.add_argument('--gamma', type=float, default=0.1, 73 | help='Margin of the alignment loss') 74 | 75 | argparser.add_argument('--alpha', type=float, default=0.5, 76 | help='Balance between reconstruction and alignment (1-alpha) loss') 77 | 78 | argparser.add_argument('--lr', type=float, default=1e-2, 79 | help='Learning rate') 80 | 81 | argparser.add_argument('--device', type=str, default='cpu', 82 | help='Device to train on') 83 | 84 | argparser.add_argument('--patience', type=int, default=5, 85 | help='Patience for early stopping') 86 | 87 | argparser.add_argument('--delta', type=float, default=1e-4, 88 | help='Delta for early stopping') 89 | 90 | argparser.add_argument('--epochs', type=int, default=500, 91 | help='Number of maximum epochs') 92 | 93 | argparser.add_argument('--from_pretrained', type=str, default=None, 94 | help='Path to pretrained model') 95 | 96 | args = argparser.parse_args() 97 | 98 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPACE: STRING proteins as complementary embeddings 2 | 3 | ## Table of Contents 4 | - [Introduction](#introduction) 5 | - [Star History](#star-history) 6 | - [Reproduce the Results](#reproduce-the-results-in-the-paper) 7 | - [How to Cite](#how-to-cite) 8 | - [How to load the embeddings](#how-to-load-the-embeddings) 9 | - [Python Example](#python-example) 10 | - [R Example](#r-example) 11 | - [Read Combined Files](#read-combined-files) 12 | - [Contact](#contact) 13 | - [License](#license) 14 | 15 | ## Introduction 16 | Official repository for the paper in **Bioinformatics**: [SPACE: STRING proteins as complementary embeddings](https://doi.org/10.1093/bioinformatics/btaf496), 17 | in which we precalculated: 18 | - cross-species network embeddings 19 | - ProtT5 sequence embeddings 20 | 21 | for all eukaryotic proteins in STRING v12.0. 22 | 23 | You can [download all the embeddings from the STRING website](https://string-db.org/cgi/download): 24 | - protein.network.embeddings.v12.0.h5 25 | - protein.sequence.embeddings.v12.0.h5 26 | 27 | ![SPACE](./figures/space_overview.png) 28 | 29 | ## Star history 30 | [![Star History Chart](https://api.star-history.com/svg?repos=deweihu96/SPACE&type=Date)](https://www.star-history.com/#deweihu96/SPACE&Date) 31 | 32 | ## Reproduce the results in the paper 33 | Please follow this [document](./reproduce.md). 34 | 35 | ## How to Cite 36 | If you use this work in your research, please cite **the SPACE paper**: 37 | 38 | Hu, Dewei, et al. "SPACE: STRING proteins as complementary embeddings." Bioinformatics (2025): btaf496. [https://doi.org/10.1101/2024.11.25.625140](https://doi.org/10.1093/bioinformatics/btaf496) 39 | 40 | and **the STRING database**: 41 | 42 | Szklarczyk, D., Nastou, K., Koutrouli, M., Kirsch, R., Mehryary, F., Hachilif, R., ... & von Mering, C. (2025). The STRING database in 2025: protein networks with directionality of regulation. Nucleic Acids Research, 53(D1), D730-D737. [https://doi.org/10.1093/nar/gkae1113](https://doi.org/10.1093/nar/gkae1113) 43 | 44 | ## How to load the embeddings 45 | 46 | The following code reads the cross-species network embedding file `9606.protein.network.embeddings.v12.0.h5`. 47 | 48 | #### Python example 49 | ```bash 50 | pip install h5py 51 | ``` 52 | 53 | 54 | ```Python 55 | import h5py 56 | 57 | filename = '9606.protein.network.embeddings.v12.0.h5' 58 | 59 | with h5py.File(filename, 'r') as f: 60 | meta_keys = f['metadata'].attrs.keys() 61 | for key in meta_keys: 62 | print(key, f['metadata'].attrs[key]) 63 | 64 | embedding = f['embeddings'][:] 65 | proteins = f['proteins'][:] 66 | 67 | # protein names are stored as bytes, convert them to strings 68 | proteins = [p.decode('utf-8') for p in proteins] 69 | ``` 70 | 71 | #### R example: 72 | Install the `rhdf5` package to read the embedding files. The following code reads the embedding file `9606.protein.network.embeddings.v12.0.h5`. 73 | 74 | ```R 75 | # Install required packages if not already installed 76 | # install.packages("rhdf5") 77 | 78 | # Load the library 79 | library(rhdf5) 80 | 81 | filename <- '9606.protein.network.embeddings.v12.0.h5' 82 | 83 | metadata <- h5readAttributes(filename, "metadata") 84 | for (key in names(meta_keys)) { 85 | print(paste(key, meta_keys[[key]])) 86 | } 87 | 88 | embeddings <- h5read(filename, "embeddings") 89 | proteins <- h5read(filename, "proteins") 90 | ``` 91 | #### Read combined files 92 | Read the combined network embedding file of all eukaryotes with Python 93 | ```Python 94 | import h5py 95 | 96 | filename = 'protein.network.embeddings.v12.0.h5' 97 | 98 | with h5py.File(filename, 'r') as f: 99 | meta_keys = f['metadata'].attrs.keys() 100 | for key in meta_keys: 101 | print(key, f['metadata'].attrs[key]) 102 | 103 | species = '4932' # if we check the brewer's yeast 104 | embeddings = f['species'][species]['embeddings'][:] 105 | proteins = f['species'][species]['proteins'][:] 106 | 107 | # protein names are stored as bytes, convert them to strings 108 | proteins = [p.decode('utf-8') for p in proteins] 109 | 110 | ``` 111 | Read the combined file with R 112 | ```R 113 | library(rhdf5) 114 | 115 | filename <- 'protein.network.embeddings.v12.0.h5' 116 | 117 | meta_keys <- h5attributes(h5file$metadata) 118 | for (key in names(meta_keys)) { 119 | print(paste(key, meta_keys[[key]])) 120 | } 121 | 122 | species <- '4932' # for brewer's yeast 123 | embeddings <- h5read(filename, paste0('species/', species, '/embeddings')) 124 | proteins <- h5read(filename, paste0('species/', species, '/proteins')) 125 | ``` 126 | 127 | ## Contact 128 | [dewei.hu@sund.ku.dk](dewei.hu@sund.ku.dk). 129 | 130 | 131 | ## License 132 | MIT. 133 | -------------------------------------------------------------------------------- /scripts/distances/deeploc_distance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('./scripts') 4 | from subloc import load_cv_set,filter_and_load_proteins_embeddings 5 | import random 6 | import pandas as pd 7 | import numpy as np 8 | from scipy import stats 9 | import itertools 10 | from tqdm import tqdm 11 | 12 | def format_diff(difference): 13 | """Format the difference for output.""" 14 | if difference > 0: 15 | return f'+{difference:.2f}' 16 | else: 17 | return f'{difference:.2f}' 18 | 19 | def normalize_vectors(vectors): 20 | """Normalize the vectors to unit length.""" 21 | norm = np.linalg.norm(vectors, axis=1, keepdims=True) 22 | return vectors / np.maximum(norm, 1e-9) 23 | 24 | 25 | def mutual_distances(embed_loc, embed_other_loc, number_of_samples): 26 | 27 | used_proteins = set() 28 | sampled_pairs = [] 29 | 30 | possible_pairs = list(itertools.product(range(len(embed_loc)), range(len(embed_other_loc)))) 31 | 32 | random.shuffle(possible_pairs) 33 | for i, j in possible_pairs: 34 | if i not in used_proteins or j not in used_proteins: 35 | used_proteins.add(i) 36 | used_proteins.add(j) 37 | sampled_pairs.append((i, j)) 38 | if len(sampled_pairs) > number_of_samples: 39 | break 40 | 41 | sampled_indices = np.array(sampled_pairs).T 42 | src_emb = embed_loc[sampled_indices[0]] 43 | tgt_emb = embed_other_loc[sampled_indices[1]] 44 | cosine_similarities = np.sum(src_emb * tgt_emb, axis=1) / (np.linalg.norm(src_emb, axis=1) * np.linalg.norm(tgt_emb, axis=1)) 45 | 46 | return cosine_similarities 47 | 48 | def inner_distances(embeddings, ): 49 | """Calculate distances for a given set of embeddings.""" 50 | possible_pairs = list(itertools.combinations(range(len(embeddings)), 2)) 51 | # shuffle the pairs to ensure randomness 52 | random.shuffle(possible_pairs) 53 | 54 | used_proteins = set() 55 | sampled_pairs = [] 56 | 57 | for i, j in possible_pairs: 58 | if i not in used_proteins or j not in used_proteins: 59 | used_proteins.add(i) 60 | used_proteins.add(j) 61 | sampled_pairs.append((i, j)) 62 | 63 | sampled_indices = np.array(sampled_pairs).T 64 | 65 | src_emb = embeddings[sampled_indices[0]] 66 | tgt_emb = embeddings[sampled_indices[1]] 67 | 68 | cosine_similarities = np.sum(src_emb * tgt_emb, axis=1) / (np.linalg.norm(src_emb, axis=1) * np.linalg.norm(tgt_emb, axis=1)) 69 | 70 | return cosine_similarities 71 | 72 | def single_location_p_value(location,embed,label_df): 73 | embed_loc = embed[label_df[label_df[location] == 1].index] 74 | 75 | inner_cosine = inner_distances(embed_loc) 76 | 77 | embed_other_loc = embed[label_df[label_df[location] == 0].index] 78 | 79 | mutual_cosine = mutual_distances(embed_loc, embed_other_loc, number_of_samples=len(inner_cosine)) 80 | mutual_q1 = round(np.quantile(mutual_cosine, 0.25),2) 81 | mutual_q3 = round(np.quantile(mutual_cosine, 0.75),2) 82 | 83 | 84 | mutual_median = round(float(np.median(mutual_cosine)),2) 85 | 86 | inner_q1 = round(np.quantile(inner_cosine, 0.25),2) 87 | inner_q3 = round(np.quantile(inner_cosine, 0.75),2) 88 | inner_median = round(float(np.median(inner_cosine)),2) 89 | 90 | stat_cosine, p_value_cosine = stats.mannwhitneyu(mutual_cosine, inner_cosine, alternative='two-sided') 91 | 92 | effect_cosine = round((1 - (2*stat_cosine) / (len(mutual_cosine) * len(inner_cosine))),2) 93 | 94 | difference = format_diff(float(inner_median - mutual_median)) 95 | 96 | # use scientific notation for p-value 97 | p_value_cosine = f'{p_value_cosine:.2e}' 98 | 99 | # return p_value_euclidean, p_value_cosine, stat_mutual, stat_inner 100 | results = { 101 | 'Location': location, 102 | 'N': str(len(mutual_cosine)), 103 | 'Intra-Loc (IQR)': f'{inner_median} ({inner_q1} - {inner_q3})', 104 | 'Inter-Loc (IQR)': f'{mutual_median} ({mutual_q1} - {mutual_q3})', 105 | 'Difference': difference, 106 | 'p-value': p_value_cosine, 107 | 'Effect Size': f'{round(effect_cosine,2)}', 108 | } 109 | return results 110 | 111 | def main(cv_set, cv_id_mapping,aligned_dir, 112 | jobs=1): 113 | 114 | cv_ids, cv_labels, cv_label_headers, cv_partitions, cv_species = load_cv_set(cv_set, cv_id_mapping) 115 | 116 | cv_ids_aligned, cv_labels_aligned, cv_partitions_aligned, \ 117 | aligned_proteins, aligned_embeddings = filter_and_load_proteins_embeddings( 118 | cv_ids, 119 | cv_labels, 120 | cv_partitions, 121 | cv_species, 122 | aligned_dir, 123 | n_jobs=jobs 124 | ) 125 | 126 | df_labels = pd.DataFrame(cv_labels_aligned, columns=cv_label_headers) 127 | df_labels['id'] = cv_ids_aligned 128 | 129 | # use for loop 130 | results = [] 131 | # loc = 'Peroxisome' 132 | # results = [single_location_p_value(loc, aligned_embeddings, df_labels)] 133 | for loc in tqdm(cv_label_headers): 134 | if loc != 'id': 135 | result = single_location_p_value(loc, aligned_embeddings, df_labels) 136 | results.append(result) 137 | # break 138 | 139 | df = pd.DataFrame(results) 140 | # put the location as the first column 141 | df = df[['Location',] + [col for col in df.columns if col != 'Location']] 142 | 143 | df.to_csv('./results/deeploc_location_distances.csv', index=False) 144 | 145 | 146 | 147 | if __name__ == '__main__': 148 | 149 | np.random.seed(42) 150 | random.seed(42) 151 | 152 | cv_set = './data/benchmarks/deeploc/Swissprot_Train_Validation_dataset.csv' 153 | cv_id_mapping = './data/benchmarks/deeploc/cv_idmapping.tsv' 154 | aligned_dir = './data/functional_emb' 155 | 156 | main(cv_set, cv_id_mapping, aligned_dir, jobs=8) 157 | 158 | -------------------------------------------------------------------------------- /src/space/tools/process_orthologs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import gzip 4 | from itertools import combinations, product 5 | from multiprocessing import Pool 6 | from typing import Set 7 | import re 8 | ## for some species, the taxid used in STRING is different from the one used in ncbi 9 | __TAXID_CONVERTER__ = {339724:2059318, 10 | 1745343:2082293, 11 | 1325735:2604345, 12 | 1658172:1955775, 13 | 56484:2754530, 14 | 1266660:1072105, 15 | 1229665:2502994, 16 | 944018:1913371, 17 | 743788:2126942 18 | } 19 | 20 | 21 | class Lineage: 22 | 23 | def __init__(self,node_dmp_zip,group_dir) -> None: 24 | 25 | self.df = pd.read_csv(node_dmp_zip,sep='|',compression='zip',header=None) 26 | self.eggnog_ancestors = {f.split('.')[0] for f in os.listdir(group_dir) } 27 | self.group_dir = group_dir 28 | 29 | def get_lineage(self,taxid): 30 | taxid = int(taxid) 31 | line = [taxid] 32 | while taxid != 1: 33 | taxid = self.df[self.df.iloc[:,0]==taxid].iloc[0,1] 34 | line = [taxid] + line 35 | return line 36 | 37 | def common_ancestor(self,taxid_1,taxid_2): 38 | taxid_1 = int(taxid_1) 39 | taxid_2 = int(taxid_2) 40 | 41 | use_taxid_1 = __TAXID_CONVERTER__.get(taxid_1,taxid_1) 42 | use_taxid_2 = __TAXID_CONVERTER__.get(taxid_2,taxid_2) 43 | 44 | l_1 = self.get_lineage(use_taxid_1) 45 | l_2 = self.get_lineage(use_taxid_2) 46 | for idx,taxid in enumerate(l_1): 47 | if taxid in l_2: 48 | common_ancestor = taxid 49 | else: 50 | break 51 | 52 | ## make sure eggNOG has the common ancestor, and the orthologs are not empty 53 | while True: 54 | if self.check_ortholog_group(taxid_1,taxid_2,common_ancestor): 55 | break 56 | idx -= 1 57 | common_ancestor = l_1[idx-1] 58 | 59 | return str(taxid_1),str(taxid_2),int(common_ancestor) 60 | 61 | 62 | def check_ortholog_group(self,taxid_1,taxid_2,ancestor): 63 | 64 | group_file = f'{self.group_dir}/{ancestor}.tsv.gz' 65 | 66 | if not os.path.exists(group_file): 67 | return False 68 | 69 | with gzip.open(group_file,'rt') as f: 70 | 71 | for line in f: 72 | line = line.strip().split('\t') 73 | species_list = line[-2].split(',') 74 | 75 | if str(taxid_1) in species_list and str(taxid_2) in species_list: 76 | return True 77 | 78 | return False 79 | 80 | 81 | 82 | def infer_common_ancestor(self,seed_species,ncbi_taxonomy_file,eggnog_group_main_folder): 83 | lineage = Lineage(ncbi_taxonomy_file,group_dir=eggnog_group_main_folder) 84 | 85 | src_tgt = list(combinations(seed_species,2)) 86 | # sort the pairs 87 | src_tgt = [sorted(pair) for pair in src_tgt] 88 | if self.jobs > 1: 89 | use_jobs = int(self.jobs/2) ## for memory safaty 90 | with Pool(use_jobs) as p: 91 | common_ancestors = p.starmap(lineage.common_ancestor,src_tgt) 92 | 93 | return common_ancestors 94 | 95 | def get_proteins_from_members(tax, proteins): 96 | pattern = re.compile(rf'{tax}\.[^,]*') 97 | return pattern.findall(proteins) 98 | 99 | def parse_single_og_file(src,tgt, 100 | ancestor, 101 | eggnog_group_file_main_folder, 102 | alpha=1,og_threshold=0.1, 103 | save_to:str=None, 104 | src_set:Set=None, 105 | tgt_set:Set=None): 106 | 107 | group_file = f'{eggnog_group_file_main_folder}/{ancestor}.tsv.gz' 108 | 109 | src_set_index = {p:i for i,p in enumerate(src_set)} 110 | tgt_set_index = {p:i for i,p in enumerate(tgt_set)} 111 | 112 | results = set() ## store the results for a single species pair 113 | 114 | with gzip.open(group_file, 'rt') as f: 115 | for line in f: ## each line is an ortholog group 116 | 117 | line = line.strip().split('\t') 118 | _, species_list, orthologs = line[1], line[-2].split(','), line[-1] 119 | 120 | if str(src) in species_list and str(tgt) in species_list: 121 | source_proteins = get_proteins_from_members(src, orthologs) 122 | target_proteins = get_proteins_from_members(tgt, orthologs) 123 | 124 | ## use a filter to remove proteins that are not in the embeddings, if both src_set and tgt_set are provided 125 | if src_set is not None and tgt_set is not None: 126 | source_proteins = list(filter(lambda x: x in src_set, source_proteins)) 127 | target_proteins = list(filter(lambda x: x in tgt_set, target_proteins)) 128 | 129 | ## get the frequency inverse of each protein in the ortholog group 130 | source_protein_freq_inv = {p: 1/len(source_proteins) for p in source_proteins} 131 | target_protein_freq_inv = {p: 1/len(target_proteins) for p in target_proteins} 132 | 133 | ## get the product of the proteins with their frequency inverse product 134 | 135 | results.update({'\t'.join([str(src_set_index[p1]), str(tgt_set_index[p2]), 136 | str((source_protein_freq_inv[p1]*target_protein_freq_inv[p2])**alpha)]) 137 | for p1, p2 in product(source_proteins, target_proteins) 138 | if (source_protein_freq_inv[p1]*target_protein_freq_inv[p2])**alpha >= og_threshold}) 139 | 140 | if save_to is not None: 141 | if not os.path.exists(save_to): 142 | os.makedirs(save_to) 143 | ## save the results to main_save_folder 144 | with open(f'{save_to}/{src}_{tgt}.tsv','w') as f: 145 | f.write('\n'.join(results)) 146 | 147 | return results ## return the results for a single species pair -------------------------------------------------------------------------------- /reproduce.md: -------------------------------------------------------------------------------- 1 | ## Reproduce the results 2 | 3 | ### 1. Install the dependencies 4 | 5 | ```bash 6 | git clone https://github.com/deweihu96/SPACE.git 7 | conda create -n space python=3.11 8 | conda activate space 9 | cd SPACE 10 | pip install . 11 | mkdir results 12 | ``` 13 | 14 | ### 2. Download the data 15 | 16 | If you only need to use the embeddings, you should download from the STRING website: https://string-db.org/cgi/download 17 | 18 | This data set only serves as a backup and reference for reimplementation. 19 | 20 | Download the data and decompress it to the `data` folder. 21 | The `data/` is around `212GB` in total, including the networks, sequences, all the embeddings, and benchmark datasets. 22 | ```bash 23 | wget https://sid.erda.dk/share_redirect/cZ4tLqQZhv -O data.tar 24 | tar -xvf data.tar 25 | ``` 26 | eggNOG dataset is sourced from the eggNOG database: http://eggnog6.embl.de/download/eggnog_6.0/, make sure you cite the eggNOG database if you use the data in your work. 27 | ``` 28 | Hernández-Plaza, Ana, et al. "eggNOG 6.0: enabling comparative genomics across 12 535 organisms." Nucleic Acids Research 51.D1 (2023): D389-D394. 29 | ``` 30 | 31 | 32 | The DeeoLoc dataset is originally from https://services.healthtech.dtu.dk/services/DeepLoc-2.0/ 33 | the protein function prediction benchmark has to be downloaded manually from the following link, according the NetGO paper (https://doi.org/10.1093/nar/gkz388): https://drive.google.com/drive/folders/1HLH1aCDxlrVpu1zKvgfdQFEFnbT8gChm 34 | 35 | Please cite the original data sources and respect the rules if you use the benchmark data in your work: 36 | ``` 37 | Thumuluri, Vineet, et al. "DeepLoc 2.0: multi-label subcellular localization prediction using protein language models." Nucleic acids research 50.W1 (2022): W228-W234. 38 | 39 | Yao, Shuwei, et al. "NetGO 2.0: improving large-scale protein function prediction with massive sequence, text, domain, family and network information." Nucleic acids research 49.W1 (2021): W469-W475. 40 | ``` 41 | 42 | 43 | ### 3. Generate the functional embeddings 44 | You can get the help with each script by running `python scripts/xxx.py -h`. 45 | 46 | #### 3.1 Run the node2vec algorithm to generate the node embeddings. 47 | ```bash 48 | ## for instance, run the node2vec on human network 49 | mkdir results/node2vec 50 | 51 | python scripts/node2vec.py \ 52 | --input_network data/networks/9606.protein.links.v12.0.txt.gz \ 53 | --node2vec_output results/node2vec 54 | ``` 55 | You could also use other node embedding algorithms as the input to alignment, but make sure that the output embeddings are in the same format as the node2vec embeddings. 56 | Also the index of the nodes in the embeddings should be the same as the index of the nodes in the network file, or the nodes in the node2vec embeddings. 57 | 58 | #### 3.2 Run the FedCoder to align the seed species 59 | 60 | ```bash 61 | # with the best hyperparameters in the paper 62 | python scripts/align_seeds.py \ 63 | --seed_species data/seeds.txt \ 64 | --node2vec_dir data/node2vec \ 65 | --aligned_embedding_save_dir results/aligned_embeddings 66 | ``` 67 | 68 | #### 3.3 Run the FedCoder to align the non-seed species 69 | 70 | ```bash 71 | # for instance, align Rattus norvegicus (Norway rat) 72 | python scripts/align_non_seeds.py \ 73 | --node2vec_dir data/node2vec \ 74 | --aligned_dir data/aligned \ 75 | --non_seed_species 10116 \ 76 | --aligned_embedding_save_dir results/aligned_embeddings 77 | ``` 78 | 79 | #### 3.4 Add the singletons to the aligned embedding space 80 | 81 | ```bash 82 | python scripts/add_singleton.py \ 83 | --aligned_dir data/aligned \ 84 | --full_embedding_save_dir results/functional_embeddings 85 | ``` 86 | 87 | #### 3.5 Generate the ProtT5 embeddings 88 | 89 | ```bash 90 | # make sure you have the sentencepiece library installed 91 | pip install sentencepiece 92 | 93 | # make sure you have enough GPU memory, otherwise adjust the min_length and max_length to run the script only for sequence within the range 94 | # for example, with min_length=1 and max_length=1000, the sequences: 1<=length<=1000 will be processed. 95 | 96 | # if you have enough GPU memory (~60GB), those sequences with length [1,8000] can be processed. 97 | # we ran the ProtT5 embedding on those super-long sequences with length [8000, 100000] on a CPU server, up to 400GB memory. 98 | 99 | # for example, human sequences 100 | python scripts/prott5_emb.py \ 101 | --seq_file data/sequences/9606.protein.sequences.v12.0.fa.gz \ 102 | --save_path results/prott5/9606.h5 \ 103 | --max_length 1000 \ 104 | --min_length 1 \ 105 | --device cuda 106 | ``` 107 | 108 | Reference: 109 | ``` 110 | https://github.com/agemagician/ProtTrans 111 | https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc 112 | ``` 113 | 114 | 115 | #### 4. Evaluate the functional embeddings 116 | 117 | To run the following scripts with default parameters, it uses the data inside the `data` folder. 118 | 119 | Change the input to your own data if you want to evaluate the functional embeddings on your own networks. 120 | 121 | Use the `-h` option to get the help message for each script. 122 | ```bash 123 | # subcellular localization prediction, it also generates the umap plot of the subcellular localization 124 | python scripts/subloc.py 125 | 126 | # function prediction 127 | python scripts/func_pred.py 128 | 129 | # to have the umap plot of the species (fig1.b) 130 | python scripts/umap_species.py 131 | 132 | # precision-recall curve 133 | python scripts/pr_curves.py 134 | 135 | # due to the copyright issue, we cannot provide the KEGG data and its evaluation 136 | ``` 137 | 138 | #### 5. Alignment quality evaluation 139 | This part is not included in the preprint now. 140 | You can use the following scripts to evaluate the alignment quality of the functional embeddings. 141 | 142 | To assess if the orthologous proteins are similar to each other than non-orthologous proteins in the aligned embedding space. We sample the orthologs and non-orthologs to reduce the memory usage. 143 | Please check the code for the details of the sampling strategy. 144 | ```bash 145 | python scripts/distances/og_sampling.py 146 | ``` 147 | 148 | 149 | To assess if the proteins are similar to each other in the same subcellular localization in the aligned embedding space: 150 | ```bash 151 | python scripts/distances/deeploc_distance.py 152 | ``` 153 | -------------------------------------------------------------------------------- /scripts/pr_curves.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | 8 | 9 | argparser = argparse.ArgumentParser(description="Plot PR curves for SPACE and other methods.") 10 | argparser.add_argument('--cv_data', type=str, default='data/benchmark_results/sub_loc/cv_data.csv', 11 | help='Path to the cross-validation data CSV file.') 12 | argparser.add_argument('--cv_pr_curves', type=str, default='data/benchmark_results/sub_loc/cv_pr_curves.csv', 13 | help='Path to the cross-validation PR curves CSV file.') 14 | argparser.add_argument('--hpa_pr_curves', type=str, default='data/benchmark_results/sub_loc/hpa_pr_curves.csv', 15 | help='Path to the HPA PR curves CSV file.') 16 | argparser.add_argument('--cafa_eval_index', type=str, default='data/benchmark_results/func_pred/cafa-eval-index', 17 | help='Path to the CAFA evaluation index directory.') 18 | argparser.add_argument('--scores_csv', type=str, default='data/benchmark_results/func_pred/scores.csv', 19 | help='Path to the scores CSV file for functional prediction.') 20 | argparser.add_argument('--output', type=str, default='results/pr_curves.png', 21 | help='Output path for the PR curves plot.') 22 | args = argparser.parse_args() 23 | 24 | 25 | custom_dashes = [5, 8] 26 | title_size = 14 27 | tick_size = 14 28 | label_size = 14 29 | legend_fontsize = 14 30 | 31 | df_cv = pd.read_csv(args.cv_pr_curves) 32 | df_hpa = pd.read_csv(args.hpa_pr_curves) 33 | 34 | plt.figure(figsize=(18, 10)) 35 | plt.subplots_adjust(wspace=0.3,hspace=0.37) # Increase vertical space between plots 36 | 37 | plt.subplot(2, 3, 1) 38 | colors = ['#1e1e1e','#bcbcbc',"#e60049",] 39 | plt.plot(df_cv['space_recall'], df_cv['space_prec'], label='SPACE', color=colors[2], linestyle='-', linewidth=2) 40 | plt.plot(df_cv['aligned_recall'], df_cv['aligned_prec'], label='Aligned', color=colors[0], linestyle='-', linewidth=2) 41 | plt.plot(df_cv['t5_recall'], df_cv['t5_prec'], label='ProtT5', color=colors[1], linestyle='-', linewidth=2) 42 | 43 | 44 | plt.yticks(np.arange(0,1.1,0.2),fontsize=tick_size) 45 | plt.xticks(np.arange(0,1.1,0.2),fontsize=tick_size) 46 | plt.xlim(0,1) 47 | plt.ylim(0,1) 48 | plt.xlabel('Recall',fontsize=label_size) 49 | plt.ylabel('Precision',fontsize=label_size) 50 | plt.gca().spines['top'].set_visible(False) 51 | plt.gca().spines['right'].set_visible(False) 52 | plt.title("SwissProt Cross Validation Set",fontsize=title_size) 53 | 54 | 55 | plt.subplot(2, 3, 2) 56 | plt.plot(df_hpa['space_recall'], df_hpa['space_prec'], label='SPACE', color=colors[2], linestyle='-', linewidth=2) 57 | plt.plot(df_hpa['aligned_recall'], df_hpa['aligned_prec'], label='Aligned', color=colors[0], linestyle='-', linewidth=2) 58 | plt.plot(df_hpa['t5_recall'], df_hpa['t5_prec'], label='ProtT5', color=colors[1], linestyle='-', linewidth=2) 59 | # deeploc2 at the best cutoffs per label 60 | # Precision: 0.5763269140441521 61 | # Recall: 0.6147294589178357 62 | plt.scatter(0.6147294589178357, 0.5763269140441521, s=120, color="#0bb4ff", marker='*',) 63 | 64 | plt.yticks(np.arange(0,1.1,0.2),fontsize=tick_size) 65 | plt.xticks(np.arange(0,1.1,0.2),fontsize=tick_size) 66 | plt.xlim(0,1) 67 | plt.ylim(0,1) 68 | plt.xlabel('Recall',fontsize=label_size) 69 | plt.ylabel('Precision',fontsize=label_size) 70 | plt.gca().spines['top'].set_visible(False) 71 | plt.gca().spines['right'].set_visible(False) 72 | plt.annotate('DeepLoc 2.0', # Text 73 | xy=(0.6147294589178357, 0.5763269140441521), # Point to annotate 74 | xytext=(0.5, 0.5), # Text position 75 | fontsize=12, 76 | ) 77 | 78 | # have a horizontal legend in the middle bottom, no border 79 | plt.legend(loc='lower center', bbox_to_anchor=(-0.25, -0.3), 80 | ncol=4, frameon=False, fontsize=legend_fontsize, 81 | ) 82 | 83 | plt.title("HPA Test Set",fontsize=title_size) 84 | 85 | 86 | aspects = ['mf', 'bp','cc'] 87 | 88 | netgo_scores = pd.read_csv(args.scores_csv) 89 | # plt.figure(figsize=(15, 8)) 90 | label_dict = {'aligned':"Aligned", 'seq':"ProtT5", 'space':"SPACE"} 91 | aspect_dict = {'cc':"Cellular Component", 'bp':"Biological Process", 'mf':"Molecular Function"} 92 | for aspect in aspects: 93 | plt.subplot(2, 3, aspects.index(aspect)+4) 94 | colors = ["#e60049",'#1e1e1e','#bcbcbc',] 95 | for idx,emb in enumerate(['space','aligned','seq',]): 96 | 97 | if emb == 'space': 98 | emb_ = 'seq_concat_aligned' 99 | else: 100 | emb_ = emb 101 | 102 | df = pd.read_csv(f'{args.cafa_eval_index}/{aspect}_{emb_}_merged.csv') 103 | rc,pr = df['rc'].values, df['pr'].values 104 | ## make sure both rc and pr are 105 | 106 | plt.plot(df['rc'], df['pr'], label=label_dict[emb], 107 | linestyle='-', linewidth=2, color=colors[idx]) 108 | 109 | pr,rc = netgo_scores[netgo_scores['entry']==aspect+'_'+emb_+'_merged'][['pr','rc']].values[0] 110 | ## annotate this point, with 'x' 111 | plt.scatter(rc, pr, s=80, color=colors[idx], marker='*') 112 | plt.xlabel('Recall',fontsize=label_size) 113 | plt.ylabel('Precision',fontsize=label_size) 114 | plt.title(aspect_dict[aspect],fontsize=title_size) 115 | plt.xticks(np.arange(0,1.1,0.2),fontsize=tick_size) 116 | plt.yticks(np.arange(0,1.1,0.2),fontsize=tick_size) 117 | plt.xlim(0,1) 118 | plt.ylim(0,1) 119 | ## get rid of the up and right spines 120 | plt.gca().spines['top'].set_visible(False) 121 | plt.gca().spines['right'].set_visible(False) 122 | if aspect == 'bp': 123 | plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.16), ncol=3, frameon=False, 124 | fontsize=legend_fontsize) 125 | 126 | # plt.tight_layout() 127 | 128 | plt.savefig(args.output, dpi=300, bbox_inches='tight') 129 | print(f"PR curves saved to {args.output}") 130 | 131 | -------------------------------------------------------------------------------- /scripts/prott5_emb.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5EncoderModel 2 | import torch 3 | import re 4 | from Bio import SeqIO 5 | import gzip 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Dataset 8 | from space.tools.data import H5pyData 9 | import sys 10 | from loguru import logger 11 | import argparse 12 | import os 13 | from tqdm import tqdm 14 | 15 | def load_sequences(filename,max_length=8000, 16 | min_length=1): 17 | 18 | logger.info(f"Loading sequences from {filename} with max_length={max_length} and min_length={min_length}") 19 | 20 | file_type = filename.split('.')[-1] 21 | if file_type in ['fasta', 'fa', 'txt']: 22 | with open(filename, 'r') as f: 23 | # if the sequence is too long, we drop it 24 | sequences = [] 25 | for record in SeqIO.parse(f, 'fasta'): 26 | seq = str(record.seq) 27 | if len(seq) <= max_length and len(seq) >= min_length: 28 | sequences.append([str(record.id), seq]) 29 | elif filename.endswith('.gz'): 30 | with gzip.open(filename, 'rt') as f: 31 | sequences = [] 32 | for record in SeqIO.parse(f, 'fasta'): 33 | seq = str(record.seq) 34 | if len(seq) <= max_length and len(seq) >= min_length: 35 | sequences.append([str(record.id), seq]) 36 | else: 37 | raise ValueError('Unsupported file format') 38 | 39 | return sequences 40 | 41 | 42 | 43 | class SeqDataset(Dataset): 44 | 45 | def __init__(self,sequences) -> None: 46 | super().__init__() 47 | 48 | self.sequences = sequences 49 | 50 | def __len__(self): 51 | return len(self.sequences) 52 | 53 | def __getitem__(self, index): 54 | 55 | return len(self.sequences[index][1]), self.sequences[index][0], self.sequences[index][1] 56 | 57 | 58 | def main(seq_file, 59 | save_path, 60 | device, 61 | max_length=8000, 62 | min_length=1, 63 | batch_size=1,): 64 | 65 | 66 | 67 | # check if the device is valid 68 | if device not in ['cuda', 'cpu']: 69 | logger.error("Invalid device specified. Use 'cuda' or 'cpu'.") 70 | sys.exit(1) 71 | device = torch.device(device) 72 | logger.info(f"Using device: {device}") 73 | 74 | sequences = load_sequences(seq_file,max_length=max_length, 75 | min_length=min_length) 76 | # check if the sequences are empty 77 | if len(sequences) == 0: 78 | logger.error("No sequences found in the input file.") 79 | sys.exit(1) 80 | 81 | logger.info(f"Loaded {len(sequences)} sequences from {seq_file}") 82 | 83 | 84 | 85 | 86 | # Load the tokenizer 87 | if not os.path.exists("temp/seq_models"): 88 | os.makedirs("temp/seq_models", exist_ok=True) 89 | tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', 90 | do_lower_case=False,cache_dir="temp/seq_models", 91 | ) 92 | 93 | # Load the model 94 | model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc",cache_dir="temp/seq_models").to(device) 95 | 96 | # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower) 97 | model = model.to(torch.float16) if device!=torch.device("cpu") else model 98 | 99 | 100 | seq_loader = DataLoader(SeqDataset(sequences),batch_size=batch_size,shuffle=True) 101 | 102 | output_seq_names = [] 103 | output_seq_embeddings = [] 104 | 105 | for seq_len,seq_name,seq in tqdm(seq_loader, desc="Calculating embeddings"): 106 | ## sort the sequences by length 107 | seq_len,sort_index = seq_len.sort(descending=False) 108 | seq_name = [seq_name[i] for i in sort_index] 109 | seq = [seq[i] for i in sort_index] 110 | 111 | seq = [" ".join(list(re.sub(r"[UZOB]", "X", seq_))) for seq_ in seq] 112 | 113 | ids = tokenizer(seq, add_special_tokens=True, padding="longest") 114 | input_ids = torch.tensor(ids['input_ids']).to(device) 115 | attention_mask = torch.tensor(ids['attention_mask']).to(device) 116 | 117 | # logger.info(f"Processing {seq_name}") 118 | # generate embeddings 119 | with torch.no_grad(): 120 | embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask) 121 | 122 | for i in range(len(seq_len)): 123 | emb = embedding_repr.last_hidden_state[i][:seq_len[i]].mean(dim=0) 124 | # output.append([seq_name[i],emb]) 125 | output_seq_names.append(seq_name[i]) 126 | output_seq_embeddings.append(emb.to('cpu').numpy()) 127 | # logger.info(f"Processing finished {seq_name}") 128 | 129 | # H5pyData.write(output_seq_names,output_seq_embeddings,'temp/seq_embeddings.h5') 130 | if len(output_seq_names) > 0: 131 | save_dir = os.path.dirname(save_path) 132 | if not os.path.exists(save_dir): 133 | os.makedirs(save_dir, exist_ok=True) 134 | 135 | H5pyData.write(output_seq_names, output_seq_embeddings, 136 | save_path,16) 137 | logger.info(f"Embeddings saved to {save_path}") 138 | else: 139 | logger.warning("No sequences processed, no embeddings saved.") 140 | logger.warning("Check if the input file is empty or all sequences are too long.") 141 | 142 | return None 143 | 144 | if __name__ == '__main__': 145 | 146 | args = argparse.ArgumentParser(description="Generate protein embeddings using ProtT5") 147 | args.add_argument('--seq_file', type=str, required=True, help='Path to the input sequence file (FASTA format)') 148 | args.add_argument('--save_path', type=str, required=True, help='Path to save the output embeddings (HDF5 format)') 149 | args.add_argument('--max_length', type=int, default=8000, help='Maximum sequence length to process (default: 8000). ') 150 | args.add_argument('--min_length', type=int, default=1, help='Minimum sequence length to process (default: 1). Sequences shorter than this will be ignored.') 151 | args.add_argument('--device', type=str, default='cuda', help='Device to run the model on (default: cuda)') 152 | args.add_argument('--batch_size', type=int, default=1, help='Batch size for processing sequences (default: 1)') 153 | args = args.parse_args() 154 | 155 | main(seq_file=args.seq_file, 156 | save_path=args.save_path, 157 | device=args.device, 158 | max_length=args.max_length, 159 | batch_size=1) # Set batch_size to 1 for simplicity -------------------------------------------------------------------------------- /src/space/tools/data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This module contains the data structures for the space module. 3 | ''' 4 | import numpy as np 5 | import h5py 6 | from typing import List, Tuple, Iterable, Union 7 | import gzip 8 | import csv 9 | import os 10 | from multiprocessing import Pool 11 | import itertools 12 | from loguru import logger 13 | 14 | class H5pyData: 15 | 16 | @staticmethod 17 | def write(proteins: Union[np.ndarray, List, Tuple], 18 | embedding: np.ndarray, 19 | save_path: str, 20 | precision: int, 21 | chunk_size: int = 10000) -> None: 22 | ''' 23 | Write proteins and embeddings to HDF5 file efficiently. 24 | 25 | Args: 26 | proteins: Array-like of protein identifiers 27 | embedding: numpy array of embeddings (n_proteins x embedding_dim) 28 | save_path: Path to save the HDF5 file 29 | precision: Precision of embeddings (16 or 32) 30 | chunk_size: Size of chunks for HDF5 dataset 31 | ''' 32 | # Convert proteins to numpy array if needed 33 | proteins = np.array(proteins).astype('U').reshape(-1) 34 | embedding = np.array(embedding) 35 | 36 | # Validate inputs 37 | if len(proteins) != len(embedding): 38 | raise ValueError(f"Number of proteins ({len(proteins)}) doesn't match number of embeddings ({len(embedding)})") 39 | 40 | if precision not in [16, 32]: 41 | raise ValueError(f"Precision must be 16 or 32, got {precision}") 42 | 43 | # Determine dtype for embeddings 44 | dtype = np.float16 if precision == 16 else np.float32 45 | embedding = embedding.astype(dtype) 46 | 47 | n_proteins = len(proteins) 48 | embedding_dim = embedding.shape[1] 49 | 50 | with h5py.File(save_path, 'w') as f: 51 | # Create groups to organize data 52 | f.create_group('metadata') 53 | 54 | # Store metadata 55 | f['metadata'].attrs['n_proteins'] = n_proteins 56 | f['metadata'].attrs['embedding_dim'] = embedding_dim 57 | f['metadata'].attrs['precision'] = precision 58 | 59 | # Create datasets with chunking and compression 60 | protein_ds = f.create_dataset( 61 | 'proteins', 62 | shape=(n_proteins,), 63 | dtype=h5py.string_dtype(), 64 | chunks=(min(chunk_size, n_proteins),), 65 | compression='gzip', 66 | compression_opts=4 67 | ) 68 | 69 | embedding_ds = f.create_dataset( 70 | 'embeddings', 71 | shape=(n_proteins, embedding_dim), 72 | dtype=dtype, 73 | chunks=(min(chunk_size, n_proteins), embedding_dim), 74 | compression='gzip', 75 | compression_opts=4 76 | ) 77 | 78 | # Write data in chunks for better memory management 79 | for i in range(0, n_proteins, chunk_size): 80 | end_idx = min(i + chunk_size, n_proteins) 81 | 82 | protein_chunk = proteins[i:end_idx] 83 | embedding_chunk = embedding[i:end_idx] 84 | 85 | protein_ds[i:end_idx] = protein_chunk 86 | embedding_ds[i:end_idx] = embedding_chunk 87 | 88 | @staticmethod 89 | def read(file_path: str,precision: int) -> tuple[np.ndarray, np.ndarray]: 90 | ''' 91 | Read proteins and embeddings from HDF5 file. 92 | 93 | Args: 94 | file_path: Path to HDF5 file 95 | precision: Precision of embeddings (16 or 32) 96 | 97 | Returns: 98 | tuple: (proteins array, embeddings array) 99 | ''' 100 | with h5py.File(file_path, 'r') as f: 101 | proteins = f['proteins'][:] 102 | proteins = np.vectorize(lambda x: str(x)[2:-1])(proteins) 103 | embeddings = f['embeddings'][:] 104 | if precision == 16: 105 | embeddings = embeddings.astype(np.float16) 106 | elif precision == 32: 107 | embeddings = embeddings.astype(np.float32) 108 | 109 | return proteins, embeddings 110 | 111 | 112 | 113 | 114 | 115 | def query_single_species(querys, precision, aligned_path): 116 | 117 | # logger.info(f'Loading {aligned_path}...') 118 | proteins,embeddings = H5pyData.read(aligned_path,precision) 119 | 120 | protein2index = {protein: index for index, protein in enumerate(proteins)} 121 | 122 | indices = [] 123 | for query in querys: 124 | try: 125 | index = protein2index[query] 126 | indices.append( index) 127 | except NameError as e: 128 | logger.info(f'Query {query} not found. Error: {e}') 129 | 130 | ## get the embeddings 131 | query_proteins = [proteins[index] for index in indices] 132 | query_embeddings = [embeddings[index] for index in indices] 133 | 134 | return query_proteins, query_embeddings 135 | 136 | def query_embedding(querys:Iterable, query_dir:str,precision:int, n_jobs:int) -> Tuple[np.ndarray,np.ndarray]: 137 | ''' 138 | Query the embeddings for the given proteins in multiple species. 139 | Args: 140 | querys: Iterable, the proteins to query. 141 | aligned_dir: str, the directory of the aligned embeddings. 142 | precision: int, the precision of the embeddings, either 32 or 16. 143 | n_jobs: int, the number of jobs to run in parallel. 144 | Returns: 145 | output_proteins: np.array, the proteins that are found. 146 | output_embeddings: np.array, the embeddings of the proteins. 147 | 148 | ''' 149 | 150 | ## check if the directory of query_dir exists 151 | if not os.path.exists(query_dir): 152 | raise FileNotFoundError(f'{query_dir} does not exist') 153 | 154 | if precision not in [32,16]: 155 | raise ValueError('Precision should be either 32 or 16') 156 | 157 | ## put the quries in a dict where the key is taxid and the value is the list of querys 158 | query_dict = {} 159 | 160 | for query in querys: 161 | taxid = query.split('.')[0] 162 | if taxid not in query_dict: 163 | query_dict[taxid] = [query] 164 | else: 165 | query_dict[taxid].append(query) 166 | 167 | ## for each taxid get the data and the querys 168 | with Pool(n_jobs) as p: 169 | results = p.starmap(query_single_species, [(query_dict[taxid],precision, f'{query_dir}/{taxid}.h5') for taxid in query_dict]) 170 | 171 | output_proteins = list() 172 | output_embeddings = list() 173 | 174 | for result in results: 175 | output_proteins.append(result[0]) 176 | output_embeddings.append(result[1]) 177 | 178 | output_proteins = np.array(list(itertools.chain.from_iterable(output_proteins))) 179 | output_embeddings = np.array(list(itertools.chain.from_iterable(output_embeddings))) 180 | 181 | return output_proteins, output_embeddings 182 | 183 | 184 | ## a class to handle the gz file 185 | 186 | class GzipData: 187 | @staticmethod 188 | def string2idx(file_path:str,temp_path)->dict: 189 | 190 | nodes = dict() 191 | 192 | ## check if the directory of temp_path exists 193 | if not os.path.exists(os.path.dirname(temp_path)): 194 | os.makedirs(os.path.dirname(temp_path)) 195 | 196 | edges_writer = csv.writer(open(temp_path, 'w'), delimiter='\t') 197 | 198 | with gzip.open(file_path, 'rt') as f: 199 | 200 | reader = csv.reader(f, delimiter=' ') 201 | 202 | next(reader) # skip the header 203 | 204 | for row in reader: 205 | 206 | if row[0] not in nodes: 207 | nodes[row[0]] = len(nodes) 208 | if row[1] not in nodes: 209 | nodes[row[1]] = len(nodes) 210 | 211 | ## get the index of the protein 212 | src_idx = nodes[row[0]] 213 | dst_idx = nodes[row[1]] 214 | weight = int(row[-1])/1000 215 | 216 | edges_writer.writerow([src_idx,dst_idx,weight]) 217 | 218 | return nodes 219 | 220 | @staticmethod 221 | def read_nodes(file_path:str)->dict: 222 | nodes = dict() 223 | with gzip.open(file_path, 'rt') as f: 224 | reader = csv.reader(f, delimiter=' ') 225 | 226 | next(reader) # skip the header 227 | 228 | for row in reader: 229 | if row[0] not in nodes: 230 | nodes[row[0]] = len(nodes) 231 | if row[1] not in nodes: 232 | nodes[row[1]] = len(nodes) 233 | return nodes -------------------------------------------------------------------------------- /src/space/models/node2vec.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numba 3 | from gensim.models import Word2Vec 4 | from gensim.models.callbacks import CallbackAny2Vec 5 | from pecanpy.cli import pecanpy 6 | from loguru import logger 7 | from space.tools.data import H5pyData, GzipData 8 | 9 | 10 | 11 | class callback(CallbackAny2Vec): 12 | '''Callback to print loss after each epoch.''' 13 | 14 | def __init__(self,total_epoch): 15 | self.epoch = 1 16 | self.loss_to_be_subed = 0 17 | self.total_epoch = total_epoch 18 | self.saved_loss = list() 19 | 20 | def on_epoch_end(self, model): 21 | loss = model.get_latest_training_loss() 22 | loss_now = loss - self.loss_to_be_subed 23 | self.loss_to_be_subed = loss 24 | self.saved_loss.append([self.epoch,loss_now]) 25 | self.epoch += 1 26 | 27 | 28 | 29 | def learn_embeddings(walks,epochs=1,dimensions=64,window_size=10,workers=1,random_state=1234,taxid="",writer=None,hs=0): 30 | """_summary_ 31 | 32 | Args: 33 | epochs (int, optional): _description_. Defaults to 1. 34 | walks (int, optional): _description_. Defaults to 20. 35 | dimensions (int, optional): _description_. Defaults to 64. 36 | window_size (int, optional): _description_. Defaults to 10. 37 | workers (int, optional): _description_. Defaults to 1. 38 | random_state (int, optional): _description_. Defaults to 1234. 39 | output_path (str, optional): _description_. Defaults to ''. 40 | 41 | Returns: 42 | _type_: _description_ 43 | """ 44 | # cb = callback(epochs,taxid,writer) ## de-comment this line if you want to save the loss 45 | 46 | model = Word2Vec( 47 | walks, 48 | vector_size=dimensions, 49 | window=window_size, 50 | min_count=0, 51 | sg=1, 52 | hs=hs, 53 | workers=workers, 54 | epochs=epochs, 55 | seed=random_state, 56 | compute_loss=True, 57 | # callbacks=[cb] ## de-comment this line if you want to save the loss 58 | ) 59 | 60 | return model 61 | 62 | 63 | 64 | 65 | 66 | def check_mode(g, mode,weighted,p,q): 67 | """Check mode selection. 68 | 69 | Give recommendation to user for pecanpy mode based on graph size and density. 70 | 71 | """ 72 | # mode = args.mode 73 | # weighted = args.weighted 74 | # p = args.p 75 | # q = args.q 76 | 77 | # Check unweighted first order random walk usage 78 | if mode == "FirstOrderUnweighted": 79 | if not p == q == 1 or weighted: 80 | raise ValueError( 81 | f"FirstOrderUnweighted only works when weighted = False and " 82 | f"p = q = 1, got {weighted=}, {p=}, {q=}", 83 | ) 84 | return 85 | 86 | if mode != "FirstOrderUnweighted" and p == q == 1 and not weighted: 87 | warnings.warn( 88 | "When p = 1 and q = 1 with unweighted graph, it is highly " 89 | f"recommended to use FirstOrderUnweighted over {mode} (current " 90 | "selection). The runtime could be improved greatly with improved " 91 | "memory usage.", 92 | ) 93 | return 94 | 95 | # Check first order random walk usage 96 | if mode == "PreCompFirstOrder": 97 | if not p == q == 1: 98 | raise ValueError( 99 | f"PreCompFirstOrder only works when p = q = 1, got {p=}, {q=}", 100 | ) 101 | return 102 | 103 | if mode != "PreCompFirstOrder" and p == 1 == q: 104 | warnings.warn( 105 | "When p = 1 and q = 1, it is highly recommended to use " 106 | f"PreCompFirstOrder over {mode} (current selection). The runtime " 107 | "could be improved greatly with low memory usage.", 108 | ) 109 | return 110 | 111 | # Check network density and recommend appropriate mode 112 | g_size = g.num_nodes 113 | g_dens = g.density 114 | if (g_dens >= 0.2) & (mode != "DenseOTF"): 115 | warnings.warn( 116 | f"Network density = {g_dens:.3f} (> 0.2), it is recommended to use " 117 | f"DenseOTF over {mode} (current selection)", 118 | ) 119 | if (g_dens < 0.001) & (g_size < 10000) & (mode != "PreComp"): 120 | warnings.warn( 121 | f"Network density = {g_dens:.2e} (< 0.001) with {g_size} nodes " 122 | f"(< 10000), it is recommended to use PreComp over {mode} (current " 123 | "selection)", 124 | ) 125 | if (g_dens >= 0.001) & (g_dens < 0.2) & (mode != "SparseOTF"): 126 | warnings.warn( 127 | f"Network density = {g_dens:.3f}, it is recommended to use " 128 | f"SparseOTF over {mode} (current selection)", 129 | ) 130 | if (g_dens < 0.001) & (g_size >= 10000) & (mode != "SparseOTF"): 131 | warnings.warn( 132 | f"Network density = {g_dens:.3f} (< 0.001) with {g_size} nodes " 133 | f"(>= 10000), it is recommended to use SparseOTF over {mode} " 134 | "(current selection)", 135 | ) 136 | 137 | 138 | 139 | def read_graph(path,p,q,workers,verbose,weighted,directed,extend,gamma,random_state,mode,delimiter,implicit_ids): 140 | """Read input network to memory. 141 | 142 | Depending on the mode selected, reads the network either in CSR 143 | representation (``PreComp`` and ``SparseOTF``) or 2d numpy array 144 | (``DenseOTF``). 145 | 146 | """ 147 | 148 | if directed and extend: 149 | raise NotImplementedError("Node2vec+ not implemented for directed graph yet.") 150 | 151 | if extend and not weighted: 152 | print("NOTE: node2vec+ is equivalent to node2vec for unweighted graphs.") 153 | 154 | # if task in ["tocsr", "todense"]: # perform conversion then save and exit 155 | # g = graph.SparseGraph() if task == "tocsr" else graph.DenseGraph() 156 | # g.read_edg(path, weighted, directed, delimiter) 157 | # g.save(output) 158 | # exit() 159 | 160 | pecanpy_mode = getattr(pecanpy, mode, None) 161 | g = pecanpy_mode(p, q, workers, verbose, extend, gamma, random_state) 162 | 163 | if path.endswith(".npz"): 164 | g.read_npz(path, weighted, implicit_ids=implicit_ids) 165 | else: 166 | g.read_edg(path, weighted, directed, delimiter) 167 | 168 | check_mode(g, mode,weighted,p,q) 169 | 170 | return g 171 | 172 | 173 | def preprocess(g): 174 | """Preprocessing transition probabilities with timer.""" 175 | g.preprocess_transition_probs() 176 | 177 | 178 | def simulate_walks(num_walks, walk_length, g): 179 | """Simulate random walks with timer.""" 180 | return g.simulate_walks(num_walks, walk_length) 181 | 182 | 183 | class PecanpyEmbedder(): 184 | 185 | def __init__(self,graph_path,p=1,q=1,workers=-1, 186 | weighted=True,directed=False, 187 | extend=False,gamma=0,random_state=1234, 188 | delimiter:str='\t'): 189 | super().__init__() 190 | 191 | if workers == -1: 192 | workers = numba.config.NUMBA_DEFAULT_NUM_THREADS 193 | 194 | ## load the graph 195 | self.graph = read_graph(graph_path,p,q,workers,verbose=False, 196 | weighted=weighted,directed=directed, 197 | extend=extend,gamma=gamma, 198 | random_state=random_state, 199 | mode="SparseOTF",delimiter=delimiter, 200 | implicit_ids=False) 201 | preprocess(self.graph) 202 | 203 | 204 | def generate_walks(self,num_walks:int,walk_length:int) -> list: 205 | return simulate_walks(num_walks,walk_length,self.graph) 206 | 207 | 208 | 209 | 210 | def learn_embeddings(self, walks, epochs=1,dimensions=128, 211 | window_size=5,workers=-1, 212 | negative=5, 213 | hs=0,sg=1, 214 | random_state=1234) -> Word2Vec: 215 | """ 216 | Word2Vec API of gensim 217 | 218 | Parameters 219 | ---------- 220 | walks : list, list of walks 221 | epochs : int, number of epochs, default 1. 222 | dimensions : int, number of dimensions, default 128. 223 | window_size : int, window size, default 5. 224 | workers : int, number of workers, default -1 (all workers). 225 | negative : int, if >0, negative sampling will be used, number of negative samples if use negative sampling, default 5. 226 | hs : int, if 1, use hierarchical softmax; if 0, use negative sampling, default 0. 227 | sg : int, if 1, use skip-gram; if 0, use CBOW, default 1. 228 | random_state : int, random state, default 1234. 229 | """ 230 | 231 | cb = callback(epochs) 232 | 233 | if workers == -1: 234 | workers = numba.config.NUMBA_DEFAULT_NUM_THREADS 235 | 236 | model = Word2Vec( 237 | walks, 238 | vector_size=dimensions, 239 | window=window_size, 240 | min_count=0, 241 | sg=sg, 242 | hs=hs, 243 | workers=workers, 244 | epochs=epochs, 245 | seed=random_state, 246 | compute_loss=True, 247 | callbacks=[cb], 248 | negative=negative 249 | ) 250 | 251 | return model 252 | 253 | 254 | 255 | def run_single_embedding(species_file, temp_path, output_folder, dimensions, 256 | p, q, num_walks, walk_length, window_size, sg, 257 | negative, epochs, workers, random_state): 258 | """Run single species-specific embedding.""" 259 | ## process the gz file 260 | logger.info(f"Processing {species_file}...") 261 | nodes = GzipData.string2idx(species_file,temp_path) 262 | 263 | 264 | if len(nodes) > 50000: 265 | logger.warning(f"Number of nodes in {species_file} is {len(nodes)}, if it fails, try larger memory") 266 | 267 | logger.info(f"Running embedding for {species_file}...") 268 | # Read the graph 269 | embedder = PecanpyEmbedder(temp_path,p=p,q=q,workers=workers,weighted=True,directed=False, 270 | extend=False,gamma=0,random_state=random_state,delimiter='\t') 271 | # Generate the walks 272 | 273 | embeddings = embedder.learn_embeddings(embedder.generate_walks(num_walks=num_walks,walk_length=walk_length), 274 | epochs=epochs,dimensions=dimensions,window_size=window_size, 275 | workers=workers,negative=negative,hs=0,sg=sg,random_state=random_state) 276 | 277 | emb = embeddings.wv.vectors 278 | index = embeddings.wv.index_to_key 279 | 280 | proteins = list(nodes.keys()) 281 | 282 | ## map the index to the protein 283 | map_proteins = [proteins[int(i)] for i in index] 284 | 285 | ## save the embeddings 286 | species = species_file.split('/')[-1].split('.')[0] 287 | save_path = f"{output_folder}/{species}.h5" 288 | H5pyData.write(map_proteins,emb,save_path,32) 289 | logger.info(f"Embedding for {species_file} is saved at {save_path}") 290 | 291 | return None -------------------------------------------------------------------------------- /scripts/distances/kegg_single.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import pandas as pd 3 | from space.tools.data import H5pyData 4 | import numpy as np 5 | import csv 6 | from typing import Tuple 7 | import sys 8 | from sklearn.metrics import auc 9 | import itertools 10 | 11 | class BaseRoc: 12 | """Base class for ROC plot.""" 13 | 14 | def __init__(self): 15 | pass 16 | 17 | 18 | def load_protein_cluster(self) -> dict: 19 | """ 20 | Load protein cluster from benchmark dataset. 21 | """ 22 | raise NotImplementedError("load_protein_cluster() is not implemented.") 23 | 24 | def get_tp_fp(self,sorted_obj, 25 | protein_cluster, 26 | max_fp=-1, 27 | weight_threshold=-1, 28 | debug=False) -> Tuple[list, list]: 29 | """ 30 | Return true positive and false positive of a sorted object. 31 | """ 32 | try: 33 | iter(sorted_obj) 34 | link = sorted_obj 35 | except TypeError: 36 | raise TypeError("sorted_obj must be either a list, numpy array or an iterable object.") 37 | 38 | tp_cumu = 0 39 | fp_cumu = 0 40 | tp_cumu_list = [] 41 | fp_cumu_list = [] 42 | 43 | if max_fp == -1: 44 | max_fp = len(link) 45 | 46 | protein_cluster_keys = set(protein_cluster.keys()) 47 | 48 | for row in link: 49 | if row[2] < weight_threshold*1000: 50 | if debug: 51 | import pdb; pdb.set_trace() 52 | break 53 | if row[0] in protein_cluster_keys and row[1] in protein_cluster_keys: 54 | intersection = protein_cluster[row[0]].intersection(protein_cluster[row[1]]) 55 | 56 | if intersection: 57 | tp_cumu += 1 58 | else: 59 | fp_cumu += 1 60 | if fp_cumu > max_fp: 61 | break 62 | 63 | tp_cumu_list.append(tp_cumu) 64 | fp_cumu_list.append(fp_cumu) 65 | if fp_cumu > max_fp: 66 | break 67 | return tp_cumu_list, fp_cumu_list 68 | 69 | 70 | class KEGGRoc(BaseRoc): 71 | """Class for plotting ROC curve for KEGG.""" 72 | 73 | def __init__(self): 74 | pass 75 | 76 | def load_protein_cluster(self, benchmark_dataset,taxid) -> dict: 77 | 78 | prot_family = {} 79 | 80 | with open(benchmark_dataset, 'r') as f: 81 | 82 | reader = csv.reader(f, delimiter='\t') 83 | 84 | for row in reader: 85 | if int(row[0]) == int(taxid): 86 | 87 | if row[1] == 'pfa05144' and int(taxid) == 36329: 88 | continue 89 | elif row[1] in ['tbr05143','tbr00230'] and int(taxid) == 185431: 90 | continue 91 | 92 | for p in row[-1].split(' '): 93 | p = str(taxid)+'.'+p 94 | 95 | if p not in prot_family: 96 | prot_family[p] = set([row[1]]) 97 | else: 98 | prot_family[p].add(row[1]) 99 | else: 100 | pass 101 | self.prot_family = prot_family 102 | 103 | return prot_family 104 | 105 | def normalize_vectors(vectors): 106 | """Normalize the vectors to unit length.""" 107 | norm = np.linalg.norm(vectors, axis=1, keepdims=True) 108 | return vectors / np.maximum(norm, 1e-9) 109 | 110 | def get_embedding_rank_cos(embed:str|pd.DataFrame,proteins): 111 | """ 112 | Generate tp fp for node2vec embeddings, with cosine distance 113 | """ 114 | 115 | if isinstance(embed,str): 116 | _ , data = H5pyData.read(embed,16) 117 | data = data[proteins] 118 | 119 | elif isinstance(embed,pd.DataFrame): 120 | data = embed.values[proteins] 121 | else: 122 | raise ValueError('embed must be either str or pd.DataFrame') 123 | 124 | data = normalize_vectors(data) 125 | 126 | N, D = data.shape # Number of samples and dimension 127 | 128 | # Step 1: Load data into FAISS 129 | index = faiss.IndexFlatIP(D) # Create a flat (brute force) index 130 | index.add(data) # Add vectors to the index 131 | 132 | # Step 2: Compute pairwise distances 133 | D, I = index.search(data, N + 1) # Search all vectors (including itself) 134 | 135 | num_pairs = (N * (N - 1)) // 2 136 | results = np.empty((num_pairs, 3), dtype=object) 137 | 138 | count = 0 139 | 140 | for i in range(N): 141 | 142 | # Get indices where the condition is met 143 | valid_indices = I[i] > i 144 | 145 | # Number of valid pairs for this row 146 | num_valid = np.sum(valid_indices) 147 | 148 | # Store the results 149 | results[count:count + num_valid, 0] = proteins[i] 150 | results[count:count + num_valid, 1] = proteins[I[i, valid_indices]] 151 | results[count:count + num_valid, 2] = D[i, valid_indices] 152 | 153 | count += num_valid 154 | 155 | results = list(results) 156 | 157 | results.sort(key=lambda x: x[2],reverse=True) # Sort by cosine distance, descending 158 | 159 | # cum_tp, cum_fp = keggroc.get_tp_fp(results,kegg_clusters,max_fp) 160 | 161 | return results 162 | 163 | def get_auc(cumu_tp,cumu_fp): 164 | 165 | # Normalize the cumulative counts to get rates 166 | max_fp = cumu_fp[-1] # Maximum FP count 167 | max_tp = cumu_tp[-1] # Maximum TP count 168 | 169 | fp_rates = np.array(cumu_fp) / max_fp 170 | tp_rates = np.array(cumu_tp) / max_tp 171 | 172 | # Ensure the data is sorted by the normalized FP rates 173 | sorted_indices = np.argsort(fp_rates) 174 | fp_rates_sorted = fp_rates[sorted_indices] 175 | tp_rates_sorted = tp_rates[sorted_indices] 176 | 177 | # Calculate AUC using the trapezoidal rule 178 | auc = np.trapz(tp_rates_sorted, fp_rates_sorted) 179 | 180 | # print(f"Area Under the Curve (AUC): {auc}") 181 | 182 | return auc 183 | 184 | def calculate_auc_with_artificial_cap(cumulative_tp, cumulative_fp, max_tp,max_fp): 185 | """ 186 | Calculate AUC by setting an artificial maximum number of false positives. 187 | 188 | Parameters: 189 | - cumulative_tp: list of cumulative true positives at each threshold 190 | - cumulative_fp: list of cumulative false positives at each threshold 191 | - max_fp_cap: artificial maximum false positives (default: 100000) 192 | 193 | Returns: 194 | - auc_score: Area under the ROC curve 195 | - fpr: False positive rates 196 | - tpr: True positive rates 197 | """ 198 | 199 | # Get the maximum TP from your data (total positives found) 200 | # max_tp = max(cumulative_tp) 201 | 202 | # Use the artificial cap as total negatives 203 | # total_negatives = max_fp_cap 204 | # total_positives = max_tp # This assumes you found all positives in your ranking 205 | 206 | # Convert to rates 207 | tpr = [tp / max_tp for tp in cumulative_tp] 208 | fpr = [fp / max_fp for fp in cumulative_fp] 209 | 210 | # Add (0,0) point at the beginning 211 | tpr = [0] + tpr 212 | fpr = [0] + fpr 213 | 214 | # Calculate AUC using trapezoidal rule 215 | auc_score = auc(fpr, tpr) 216 | 217 | return auc_score, fpr, tpr 218 | 219 | 220 | def get_max_tp_fp(kegg_proteins,kegg_clusters): 221 | """ 222 | Get the maximum number of false positives for the given proteins and clusters. 223 | """ 224 | max_tp, max_fp = 0,0 225 | for p1,p2 in itertools.combinations(kegg_proteins, 2): 226 | if p1 in kegg_clusters and p2 in kegg_clusters: 227 | intersection = kegg_clusters[p1].intersection(kegg_clusters[p2]) 228 | if intersection: 229 | max_tp += 1 230 | else: 231 | max_fp += 1 232 | return max_tp, max_fp 233 | 234 | def run_single_species(taxid,node2vec_dir,t5_dir,aligned_dir,kegg_benchmarking_file,percent_threshold = 0.001): 235 | # print(f'Running {taxid}') 236 | node2vec_path = f'{node2vec_dir}/{taxid}.h5' 237 | 238 | t5_path = f'{t5_dir}/{taxid}.h5' 239 | 240 | aligned_path = f'{aligned_dir}/{taxid}.h5' 241 | 242 | node2ind = {k:v for v,k in enumerate(H5pyData.read(node2vec_path,16)[0])} 243 | 244 | keggroc = KEGGRoc() 245 | 246 | kegg_clusters = keggroc.load_protein_cluster(kegg_benchmarking_file,taxid) 247 | 248 | ## replace the protein names with the indices 249 | kegg_clusters = {node2ind[k]:v for k,v in kegg_clusters.items() if k in node2ind.keys() } 250 | 251 | kegg_proteins = np.array(list(kegg_clusters.keys())) 252 | 253 | # print('Running node2vec rank') 254 | node2vec_rank = get_embedding_rank_cos(node2vec_path,kegg_proteins) 255 | 256 | 257 | # max_fp = 100000 258 | # max_fp = get_max_fp(kegg_proteins,kegg_clusters) 259 | max_tp, max_fp = get_max_tp_fp(kegg_proteins,kegg_clusters) 260 | # use only the top 0.1% of the maximum false positives 261 | use_max_fp = int(max_fp * percent_threshold) # to reduce the calculation 262 | 263 | node2vec_tp, node2vec_fp = keggroc.get_tp_fp(node2vec_rank,kegg_clusters,use_max_fp) 264 | 265 | aligned2ind = {k:v for v,k in enumerate(H5pyData.read(aligned_path,16)[0])} 266 | kegg_clusters = keggroc.load_protein_cluster(kegg_benchmarking_file,taxid) 267 | kegg_clusters = {aligned2ind[k]:v for k,v in kegg_clusters.items() if k in aligned2ind.keys() } 268 | kegg_proteins = np.array(list(kegg_clusters.keys())) 269 | # print('Running aligned rank') 270 | aligned_rank = get_embedding_rank_cos(aligned_path,kegg_proteins,) 271 | aligned_tp, aligned_fp = keggroc.get_tp_fp(aligned_rank,kegg_clusters,use_max_fp,debug=False) 272 | 273 | t52ind = {k:v for v,k in enumerate(H5pyData.read(t5_path,16)[0])} 274 | kegg_clusters = keggroc.load_protein_cluster(kegg_benchmarking_file,taxid) 275 | kegg_clusters = {t52ind[k]:v for k,v in kegg_clusters.items() if k in t52ind.keys() } 276 | kegg_proteins = np.array(list(kegg_clusters.keys())) 277 | # print('Running t5 rank') 278 | seq_rank = get_embedding_rank_cos(t5_path,kegg_proteins) 279 | t5_tp,t5_fp = keggroc.get_tp_fp(seq_rank,kegg_clusters,use_max_fp) 280 | 281 | 282 | return node2vec_tp, node2vec_fp, aligned_tp, aligned_fp, t5_tp, t5_fp, max_tp, max_fp 283 | 284 | 285 | if __name__ == "__main__": 286 | 287 | species = int(sys.argv[1]) 288 | 289 | node2vec_embed = './data/node2vec' 290 | aligned_embed = './data/functional_emb' 291 | 292 | t5_embed = './data/t5_emb' 293 | 294 | benchmark = './kegg_benchmarking.CONN_maps_in.v12.tsv' 295 | 296 | 297 | node2vec_tp, node2vec_fp, aligned_tp, aligned_fp, \ 298 | t5_tp, t5_fp, max_tp, max_fp = run_single_species(species, 299 | node2vec_embed, 300 | t5_embed, 301 | aligned_embed, 302 | benchmark) 303 | 304 | node2vec_auc,fpr, tpr = calculate_auc_with_artificial_cap(node2vec_tp,node2vec_fp,max_fp=max_fp,max_tp=max_tp) 305 | aligned_auc,fpr, tpr = calculate_auc_with_artificial_cap(aligned_tp,aligned_fp,max_fp=max_fp,max_tp=max_tp) 306 | t5_auc,fpr, tpr = calculate_auc_with_artificial_cap(t5_tp,t5_fp,max_fp=max_fp,max_tp=max_tp) 307 | 308 | 309 | print(f"{species}\t{node2vec_auc}\t{aligned_auc}\t{t5_auc}\t{max_tp}\t{max_fp}") 310 | 311 | -------------------------------------------------------------------------------- /scripts/func_pred.py: -------------------------------------------------------------------------------- 1 | from space.tools.data import H5pyData 2 | import pandas as pd 3 | import os 4 | from loguru import logger 5 | from multiprocessing import Pool 6 | import numpy as np 7 | from sklearn.linear_model import LogisticRegression 8 | from cafaeval.evaluation import cafa_eval 9 | from sklearn import metrics 10 | import argparse 11 | 12 | def load_data(idmapping,dataset,train): 13 | 14 | idmapping = pd.read_csv(idmapping, sep='\t', ) 15 | 16 | labels = pd.read_csv(dataset, sep='\t', header=None) 17 | 18 | labels = labels[labels.iloc[:, 0].isin(idmapping['From'])] 19 | 20 | labels =labels.merge(idmapping, left_on=0, right_on='From', how='inner') 21 | 22 | labels = labels.drop(columns=[3, 'From']) 23 | 24 | labels.columns = ['uniprot', 'label', 'aspect', 'protein'] 25 | 26 | labels = labels.drop_duplicates().dropna() 27 | 28 | if train: 29 | labels = labels.groupby('label').filter(lambda x: len(x) >=10) 30 | 31 | return labels 32 | 33 | def load_embeddings_for_species(args): 34 | s, directory, ids_set = args 35 | file = f'{directory}/{s}.h5' 36 | 37 | if not os.path.exists(file): 38 | logger.error(f'{file} does not exist') 39 | return {} 40 | 41 | species_proteins, species_embeddings = H5pyData.read(file, precision=16) 42 | species_prot2idx = {p: i for i, p in enumerate(species_proteins)} 43 | 44 | # Filter species_proteins to include only those present in cv_ids 45 | return {p:species_embeddings[species_prot2idx[p]] for p in ids_set.intersection(species_prot2idx)} 46 | 47 | def load_embeddings_for_species_parallel(directory,protein_ids_set,species,n_jobs): 48 | pool_args = [(s, directory, protein_ids_set) for s in species] 49 | with Pool(n_jobs) as pool: 50 | results = pool.map(load_embeddings_for_species, pool_args) 51 | 52 | proteins = list() 53 | embeddings = list() 54 | 55 | for r in results: 56 | for p, e in r.items(): 57 | proteins.append(p) 58 | embeddings.append(e) 59 | 60 | return np.array(proteins), np.array(embeddings) 61 | 62 | 63 | def prepare_embeddings(train_idmapping, train_dataset, test_idmapping, 64 | test_dataset, aligned_dir, t5_dir, n_jobs): 65 | 66 | train_labels = load_data(train_idmapping, train_dataset, True) 67 | 68 | test_labels = load_data(test_idmapping, test_dataset, False) 69 | 70 | proteins = set(train_labels['protein']).union(set(test_labels['protein'])) 71 | 72 | species = list(set(map(lambda x: x.split('.')[0], proteins))) 73 | 74 | aligned_proteins, aligned_embeddings = load_embeddings_for_species_parallel(aligned_dir, 75 | proteins, 76 | species, 77 | n_jobs) 78 | 79 | seq_proteins, seq_embeddings = load_embeddings_for_species_parallel(t5_dir, 80 | proteins, 81 | species, 82 | n_jobs) 83 | # remove the rows with missing embeddings 84 | train_labels = train_labels[train_labels['protein'].isin(aligned_proteins)] 85 | test_labels = test_labels[test_labels['protein'].isin(aligned_proteins)] 86 | 87 | return train_labels, test_labels, \ 88 | dict(zip(aligned_proteins, aligned_embeddings)), \ 89 | dict(zip(seq_proteins, seq_embeddings)) 90 | 91 | def predict_single_label(train_labels,test_labels,label,aspect,seq_embeddings,aligned_embeddings,): 92 | 93 | train = train_labels[(train_labels['aspect']==aspect)] 94 | 95 | train_idx_string2uniprot = train[['protein','uniprot']].drop_duplicates() 96 | train_idx_string2uniprot = dict(zip(train_idx_string2uniprot['protein'],train_idx_string2uniprot['uniprot'])) 97 | 98 | train_pos = train[train['label']==label]['protein'].unique() 99 | 100 | train_neg = train[~train['uniprot'].isin(train_pos)]['protein'].unique() 101 | 102 | Y_train = np.array([1]*len(train_pos) + [0]*len(train_neg)) 103 | 104 | ## prepare the embeddings array 105 | X_train_seq = np.array([seq_embeddings[protein] for protein in train_pos] + [seq_embeddings[protein] for protein in train_neg]) 106 | X_train_aligned = np.array([aligned_embeddings[protein] for protein in train_pos] + [aligned_embeddings[protein] for protein in train_neg]) 107 | 108 | 109 | ## test set 110 | test = test_labels[test_labels['aspect']==aspect] 111 | test_idx_string2uniprot = test[['protein','uniprot']].drop_duplicates() 112 | test_idx_string2uniprot = dict(zip(test_idx_string2uniprot['protein'],test_idx_string2uniprot['uniprot'])) 113 | 114 | test_proteins = test['protein'].unique() 115 | X_test_seq = np.array([seq_embeddings[protein] for protein in test_proteins]) 116 | X_test_aligned = np.array([aligned_embeddings[protein] for protein in test_proteins]) 117 | 118 | ## 1. seq 119 | clf = LogisticRegression(max_iter=1000).fit(X_train_seq, Y_train) 120 | y_pred_seq = clf.predict_proba(X_test_seq)[:,1] 121 | 122 | ## 2. aligned 123 | clf = LogisticRegression(max_iter=1000).fit(X_train_aligned, Y_train) 124 | y_pred_aligned = clf.predict_proba(X_test_aligned)[:,1] 125 | 126 | ## 3. seq concatenated with aligned 127 | clf = LogisticRegression(max_iter=1000).fit(np.concatenate([X_train_seq,X_train_aligned],axis=1), Y_train) 128 | y_pred_seq_concat_aligned = clf.predict_proba(np.concatenate([X_test_seq,X_test_aligned],axis=1))[:,1] 129 | 130 | df_seq = pd.DataFrame(list(zip([test_idx_string2uniprot[protein] for protein in test_proteins], [label]*len(test_proteins), y_pred_seq)), columns=['uniprot','label','prediction']) 131 | df_aligned = pd.DataFrame(list(zip([test_idx_string2uniprot[protein] for protein in test_proteins], [label]*len(test_proteins), y_pred_aligned)), columns=['uniprot','label','prediction']) 132 | df_seq_concat_aligned = pd.DataFrame(list(zip([test_idx_string2uniprot[protein] for protein in test_proteins], [label]*len(test_proteins), y_pred_seq_concat_aligned)), columns=['uniprot','label','prediction']) 133 | 134 | df_seq = df_seq[df_seq['prediction'] > 0.01] 135 | df_aligned = df_aligned[df_aligned['prediction'] > 0.01] 136 | df_seq_concat_aligned = df_seq_concat_aligned[df_seq_concat_aligned['prediction'] > 0.01] 137 | 138 | return df_seq, df_aligned, df_seq_concat_aligned 139 | 140 | def eval_single_modal(ontology, prediction_dir, ground_truth,save_name,n_jobs): 141 | res = cafa_eval(ontology, prediction_dir, ground_truth, n_cpu=n_jobs, th_step=0.001) 142 | 143 | res[0].to_csv(f'{save_name}_metrics.csv', index=False) 144 | 145 | fmax_row = res[0].sort_values(by='f',ascending=False).head(1) 146 | fmax,s = float(fmax_row['f'].values[0]), float(fmax_row['s'].values[0]) 147 | pr,rc = fmax_row['pr'].values[0], fmax_row['rc'].values[0] 148 | 149 | auprc = metrics.auc(res[0]['rc'], res[0]['pr']) 150 | 151 | record = [fmax, s, auprc, pr, rc] 152 | # round 153 | record = [round(x, 3) if isinstance(x, float) else x for x in record] 154 | 155 | return record 156 | 157 | 158 | 159 | 160 | 161 | if __name__ == "__main__": 162 | 163 | argparser = argparse.ArgumentParser(description='Run functional prediction using SPACE embeddings') 164 | argparser.add_argument('--train_idmapping', type=str, default='data/benchmarks/netgo/train_idmapping_euk.tsv', help='Path to train idmapping file') 165 | argparser.add_argument('--train_dataset', type=str, default='data/benchmarks/netgo/train.txt', help='Path to train dataset file') 166 | argparser.add_argument('--test_idmapping', type=str, default='data/benchmarks/netgo/test_idmapping_euk.tsv', help='Path to test idmapping file') 167 | argparser.add_argument('--test_dataset', type=str, default='data/benchmarks/netgo/test.txt', help='Path to test dataset file') 168 | argparser.add_argument('--aligned_dir', type=str, default='data/functional_emb', help='Path to aligned embeddings directory') 169 | argparser.add_argument('--t5_dir', type=str, default='data/t5_emb', help='Path to T5 embeddings directory') 170 | argparser.add_argument('--save_dir', type=str, default='results/func_pred', help='Directory to save results') 171 | argparser.add_argument('--ontology', type=str, default='data/benchmarks/netgo/go_2020_10_09.obo', help='Path to ontology file') 172 | argparser.add_argument('--ground_truth_dir', type=str, default='data/benchmarks/netgo', help='Path to ground truth directory') 173 | argparser.add_argument('--n_jobs', type=int, default=3, help='Number of jobs for parallel processing') 174 | args = argparser.parse_args() 175 | 176 | save_dir = args.save_dir 177 | train_idmapping = args.train_idmapping 178 | train_dataset = args.train_dataset 179 | test_idmapping = args.test_idmapping 180 | test_dataset = args.test_dataset 181 | aligned_dir = args.aligned_dir 182 | t5_dir = args.t5_dir 183 | ontology = args.ontology 184 | ground_truth_dir = args.ground_truth_dir 185 | n_jobs = args.n_jobs 186 | logger.info('Starting functional prediction with SPACE embeddings') 187 | logger.info(f'Using {n_jobs} jobs for parallel processing') 188 | 189 | if not os.path.exists(save_dir): 190 | os.makedirs(save_dir) 191 | 192 | logger.info('Preparing embeddings') 193 | train_labels, test_labels, \ 194 | seq_embeddings, aligned_embeddings = prepare_embeddings(train_idmapping,train_dataset, 195 | test_idmapping,test_dataset, 196 | aligned_dir,t5_dir,n_jobs) 197 | scores = list() 198 | for aspect in ['cc', 'bp', 'mf']: 199 | logger.info(f'Predicting {aspect}') 200 | aspect_labels = train_labels[train_labels['aspect'] == aspect]['label'].unique() 201 | 202 | logger.info(f'Aspect {aspect} has {len(aspect_labels)} labels') 203 | seq_predictions = [] 204 | aligned_predictions = [] 205 | seq_concat_aligned_predictions = [] 206 | 207 | with Pool(n_jobs) as pool: 208 | results = pool.starmap(predict_single_label, [(train_labels,test_labels, 209 | label,aspect,seq_embeddings, 210 | aligned_embeddings) for label in aspect_labels]) 211 | 212 | for df_seq, df_aligned, df_seq_concat_aligned in results: 213 | seq_predictions.append(df_seq) 214 | aligned_predictions.append(df_aligned) 215 | seq_concat_aligned_predictions.append(df_seq_concat_aligned) 216 | 217 | os.system(f'mkdir -p {save_dir}/{aspect}_seq_pred') 218 | os.system(f'mkdir -p {save_dir}/{aspect}_aligned_pred') 219 | os.system(f'mkdir -p {save_dir}/{aspect}_space_pred') 220 | 221 | pd.concat(seq_predictions).to_csv(f'{save_dir}/{aspect}_seq_pred/{aspect}_seq_pred.tsv', header=False,index=False,sep='\t') 222 | pd.concat(aligned_predictions).to_csv(f'{save_dir}/{aspect}_aligned_pred/{aspect}_aligned_pred.tsv', index=False,header=False,sep='\t') 223 | pd.concat(seq_concat_aligned_predictions).to_csv(f'{save_dir}/{aspect}_space_pred/{aspect}_space_pred.tsv', index=False,header=False,sep='\t') 224 | 225 | ## evaluate the predictions with cafa-eval 226 | group_truth = f'{ground_truth_dir}/test_{aspect}_ground_truth.txt' 227 | 228 | # eval 229 | logger.info('Evaluating predictions') 230 | seq_record = eval_single_modal(ontology, f'{save_dir}/{aspect}_seq_pred', group_truth, f'{save_dir}/{aspect}_seq_eval', n_jobs) 231 | aligned_record = eval_single_modal(ontology, f'{save_dir}/{aspect}_aligned_pred', group_truth, f'{save_dir}/{aspect}_aligned_eval', n_jobs) 232 | seq_concat_aligned_record = eval_single_modal(ontology, f'{save_dir}/{aspect}_space_pred', group_truth, f'{save_dir}/{aspect}_space_eval', n_jobs) 233 | 234 | seq_record = [aspect,'seq'] + seq_record 235 | aligned_record = [aspect,'aligned'] + aligned_record 236 | seq_concat_aligned_record = [aspect,'space'] + seq_concat_aligned_record 237 | 238 | scores.append(seq_record) 239 | scores.append(aligned_record) 240 | scores.append(seq_concat_aligned_record) 241 | 242 | scores = pd.DataFrame(scores, columns=['aspect','method', 'fmax', 's', 'auprc', 'pr', 'rc']) 243 | scores.to_csv(f'{save_dir}/scores.csv', index=False) 244 | 245 | logger.info('DONE.') -------------------------------------------------------------------------------- /scripts/distances/og_sampling.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pandas as pd 3 | from multiprocessing import Pool 4 | from space.tools.data import H5pyData 5 | from tqdm import tqdm 6 | import itertools 7 | import numpy as np 8 | import scipy 9 | import random 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | def read_single_embeddings(taxid): 14 | proteins, emb = H5pyData.read(f'./data/aligned/{taxid}.h5', 16) 15 | return taxid, (proteins, emb) 16 | 17 | def load_embeddings(): 18 | # load all the embeddings 19 | aligned_embeddings = dict() 20 | 21 | with open('./data/euks.txt') as f: 22 | taxids = f.read().strip().split('\n') 23 | 24 | with Pool(processes=8) as pool: 25 | results = list(tqdm(pool.imap(read_single_embeddings, taxids), total=len(taxids), desc="Loading embeddings")) 26 | 27 | for taxid, (proteins, emb) in results: 28 | aligned_embeddings[taxid] = (proteins, emb) 29 | 30 | protein_2_index = dict() 31 | for taxid, (proteins, emb) in aligned_embeddings.items(): 32 | protein_2_index[taxid] = dict(zip(proteins, range(len(proteins)))) 33 | 34 | return aligned_embeddings, protein_2_index 35 | 36 | def sample_proteins(protein_list, probability=0.01): 37 | """Fastest method using NumPy vectorized operations""" 38 | protein_array = np.array(protein_list) 39 | mask = np.random.random(len(protein_array)) < probability 40 | return protein_array[mask].tolist() 41 | 42 | def kick_out_pairs(protein_pairs): 43 | # randomize the paris 44 | random.shuffle(protein_pairs) 45 | 46 | used_proteins = set() 47 | filtered_pairs = [] 48 | 49 | for src, tgt in protein_pairs: 50 | if src not in used_proteins or tgt not in used_proteins: 51 | filtered_pairs.append((src, tgt)) 52 | used_proteins.add(src) 53 | used_proteins.add(tgt) 54 | return filtered_pairs 55 | 56 | def process_single_level(df_ancestor, level, all_proteins, aligned_embeddings): 57 | 58 | level = int(level) 59 | species_taxids = df_ancestor[df_ancestor['ancestor'] == level][['taxid_1', 'taxid_2']] 60 | 61 | # Convert to efficient set of integer tuples 62 | species_pairs = set() 63 | for taxid_1, taxid_2 in species_taxids.values: 64 | t1, t2 = int(taxid_1), int(taxid_2) 65 | species_pairs.add((min(t1, t2), max(t1, t2))) 66 | 67 | protein_pairs = list() 68 | 69 | with gzip.open(f'./data/eggnog/{level}.tsv.gz', 'rt') as f: 70 | 71 | for line in f: 72 | line = line.strip().split('\t') 73 | _, species_list, orthologs = line[1], line[-2].split(','), line[-1] 74 | 75 | species_list = [int(s) for s in species_list] 76 | orthologs = orthologs.split(',') 77 | 78 | ortholog_per_species = dict() 79 | 80 | for og_ in orthologs: 81 | taxid_int = int(og_.split('.')[0]) 82 | # make sure the taxid is in the aligned embeddings 83 | if str(taxid_int) not in aligned_embeddings: 84 | continue 85 | if taxid_int not in ortholog_per_species: 86 | ortholog_per_species[taxid_int] = 0 87 | ortholog_per_species[taxid_int] += 1 88 | 89 | # make sure we have this protein in the aligned embeddings 90 | valid_proteins = set(orthologs) & all_proteins 91 | if len(valid_proteins) < 2: 92 | continue 93 | # proteins from the same species can only show up once 94 | valid_proteins = [og_ for og_ in valid_proteins if ortholog_per_species.get(int(og_.split('.')[0]), 0) == 1] 95 | if len(valid_proteins) < 2: 96 | continue 97 | 98 | # protein_pairs.extend(itertools.combinations(valid_proteins, 2)) 99 | 100 | # for each protein we have 0.01% chance to sample it 101 | # sampled_proteins = np.random.choice(valid_proteins, size=int(len(valid_proteins) * 0.0001), replace=False) 102 | sampled_proteins = sample_proteins(valid_proteins) 103 | 104 | protein_pairs.extend(itertools.combinations(sampled_proteins, 2)) 105 | 106 | return protein_pairs 107 | 108 | def get_protein_pairs_distances(protein_pairs, aligned_embeddings, protein_2_index): 109 | 110 | src_emb = list() 111 | tgt_emb = list() 112 | for src,tgt in protein_pairs: 113 | src_taxid = src.split('.')[0] 114 | tgt_taxid = tgt.split('.')[0] 115 | 116 | src_index = protein_2_index[src_taxid][src] 117 | tgt_index = protein_2_index[tgt_taxid][tgt] 118 | src_emb.append(aligned_embeddings[src_taxid][1][src_index]) 119 | tgt_emb.append(aligned_embeddings[tgt_taxid][1][tgt_index]) 120 | src_emb = np.array(src_emb) 121 | tgt_emb = np.array(tgt_emb) 122 | euc_distances = np.linalg.norm(src_emb - tgt_emb, axis=1) 123 | cos_similarities = np.sum(src_emb * tgt_emb, axis=1) / (np.linalg.norm(src_emb, axis=1) * np.linalg.norm(tgt_emb, axis=1)) 124 | return euc_distances, cos_similarities 125 | 126 | def sample_negative_pairs(protein_pairs, protein_2_index): 127 | 128 | negative_pairs = [] 129 | 130 | for og_1, og_2 in protein_pairs: 131 | taxid_1 = int(og_1.split('.')[0]) 132 | taxid_2 = int(og_2.split('.')[0]) 133 | 134 | # sample the same taxid but random index 135 | while True: 136 | sample_prot1 = random.choice(list(protein_2_index[str(taxid_1)].keys())) 137 | sample_prot2 = random.choice(list(protein_2_index[str(taxid_2)].keys())) 138 | if sample_prot1 != og_1 and sample_prot2 != og_2: 139 | negative_pairs.append((sample_prot1, sample_prot2)) 140 | break 141 | return negative_pairs 142 | 143 | 144 | def get_stat_report(pos_distance, neg_distance, pos_similarity, neg_similarity): 145 | _, p_value = scipy.stats.wilcoxon(pos_distance, neg_distance) 146 | differences = pos_distance - neg_distance 147 | n_pos = np.sum(differences > 0) 148 | n_neg = np.sum(differences < 0) 149 | effect_size = abs(n_pos - n_neg) / len(differences) 150 | 151 | if effect_size < 0.1: 152 | effect_interpretation = "negligible" 153 | elif effect_size < 0.3: 154 | effect_interpretation = "small" 155 | elif effect_size < 0.5: 156 | effect_interpretation = "medium" 157 | else: 158 | effect_interpretation = "large" 159 | 160 | _, p_value = scipy.stats.wilcoxon(pos_similarity, neg_similarity) 161 | differences = pos_similarity - neg_similarity 162 | n_pos = np.sum(differences > 0) 163 | n_neg = np.sum(differences < 0) 164 | effect_size = abs(n_pos - n_neg) / len(differences) 165 | if effect_size < 0.1: 166 | effect_interpretation = "negligible" 167 | elif effect_size < 0.3: 168 | effect_interpretation = "small" 169 | elif effect_size < 0.5: 170 | effect_interpretation = "medium" 171 | else: 172 | effect_interpretation = "large" 173 | stat_dict_cos = { 174 | 'pos_cos_similarity_mean': np.mean(pos_similarity), 175 | 'pos_cos_similarity_std': np.std(pos_similarity), 176 | 'pos_cos_min': np.min(pos_similarity), 177 | 'pos_cos_max': np.max(pos_similarity), 178 | 'pos_cos_median': np.median(pos_similarity), 179 | 'neg_cos_similarity_mean': np.mean(neg_similarity), 180 | 'neg_cos_similarity_std': np.std(neg_similarity), 181 | 'neg_cos_min': np.min(neg_similarity), 182 | 'neg_cos_max': np.max(neg_similarity), 183 | 'neg_cos_median': np.median(neg_similarity), 184 | 'cos_similarity_wilcoxon_effect_size': effect_size, 185 | 'cos_similarity_wilcoxon_effect_interpretation': effect_interpretation, 186 | 'cos_p_value': p_value, 187 | } 188 | return stat_dict_cos 189 | 190 | 191 | def distances_by_group(df,all_proteins,aligned_embeddings,protein_2_index): 192 | 193 | unique_levels = df['ancestor'].unique() 194 | unique_levels = [int(level) for level in unique_levels if level != 'nan'] 195 | 196 | unique_levels = sorted(unique_levels) 197 | 198 | total_pairs = list() 199 | for level in tqdm(unique_levels, desc="Processing levels"): 200 | protein_pairs = process_single_level(df, level, all_proteins, aligned_embeddings) 201 | total_pairs.extend(protein_pairs) 202 | 203 | total_pairs = kick_out_pairs(total_pairs) 204 | # negative sampling 205 | negative_pairs = sample_negative_pairs(total_pairs, protein_2_index) 206 | pos_euc_distances, pos_cos_similarities = get_protein_pairs_distances(total_pairs, aligned_embeddings, protein_2_index) 207 | neg_euc_distances, neg_cos_similarities = get_protein_pairs_distances(negative_pairs, aligned_embeddings, protein_2_index) 208 | 209 | report = get_stat_report(pos_euc_distances, neg_euc_distances, pos_cos_similarities, neg_cos_similarities) 210 | return report, (pos_euc_distances, pos_cos_similarities), (neg_euc_distances, neg_cos_similarities) 211 | 212 | 213 | def main(): 214 | random.seed(42) 215 | np.random.seed(42) 216 | 217 | report_results = list() 218 | distances = list() 219 | df_ancestor = pd.read_csv('./data/euks_ancestors.tsv', 220 | sep='\t') 221 | seeds = open('./data/seeds.txt').read().strip().split('\n') 222 | seeds = [int(s) for s in seeds] 223 | df_ancestor['taxid_1_seed'] = False 224 | df_ancestor['taxid_2_seed'] = False 225 | for seed in seeds: 226 | df_ancestor.loc[df_ancestor['taxid_1'] == seed, 'taxid_1_seed'] = True 227 | df_ancestor.loc[df_ancestor['taxid_2'] == seed, 'taxid_2_seed'] = True 228 | 229 | aligned_embeddings, protein_2_index = load_embeddings() 230 | 231 | all_proteins = set() 232 | for proteins, _ in aligned_embeddings.values(): 233 | all_proteins.update(proteins) 234 | 235 | # seed with seed 236 | df_seed_seed = df_ancestor[(df_ancestor['taxid_1_seed'] == True) & (df_ancestor['taxid_2_seed'] == True)] 237 | report, (pos_euc_distances, pos_cos_similarities), \ 238 | (neg_euc_distances, neg_cos_similarities) = distances_by_group(df_seed_seed, 239 | all_proteins, 240 | aligned_embeddings, 241 | protein_2_index) 242 | report['Group'] = 'Seed with Seed' 243 | report_results.append(report) 244 | 245 | distances.append((pos_euc_distances, pos_cos_similarities, neg_euc_distances, neg_cos_similarities)) 246 | 247 | # seed with non-seed 248 | df1 = df_ancestor[(df_ancestor['taxid_1_seed'] == True) & (df_ancestor['taxid_2_seed'] == False)] 249 | df2 = df_ancestor[(df_ancestor['taxid_1_seed'] == False) & (df_ancestor['taxid_2_seed'] == True)] 250 | df_seed_nonseed = pd.concat([df1, df2]) 251 | report, (pos_euc_distances, pos_cos_similarities), \ 252 | (neg_euc_distances, neg_cos_similarities) = distances_by_group(df_seed_nonseed, 253 | all_proteins, 254 | aligned_embeddings, 255 | protein_2_index) 256 | report['Group'] = 'Seed with Non-Seed' 257 | report_results.append(report) 258 | distances.append((pos_euc_distances, pos_cos_similarities, neg_euc_distances, neg_cos_similarities)) 259 | 260 | # non-seed with non-seed 261 | df_nonseed = df_ancestor[(df_ancestor['taxid_1_seed'] == False) & (df_ancestor['taxid_2_seed'] == False)] 262 | report, (pos_euc_distances, pos_cos_similarities), \ 263 | (neg_euc_distances, neg_cos_similarities) = distances_by_group(df_nonseed, 264 | all_proteins, 265 | aligned_embeddings, 266 | protein_2_index) 267 | report['Group'] = 'Non-Seed with Non-Seed' 268 | report_results.append(report) 269 | distances.append((pos_euc_distances, pos_cos_similarities, neg_euc_distances, neg_cos_similarities)) 270 | 271 | # save the report 272 | report_df = pd.DataFrame(report_results) 273 | report_df.to_csv('./results/og_sampling_report.csv', index=False) 274 | 275 | # plot the distances 276 | 277 | colors = ['#e70148','#bbbbbb'] 278 | plt.figure(figsize=(10, 4)) 279 | sns.boxplot(data=report_df, x='Group', y='Distance', hue='Type', 280 | showfliers=False, 281 | widths=0.2, 282 | gap=0.01, 283 | palette=colors) 284 | plt.title('Cosine Similarity Distribution') 285 | plt.ylabel('Cosine Similarity') 286 | plt.xlabel('Groups') 287 | # use the custom colors 288 | 289 | plt.grid(True) 290 | # plt.tight_layout() 291 | plt.savefig('./results/og_sampling_cosine_similarity.png', dpi=300) 292 | plt.show() 293 | -------------------------------------------------------------------------------- /scripts/subloc.py: -------------------------------------------------------------------------------- 1 | from space.tools.data import H5pyData 2 | import pandas as pd 3 | import random 4 | import numpy as np 5 | import os 6 | from sklearn.metrics import accuracy_score, f1_score, jaccard_score, matthews_corrcoef 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.multioutput import MultiOutputClassifier 9 | from loguru import logger 10 | from typing import Dict, List, Tuple 11 | from multiprocessing import Pool 12 | import umap 13 | import matplotlib.pyplot as plt 14 | import argparse 15 | 16 | 17 | 18 | 19 | def precision_recall(y_scores_flat,y_true_flat): 20 | thresholds = np.linspace(0, 1, 1000) 21 | 22 | # Calculate precision and recall for each threshold 23 | precisions = [] 24 | recalls = [] 25 | for threshold in thresholds: 26 | y_pred = (y_scores_flat > threshold).astype(int) 27 | true_positives = np.sum((y_true_flat == 1) & (y_pred == 1)) 28 | false_positives = np.sum((y_true_flat == 0) & (y_pred == 1)) 29 | false_negatives = np.sum((y_true_flat == 1) & (y_pred == 0)) 30 | 31 | precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 1 32 | recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 33 | 34 | precisions.append(precision) 35 | recalls.append(recall) 36 | 37 | # Convert to numpy arrays 38 | precisions = np.array(precisions) 39 | recalls = np.array(recalls) 40 | 41 | # Sort the points by recall 42 | sorted_indices = np.argsort(recalls) 43 | recalls = recalls[sorted_indices] 44 | precisions = precisions[sorted_indices] 45 | 46 | return precisions, recalls 47 | 48 | def filter_non_existing_proteins(cv_ids,cv_labels,cv_partitions,cv_species,directory:List[str]): 49 | 50 | ## load the proteins from the embeddings 51 | proteins = set() 52 | for s in cv_species: 53 | file = f'{directory}/{s}.h5' 54 | if not os.path.exists(file): 55 | logger.error(f'{file} does not exist') 56 | continue 57 | 58 | s_proteins, _ = H5pyData.read(file,precision=16) 59 | 60 | proteins.update(s_proteins) 61 | 62 | proteins = list(proteins) 63 | 64 | ## filter the proteins that are not in the embeddings 65 | idx = [i for i,p in enumerate(cv_ids) if p in proteins] 66 | 67 | cv_ids = cv_ids[idx] 68 | cv_labels = cv_labels[idx] 69 | cv_partitions = cv_partitions[idx] 70 | 71 | return cv_ids, cv_labels, cv_partitions 72 | 73 | def init_logistic_regression(random_seed): 74 | return MultiOutputClassifier(LogisticRegression(max_iter=1000,random_state=random_seed)) 75 | 76 | def load_cv_set(cv_set, cv_id_mapping, ): 77 | 78 | cv_set = pd.read_csv(cv_set,index_col=0) 79 | 80 | ## drop the sequence column 81 | cv_set = cv_set.drop(columns=['Sequence']) 82 | 83 | 84 | ## open the id mapping file 85 | cv_id_mapping = pd.read_csv(cv_id_mapping,sep='\t') 86 | cv_id_mapping.columns = ['ACC','STRING_id'] 87 | 88 | 89 | ## merge the id mapping file with the cross validation set 90 | cv_set = cv_set.merge(cv_id_mapping,on='ACC').dropna() 91 | ## if one ACC has multiple STRING_id, keep the first one 92 | cv_set = cv_set.drop_duplicates(subset='ACC') 93 | 94 | ## extract the species from the string id 95 | cv_species = cv_set['STRING_id'].apply(lambda x: x.split('.')[0]).unique() 96 | 97 | cv_ids = cv_set['STRING_id'] 98 | 99 | cv_label_headers = ['Cytoplasm','Nucleus','Extracellular','Cell membrane', 100 | 'Mitochondrion','Plastid','Endoplasmic reticulum', 101 | 'Lysosome/Vacuole','Golgi apparatus','Peroxisome'] 102 | 103 | cv_labels = cv_set[cv_label_headers].values.astype(int) 104 | 105 | cv_partitions = cv_set['Partition'].values 106 | 107 | return cv_ids, cv_labels, cv_label_headers, cv_partitions, cv_species 108 | 109 | def load_embeddings_for_species(args) -> Dict[str, np.ndarray]: 110 | s, directory, cv_ids_set = args 111 | file = f'{directory}/{s}.h5' 112 | species_proteins, species_embeddings = H5pyData.read(file, precision=16) 113 | species_prot2idx = {p: i for i, p in enumerate(species_proteins)} 114 | return {p: species_embeddings[species_prot2idx[p]] for p in cv_ids_set.intersection(species_prot2idx)} 115 | 116 | def filter_and_load_proteins_embeddings( 117 | cv_ids: np.ndarray, 118 | cv_labels: np.ndarray, 119 | cv_partitions: np.ndarray, 120 | cv_species: List[str], 121 | directory: str, 122 | n_jobs: int = 1, 123 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 124 | # Convert cv_ids to a set for faster lookup 125 | cv_ids_set = set(cv_ids) 126 | 127 | pool_args = [(s, directory, cv_ids_set) for s in cv_species] 128 | with Pool(n_jobs) as pool: 129 | results = pool.map(load_embeddings_for_species, pool_args) 130 | protein_to_embedding = dict() 131 | for result in results: 132 | protein_to_embedding.update(result) 133 | 134 | filtered_idx = [i for i, p in enumerate(cv_ids) if p in protein_to_embedding] 135 | cv_ids = cv_ids[filtered_idx] 136 | cv_labels = cv_labels[filtered_idx] 137 | cv_partitions = cv_partitions[filtered_idx] 138 | 139 | # Extract the ordered embeddings 140 | output_proteins = cv_ids 141 | output_embeddings = np.array([protein_to_embedding[p] for p in output_proteins]) 142 | 143 | return cv_ids, cv_labels, cv_partitions, np.array(output_proteins), output_embeddings 144 | 145 | 146 | def evaluate_model(model, X_test:np.ndarray, y_test:np.ndarray,eval_human=False): 147 | 148 | y_test = y_test.astype(int) 149 | ypred = model.predict(X_test) 150 | 151 | if eval_human: 152 | ypred = ypred[:,[0,1,3,4,6,8]] 153 | y_test = y_test[:,[0,1,3,4,6,8]] 154 | 155 | metrics = [f1_score(y_test, ypred, average='micro'), 156 | f1_score(y_test, ypred, average='macro'), 157 | accuracy_score(y_test, ypred), 158 | jaccard_score(y_test, ypred, average='micro')] 159 | 160 | mcc = [matthews_corrcoef(y_test[:,i], ypred[:,i]) for i in range(y_test.shape[-1])] 161 | 162 | return metrics, mcc 163 | 164 | 165 | def benchmark_single_modal_on_cv(cv_partitions,embeddings,cv_labels,cv_label_headers,random_seed): 166 | # ## 1. run logreg with t5 embeddings 167 | scores = list() 168 | mccs = list() 169 | 170 | pred_scores = list() 171 | y_trues = list() 172 | 173 | for i in range(5): 174 | training_idx = cv_partitions != i 175 | val_idx = cv_partitions == i 176 | 177 | X_train = embeddings[training_idx] 178 | y_train = cv_labels[training_idx] 179 | 180 | X_val = embeddings[val_idx] 181 | y_val = cv_labels[val_idx] 182 | 183 | clf = init_logistic_regression(random_seed) 184 | 185 | clf.fit(X_train,y_train) 186 | 187 | score, mcc = evaluate_model(clf, X_val, y_val) 188 | 189 | pred_scores.append(np.array([s[:,1] for s in clf.predict_proba(X_val)]).T.flatten()) 190 | y_trues.append(y_val.flatten()) 191 | 192 | scores.append(score) 193 | mccs.append(mcc) 194 | 195 | y_scores_flat = np.concatenate(pred_scores) 196 | y_trues_flat = np.concatenate(y_trues) 197 | 198 | scores = pd.DataFrame(scores, columns=['f1_micro','f1_macro','accuracy','jaccard']) 199 | 200 | mccs = pd.DataFrame(mccs, columns=cv_label_headers) 201 | 202 | return scores, mccs, np.stack(precision_recall(y_scores_flat,y_trues_flat),axis=1) 203 | 204 | def mean_std(scores:pd.DataFrame): 205 | means = np.round(scores.mean().values,2) 206 | stddevs = np.round(scores.std().values,2) 207 | means_std = [str(mn)+' ± '+str(sd) for mn, sd in zip(means, stddevs)] 208 | return means_std 209 | 210 | 211 | def plot_umap_projection(projection,cv_labels_aligned, 212 | cv_label_headers,save_name,random_seed): 213 | # import pdb; pdb.set_trace() 214 | projection = np.stack(projection,axis=1) 215 | 216 | df_cv_labels = pd.DataFrame(cv_labels_aligned,columns=cv_label_headers) 217 | df_cv_labels['num_labels'] = df_cv_labels.sum(axis=1) 218 | 219 | indices = df_cv_labels[df_cv_labels['num_labels'] == 1].index 220 | indices_labels = df_cv_labels.iloc[indices, :-1].idxmax(axis=1) 221 | 222 | reducer = umap.UMAP(n_components=2,n_neighbors=100,min_dist=0.1,random_state=random_seed) 223 | 224 | umap_embeddings = reducer.fit_transform(projection[indices]) 225 | 226 | colors = ["#e60049", "#0bb4ff", "#50e991", "#e6d800", "#9b19f5", "#ffa300", "#dc0ab4", "#b3d4ff", "#00bfa0", "grey"] 227 | locations = ['Cytoplasm','Cell membrane', 'Nucleus', 'Lysosome/Vacuole', 228 | 'Mitochondrion', 'Plastid','Endoplasmic reticulum', 229 | 'Golgi apparatus', 'Peroxisome', 'Extracellular',] 230 | 231 | loc2color = dict(zip(locations, colors)) 232 | 233 | df_umap_emb = pd.DataFrame(umap_embeddings) 234 | df_umap_emb['label'] = indices_labels.values 235 | df_umap_emb['color'] = df_umap_emb['label'].apply(lambda x: loc2color[x]) 236 | 237 | scatter = plt.scatter(df_umap_emb.iloc[:, 0], df_umap_emb.iloc[:, 1], 238 | c=df_umap_emb['color'], 239 | s=1) 240 | 241 | legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 242 | markerfacecolor=colors[i], label=label, markersize=5) 243 | for i, label in enumerate(locations)] 244 | 245 | plt.legend(handles=legend_elements, 246 | bbox_to_anchor=(0.95, 0.5), # Position relative to the plot 247 | loc='center left', # Anchor point of the legend 248 | borderaxespad=0, # Padding between legend and axes 249 | bbox_transform=plt.gca().transAxes, 250 | frameon=False, 251 | fontsize=13,) 252 | 253 | plt.xticks([]) 254 | plt.yticks([]) 255 | plt.gca().spines['top'].set_visible(False) 256 | plt.gca().spines['right'].set_visible(False) 257 | plt.gca().spines['bottom'].set_visible(False) 258 | plt.gca().spines['left'].set_visible(False) 259 | plt.tight_layout() 260 | plt.savefig(save_name, dpi=300) 261 | return None 262 | 263 | 264 | 265 | def benchmark_cv(cv_set, cv_id_mapping, t5_dir, aligned_dir:str,save_dir,jobs:int,random_seed:int): 266 | 267 | cv_ids, cv_labels, cv_label_headers, cv_partitions, cv_species = load_cv_set(cv_set, cv_id_mapping) 268 | 269 | ## filter the proteins that are not in the embeddings 270 | 271 | cv_ids_t5, cv_labels_t5, cv_partitions_t5, \ 272 | t5_proteins, t5_embeddings = filter_and_load_proteins_embeddings( 273 | cv_ids, 274 | cv_labels, 275 | cv_partitions, 276 | cv_species, 277 | t5_dir, 278 | n_jobs=jobs,) 279 | 280 | cv_ids_aligned, cv_labels_aligned, cv_partitions_aligned, \ 281 | aligned_proteins, aligned_embeddings = filter_and_load_proteins_embeddings( 282 | cv_ids, 283 | cv_labels, 284 | cv_partitions, 285 | cv_species, 286 | aligned_dir, 287 | n_jobs=jobs 288 | ) 289 | 290 | scores = list() 291 | mccs = list() 292 | 293 | t5_scores, t5_mccs, t5_pr_curves = benchmark_single_modal_on_cv(cv_partitions_t5,t5_embeddings, 294 | cv_labels_t5,cv_label_headers, 295 | random_seed) 296 | scores.append(mean_std(t5_scores)) 297 | mccs.append(mean_std(t5_mccs)) 298 | 299 | aligned_scores, aligned_mccs, aligned_pr_curves = benchmark_single_modal_on_cv(cv_partitions_t5,aligned_embeddings, 300 | cv_labels_t5,cv_label_headers, 301 | random_seed) 302 | scores.append(mean_std(aligned_scores)) 303 | mccs.append(mean_std(aligned_mccs)) 304 | 305 | space_scores, space_mccs, space_pr_curves = benchmark_single_modal_on_cv(cv_partitions_t5, 306 | np.concatenate([t5_embeddings,aligned_embeddings],axis=-1), 307 | cv_labels_t5,cv_label_headers, 308 | random_seed) 309 | 310 | 311 | scores.append(mean_std(space_scores)) 312 | mccs.append(mean_std(space_mccs)) 313 | 314 | scores = pd.DataFrame(scores,columns=['f1_micro','f1_macro','accuracy','jaccard'],index=['t5','aligned','t5_concat_aligned']) 315 | mccs = pd.DataFrame(mccs,columns=cv_label_headers,index=['t5','aligned','t5_concat_aligned']) 316 | 317 | pr_curves = np.concatenate([t5_pr_curves,aligned_pr_curves,space_pr_curves],axis=1) 318 | pr_curves = pd.DataFrame(pr_curves, 319 | columns=['t5_prec','t5_recall', 320 | 'aligned_prec','aligned_recall', 321 | 'space_prec','space_recall']) 322 | 323 | scores.to_csv(f'{save_dir}/cv_scores.csv') 324 | mccs.to_csv(f'{save_dir}/cv_mccs.csv') 325 | pr_curves.to_csv(f'{save_dir}/cv_pr_curves.csv') 326 | 327 | t5_clf = init_logistic_regression(random_seed) 328 | t5_clf.fit(t5_embeddings,cv_labels_t5) 329 | 330 | ## get the projection of the t5 embeddings on clf 331 | t5_porjection = [est_.decision_function(t5_embeddings) for est_ in t5_clf.estimators_] 332 | np.save(f'{save_dir}/t5_projection_cv.npy',np.stack(t5_porjection,axis=1),) 333 | 334 | aligned_clf = init_logistic_regression(random_seed) 335 | aligned_clf.fit(aligned_embeddings,cv_labels_t5) 336 | aligned_projection = [est_.decision_function(aligned_embeddings) for est_ in aligned_clf.estimators_] 337 | np.save(f'{save_dir}/aligned_projection_cv.npy',np.stack(aligned_projection,axis=1)) 338 | 339 | space_clf = init_logistic_regression(random_seed) 340 | space_clf.fit(np.concatenate([t5_embeddings,aligned_embeddings],axis=-1),cv_labels_t5) 341 | space_projection = [est_.decision_function(np.concatenate([t5_embeddings,aligned_embeddings],axis=-1)) 342 | for est_ in space_clf.estimators_] 343 | np.save(f'{save_dir}/space_projection_cv.npy',np.stack(space_projection,axis=1)) 344 | 345 | # plot the umap projection of aligned embeddings 346 | plot_umap_projection(aligned_projection,cv_labels_aligned, 347 | cv_label_headers,f'{save_dir}/aligned_cv_clf_project_umap.png',random_seed) 348 | 349 | return {'t5':t5_clf,'aligned':aligned_clf,'space':space_clf}, cv_label_headers 350 | 351 | 352 | def process_hpa_set(human_alias,hpa_set,save_dir): 353 | 354 | aliases = pd.read_csv(human_alias,compression='gzip',sep='\t') 355 | hpa = pd.read_csv(hpa_set,sep=',') 356 | idmapping = aliases[aliases['alias'].isin(hpa['sid'])][['#string_protein_id','alias']].drop_duplicates() 357 | hpa = pd.merge(hpa,idmapping,left_on='sid',right_on='alias',how='left').drop_duplicates().dropna() 358 | 359 | hpa[['#string_protein_id', 360 | 'Cytoplasm','Nucleus','Cell membrane', 361 | 'Mitochondrion','Endoplasmic reticulum', 362 | 'Golgi apparatus','sid']].to_csv(f'{save_dir}/hpa_testset_mapped.csv',index=False) 363 | 364 | hpa_proteins = hpa['#string_protein_id'].values 365 | 366 | hpa_headers = hpa.columns[1:-4].values 367 | hpa_labels = hpa.iloc[:,1:-4].values 368 | 369 | use_headers = 'Cytoplasm, Nucleus, Cell membrane, Mitochondrion, Endoplasmic reticulum, Golgi apparatus' 370 | use_headers = use_headers.split(', ') 371 | 372 | ## filter the headers and labels 373 | hpa_labels = hpa_labels[:,[i for i,h in enumerate(hpa_headers) if h in use_headers]] 374 | hpa_headers = [h for h in hpa_headers if h in use_headers] 375 | 376 | return hpa_proteins, hpa_labels, hpa_headers 377 | 378 | def load_human_embeddings(file,hpa_proteins): 379 | 380 | proteins, embeddings = H5pyData.read(file,precision=16) 381 | 382 | prot2idx = {p:i for i,p in enumerate(proteins)} 383 | 384 | idx = [prot2idx[p] for p in hpa_proteins] 385 | 386 | return np.array(embeddings)[idx] 387 | 388 | 389 | def benchmark_hpa(human_alias,hpa_set,t5_path,aligned_path,clfs,cv_headers,save_dir): 390 | 391 | hpa_proteins, hpa_labels, hpa_headers = process_hpa_set(human_alias,hpa_set,save_dir) 392 | 393 | hpa_labels_list = list() 394 | hpa_headers_list = list() 395 | 396 | headers2int = {h:idx for idx,h in enumerate(hpa_headers)} 397 | 398 | for h in cv_headers: 399 | if h in headers2int: 400 | hpa_labels_list.append(hpa_labels[:,headers2int[h]]) 401 | hpa_headers_list.append(h) 402 | else: 403 | hpa_labels_list.append(np.zeros(hpa_labels.shape[0])) 404 | hpa_headers_list.append(h) 405 | 406 | hpa_labels = np.array(hpa_labels_list).T 407 | hpa_headers = np.array(hpa_labels_list) 408 | 409 | t5_embeddings = load_human_embeddings(t5_path,hpa_proteins) 410 | aligned_embeddings = load_human_embeddings(aligned_path,hpa_proteins) 411 | 412 | metrics_output = list() 413 | mcc_output = list() 414 | y_test = hpa_labels[:,[0,1,3,4,6,8]] 415 | pr_curves = list() 416 | 417 | for model_name, clf in clfs.items(): 418 | if model_name == 't5': 419 | 420 | X_test = t5_embeddings 421 | 422 | elif model_name == 'aligned': 423 | X_test = aligned_embeddings 424 | 425 | elif model_name == 'space': 426 | X_test = np.concatenate([t5_embeddings,aligned_embeddings],axis=1) 427 | 428 | metrics, mcc = evaluate_model(clf,X_test,hpa_labels,True) 429 | 430 | metrics_output.append([model_name]+metrics) 431 | 432 | mcc_output.append([model_name]+mcc) 433 | 434 | pred_scores = clf.predict_proba(X_test) 435 | pred_scores = np.array([s[:,1] for s in pred_scores]).T[:,[0,1,3,4,6,8]] 436 | p,r = precision_recall(pred_scores.flatten(),y_test.flatten()) 437 | pr_curves.append(np.stack([p,r],axis=1)) 438 | 439 | pr_curves = pd.DataFrame(np.concatenate(pr_curves,axis=1),columns=['t5_prec','t5_recall', 440 | 'aligned_prec','aligned_recall', 441 | 'space_prec','space_recall']) 442 | 443 | pr_curves.to_csv(f'{save_dir}/hpa_pr_curves.csv') 444 | 445 | hpa_mcc_headers = ['model_name'] + [cv_headers[i] for i in [0,1,3,4,6,8]] 446 | 447 | pd.DataFrame(mcc_output,columns=hpa_mcc_headers).to_csv(f'{save_dir}/hpa_mccs.csv') 448 | 449 | pd.DataFrame(metrics_output,columns=['model_name','f1_micro','f1_macro','accuracy','jaccard']).to_csv(f'{save_dir}/hpa_scores.csv') 450 | 451 | return None 452 | 453 | 454 | if __name__ == '__main__': 455 | 456 | parser = argparse.ArgumentParser(description='Benchmarking subcellular localization with logistic regression') 457 | 458 | parser.add_argument('--aligned_dir',type=str,help='Directory containing aligned embeddings',required=False,default='data/functional_emb') 459 | 460 | parser.add_argument('--t5_dir',type=str,help='Directory containing t5 embeddings',required=False,default='data/t5_emb') 461 | 462 | parser.add_argument('--cv_set',type=str,help='Cross validation set',required=False,default='data/benchmarks/deeploc/Swissprot_Train_Validation_dataset.csv') 463 | 464 | parser.add_argument('--cv_mapping',type=str,help='Cross validation id mapping',required=False,default='data/benchmarks/deeploc/cv_idmapping.tsv') 465 | 466 | parser.add_argument('--hpa_set',type=str,help='HPA test set',required=False,default='data/benchmarks/deeploc/hpa_testset.csv') 467 | 468 | parser.add_argument('--human_alias',type=str,help='Human alias mapping',required=False,default='data/benchmarks/deeploc/9606.protein.aliases.v12.0.txt.gz') 469 | 470 | parser.add_argument('--save_dir',type=str,help='Directory to save results',required=False,default='results/subloc') 471 | 472 | parser.add_argument('--jobs',type=int,help='Number of jobs to run in parallel',required=False,default=7) 473 | 474 | parser.add_argument('--random_seed',type=int,help='Random seed',required=False,default=5678) 475 | 476 | args = parser.parse_args() 477 | 478 | if not os.path.exists(args.save_dir): 479 | os.makedirs(args.save_dir) 480 | 481 | random.seed(args.random_seed) 482 | np.random.seed(args.random_seed) 483 | 484 | logger.info('Starting benchmarking') 485 | logger.info('Benchmarking cross validation set') 486 | clfs, cv_headers = benchmark_cv(args.cv_set, args.cv_mapping, args.t5_dir, args.aligned_dir, args.save_dir, args.jobs, args.random_seed) 487 | logger.info('Benchmarking HPA test set') 488 | benchmark_hpa(args.human_alias, args.hpa_set, f'{args.t5_dir}/9606.h5', f'{args.aligned_dir}/9606.h5', 489 | clfs, cv_headers, args.save_dir) 490 | logger.info('Results saved in {}'.format(args.save_dir)) 491 | logger.info('Done!') 492 | -------------------------------------------------------------------------------- /scripts/add_singleton.py: -------------------------------------------------------------------------------- 1 | from space.tools.taxonomy import Lineage 2 | from space.tools.data import H5pyData 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | import itertools 7 | from multiprocessing import Pool 8 | import gzip 9 | import csv 10 | import sys 11 | from loguru import logger 12 | import argparse 13 | 14 | np.random.seed(42) 15 | csv.field_size_limit(sys.maxsize) 16 | 17 | class LLineage(Lineage): 18 | 19 | def __init__(self,node_dmp_zip,group_dir) -> None: 20 | 21 | self.df = pd.read_csv(node_dmp_zip,sep='|',compression='zip',header=None) 22 | self.eggnog_ancestors = {f.split('.')[0] for f in os.listdir(group_dir) } 23 | self.group_dir = group_dir 24 | 25 | def common_ancestor(self,taxid_1, 26 | taxid_2,l_1,l_2 27 | ): 28 | 29 | 30 | taxid_1 = int(taxid_1) 31 | taxid_2 = int(taxid_2) 32 | 33 | for idx,taxid in enumerate(l_1): 34 | if taxid in l_2: 35 | common_ancestor = taxid 36 | else: 37 | break 38 | 39 | ## make sure eggNOG has the common ancestor, and the orthologs are not empty 40 | while True: 41 | if self.check_ortholog_group(taxid_1,taxid_2,common_ancestor): 42 | break 43 | idx -= 1 44 | common_ancestor = l_1[idx-1] 45 | 46 | return str(taxid_1),str(taxid_2),int(common_ancestor) 47 | 48 | 49 | def infer_common_ancestors(euk_groups,ncbi_lineage,eggnog_dir, 50 | orthologs_dir,): 51 | 52 | euk_groups = pd.read_csv(euk_groups,sep='\t') 53 | 54 | ancestor_finder = LLineage(ncbi_lineage,eggnog_dir) 55 | 56 | lineages = dict(zip(euk_groups['taxid'],euk_groups['lineage'])) 57 | ## change the lineages to list of integers 58 | lineages = {k:list(map(int,v.split(','))) for k,v in lineages.items()} 59 | 60 | pairs = list() 61 | 62 | for f in os.listdir(f'{orthologs_dir}/seeds'): 63 | src,tgt = f.split('.')[0].split('_') 64 | src,tgt = int(src),int(tgt) 65 | pairs.append((src,tgt)) 66 | 67 | # data structure is different for non_seeds 68 | for d in os.listdir(f'{orthologs_dir}/non_seeds'): 69 | for f in os.listdir(f'{orthologs_dir}/non_seeds/{d}'): 70 | src,tgt = f.split('.')[0].split('_') 71 | src,tgt = int(src),int(tgt) 72 | pairs.append((src,tgt)) 73 | 74 | ancestors = list() 75 | 76 | for src,tgt in pairs: 77 | l1 = lineages[src] 78 | l2 = lineages[tgt] 79 | ancestors.append(ancestor_finder.common_ancestor(src,tgt,l1,l2)) 80 | return ancestors 81 | 82 | 83 | def min_max(filename): 84 | 85 | p,e = H5pyData.read(filename, 16) 86 | 87 | min_e = float(np.min(e)) 88 | 89 | max_e = float(np.max(e)) 90 | 91 | return min_e, max_e 92 | 93 | def find_min_max(species_file,directory,num_jobs): 94 | 95 | species_list = open(species_file).read().strip().split('\n') 96 | 97 | with Pool(num_jobs) as p: 98 | results = p.map(min_max, [f'{directory}/{species}.h5' 99 | for species in species_list]) 100 | 101 | results = list(itertools.chain(*results)) 102 | 103 | return min(results), max(results) 104 | 105 | 106 | #2. scale the embeddings 107 | def scale_fn(filename, scaler, save_dir): 108 | 109 | taxid = filename.split('/')[-1].split('.')[0] 110 | 111 | p,e = H5pyData.read(filename, 16) 112 | 113 | e = e*scaler 114 | 115 | H5pyData.write(proteins=p, embedding=e, 116 | save_path=f'{save_dir}/{taxid}.h5', 117 | precision=16) 118 | 119 | return None 120 | 121 | def scale_embeddings(species_file,aligned_dir,num_jobs,save_dir,scaler=None): 122 | 123 | min_e, max_e = find_min_max(species_file,aligned_dir,num_jobs) 124 | 125 | if not scaler: 126 | scaler = min(0.99/max_e, abs(0.99/min_e)) 127 | 128 | species_list = open(species_file).read().strip().split('\n') 129 | 130 | with Pool(num_jobs) as p: 131 | results = p.starmap(scale_fn, [(f'{aligned_dir}/{species}.h5', 132 | scaler, save_dir) 133 | for species in species_list]) 134 | return None 135 | 136 | def filter_singleton(species,embed_dir,sequence_dir,save_dir): 137 | 138 | emb_path = os.path.join(embed_dir,f'{species}.h5') 139 | 140 | interaction_proteins, _ = H5pyData.read(emb_path,16) 141 | interaction_proteins = set(interaction_proteins) 142 | 143 | sequence_file = f'{sequence_dir}/{species}.protein.sequences.v12.0.fa.gz' 144 | ## read the sequence file 145 | sequence_proteins = set() 146 | with gzip.open(sequence_file,'rt') as f: 147 | for line in f: 148 | if line.startswith('>'): 149 | sequence_proteins.add(line.strip().split()[0][1:]) 150 | 151 | singleton_proteins = sequence_proteins - interaction_proteins 152 | 153 | with open(f'{save_dir}/{species}.singleton.proteins.v12.0.txt','w') as f: 154 | for protein in singleton_proteins: 155 | f.write(f'{protein}\n') 156 | return None 157 | 158 | 159 | 160 | def filter_singleton_parallel(species,embed_dir,sequence_dir,save_dir,number_jobs): 161 | 162 | with Pool(number_jobs) as p: 163 | results = p.starmap(filter_singleton, [(species,embed_dir,sequence_dir,save_dir) 164 | for species in species]) 165 | 166 | return None 167 | 168 | def filter_singleton_in_og(ancestor,singleton_dir,eggnog_dir,save_dir): 169 | 170 | group_file = f'{eggnog_dir}/{ancestor}.tsv.gz' 171 | 172 | singletons = dict() 173 | 174 | for f in os.listdir(singleton_dir): 175 | f = os.path.join(singleton_dir,f) 176 | s = open(f).read().strip().split('\n') 177 | species = f.split('/')[-1].split('.')[0] 178 | singletons[species] = set(s) 179 | 180 | records = list() 181 | 182 | with gzip.open(group_file,'rt') as f: 183 | 184 | for line in f: 185 | line = line.strip().split('\t') 186 | orthologs = set(line[-1].split(',')) 187 | og_name = line[1] 188 | species = line[4].split(',') 189 | 190 | og_singletons = list() 191 | 192 | ## singleton in this og 193 | for o in orthologs: 194 | s = o.split('.')[0] 195 | ## check if this is a singleton 196 | if o in singletons[s]: 197 | og_singletons.append(o) 198 | 199 | if len(og_singletons) == 0: 200 | continue 201 | 202 | ## check if all the orthologs are singletons 203 | if len(og_singletons) == len(orthologs): 204 | records.append((og_name,','.join(species),','.join(og_singletons),'','non-interaction')) 205 | else: 206 | non_singleton_og = orthologs - set(og_singletons) 207 | records.append((og_name,','.join(species),','.join(og_singletons),','.join(non_singleton_og),'partial-interaction')) 208 | 209 | 210 | ## save as a gzipped file 211 | with gzip.open(f'{save_dir}/{ancestor}.tsv.gz','wt') as f: 212 | writer = csv.writer(f,delimiter='\t') 213 | writer.writerow(['og_name','species','singletons','non_singleton_orths','interaction']) 214 | writer.writerows(records) 215 | return None 216 | 217 | def filter_singleton_in_og_parallel(ancestors,singleton_dir,eggnog_dir,save_dir,num_jobs): 218 | 219 | # ancestors = set([int(l.split('\t')[-1]) for l in ancestors]) 220 | ancestors = set([int(l[-1]) for l in ancestors]) 221 | 222 | with Pool(num_jobs) as p: 223 | results = p.starmap(filter_singleton_in_og, [(ancestor,singleton_dir,eggnog_dir,save_dir) 224 | for ancestor in ancestors]) 225 | 226 | return None 227 | 228 | 229 | def generate_random_embeddings(size=512): 230 | # Generate random values between 0 and 1 231 | random_values = np.random.random(size) 232 | 233 | # Randomly choose whether each value should be in the negative or positive range 234 | negative_mask = np.random.random(size) < 0.5 235 | 236 | # Transform values to desired ranges 237 | result = np.where(negative_mask, 238 | # For negative range: [-1, -0.99] 239 | -1 + (random_values * 0.01), 240 | # For positive range: [0.99, 1] 241 | 0.99 + (random_values * 0.01)) 242 | 243 | return result 244 | 245 | 246 | def allocate_orthologous_singletons(singleton_og_dir,save_name): 247 | 248 | noise = 1e-5 249 | 250 | singleton_embs = dict() 251 | 252 | for f in os.listdir(singleton_og_dir): 253 | f = f'{singleton_og_dir}/{f}' 254 | 255 | df = pd.read_csv(f,compression='gzip',sep='\t',header=None) 256 | 257 | df = df[df.iloc[:,-1] == 'non-interaction'] 258 | 259 | for _,line in df.iterrows(): 260 | 261 | proteins = line[2].split(',') 262 | 263 | random_emb = generate_random_embeddings() 264 | 265 | noise_ = noise * np.random.random(size=(len(proteins),512)) 266 | 267 | protein_embs = random_emb + noise_ 268 | 269 | for p_idx, p in enumerate(proteins): 270 | if p in singleton_embs: 271 | singleton_embs[p].append(protein_embs[p_idx]) 272 | else: 273 | singleton_embs[p] = [protein_embs[p_idx]] 274 | 275 | ## average the embeddings 276 | for p,e in singleton_embs.items(): 277 | e = np.array(e) 278 | e = np.mean(e,axis=0).reshape(1,-1) 279 | singleton_embs[p] = e 280 | 281 | ## save the embeddings 282 | H5pyData.write(list(singleton_embs.keys()), 283 | np.array(list(singleton_embs.values())).reshape(-1,512), 284 | save_name,16) 285 | 286 | return None 287 | 288 | 289 | def extract_network_proteins(species,og_proteins,scaled_dir): 290 | # get the embeddings for og_proteins 291 | p,e = H5pyData.read(f'{scaled_dir}/{species}.h5',16) 292 | 293 | # make a dictionary of proteins 294 | p2indx = {p:i for i,p in enumerate(p)} 295 | 296 | # get the indices of the proteins in the og_proteins 297 | indices = [p2indx[p] for p in og_proteins] 298 | used_e = e[indices] 299 | return og_proteins,used_e 300 | 301 | def extract_network_proteins_parallel(singleton_ogs,scaled_dir,save_name,number_jobs): 302 | og_interaction_proteins = set() 303 | 304 | for f in os.listdir(singleton_ogs): 305 | f = f'{singleton_ogs}/{f}' 306 | 307 | df = pd.read_csv(f,compression='gzip',sep='\t',header=None) 308 | 309 | df_interaction = df[df.iloc[:,-1] == 'partial-interaction'] 310 | 311 | for line in df_interaction.iloc[:,-2]: 312 | og_interaction_proteins.update(line.split(',')) 313 | 314 | proteins_per_species = dict() 315 | for p in og_interaction_proteins: 316 | species = p.split('.')[0] 317 | if species in proteins_per_species: 318 | proteins_per_species[species].append(p) 319 | else: 320 | proteins_per_species[species] = [p] 321 | 322 | ## read the embeddings 323 | with Pool(number_jobs) as p: 324 | results = p.starmap(extract_network_proteins,[(s,p,scaled_dir) 325 | for s,p in proteins_per_species.items()]) 326 | 327 | # save the embeddings in a single file 328 | proteins = list() 329 | embeddings = list() 330 | 331 | ## prepare the data for saving 332 | for og_proteins,used_e in results: 333 | proteins.extend(og_proteins) 334 | embeddings.extend(used_e) 335 | 336 | H5pyData.write(proteins,embeddings,save_name,16) 337 | 338 | return None 339 | 340 | 341 | def collect_embeddings(species:str,interation_og:str, 342 | singleton_og_dir:str, 343 | pure_singleton_og:str, 344 | singleton_dir:str, 345 | scaled_dir:str, 346 | save_dir:str): 347 | 348 | singleton_embeddings = dict() 349 | ## load the singletons 350 | noise = 1e-5 351 | interaction_og,interaction_og_emb = H5pyData.read(interation_og,16) 352 | interaction_og2idx = {str(p):i for i,p in enumerate(interaction_og)} 353 | 354 | ## load the singleton orthologs 355 | for f in os.listdir(singleton_og_dir): 356 | f = f'{singleton_og_dir}/{f}' 357 | 358 | df = pd.read_csv(f,compression='gzip',sep='\t',header=None) 359 | 360 | df_interaction = df[df.iloc[:,-1] == 'partial-interaction'] 361 | 362 | for _, line in df_interaction.iterrows(): 363 | og_species = line[1].split(',') 364 | if species not in og_species: 365 | continue 366 | 367 | ## proteins with interactions, in orthologous groups 368 | og_proteins = line[3].split(',') 369 | 370 | indices = list(map(lambda x: interaction_og2idx[x],og_proteins)) 371 | og_emb = np.mean(interaction_og_emb[indices],axis=0) 372 | 373 | ## species proteins, singletons 374 | species_proteins = line[2].split(',') 375 | species_proteins = [p for p in species_proteins if p.split('.')[0] == species] 376 | noise_ = noise * np.random.random(size=(len(species_proteins),512)) 377 | 378 | protein_embs = og_emb + noise_ 379 | 380 | for p_idx, p in enumerate(species_proteins): 381 | if p in singleton_embeddings: 382 | singleton_embeddings[p].append(protein_embs[p_idx]) 383 | else: 384 | singleton_embeddings[p] = [protein_embs[p_idx]] 385 | 386 | ## average the embeddings in orthologous groups 387 | for p,e in singleton_embeddings.items(): 388 | e = np.array(e) 389 | e = np.mean(e,axis=0).reshape(-1) 390 | singleton_embeddings[p] = e 391 | 392 | ## put the pure singletons into the dictionary 393 | ## just to be incorporated in the end 394 | pure_singleton_og,pure_singleton_ogs_emb = H5pyData.read(pure_singleton_og,16) 395 | pure_singleton_og2idx = {p:i for i,p in enumerate(pure_singleton_og)} 396 | ## filter out other species 397 | # pure_singleton_og = [p for p in pure_singleton_og if p.split('.')[0] == species] 398 | pure_singleton_og = list(filter(lambda x: x.split('.')[0] == species,pure_singleton_og)) 399 | indices = list(map(lambda x: pure_singleton_og2idx[x],pure_singleton_og)) 400 | pure_singleton_ogs_emb = pure_singleton_ogs_emb[indices] 401 | singleton_embeddings.update({p:e for p,e in zip(pure_singleton_og,pure_singleton_ogs_emb)}) 402 | 403 | ## load the singletons list 404 | all_singletons = f'{singleton_dir}/{species}.singleton.proteins.v12.0.txt' 405 | all_singletons = open(all_singletons).read().strip().split('\n') 406 | 407 | ## 408 | unsolved_singletons = set(all_singletons) - set(singleton_embeddings.keys()) 409 | 410 | ## give random embeddings for the unsolved singletons 411 | for p in unsolved_singletons: 412 | random_emb = generate_random_embeddings() 413 | singleton_embeddings[p] = random_emb 414 | 415 | ## check if every proteins are solved 416 | assert len(set(all_singletons) - set(singleton_embeddings.keys())) == 0 417 | 418 | ## save the embeddings 419 | proteins,embeddings = H5pyData.read(f'{scaled_dir}/{species}.h5',16) 420 | 421 | proteins = list(singleton_embeddings.keys()) + list(map(str,proteins)) 422 | embeddings = np.concatenate([np.array(list(singleton_embeddings.values())),embeddings],axis=0) 423 | 424 | ## save the embeddings 425 | H5pyData.write(proteins,embeddings,f'{save_dir}/{species}.h5',16) 426 | 427 | return None 428 | 429 | def collect_embeddings_parallel(species:list,interation_og:str, 430 | singleton_og_dir:str, 431 | pure_singleton_og:str, 432 | singleton_dir:str, 433 | scaled_dir:str, 434 | save_dir:str, 435 | number_jobs): 436 | 437 | with Pool(number_jobs) as p: 438 | p.starmap(collect_embeddings,[(s,interation_og,singleton_og_dir, 439 | pure_singleton_og,singleton_dir,scaled_dir,save_dir) 440 | for s in species]) 441 | 442 | return None 443 | 444 | 445 | def main(aligned_dir,species_file,euk_groups,ncbi_lineage,eggnog_dir, 446 | sequence_dir, 447 | orthologs_dir,working_dir,full_embedding_save_dir, 448 | number_jobs,scaler=1.497): 449 | 450 | if not os.path.exists(working_dir): 451 | os.makedirs(working_dir) 452 | 453 | if not os.path.exists(full_embedding_save_dir): 454 | os.makedirs(full_embedding_save_dir) 455 | 456 | logger.info('Infer common ancestors...') 457 | common_ancestors = infer_common_ancestors(euk_groups,ncbi_lineage,eggnog_dir, 458 | orthologs_dir,) 459 | 460 | #save all the scaled embeddings 461 | scaled_save_dir = f'{working_dir}/scaled' 462 | if not os.path.exists(scaled_save_dir): 463 | os.makedirs(scaled_save_dir) 464 | logger.info('Scaling the embeddings...') 465 | scale_embeddings(species_file,aligned_dir,number_jobs, 466 | scaled_save_dir,scaler=scaler) 467 | 468 | #filter singletons 469 | logger.info('Filtering singletons...') 470 | species = open(species_file).read().strip().split('\n') 471 | if not os.path.exists(f'{working_dir}/singleton_ids'): 472 | os.makedirs(f'{working_dir}/singleton_ids') 473 | filter_singleton_parallel(species,f'{working_dir}/scaled',sequence_dir=sequence_dir, 474 | save_dir=f'{working_dir}/singleton_ids', 475 | number_jobs=number_jobs) 476 | 477 | if not os.path.exists(f'{working_dir}/singleton_in_og'): 478 | os.makedirs(f'{working_dir}/singleton_in_og') 479 | logger.info('Filtering singletons in orthologous groups...') 480 | filter_singleton_in_og_parallel(common_ancestors,f'{working_dir}/singleton_ids', 481 | eggnog_dir,f'{working_dir}/singleton_in_og', 482 | number_jobs) 483 | logger.info('Allocating orthologous singletons...') 484 | allocate_orthologous_singletons(f'{working_dir}/singleton_in_og', 485 | f'{working_dir}/singleton_in_og_embeddings.h5') 486 | 487 | logger.info('Extracting network orthologs...') 488 | extract_network_proteins_parallel(f'{working_dir}/singleton_in_og', 489 | scaled_dir=scaled_save_dir, 490 | save_name=f'{working_dir}/og_proteins.h5', 491 | number_jobs=number_jobs) 492 | 493 | logger.info('Collecting embeddings...') 494 | collect_embeddings_parallel(species,interation_og=f'{working_dir}/og_proteins.h5', 495 | singleton_og_dir=f'{working_dir}/singleton_in_og', 496 | pure_singleton_og=f'{working_dir}/singleton_in_og_embeddings.h5', 497 | singleton_dir=f'{working_dir}/singleton_ids', 498 | scaled_dir=scaled_save_dir, 499 | save_dir=full_embedding_save_dir, 500 | number_jobs=number_jobs) 501 | # remove the working directory 502 | os.system(f'rm -r {working_dir}') 503 | 504 | logger.info('DONE.') 505 | return None 506 | 507 | 508 | if __name__ == "__main__": 509 | 510 | argparser = argparse.ArgumentParser(description='Allocate embeddings for singletons') 511 | 512 | argparser.add_argument('--aligned_dir',type=str,required=False, 513 | default='data/aligned', 514 | help='Directory with aligned embeddings') 515 | argparser.add_argument('--species_file',type=str,required=False, 516 | default='data/euks.txt', 517 | help='File with species names') 518 | argparser.add_argument('--euk_groups',type=str,required=False, 519 | default='data/euks_groups.tsv', 520 | help='File with eukaryotic groups') 521 | argparser.add_argument('--ncbi_lineage',type=str,required=False, 522 | default='data/ncbi_lineage.zip', 523 | help='NCBI taxonomy lineage file') 524 | argparser.add_argument('--eggnog_dir',type=str,required=False, 525 | default='data/eggnog', 526 | help='Directory with eggNOG orthologous groups') 527 | argparser.add_argument('--sequence_dir',type=str,required=False, 528 | default='data/sequences', 529 | help='Directory with protein sequences') 530 | argparser.add_argument('--orthologs_dir',type=str,required=False, 531 | default='data/orthologs', 532 | help='Directory with orthologous groups') 533 | argparser.add_argument('--working_dir',type=str,required=False, 534 | default='temp/singletons', 535 | help='Working directory') 536 | argparser.add_argument('--full_embedding_save_dir',type=str,required=False, 537 | default='results/functional_embeddings', 538 | help='Directory to save the embeddings') 539 | argparser.add_argument('--number_jobs',type=int,required=False, 540 | default=7, 541 | help='Number of parallel jobs') 542 | argparser.add_argument('--scaler',type=float,required=False, 543 | default=1.497, 544 | help='Scaling factor for embeddings, default is 1.497 if you use the embeddings in data/aligned. \ 545 | Use `None` if you want to calculate the scaling factor') 546 | 547 | args = argparser.parse_args() 548 | 549 | main(args.aligned_dir,args.species_file,args.euk_groups,args.ncbi_lineage,args.eggnog_dir, 550 | args.sequence_dir, 551 | args.orthologs_dir,args.working_dir,args.full_embedding_save_dir, 552 | args.number_jobs,scaler=args.scaler) 553 | 554 | 555 | 556 | -------------------------------------------------------------------------------- /src/space/models/fedcoder.py: -------------------------------------------------------------------------------- 1 | from space.tools.data import H5pyData 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader, Dataset 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from torch.utils.tensorboard import SummaryWriter 8 | import numpy as np 9 | import math 10 | import os 11 | from itertools import combinations, product 12 | from typing import Iterable 13 | from loguru import logger 14 | import yaml 15 | from datetime import datetime 16 | from tqdm import tqdm 17 | import json 18 | import pandas as pd 19 | from uuid import uuid4 20 | 21 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 22 | 23 | class Encoder(nn.Module): 24 | def __init__(self,input_dim,latent_dim,hidden_dims:list=None,activation_function=None) -> None: 25 | super(Encoder,self).__init__() 26 | 27 | self.hidden_dims = hidden_dims 28 | self.activation_function = activation_function 29 | 30 | self.hidden_layers = nn.ModuleList() 31 | ## create hidden layers 32 | if self.hidden_dims is not None: 33 | for h_dim in self.hidden_dims: 34 | self.hidden_layers.append(nn.Linear(input_dim,h_dim)) 35 | input_dim = h_dim 36 | 37 | self.l1 = nn.Linear(input_dim,latent_dim) 38 | 39 | def forward(self,x): 40 | 41 | if self.hidden_dims is not None: 42 | if self.activation_function is not None: 43 | for layer in self.hidden_layers: 44 | x = self.activation_function(layer(x)) 45 | else: 46 | for layer in self.hidden_layers: 47 | x = layer(x) 48 | x = self.l1(x) 49 | return x 50 | 51 | class Decoder(nn.Module): 52 | def __init__(self,latent_dim,input_dim,hidden_dims:list=None,activation_function=None) -> None: 53 | super(Decoder,self).__init__() 54 | 55 | self.hidden_dims = hidden_dims 56 | self.activation_function = activation_function 57 | 58 | self.hidden_layers = nn.ModuleList() 59 | ## create hidden layers, reverse order 60 | if self.hidden_dims is not None: 61 | for h_dim in reversed(self.hidden_dims): 62 | self.hidden_layers.append(nn.Linear(latent_dim,h_dim)) 63 | latent_dim = h_dim 64 | 65 | self.l1 = nn.Linear(latent_dim,input_dim) 66 | 67 | def forward(self,z): 68 | 69 | if self.hidden_dims is not None: 70 | if self.activation_function is not None: 71 | for layer in self.hidden_layers: 72 | z = self.activation_function(layer(z)) 73 | else: 74 | for layer in self.hidden_layers: 75 | z = layer(z) 76 | z = self.l1(z) 77 | return z 78 | 79 | 80 | class VAEEncoder(nn.Module): 81 | def __init__(self, input_dim, latent_dim, hidden_dims, activation_function, batch_norm, ): 82 | super(VAEEncoder, self).__init__() 83 | self.layers = nn.ModuleList() 84 | if batch_norm: 85 | self.batch_norms = nn.ModuleList() 86 | self.activation_function = activation_function 87 | self.batch_norm = batch_norm 88 | 89 | # Create layers 90 | for h_dim in hidden_dims: 91 | self.layers.append(nn.Linear(input_dim, h_dim)) 92 | if batch_norm: 93 | self.batch_norms.append(nn.BatchNorm1d(h_dim)) 94 | input_dim = h_dim 95 | 96 | self.fc_mean = nn.Linear(hidden_dims[-1], latent_dim) 97 | self.fc_log_var = nn.Linear(hidden_dims[-1], latent_dim) 98 | 99 | def reparameterize(self, mean, log_var): 100 | std = torch.exp(0.5 * log_var) 101 | eps = torch.randn_like(std) 102 | return mean + eps * std 103 | 104 | 105 | def xavier_init(self): 106 | for layer in self.layers: 107 | torch.nn.init.xavier_uniform_(layer.weight) 108 | layer.bias.data.fill_(0.01) 109 | torch.nn.init.xavier_uniform_(self.fc_mean.weight) 110 | torch.nn.init.xavier_uniform_(self.fc_log_var.weight) 111 | 112 | def forward(self, x): 113 | 114 | if self.batch_norm: 115 | for layer, batch_norm in zip(self.layers, self.batch_norms): 116 | if self.activation_function: 117 | x = self.activation_function(batch_norm(layer(x))) 118 | else: 119 | x = batch_norm(layer(x)) 120 | else: 121 | for layer in self.layers: 122 | if self.activation_function: 123 | x = self.activation_function(layer(x)) 124 | else: 125 | x = layer(x) 126 | mean = self.fc_mean(x) 127 | log_var = self.fc_log_var(x) 128 | return self.reparameterize(mean, log_var) 129 | 130 | 131 | 132 | class VAEDecoder(nn.Module): 133 | def __init__(self, latent_dim, output_dim, hidden_dims, activation_function, batch_norm): 134 | super(VAEDecoder, self).__init__() 135 | self.layers = nn.ModuleList() 136 | if batch_norm: 137 | self.batch_norms = nn.ModuleList() 138 | self.activation_function = activation_function 139 | self.batch_norm = batch_norm 140 | # Create layers 141 | for h_dim in reversed(hidden_dims): 142 | self.layers.append(nn.Linear(latent_dim, h_dim)) 143 | if batch_norm: 144 | self.batch_norms.append(nn.BatchNorm1d(h_dim)) 145 | latent_dim = h_dim 146 | 147 | self.final_layer = nn.Linear(hidden_dims[0], output_dim) 148 | self.xavier_init() 149 | 150 | 151 | def xavier_init(self): 152 | for layer in self.layers: 153 | torch.nn.init.xavier_uniform_(layer.weight) 154 | torch.nn.init.xavier_uniform_(self.final_layer.weight) 155 | 156 | 157 | def forward(self, z): 158 | if self.batch_norm: 159 | for layer, batch_norm in zip(self.layers, self.batch_norms): 160 | if self.activation_function: 161 | z = self.activation_function(batch_norm(layer(z))) 162 | else: 163 | z = batch_norm(layer(z)) 164 | 165 | else: 166 | for layer in self.layers: 167 | if self.activation_function: 168 | z = self.activation_function(layer(z)) 169 | else: 170 | z = layer(z) 171 | reconstruction = torch.sigmoid(self.final_layer(z)) 172 | return reconstruction 173 | 174 | 175 | 176 | class BaseVAEEncoder(nn.Module): 177 | def __init__(self, input_dim, latent_dim): 178 | super(BaseVAEEncoder, self).__init__() 179 | self.fc1 = nn.Linear(input_dim, 2*latent_dim) 180 | self.fc2_mean = nn.Linear(2*latent_dim, latent_dim) 181 | self.fc2_log_var = nn.Linear(2*latent_dim, latent_dim) 182 | 183 | def reparameterize(self, mean, log_var): 184 | std = torch.exp(0.5 * log_var) 185 | eps = torch.randn_like(std) 186 | return mean + eps * std 187 | 188 | def forward(self, x): 189 | x = F.relu(self.fc1(x)) 190 | mean = self.fc2_mean(x) 191 | log_var = self.fc2_log_var(x) 192 | # z = self.reparameterize(mean, log_var) 193 | return self.reparameterize(mean, log_var) 194 | 195 | class BaseVAEDecoder(nn.Module): 196 | def __init__(self, latent_dim, output_dim): 197 | super(BaseVAEDecoder, self).__init__() 198 | self.fc1 = nn.Linear(latent_dim, 2*latent_dim) 199 | self.fc2 = nn.Linear(2*latent_dim, output_dim) 200 | 201 | def forward(self, z): 202 | z = F.relu(self.fc1(z)) 203 | reconstruction = torch.sigmoid(self.fc2(z)) 204 | return reconstruction 205 | 206 | def init_xavier(module): 207 | if isinstance(module, nn.Linear): 208 | nn.init.xavier_uniform_(module.weight) 209 | if module.bias is not None: 210 | nn.init.constant_(module.bias, 0) 211 | 212 | 213 | 214 | class EarlyStopping(): 215 | 216 | def __init__(self, patience=5, delta=0.0001, save_models=True): 217 | 218 | self.patience = patience 219 | self.counter = 0 220 | self.best_score = None 221 | self.early_stop = False 222 | self.delta = delta 223 | self.save_models = save_models 224 | 225 | def save_checkpoint(self, models:dict, save_name:str): 226 | '''Saves model when the metric improves.''' 227 | logger.info(f'Saving the checkpoint to {save_name}') 228 | torch.save(models.state_dict(), save_name) 229 | 230 | return True 231 | 232 | def reset(self): 233 | self.best_score = None 234 | self.counter = 0 235 | self.early_stop = False 236 | 237 | def __call__(self, loss, model:dict, save_folder:str): 238 | if self.best_score is None: 239 | self.best_score = loss 240 | if self.save_models: 241 | self.save_checkpoint(model, save_folder) 242 | elif loss > self.best_score - self.delta: 243 | self.counter += 1 244 | if self.counter >= self.patience: 245 | self.early_stop = True 246 | else: 247 | self.best_score = loss 248 | if self.save_models: 249 | self.save_checkpoint(model, save_folder) 250 | self.counter = 0 251 | 252 | 253 | class NodeEmbedData(Dataset): 254 | 255 | def __init__(self,file_path) -> None: 256 | super().__init__() 257 | 258 | self.protein_names, self.embedding = self.read_h5_file(file_path) 259 | 260 | def read_h5_file(self,file_path): 261 | 262 | # embed_np = np.load(file_path)['data'] 263 | proteins,embed_np = H5pyData.read(file_path,32) 264 | return proteins,torch.Tensor(embed_np).requires_grad_(False) 265 | 266 | def __len__(self): 267 | 268 | return len(self.embedding) 269 | 270 | def __getitem__(self, index): 271 | 272 | return self.embedding[index] 273 | 274 | class OrthologPair: 275 | 276 | def __init__(self, 277 | orthologs:Iterable,): 278 | 279 | self.pairs, self.weights = self.load_ortholog_pairs(orthologs) 280 | 281 | 282 | def __len__(self): 283 | 284 | return self.pairs.shape[0] 285 | 286 | def __getitem__(self, index): 287 | 288 | src_idx = self.pairs[:,0][index] 289 | tgt_idx = self.pairs[:,1][index] 290 | 291 | weight = self.weights[index] 292 | 293 | return src_idx, tgt_idx, weight 294 | 295 | 296 | def load_ortholog_pairs(self,pairs_file:str): 297 | pairs_weights = [pair.strip().split('\t') for pair in open(pairs_file,'r').readlines()] 298 | pairs = np.array(pairs_weights)[:, :-1].astype(int) 299 | pairs = torch.tensor(pairs).requires_grad_(False) 300 | 301 | weights = np.array(pairs_weights)[:, -1].astype(float) 302 | weights = torch.tensor(weights).requires_grad_(False) 303 | 304 | return pairs, weights 305 | 306 | 307 | class FedCoder: 308 | # add the docstring here 309 | ''' 310 | Parameters 311 | ---------- 312 | seed_species : str 313 | The file containing the seed species taxonomy ids. 314 | node2vec_dir : str 315 | The directory containing the node2vec embeddings. 316 | ortholog_dir : str 317 | The directory containing the ortholog pairs. 318 | aligned_embedding_save_dir : str 319 | The directory to save the embeddings. 320 | save_top_k : int, optional 321 | The number of top models to save, by default 3. 322 | log_dir : str, optional 323 | The directory to save the logs, by default None. 324 | input_dim : int, optional 325 | The input dimension of the autoencoder, by default 128. 326 | latent_dim : int, optional 327 | The latent dimension of the autoencoder, by default 512. 328 | hidden_dims : list, optional 329 | The hidden dimensions of the autoencoder, by default None. 330 | activation_fn : str, optional 331 | The activation function of the autoencoder, by default None. Only useful when hidden_dims is not None. 332 | batch_norm : bool, optional 333 | Whether to use batch normalization, by default False. Only useful when hidden_dims is not None. 334 | number_iters : int, optional 335 | The number of iterations per epoch to train the model, by default 10. 336 | autoencoder_type : str, optional 337 | The type of autoencoder to use, by default 'naive'. 338 | gamma : float, optional 339 | The gamma parameter for the alignment loss, by default 0.1. 340 | alpha : float, optional 341 | The alpha parameter for balancing the alignment loss and reconstruction loss, by default 0.5. 342 | lr : float, optional 343 | The learning rate, by default 0.01. 344 | device : str, optional 345 | The device to use, by default 'cpu'. 346 | patience : int, optional 347 | The patience for early stopping, by default 5. 348 | delta : float, optional 349 | The delta parameter for early stopping, by default 0.001. 350 | epochs : int, optional 351 | The number of epochs to train the model, by default 400. 352 | from_pretrained : str, optional 353 | The path to the pretrained model, by default None. 354 | ''' 355 | 356 | def __init__(self,seed_species:str, 357 | node2vec_dir:str, 358 | ortholog_dir:str, 359 | aligned_embedding_save_dir:str, 360 | save_top_k:int=3, 361 | log_dir:str=None, 362 | input_dim:int=128, 363 | latent_dim:int=512, 364 | hidden_dims:list=None, 365 | activation_fn:str=None, 366 | batch_norm:bool=False, 367 | number_iters:int=10, 368 | autoencoder_type:str='naive', 369 | gamma:float=0.1, 370 | alpha:float=0.5, 371 | lr:float=0.01, 372 | device:str='cpu', 373 | patience:int=5, 374 | delta:float=0.0001, 375 | epochs:int=600, 376 | from_pretrained:str=None, 377 | ) -> None: 378 | 379 | 380 | seed_species = open(seed_species,'r').read().strip().split('\n') 381 | self.seed_species = list(map(int,seed_species)) 382 | 383 | self.node2vec_dir = node2vec_dir 384 | self.ortholog_dir = ortholog_dir 385 | self.embedding_save_folder = aligned_embedding_save_dir 386 | self.save_top_k = save_top_k 387 | log_dir = log_dir if log_dir is not None else os.path.join(aligned_embedding_save_dir,'logs') 388 | log_dir = os.path.join(log_dir, datetime.now().strftime("%Y%m%d-%H%M%S")) 389 | self.log_dir = log_dir + '-' + str(uuid4()) 390 | self.model_save_path = os.path.join(self.log_dir,'model.pth') 391 | 392 | if not os.path.exists(self.log_dir): 393 | os.makedirs(self.log_dir) 394 | if not os.path.exists(self.embedding_save_folder): 395 | os.makedirs(self.embedding_save_folder) 396 | 397 | 398 | self.input_dim = input_dim 399 | self.latent_dim = latent_dim 400 | self.hidden_dims = hidden_dims 401 | self.activation_fn = activation_fn 402 | self.batch_norm = batch_norm 403 | self.number_iters = number_iters 404 | self.autoencoder_type = autoencoder_type 405 | self.gamma = gamma 406 | self.alpha = alpha 407 | self.lr = lr 408 | self.device = device 409 | self.patience = patience 410 | self.delta = delta 411 | self.epochs = epochs 412 | self.from_pretrained = from_pretrained 413 | 414 | def save_hyperparameters(self,save_dict,save_dir): 415 | ## save as a yaml file 416 | with open(f'{save_dir}/hyperparameters.yaml','w') as f: 417 | yaml.dump(save_dict,f) 418 | return None 419 | 420 | 421 | def init_everything(self): 422 | 423 | def init_autoencoder(input_dim,latent_dim,hidden_dims,activation_fn,batch_norm,device): 424 | # to simply, use a function to initialize the autoencoder 425 | if self.autoencoder_type == 'naive': 426 | return Encoder(input_dim,latent_dim,hidden_dims,activation_fn).to(device), \ 427 | Decoder(latent_dim,input_dim,hidden_dims,activation_fn).to(device) 428 | elif self.autoencoder_type == 'vae': 429 | return VAEEncoder(input_dim,latent_dim,hidden_dims,activation_fn,batch_norm).to(device), \ 430 | VAEDecoder(latent_dim,input_dim,hidden_dims,activation_fn,batch_norm).to(device) 431 | else: 432 | raise ValueError('Unknown autoencoder type') 433 | 434 | ## load the node2vec embeddings 435 | logger.info('Loading the node2vec embeddings') 436 | self.node2vec_embeddings = dict() 437 | for species in self.seed_species: 438 | self.node2vec_embeddings[str(species)] = NodeEmbedData(f'{self.node2vec_dir}/{species}.h5') 439 | ## dataloader 440 | self.node2vec_dataloader = {species:DataLoader(embed, 441 | batch_size=math.ceil(len(embed)/self.number_iters), 442 | shuffle=True) 443 | for species,embed in self.node2vec_embeddings.items()} 444 | 445 | 446 | ## load the ortholog pairs 447 | logger.info('Loading the ortholog pairs') 448 | species_pairs = list(combinations(self.seed_species,2)) 449 | # sort the species pairs 450 | species_pairs = [sorted(pair) for pair in species_pairs] 451 | self.ortholog_pairs = dict() 452 | for src,tgt in species_pairs: 453 | self.ortholog_pairs[f'{src}_{tgt}'] = OrthologPair(f'{self.ortholog_dir}/{src}_{tgt}.tsv') 454 | ## dataloader 455 | self.ortholog_dataloader = {pair:DataLoader(ortholog, 456 | batch_size=math.ceil(len(ortholog)/self.number_iters), 457 | shuffle=True) 458 | for pair,ortholog in self.ortholog_pairs.items()} 459 | 460 | ## init the models: {'encoder_species':encoder,'decoder_species':decoder} 461 | logger.info('Initializing the models') 462 | self.models = dict() 463 | for species in self.seed_species: 464 | encoder, decoder = init_autoencoder(self.input_dim,self.latent_dim,self.hidden_dims,self.activation_fn,self.batch_norm,self.device) 465 | self.models[f'encoder_{species}'] = encoder 466 | self.models[f'decoder_{species}'] = decoder 467 | self.models = torch.nn.ModuleDict(self.models) 468 | 469 | if self.from_pretrained is not None: 470 | logger.info(f'Loading the pretrained model from {self.from_pretrained}') 471 | self.models.load_state_dict(torch.load(self.from_pretrained)) 472 | else: 473 | for model in self.models.values(): 474 | model.apply(init_xavier) 475 | 476 | 477 | ## init the optimizers 478 | self.parameters = [] 479 | for model in self.models.values(): 480 | self.parameters += list(model.parameters()) 481 | 482 | self.optimizer = torch.optim.Adam(self.parameters,lr=self.lr) 483 | 484 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',patience=3,factor=0.1) 485 | 486 | self.early_stopping = EarlyStopping(patience=self.patience,delta=self.delta,save_models=True) 487 | 488 | # init the tensorboard writer 489 | logger.info('Initializing the tensorboard writer') 490 | 491 | self.writer = SummaryWriter(self.log_dir) 492 | 493 | logger.info('Everything initialized') 494 | 495 | hyperparameters = { 'seed_species':self.seed_species, 496 | 'node2vec_dir':self.node2vec_dir, 497 | 'ortholog_dir':self.ortholog_dir, 498 | 'embedding_save_folder':self.embedding_save_folder, 499 | 'model_save_path':self.model_save_path, 500 | 'save_top_k':self.save_top_k, 501 | 'log_dir':self.log_dir, 502 | 'input_dim':self.input_dim, 503 | 'latent_dim':self.latent_dim, 504 | 'hidden_dims':self.hidden_dims, 505 | 'activation_fn':self.activation_fn, 506 | 'batch_norm':self.batch_norm, 507 | 'number_iters':self.number_iters, 508 | 'autoencoder_type':self.autoencoder_type, 509 | 'gamma':self.gamma, 510 | 'alpha':self.alpha, 511 | 'lr':self.lr, 512 | 'device':self.device, 513 | 'patience':self.patience, 514 | 'delta':self.delta, 515 | 'epochs':self.epochs, 516 | 'from_pretrained':self.from_pretrained,} 517 | self.save_hyperparameters(hyperparameters,self.log_dir) 518 | 519 | return None 520 | 521 | 522 | def reconstruction_loss(self,node_batches): 523 | loss = list() 524 | for taxid, batch in node_batches.items(): 525 | 526 | batch = batch.to(self.device) 527 | 528 | latent = self.models[f'encoder_{taxid}'](batch) 529 | 530 | reconstruction = self.models[f'decoder_{taxid}'](latent) 531 | 532 | loss.append(F.pairwise_distance(batch,reconstruction,p=2).mean().unsqueeze(0)) 533 | 534 | return torch.cat(loss).mean() 535 | 536 | 537 | def alignment_loss(self,pair_batches): 538 | 539 | loss = list() 540 | 541 | for src_tgt, (src_index,tgt_index,weight) in pair_batches.items(): 542 | 543 | src,tgt = src_tgt.split('_') 544 | 545 | src_batch = self.node2vec_embeddings[src][src_index] 546 | tgt_batch = self.node2vec_embeddings[tgt][tgt_index] 547 | 548 | src_batch = src_batch.to(self.device) 549 | tgt_batch = tgt_batch.to(self.device) 550 | 551 | src_latent = self.models[f'encoder_{src}'](src_batch) 552 | tgt_latent = self.models[f'encoder_{tgt}'](tgt_batch) 553 | 554 | src_tgt_loss = F.pairwise_distance(src_latent,tgt_latent,p=2) 555 | 556 | weight = weight.to(self.device) 557 | 558 | src_tgt_loss = (-F.logsigmoid(self.gamma - src_tgt_loss)*weight).mean().unsqueeze(0) 559 | 560 | loss.append(src_tgt_loss) 561 | 562 | return torch.cat(loss).mean() 563 | 564 | def one_epoch(self,crt_epoch): 565 | 566 | loss_dict = {'epoch_loss':0,'reconstruction_loss':0,'alignment_loss':0} 567 | 568 | node_iterators = {taxid: iter(loader) for taxid, loader in self.node2vec_dataloader.items()} 569 | pair_iterators = {src_tgt: iter(loader) for src_tgt, loader in self.ortholog_dataloader.items()} 570 | 571 | for iter_ in tqdm(range(self.number_iters)): 572 | 573 | node_batches = {taxid: next(node_iterators[str(taxid)],None) for taxid in self.seed_species} 574 | 575 | pair_batches = {src_tgt: next(pair_iterators[src_tgt],None) for src_tgt in self.ortholog_pairs.keys()} 576 | 577 | self.optimizer.zero_grad() 578 | 579 | reconstruction_loss = self.reconstruction_loss(node_batches) * self.alpha 580 | 581 | alignment_loss = self.alignment_loss(pair_batches) * (1-self.alpha) 582 | 583 | loss = reconstruction_loss + alignment_loss 584 | 585 | loss_dict['reconstruction_loss'] += reconstruction_loss.item() 586 | loss_dict['alignment_loss'] += alignment_loss.item() 587 | loss_dict['epoch_loss'] += loss.item() 588 | 589 | loss.backward() 590 | 591 | self.optimizer.step() 592 | 593 | # log the losses 594 | for key,value in loss_dict.items(): 595 | self.writer.add_scalar(key,value,crt_epoch+1) 596 | 597 | logger.info(f'Epoch {crt_epoch+1} loss: {loss_dict["epoch_loss"]}\n \ 598 | reconstruction loss: {loss_dict["reconstruction_loss"]}\n \ 599 | alignment loss: {loss_dict["alignment_loss"]}') 600 | 601 | return tuple(loss_dict.values()) 602 | 603 | 604 | def fit(self): 605 | self.init_everything() 606 | 607 | for epoch in range(self.epochs): 608 | logger.info(f'Epoch {epoch+1}') 609 | epoch_loss, reconstruction_loss, alignment_loss = self.one_epoch(epoch) 610 | self.scheduler.step(alignment_loss,epoch) 611 | 612 | ## save the best model 613 | self.early_stopping(alignment_loss,self.models,self.model_save_path) 614 | 615 | if self.early_stopping.early_stop: 616 | logger.info('Early stopping') 617 | break 618 | 619 | logger.info(f'Training completed after {epoch+1} epochs') 620 | 621 | return None 622 | 623 | 624 | @torch.no_grad() 625 | def save_embeddings(self,species:int=None): 626 | 627 | if species is not None: 628 | species = [str(species)] 629 | else: 630 | species = self.seed_species 631 | 632 | for taxid in species: 633 | taxid = str(taxid) 634 | encoder = self.models[f'encoder_{taxid}'] 635 | data = self.node2vec_embeddings[str(taxid)].embedding 636 | encoder.eval() 637 | 638 | data = data.to(self.device) 639 | 640 | latent = encoder(data) 641 | 642 | latent = latent.cpu().detach().numpy() 643 | 644 | H5pyData.write(self.node2vec_embeddings[taxid].protein_names, 645 | latent, 646 | f'{self.embedding_save_folder}/{taxid}.h5', 647 | 16) 648 | return None 649 | 650 | 651 | 652 | 653 | class FedCoderNonSeed(FedCoder): 654 | # docstring here 655 | ''' 656 | Parameters 657 | ---------- 658 | seed_groups : str 659 | The file containing the seed species taxonomy ids. 660 | tax_group : str 661 | The file containing the taxonomy groups. 662 | non_seed_species : str|int 663 | The non seed species taxonomy id. 664 | node2vec_dir : str 665 | The directory containing the node2vec embeddings. 666 | aligned_dir : str 667 | The directory containing the aligned embeddings. 668 | ortholog_dir : str 669 | The directory containing the ortholog pairs. 670 | aligned_embedding_save_dir : str 671 | The directory to save the embeddings. 672 | save_top_k : int, optional 673 | The number of top models to save, by default 3. 674 | log_dir : str, optional 675 | The directory to save the logs, by default None. 676 | input_dim : int, optional 677 | The input dimension of the autoencoder, by default 128. 678 | latent_dim : int, optional 679 | The latent dimension of the autoencoder, by default 512. 680 | hidden_dims : list, optional 681 | The hidden dimensions of the autoencoder, by default None. 682 | activation_fn : str, optional 683 | The activation function of the autoencoder, by default None. Only useful when hidden_dims is not None. 684 | batch_norm : bool, optional 685 | Whether to use batch normalization, by default False. Only useful when hidden_dims is not None. 686 | number_iters : int, optional 687 | The number of iterations per epoch to train the model, by default 10. 688 | autoencoder_type : str, optional 689 | The type of autoencoder to use, by default 'naive'. 690 | gamma : float, optional 691 | The gamma parameter for the alignment loss, by default 0.1. 692 | alpha : float, optional 693 | The alpha parameter for balancing the alignment loss and reconstruction loss, by default 0.5. 694 | lr : float, optional 695 | The learning rate, by default 0.01. 696 | device : str, optional 697 | The device to use, by default 'cpu'. 698 | patience : int, optional 699 | The patience for early stopping, by default 5. 700 | delta : float, optional 701 | The delta parameter for early stopping, by default 0.0001. 702 | epochs : int, optional 703 | The number of epochs to train the model, by default 600. 704 | from_pretrained : str, optional 705 | The path to the pretrained model, by default None. 706 | ''' 707 | 708 | def __init__(self,seed_groups:str, 709 | tax_group:str, 710 | non_seed_species:str|int, 711 | node2vec_dir:str, 712 | aligned_dir:str, 713 | ortholog_dir:str, 714 | aligned_embedding_save_dir:str, 715 | save_top_k:int=3, 716 | log_dir:str=None, 717 | input_dim:int=128, 718 | latent_dim:int=512, 719 | hidden_dims:list=None, 720 | activation_fn:str=None, 721 | batch_norm:bool=False, 722 | number_iters:int=10, 723 | autoencoder_type:str='naive', 724 | gamma:float=0.1, 725 | alpha:float=0.5, 726 | lr:float=0.01, 727 | device:str='cpu', 728 | patience:int=5, 729 | delta:float=0.0001, 730 | epochs:int=600, 731 | from_pretrained:str=None, 732 | ) -> None: 733 | 734 | self.non_seed_species = int(non_seed_species) 735 | 736 | self.seed_groups = json.load(open(seed_groups,'r')) 737 | tax_group = pd.read_csv(tax_group,sep='\t') 738 | 739 | seed_species = tax_group[tax_group['taxid']==self.non_seed_species]['group'].values[0] 740 | self.seed_species = list(map(int,self.seed_groups[seed_species])) 741 | 742 | self.node2vec_dir = node2vec_dir 743 | self.aligned_dir = aligned_dir 744 | self.ortholog_dir = f'{ortholog_dir}/{self.non_seed_species}' 745 | self.embedding_save_folder = aligned_embedding_save_dir 746 | self.save_top_k = save_top_k 747 | log_dir = log_dir if log_dir is not None else os.path.join(aligned_embedding_save_dir,'logs') 748 | log_dir = os.path.join(log_dir, datetime.now().strftime("%Y%m%d-%H%M%S")) 749 | log_dir = log_dir + '-' + str(uuid4()) 750 | self.log_dir = log_dir 751 | self.model_save_path = os.path.join(self.log_dir,'model.pth') 752 | 753 | if not os.path.exists(self.log_dir): 754 | os.makedirs(self.log_dir) 755 | if not os.path.exists(self.embedding_save_folder): 756 | os.makedirs(self.embedding_save_folder) 757 | 758 | self.input_dim = input_dim 759 | self.latent_dim = latent_dim 760 | self.hidden_dims = hidden_dims 761 | self.activation_fn = activation_fn 762 | self.batch_norm = batch_norm 763 | self.number_iters = number_iters 764 | self.autoencoder_type = autoencoder_type 765 | self.gamma = gamma 766 | self.alpha = alpha 767 | self.lr = lr 768 | self.device = device 769 | self.patience = patience 770 | self.delta = delta 771 | self.epochs = epochs 772 | self.from_pretrained = from_pretrained 773 | 774 | 775 | def init_everything(self): 776 | def init_autoencoder(input_dim,latent_dim,hidden_dims,activation_fn,batch_norm,device): 777 | # to simply, use a function to initialize the autoencoder 778 | if self.autoencoder_type == 'naive': 779 | return Encoder(input_dim,latent_dim,hidden_dims,activation_fn).to(device), \ 780 | Decoder(latent_dim,input_dim,hidden_dims,activation_fn).to(device) 781 | elif self.autoencoder_type == 'vae': 782 | return VAEEncoder(input_dim,latent_dim,hidden_dims,activation_fn,batch_norm).to(device), \ 783 | VAEDecoder(latent_dim,input_dim,hidden_dims,activation_fn,batch_norm).to(device) 784 | else: 785 | raise ValueError('Unknown autoencoder type') 786 | 787 | 788 | ## load the node2vec embeddings 789 | logger.info('Loading the node2vec embeddings') 790 | self.node2vec_embeddings = dict() 791 | for species in self.seed_species: 792 | print(f'Loading {species} from {self.aligned_dir}/{species}.h5') 793 | self.node2vec_embeddings[str(species)] = NodeEmbedData(f'{self.aligned_dir}/{species}.h5') 794 | self.node2vec_embeddings[str(self.non_seed_species)] = NodeEmbedData(f'{self.node2vec_dir}/{self.non_seed_species}.h5') 795 | ## dataloader 796 | # only need the dataloader for non seed species 797 | self.node2vec_dataloader = {str(self.non_seed_species): 798 | DataLoader(self.node2vec_embeddings[str(self.non_seed_species)], 799 | batch_size=math.ceil(len(self.node2vec_embeddings[str(self.non_seed_species)])/self.number_iters), 800 | shuffle=True)} 801 | 802 | ## load the ortholog pairs 803 | logger.info('Loading the ortholog pairs') 804 | species_pairs = list(product(self.seed_species,[self.non_seed_species])) 805 | # sort the species pairs 806 | species_pairs = [sorted(pair) for pair in species_pairs] 807 | self.ortholog_pairs = dict() 808 | for src,tgt in species_pairs: 809 | self.ortholog_pairs[f'{src}_{tgt}'] = OrthologPair(f'{self.ortholog_dir}/{src}_{tgt}.tsv') 810 | 811 | ## dataloader 812 | self.ortholog_dataloader = {pair:DataLoader(ortholog, 813 | batch_size=math.ceil(len(ortholog)/self.number_iters), 814 | shuffle=True) 815 | for pair,ortholog in self.ortholog_pairs.items()} 816 | ## init the models: {'encoder_species':encoder,'decoder_species':decoder} 817 | logger.info('Initializing the models') 818 | self.models = dict() 819 | encoder, decoder = init_autoencoder(self.input_dim,self.latent_dim,self.hidden_dims,self.activation_fn,self.batch_norm,self.device) 820 | self.models[f'encoder_{self.non_seed_species}'] = encoder 821 | self.models[f'decoder_{self.non_seed_species}'] = decoder 822 | self.models = torch.nn.ModuleDict(self.models) 823 | 824 | if self.from_pretrained is not None: 825 | logger.info(f'Loading the pretrained model from {self.from_pretrained}') 826 | self.models.load_state_dict(torch.load(self.from_pretrained)) 827 | else: 828 | for model in self.models.values(): 829 | model.apply(init_xavier) 830 | 831 | ## init the optimizers 832 | self.parameters = [] 833 | for model in self.models.values(): 834 | self.parameters += list(model.parameters()) 835 | 836 | self.optimizer = torch.optim.Adam(self.parameters,lr=self.lr) 837 | 838 | self.scheduler = ReduceLROnPlateau(self.optimizer,'min',patience=3,factor=0.1) 839 | 840 | self.early_stopping = EarlyStopping(patience=self.patience,delta=self.delta,save_models=True) 841 | 842 | # init the tensorboard writer 843 | logger.info('Initializing the tensorboard writer') 844 | 845 | self.writer = SummaryWriter(self.log_dir) 846 | 847 | logger.info('Everything initialized') 848 | 849 | hyperparameters = { 'seed_species':self.seed_species, 850 | 'node2vec_dir':self.node2vec_dir, 851 | 'ortholog_dir':self.ortholog_dir, 852 | 'embedding_save_folder':self.embedding_save_folder, 853 | 'model_save_path':self.model_save_path, 854 | 'save_top_k':self.save_top_k, 855 | 'log_dir':self.log_dir, 856 | 'input_dim':self.input_dim, 857 | 'latent_dim':self.latent_dim, 858 | 'hidden_dims':self.hidden_dims, 859 | 'activation_fn':self.activation_fn, 860 | 'batch_norm':self.batch_norm, 861 | 'number_iters':self.number_iters, 862 | 'autoencoder_type':self.autoencoder_type, 863 | 'gamma':self.gamma, 864 | 'alpha':self.alpha, 865 | 'lr':self.lr, 866 | 'device':self.device, 867 | 'patience':self.patience, 868 | 'delta':self.delta, 869 | 'epochs':self.epochs, 870 | 'from_pretrained':self.from_pretrained,} 871 | self.save_hyperparameters(hyperparameters,self.log_dir) 872 | 873 | return None 874 | 875 | 876 | def alignment_loss(self, pair_batches): 877 | 878 | loss = list() 879 | 880 | for src_tgt, (src_index,tgt_index,weight) in pair_batches.items(): 881 | 882 | src,tgt = src_tgt.split('_') 883 | 884 | src_batch = self.node2vec_embeddings[src][src_index].to(self.device) 885 | tgt_batch = self.node2vec_embeddings[tgt][tgt_index].to(self.device) 886 | 887 | ## check if src or tgt is the non seed species 888 | if int(src) == self.non_seed_species: 889 | src_latent = self.models[f'encoder_{self.non_seed_species}'](src_batch) 890 | tgt_latent = tgt_batch 891 | else: 892 | src_latent = src_batch 893 | tgt_latent = self.models[f'encoder_{self.non_seed_species}'](tgt_batch) 894 | 895 | src_tgt_loss = F.pairwise_distance(src_latent,tgt_latent,p=2) 896 | 897 | weight = weight.to(self.device) 898 | 899 | src_tgt_loss = (-F.logsigmoid(self.gamma - src_tgt_loss)*weight) 900 | 901 | loss.append(src_tgt_loss) 902 | 903 | return torch.cat(loss).mean().unsqueeze(0) 904 | 905 | 906 | def one_epoch(self,crt_epoch): 907 | 908 | loss_dict = {'epoch_loss':0,'reconstruction_loss':0,'alignment_loss':0} 909 | 910 | node_iterators = {taxid: iter(loader) for taxid, loader in self.node2vec_dataloader.items()} 911 | pair_iterators = {src_tgt: iter(loader) for src_tgt, loader in self.ortholog_dataloader.items()} 912 | 913 | for iter_ in tqdm(range(self.number_iters)): 914 | 915 | node_batches = {self.non_seed_species: next(node_iterators[str(self.non_seed_species)],None)} 916 | 917 | pair_batches = {src_tgt: next(pair_iterators[src_tgt],None) for src_tgt in self.ortholog_pairs.keys()} 918 | 919 | self.optimizer.zero_grad() 920 | 921 | reconstruction_loss = self.reconstruction_loss(node_batches) * self.alpha 922 | 923 | alignment_loss = self.alignment_loss(pair_batches) * (1-self.alpha) 924 | 925 | loss = reconstruction_loss + alignment_loss 926 | 927 | loss_dict['reconstruction_loss'] += reconstruction_loss.item() 928 | loss_dict['alignment_loss'] += alignment_loss.item() 929 | loss_dict['epoch_loss'] += loss.item() 930 | 931 | loss.backward() 932 | 933 | self.optimizer.step() 934 | 935 | # log the losses 936 | for key,value in loss_dict.items(): 937 | self.writer.add_scalar(key,value,crt_epoch+1) 938 | 939 | logger.info(f'Epoch {crt_epoch+1} loss: {loss_dict["epoch_loss"]}\n \ 940 | reconstruction loss: {loss_dict["reconstruction_loss"]}\n \ 941 | alignment loss: {loss_dict["alignment_loss"]}') 942 | 943 | return tuple(loss_dict.values()) 944 | 945 | def fit(self): 946 | self.init_everything() 947 | 948 | for epoch in range(self.epochs): 949 | logger.info(f'Epoch {epoch+1}') 950 | epoch_loss, reconstruction_loss, alignment_loss = self.one_epoch(epoch) 951 | self.scheduler.step(alignment_loss,epoch) 952 | 953 | ## save the best model 954 | self.early_stopping(alignment_loss,self.models,self.model_save_path) 955 | 956 | if self.early_stopping.early_stop: 957 | logger.info('Early stopping') 958 | break 959 | 960 | logger.info(f'Training completed after {epoch+1} epochs') 961 | 962 | return None 963 | 964 | 965 | ## change the default to save only the non seed species 966 | @torch.no_grad() 967 | def save_embeddings(self, species: int=None): 968 | 969 | if species is not None: 970 | species = [str(species)] 971 | else: 972 | species = [str(self.non_seed_species)] 973 | 974 | encoder = self.models[f'encoder_{self.non_seed_species}'] 975 | 976 | data = self.node2vec_embeddings[str(self.non_seed_species)].embedding 977 | 978 | encoder.eval() 979 | 980 | data = data.to(self.device) 981 | 982 | latent = encoder(data) 983 | 984 | latent = latent.cpu().detach().numpy() 985 | 986 | H5pyData.write(self.node2vec_embeddings[str(self.non_seed_species)].protein_names, 987 | latent, 988 | f'{self.embedding_save_folder}/{self.non_seed_species}.h5', 989 | 16) 990 | return None 991 | --------------------------------------------------------------------------------