├── README.md ├── profun ├── __init__.py ├── evaluation │ ├── __init__.py │ └── metrics.py ├── models │ ├── __init__.py │ ├── blast_model.py │ ├── foldseek_model.py │ ├── hmm │ │ ├── __init__.py │ │ ├── hmm_dataclasses.py │ │ └── hmm_model.py │ └── ifaces │ │ ├── __init__.py │ │ ├── config_baseclasses.py │ │ └── model_baseclass.py └── utils │ ├── __init__.py │ ├── alphafold_struct_downloader.py │ ├── msa.py │ └── project_info.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # ProFun 2 | Library of models for **Pro**tein **Fun**ction prediction 3 | 4 | # Installation 5 | The majority of dependencies will be installed automatically via the command 6 | ``` 7 | pip install git+https://github.com/SamusRam/ProFun.git 8 | ``` 9 | 10 | If you want to use the BLAST-based model, please run these commands: 11 | ``` 12 | wget https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/2.14.0/ncbi-blast-2.14.0+-x64-linux.tar.gz 13 | tar zxvpf ncbi-blast-2.14.0+-x64-linux.tar.gz 14 | # add ncbi-blast-2.14.0+/bin to PATH 15 | ``` 16 | If you want to use profile Hidden Markov models, please run the following commands: 17 | ``` 18 | conda install -c bioconda mafft -y 19 | conda install -c bioconda hmmer -y 20 | ``` 21 | 22 | If you want to use Foldseek-based model, please run the following command: 23 | ``` 24 | conda install -c conda-forge -c bioconda foldseek -y 25 | ``` 26 | 27 | # Basic usage 28 | ## BLAST 29 | Please see [this notebook](https://www.kaggle.com/code/samusram/blastp-sprof-go) as a usage demo. 30 | 31 | ``` 32 | from profun.models import BlastMatching, BlastConfig 33 | from profun.utils.project_info import ExperimentInfo 34 | 35 | experiment_info = ExperimentInfo(validation_schema='public_lb', 36 | model_type='blast', model_version='1nn') 37 | 38 | config = BlastConfig(experiment_info=experiment_info, 39 | id_col_name='EntryID', 40 | target_col_name='term', 41 | seq_col_name='Seq', 42 | class_names=list(train_df_long['term'].unique()), 43 | optimize_hyperparams=False, 44 | n_calls_hyperparams_opt=None, 45 | hyperparam_dimensions=None, 46 | per_class_optimization=None, 47 | class_weights=None, 48 | n_neighbours=5, 49 | e_threshold=0.0001, 50 | n_jobs=100, 51 | pred_batch_size=10 52 | ) 53 | 54 | blast_model = BlastMatching(config) 55 | 56 | # fit 57 | blast_model.fit(train_df_long) 58 | 59 | # predict 60 | test_pred_df = blast_model.predict_proba(test_seqs_df.sample(42).drop_duplicates('EntryID'), return_long_df=True) 61 | ``` 62 | 63 | ## Profile Hidden Markov model 64 | ``` 65 | from profun.models import ProfileHMM, HmmConfig 66 | from profun.utils.project_info import ExperimentInfo 67 | 68 | experiment_info = ExperimentInfo(validation_schema='public_lb', 69 | model_type='profileHMM', model_version='24additional') 70 | 71 | config = HmmConfig(experiment_info=experiment_info, 72 | id_col_name='EntryID', 73 | target_col_name='term', 74 | seq_col_name='Seq', 75 | class_names=list(additional_classes), 76 | optimize_hyperparams=False, 77 | n_calls_hyperparams_opt=None, 78 | hyperparam_dimensions=None, 79 | per_class_optimization=None, 80 | class_weights=None, 81 | search_e_threshold=0.000001, 82 | zero_conf_level=0.00001, 83 | group_column_name='taxonomyID', 84 | n_jobs=56, 85 | pred_batch_size=20000) 86 | 87 | hmm_model = ProfileHMM(config) 88 | hmm_model.fit(train_df_long) 89 | test_pred_df = hmm_model.predict_proba(test_seqs_df.drop_duplicates('EntryID'), return_long_df=True) 90 | ``` 91 | 92 | ## Foldseek-based classifier 93 | Please see [this notebook](https://www.kaggle.com/code/samusram/leveraging-foldseek) as a usage demo. 94 | 95 | ``` 96 | from profun.models import FoldseekMatching, FoldseekConfig 97 | from profun.utils.project_info import ExperimentInfo 98 | 99 | experiment_info = ExperimentInfo(validation_schema='public_lb', 100 | model_type='foldseek', model_version='5nn') 101 | 102 | config = FoldseekConfig(experiment_info=experiment_info, 103 | id_col_name='EntryID', 104 | target_col_name='term', 105 | seq_col_name='Seq', 106 | class_names=list(train_df_long_sample['term'].unique()), 107 | optimize_hyperparams=False, 108 | n_calls_hyperparams_opt=None, 109 | hyperparam_dimensions=None, 110 | per_class_optimization=None, 111 | class_weights=None, 112 | n_neighbours=5, 113 | e_threshold=0.0001, 114 | n_jobs=56, 115 | pred_batch_size=10, 116 | local_pdb_storage_path=None #then it stores structures into the working dir 117 | ) 118 | 119 | model = FoldseekMatching(config) 120 | model.fit(train_df_long) 121 | test_pred_df = model.predict_proba(test_seqs_df.drop_duplicates('EntryID'), return_long_df=True) 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /profun/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SamusRam/ProFun/084028f6e3df999bb2625428aa466fe8437daf6a/profun/__init__.py -------------------------------------------------------------------------------- /profun/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """The init is added to enable finding all the submodules""" 2 | -------------------------------------------------------------------------------- /profun/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | """This module implements metrics computation""" 2 | 3 | import numpy as np # type: ignore 4 | import pandas as pd 5 | from sklearn.metrics import average_precision_score, recall_score # type: ignore 6 | import logging 7 | 8 | logging.basicConfig() 9 | logging.root.setLevel(logging.NOTSET) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def eval_model_mean_average_precision( 14 | model, val_df: pd.DataFrame, selected_class_name: str = None 15 | ): 16 | id_col_name = model.config.id_col_name 17 | val_df_unique = val_df.drop_duplicates(subset=id_col_name) 18 | gt_type = val_df[[id_col_name, model.config.target_col_name]].drop_duplicates() 19 | gt_type = ( 20 | gt_type.groupby(id_col_name)[model.config.target_col_name] 21 | .agg(set) 22 | .reset_index() 23 | ) 24 | gt_type.columns = [id_col_name, "target_set"] 25 | val_df_unique = val_df_unique.merge(gt_type, on=id_col_name) 26 | y_pred = model.predict_proba(val_df_unique) 27 | 28 | average_precisions = [] 29 | try: 30 | class_weights = model.config.class_weights 31 | if class_weights is None: 32 | class_weights = 1.0 33 | except AttributeError: 34 | class_weights = 1.0 35 | weights_sum = 0 36 | for class_i, class_name in enumerate(model.config.class_names): 37 | if selected_class_name is None or class_name == selected_class_name: 38 | y_true = val_df_unique["target_set"].map(lambda x: class_name in x) 39 | ap = average_precision_score(y_true, y_pred[:, class_i]) 40 | if isinstance(class_weights, float): 41 | class_weight = class_weights 42 | elif isinstance(class_weights, dict): 43 | class_weight = class_weights[class_name] 44 | else: 45 | raise NotImplementedError(f"Unexpected type {type(class_weights)} for the class_weight parameter.") 46 | ap_weighted = class_weight*ap 47 | average_precisions.append(ap_weighted) 48 | weights_sum += class_weight 49 | logger.info(f"{class_name}: ap = {ap:.3f}, weighted ap = {ap_weighted: .3f}") 50 | 51 | return 1 - np.sum(average_precisions)/weights_sum -------------------------------------------------------------------------------- /profun/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module with models and configs""" 2 | 3 | from .blast_model import BlastMatching, BlastConfig 4 | from .hmm import ProfileHMM, HmmConfig 5 | from .foldseek_model import FoldseekMatching, FoldseekConfig 6 | -------------------------------------------------------------------------------- /profun/models/blast_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import subprocess 6 | import sys 7 | from collections import Counter 8 | from dataclasses import dataclass 9 | from shutil import rmtree 10 | from typing import Type, Optional, Iterable 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from tqdm.auto import tqdm 15 | 16 | from profun.models.ifaces import BaseConfig, BaseModel 17 | 18 | logging.basicConfig() 19 | logging.root.setLevel(logging.NOTSET) 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | @dataclass 24 | class BlastConfig(BaseConfig): 25 | """ 26 | A data class to store Blast-model attributes 27 | """ 28 | 29 | n_neighbours: int 30 | e_threshold: float 31 | n_jobs: Optional[int] = 20 32 | pred_batch_size: Optional[int] = 10_000 33 | 34 | 35 | class BlastMatching(BaseModel): 36 | def __init__( 37 | self, 38 | config: BlastConfig, 39 | ): 40 | super().__init__( 41 | config=config, 42 | ) 43 | self.config = config 44 | self.working_directory = self.output_root / "_working_directory" 45 | if os.path.exists(self.working_directory): 46 | rmtree(self.working_directory) 47 | os.makedirs(self.working_directory) 48 | self.db_path = None 49 | self.train_df = None 50 | 51 | def get_fasta_seqs(self, df, type_name=None): 52 | df.drop_duplicates(subset=[self.config.id_col_name], inplace=True) 53 | if type_name is not None: 54 | seqs = df.loc[ 55 | df[self.config.target_col_name] == type_name, self.config.seq_col_name 56 | ].values 57 | ids = df.loc[ 58 | df[self.config.target_col_name] == type_name, self.config.id_col_name 59 | ].values 60 | else: 61 | seqs = df[self.config.seq_col_name].values 62 | ids = df[self.config.id_col_name].values 63 | full_entries = [f">{entry_id}\n{entry_seq}" for entry_id, entry_seq in zip(ids, seqs)] 64 | unique_ids = {el.replace("'", "").replace('"', "") for el in ids} 65 | logger.info(f"For class {type_name}, the number of duplicated ids is {len(ids) - len(unique_ids)}") 66 | fasta_str = "\n".join(full_entries) 67 | return fasta_str.replace("'", "").replace('"', "") 68 | 69 | def _train(self, df: pd.DataFrame) -> str: 70 | fasta_str = self.get_fasta_seqs(df) 71 | with open(f"{self.working_directory}/_temp.fasta", "w") as f: 72 | f.writelines(fasta_str) 73 | all_id_lines = [line for line in fasta_str.split() if ">" in line] 74 | logger.info( 75 | f"Written fasta file. Number of duplicated id lines: {len(all_id_lines) - len(set(all_id_lines))}" 76 | ) 77 | 78 | x = subprocess.check_output( 79 | f"makeblastdb -in {self.working_directory}/_temp.fasta -dbtype prot -parse_seqids".split(), 80 | stderr=sys.stdout, 81 | ) 82 | logger.info(f"makeblastdb output: {x}") 83 | 84 | return f"{self.working_directory}/_temp.fasta" 85 | 86 | def _predict(self, df: pd.DataFrame, db_name: str) -> str: 87 | test_fasta = self.get_fasta_seqs(df) 88 | with open(f"{self.working_directory}/_test.fasta", "w") as f: 89 | f.writelines(test_fasta.replace("'", "").replace('"', "")) 90 | if os.path.exists(f"{self.working_directory}/results_raw.csv"): 91 | os.remove(f"{self.working_directory}/results_raw.csv") 92 | os.system( 93 | f"blastp -db {db_name} -evalue {self.config.e_threshold} -query {self.working_directory}/_test.fasta -out {self.working_directory}/results_raw.csv -max_target_seqs {self.config.n_neighbours} -outfmt 10 -num_threads {self.config.n_jobs}" 94 | ) 95 | os.remove(f"{self.working_directory}/_test.fasta") 96 | return f"{self.working_directory}/results_raw.csv" 97 | 98 | def fit_core(self, train_df: pd.DataFrame, class_name: str = None): 99 | try: 100 | train_df.drop_duplicates( 101 | subset=[self.config.id_col_name, self.config.target_col_name], inplace=True 102 | ) 103 | except TypeError: 104 | try: 105 | train_df.drop_duplicates( 106 | subset=[self.config.id_col_name], inplace=True 107 | ) 108 | except TypeError: 109 | train_df[self.config.id_col_name] = train_df[self.config.id_col_name].map(lambda x: str(sorted(x))) 110 | train_df.drop_duplicates( 111 | subset=[self.config.id_col_name], inplace=True 112 | ) 113 | 114 | if (self.db_path is None or len(self.train_df) != len(train_df) or 115 | np.any(self.train_df[[self.config.id_col_name, self.config.target_col_name]].values != train_df[ 116 | [self.config.id_col_name, self.config.target_col_name]].values)): 117 | train_df.drop_duplicates(subset=[self.config.id_col_name], inplace=True) 118 | self.train_df = train_df.copy() 119 | self.db_path = self._train(train_df) 120 | 121 | def predict_proba(self, val_df: pd.DataFrame, return_long_df: bool = False) -> [np.ndarray | pd.DataFrame]: 122 | assert val_df[self.config.id_col_name].nunique() == len( 123 | val_df 124 | ), "Expected input to predict_proba without duplicated ids" 125 | if return_long_df: 126 | predicted_ids, predicted_classes, predicted_probs = [], [], [] 127 | else: 128 | all_predicted_batches = [] 129 | for batch_i in tqdm(range(len(val_df) // self.config.pred_batch_size + 1), 130 | desc='Predicting with BLASTp-matching..'): 131 | val_df_batch = val_df.iloc[ 132 | batch_i * self.config.pred_batch_size: (batch_i + 1) * self.config.pred_batch_size] 133 | if len(val_df_batch): 134 | output_path = self._predict(val_df_batch, self.db_path) 135 | blast_results_df = pd.read_csv( 136 | output_path, names=[f"{self.config.id_col_name}_blasted", "Matched ID"] + list(range(10)) 137 | ) 138 | train_df_with_targets = self.train_df[[self.config.id_col_name, self.config.target_col_name]] 139 | if np.any(self.train_df[self.config.target_col_name].map(lambda x: isinstance(x, Iterable))): 140 | train_df_with_targets = train_df_with_targets.explode(self.config.target_col_name) 141 | 142 | blasted_merged_with_train_df = blast_results_df.merge( 143 | train_df_with_targets, 144 | left_on="Matched ID", 145 | right_on=self.config.id_col_name, 146 | copy=False, 147 | ) 148 | label_and_nn_counts = (blasted_merged_with_train_df 149 | .groupby(f"{self.config.id_col_name}_blasted")[ 150 | [self.config.target_col_name, "Matched ID"]] 151 | .agg( 152 | {self.config.target_col_name: lambda x: [Counter(x)], "Matched ID": lambda x: [len(set(x))]}) 153 | .reset_index() 154 | ) 155 | label_and_nn_counts['prediction_dict'] = ( 156 | label_and_nn_counts[self.config.target_col_name] + label_and_nn_counts['Matched ID']).map( 157 | lambda x: {class_name: class_count / x[1] for class_name, class_count in x[0].items()}) 158 | label_and_nn_counts = label_and_nn_counts.merge( 159 | val_df_batch, left_on=f"{self.config.id_col_name}_blasted", 160 | right_on=self.config.id_col_name, how="right" 161 | ) 162 | if return_long_df: 163 | for _, row in label_and_nn_counts.iterrows(): 164 | if isinstance(row['prediction_dict'], dict): 165 | for class_name, class_prob in row['prediction_dict'].items(): 166 | predicted_ids.append(row[self.config.id_col_name]) 167 | predicted_classes.append(class_name) 168 | predicted_probs.append(class_prob) 169 | else: 170 | val_proba_np_batch = np.zeros((len(val_df_batch), len(self.config.class_names))) 171 | for class_i, class_name in enumerate(self.config.class_names): 172 | val_proba_np_batch[:, class_i] = label_and_nn_counts['prediction_dict'].map( 173 | lambda x: x[class_name] if isinstance(x, dict) and class_name in x else 0 174 | ) 175 | indices_batch = label_and_nn_counts[self.config.id_col_name].values 176 | orig_val_2_ord = {value: i for i, value in enumerate(val_df_batch[self.config.id_col_name])} 177 | order_of_predictions_in_orig_batch = sorted(range(len(indices_batch)), 178 | key=lambda idx: orig_val_2_ord[indices_batch[idx]]) 179 | all_predicted_batches.append(val_proba_np_batch[order_of_predictions_in_orig_batch]) 180 | if return_long_df: 181 | return pd.DataFrame({self.config.id_col_name: predicted_ids, 182 | self.config.target_col_name: predicted_classes, 183 | "probability": predicted_probs}) 184 | val_proba_np = all_predicted_batches[0] if len(all_predicted_batches) == 1 else np.concatenate( 185 | all_predicted_batches) 186 | return val_proba_np 187 | 188 | @classmethod 189 | def config_class(cls) -> Type[BlastConfig]: 190 | return BlastConfig 191 | -------------------------------------------------------------------------------- /profun/models/foldseek_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import subprocess 6 | import sys 7 | from collections import Counter 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from shutil import rmtree, move 11 | from typing import Type, Optional, List, Iterable 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from tqdm.auto import tqdm 16 | from uuid import uuid4 17 | 18 | from profun.models.ifaces import BaseModel 19 | from profun.models.blast_model import BlastConfig 20 | 21 | logging.basicConfig() 22 | logging.root.setLevel(logging.NOTSET) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class FoldseekConfig(BlastConfig): 28 | """ 29 | A data class to store Foldseek-model attributes 30 | """ 31 | local_pdb_storage_path: Optional[str | Path] = None 32 | 33 | 34 | class FoldseekMatching(BaseModel): 35 | def __init__( 36 | self, 37 | config: FoldseekConfig, 38 | ): 39 | super().__init__( 40 | config=config, 41 | ) 42 | self.config = config 43 | self.working_directory = self.output_root / "_working_directory" 44 | if os.path.exists(self.working_directory): 45 | rmtree(self.working_directory) 46 | os.makedirs(self.working_directory) 47 | if self.config.local_pdb_storage_path is None: 48 | self.local_pdb_storage_path = self.working_directory / "_trn_db" 49 | else: 50 | self.local_pdb_storage_path = Path(self.config.local_pdb_storage_path) 51 | self.local_pdb_storage_path.mkdir(exist_ok=True, parents=True) 52 | self.trn_db_path = None 53 | self.train_df = None 54 | 55 | def organize_db_folder(self, list_of_required_ids: List[str]): 56 | available_ids = {file.stem for file in self.local_pdb_storage_path.glob('*.pdb')} 57 | ids_to_download = [uniprot_id for uniprot_id in list_of_required_ids if uniprot_id not in available_ids] 58 | if len(ids_to_download): 59 | path_to_id_file = self.working_directory / "_temp_ids_list" 60 | with open(path_to_id_file, "w") as file: 61 | file.writelines('\n'.join(ids_to_download)) 62 | try: 63 | downloading_output = subprocess.check_output( 64 | f"python -m profun.utils.alphafold_struct_downloader --structures-output-path {self.local_pdb_storage_path} --path-to-file-with-ids {path_to_id_file} --n-jobs {self.config.n_jobs}".split(), 65 | ) 66 | logger.info(f"AlphaFold2 structure download finished with output: {downloading_output}") 67 | except subprocess.CalledProcessError: 68 | logger.error("AlphaFold2 structures downloading failed") 69 | raise subprocess.CalledProcessError 70 | finally: 71 | os.remove(path_to_id_file) 72 | # moving only the required ids; a possible alternative for the future: --tar-exclude option of foldseek createdb 73 | selection_path = self.working_directory / f"_{uuid4()}" 74 | selection_path.mkdir() 75 | for uniprot_id in tqdm(list_of_required_ids, desc="Moving the PDB files..."): 76 | filename = f"{uniprot_id}.pdb" 77 | move(self.local_pdb_storage_path/filename, selection_path/filename) 78 | return selection_path 79 | 80 | def move_pdbs_to_main_storage(self, direction_to_move: str | Path): 81 | direction_to_move = Path(direction_to_move) 82 | available_pdb_files = [file.name for file in direction_to_move.glob('*.pdb')] 83 | for filename in tqdm(available_pdb_files, desc="Moving the PDB files back..."): 84 | move(direction_to_move/filename, self.local_pdb_storage_path/filename) 85 | rmtree(direction_to_move) 86 | 87 | def _train(self, df: pd.DataFrame) -> str: 88 | list_of_required_trn_ids = list(set(df[self.config.id_col_name].values)) 89 | trn_structs_path = self.organize_db_folder(list_of_required_trn_ids) 90 | logger.info( 91 | f"Prepared Foldseek trn folder" 92 | ) 93 | createdb_out = subprocess.check_output( 94 | f"foldseek createdb {trn_structs_path} {self.working_directory/'trn_db'} --threads {self.config.n_jobs}".split(), 95 | stderr=sys.stdout, 96 | ) 97 | logger.info(f"Trn DB, foldseek createdb output: {createdb_out}") 98 | # moving back the additional ids 99 | self.move_pdbs_to_main_storage(trn_structs_path) 100 | return self.working_directory/"trn_db" 101 | 102 | def _predict(self, df: pd.DataFrame, trn_db_name: str) -> str: 103 | list_of_required_trn_ids = list(set(df[self.config.id_col_name].values)) 104 | query_structs_path = self.organize_db_folder(list_of_required_trn_ids) 105 | logger.info( 106 | f"Prepared Foldseek query folder" 107 | ) 108 | createdb_out = subprocess.check_output( 109 | f"foldseek createdb {query_structs_path} {self.working_directory/'query_db'} --threads {self.config.n_jobs}".split(), 110 | stderr=sys.stdout, 111 | ) 112 | logger.info(f"Query DB, foldseek createdb output: {createdb_out}") 113 | self.move_pdbs_to_main_storage(query_structs_path) 114 | search_out = subprocess.check_output(f"foldseek search {self.working_directory/'query_db'} {self.trn_db_path} {self.working_directory}/resultDB tmp -e {self.config.e_threshold} --max-seqs {self.config.n_neighbours}".split(), 115 | stderr=sys.stdout) 116 | logger.info(f"Foldseek search output: {search_out}") 117 | result_conversion_out = subprocess.check_output(f"foldseek convertalis {self.working_directory/'query_db'} {self.trn_db_path} {self.working_directory}/resultDB {self.working_directory}/result.tsv --format-output query,target,evalue".split(), 118 | stderr=sys.stdout) 119 | logger.info(f"Result conversion output: {result_conversion_out}") 120 | return f"{self.working_directory}/result.tsv" 121 | 122 | def fit_core(self, train_df: pd.DataFrame, class_name: str = None): 123 | try: 124 | train_df.drop_duplicates( 125 | subset=[self.config.id_col_name, self.config.target_col_name], inplace=True 126 | ) 127 | except TypeError: 128 | try: 129 | train_df.drop_duplicates( 130 | subset=[self.config.id_col_name], inplace=True 131 | ) 132 | except TypeError: 133 | train_df[self.config.id_col_name] = train_df[self.config.id_col_name].map(lambda x: str(sorted(x))) 134 | train_df.drop_duplicates( 135 | subset=[self.config.id_col_name], inplace=True 136 | ) 137 | if (self.trn_db_path is None or len(self.train_df) != len(train_df) or 138 | np.any(self.train_df[[self.config.id_col_name, self.config.target_col_name]].values != train_df[ 139 | [self.config.id_col_name, self.config.target_col_name]].values)): 140 | train_df.drop_duplicates(subset=[self.config.id_col_name], inplace=True) 141 | self.train_df = train_df.copy() 142 | self.trn_db_path = self._train(train_df.drop_duplicates(subset=[self.config.id_col_name])) 143 | 144 | def predict_proba(self, val_df: pd.DataFrame, return_long_df: bool = False) -> [np.ndarray | pd.DataFrame]: 145 | assert val_df[self.config.id_col_name].nunique() == len( 146 | val_df 147 | ), "Expected input to predict_proba without duplicated ids" 148 | if return_long_df: 149 | predicted_ids, predicted_classes, predicted_probs = [], [], [] 150 | else: 151 | all_predicted_batches = [] 152 | for batch_i in tqdm(range(len(val_df) // self.config.pred_batch_size + 1), 153 | desc='Predicting with Foldseek-matching..'): 154 | val_df_batch = val_df.iloc[ 155 | batch_i * self.config.pred_batch_size: (batch_i + 1) * self.config.pred_batch_size] 156 | if len(val_df_batch): 157 | output_path = self._predict(val_df_batch, self.trn_db_path) 158 | results_df = pd.read_csv( 159 | output_path, sep='\t', header=None, names=[f"{self.config.id_col_name}_queried", f"{self.config.id_col_name}_matched", "evalue"], 160 | ) 161 | for colname in [f"{self.config.id_col_name}_queried", f"{self.config.id_col_name}_matched"]: 162 | results_df[colname] = results_df[colname].map(lambda x: x.replace(".pdb", "")) 163 | train_df_with_targets = self.train_df[[self.config.id_col_name, self.config.target_col_name]] 164 | if np.any(self.train_df[self.config.target_col_name].map(lambda x: isinstance(x, Iterable))): 165 | train_df_with_targets = train_df_with_targets.explode(self.config.target_col_name) 166 | results_merged_with_train_df = results_df.merge( 167 | train_df_with_targets, 168 | left_on=f"{self.config.id_col_name}_matched", 169 | right_on=self.config.id_col_name, 170 | copy=False, 171 | ) 172 | label_and_nn_counts = (results_merged_with_train_df 173 | .groupby(f"{self.config.id_col_name}_queried")[ 174 | [self.config.target_col_name, f"{self.config.id_col_name}_matched"]] 175 | .agg( 176 | {self.config.target_col_name: lambda x: [Counter(x)], f"{self.config.id_col_name}_matched": lambda x: [len(set(x))]}) 177 | .reset_index() 178 | ) 179 | label_and_nn_counts['prediction_dict'] = ( 180 | label_and_nn_counts[self.config.target_col_name] + label_and_nn_counts[f"{self.config.id_col_name}_matched"]).map( 181 | lambda x: {class_name: class_count / x[1] for class_name, class_count in x[0].items()}) 182 | label_and_nn_counts = label_and_nn_counts.merge( 183 | val_df_batch, left_on=f"{self.config.id_col_name}_queried", 184 | right_on=self.config.id_col_name, how="right" 185 | ) 186 | if return_long_df: 187 | for _, row in label_and_nn_counts.iterrows(): 188 | if isinstance(row['prediction_dict'], dict): 189 | for class_name, class_prob in row['prediction_dict'].items(): 190 | predicted_ids.append(row[self.config.id_col_name]) 191 | predicted_classes.append(class_name) 192 | predicted_probs.append(class_prob) 193 | else: 194 | val_proba_np_batch = np.zeros((len(val_df_batch), len(self.config.class_names))) 195 | for class_i, class_name in enumerate(self.config.class_names): 196 | val_proba_np_batch[:, class_i] = label_and_nn_counts['prediction_dict'].map( 197 | lambda x: x[class_name] if isinstance(x, dict) and class_name in x else 0 198 | ) 199 | indices_batch = label_and_nn_counts[self.config.id_col_name].values 200 | orig_val_2_ord = {value: i for i, value in enumerate(val_df_batch[self.config.id_col_name])} 201 | order_of_predictions_in_orig_batch = sorted(range(len(indices_batch)), 202 | key=lambda idx: orig_val_2_ord[indices_batch[idx]]) 203 | all_predicted_batches.append(val_proba_np_batch[order_of_predictions_in_orig_batch]) 204 | if return_long_df: 205 | return pd.DataFrame({self.config.id_col_name: predicted_ids, 206 | self.config.target_col_name: predicted_classes, 207 | "probability": predicted_probs}) 208 | val_proba_np = all_predicted_batches[0] if len(all_predicted_batches) == 1 else np.concatenate( 209 | all_predicted_batches) 210 | return val_proba_np 211 | 212 | @classmethod 213 | def config_class(cls) -> Type[FoldseekConfig]: 214 | return FoldseekConfig 215 | -------------------------------------------------------------------------------- /profun/models/hmm/__init__.py: -------------------------------------------------------------------------------- 1 | """ A module for direct importing 2 | """ 3 | 4 | from .hmm_model import ProfileHMM 5 | from .hmm_dataclasses import HmmConfig 6 | -------------------------------------------------------------------------------- /profun/models/hmm/hmm_dataclasses.py: -------------------------------------------------------------------------------- 1 | """Defining self-explainable datastructures""" 2 | 3 | from dataclasses import dataclass 4 | from functools import total_ordering 5 | from typing import Optional 6 | 7 | from profun.models.ifaces import BaseConfig 8 | 9 | 10 | @dataclass 11 | class HmmConfig(BaseConfig): 12 | """ 13 | A config class for profile HMM 14 | """ 15 | 16 | search_e_threshold: float 17 | zero_conf_level: float 18 | group_column_name: Optional[str] = None 19 | n_jobs: Optional[int] = 56 20 | pred_batch_size: Optional[int] = 10000 21 | 22 | 23 | @total_ordering 24 | @dataclass 25 | class HmmPrediction: 26 | """ 27 | A data class to store and post-process profile HMM predictions 28 | """ 29 | 30 | e_value: float 31 | score: float 32 | id: str 33 | prediction_label: str 34 | 35 | def __eq__(self, other): 36 | return self.score == other.score 37 | 38 | def __lt__(self, other): 39 | return self.score < other.score 40 | 41 | -------------------------------------------------------------------------------- /profun/models/hmm/hmm_model.py: -------------------------------------------------------------------------------- 1 | """This class implements profile Hidden Markov model""" 2 | from __future__ import annotations 3 | 4 | import logging 5 | import os 6 | import pickle 7 | import uuid 8 | from itertools import groupby 9 | from shutil import rmtree 10 | from typing import Dict, List, Type, Iterable 11 | 12 | import numpy as np 13 | import pandas as pd # type: ignore 14 | from tqdm.auto import tqdm 15 | 16 | from profun.models.ifaces import BaseModel 17 | from profun.utils.msa import get_fasta_seqs, generate_msa_mafft 18 | from .hmm_dataclasses import HmmConfig, HmmPrediction 19 | 20 | logging.basicConfig() 21 | logging.root.setLevel(logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def read_predictions_from_file( 26 | file_path: str, prediction_label: str 27 | ) -> List[HmmPrediction]: 28 | """ 29 | The function gathers predictions from outputs on the disk 30 | :param file_path: a path to file with raw prediction 31 | :param prediction_label: label of interest 32 | :return: list of predictions 33 | """ 34 | if not os.path.exists(file_path): 35 | return [] 36 | with open(file_path, "r", encoding="utf8") as file: 37 | lines_file = file.readlines() 38 | lines = [line.split() for line in lines_file[3:]] 39 | predictions = [] 40 | i = 0 41 | if len(lines) == 0: 42 | return [] 43 | while len(lines[i]) > 1: 44 | line = lines[i] 45 | predictions.append( 46 | HmmPrediction(float(line[4]), float(line[5]), line[0], prediction_label) 47 | ) 48 | i += 1 49 | return predictions 50 | 51 | 52 | class ProfileHMM(BaseModel): 53 | """Class with the profile HMM algorithm implementation""" 54 | 55 | def __init__(self, config: HmmConfig): 56 | super().__init__( 57 | config=config, 58 | ) 59 | self.config = config 60 | self.working_directory = self.output_root / "_out" 61 | if os.path.exists(self.working_directory): 62 | rmtree(self.working_directory) 63 | os.makedirs(self.working_directory) 64 | self.class_name_2_path_to_model_path: Dict[str, str] 65 | self.class_2_groups = None 66 | self.class_name_2_path_to_model_paths = None 67 | 68 | def prep_fasta_seqs( 69 | self, df: pd.DataFrame, type_name: str = None, group_name: str = None 70 | ) -> str: 71 | """ 72 | This function prepares inputs for a particular class as a fasta format 73 | :param df: input dataframe 74 | :param type_name: class name 75 | :param group_name: name of a group (e.g. a clade) 76 | :return: string representation in a fasta format 77 | """ 78 | if type_name is not None and group_name is not None: 79 | df_subset = df.loc[ 80 | (df[self.config.target_col_name].map(lambda x: (isinstance(x, str) and x == type_name) or (isinstance(x, set) and type_name in x))) 81 | & (df[self.config.group_column_name] == group_name) 82 | ] 83 | seqs = df_subset[self.config.seq_col_name].values 84 | ids = df_subset[self.config.id_col_name].values 85 | else: 86 | seqs = df[self.config.seq_col_name].values 87 | ids = df[self.config.id_col_name].values 88 | return get_fasta_seqs(seqs, ids) 89 | 90 | def _train_for_class_group( 91 | self, df: pd.DataFrame, class_name: str, group_name: str 92 | ) -> str: 93 | """ 94 | This function trains a HMM predictor for a given class 95 | :param df: input dataframe 96 | :param class_name: class name 97 | :return: path to a stored HMM 98 | """ 99 | fasta_str = self.prep_fasta_seqs(df, class_name, group_name) 100 | 101 | logger.info( 102 | "Training for class %s, group %s, fasta size: %d", 103 | class_name, 104 | group_name, 105 | len(fasta_str.split(">")), 106 | ) 107 | 108 | model_id = str(uuid.uuid4()) 109 | generate_msa_mafft(fasta_str=fasta_str, 110 | output_name=f"{self.working_directory}/{model_id}_msa.out", 111 | n_jobs=self.config.n_jobs, 112 | clustal_output_format=False) 113 | # check number of lines in the msa file 114 | with open(f"{self.working_directory}/{model_id}_msa.out", "r") as file: 115 | msa_lines = file.readlines() 116 | if len(msa_lines) == 0: 117 | raise ValueError("Empty MSA file") 118 | os.system( 119 | f"hmmbuild {self.working_directory}/{model_id}.hmm {self.working_directory}/{model_id}_msa.out" 120 | ) 121 | return f"{self.working_directory}/{model_id}.hmm" 122 | 123 | def predict_for_class_group( 124 | self, df: pd.DataFrame, class_name: str, group: str 125 | ) -> str: 126 | """ 127 | Prediction of the specified class 128 | :param df: input dataframe 129 | :param class_name: class to predict 130 | :return: path to a table with predictions 131 | """ 132 | test_fasta = self.prep_fasta_seqs(df) 133 | with open( 134 | f"{self.working_directory}/_test.fasta", 135 | "w", 136 | encoding="utf8", 137 | ) as file: 138 | file.writelines(test_fasta.replace("'", "").replace('"', "")) 139 | 140 | # logger.info(f'Predicting for class {class_name}, fasta size: {len(test_fasta.split(">"))}') 141 | 142 | result_id = str(uuid.uuid4()) 143 | assert ( 144 | self.class_name_2_path_to_model_paths is not None 145 | ), "Predicting class before training the profile HMM model" 146 | os.system( 147 | f"hmmsearch -E {self.config.search_e_threshold} --tblout {self.working_directory}/_{result_id}.tbl {self.class_name_2_path_to_model_paths[(class_name, group)]} {self.working_directory}/_test.fasta > {self.working_directory}/_{result_id}.out" 148 | ) 149 | return f"{self.working_directory}/_{result_id}.tbl" 150 | 151 | def aggregate_predictions(self, class_name_2_pred_path: Dict[tuple[str, str], str], do_major_class_agg: bool = False 152 | ) -> pd.DataFrame: 153 | """ 154 | The function aggregates prediction analogously to Terzyme algorithm https://link.springer.com/article/10.1186/s13007-017-0269-0 if do_major_class_agg == True 155 | :return: a dataframe with predictions 156 | """ 157 | class_name_2_pred_list = {} 158 | for (class_name, kingdom), prediction_path in class_name_2_pred_path.items(): 159 | class_name_2_pred_list[(class_name, kingdom)] = read_predictions_from_file( 160 | prediction_path, class_name 161 | ) 162 | os.remove(prediction_path) 163 | os.remove(prediction_path.replace(".tbl", ".out")) 164 | if do_major_class_agg: 165 | all_predictions_sorted = sorted( 166 | sum(class_name_2_pred_list.values(), []), key=lambda x: x.id 167 | ) 168 | predictions_dict = { 169 | uniprot_id: max(predictions) 170 | for uniprot_id, predictions in groupby( 171 | all_predictions_sorted, key=lambda x: x.id 172 | ) 173 | } 174 | predictions_to_output = predictions_dict.values() 175 | else: 176 | predictions_to_output = sum(class_name_2_pred_list.values(), []) 177 | ids_list = [] 178 | e_val_list = [] 179 | y_pred_list = [] 180 | for prediction in predictions_to_output: 181 | ids_list.append(prediction.id) 182 | e_val_list.append(prediction.e_value) 183 | y_pred_list.append(prediction.prediction_label) 184 | 185 | predictions_df = pd.DataFrame( 186 | {self.config.id_col_name: ids_list, 187 | self.config.target_col_name: y_pred_list, 188 | "E": e_val_list} 189 | ) 190 | return predictions_df 191 | 192 | def fit_core(self, train_df: pd.DataFrame, class_name: str = None): 193 | # assert isinstance( 194 | # self.config, HmmConfig 195 | # ), "HHM config instance is expected to be of type HmmConfig" 196 | if self.config.group_column_name is None: 197 | train_df[self.config.group_column_name] = "all" 198 | try: 199 | train_df.drop_duplicates( 200 | subset=[self.config.id_col_name, self.config.target_col_name], inplace=True 201 | ) 202 | except TypeError: 203 | try: 204 | train_df.drop_duplicates( 205 | subset=[self.config.id_col_name], inplace=True 206 | ) 207 | except TypeError: 208 | train_df[self.config.id_col_name] = train_df[self.config.id_col_name].map(lambda x: str(sorted(x))) 209 | train_df.drop_duplicates( 210 | subset=[self.config.id_col_name], inplace=True 211 | ) 212 | 213 | logger.info("Train size: %d", len(train_df)) 214 | if self.config.class_names is None: 215 | self.config.class_names = [ 216 | x 217 | for x in train_df[self.config.target_col_name].unique() # TODO: handle sets 218 | if not pd.isnull(x) 219 | ] 220 | 221 | self.class_2_groups = { 222 | class_name: train_df.loc[ 223 | train_df[self.config.target_col_name].map(lambda x: (isinstance(x, str) and x == class_name) or (isinstance(x, set) and class_name in x)), 224 | self.config.group_column_name, 225 | ].unique() 226 | for class_name in self.config.class_names 227 | } 228 | self.class_name_2_path_to_model_paths = dict() 229 | for class_name in self.config.class_names: 230 | for kingdom in self.class_2_groups[class_name]: 231 | n_samples = sum( 232 | (train_df[self.config.target_col_name].map(lambda x: (isinstance(x, str) and x == class_name) or (isinstance(x, set) and class_name in x))) 233 | & (train_df[self.config.group_column_name] == kingdom) 234 | ) 235 | logger.info("For a class %s and group %s there are %d training samples", 236 | class_name, kingdom, n_samples) 237 | if n_samples >= 2: 238 | try: 239 | self.class_name_2_path_to_model_paths[(class_name, kingdom)] = self._train_for_class_group( 240 | train_df, class_name=class_name, group_name=kingdom) 241 | except ValueError: 242 | continue 243 | 244 | with open( 245 | f"{self.working_directory}/class_name_2_path_to_model_paths.pkl", 246 | "wb", 247 | ) as file: 248 | pickle.dump(self.class_name_2_path_to_model_paths, file) 249 | 250 | def predict_proba(self, val_df: pd.DataFrame, return_long_df: bool = False) -> np.ndarray | pd.DataFrame: 251 | if self.config.group_column_name is None: 252 | val_df[self.config.group_column_name] = "all" 253 | assert val_df[self.config.id_col_name].nunique() == len( 254 | val_df 255 | ), "Expected input to predict_proba without duplicated ids" 256 | logger.info("Val size: %d", len(val_df)) 257 | assert ( 258 | self.class_name_2_path_to_model_paths is not None 259 | ), "Predicting before training the HMM model" 260 | assert ( 261 | self.config.class_names is not None 262 | ), "Class names were not derived and stored during training" 263 | batch_results = [] 264 | 265 | for batch_i in tqdm(range(len(val_df) // self.config.pred_batch_size + 1), 266 | desc='Predicting with Profile HMM..'): 267 | val_df_batch = val_df.iloc[ 268 | batch_i * self.config.pred_batch_size: (batch_i + 1) * self.config.pred_batch_size] 269 | if len(val_df_batch): 270 | class_name_2_pred_path = { 271 | (class_name, kingdom): self.predict_for_class_group( 272 | val_df_batch, class_name, kingdom 273 | ) 274 | for class_name in self.config.class_names 275 | for kingdom in self.class_2_groups[class_name] 276 | if (class_name, kingdom) in self.class_name_2_path_to_model_paths 277 | } 278 | pred_df = self.aggregate_predictions(class_name_2_pred_path) 279 | pred_df = pred_df.merge(val_df_batch[[self.config.id_col_name, self.config.seq_col_name]], on=self.config.id_col_name, how="right").set_index( 280 | self.config.id_col_name 281 | ) 282 | pred_df["probability"] = pred_df["E"] 283 | pred_df.loc[ 284 | pred_df["probability"] > self.config.zero_conf_level, "probability" 285 | ] = self.config.zero_conf_level 286 | pred_df["probability"] /= self.config.zero_conf_level 287 | pred_df.loc[pred_df["probability"].isnull(), "probability"] = 1 288 | pred_df["probability"] = 1 - pred_df["probability"] 289 | 290 | if return_long_df: 291 | batch_results.append(pred_df.loc[~pred_df[self.config.target_col_name].isnull(), 292 | [self.config.id_col_name, 293 | self.config.target_col_name, 294 | "probability"]]) 295 | else: 296 | val_proba_np = np.zeros((len(val_df_batch), len(self.config.class_names))) 297 | pred_df = pred_df.groupby(self.config.id_col_name)[[self.config.target_col_name, "probability"]].agg(list) 298 | pred_df = pred_df.loc[val_df_batch[self.config.id_col_name]] 299 | for class_i, class_name in enumerate(self.config.class_names): 300 | bool_idx = (pred_df[self.config.target_col_name].map(lambda x: class_name in x)).values 301 | if sum(bool_idx): 302 | val_proba_np[bool_idx, class_i] = pred_df[bool_idx].apply(lambda row: row["probability"][row[self.config.target_col_name].index(class_name)], axis=1) 303 | batch_results.append(val_proba_np) 304 | if return_long_df: 305 | return pd.concat(batch_results) 306 | return np.concatenate(batch_results) 307 | 308 | @classmethod 309 | def config_class(cls) -> Type[HmmConfig]: 310 | return HmmConfig 311 | -------------------------------------------------------------------------------- /profun/models/ifaces/__init__.py: -------------------------------------------------------------------------------- 1 | """The module for abstract classes""" 2 | 3 | from .config_baseclasses import BaseConfig 4 | from .model_baseclass import BaseModel 5 | -------------------------------------------------------------------------------- /profun/models/ifaces/config_baseclasses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | """This module defines an abstract class for models""" 3 | 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import yaml # type: ignore 9 | 10 | from profun.utils.project_info import ExperimentInfo 11 | 12 | 13 | @dataclass 14 | class BaseConfig: 15 | """ 16 | A data class to store model attributes 17 | """ 18 | 19 | experiment_info: ExperimentInfo 20 | id_col_name: str 21 | target_col_name: str 22 | seq_col_name: str 23 | class_names: list[str] 24 | optimize_hyperparams: bool 25 | n_calls_hyperparams_opt: int 26 | hyperparam_dimensions: dict[ 27 | str, 28 | ] 29 | per_class_optimization: bool 30 | class_weights: dict[str, float] 31 | 32 | @classmethod 33 | def load(cls, path_to_config: Union[str, Path]) -> dict: 34 | """ 35 | This class function loads config from a configs folder 36 | :param path_to_config: 37 | :return: a dictionary loaded from the config yaml 38 | """ 39 | with open(path_to_config, encoding="utf-8") as file: 40 | configs_dict = yaml.load(file, Loader=yaml.FullLoader) 41 | return configs_dict 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /profun/models/ifaces/model_baseclass.py: -------------------------------------------------------------------------------- 1 | """This module defines an abstract class for models""" 2 | from __future__ import annotations 3 | 4 | import inspect 5 | import json 6 | import os 7 | import pickle 8 | from abc import ABC, abstractmethod 9 | from copy import deepcopy 10 | from datetime import datetime 11 | from pathlib import Path 12 | from typing import Type 13 | 14 | import numpy as np # type: ignore 15 | import pandas as pd # type: ignore 16 | from sklearn.base import BaseEstimator 17 | from sklearn.model_selection import KFold, cross_val_score, StratifiedKFold # type: ignore 18 | from skopt import gp_minimize # type: ignore 19 | from skopt.space import Categorical, Integer, Real # type: ignore 20 | from skopt.utils import use_named_args # type: ignore 21 | 22 | from profun.evaluation.metrics import eval_model_mean_average_precision 23 | from profun.models.ifaces.config_baseclasses import BaseConfig 24 | from profun.utils.project_info import get_output_root 25 | import logging 26 | 27 | logging.basicConfig() 28 | logging.root.setLevel(logging.NOTSET) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class BaseModel(ABC, BaseEstimator): 33 | """Base model class with an abstract method train_and_predict""" 34 | 35 | def __init__( 36 | self, 37 | config: BaseConfig, 38 | ): 39 | self.config = config 40 | assert isinstance(self.config.experiment_info.timestamp, datetime) 41 | if self.config.experiment_info.model_type is None: 42 | self.config.experiment_info.model = self.__class__.__name__ 43 | self.output_root = ( 44 | get_output_root() 45 | / config.experiment_info.model_type 46 | / config.experiment_info.model_version 47 | / config.experiment_info.validation_schema 48 | / config.experiment_info.fold 49 | / config.experiment_info.timestamp.strftime("%Y%m%d-%H%M%S") 50 | ) 51 | self.output_root.mkdir(exist_ok=True, parents=True) 52 | self.classifier_class = None 53 | try: 54 | self.per_class_optimization = self.config.per_class_optimization 55 | if self.per_class_optimization is None: 56 | self.per_class_optimization = False 57 | except AttributeError: 58 | self.per_class_optimization = False 59 | if self.per_class_optimization: 60 | self.class_2_classifier = dict() 61 | 62 | @abstractmethod 63 | def fit_core(self, train_df: pd.DataFrame, class_name: str = None): 64 | """ 65 | Function for training model instance 66 | :param train_df: pandas dataframe containing training data 67 | :param class_name: name of a class for the separate model fitting for the class 68 | """ 69 | raise NotImplementedError 70 | 71 | # TODO: add proper map metric to hyperparam tuning 72 | def fit(self, train_df: pd.DataFrame): 73 | """ 74 | Fit function 75 | :param train_df: pandas dataframe containing training data 76 | """ 77 | try: 78 | per_class_optimization = self.config.per_class_optimization 79 | except AttributeError: 80 | per_class_optimization = False 81 | if self.config.optimize_hyperparams: 82 | try: 83 | n_fold_splits = self.config.n_fold_splits 84 | except AttributeError: 85 | n_fold_splits = 5 86 | try: 87 | use_cross_validation = self.config.use_cross_validation 88 | except AttributeError: 89 | use_cross_validation = True 90 | try: 91 | reuse_existing_partial_results = ( 92 | self.config.reuse_existing_partial_results 93 | ) 94 | except AttributeError: 95 | reuse_existing_partial_results = False 96 | 97 | self.optimize_hyperparameters( 98 | train_df, 99 | n_calls=self.config.n_calls_hyperparams_opt, 100 | per_class_optimization=per_class_optimization, 101 | n_fold_splits=n_fold_splits, 102 | use_cross_validation=use_cross_validation, 103 | reuse_existing_partial_results=reuse_existing_partial_results, 104 | **self.config.hyperparam_dimensions, 105 | ) 106 | try: 107 | load_per_class_params_from = self.config.load_per_class_params_from 108 | if load_per_class_params_from is None: 109 | load_per_class_params_from = False 110 | except AttributeError: 111 | load_per_class_params_from = False 112 | if load_per_class_params_from: 113 | load_per_class_params_from = Path(load_per_class_params_from) 114 | previous_results = list( 115 | load_per_class_params_from.glob( 116 | f"*/hyperparameters_optimization/best_params_*.json" 117 | ) 118 | ) 119 | assert len( 120 | previous_results 121 | ), f"Requested to load per-class parameters from {load_per_class_params_from}, but no parameters in json are found" 122 | logger.info(f"Loading hyper parameters from: {previous_results[0]}") 123 | with open(previous_results[0], "r") as file: 124 | best_params = json.load(file) 125 | self.set_params(**best_params) 126 | 127 | if ( 128 | self.config.optimize_hyperparams or load_per_class_params_from 129 | ) and per_class_optimization: 130 | for class_name in self.config.class_names: 131 | self.fit_core(train_df, class_name=class_name) 132 | else: 133 | self.fit_core(train_df) 134 | 135 | @abstractmethod 136 | def predict_proba(self, val_df: pd.DataFrame, return_long_df: bool = False) -> [np.ndarray | pd.DataFrame]: 137 | """ 138 | Model predict method 139 | :param val_df: pandas dataframe containing instances to score the model on 140 | :param return_long_df: flag to return predictions in a long format dataframe with columns [id, target_class, prob] 141 | :returns predicted class probabilities either as wide numpy array or long pandas dataframe 142 | """ 143 | raise NotImplementedError 144 | 145 | def set_params(self, **kwargs): 146 | """ 147 | It's a generic function setting values of all parameters in the kwargs 148 | """ 149 | for attribute_name, value in kwargs.items(): 150 | if attribute_name not in {"class_name", "per_class"}: 151 | self.__setattr__(attribute_name, value if value != "None" else None) 152 | 153 | def optimize_hyperparameters( 154 | self, 155 | train_df: pd.DataFrame, 156 | n_calls: int, 157 | per_class_optimization: bool, 158 | n_fold_splits: int, 159 | use_cross_validation: bool, 160 | reuse_existing_partial_results: bool, 161 | **dimension_params, 162 | ): 163 | logger.info("Starting hyperparameter optimization...") 164 | if per_class_optimization: 165 | class_names = self.config.class_names 166 | else: 167 | class_names = ["all_classes"] 168 | for class_name in class_names: 169 | prefix = "" if class_name == "all_classes" else f"{class_name}_" 170 | if reuse_existing_partial_results: 171 | previous_results = list( 172 | self.output_root.glob( 173 | f"../*/hyperparameters_optimization/optimization_results_detailed_{prefix}*.pkl" 174 | ) 175 | ) 176 | if len(previous_results): 177 | logger.info( 178 | f"found previous results for class {class_name}: {previous_results}" 179 | ) 180 | with open(previous_results[0], "rb") as file: 181 | best_params, _, _ = pickle.load(file) 182 | self.set_params(**best_params) 183 | logger.info("Restored previous results") 184 | continue 185 | 186 | type_2_skopt_class = { 187 | "categorical": Categorical, 188 | "float": Real, 189 | "int": Integer, 190 | } 191 | dimensions = [] 192 | # x0 -> to enforce evaluation of the default parameters 193 | initial_instance_parameters = self.__dict__ 194 | if hasattr(self, "classifier_class"): 195 | classifier_attributes = inspect.getfullargspec( 196 | self.classifier_class 197 | ).kwonlydefaults 198 | if classifier_attributes is not None: 199 | classifier_attributes.update(initial_instance_parameters) 200 | initial_instance_parameters = classifier_attributes 201 | x0 = [] 202 | for name, characteristics in dimension_params.items(): 203 | if characteristics["type"] != "categorical": 204 | next_dims = type_2_skopt_class[characteristics["type"]]( 205 | *characteristics["args"], name=name 206 | ) 207 | else: 208 | next_dims = type_2_skopt_class[characteristics["type"]]( 209 | characteristics["args"], name=name 210 | ) 211 | dimensions.append(next_dims) 212 | assert ( 213 | name in initial_instance_parameters 214 | ), f"Hyperparameter {name} does not seem to be a model attribute" 215 | x0.append(initial_instance_parameters[name]) 216 | run_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 217 | 218 | # optimization results will be stored in 219 | if not (self.output_root / "hyperparameters_optimization").exists(): 220 | (self.output_root / "hyperparameters_optimization").mkdir(exist_ok=True) 221 | 222 | # The objective function to be minimized 223 | def make_objective(train_df, space, cross_validation): 224 | # This decorator converts your objective function with named arguments into one that 225 | # accepts a list as argument, while doing the conversion automatically 226 | @use_named_args(space) 227 | def objective_value(**params): 228 | if class_name == "all_classes": 229 | prefix = "" 230 | else: 231 | prefix = f"{class_name}_" 232 | params = { 233 | f"{prefix}{param_name}": param_value 234 | for param_name, param_value in params.items() 235 | } 236 | logger.info("setting params", params) 237 | self.set_params(**params) 238 | map_scores = [] 239 | available_ids = train_df[self.config.id_col_name].drop_duplicates() 240 | id_2_classes = train_df.groupby(self.config.id_col_name)[ 241 | self.config.target_col_name 242 | ].agg(set) 243 | try: 244 | for trn_idx, val_idx in cross_validation.split( 245 | available_ids, 246 | available_ids.map( 247 | lambda uniprot_id: tuple( 248 | sorted(id_2_classes.loc[uniprot_id]) 249 | ) 250 | ) 251 | if not per_class_optimization 252 | else available_ids.map( 253 | lambda uniprot_id: class_name 254 | in id_2_classes.loc[uniprot_id] 255 | ).astype(int), 256 | ): 257 | trn_ids = set(available_ids.iloc[trn_idx].values) 258 | val_ids = set(available_ids.iloc[val_idx].values) 259 | trn_df = train_df[train_df[self.config.id_col_name].isin(trn_ids)] 260 | vl_df = train_df[train_df[self.config.id_col_name].isin(val_ids)] 261 | self.fit_core( 262 | trn_df, 263 | class_name=class_name 264 | if per_class_optimization 265 | else None, 266 | ) 267 | map_scores.append( 268 | eval_model_mean_average_precision( 269 | self, 270 | vl_df, 271 | selected_class_name=None 272 | if not per_class_optimization 273 | else class_name, 274 | ) 275 | ) 276 | if not use_cross_validation: 277 | break 278 | score = np.mean(map_scores) 279 | except ValueError as e: 280 | print(e) 281 | score = 1.0 282 | 283 | ckpts = list( 284 | (self.output_root / "hyperparameters_optimization").glob( 285 | "*_params.json" 286 | ) 287 | ) 288 | if len(ckpts) > 0: 289 | past_performances = sorted( 290 | [ 291 | float(str(ckpt_name.stem).split("_")[0]) 292 | for ckpt_name in ckpts 293 | ] 294 | ) 295 | if len(ckpts) == 0 or past_performances[-1] > score: 296 | for ckpt in ckpts: 297 | os.remove(ckpt) 298 | with open( 299 | self.output_root 300 | / "hyperparameters_optimization" 301 | / f"{score:.5f}_params_{run_timestamp}.json", 302 | "w", 303 | encoding="utf8", 304 | ) as file: 305 | json.dump( 306 | { 307 | key: ( 308 | val 309 | if not isinstance(val, np.integer) 310 | else int(val) 311 | ) 312 | for key, val in params.items() 313 | }, 314 | file, 315 | ) 316 | 317 | return score 318 | 319 | return objective_value 320 | 321 | k_fold = StratifiedKFold( 322 | n_splits=n_fold_splits, shuffle=True, random_state=42 323 | ) 324 | 325 | objective = make_objective( 326 | train_df, space=dimensions, cross_validation=k_fold 327 | ) 328 | 329 | gp_round = gp_minimize( 330 | func=objective, 331 | dimensions=dimensions, 332 | acq_func="gp_hedge", 333 | n_calls=n_calls, 334 | n_initial_points=min(10, n_calls // 5), 335 | random_state=42, 336 | verbose=True, 337 | x0=x0, 338 | ) 339 | best_params = { 340 | f"{prefix}{dimensions[i].name}": param_value 341 | for i, param_value in enumerate(gp_round.x) 342 | } 343 | self.set_params(**best_params) 344 | 345 | with open( 346 | self.output_root 347 | / "hyperparameters_optimization" 348 | / f"optimization_results_detailed_{prefix}{run_timestamp}.pkl", 349 | "wb", 350 | ) as file: 351 | pickle.dump((best_params, gp_round.x_iters, gp_round.func_vals), file) 352 | 353 | def _jsonify_value(value): 354 | if isinstance(value, np.int64): 355 | return int(value) 356 | if isinstance(value, np.float): 357 | return float(value) 358 | return value 359 | 360 | # just in case all hyperparameters for all classes were already pre-computed 361 | if not (self.output_root / "hyperparameters_optimization").exists(): 362 | (self.output_root / "hyperparameters_optimization").mkdir(exist_ok=True) 363 | with open( 364 | self.output_root 365 | / "hyperparameters_optimization" 366 | / f"best_params_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json", 367 | "w", 368 | ) as file: 369 | json.dump( 370 | { 371 | param: _jsonify_value(val) 372 | for param, val in self.get_model_specific_params().items() 373 | }, 374 | file, 375 | ) 376 | 377 | @classmethod 378 | @abstractmethod 379 | def config_class(cls) -> Type[BaseConfig]: 380 | """ 381 | A getter of a config class 382 | """ 383 | raise NotImplementedError 384 | 385 | def get_params(self, deep: bool = True): 386 | return { 387 | "config": ( 388 | deepcopy(self.__dict__["config"]) if deep else self.__dict__["config"] 389 | ) 390 | } 391 | 392 | def get_model_specific_params(self, class_name: str = None): 393 | initial_instance_parameters = self.__dict__ 394 | try: 395 | classifier_args = inspect.getfullargspec(self.classifier_class) 396 | classifier_attributes = set(classifier_args.kwonlydefaults.keys()) 397 | classifier_attributes.update( 398 | {arg for arg in classifier_args.args if arg != "self"} 399 | ) 400 | except AttributeError: # sometimes everything is being hidden in **kwargs 401 | # (e.g. it's the case for sklearn wrapper of xgboost), 402 | # then fallback to default constructor 403 | instance = self.classifier_class() 404 | classifier_attributes = instance.__dict__ 405 | return { 406 | key: val 407 | for key, val in initial_instance_parameters.items() 408 | if key in classifier_attributes 409 | or ( 410 | "_".join(key.split("_")[1:]) in classifier_attributes 411 | and (class_name is None or key.split("_")[0] == class_name) 412 | ) 413 | } 414 | -------------------------------------------------------------------------------- /profun/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """The init is added to enable finding all the submodules""" 2 | -------------------------------------------------------------------------------- /profun/utils/alphafold_struct_downloader.py: -------------------------------------------------------------------------------- 1 | """This script downloads protein structures predicted by AlphaFold2""" 2 | 3 | import argparse 4 | import logging 5 | from pathlib import Path 6 | from functools import partial 7 | 8 | import requests 9 | from multiprocessing import Pool 10 | 11 | import pandas as pd # type: ignore 12 | 13 | from tqdm.auto import tqdm 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger("Downloading AlphaFold2 structures") 17 | logger.setLevel(logging.INFO) 18 | logging.getLogger("requests").setLevel(logging.WARNING) 19 | logging.getLogger("urllib3").setLevel(logging.WARNING) 20 | 21 | 22 | def parse_args() -> argparse.Namespace: 23 | """ 24 | This function parses arguments 25 | :return: current argparse.Namespace 26 | """ 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--structures-output-path", type=str, default="../data/af_structs" 30 | ) 31 | parser.add_argument("--path-to-file-with-ids", type=str, 32 | default="../data/uniprot_ids_of_interest.txt", help="Path to a file containing UniProt IDs," 33 | "for which the script will download AF2 structures") 34 | parser.add_argument("--n-jobs", type=int, default=1) 35 | 36 | args = parser.parse_args() 37 | return args 38 | 39 | def download_af_struct(uniprot_id, root_af, fails_count=0, max_fails_count=3): 40 | try: 41 | URL = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v3.pdb" 42 | response = requests.get(URL) 43 | with open(root_af / f"{uniprot_id}.pdb", "wb") as file: 44 | file.write(response.content) 45 | except: 46 | logger.warning(f"Error downloading AlphaFold2 structure for {uniprot_id}") 47 | if fails_count < max_fails_count: 48 | download_af_struct(uniprot_id, root_af, fails_count+1) 49 | 50 | 51 | def main(): 52 | """ 53 | This function downloads protein structures predicted by AlphaFold 54 | """ 55 | cl_args = parse_args() 56 | root_af = Path(cl_args.structures_output_path) 57 | if not root_af.exists(): 58 | root_af.mkdir() 59 | 60 | download_af_struct_for_current_root = partial(download_af_struct, root_af=root_af) 61 | 62 | with open(cl_args.path_to_file_with_ids, 'r') as file: 63 | all_ids_of_interest = [line.strip() for line in file.readlines()] 64 | 65 | with Pool(processes=cl_args.n_jobs) as pool: 66 | pool.map(download_af_struct_for_current_root, all_ids_of_interest) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /profun/utils/msa.py: -------------------------------------------------------------------------------- 1 | """Utils for MSA (Multiple Sequence Alignment)""" 2 | import os 3 | from typing import Iterable, Optional 4 | 5 | 6 | def get_fasta_seqs(seqs: Iterable[str], ids: Optional[Iterable[str]] = None): 7 | """ 8 | This function dumps list of sequences with optional ids to a fasta file 9 | """ 10 | if ids is None: 11 | full_entries = [ 12 | f">{i}\n{''.join(tps_seq.split())}" for i, tps_seq in enumerate(seqs) 13 | ] 14 | else: 15 | full_entries = [ 16 | f">{id}\n{''.join(tps_seq.split())}" for id, tps_seq in zip(ids, seqs) 17 | ] 18 | return "\n".join(full_entries) 19 | 20 | 21 | def generate_msa_mafft(*, seqs: Optional[Iterable[str]] = None, ids: Iterable[str] = None, 22 | fasta_str: str = None, 23 | output_name: str = "_msa.fasta", 24 | n_jobs: int = 26, 25 | clustal_output_format: bool = True): 26 | """ 27 | This function generates multiple sequence alignment using MAFFT 28 | 29 | Either accepts all fasta-format sequences prepared in `fasta_str` argument, 30 | or prepares the fasta-format sequences based on 31 | """ 32 | assert fasta_str is None or seqs is None, ("The input sequences must be passed either as an iterable of strings or " 33 | "as the preprocessed fasta_str, " 34 | "but cannot be passed by both options simultaneously") 35 | if fasta_str is None: 36 | fasta_str = get_fasta_seqs(seqs, ids) 37 | with open("_temp_mafft.fasta", "w", encoding="utf8") as f: 38 | f.writelines(fasta_str.replace("'", "").replace('"', "")) 39 | os.system( 40 | f"mafft --thread {n_jobs} --auto --quiet {'--clustalout ' if clustal_output_format else ''}_temp_mafft.fasta > {output_name}") 41 | os.remove("_temp_mafft.fasta") 42 | -------------------------------------------------------------------------------- /profun/utils/project_info.py: -------------------------------------------------------------------------------- 1 | """This script contains routines for working with project info""" 2 | 3 | import datetime 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | from dataclasses_json import dataclass_json # type: ignore 8 | from typing import Optional 9 | 10 | 11 | @dataclass_json 12 | @dataclass 13 | class ExperimentInfo: 14 | """A dataclass to store information about a particular experiment scenario""" 15 | 16 | validation_schema: Optional[str] = None 17 | model_type: Optional[str] = None 18 | model_version: Optional[str] = None 19 | 20 | def __post_init__(self): 21 | """Setting up an experiment timestamp and fold info""" 22 | self.timestamp = datetime.datetime.now() 23 | self._fold = "all_folds" 24 | 25 | @property 26 | def fold(self): 27 | return self._fold 28 | 29 | @fold.setter 30 | def fold(self, value: str): 31 | self._fold = value 32 | 33 | @property 34 | def model(self): 35 | return self.model_type 36 | 37 | @model.setter 38 | def model(self, value: str): 39 | self.model_type = value 40 | 41 | def get_experiment_name(self): 42 | """Detailed experiment name getter""" 43 | experiment_name = ( 44 | f"validation_{self.validation_schema}__model_{self.model_type}_{self.model_version}_" 45 | f'{self.timestamp.strftime("%Y%m%d-%H%M")}' 46 | ) 47 | return experiment_name 48 | 49 | 50 | def get_project_root() -> Path: 51 | """ 52 | Returns: absolute path to the project root directory 53 | """ 54 | return Path.home() / "profun_outs" 55 | 56 | 57 | def get_output_root() -> Path: 58 | """ 59 | Returns: absolute path to the output directory 60 | """ 61 | return get_project_root() / "outputs" 62 | 63 | 64 | def get_experiments_output() -> Path: 65 | """ 66 | Returns: absolute path to the experiments directory 67 | """ 68 | return get_output_root() / "experiment_results" 69 | 70 | 71 | def get_config_root() -> Path: 72 | """ 73 | Returns: absolute path to the config directory 74 | """ 75 | return Path(get_project_root()) / "configs" 76 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | 5 | setup( 6 | name='profun', 7 | description='Library for protein function prediction', 8 | license='MIT', 9 | version='0.1', 10 | zip_safe=True, 11 | include_package_data=True, 12 | packages=find_packages(), 13 | entry_points={ 14 | "console_scripts": [ 15 | "alphafold_struct_downloader=profun.utils.alphafold_struct_downloader:main", 16 | ], 17 | }, 18 | install_requires=['pandas', 'numpy', 'dataclasses_json', 'scikit-learn', 19 | 'iterative-stratification', 'scikit-optimize'], 20 | author='Raman Samusevich', 21 | author_email='raman.samusevich@gmail.com' 22 | ) 23 | --------------------------------------------------------------------------------