├── __init__.py ├── tests ├── __init__.py ├── test_preprocess.py ├── test_triad_pre.py ├── test_triad_post.py ├── test_hd.py ├── test_coves.py ├── test_esmif.py ├── test_de_vis.py ├── test_de.py ├── test_loc_opt.py ├── test_gen_learned_emb.py ├── test_ev_esm.py ├── test_alde.py ├── test_mlde_vis.py ├── test_finetune.py ├── test_pairwise_epistasis.py ├── test_zs.py └── test_corr.py ├── fig1.png ├── envs ├── esmif.yml ├── coves.yml ├── SSMuLA.yml ├── frozen │ ├── esmif.yml │ ├── coves.yml │ └── SSMuLA.yml └── finetune.yml ├── SSMuLA ├── gen_atom3d.py ├── __init__.py ├── run_ev_esm.py ├── finetune_analysis.py ├── est_ep.py ├── util.py ├── zs_calc.py ├── alde_analysis.py ├── triad_prepost.py ├── aa_global.py ├── vis.py ├── calc_hd.py ├── zs_models.py ├── plm_finetune.py └── get_factor.py ├── esmif ├── esmif.sh └── score_log_likelihoods.py ├── .gitignore └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fhalab/SSMuLA/HEAD/fig1.png -------------------------------------------------------------------------------- /envs/esmif.yml: -------------------------------------------------------------------------------- 1 | # For the general environment: 2 | # Install or update using 3 | # conda env update --file esmif.yml --prune 4 | 5 | name: esmif 6 | channels: 7 | - pytorch 8 | - pyg 9 | - conda-forge 10 | - defaults 11 | dependencies: 12 | - python=3.9 13 | - biopandas 14 | - biopython 15 | - biotite 16 | - cudatoolkit=11.3 17 | - pandas 18 | - pyg 19 | - pytorch 20 | - pip 21 | - pip: 22 | - git+https://github.com/facebookresearch/esm.git 23 | -------------------------------------------------------------------------------- /envs/coves.yml: -------------------------------------------------------------------------------- 1 | # conda env update --file coves.yml --prune 2 | # need to be fixed 3 | name: coves 4 | channels: 5 | - pytorch 6 | - pyg 7 | - pytorch3d 8 | - conda-forge 9 | - defaults 10 | - anaconda 11 | dependencies: 12 | - python=3.9 13 | - numpy # =1.19.4 14 | - pip 15 | - pyg 16 | - pytorch # =1.8.1 17 | - scikit-learn # =0.24.1 18 | # - torch_geometric # =1.7.0 19 | # - torch_scatter # =2.0.6 20 | # - torch_cluster # =1.5.9 21 | - tqdm 22 | 23 | - pip: 24 | - atom3d # =0.2.1 25 | - blackcellmagic 26 | - rdkit -------------------------------------------------------------------------------- /tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | """Test the preprocess module.""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.fitness_process_vis import process_all, get_all_lib_stats 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | log_folder = checkNgen_folder("logs/fitness_process_vis") 15 | 16 | # log outputs 17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 18 | sys.stdout = f 19 | 20 | process_all(scale_fit="max") 21 | get_all_lib_stats() 22 | 23 | f.close() -------------------------------------------------------------------------------- /tests/test_triad_pre.py: -------------------------------------------------------------------------------- 1 | """Test the triad pre and post processing.""" 2 | 3 | import sys 4 | import os 5 | 6 | from glob import glob 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.triad_prepost import run_traid_gen_mut_file 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/triad/pre"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | run_traid_gen_mut_file(lib_list = ["kej"]) 26 | 27 | f.close() -------------------------------------------------------------------------------- /tests/test_triad_post.py: -------------------------------------------------------------------------------- 1 | """Test the triad pre and post processing.""" 2 | 3 | import sys 4 | import os 5 | 6 | from glob import glob 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.triad_prepost import run_parse_triad_results 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/triad/post"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | run_parse_triad_results(all_lib = False, lib_list = ["T7", "TEV"]) 26 | 27 | f.close() -------------------------------------------------------------------------------- /tests/test_hd.py: -------------------------------------------------------------------------------- 1 | """A script for testing plotting for de""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.calc_hd import run_hd_avg_fit, run_hd_avg_metric 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | log_folder = checkNgen_folder("logs/hd") 15 | 16 | # log outputs 17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 18 | sys.stdout = f 19 | 20 | run_hd_avg_fit( 21 | data_dir = "data", 22 | num_processes=256, 23 | hd_dir = "results/hd_fit", 24 | ) 25 | 26 | run_hd_avg_metric() 27 | 28 | f.close() -------------------------------------------------------------------------------- /tests/test_coves.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for testing coves 3 | Use coves environment 4 | """ 5 | 6 | import sys 7 | import os 8 | 9 | from datetime import datetime 10 | 11 | from SSMuLA.run_coves import run_all_coves, append_all_coves_scores 12 | from SSMuLA.util import checkNgen_folder 13 | 14 | if __name__ == "__main__": 15 | 16 | log_folder = checkNgen_folder("logs/coves") 17 | 18 | # log outputs 19 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 20 | sys.stdout = f 21 | 22 | run_all_coves(n_ave=100) 23 | append_all_coves_scores() 24 | 25 | f.close() 26 | 27 | """ 28 | append_all_coves_scores( 29 | libs: list|str = "ev_esm2/*", 30 | ev_esm_dir: str = "ev_esm2", 31 | coves_dir: str = "coves/100", 32 | t: float = 0.1 33 | """ -------------------------------------------------------------------------------- /tests/test_esmif.py: -------------------------------------------------------------------------------- 1 | """Test ems inverse folding zs""" 2 | 3 | import sys 4 | import os 5 | 6 | 7 | from datetime import datetime 8 | 9 | # from SSMuLA.zs_analysis import run_zs_analysis 10 | from SSMuLA.zs_data import get_all_mutfasta 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/zs"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | get_all_mutfasta( 26 | ev_esm_dir="ev_esm2", 27 | all_libs=True, 28 | ) 29 | 30 | f.close() 31 | 32 | """ 33 | get_all_mutfasta( 34 | ev_esm_dir: str = "ev_esm2", 35 | all_libs: bool = True, 36 | lib_list: list[str] = [] 37 | ) 38 | """ -------------------------------------------------------------------------------- /tests/test_de_vis.py: -------------------------------------------------------------------------------- 1 | """A script for testing plotting for de""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.de_simulations import run_plot_de 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | if __name__ == "__main__": 12 | 13 | log_folder = checkNgen_folder("logs/plot_de") 14 | 15 | # log outputs 16 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 17 | sys.stdout = f 18 | 19 | run_plot_de( 20 | scale_types = ["scale2max"], 21 | de_opts = ["DE-active"], 22 | ) 23 | 24 | """ 25 | def run_plot_de( 26 | scale_types: list = ["scale2max", "scale2parent"], 27 | de_opts: list = ["DE-active"], 28 | sim_folder: str = "results/de", 29 | vis_folder: str = "results/de_vis", 30 | v_width: int = 400, 31 | all_lib: bool = True, 32 | lib_list: list[str] = [], 33 | ): 34 | """ 35 | 36 | f.close() -------------------------------------------------------------------------------- /tests/test_de.py: -------------------------------------------------------------------------------- 1 | """A script for testing plotting for de""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.de_simulations import run_all_lib_de_simulations 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | log_folder = checkNgen_folder("logs/de") 15 | 16 | # log outputs 17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 18 | sys.stdout = f 19 | 20 | run_all_lib_de_simulations( 21 | scale_types = ["scale2max"], 22 | de_opts = ["DE-active", "DE-0", "DE-all"], 23 | all_lib = True, 24 | lib_list = [], 25 | rerun = False 26 | ) 27 | 28 | """ 29 | run_all_lib_de_simulations( 30 | scale_types: list = ["scale2max", "scale2parent"], 31 | de_opts: list = ["DE-active", "DE-0", "DE-all"], 32 | save_dir: str = "results/de", 33 | all_lib: bool = True, 34 | lib_list: list[str] = [], 35 | rerun: bool = False, 36 | ) 37 | """ 38 | 39 | f.close() -------------------------------------------------------------------------------- /tests/test_loc_opt.py: -------------------------------------------------------------------------------- 1 | """A script for testing the local optima""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.landscape_optima import run_loc_opt 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | log_folder = checkNgen_folder("logs/loc_opt") 15 | 16 | # log outputs 17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 18 | sys.stdout = f 19 | 20 | run_loc_opt(input_folder = "data", 21 | fitness_process_type = "scale2max", 22 | output_folder = "results/local_optima", 23 | n_jobs = 16, 24 | if_append_escape = True, 25 | rerun = True, 26 | ) 27 | 28 | f.close() 29 | 30 | """ 31 | run_loc_opt( 32 | input_folder: str = "data", 33 | fitness_process_type: str = "scale2max", 34 | output_folder: str = "results/local_optima", 35 | n_jobs: int = 16, 36 | rerun: bool = False, 37 | ) -> None 38 | """ 39 | 40 | 41 | -------------------------------------------------------------------------------- /tests/test_gen_learned_emb.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to test the learned embedding generation 3 | Test the triad pre and post processing. 4 | """ 5 | 6 | import sys 7 | import os 8 | 9 | from datetime import datetime 10 | 11 | from SSMuLA.gen_learned_emb import gen_all_learned_emb 12 | from SSMuLA.util import checkNgen_folder 13 | 14 | if __name__ == "__main__": 15 | 16 | # log outputs 17 | f = open( 18 | os.path.join( 19 | checkNgen_folder("logs/emb"), 20 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 21 | ), 22 | "w", 23 | ) 24 | sys.stdout = f 25 | 26 | gen_all_learned_emb( 27 | input_folder = "results/zs_comb_6/none/scale2max/all", 28 | all_libs = False, 29 | regen = True, 30 | lib_list = ["TrpB3F"], 31 | ) 32 | 33 | f.close() 34 | 35 | """ 36 | input_folder: str = "results/zs_comb/none/scale2max", 37 | encoder_name: str = DEFAULT_ESM, 38 | batch_size: int = 128, 39 | regen: bool = False, 40 | emb_folder: str = "learned_emb", 41 | all_libs: bool = True, 42 | lib_list: list[str] = [], 43 | """ -------------------------------------------------------------------------------- /tests/test_ev_esm.py: -------------------------------------------------------------------------------- 1 | """Test ems inverse folding zs""" 2 | 3 | import sys 4 | import os 5 | 6 | 7 | from datetime import datetime 8 | 9 | # from SSMuLA.zs_analysis import run_zs_analysis 10 | from SSMuLA.zs_calc import calc_all_zs 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/zs"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | calc_all_zs( 26 | landscape_folder="data", 27 | output_folder="results/zs", 28 | ev_model_folder="data", 29 | regen_esm = False, 30 | rerun_zs = False 31 | ) 32 | 33 | f.close() 34 | 35 | """ 36 | calc_all_zs( 37 | landscape_folder: str = "data/processed", 38 | dataset_list: list[str] = [], 39 | output_folder: str = "results/zs", 40 | zs_model_names: str = "all", 41 | ev_model_folder: str = "data/evmodels", 42 | regen_esm: str = False, 43 | rerun_zs: str = False, 44 | ) 45 | """ -------------------------------------------------------------------------------- /tests/test_alde.py: -------------------------------------------------------------------------------- 1 | 2 | """Test alde file comb""" 3 | 4 | import sys 5 | import os 6 | 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.alde_analysis import aggregate_alde_df 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/alde_comb"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | aggregate_alde_df( 26 | eq_ns = [2, 3, 4], 27 | alde_dir = "results/alde", 28 | alde_df_path = "results/alde/alde_all.csv", 29 | 30 | ) 31 | 32 | f.close() 33 | 34 | """ 35 | aggregate_alde_df( 36 | eq_ns: list[int] = [1, 2, 3, 4], 37 | zs_opts: list[str] = ["esmif", "ev", "coves", "ed", "esm", "Triad", ""], 38 | alde_model: str = "Boosting Ensemble", 39 | alde_encoding: str = "onehot", 40 | alde_acq: str = "GREEDY", 41 | alde_dir: str = "/disk2/fli/alde4ssmula", 42 | alde_df_path: str = "results/alde/alde_all.csv", 43 | ) 44 | """ 45 | 46 | -------------------------------------------------------------------------------- /tests/test_mlde_vis.py: -------------------------------------------------------------------------------- 1 | """Test MLDE.""" 2 | 3 | import sys 4 | import os 5 | 6 | from glob import glob 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.mlde_analysis import MLDESum 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/mlde_vis"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | 26 | # MLDESum( 27 | # mlde_results_dir = "results_rev/mlde_coves_ens/saved", 28 | # mlde_vis_dir = "results_rev/mlde_coves_ens/vis" 29 | # ) 30 | MLDESum( 31 | mlde_results_dir = "results_rev/mlde_lown/saved", 32 | mlde_vis_dir = "results_rev/mlde_lown/vis" 33 | ) 34 | 35 | """ 36 | MLDESum: 37 | 38 | def __init__( 39 | self, 40 | mlde_results_dir: str = "results/mlde/saved", 41 | mlde_vis_dir: str = "results/mlde/vis", 42 | all_encoding: bool = True, 43 | encoding_lists: list[str] = [], 44 | ifvis: bool = False, 45 | ) -> None: 46 | """ 47 | f.close() -------------------------------------------------------------------------------- /SSMuLA/gen_atom3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | For generating CoVES based zs scores 3 | 4 | NOTE: 5 | - have to use atom3d evn for atom3d gen 6 | - have to be run in the coves env for coves 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | import os 12 | from glob import glob 13 | 14 | import atom3d.datasets as da 15 | 16 | from SSMuLA.util import checkNgen_folder, get_file_name 17 | 18 | 19 | def gen_lmdb_dataset(pdb_path: str, lmdb_path: str): 20 | 21 | """ 22 | Generate LMDB dataset from PDB dataset 23 | 24 | Args: 25 | - pdb_path, str: Path to the PDB files 26 | - lmdb_path, str: Path to directory to save LMDB dataset 27 | """ 28 | 29 | # Load dataset from directory of PDB files 30 | pdb_file = da.load_dataset(pdb_path, 'pdb') 31 | # Create LMDB dataset from PDB dataset 32 | checkNgen_folder(lmdb_path) 33 | da.make_lmdb_dataset(pdb_file, lmdb_path) 34 | 35 | 36 | def gen_all_lmdb(pdb_pattern: str = "data/*/*.pdb", lmdb_dir: str = "lmdb"): 37 | 38 | for pdb_path in sorted(glob(pdb_pattern)): 39 | 40 | print(f"Generating LMDB dataset for {pdb_path}...") 41 | 42 | protein_name = get_file_name(pdb_path) 43 | 44 | lmdb_path = os.path.join(lmdb_dir, protein_name, "lmdb") 45 | gen_lmdb_dataset(pdb_path, lmdb_path) 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /esmif/esmif.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 2 | 3 | python score_log_likelihoods.py DHFR.pdb DHFR.fasta --chain A --outpath DHFR_esmif_scores.csv 4 | python score_log_likelihoods.py GB1.pdb GB1.fasta --chain A --outpath GB1_esmif_scores.csv 5 | python score_log_likelihoods.py ParD2.pdb ParD2.fasta --chain A --outpath ParD2_esmif_scores.csv 6 | python score_log_likelihoods.py ParD3.pdb ParD3.fasta --chain A --outpath ParD3_esmif_scores.csv 7 | python score_log_likelihoods.py TrpB.pdb TrpB3A.fasta --chain A --outpath TrpB3A_esmif_scores.csv 8 | python score_log_likelihoods.py TrpB.pdb TrpB3B.fasta --chain A --outpath TrpB3B_esmif_scores.csv 9 | python score_log_likelihoods.py TrpB.pdb TrpB3C.fasta --chain A --outpath TrpB3C_esmif_scores.csv 10 | python score_log_likelihoods.py TrpB.pdb TrpB3D.fasta --chain A --outpath TrpB3D_esmif_scores.csv 11 | python score_log_likelihoods.py TrpB.pdb TrpB3E.fasta --chain A --outpath TrpB3E_esmif_scores.csv 12 | python score_log_likelihoods.py TrpB.pdb TrpB3F.fasta --chain A --outpath TrpB3F_esmif_scores.csv 13 | python score_log_likelihoods.py TrpB.pdb TrpB3G.fasta --chain A --outpath TrpB3G_esmif_scores.csv 14 | python score_log_likelihoods.py TrpB.pdb TrpB3H.fasta --chain A --outpath TrpB3H_esmif_scores.csv 15 | python score_log_likelihoods.py TrpB.pdb TrpB3I.fasta --chain A --outpath TrpB3I_esmif_scores.csv 16 | python score_log_likelihoods.py TrpB.pdb TrpB4.fasta --chain A --outpath TrpB4_esmif_scores.csv -------------------------------------------------------------------------------- /envs/SSMuLA.yml: -------------------------------------------------------------------------------- 1 | # For the general environment: 2 | # Install or update using 3 | # conda env update --file SSMuLA.yml --prune 4 | 5 | name: SSMuLA 6 | channels: 7 | - pytorch 8 | - pyg 9 | - pytorch3d 10 | - nvidia 11 | - conda-forge 12 | - salilab 13 | - pyviz 14 | - defaults 15 | - anaconda 16 | dependencies: 17 | - python=3.11 18 | - biopandas 19 | - biopython 20 | - biotite 21 | - bokeh 22 | - brokenaxes 23 | - cairosvg 24 | - colorcet 25 | - datashader 26 | - firefox 27 | - flake8 28 | - geckodriver 29 | - hdf5 30 | - holoviews 31 | - hvplot 32 | - ipykernel 33 | - ipympl 34 | - ipywidgets 35 | - jinja2 36 | - jupyterlab 37 | - jupyter_bokeh 38 | - matplotlib 39 | - multipledispatch 40 | - mypy 41 | - networkx 42 | - nodejs 43 | - numpy 44 | - openpyxl 45 | - pandas 46 | - panel 47 | - param 48 | - pdbfixer 49 | - pip 50 | - psutil 51 | - pyg>=2.0.3 52 | - pytables 53 | - pytorch=2.1.1 # for h100 54 | - pytorch-cuda=12.1 # change this as needed 55 | - pyviz_comms 56 | - pyyaml 57 | - requests 58 | - scikit-learn # might not be evcouplings compatible 59 | - scipy 60 | - seaborn 61 | - tensorflow 62 | - tqdm 63 | - versioneer 64 | - xarray 65 | - xgboost 66 | 67 | - pip: 68 | - blackcellmagic 69 | - fair-esm 70 | - ordered-set 71 | - rdkit 72 | # pip install https://github.com/debbiemarkslab/EVcouplings/archive/develop.zip install after the environment is created 73 | prefix: /disk2/fli/setup/miniconda3/ -------------------------------------------------------------------------------- /tests/test_finetune.py: -------------------------------------------------------------------------------- 1 | """Test MLDE.""" 2 | 3 | import sys 4 | import os 5 | 6 | from glob import glob 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.plm_finetune import train_predict_per_protein 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/finetune"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | 26 | for landscape in sorted(glob("results/zs_comb/all/*.csv")): 27 | for i in range(5): 28 | train_predict_per_protein( 29 | df_csv=landscape, 30 | rep=i, 31 | ) 32 | 33 | f.close() 34 | 35 | """ 36 | train_predict_per_protein( 37 | df_csv: str, # csv file with landscape data 38 | rep: int, # replicate number 39 | checkpoint: str = "facebook/esm2_t33_650M_UR50D", # model checkpoint 40 | n_sample: int = 384, # number of train+val 41 | zs_predictor: str = "none", # zero-shot predictor 42 | ft_frac: float = 0.125, # fraction of data for focused sampling 43 | plot_dir: str = "results/finetuning/plot", # directory to save the plot 44 | model_dir: str = "results/finetuning/model", # directory to save the model 45 | pred_dir: str = "results/finetuning/predictions", # directory to save the predictions 46 | train_kwargs: dict = {}, # additional training arguments 47 | ) 48 | """ -------------------------------------------------------------------------------- /SSMuLA/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # # 3 | # This program is free software: you can redistribute it and/or modify # 4 | # it under the terms of the GNU General Public License as published by # 5 | # the Free Software Foundation, either version 3 of the License, or # 6 | # (at your option) any later version. # 7 | # # 8 | # This program is distributed in the hope that it will be useful, # 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of # 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 11 | # GNU General Public License for more details. # 12 | # # 13 | # You should have received a copy of the GNU General Public License # 14 | # along with this program. If not, see . # 15 | # # 16 | ############################################################################### 17 | 18 | __title__ = 'SSMuLA' 19 | __description__ = 'Site Saturation Mutagenesis Landscape Analysis' 20 | __url__ = 'https://github.com/fhalab/SSMuLA.git' 21 | __version__ = '1.0.0' 22 | __author__ = 'Francesca-Zhoufan Li' 23 | __author_email__ = 'fzl@caltech.edu' 24 | __license__ = 'GPL3' 25 | -------------------------------------------------------------------------------- /tests/test_pairwise_epistasis.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to test the pairwise epistasis calculation 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import sys 8 | import os 9 | 10 | from datetime import datetime 11 | 12 | 13 | from SSMuLA.pairwise_epistasis import calc_all_pairwise_epistasis, plot_pairwise_epistasis 14 | from SSMuLA.util import checkNgen_folder 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | log_folder = checkNgen_folder("logs/pairwise_epistasis") 20 | 21 | # log outputs 22 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w') 23 | sys.stdout = f 24 | 25 | fitness_process_type = "scale2max" 26 | 27 | calc_all_pairwise_epistasis( 28 | fitness_process_type=fitness_process_type, 29 | ifall=False, 30 | lib_list=["T7", "TEV"], 31 | output_folder="results/pairwise_epistasis", 32 | n_jobs=128, 33 | ) 34 | 35 | """ 36 | calc_all_pairwise_epistasis( 37 | input_folder: str = "data", 38 | fitness_process_type: str = "scale2max", 39 | activestart: bool = True, 40 | ifall: bool = True, 41 | lib_list: list[str] = [], 42 | output_folder: str = "results/pairwise_epistasis", 43 | n_jobs: int = 128, 44 | """ 45 | 46 | pos_calc_filter_min = "none" 47 | 48 | plot_pairwise_epistasis( 49 | fitness_process_type=fitness_process_type, 50 | pos_calc_filter_min=pos_calc_filter_min, 51 | input_folder="results/pairwise_epistasis", 52 | output_folder="results/pairwise_epistasis_vis", 53 | dets_folder="results/pairwise_epistasis_dets", 54 | ) 55 | 56 | f.close() -------------------------------------------------------------------------------- /tests/test_zs.py: -------------------------------------------------------------------------------- 1 | """Test the triad pre and post processing.""" 2 | 3 | import sys 4 | import os 5 | 6 | from glob import glob 7 | 8 | from datetime import datetime 9 | 10 | from SSMuLA.zs_analysis import run_zs_analysis 11 | from SSMuLA.util import checkNgen_folder 12 | 13 | if __name__ == "__main__": 14 | 15 | # log outputs 16 | f = open( 17 | os.path.join( 18 | checkNgen_folder("logs/zs"), 19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out", 20 | ), 21 | "w", 22 | ) 23 | sys.stdout = f 24 | 25 | # run_zs_analysis( 26 | # scale_types=["max"], 27 | # filter_min_by="none", 28 | # ev_esm_folder = "ev_esm2", 29 | # zs_comb_dir = "results/zs_comb_2", 30 | # zs_vis_dir = "results/zs_vis_2", 31 | # zs_sum_dir = "results/zs_sum_2", 32 | # ) 33 | 34 | # run_zs_analysis( 35 | # scale_types=["max"], 36 | # filter_min_by="min0", 37 | # ev_esm_folder = "ev_esm2", 38 | # zs_comb_dir = "results/zs_comb_2", 39 | # zs_vis_dir = "results/zs_vis_2", 40 | # zs_sum_dir = "results/zs_sum_2", 41 | # ) 42 | 43 | # run_zs_analysis( 44 | # scale_types=["max"], 45 | # filter_min_by="none", 46 | # ev_esm_folder = "ev_esm2", 47 | # zs_comb_dir = "results/zs_comb_6", 48 | # zs_vis_dir = "results/zs_vis_6", 49 | # zs_sum_dir = "results/zs_sum_6", 50 | # ) 51 | 52 | run_zs_analysis( 53 | scale_types=["max"], 54 | filter_min_by="none", 55 | ev_esm_folder = "ev_esm2", 56 | zs_comb_dir = "results_old/zs_comb_7", 57 | zs_vis_dir = "results_old/zs_vis_7", 58 | zs_sum_dir = "results_old/zs_sum_7", 59 | ) 60 | 61 | f.close() 62 | 63 | """ 64 | run_zs_analysis( 65 | scale_types: list = ["max", "parent"], 66 | data_folder: str = "data", 67 | ev_esm_folder: str = "ev_esm", 68 | triad_folder: str = "triad", 69 | esmif_folder: str = "esmif", 70 | filter_min_by: str = "none", 71 | n_mut_cutoff_list: list[int] = [0, 1, 2], 72 | zs_comb_dir: str = "results/zs_comb", 73 | zs_vis_dir: str = "results/zs_vis", 74 | zs_sum_dir: str = "results/zs_sum", 75 | ) 76 | """ -------------------------------------------------------------------------------- /SSMuLA/run_ev_esm.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work 3 | """ 4 | 5 | # Import packages 6 | import glob 7 | import json 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import random 12 | from collections import Counter 13 | 14 | 15 | import argparse 16 | from pathlib import Path 17 | 18 | from tqdm import tqdm 19 | from typing import List, Tuple, Optional 20 | 21 | 22 | import warnings 23 | 24 | from SSMuLA.zs_calc import calc_all_zs 25 | from SSMuLA.util import get_file_name, checkNgen_folder 26 | 27 | 28 | # TODO clean up path to be dataset name independent 29 | 30 | def create_parser(): 31 | parser = argparse.ArgumentParser(description="Run zero-shot predictions") 32 | 33 | parser.add_argument( 34 | "--landscape_folder", 35 | type=str, 36 | default="data/processed", 37 | metavar="LSF", 38 | help="A folder path for all landscape data" 39 | ) 40 | 41 | parser.add_argument( 42 | "--dataset_list", 43 | type=json.loads, 44 | metavar="dsl", 45 | default=[], 46 | help="default dataset list empty to use glob for all", 47 | ) 48 | 49 | parser.add_argument( 50 | "--output_folder", 51 | type=str, 52 | default="results/zs", 53 | metavar="OPF", 54 | help="A output folder path with landscape subfolders" 55 | ) 56 | 57 | parser.add_argument( 58 | "--zs_model_names", 59 | type=str, 60 | metavar="ZSMN", 61 | help="A str of name(s) of zero-shot models to use seperated by comma, \ 62 | available: 'esm', 'ev', developing: 'ddg', 'Bert', all: 'all' runs all currently available models \ 63 | ie, 'esm, ev'", 64 | ) 65 | 66 | parser.add_argument( 67 | "--ev_model_folder", 68 | type=str, 69 | default="data/", 70 | metavar="EVF", 71 | help="folder for evmodels" 72 | ) 73 | 74 | parser.add_argument( 75 | "--regen_esm", 76 | type=bool, 77 | default=False, 78 | metavar="RG", 79 | help="if regenerate esm logits or load directly" 80 | ) 81 | 82 | parser.add_argument( 83 | "--rerun_zs", 84 | type=bool, 85 | default=False, 86 | metavar="RR", 87 | help="if append new zs to current csv or create new output" 88 | ) 89 | 90 | return parser 91 | 92 | 93 | def main(args): 94 | 95 | # Input processing 96 | 97 | calc_all_zs(landscape_folder = args.landscape_folder, 98 | dataset_list=args.dataset_list, 99 | output_folder = args.output_folder, 100 | zs_model_names = args.zs_model_names, 101 | ev_model_folder = args.ev_model_folder, 102 | regen_esm = args.regen_esm, 103 | rerun_zs = args.rerun_zs) 104 | 105 | # Run EvMutation 106 | if __name__ == "__main__": 107 | parser = create_parser() 108 | args = parser.parse_args() 109 | main(args) -------------------------------------------------------------------------------- /SSMuLA/finetune_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for combining and analyzing the results of the LoRA fine-tuning analysis. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from glob import glob 8 | import pandas as pd 9 | 10 | from SSMuLA.landscape_global import N_SAMPLE_LIST 11 | from SSMuLA.util import get_file_name 12 | 13 | 14 | def parse_finetune_df( 15 | finetune_dir: str, # ie results/finetuning/ev or none 16 | lib_list: list, 17 | n_top: int = 96, 18 | ) -> pd.DataFrame: 19 | 20 | """ 21 | Parse the finetune dataframe and return a summary dataframe. 22 | 23 | The finetune_dir should contain the results of the finetuning analysis, 24 | where each landscape is in a separate folder, and each folder contains 25 | the results of the finetuning analysis for each landscape 26 | with the format __.csv. 27 | 28 | Args: 29 | finetune_dir (str): The directory containing the finetuning results. 30 | lib_list (list): A list of libraries to include in the analysis. 31 | n_top (int): The number of top variants to consider. 32 | """ 33 | 34 | sum_df_list = [] 35 | 36 | for df_path in sorted(glob(f"{finetune_dir}/*/*.csv")): 37 | 38 | landscape, n_sample, rep = get_file_name(df_path).split("_") 39 | 40 | if landscape not in lib_list: 41 | continue 42 | 43 | df = pd.read_csv(df_path) 44 | max_fit_seq = df.loc[df["fitness"].idxmax()]["seq"] 45 | 46 | # get top 96 maxes 47 | top_df = ( 48 | df.sort_values(by="predictions", ascending=False) 49 | .reset_index(drop=True) 50 | .iloc[:n_top, :] 51 | ) 52 | top_seqs = top_df["seq"].astype(str).values 53 | 54 | # write to sum_df 55 | sum_df_list.append( 56 | { 57 | "landscape": landscape, 58 | "n_sample": int(n_sample), 59 | "rep": int(rep), 60 | "top_maxes": top_df["fitness"].max(), 61 | "if_truemaxs": int(max_fit_seq in top_seqs), 62 | } 63 | ) 64 | 65 | return ( 66 | pd.DataFrame(sum_df_list) 67 | .sort_values(by=["n_sample", "landscape", "rep"]) 68 | .reset_index(drop=True) 69 | .copy() 70 | ) 71 | 72 | 73 | def avg_finetune_df( 74 | finetune_df: pd.DataFrame, 75 | n_sample_list: list = N_SAMPLE_LIST, 76 | ) -> pd.DataFrame: 77 | 78 | """ 79 | Average the finetune dataframe over the number of samples and repetitions. 80 | 81 | Args: 82 | finetune_df (pd.DataFrame): The dataframe containing the finetuning results. 83 | n_sample_list (list): A list of the number of samples to consider. 84 | """ 85 | 86 | avg_sum_df = ( 87 | finetune_df[["n_sample", "top_maxes", "if_truemaxs"]] 88 | .groupby("n_sample") 89 | .agg(["mean", "std"]) 90 | .reset_index() 91 | ) 92 | avg_sum_df.columns = ["{}_{}".format(i, j) for i, j in avg_sum_df.columns] 93 | return ( 94 | avg_sum_df.rename(columns={"n_sample_": "n_sample"}) 95 | .set_index("n_sample") 96 | .copy() 97 | ) -------------------------------------------------------------------------------- /tests/test_corr.py: -------------------------------------------------------------------------------- 1 | """A script for testing plotting for de""" 2 | 3 | import sys 4 | import os 5 | 6 | from datetime import datetime 7 | 8 | from SSMuLA.get_corr import MergeLandscapeAttributes, MergeMLDEAttributes, perfom_corr 9 | from SSMuLA.util import checkNgen_folder 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | log_folder = checkNgen_folder("logs/corr") 15 | 16 | # log outputs 17 | f = open( 18 | os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 19 | "w", 20 | ) 21 | sys.stdout = f 22 | 23 | # MergeLandscapeAttributes( 24 | # lib_stat_path="results/fitness_distribution/max/all_lib_stats.csv", 25 | # loc_opt_path="results/local_optima/scale2max.csv", 26 | # pwe_path="results/pairwise_epistasis_vis/none/scale2max.csv", 27 | # zs_path="results/zs_sum_5/none/zs_stat_scale2max.csv", 28 | # de_path="results/de/DE-active/scale2max/all_landscape_de_summary.csv", 29 | # merge_dir="results/merged", 30 | # ) 31 | 32 | # MergeMLDEAttributes(mlde_path = "results/mlde/all_results.csv", merge_dir = "results/merged_2", models=["boosting"]) 33 | 34 | 35 | perfom_corr( 36 | n_mut_cutoff=0, 37 | n_list=[384], 38 | zs_path="results/zs_sum/none/zs_stat_scale2max.csv", 39 | mlde_path="results/mlde/all_df_comb_onehot.csv", 40 | corr_dir="results/corr", 41 | ifplot=False, 42 | ) 43 | 44 | """ 45 | 46 | MergeLandscapeAttributes: 47 | 48 | def __init__( 49 | self, 50 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv", 51 | loc_opt_path: str = "results/local_optima/scale2max.csv", 52 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv", 53 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv", 54 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv", 55 | merge_dir: str = "results/merged", 56 | ) 57 | 58 | MergeMLDEAttributes(MergeLandscapeAttributes): 59 | 60 | def __init__( 61 | self, 62 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv", 63 | loc_opt_path: str = "results/local_optima/scale2max.csv", 64 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv", 65 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv", 66 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv", 67 | mlde_path: str = "results/mlde/all_df_comb_onehot_2.csv", 68 | merge_dir: str = "results/merged", 69 | n_mut_cutoff: int = 0, 70 | n_sample: int = 384, 71 | n_top: int = 96, 72 | filter_active: float = 1, 73 | ft_frac=0.125, 74 | models: list[str] = ["boosting", "ridge"], 75 | ifplot: bool = True, 76 | ) 77 | 78 | perfom_corr( 79 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv", 80 | loc_opt_path: str = "results/local_optima/scale2max.csv", 81 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv", 82 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv", 83 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv", 84 | mlde_path: str = "results/mlde/all_df_comb_onehot.csv", 85 | corr_dir: str = "results/corr", 86 | n_mut_cutoff: int = 0, 87 | filter_active: float = 1, 88 | ft_frac: float = 0.125, 89 | n_top_list: list[int] = [96, 384], 90 | n_list: list[int] = N_SAMPLE_LIST, 91 | models_list: list[list[str]] = [["boosting", "ridge"], ["boosting"], ["ridge"]], 92 | ifplot: bool = True, 93 | ) 94 | """ 95 | 96 | f.close() -------------------------------------------------------------------------------- /envs/frozen/esmif.yml: -------------------------------------------------------------------------------- 1 | name: esmif 2 | channels: 3 | - pyg 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - aiohappyeyeballs=2.4.4=pyhd8ed1ab_1 11 | - aiohttp=3.11.11=py39h9399b63_0 12 | - aiosignal=1.3.2=pyhd8ed1ab_0 13 | - async-timeout=5.0.1=pyhd8ed1ab_1 14 | - attrs=25.1.0=pyh71513ae_0 15 | - biopandas=0.5.1=pyhd8ed1ab_1 16 | - biopython=1.85=py39h8cd3c5a_1 17 | - biotite=0.38.0=py39h44dd56e_0 18 | - blas=1.0=mkl 19 | - brotli-python=1.1.0=py39hf88036b_2 20 | - bzip2=1.0.8=h4bc722e_7 21 | - ca-certificates=2025.1.31=hbcca054_0 22 | - certifi=2024.12.14=pyhd8ed1ab_0 23 | - cffi=1.17.1=py39h15c3d72_0 24 | - charset-normalizer=3.4.1=pyhd8ed1ab_0 25 | - colorama=0.4.6=pyhd8ed1ab_1 26 | - cpuonly=2.0=0 27 | - cpython=3.9.21=py39hd8ed1ab_1 28 | - cudatoolkit=11.3.1=hb98b00a_13 29 | - filelock=3.17.0=pyhd8ed1ab_0 30 | - frozenlist=1.5.0=py39h9399b63_1 31 | - fsspec=2025.2.0=pyhd8ed1ab_0 32 | - gmp=6.3.0=hac33072_2 33 | - gmpy2=2.1.5=py39h7196dd7_3 34 | - h2=4.1.0=pyhd8ed1ab_1 35 | - hpack=4.1.0=pyhd8ed1ab_0 36 | - hyperframe=6.1.0=pyhd8ed1ab_0 37 | - idna=3.10=pyhd8ed1ab_1 38 | - intel-openmp=2022.0.1=h06a4308_3633 39 | - jinja2=3.1.5=pyhd8ed1ab_0 40 | - ld_impl_linux-64=2.43=h712a8e2_2 41 | - libblas=3.9.0=16_linux64_mkl 42 | - libcblas=3.9.0=16_linux64_mkl 43 | - libffi=3.4.2=h7f98852_5 44 | - libgcc=14.2.0=h77fa898_1 45 | - libgcc-ng=14.2.0=h69a702a_1 46 | - libgfortran=14.2.0=h69a702a_1 47 | - libgfortran-ng=14.2.0=h69a702a_1 48 | - libgfortran5=14.2.0=hd5240d6_1 49 | - libgomp=14.2.0=h77fa898_1 50 | - liblapack=3.9.0=16_linux64_mkl 51 | - liblzma=5.6.4=hb9d3cd8_0 52 | - libnsl=2.0.1=hd590300_0 53 | - libsqlite=3.48.0=hee588c1_1 54 | - libstdcxx=14.2.0=hc0a3c3a_1 55 | - libstdcxx-ng=14.2.0=h4852527_1 56 | - libuuid=2.38.1=h0b41bf4_0 57 | - libxcrypt=4.4.36=hd590300_1 58 | - libzlib=1.3.1=hb9d3cd8_2 59 | - llvm-openmp=15.0.7=h0cdce71_0 60 | - looseversion=1.3.0=pyhd8ed1ab_0 61 | - markupsafe=3.0.2=py39h9399b63_1 62 | - mkl=2022.1.0=hc2b9512_224 63 | - mmtf-python=1.1.3=pyhd8ed1ab_0 64 | - mpc=1.3.1=h24ddda3_1 65 | - mpfr=4.2.1=h90cbb55_3 66 | - mpmath=1.3.0=pyhd8ed1ab_1 67 | - msgpack-python=1.1.0=py39h74842e3_0 68 | - multidict=6.1.0=py39h9399b63_2 69 | - ncurses=6.5=h2d0b736_3 70 | - networkx=3.2.1=pyhd8ed1ab_0 71 | - numpy=1.26.4=py39h474f0d3_0 72 | - openssl=3.4.0=h7b32b05_1 73 | - pandas=2.2.3=py39h3b40f6f_2 74 | - pip=25.0=pyh8b19718_0 75 | - propcache=0.2.1=py39h9399b63_1 76 | - psutil=6.1.1=py39h8cd3c5a_0 77 | - pycparser=2.22=pyh29332c3_1 78 | - pyg=2.6.1=py39_torch_2.4.0_cpu 79 | - pyparsing=3.2.1=pyhd8ed1ab_0 80 | - pysocks=1.7.1=pyha55dd90_7 81 | - python=3.9.21=h9c0c6dc_1_cpython 82 | - python-dateutil=2.9.0.post0=pyhff2d567_1 83 | - python-tzdata=2025.1=pyhd8ed1ab_0 84 | - python_abi=3.9=5_cp39 85 | - pytorch=2.4.1=py3.9_cpu_0 86 | - pytorch-mutex=1.0=cpu 87 | - pytz=2024.1=pyhd8ed1ab_0 88 | - pyyaml=6.0.2=py39h9399b63_2 89 | - readline=8.2=h8228510_1 90 | - requests=2.32.3=pyhd8ed1ab_1 91 | - scipy=1.13.1=py39haf93ffa_0 92 | - setuptools=75.8.0=pyhff2d567_0 93 | - six=1.17.0=pyhd8ed1ab_0 94 | - sympy=1.13.3=pyh2585a3b_105 95 | - tk=8.6.13=noxft_h4845f30_101 96 | - tqdm=4.67.1=pyhd8ed1ab_1 97 | - typing-extensions=4.12.2=hd8ed1ab_1 98 | - typing_extensions=4.12.2=pyha770c72_1 99 | - tzdata=2025a=h78e105d_0 100 | - urllib3=2.3.0=pyhd8ed1ab_0 101 | - wheel=0.45.1=pyhd8ed1ab_1 102 | - yaml=0.2.5=h7f98852_2 103 | - yarl=1.18.3=py39h9399b63_1 104 | - zstandard=0.23.0=py39h08a7858_1 105 | - zstd=1.5.6=ha6fb4c9_0 106 | - pip: 107 | - fair-esm==2.0.1 108 | - pillow==11.1.0 109 | - rdkit==2024.9.4 110 | prefix: /disk2/fli/miniconda3/envs/esmif2 111 | -------------------------------------------------------------------------------- /SSMuLA/est_ep.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to estimate pairwise epistasis in a given dataset. 3 | """ 4 | 5 | import os 6 | from glob import glob 7 | import pandas as pd 8 | import numpy as np 9 | 10 | from SSMuLA.landscape_global import LIB_INFO_DICT, lib2prot 11 | from SSMuLA.fitness_process_vis import parse_lib_stat 12 | 13 | from Bio.PDB import PDBParser, PDBIO 14 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 15 | import warnings 16 | import itertools 17 | 18 | # Suppress PDB construction warnings 19 | warnings.simplefilter('ignore', PDBConstructionWarning) 20 | 21 | def get_ca_distance(structure, residue1, residue2, chain_id='A'): 22 | """ 23 | Calculate the C-alpha distance between two residues within a structure object. 24 | 25 | Parameters: 26 | - structure: PDB structure object. 27 | - residue1, residue2: Residue numbers (integers) of the residues to measure. 28 | - chain_id: ID of the chain where the residues are located (default is chain 'A'). 29 | 30 | Returns: 31 | - distance: Distance between the C-alpha atoms of the specified residues. 32 | """ 33 | # Select chain and residues 34 | chain = structure[0][chain_id] # Assume using the first model 35 | res1 = chain[residue1] 36 | res2 = chain[residue2] 37 | 38 | # Fetch the 'CA' atoms if they exist 39 | if 'CA' in res1 and 'CA' in res2: 40 | ca1 = res1['CA'] 41 | ca2 = res2['CA'] 42 | # Calculate distance 43 | distance = ca1 - ca2 44 | return distance 45 | else: 46 | return None 47 | 48 | def calculate_pairwise_distances(pdb_file, residues_dict, chain_id='A'): 49 | """ 50 | Calculate pairwise C-alpha distances for a set of residues specified in a dictionary. 51 | 52 | Parameters: 53 | - pdb_file: Path to the PDB file. 54 | - residues_dict: Dictionary mapping indices to residue numbers. 55 | - chain_id: Chain ID to look for residues. 56 | 57 | Returns: 58 | - distances: Dictionary of tuple (residue pair) to distance. 59 | """ 60 | # Parse the PDB file 61 | parser = PDBParser() 62 | structure = parser.get_structure('PDB', pdb_file) 63 | 64 | # Calculate distances for all pairs 65 | distances = {} 66 | for (idx1, res1), (idx2, res2) in itertools.combinations(residues_dict.items(), 2): 67 | distance = get_ca_distance(structure, res1, res2, chain_id) 68 | distances[(res1, res2)] = distance 69 | 70 | return distances 71 | 72 | 73 | def all_lib_pairwise_dist( 74 | data_dir: str = "data", 75 | ): 76 | 77 | """ 78 | Calculate pairwise distances for all libraries in the specified data directory. 79 | 80 | Args: 81 | - data_dir: Directory containing PDB files for each library. 82 | 83 | Returns: 84 | - pwd: DataFrame containing mean and standard deviation of distances for each library. 85 | """ 86 | df = pd.DataFrame(columns=["lib", "res1", "res2", "dist"]) 87 | 88 | chain_id = "A" 89 | for lib, l_d in LIB_INFO_DICT.items(): 90 | pdb_path = os.path.join(data_dir, f"{lib2prot(lib)}/{lib2prot(lib)}.pdb") 91 | 92 | parser = PDBParser() 93 | structure = parser.get_structure("PDB", pdb_path) 94 | 95 | for (idx1, res_id1), (idx2, res_id2) in itertools.combinations( 96 | l_d["positions"].items(), 2 97 | ): 98 | df = df._append( 99 | { 100 | "lib": lib, 101 | "res1": res_id1, 102 | "res2": res_id2, 103 | "dist": get_ca_distance(structure, res_id1, res_id2, chain_id), 104 | }, 105 | ignore_index=True, 106 | ) 107 | 108 | pwd = df[["lib", "dist"]].groupby(["lib"]).agg(["mean", "std"]) 109 | pwd.columns = ["mean", "std"] 110 | 111 | return pwd -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignore 2 | 3 | *.zip 4 | *.tar.gz 5 | 6 | data 7 | data/ 8 | data/* 9 | 10 | data_old 11 | data_old/ 12 | data_old/* 13 | 14 | data4upload 15 | data4upload/ 16 | data4upload/* 17 | 18 | coves_data 19 | coves_data/ 20 | coves_data/* 21 | 22 | add_data 23 | add_data/ 24 | add_data/* 25 | 26 | coves 27 | coves/ 28 | coves/* 29 | 30 | results 31 | results/ 32 | results/* 33 | results.tar.gz 34 | 35 | results_rev 36 | results_rev/ 37 | results_rev/* 38 | 39 | results_raw 40 | results_raw/ 41 | results_raw/* 42 | 43 | results4upload 44 | results4upload/ 45 | results4upload/* 46 | 47 | figs 48 | figs/ 49 | figs/* 50 | 51 | runs 52 | runs/ 53 | runs/* 54 | 55 | logs 56 | logs/ 57 | logs/* 58 | 59 | esmif/*csv 60 | 61 | ev_esm 62 | ev_esm/ 63 | ev_esm/* 64 | 65 | ev_esm1v 66 | ev_esm1v/ 67 | ev_esm1v/* 68 | 69 | ev_esm2 70 | ev_esm2/ 71 | ev_esm2/* 72 | 73 | gvp-pytorch 74 | gvp-pytorch/ 75 | gvp-pytorch/* 76 | 77 | lmdb 78 | lmdb/ 79 | lmdb/* 80 | 81 | triad 82 | triad/ 83 | triad/* 84 | 85 | mlde 86 | mlde/ 87 | mlde/* 88 | 89 | mlde_messy 90 | mlde_messy/ 91 | mlde_messy/* 92 | 93 | learned_emb 94 | learned_emb/ 95 | learned_emb/* 96 | 97 | KEJ 98 | KEJ/ 99 | KEJ/* 100 | 101 | sandbox 102 | sandbox/ 103 | sandbox/* 104 | 105 | sandbox/esmif 106 | sandbox/esmif/ 107 | sandbox/esmif/* 108 | 109 | sandbox/fig_svg 110 | sandbox/fig_svg/ 111 | sandbox/fig_svg/* 112 | 113 | zs 114 | zs/ 115 | zs/* 116 | 117 | *.idea 118 | *.npy 119 | 120 | # Byte-compiled / optimized / DLL files 121 | __pycache__/ 122 | *.py[cod] 123 | *$py.class 124 | 125 | # C extensions 126 | *.so 127 | 128 | # Distribution / packaging 129 | .Python 130 | build/ 131 | develop-eggs/ 132 | dist/ 133 | downloads/ 134 | eggs/ 135 | .eggs/ 136 | lib/ 137 | lib64/ 138 | parts/ 139 | sdist/ 140 | var/ 141 | wheels/ 142 | pip-wheel-metadata/ 143 | share/python-wheels/ 144 | *.egg-info/ 145 | .installed.cfg 146 | *.egg 147 | MANIFEST 148 | 149 | # PyInstaller 150 | # Usually these files are written by a python script from a template 151 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 152 | *.manifest 153 | *.spec 154 | 155 | # Installer logs 156 | pip-log.txt 157 | pip-delete-this-directory.txt 158 | 159 | # Unit test / coverage reports 160 | htmlcov/ 161 | .tox/ 162 | .nox/ 163 | .coverage 164 | .coverage.* 165 | .cache 166 | nosetests.xml 167 | coverage.xml 168 | *.cover 169 | *.py,cover 170 | .hypothesis/ 171 | .pytest_cache/ 172 | 173 | # Translations 174 | *.mo 175 | *.pot 176 | 177 | # Django stuff: 178 | *.log 179 | local_settings.py 180 | db.sqlite3 181 | db.sqlite3-journal 182 | 183 | # Flask stuff: 184 | instance/ 185 | .webassets-cache 186 | 187 | # Scrapy stuff: 188 | .scrapy 189 | 190 | # Sphinx documentation 191 | docs/_build/ 192 | 193 | # PyBuilder 194 | target/ 195 | 196 | # Jupyter Notebook 197 | .ipynb_checkpoints 198 | 199 | # IPython 200 | profile_default/ 201 | ipython_config.py 202 | 203 | # pyenv 204 | .python-version 205 | 206 | # pipenv 207 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 208 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 209 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 210 | # install all needed dependencies. 211 | #Pipfile.lock 212 | 213 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 214 | __pypackages__/ 215 | 216 | # Celery stuff 217 | celerybeat-schedule 218 | celerybeat.pid 219 | 220 | # SageMath parsed files 221 | *.sage.py 222 | 223 | # Environments 224 | .env 225 | .venv 226 | env/ 227 | venv/ 228 | ENV/ 229 | env.bak/ 230 | venv.bak/ 231 | 232 | # Spyder project settings 233 | .spyderproject 234 | .spyproject 235 | 236 | # Rope project settings 237 | .ropeproject 238 | 239 | # mkdocs documentation 240 | /site 241 | 242 | # mypy 243 | .mypy_cache/ 244 | .dmypy.json 245 | dmypy.json 246 | 247 | # Pyre type checker 248 | .pyre/ 249 | -------------------------------------------------------------------------------- /SSMuLA/util.py: -------------------------------------------------------------------------------- 1 | """Util functions""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import pickle 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from sklearn.metrics import ndcg_score 12 | 13 | 14 | def checkNgen_folder(folder_path: str) -> str: 15 | 16 | """ 17 | Check if the folder and its subfolder exists 18 | create a new directory if not 19 | Args: 20 | - folder_path: str, the folder path 21 | """ 22 | 23 | # if input path is file 24 | if bool(os.path.splitext(folder_path)[1]): 25 | folder_path = os.path.dirname(folder_path) 26 | 27 | split_list = os.path.normpath(folder_path).split("/") 28 | for p, _ in enumerate(split_list): 29 | subfolder_path = "/".join(split_list[: p + 1]) 30 | if not os.path.exists(subfolder_path): 31 | print(f"Making {subfolder_path} ...") 32 | os.mkdir(subfolder_path) 33 | return folder_path 34 | 35 | 36 | def pickle_save(what2save, where2save: str) -> None: 37 | 38 | """ 39 | Save variable to a pickle file 40 | Args: 41 | - what2save, the varible that needs to be saved 42 | - where2save: str, the .pkl path for saving 43 | """ 44 | 45 | with open(where2save, "wb") as f: 46 | pickle.dump(what2save, f) 47 | 48 | 49 | def pickle_load(path2load: str): 50 | 51 | """ 52 | Load pickle file 53 | Args: 54 | - path2load: str, the .pkl path for loading 55 | """ 56 | 57 | with open(path2load, "rb") as f: 58 | return pickle.load(f) 59 | 60 | 61 | def get_file_name(file_path: str) -> str: 62 | 63 | """ 64 | Extract file name without the extension 65 | Args: 66 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb 67 | Returns: 68 | - str, ie WT 69 | """ 70 | 71 | return os.path.splitext(os.path.basename(file_path))[0] 72 | 73 | 74 | def get_dir_name(file_path: str) -> str: 75 | 76 | """ 77 | Extract dir name 78 | Args: 79 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb 80 | Returns: 81 | - str, ie Tm9D8s_3siteA_fixed 82 | """ 83 | 84 | return os.path.basename(os.path.dirname(file_path)) 85 | 86 | 87 | def get_dirNfile_name(file_path: str) -> [str, str]: 88 | 89 | """ 90 | Extract file name without the extension and direct dir name 91 | Args: 92 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb 93 | Returns: 94 | - str, ie ['Tm9D8s_3siteA_fixed', 'WT'] 95 | """ 96 | 97 | return ( 98 | os.path.basename(os.path.dirname(file_path)), 99 | os.path.splitext(os.path.basename(file_path))[0], 100 | ) 101 | 102 | 103 | def get_fulldirNfile_name(file_path: str) -> [str, str]: 104 | 105 | """ 106 | Extract file name without the extension and full dir name 107 | Args: 108 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb 109 | Returns: 110 | - str, ie ['data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed', 'WT'] 111 | """ 112 | 113 | return os.path.dirname(file_path), os.path.splitext(os.path.basename(file_path))[0] 114 | 115 | 116 | def ndcg_scale(true: np.ndarray, pred: np.ndarray): 117 | """Calculate the ndcg_score with neg correction""" 118 | 119 | if min(true) < 0: 120 | true = true - min(true) 121 | return ndcg_score(true[None, :], pred[None, :]) 122 | 123 | 124 | def ecdf_transform(data: pd.Series) -> pd.Series: 125 | 126 | """ 127 | Transform a series of fitness values into an empirical cumulative distribution function 128 | 129 | Args: 130 | - data: pd.Series, the fitness values 131 | 132 | Returns: 133 | - pd.Series, the ECDF 134 | """ 135 | 136 | return data.rank(method="first") / len(data) 137 | 138 | 139 | 140 | def csv2fasta(csv: str) -> None: 141 | """ 142 | A function for converting a csv file to a fasta file 143 | ie /disk2/fli/SSMuLA/ev_esm2/DHFR/DHFR.csv 144 | 145 | """ 146 | df = pd.read_csv(csv) 147 | 148 | for col in ["muts", "seq"]: 149 | if col not in df.columns: 150 | raise ValueError(f"{col} column not found") 151 | 152 | fasta = csv.replace(".csv", ".fasta") 153 | with open(fasta, "w") as f: 154 | for mut, seq in zip(df["muts"].values, df["seq"].values): 155 | f.write(f">{mut}\n{seq}\n") -------------------------------------------------------------------------------- /esmif/score_log_likelihoods.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | # Scores sequences based on a given structure. 7 | # 8 | # usage: 9 | # score_log_likelihoods.py [-h] [--outpath OUTPATH] [--chain CHAIN] pdbfile seqfile 10 | 11 | import argparse 12 | from biotite.sequence.io.fasta import FastaFile, get_sequences 13 | import numpy as np 14 | from pathlib import Path 15 | import torch 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | 19 | import esm 20 | import esm.inverse_folding 21 | 22 | 23 | def score_singlechain_backbone(model, alphabet, args): 24 | if torch.cuda.is_available() and not args.nogpu: 25 | model = model.cuda() 26 | print("Transferred model to GPU") 27 | coords, native_seq = esm.inverse_folding.util.load_coords(args.pdbfile, args.chain) 28 | print('Native sequence loaded from structure file:') 29 | print(native_seq) 30 | print('\n') 31 | 32 | ll, _ = esm.inverse_folding.util.score_sequence( 33 | model, alphabet, coords, native_seq) 34 | print('Native sequence') 35 | print(f'Log likelihood: {ll:.2f}') 36 | print(f'Perplexity: {np.exp(-ll):.2f}') 37 | 38 | print('\nScoring variant sequences from sequence file..\n') 39 | infile = FastaFile() 40 | infile.read(args.seqfile) 41 | seqs = get_sequences(infile) 42 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True) 43 | with open(args.outpath, 'w') as fout: 44 | fout.write('seqid,log_likelihood\n') 45 | for header, seq in tqdm(seqs.items()): 46 | ll, _ = esm.inverse_folding.util.score_sequence( 47 | model, alphabet, coords, str(seq)) 48 | fout.write(header + ',' + str(ll) + '\n') 49 | print(f'Results saved to {args.outpath}') 50 | 51 | 52 | def score_multichain_backbone(model, alphabet, args): 53 | if torch.cuda.is_available() and not args.nogpu: 54 | model = model.cuda() 55 | print("Transferred model to GPU") 56 | structure = esm.inverse_folding.util.load_structure(args.pdbfile) 57 | coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure) 58 | target_chain_id = args.chain 59 | native_seq = native_seqs[target_chain_id] 60 | print('Native sequence loaded from structure file:') 61 | print(native_seq) 62 | print('\n') 63 | 64 | ll, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex( 65 | model, alphabet, coords, target_chain_id, native_seq) 66 | print('Native sequence') 67 | print(f'Log likelihood: {ll:.2f}') 68 | print(f'Perplexity: {np.exp(-ll):.2f}') 69 | 70 | print('\nScoring variant sequences from sequence file..\n') 71 | infile = FastaFile() 72 | infile.read(args.seqfile) 73 | seqs = get_sequences(infile) 74 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True) 75 | with open(args.outpath, 'w') as fout: 76 | fout.write('seqid,log_likelihood\n') 77 | for header, seq in tqdm(seqs.items()): 78 | ll, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex( 79 | model, alphabet, coords, target_chain_id, str(seq)) 80 | fout.write(header + ',' + str(ll) + '\n') 81 | print(f'Results saved to {args.outpath}') 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser( 86 | description='Score sequences based on a given structure.' 87 | ) 88 | parser.add_argument( 89 | 'pdbfile', type=str, 90 | help='input filepath, either .pdb or .cif', 91 | ) 92 | parser.add_argument( 93 | 'seqfile', type=str, 94 | help='input filepath for variant sequences in a .fasta file', 95 | ) 96 | parser.add_argument( 97 | '--outpath', type=str, 98 | help='output filepath for scores of variant sequences', 99 | default='output/sequence_scores.csv', 100 | ) 101 | parser.add_argument( 102 | '--chain', type=str, 103 | help='chain id for the chain of interest', default='A', 104 | ) 105 | parser.set_defaults(multichain_backbone=False) 106 | parser.add_argument( 107 | '--multichain-backbone', action='store_true', 108 | help='use the backbones of all chains in the input for conditioning' 109 | ) 110 | parser.add_argument( 111 | '--singlechain-backbone', dest='multichain_backbone', 112 | action='store_false', 113 | help='use the backbone of only target chain in the input for conditioning' 114 | ) 115 | 116 | parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") 117 | 118 | args = parser.parse_args() 119 | 120 | model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() 121 | model = model.eval() 122 | 123 | if args.multichain_backbone: 124 | score_multichain_backbone(model, alphabet, args) 125 | else: 126 | score_singlechain_backbone(model, alphabet, args) 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /envs/frozen/coves.yml: -------------------------------------------------------------------------------- 1 | name: coves 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - aiohappyeyeballs=2.4.4=pyhd8ed1ab_1 11 | - aiohttp=3.11.11=py39h9399b63_0 12 | - aiosignal=1.3.2=pyhd8ed1ab_0 13 | - async-timeout=5.0.1=pyhd8ed1ab_1 14 | - attrs=25.1.0=pyh71513ae_0 15 | - blas=1.0=mkl 16 | - brotli-python=1.1.0=py39hf88036b_2 17 | - bzip2=1.0.8=h4bc722e_7 18 | - ca-certificates=2024.12.14=hbcca054_0 19 | - certifi=2024.12.14=pyhd8ed1ab_0 20 | - cffi=1.17.1=py39h15c3d72_0 21 | - charset-normalizer=3.4.1=pyhd8ed1ab_0 22 | - colorama=0.4.6=pyhd8ed1ab_1 23 | - cpython=3.9.21=py39hd8ed1ab_1 24 | - cuda-cudart=12.4.127=he02047a_2 25 | - cuda-cudart_linux-64=12.4.127=h85509e4_2 26 | - cuda-cupti=12.4.127=he02047a_2 27 | - cuda-libraries=12.4.1=ha770c72_1 28 | - cuda-nvrtc=12.4.127=he02047a_2 29 | - cuda-nvtx=12.4.127=he02047a_2 30 | - cuda-opencl=12.4.127=he02047a_1 31 | - cuda-runtime=12.4.1=ha804496_0 32 | - cuda-version=12.4=h3060b56_3 33 | - filelock=3.17.0=pyhd8ed1ab_0 34 | - frozenlist=1.5.0=py39h9399b63_1 35 | - fsspec=2024.12.0=pyhd8ed1ab_0 36 | - gmp=6.3.0=hac33072_2 37 | - gmpy2=2.1.5=py39h7196dd7_3 38 | - h2=4.1.0=pyhd8ed1ab_1 39 | - hpack=4.1.0=pyhd8ed1ab_0 40 | - hyperframe=6.1.0=pyhd8ed1ab_0 41 | - idna=3.10=pyhd8ed1ab_1 42 | - intel-openmp=2022.0.1=h06a4308_3633 43 | - jinja2=3.1.5=pyhd8ed1ab_0 44 | - joblib=1.4.2=pyhd8ed1ab_1 45 | - ld_impl_linux-64=2.43=h712a8e2_2 46 | - libblas=3.9.0=16_linux64_mkl 47 | - libcblas=3.9.0=16_linux64_mkl 48 | - libcublas=12.4.5.8=he02047a_2 49 | - libcufft=11.2.1.3=he02047a_2 50 | - libcufile=1.9.1.3=he02047a_2 51 | - libcurand=10.3.5.147=he02047a_2 52 | - libcusolver=11.6.1.9=he02047a_2 53 | - libcusparse=12.3.1.170=he02047a_2 54 | - libffi=3.4.2=h7f98852_5 55 | - libgcc=14.2.0=h77fa898_1 56 | - libgcc-ng=14.2.0=h69a702a_1 57 | - libgfortran=14.2.0=h69a702a_1 58 | - libgfortran-ng=14.2.0=h69a702a_1 59 | - libgfortran5=14.2.0=hd5240d6_1 60 | - libgomp=14.2.0=h77fa898_1 61 | - liblapack=3.9.0=16_linux64_mkl 62 | - liblzma=5.6.3=hb9d3cd8_1 63 | - libnpp=12.2.5.30=he02047a_2 64 | - libnsl=2.0.1=hd590300_0 65 | - libnvfatbin=12.4.127=he02047a_2 66 | - libnvjitlink=12.4.127=he02047a_2 67 | - libnvjpeg=12.3.1.117=he02047a_2 68 | - libsqlite=3.48.0=hee588c1_1 69 | - libstdcxx=14.2.0=hc0a3c3a_1 70 | - libstdcxx-ng=14.2.0=h4852527_1 71 | - libuuid=2.38.1=h0b41bf4_0 72 | - libxcrypt=4.4.36=hd590300_1 73 | - libzlib=1.3.1=hb9d3cd8_2 74 | - llvm-openmp=15.0.7=h0cdce71_0 75 | - markupsafe=3.0.2=py39h9399b63_1 76 | - mkl=2022.1.0=hc2b9512_224 77 | - mpc=1.3.1=h24ddda3_1 78 | - mpfr=4.2.1=h90cbb55_3 79 | - mpmath=1.3.0=pyhd8ed1ab_1 80 | - multidict=6.1.0=py39h9399b63_2 81 | - ncurses=6.5=h2d0b736_2 82 | - networkx=3.2.1=pyhd8ed1ab_0 83 | - ocl-icd=2.3.2=hb9d3cd8_2 84 | - opencl-headers=2024.10.24=h5888daf_0 85 | - openssl=3.4.0=h7b32b05_1 86 | - pandas=2.2.3=py39h3b40f6f_2 87 | - pip=24.3.1=pyh8b19718_2 88 | - propcache=0.2.1=py39h9399b63_1 89 | - psutil=6.1.1=py39h8cd3c5a_0 90 | - pycparser=2.22=pyh29332c3_1 91 | - pyg=2.6.1=py39_torch_2.4.0_cu124 92 | - pyparsing=3.2.1=pyhd8ed1ab_0 93 | - pysocks=1.7.1=pyha55dd90_7 94 | - python=3.9.21=h9c0c6dc_1_cpython 95 | - python-dateutil=2.9.0.post0=pyhff2d567_1 96 | - python-tzdata=2025.1=pyhd8ed1ab_0 97 | - python_abi=3.9=5_cp39 98 | - pytorch-cuda=12.4=hc786d27_7 99 | - pytorch-mutex=1.0=cuda 100 | - pytz=2024.1=pyhd8ed1ab_0 101 | - pyyaml=6.0.2=py39h9399b63_2 102 | - readline=8.2=h8228510_1 103 | - requests=2.32.3=pyhd8ed1ab_1 104 | - scikit-learn=1.6.1=py39h4b7350c_0 105 | - scipy=1.13.1=py39haf93ffa_0 106 | - setuptools=75.8.0=pyhff2d567_0 107 | - six=1.17.0=pyhd8ed1ab_0 108 | - sympy=1.13.3=pyh2585a3b_105 109 | - threadpoolctl=3.5.0=pyhc1e730c_0 110 | - tk=8.6.13=noxft_h4845f30_101 111 | - tqdm=4.67.1=pyhd8ed1ab_1 112 | - typing-extensions=4.12.2=hd8ed1ab_1 113 | - typing_extensions=4.12.2=pyha770c72_1 114 | - tzdata=2025a=h78e105d_0 115 | - urllib3=2.3.0=pyhd8ed1ab_0 116 | - wheel=0.45.1=pyhd8ed1ab_1 117 | - yaml=0.2.5=h7f98852_2 118 | - yarl=1.18.3=py39h9399b63_1 119 | - zstandard=0.23.0=py39h08a7858_1 120 | - zstd=1.5.6=ha6fb4c9_0 121 | - pip: 122 | - anyio==4.8.0 123 | - argon2-cffi==23.1.0 124 | - argon2-cffi-bindings==21.2.0 125 | - argparse==1.4.0 126 | - arrow==1.3.0 127 | - asttokens==3.0.0 128 | - async-lru==2.0.4 129 | - atom3d==0.2.6 130 | - babel==2.16.0 131 | - beautifulsoup4==4.12.3 132 | - biopython==1.85 133 | - black==21.12b0 134 | - blackcellmagic==0.0.3 135 | - bleach==6.2.0 136 | - blosc2==2.5.1 137 | - click==8.1.8 138 | - comm==0.2.2 139 | - debugpy==1.8.12 140 | - decorator==5.1.1 141 | - defusedxml==0.7.1 142 | - dill==0.3.9 143 | - easy-parallel==0.1.6 144 | - exceptiongroup==1.2.2 145 | - executing==2.2.0 146 | - fastjsonschema==2.21.1 147 | - fqdn==1.5.1 148 | - freesasa==2.2.1 149 | - h11==0.14.0 150 | - h5py==3.12.1 151 | - httpcore==1.0.7 152 | - httpx==0.28.1 153 | - importlib-metadata==8.6.1 154 | - ipykernel==6.29.5 155 | - ipython==8.18.1 156 | - ipywidgets==8.1.5 157 | - isoduration==20.11.0 158 | - jedi==0.19.2 159 | - json5==0.10.0 160 | - jsonpointer==3.0.0 161 | - jsonschema==4.23.0 162 | - jsonschema-specifications==2024.10.1 163 | - jupyter==1.1.1 164 | - jupyter-client==8.6.3 165 | - jupyter-console==6.6.3 166 | - jupyter-core==5.7.2 167 | - jupyter-events==0.11.0 168 | - jupyter-lsp==2.2.5 169 | - jupyter-server==2.15.0 170 | - jupyter-server-terminals==0.5.3 171 | - jupyterlab==4.3.4 172 | - jupyterlab-pygments==0.3.0 173 | - jupyterlab-server==2.27.3 174 | - jupyterlab-widgets==3.0.13 175 | - lmdb==1.6.2 176 | - matplotlib-inline==0.1.7 177 | - mistune==3.1.0 178 | - msgpack==1.1.0 179 | - multipledispatch==1.0.0 180 | - multiprocess==0.70.17 181 | - mypy-extensions==1.0.0 182 | - nbclient==0.10.2 183 | - nbconvert==7.16.5 184 | - nbformat==5.10.4 185 | - ndindex==1.9.2 186 | - nest-asyncio==1.6.0 187 | - notebook==7.3.2 188 | - notebook-shim==0.2.4 189 | - numexpr==2.10.2 190 | - numpy==1.23.5 191 | - overrides==7.7.0 192 | - packaging==24.2 193 | - pandocfilters==1.5.1 194 | - parso==0.8.4 195 | - pathos==0.3.3 196 | - pathspec==0.12.1 197 | - pexpect==4.9.0 198 | - pillow==11.1.0 199 | - platformdirs==4.3.6 200 | - pox==0.3.5 201 | - ppft==1.7.6.9 202 | - prometheus-client==0.21.1 203 | - prompt-toolkit==3.0.50 204 | - ptyprocess==0.7.0 205 | - pure-eval==0.2.3 206 | - py-cpuinfo==9.0.0 207 | - pygments==2.19.1 208 | - pyrr==0.10.3 209 | - python-dotenv==1.0.1 210 | - python-json-logger==3.2.1 211 | - pyzmq==26.2.0 212 | - rdkit==2024.9.4 213 | - referencing==0.36.2 214 | - rfc3339-validator==0.1.4 215 | - rfc3986-validator==0.1.1 216 | - rpds-py==0.22.3 217 | - send2trash==1.8.3 218 | - sniffio==1.3.1 219 | - soupsieve==2.6 220 | - stack-data==0.6.3 221 | - tables==3.9.2 222 | - terminado==0.18.1 223 | - tinycss2==1.4.0 224 | - tomli==1.2.3 225 | - torch==2.1.0+cu121 226 | - torch-cluster==1.6.3+pt21cu121 227 | - torchaudio==2.1.0+cu121 228 | - torchvision==0.16.0+cu121 229 | - tornado==6.4.2 230 | - traitlets==5.14.3 231 | - triton==2.1.0 232 | - types-python-dateutil==2.9.0.20241206 233 | - uri-template==1.3.0 234 | - wcwidth==0.2.13 235 | - webcolors==24.11.1 236 | - webencodings==0.5.1 237 | - websocket-client==1.8.0 238 | - widgetsnbextension==4.0.13 239 | - zipp==3.21.0 240 | prefix: /disk2/fli/miniconda3/envs/coves 241 | -------------------------------------------------------------------------------- /SSMuLA/zs_calc.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work 3 | """ 4 | 5 | # Import packages 6 | import os 7 | from glob import glob 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import random 12 | from collections import Counter 13 | 14 | 15 | import argparse 16 | from pathlib import Path 17 | 18 | from tqdm import tqdm 19 | from typing import List, Tuple, Optional 20 | 21 | 22 | import warnings 23 | 24 | from SSMuLA.landscape_global import lib2prot 25 | from SSMuLA.zs_models import ZeroShotPrediction, ESM, EvMutation 26 | from SSMuLA.zs_data import DataProcessor 27 | from SSMuLA.util import get_file_name, checkNgen_folder 28 | 29 | 30 | # TODO clean up evmutation model path 31 | 32 | 33 | def calc_zs( 34 | fit_df_path: str, 35 | scalefit : str = "scale2max", 36 | output_folder: str = "results/zs", 37 | zs_model_names: str = "all", 38 | ev_model_folder: str = "data/evmodels", 39 | regen_esm: bool = False, 40 | rerun_zs: bool = False, 41 | ) -> pd.DataFrame: 42 | 43 | """ 44 | A function for calculating zs scores and adding them to the fitness csv 45 | 46 | Args: 47 | - input_folder: str, input folder of landscape 48 | ie, 'data/processed/AAV2_Bryant_2021', 49 | contains csv and fasta with wt seq 50 | - output_folder: str = "results/zs", with landscape subfolders 51 | - zs_model_names: str, name(s) of zero-shot models to use seperated by comma, 52 | available: 'esm', 'ev' 53 | developing: 'ddg', 'Bert' 54 | all: 'all' runs all currently available models 55 | ie, 'esm, ev' 56 | - ev_model_folder: str = "data/evmodels", folder for evmodels, 57 | with dataset name and dataset name.model 58 | ie. data/evmodels/AAV2_Bryant_2021/AAV2_Bryant_2021.model 59 | - regen_esm: str = False, if regenerate esm logits or load directly 60 | - rerun_zs: str = False, if append new zs to current csv or create new output 61 | """ 62 | 63 | # deal with the / 64 | input_folder = os.path.normpath(fit_df_path.split(scalefit)[0]) 65 | output_folder = os.path.normpath(output_folder) 66 | 67 | landscape_name = get_file_name(fit_df_path) 68 | 69 | if "DHFR" in landscape_name: 70 | append_fasta = "_trans" 71 | else: 72 | append_fasta = "" 73 | 74 | # fit_df_path = os.path.join(input_folder, landscape_name + ".csv") 75 | fasta_path = os.path.join(input_folder, lib2prot(landscape_name) + append_fasta + ".fasta") 76 | 77 | ev_model_path = os.path.join( 78 | os.path.normpath(ev_model_folder), lib2prot(landscape_name), lib2prot(landscape_name) + ".model" 79 | ) 80 | 81 | print(fasta_path, ev_model_path) 82 | 83 | # check if file exists 84 | assert os.path.exists(fit_df_path), f"{fit_df_path} does not exist" 85 | assert os.path.exists(fasta_path), f"{fasta_path} does not exist" 86 | 87 | output_folder = checkNgen_folder(output_folder) 88 | landscape_output_folder = checkNgen_folder( 89 | os.path.join(output_folder, landscape_name) 90 | ) 91 | 92 | zs_df_path = os.path.join(landscape_output_folder, landscape_name + ".csv") 93 | 94 | # Create an instance of the DataProcessor class 95 | data_processor = DataProcessor() 96 | 97 | # Call the prepare_zero_shot method 98 | data = data_processor.prepare_zero_shot( 99 | fit_df_path, fasta_path, _combo=True, _pos=True 100 | ) 101 | 102 | # init df 103 | existing_zs_df = data 104 | 105 | # check if exist 106 | if os.path.exists(zs_df_path): 107 | if rerun_zs: 108 | print(f"{zs_df_path} exists. Remove for rerun_zs = {rerun_zs}") 109 | os.remove(zs_df_path) 110 | else: 111 | print( 112 | f"{zs_df_path} exists. Append new zs {zs_model_names} for rerun_zs = {rerun_zs}" 113 | ) 114 | existing_zs_df = pd.read_csv(zs_df_path) 115 | 116 | # Ref Sequence 117 | wt = data_processor.get_Seq(fasta_path) 118 | 119 | if zs_model_names == "all": 120 | zs_model_list = ["esm", "ev"] 121 | else: 122 | zs_model_list = zs_model_names.split(",") 123 | 124 | # init zs_df_list 125 | zs_df_list = [] 126 | 127 | print(f"zs_model_list: {zs_model_list}") 128 | 129 | # TODO ESM load logits directly 130 | 131 | for zs_model_name in zs_model_list: 132 | 133 | # get max numb of muts 134 | max_numb_mut = max(data["combo"].str.len()) 135 | 136 | # Access Model 137 | if "esm" in zs_model_name: 138 | 139 | logits_path = os.path.join(landscape_output_folder, landscape_name + "_logits.npy") 140 | 141 | esm = ESM(data, wt, logits_path=logits_path, regen_esm=regen_esm) 142 | 143 | if os.path.exists(logits_path) and not(regen_esm): 144 | print(f"{logits_path} exists and regen_esm = {regen_esm}. Loading...") 145 | log_reprs = np.load(logits_path) 146 | else: 147 | print(f"Generating {logits_path}...") 148 | log_reprs = esm._get_logits() 149 | np.save(logits_path, log_reprs) 150 | 151 | score_esm = esm._get_n_score(list(range(max_numb_mut+1))[1:]) 152 | zs_df_list.append(score_esm) 153 | print(f"score_esm:\n{score_esm.head()}") 154 | 155 | elif "ev" in zs_model_name: 156 | if os.path.exists(ev_model_path): 157 | ev = EvMutation(data, wt, model_path=ev_model_path) 158 | score_ev = ev._get_n_score(list(range(max_numb_mut+1))[1:]) 159 | zs_df_list.append(score_ev) 160 | print(f"score_ev:\n{score_ev.head()}") 161 | else: 162 | print(f"{ev_model_path} does not exist yet. Skipping...") 163 | 164 | # TODO add remaining zs_models 165 | 166 | elif zs_model_name == "ddg": 167 | pass 168 | else: 169 | print("Model currently not available") 170 | continue 171 | 172 | # Add muts from data to df 173 | print(f"zs_df_list:\n{zs_df_list}") 174 | 175 | for zs_df in zs_df_list: 176 | 177 | df = pd.merge( 178 | existing_zs_df, 179 | zs_df[list(zs_df.columns.difference(existing_zs_df.columns)) + ["muts"]], 180 | left_on="muts", 181 | right_on="muts", 182 | how="outer", 183 | ) 184 | 185 | existing_zs_df = df 186 | 187 | return df.to_csv(zs_df_path, index=False) 188 | 189 | 190 | # TODO FIX EV MODEL PATH 191 | def calc_all_zs( 192 | landscape_folder: str = "data", 193 | scalefit : str = "scale2max", 194 | dataset_list: list[str] = [], 195 | output_folder: str = "results/zs", 196 | zs_model_names: str = "all", 197 | ev_model_folder: str = "data", 198 | regen_esm: str = False, 199 | rerun_zs: str = False, 200 | ): 201 | """ 202 | A function for calc same list of zs scores for all landscape datasets 203 | 204 | Args: 205 | - landscape_folder: str = "data/processed", folder path for all landscape data 206 | - dataset_list: list[str] = [], a list of encoders over write dataset_folder, 207 | ie. ['TrpB3I_Johnston_2023'] 208 | - output_folder: str = "results/zs", 209 | - zs_model_names: str = "all", 210 | - ev_model_folder: str = "data/evmodels", folder for evmodels, 211 | with dataset name and dataset name.model 212 | ie. data/evmodels/AAV2_Bryant_2021/AAV2_Bryant_2021.model 213 | - regen_esm: bool = False, if regenerate esm logits or load directly 214 | - rerun_zs: bool = False, if append new zs to current csv or create new output 215 | """ 216 | 217 | if len(dataset_list) == 0: 218 | landscape_paths = sorted(glob(os.path.normpath(landscape_folder) + "/*/" + scalefit + "/*.csv")) 219 | else: 220 | landscape_paths = [ 221 | os.path.join(landscape_folder, dataset) for dataset in dataset_list 222 | ] 223 | 224 | for landscape_path in landscape_paths: 225 | print(f"Calc zs {zs_model_names} for {landscape_path}...") 226 | 227 | _ = calc_zs( 228 | fit_df_path=landscape_path, 229 | scalefit=scalefit, 230 | output_folder=output_folder, 231 | zs_model_names=zs_model_names, 232 | ev_model_folder=ev_model_folder, 233 | regen_esm=regen_esm, 234 | rerun_zs=rerun_zs, 235 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSMuLA 2 | 3 | ## About 4 | * Code base termed "Site Saturation Mutagenesis Landscape Analysis (SSMuLA)" for our [paper](https://doi.org/10.1101/2024.10.24.619774) titled "Evaluation of Machine Learning-Assisted Directed Evolution Across Diverse Combinatorial Landscapes" 5 | * Data and results can be found at [Zenodo](https://doi.org/10.5281/zenodo.13910506) 6 | ![fig1](fig1.png) 7 | 8 | ### Environment 9 | * For the overall environment `SSMuLA` 10 | ``` 11 | conda env create -f SSMuLA.yml 12 | ``` 13 | * Then install EVmutation from the [develop branch](https://github.com/debbiemarkslab/EVcouplings/archive/develop.zip) after the environment is created 14 | * For the ESM-IF environment 15 | ``` 16 | conda create -n inverse python=3.9 17 | conda activate inverse 18 | conda install pytorch cudatoolkit=11.3 -c pytorch 19 | conda install pyg -c pyg -c conda-forge 20 | conda install pip 21 | pip install biotite 22 | pip install git+https://github.com/facebookresearch/esm.git 23 | ``` 24 | or install the ESM-IF environment `esmif` 25 | ``` 26 | conda env create -f esmif.yml 27 | ``` 28 | * For the CoVES environment `coves` 29 | ``` 30 | conda env create -f coves.yml 31 | ``` 32 | * For installing Triad command line, see instructions [here](https://triad.protabit.com/api/static/doc/user/userGettingStarted.html) 33 | * For running ESM-2 fintuning simulations, use the `finetune.yml` environment 34 | ``` 35 | conda env create -f finetune.yml 36 | ``` 37 | * Frozen environment can be found in `envs/frozen` 38 | 39 | ### Datasets 40 | * The `data/` folder is organized by protein type. Each protein directory contains: 41 | - `.fasta`: FASTA file for the parent sequence 42 | - `.pdb`: PDB file for the parent structure 43 | - `.model`: EVmutation model file 44 | - `fitness_landscape/`: Folder containing CSV files for all fitness landscapes for this protein type, each listing amino acid substitutions and their corresponding fitness values from the original sources 45 | - `scale2max/`: the folder containing processed fitness csv files returned from the `process_all` function in the `SSMuLA.fitness_process_vis` module where the maximum fitness value is normalized to 1 for each landscape 46 | 47 | * Landscapes summarized in the table below and described in detail in the paper: 48 | 49 | | Landscape | PDB ID | Sites | 50 | |-----------|--------|------------------------| 51 | | ParD2 | 6X0A | I61, L64, K80 | 52 | | ParD3 | 5CEG | D61, K64, E80 | 53 | | GB1 | 2GI9 | V39, D40, G41, V54 | 54 | | DHFR | 6XG5 | A26, D27, L28 | 55 | | T7 | 1CEZ | N748, R756, Q758 | 56 | | TEV | 1LVM | T146, D148, H167, S170 | 57 | | TrpB3A | 8VHH | A104, E105, T106 | 58 | | TrpB3B | | E105, T106, G107 | 59 | | TrpB3C | | T106, G107, A108 | 60 | | TrpB3D | | T117, A118, A119 | 61 | | TrpB3E | | F184, G185, S186 | 62 | | TrpB3F | | L162, I166, Y301 | 63 | | TrpB3G | | V227, S228, Y301 | 64 | | TrpB3H | | S228, G230, S231 | 65 | | TrpB3I | | Y182, V183, F184 | 66 | | TrpB4 | | V183, F184, V227, S228 | 67 | 68 | 69 | ### Preprocessing 70 | * Run 71 | ``` 72 | python -m tests.test_preprocess 73 | ``` 74 | refer to the test file and the script documentation for further details 75 | * Processed with `fitness_process_vis` 76 | * Rename columns to be `AAs`, `AA1`, `AA2`, `AA3`, `AA4`, `fitness`, add `active` if not already there and add `muts` columns 77 | * Scale to `max` (with option to scale to `parent`) 78 | * Processed data saved in `scale2max` folder 79 | * The landscape stats will be saved 80 | 81 | 82 | ### Landscape attributes 83 | #### Local optima 84 | * Run 85 | ``` 86 | python -m tests.local_optima 87 | ``` 88 | refer to the test file and the script documentation for further details 89 | * Calculate local optima with `calc_local_optima` function in `SSMuLA.local_optima` 90 | 91 | #### Pairwise epistasis 92 | * Run 93 | ``` 94 | python -m tests.pairwise_epistasis 95 | ``` 96 | refer to the test file and the script documentation for further details 97 | * Calculate pairwise epistasis with `calc_all_pairwise_epistasis` function in `SSMuLA.pairwise_epistasis` 98 | * Start from all active variants scaled to max fitness without post filtering 99 | * Initial results will be saved under the default path `results/pairwise_epistasis` folder (corresponding to the `active_start` subfolder in the zenodo repo) 100 | * Post processing the output with `plot_pairwise_epistasis` function in `SSMuLA.pairwise_epistasis` 101 | * Post processed results will be saved under the default path `results/pairwise_epistasis_dets` folder with summary files (corresponding to the `processed` subfolder) and `results/pairwise_epistasis_vis` for each of the landscape with a master summary file across all landscapes (in the `pairwise_epistasis_summary.csv`) 102 | 103 | ### Zero-shot 104 | * The currrent pipeline runs EVmutation and ESM together, and then append the rest based 105 | 106 | #### EVmutation 107 | * All EVmutation predictions run with [EVcouplings](https://v2.evcouplings.org/) 108 | * All settings remain default 109 | * Model parameters in the `.model` files are downloaded and renamed 110 | 111 | #### ESM 112 | * The logits will be generated and saved in the output folder 113 | * Run 114 | ``` 115 | python -m tests.test_ev_esm 116 | ``` 117 | refer to the test file and the script documentation for further details 118 | 119 | #### Hamming distance 120 | * Directly calculated from `n_mut` 121 | * For Hamming ditsance testing, run 122 | ``` 123 | python -m tests.hamming_distance 124 | ``` 125 | to deploy `run_hd_avg_fit` and `run_hd_avg_metric` from `SSMuLA.calc_hd` 126 | refer to the test file and the script documentation for further details 127 | 128 | #### ESM-IF 129 | * Run 130 | ``` 131 | python -m tests.test_esmif 132 | ``` 133 | refer to the test file and the script documentation for further details 134 | * Generate the input fasta files with `get_all_mutfasta` from `SSMuLA.zs_data` to be used in ESM-IF 135 | * Set up the environment for [ESM-IF](https://github.com/facebookresearch/esm?tab=readme-ov-file#invf) to 136 | ``` 137 | conda create -n inverse python=3.9 138 | conda activate inverse 139 | conda install pytorch cudatoolkit=11.3 -c pytorch 140 | conda install pyg -c pyg -c conda-forge 141 | conda install pip 142 | pip install biotite 143 | pip install git+https://github.com/facebookresearch/esm.git 144 | ``` 145 | or use 146 | * With in the `esmif` folder within the new environment, run 147 | ``` 148 | ./esmif.sh 149 | ``` 150 | * ESM-IF results will be saved in the same directory as the `esmif.sh` script 151 | 152 | #### CoVES 153 | * Follow the instructions in the [CoVES](https://github.com/ddingding/CoVES/tree/publish) 154 | * Prepare input data in the `coves_data` folder 155 | * Run `run_all_coves` from `SSMuLA.run_coves` to get all scores 156 | * Append scores with `append_all_coves_scores` from `SSMuLA.run_coves` 157 | 158 | #### Triad 159 | * Prep mutation file in `.mut` format such as `A_1A+A_2A+A_3A+A_4A` with `TriadGenMutFile` class in `SSMuLA.triad_prepost` 160 | * Run 161 | ``` 162 | python -m tests.test_triad_pre 163 | ``` 164 | refer to the test file and the script documentation for further details 165 | * With `triad-2.1.3` local command line 166 | * Prepare structure with `2prep_structures.sh` 167 | * Run `3getfixed.sh` 168 | * Parse results with `ParseTriadResults` class in `SSMuLA.triad_prepost` 169 | 170 | 171 | #### Combine all zs 172 | * Run 173 | ``` 174 | python -m tests.test_zs 175 | ``` 176 | refer to the test file and the script documentation for further details 177 | 178 | ### Simulations 179 | #### DE 180 | * Run `de_simulations` and visualise with `plot_de_simulations` 181 | * Run 182 | ``` 183 | python -m tests.test_de 184 | ``` 185 | and 186 | ``` 187 | python -m tests.test_de_vis 188 | ``` 189 | refer to the test file and the script documentation for further details 190 | 191 | 192 | #### MLDE and ftMLDE 193 | * Use `MLDE_lite` environment 194 | * For using learned ESM embeddings, first run `gen_all_learned_emb` from `SSMuLA.gen_learned_emb`, else skip this step 195 | * Run 196 | ``` 197 | python -m tests.test_gen_learned_emb 198 | ``` 199 | * Run `run_all_mlde_parallelized` from `SSMuLA.mlde_lite` to run simulations 200 | * Run 201 | ``` 202 | python -m tests.test_mlde 203 | ``` 204 | * Important options including: 205 | * `n_mut_cutoffs`: list of integers for Hamming distance cutoff options where `[0]` means none and `[2]` for Hamming distance of two for ensemble 206 | * `zs_predictors`: list of strings for zero-shot predictors, i.e. `["none", "Triad", "ev", "esm"]` where `none` means not focused training and thus default MLDE runs; the list can be extended for non-Hamming distance ensemble, including `["Triad-esmif", "Triad-ev", "Triad-esm", "two-best"]` 207 | * `ft_lib_fracs`: list of floats for fraction of libraries to use for focused training, i.e. `[0.5, 0.25, 0.125]` 208 | * `encoding`: list of strings for encoding options, i.e. `["one-hot"] + DEFAULT_LEARNED_EMB_COMBO` 209 | * `model_classes`: list of strings for model classes, i.e. `["boosting", "ridge"]` 210 | * `n_samples`: list of integers for number of training samples to use, i.e. `[96, 384]` 211 | * `n_split`: integer for number of splits for cross-validation, i.e. `5` 212 | * `n_replicate`: integer for number of replicates for each model, i.e. `50` 213 | * `n_tops`: integer for number of variants to test the prediction, i.e. `[96, 384]` 214 | refer to the test file and the script documentation for further details 215 | * Run `MLDESum` from `SSMuLA.mlde_analysis` to get the summary dataframe and optional visualization 216 | ``` 217 | python -m tests.test_mlde_vis 218 | ``` 219 | 220 | #### ALDE and ftALDE 221 | * See details in [alde4ssmula](https://github.com/fhalab/alde4ssmula) repository 222 | * `aggregate_alde_df` from `SSMuLA.alde_analysis` to get the summary dataframe 223 | ``` 224 | python -m tests.test_alde 225 | ``` 226 | 227 | #### Fine-tuning 228 | * Run `train_predict_per_protein` from `SSMuLA.plm_finetune` for ESM-2 LoRA fine-tuning simulations 229 | 230 | ### Analysis and paper figures 231 | * All notebooks in `fig_notebooks` are used to reproduce figures in the paper with files downloaded from [Zenodo]((https://doi.org/10.5281/zenodo.13910506)) 232 | 233 | ## Contact 234 | * [Francesca-Zhoufan Li](mailto:fzl@caltech.edu) -------------------------------------------------------------------------------- /SSMuLA/alde_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for combining and analyzing the results of the ALDE analysis. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import os 8 | import pandas as pd 9 | import numpy as np 10 | 11 | 12 | from SSMuLA.landscape_global import N_SAMPLE_LIST, LOWN_DICT 13 | from SSMuLA.zs_analysis import ZS_OPTS, map_zs_labels 14 | 15 | 16 | def avg_alde_df( 17 | eq_n: int, 18 | lib_list: list, 19 | zs: str = "", 20 | alde_model: str = "Boosting Ensemble", 21 | alde_encoding: str = "onehot", 22 | alde_acq: str = "GREEDY", 23 | alde_dir: str = "/disk2/fli/alde4ssmula", 24 | ) -> pd.DataFrame: 25 | 26 | """ 27 | Average ALDE results for a given list of libraries and equal n. 28 | 29 | Args: 30 | - eq_n (int): Equal n for the libraries. 31 | - lib_list (list): List of libraries to aggregate. 32 | 33 | Returns: 34 | - df (pd.DataFrame): Aggregated ALDE results. 35 | """ 36 | 37 | df = pd.DataFrame( 38 | columns=[ 39 | "n_sample", 40 | "top_maxes_mean", 41 | "top_maxes_std", 42 | "if_truemaxs_mean", 43 | "if_truemaxs_std", 44 | ] 45 | ) 46 | 47 | for n in N_SAMPLE_LIST: 48 | 49 | if zs != "": 50 | zs_append = f"{zs}_" 51 | else: 52 | zs_append = "" 53 | 54 | if eq_n == 1: 55 | csv_path = f"{alde_dir}/results/{zs_append}all_{str(n)}+96/all_results.csv" 56 | 57 | else: 58 | csv_path = f"{alde_dir}/results/{zs_append}{str(eq_n)}eq_{str(int((n+96)/eq_n))}/all_results.csv" 59 | 60 | if os.path.exists(csv_path): 61 | a_df = pd.read_csv(csv_path) 62 | 63 | # Get the max Timestep for each Protein 64 | max_timesteps = a_df.groupby("Protein")["Timestep"].transform("max") 65 | # DNN Ensemble 66 | # Boosting Ensemble 67 | slice_df = a_df[ 68 | (a_df["Encoding"] == alde_encoding) 69 | & (a_df["Acquisition"] == alde_acq) 70 | & (a_df["Model"] == alde_model) 71 | & (a_df["Protein"].isin(lib_list)) 72 | & (a_df["Timestep"] == max_timesteps) 73 | ] 74 | # for each Protein take the max of the timestep 75 | 76 | if len(lib_list) == 1: 77 | top_maxes_std = slice_df["Std"].mean() 78 | if_truemaxs_std = 0 79 | else: 80 | top_maxes_std = slice_df["Mean"].std() 81 | if_truemaxs_std = slice_df["Frac"].std() 82 | 83 | df = df._append( 84 | { 85 | "n_sample": n, 86 | "top_maxes_mean": slice_df["Mean"].mean(), 87 | "top_maxes_std": top_maxes_std, 88 | "if_truemaxs_mean": slice_df["Frac"].mean(), 89 | "if_truemaxs_std": if_truemaxs_std, 90 | }, 91 | ignore_index=True, 92 | ) 93 | elif "ds-ed" in csv_path: 94 | continue 95 | else: 96 | print(f"File not found: {csv_path}") 97 | 98 | df = df._append( 99 | { 100 | "n_sample": n, 101 | "top_maxes_mean": np.nan, 102 | "top_maxes_std": np.nan, 103 | "if_truemaxs_mean": np.nan, 104 | "if_truemaxs_std": np.nan, 105 | }, 106 | ignore_index=True, 107 | ) 108 | 109 | return df.set_index("n_sample") 110 | 111 | 112 | def aggregate_alde_df( 113 | eq_ns: list[int] = [1, 2, 3, 4], 114 | n_list: list[int] = N_SAMPLE_LIST, 115 | zs_opts: list[str] = ["esmif", "ev", "coves", "ed", "esm", "Triad", ""], 116 | alde_dir: str = "/disk2/fli/alde4ssmula", 117 | alde_res_folder: str = "results", 118 | alde_df_path: str = "results/alde/alde_all.csv", 119 | ) -> pd.DataFrame: 120 | 121 | """ 122 | Aggregate ALDE results for a given list of libraries and equal n. 123 | 124 | Args: 125 | - eq_ns (list): List of equal n values. 126 | - zs_opts (list): List of zero-shot options. 127 | - alde_dir (str): Directory containing ALDE results. 128 | - alde_df_path (str): Path to save the aggregated ALDE results. 129 | 130 | Returns: 131 | - df (pd.DataFrame): Aggregated ALDE results. 132 | """ 133 | 134 | # initialize the dataframe 135 | alde_all = pd.DataFrame( 136 | columns=[ 137 | "n_mut_cutoff", 138 | "zs", 139 | "rounds", 140 | "n_samples", 141 | "Protein", 142 | "Encoding", 143 | "Model", 144 | "Acquisition", 145 | "Timestep", 146 | "Mean", 147 | "Std", 148 | "Frac", 149 | ] 150 | ) 151 | 152 | for eq_n in eq_ns: 153 | 154 | for zs in zs_opts + ["ds-" + z for z in zs_opts if z != ""]: 155 | 156 | if "ds-" in zs: 157 | n_mut = "double" 158 | else: 159 | n_mut = "all" 160 | 161 | for n in n_list: 162 | 163 | if isinstance(n, int): 164 | dir_det = str(int((n + 96) / eq_n)) 165 | n_sample = n + 96 166 | else: 167 | dir_det = n 168 | n_sample = LOWN_DICT[n] 169 | 170 | if zs != "": 171 | zs_append = f"{zs}_" 172 | else: 173 | zs_append = "" 174 | 175 | if eq_n == 1: 176 | csv_path = f"{alde_dir}/{alde_res_folder}/{zs_append}all_{str(n)}+96/all_results.csv" 177 | 178 | else: 179 | csv_path = f"{alde_dir}/{alde_res_folder}/{zs_append}{str(eq_n)}eq_{dir_det}/all_results.csv" 180 | 181 | if os.path.exists(csv_path): 182 | print(f"Reading {csv_path}...") 183 | a_df = pd.read_csv(csv_path) 184 | 185 | max_timesteps = a_df.groupby("Protein")["Timestep"].transform("max") 186 | slice_df = a_df[a_df["Timestep"] == max_timesteps].copy() 187 | 188 | slice_df["n_mut_cutoff"] = n_mut 189 | slice_df["zs"] = zs 190 | slice_df["rounds"] = eq_n 191 | slice_df["n_samples"] = n_sample 192 | 193 | # replace T7_2 with T7 194 | # slice_df = slice_df.replace("T7_2", "T7") 195 | 196 | alde_all = alde_all._append(slice_df, ignore_index=True) 197 | 198 | else: 199 | print(f"File not found: {csv_path}") 200 | 201 | alde_all = alde_all._append( 202 | { 203 | "n_mut_cutoff": n_mut, 204 | "zs": zs, 205 | "rounds": eq_n, 206 | "n_samples": n_sample, 207 | "Protein": np.nan, 208 | "Encoding": np.nan, 209 | "Model": np.nan, 210 | "Acquisition": np.nan, 211 | "Timestep": np.nan, 212 | "Mean": np.nan, 213 | "Std": np.nan, 214 | "Frac": np.nan, 215 | }, 216 | ignore_index=True, 217 | ) 218 | 219 | alde_all = alde_all.dropna(subset=["Protein"]) 220 | 221 | alde_all.to_csv(alde_df_path, index=False) 222 | 223 | return alde_all 224 | 225 | 226 | def get_ftalde_libavg( 227 | alde_csv: str, 228 | lib_list: list, 229 | n_total: int, 230 | n_round: int, 231 | models: list = ["Boosting Ensemble"], 232 | acquisition: list = ["GREEDY"], 233 | ): 234 | 235 | """ 236 | Get the FT-ALDE data for each of the library, number of rounds, models, and acquisition method. 237 | Args: 238 | alde_csv (str): Path to the FT-ALDE CSV file. 239 | lib_list (list): List of libraries to filter. 240 | n_total (int): Total number of samples. 241 | n_round (int): Number of rounds. 242 | models (list): List of models to filter. 243 | acquisition (list): List of acquisition methods to filter. 244 | Returns: 245 | pd.DataFrame: Filtered DataFrame containing FT-ALDE data. 246 | """ 247 | # have all none zs and zs opt for MLDE, ALDE different rounds 248 | alde_all = pd.read_csv(alde_csv) 249 | # Replace NaN values in column 'zs' with the string "none" 250 | alde_all["zs"] = alde_all["zs"].fillna("none") 251 | 252 | slice_df = alde_all[ 253 | (alde_all["rounds"] == n_round) 254 | & (alde_all["Encoding"] == "onehot") 255 | & (alde_all["Model"].isin(models)) 256 | & (alde_all["Acquisition"].isin(acquisition)) 257 | & (alde_all["n_samples"] == n_total) 258 | & (alde_all["Protein"].isin(lib_list)) 259 | # & (alde_all["n_mut_cutoff"] == "all") 260 | ].copy() 261 | 262 | # Convert 'Category' column to categorical with defined order 263 | slice_df["zs"] = pd.Categorical( 264 | slice_df["zs"], 265 | categories=["none"] 266 | + [o.replace("_score", "") for o in ZS_OPTS] 267 | + [ 268 | "ds-esmif", 269 | "ds-ev", 270 | "ds-coves", 271 | "ds-Triad", 272 | "ds-esm", 273 | ], 274 | ordered=True, 275 | ) 276 | 277 | slice_df = slice_df.sort_values(by=["zs", "Protein"]) 278 | 279 | slice_df["zs"] = slice_df["zs"].apply(map_zs_labels) 280 | 281 | return ( 282 | slice_df[["Protein", "zs", "Mean", "Frac"]] 283 | .rename(columns={"Protein": "lib", "Mean": "top_maxes", "Frac": "if_truemaxs"}) 284 | .copy() 285 | ) 286 | 287 | 288 | def clean_alde_df( 289 | agg_alde_df_path: str = "results/alde/alde_all.csv", 290 | clean_alde_df_path: str = "results/alde/alde_results.csv", 291 | ): 292 | """ 293 | A function to clean up the aggregated ALDE results. 294 | """ 295 | 296 | alde_df = pd.read_csv(agg_alde_df_path) 297 | alde_df[ 298 | (alde_df["rounds"].isin([2, 3, 4])) 299 | & (alde_df["Model"].isin(["Boosting Ensemble", "DNN Ensemble"])) 300 | & (alde_df["Acquisition"] == "GREEDY") 301 | ].rename( 302 | columns={ 303 | "Protein": "lib", 304 | "Mean": "top_maxes_mean", 305 | "Std": "top_maxes_std", 306 | "Frac": "if_truemaxs_mean", 307 | "Encoding": "encoding", 308 | "Model": "model", 309 | "n_samples": "n_sample", 310 | } 311 | )[ 312 | [ 313 | "encoding", 314 | "model", 315 | "n_sample", 316 | "top_maxes_mean", 317 | "top_maxes_std", 318 | "if_truemaxs_mean", 319 | "n_mut_cutoff", 320 | "lib", 321 | "zs", 322 | "rounds", 323 | ] 324 | ].reset_index( 325 | drop=True 326 | ).to_csv( 327 | clean_alde_df_path, index=False 328 | ) -------------------------------------------------------------------------------- /SSMuLA/triad_prepost.py: -------------------------------------------------------------------------------- 1 | """A script for generating mut file needed for triad""" 2 | 3 | import re 4 | import os 5 | 6 | from copy import deepcopy 7 | from glob import glob 8 | 9 | import pandas as pd 10 | import numpy as np 11 | from itertools import product 12 | 13 | from SSMuLA.aa_global import ALL_AAS 14 | from SSMuLA.landscape_global import LIB_INFO_DICT, LIB_NAMES, TrpB_names 15 | from SSMuLA.util import checkNgen_folder, get_file_name 16 | 17 | 18 | # TrpB_TRIAD_FOLDER = "/home/shared_data/triad_structures" 19 | TrpB_LIB_FOLDER = "data/TrpB/scaled2max" 20 | # TrpB3_TRIAD_TXT = deepcopy(sorted(list(glob(f"{TrpB_TRIAD_FOLDER}/*3*/*/*.txt")))) 21 | TrpB4_TRIAD_TXT = deepcopy(sorted(list(glob("triad/TrpB4/*.txt")))) 22 | # /disk2/fli/SSMuLA/triad/TrpB4 23 | 24 | lib_triad_pair = {} 25 | 26 | # append the other two lib 27 | for lib in LIB_NAMES: 28 | if lib != "TrpB4": 29 | lib_triad_pair[ 30 | f"data/{lib}/scaled2max/{lib}.csv" 31 | ] = f"triad/{lib}/{lib}_fixed.txt" 32 | 33 | SORTED_LIB_TRIAD_PAIR = deepcopy( 34 | dict(sorted(lib_triad_pair.items(), key=lambda x: x[0])) 35 | ) 36 | 37 | 38 | class TriadLib: 39 | """ 40 | A class for common traid things for a given lib 41 | """ 42 | 43 | def __init__(self, input_csv: str, triad_folder: str = "triad") -> None: 44 | 45 | """ 46 | Args: 47 | - input_csv: str, the path to the input csv 48 | - output_folder: str, the path to the output folder 49 | """ 50 | 51 | self._input_csv = input_csv 52 | self._triad_folder = os.path.normpath(triad_folder) 53 | 54 | @property 55 | def lib_name(self) -> str: 56 | 57 | """ 58 | A property for the library name 59 | """ 60 | return get_file_name(self._input_csv) 61 | 62 | @property 63 | def site_num(self) -> int: 64 | 65 | """ 66 | A property for the site number 67 | """ 68 | return len(LIB_INFO_DICT[self.lib_name]["positions"]) 69 | 70 | @property 71 | def wt_aas(self) -> list: 72 | """ 73 | A property for the wildtype amino acids 74 | """ 75 | return list(LIB_INFO_DICT[self.lib_name]["AAs"].values()) 76 | 77 | @property 78 | def prefixes(self) -> list: 79 | """ 80 | A property for the prefixes 81 | """ 82 | return [ 83 | f"A_{pos}" for pos in LIB_INFO_DICT[self.lib_name]["positions"].values() 84 | ] 85 | 86 | @property 87 | def df(self) -> pd.DataFrame: 88 | """ 89 | A property for the dataframe and drop stop codons 90 | """ 91 | return pd.read_csv(self._input_csv) 92 | 93 | @property 94 | def df_no_stop(self) -> pd.DataFrame: 95 | """ 96 | A property for the dataframe and drop stop codons 97 | """ 98 | return self.df[~self.df["AAs"].str.contains("\*")] 99 | 100 | @property 101 | def variants(self) -> list: 102 | """ 103 | A AA sequence for the variants 104 | """ 105 | return self.df_no_stop["AAs"].values.tolist() 106 | 107 | @property 108 | def mut_numb(self) -> int: 109 | """ 110 | A property for the number of mutations 111 | """ 112 | return len(self.df_no_stop) 113 | 114 | 115 | class TriadGenMutFile(TriadLib): 116 | """ 117 | A class for generating a mut file for triad 118 | """ 119 | 120 | def __init__(self, input_csv: str, triad_folder: str = "triad") -> None: 121 | 122 | """ 123 | Args: 124 | - input_csv: str, the path to the input csv 125 | - output_folder: str, the path to the output folder 126 | """ 127 | 128 | super().__init__(input_csv, triad_folder) 129 | 130 | print(f"Generating {self.mut_path} from {self._input_csv}...") 131 | self._mutation_encodings = self._generate_mut_file() 132 | 133 | def _generate_mut_file(self) -> None: 134 | """ 135 | Generate the mut file 136 | """ 137 | 138 | # Loop over variants 139 | mutation_encodings = [] 140 | 141 | for variant in self.variants: 142 | 143 | # Loop over each character in the variant 144 | mut_encoding_list = [] 145 | for j, (var_char, wt_char) in enumerate(zip(variant, self.wt_aas)): 146 | 147 | # If the var_char does not equal the wt_char, append 148 | if var_char != wt_char: 149 | mut_encoding_list.append(self.prefixes[j] + var_char) 150 | 151 | # If the mut_encoding_list has no entries, continue (this is wild type) 152 | if len(mut_encoding_list) == 0: 153 | continue 154 | 155 | # Otherwise, append to mutation_encodings 156 | else: 157 | mutation_encodings.append("+".join(mut_encoding_list) + "\n") 158 | 159 | # check before saving 160 | # assert len(mutation_encodings) == self.mut_numb - 1 161 | 162 | # Save the mutants 163 | with open(self.mut_path, "w") as f: 164 | f.writelines(mutation_encodings) 165 | 166 | return mutation_encodings 167 | 168 | @property 169 | def mut_path(self) -> str: 170 | """ 171 | A property for the mut file path 172 | """ 173 | sub_folder = checkNgen_folder(os.path.join(self._triad_folder, self.lib_name)) 174 | return os.path.join(sub_folder, f"{self.lib_name}.mut") 175 | 176 | @property 177 | def mut_encoding(self) -> list: 178 | """ 179 | A property for the mutation encodings 180 | """ 181 | return self._mutation_encodings 182 | 183 | 184 | class ParseTriadResults(TriadLib): 185 | """ 186 | A class for parsing the triad results 187 | """ 188 | 189 | def __init__( 190 | self, 191 | input_csv: str, 192 | triad_txt: str, 193 | triad_folder: str = "triad", 194 | ) -> None: 195 | 196 | """ 197 | Args: 198 | - input_csv: str, the path to the input csv 199 | - triad_txt: str, the path to the triad txt file 200 | - triad_folder: str, the parent folder to all triad data 201 | """ 202 | 203 | super().__init__(input_csv, triad_folder) 204 | 205 | self._triad_txt = triad_txt 206 | 207 | print(f"Parsing {self._triad_txt} and save to {self.triad_csv}...") 208 | 209 | # extract triad score into dataframe 210 | self._triad_df = self._get_triad_score() 211 | 212 | # save the triad dataframe unless trpb4 as need to NOT overwrite 213 | if self.lib_name != "TrpB4": 214 | self._triad_df.to_csv(self.triad_csv, index=False) 215 | 216 | def _get_triad_score(self) -> float: 217 | 218 | """ 219 | A function to load the output of a triad analysis and get a score 220 | 221 | Args: 222 | - triad_output_file: str, the path to the triad output file 223 | - WT_combo: str, the wildtype combo 224 | - num_seqs: int, the number of sequences to load 225 | """ 226 | 227 | # Load the output file 228 | with open(self._triad_txt) as f: 229 | 230 | # Set some flags for starting analysis 231 | solutions_started = False 232 | record_start = False 233 | 234 | # Begin looping over the file 235 | summary_lines = [] 236 | for line in f: 237 | 238 | # Start looking at data once we hit "solution" 239 | # if "Solution" in line: 240 | if "All sequences:" in line: 241 | solutions_started = True 242 | 243 | # Once we have "Index" we can start recording the rest 244 | if solutions_started and "Index" in line: 245 | record_start = True 246 | 247 | # Record appropriate data 248 | if record_start: 249 | 250 | # Strip the newline and split on whitespace 251 | summary_line = line.strip().split() 252 | 253 | if summary_line[0] == "Average": 254 | break 255 | else: 256 | # Otherwise, append the line 257 | summary_lines.append(summary_line) 258 | 259 | # Build the dataframe with col ['Index', 'Tags', 'Score', 'Seq', 'Muts'] 260 | all_results = pd.DataFrame(summary_lines[1:], columns=summary_lines[0]) 261 | all_results["Triad_score"] = all_results["Score"].astype(float) 262 | 263 | wt_chars = self.wt_aas 264 | reconstructed_combos = [ 265 | "".join( 266 | [char if char != "-" else wt_chars[i] for i, char in enumerate(seq)] 267 | ) 268 | for seq in all_results.Seq.values 269 | ] 270 | all_results["AAs"] = reconstructed_combos 271 | 272 | # Get the order 273 | all_results["Triad_rank"] = np.arange(1, len(all_results) + 1) 274 | 275 | return all_results[["AAs", "Triad_score", "Triad_rank"]] 276 | 277 | @property 278 | def triad_csv(self) -> str: 279 | """ 280 | A property for the triad csv 281 | """ 282 | return os.path.join(self._triad_folder, self.lib_name, f"{self.lib_name}.csv") 283 | 284 | @property 285 | def triad_df(self) -> pd.DataFrame: 286 | """ 287 | A property for the triad dataframe 288 | """ 289 | return self._triad_df 290 | 291 | 292 | def run_traid_gen_mut_file(all_lib: bool = True, lib_list: list[str] = []): 293 | """ 294 | Run the triad gen mut file function for all libraries 295 | 296 | Args: 297 | - all_lib: bool, whether to run for all libraries 298 | - lib_list: list, a list of libraries to run for 299 | """ 300 | 301 | if all_lib or len(lib_list) == 0: 302 | lib_list = glob("data/*/scale2max/*.csv") 303 | 304 | for lib in lib_list: 305 | TriadGenMutFile(input_csv=lib) 306 | 307 | 308 | def run_parse_triad_results( 309 | triad_folder: str = "triad", 310 | all_lib: bool = True, 311 | lib_list: list[str] = [] 312 | ): 313 | 314 | """ 315 | Run the parse triad results function for all libraries 316 | 317 | Args: 318 | - triad_folder: str, the parent folder to all triad data 319 | - all_lib: bool, whether to run for all libraries 320 | - lib_list: list, a list of libraries to run for 321 | """ 322 | 323 | for lib, triad_txt in SORTED_LIB_TRIAD_PAIR.items(): 324 | ParseTriadResults(input_csv=lib, triad_txt=triad_txt, triad_folder=triad_folder) 325 | 326 | # need to merge 327 | trpb4_dfs = [] 328 | print(f"Parsing TrpB4 {TrpB4_TRIAD_TXT}...") 329 | 330 | for triad_txt in TrpB4_TRIAD_TXT: 331 | # ignore rank for now 332 | trpb4_dfs.append( 333 | ParseTriadResults( 334 | input_csv=os.path.join(TrpB_LIB_FOLDER, "TrpB4.csv"), 335 | triad_txt=triad_txt, 336 | triad_folder=triad_folder, 337 | ) 338 | .triad_df[["AAs", "Triad_score"]] 339 | .copy() 340 | ) 341 | 342 | # there will be multip wt from each file so need to drop them 343 | trpb4_df = pd.concat(trpb4_dfs).drop_duplicates().sort_values(["Triad_score"]) 344 | 345 | # resort and overwrite 346 | trpb4_df["Triad_rank"] = np.arange(1, len(trpb4_df) + 1) 347 | print("Added new rank for TrpB4") 348 | 349 | trpb4_df.to_csv(os.path.join(triad_folder, "TrpB4", "TrpB4.csv"), index=False) -------------------------------------------------------------------------------- /envs/finetune.yml: -------------------------------------------------------------------------------- 1 | # conda env update --file finetune.yml --prune 2 | 3 | name: finetune 4 | channels: 5 | - pytorch 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - abseil-cpp=20211102.0=hd4dd3e8_0 13 | - absl-py=1.3.0=py39h06a4308_0 14 | - aiohttp=3.8.3=py39h5eee18b_0 15 | - aiosignal=1.2.0=pyhd3eb1b0_0 16 | - anyio=3.5.0=py39h06a4308_0 17 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 18 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 19 | - asttokens=2.0.5=pyhd3eb1b0_0 20 | - astunparse=1.6.3=py_0 21 | - async-timeout=4.0.2=py39h06a4308_0 22 | - attrs=22.1.0=py39h06a4308_0 23 | - babel=2.11.0=py39h06a4308_0 24 | - backcall=0.2.0=pyhd3eb1b0_0 25 | - beautifulsoup4=4.11.1=py39h06a4308_0 26 | - biopython=1.83=py39hd1e30aa_0 27 | - blas=1.0=mkl 28 | - bleach=4.1.0=pyhd3eb1b0_0 29 | - blinker=1.4=py39h06a4308_0 30 | - brotli=1.1.0=hd590300_1 31 | - brotli-bin=1.1.0=hd590300_1 32 | - brotlipy=0.7.0=py39h27cfd23_1003 33 | - c-ares=1.18.1=h7f8727e_0 34 | - ca-certificates=2024.2.2=hbcca054_0 35 | - cachetools=4.2.2=pyhd3eb1b0_0 36 | - certifi=2024.2.2=pyhd8ed1ab_0 37 | - cffi=1.15.1=py39h5eee18b_3 38 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 39 | - click=8.0.4=py39h06a4308_0 40 | - comm=0.1.2=py39h06a4308_0 41 | - cryptography=38.0.4=py39h9ce1e76_0 42 | - cuda=11.7.1=0 43 | - cuda-cccl=11.7.91=0 44 | - cuda-command-line-tools=11.7.1=0 45 | - cuda-compiler=11.7.1=0 46 | - cuda-cudart=11.7.99=0 47 | - cuda-cudart-dev=11.7.99=0 48 | - cuda-cuobjdump=11.7.91=0 49 | - cuda-cupti=11.7.101=0 50 | - cuda-cuxxfilt=11.7.91=0 51 | - cuda-demo-suite=12.0.140=0 52 | - cuda-documentation=12.0.140=0 53 | - cuda-driver-dev=11.7.99=0 54 | - cuda-gdb=12.0.140=0 55 | - cuda-libraries=11.7.1=0 56 | - cuda-libraries-dev=11.7.1=0 57 | - cuda-memcheck=11.8.86=0 58 | - cuda-nsight=12.0.140=0 59 | - cuda-nsight-compute=12.0.1=0 60 | - cuda-nvcc=11.7.99=0 61 | - cuda-nvdisasm=12.0.140=0 62 | - cuda-nvml-dev=11.7.91=0 63 | - cuda-nvprof=11.7.101=0 64 | - cuda-nvprune=11.7.91=0 65 | - cuda-nvrtc=11.7.99=0 66 | - cuda-nvrtc-dev=11.7.99=0 67 | - cuda-nvtx=11.7.91=0 68 | - cuda-nvvp=12.0.146=0 69 | - cuda-runtime=11.7.1=0 70 | - cuda-sanitizer-api=12.0.140=0 71 | - cuda-toolkit=11.7.1=0 72 | - cuda-tools=11.7.1=0 73 | - cuda-visual-tools=11.7.1=0 74 | - cudatoolkit=11.7.0=hd8887f6_11 75 | - cudatoolkit-dev=11.7.0=h1de0b5d_6 76 | - cudnn=8.4.1.50=hed8a83a_0 77 | - debugpy=1.5.1=py39h295c915_0 78 | - decorator=5.1.1=pyhd3eb1b0_0 79 | - defusedxml=0.7.1=pyhd3eb1b0_0 80 | - entrypoints=0.4=py39h06a4308_0 81 | - et_xmlfile=1.1.0=py39h06a4308_0 82 | - executing=0.8.3=pyhd3eb1b0_0 83 | - fftw=3.3.9=h27cfd23_1 84 | - flit-core=3.6.0=pyhd3eb1b0_0 85 | - freetype=2.12.1=hca18f0e_0 86 | - frozenlist=1.3.3=py39h5eee18b_0 87 | - gast=0.4.0=pyhd3eb1b0_0 88 | - gds-tools=1.5.1.14=0 89 | - giflib=5.2.1=h5eee18b_1 90 | - google-auth=2.6.0=pyhd3eb1b0_0 91 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 92 | - google-pasta=0.2.0=pyhd3eb1b0_0 93 | - grpc-cpp=1.46.4=h6fc47f4_3 94 | - grpcio=1.46.4=py39h4587e31_3 95 | - h5py=3.7.0=py39h737f45e_0 96 | - hdf5=1.10.6=h3ffc7dd_1 97 | - icu=70.1=h27087fc_0 98 | - idna=3.4=py39h06a4308_0 99 | - importlib-metadata=4.11.3=py39h06a4308_0 100 | - intel-openmp=2021.4.0=h06a4308_3561 101 | - ipykernel=6.19.2=py39hb070fc8_0 102 | - ipython=8.9.0=py39h06a4308_0 103 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 104 | - ipywidgets=8.0.4=pyhd8ed1ab_0 105 | - jedi=0.18.1=py39h06a4308_1 106 | - jinja2=3.1.2=py39h06a4308_0 107 | - joblib=1.1.1=py39h06a4308_0 108 | - jpeg=9e=h7f8727e_0 109 | - json5=0.9.6=pyhd3eb1b0_0 110 | - jsonschema=4.16.0=py39h06a4308_0 111 | - jupyter_client=7.4.9=py39h06a4308_0 112 | - jupyter_contrib_core=0.4.0=pyhd8ed1ab_0 113 | - jupyter_contrib_nbextensions=0.7.0=pyhd8ed1ab_0 114 | - jupyter_core=5.1.1=py39h06a4308_0 115 | - jupyter_highlight_selected_word=0.2.0=py39hf3d152e_1005 116 | - jupyter_latex_envs=1.4.6=pyhd8ed1ab_1002 117 | - jupyter_nbextensions_configurator=0.6.1=pyhd8ed1ab_0 118 | - jupyter_server=1.23.4=py39h06a4308_0 119 | - jupyterlab=3.5.3=py39h06a4308_0 120 | - jupyterlab_pygments=0.1.2=py_0 121 | - jupyterlab_server=2.16.5=py39h06a4308_0 122 | - jupyterlab_widgets=3.0.5=pyhd8ed1ab_0 123 | - keras=2.9.0=py39h06a4308_0 124 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 125 | - krb5=1.19.4=h568e23c_0 126 | - lcms2=2.12=h3be6417_0 127 | - ld_impl_linux-64=2.38=h1181459_1 128 | - lerc=3.0=h295c915_0 129 | - libabseil=20211102.0=cxx17_h48a1fff_2 130 | - libbrotlicommon=1.1.0=hd590300_1 131 | - libbrotlidec=1.1.0=hd590300_1 132 | - libbrotlienc=1.1.0=hd590300_1 133 | - libcublas=11.10.3.66=0 134 | - libcublas-dev=11.10.3.66=0 135 | - libcufft=10.7.2.124=h4fbf590_0 136 | - libcufft-dev=10.7.2.124=h98a8f43_0 137 | - libcufile=1.5.1.14=0 138 | - libcufile-dev=1.5.1.14=0 139 | - libcurand=10.3.1.124=0 140 | - libcurand-dev=10.3.1.124=0 141 | - libcurl=7.87.0=h91b91d3_0 142 | - libcusolver=11.4.0.1=0 143 | - libcusolver-dev=11.4.0.1=0 144 | - libcusparse=11.7.4.91=0 145 | - libcusparse-dev=11.7.4.91=0 146 | - libdeflate=1.17=h5eee18b_1 147 | - libedit=3.1.20221030=h5eee18b_0 148 | - libev=4.33=h7f8727e_1 149 | - libffi=3.4.2=h6a678d5_6 150 | - libgcc-ng=12.2.0=h65d4601_19 151 | - libgfortran-ng=11.2.0=h00389a5_1 152 | - libgfortran5=11.2.0=h1234567_1 153 | - libiconv=1.17=h166bdaf_0 154 | - libnghttp2=1.46.0=hce63b2e_0 155 | - libnpp=11.7.4.75=0 156 | - libnpp-dev=11.7.4.75=0 157 | - libnvjpeg=11.8.0.2=0 158 | - libnvjpeg-dev=11.8.0.2=0 159 | - libpng=1.6.37=hbc83047_0 160 | - libprotobuf=3.20.3=he621ea3_0 161 | - libsodium=1.0.18=h7b6447c_0 162 | - libssh2=1.10.0=h8f2d780_0 163 | - libstdcxx-ng=12.2.0=h46fd767_19 164 | - libtiff=4.5.1=h6a678d5_0 165 | - libwebp-base=1.3.2=hd590300_0 166 | - libxml2=2.10.3=h7463322_0 167 | - libxslt=1.1.37=h873f0b0_0 168 | - libzlib=1.2.13=h166bdaf_4 169 | - llvm-openmp=14.0.6=h9e868ea_0 170 | - lxml=4.9.2=py39h14694de_0 171 | - markdown=3.4.1=py39h06a4308_0 172 | - markupsafe=2.1.1=py39h7f8727e_0 173 | - matplotlib-base=3.5.3=py39h19d6b11_2 174 | - matplotlib-inline=0.1.6=py39h06a4308_0 175 | - mistune=0.8.4=py39h27cfd23_1000 176 | - mkl=2021.4.0=h06a4308_640 177 | - mkl-service=2.4.0=py39h7f8727e_0 178 | - mkl_fft=1.3.1=py39hd3c417c_0 179 | - mkl_random=1.2.2=py39h51133e4_0 180 | - multidict=6.0.2=py39h5eee18b_0 181 | - munkres=1.1.4=pyh9f0ad1d_0 182 | - nb_conda=2.2.1=unix_6 183 | - nb_conda_kernels=2.3.1=py39h06a4308_0 184 | - nbclassic=0.4.8=py39h06a4308_0 185 | - nbclient=0.5.13=py39h06a4308_0 186 | - nbconvert=6.4.4=py39h06a4308_0 187 | - nbformat=5.7.0=py39h06a4308_0 188 | - nccl=2.14.3.1=h0800d71_0 189 | - ncurses=6.4=h6a678d5_0 190 | - nest-asyncio=1.5.6=py39h06a4308_0 191 | - notebook=6.5.2=py39h06a4308_0 192 | - notebook-shim=0.2.2=py39h06a4308_0 193 | - nsight-compute=2022.4.1.6=0 194 | - numpy=1.22.3=py39he7a7128_0 195 | - numpy-base=1.22.3=py39hf524024_0 196 | - oauthlib=3.2.1=py39h06a4308_0 197 | - openjpeg=2.4.0=h3ad879b_0 198 | - openpyxl=3.0.10=py39h5eee18b_0 199 | - openssl=1.1.1w=hd590300_0 200 | - opt_einsum=3.3.0=pyhd3eb1b0_1 201 | - packaging=22.0=py39h06a4308_0 202 | - pandocfilters=1.5.0=pyhd3eb1b0_0 203 | - parso=0.8.3=pyhd3eb1b0_0 204 | - patsy=0.5.6=pyhd8ed1ab_0 205 | - pexpect=4.8.0=pyhd3eb1b0_3 206 | - pickleshare=0.7.5=pyhd3eb1b0_1003 207 | - pip=22.3.1=py39h06a4308_0 208 | - platformdirs=2.5.2=py39h06a4308_0 209 | - prometheus_client=0.14.1=py39h06a4308_0 210 | - prompt-toolkit=3.0.36=py39h06a4308_0 211 | - protobuf=3.20.3=py39h6a678d5_0 212 | - psutil=5.9.0=py39h5eee18b_0 213 | - ptyprocess=0.7.0=pyhd3eb1b0_2 214 | - pure_eval=0.2.2=pyhd3eb1b0_0 215 | - pyasn1=0.4.8=pyhd3eb1b0_0 216 | - pyasn1-modules=0.2.8=py_0 217 | - pycparser=2.21=pyhd3eb1b0_0 218 | - pygments=2.11.2=pyhd3eb1b0_0 219 | - pyjwt=2.4.0=py39h06a4308_0 220 | - pyopenssl=22.0.0=pyhd3eb1b0_0 221 | - pyrsistent=0.18.0=py39heee7806_0 222 | - pysocks=1.7.1=py39h06a4308_0 223 | - python=3.9.16=h7a1cb2a_0 224 | - python-dateutil=2.8.2=pyhd3eb1b0_0 225 | - python-fastjsonschema=2.16.2=py39h06a4308_0 226 | - python-flatbuffers=1.12=pyhd3eb1b0_0 227 | - python-tzdata=2024.1=pyhd8ed1ab_0 228 | - python_abi=3.9=2_cp39 229 | - pytorch=1.13.1=py3.9_cuda11.7_cudnn8.5.0_0 230 | - pytorch-cuda=11.7=h67b0de4_1 231 | - pytorch-mutex=1.0=cuda 232 | - pyyaml=6.0=py39h5eee18b_1 233 | - pyzmq=23.2.0=py39h6a678d5_0 234 | - re2=2022.06.01=h27087fc_1 235 | - readline=8.2=h5eee18b_0 236 | - requests=2.28.1=py39h06a4308_0 237 | - requests-oauthlib=1.3.0=py_0 238 | - rsa=4.7.2=pyhd3eb1b0_1 239 | - scikit-learn=1.2.0=py39h6a678d5_0 240 | - scipy=1.7.3=py39h6c91a56_2 241 | - seaborn=0.13.2=hd8ed1ab_0 242 | - seaborn-base=0.13.2=pyhd8ed1ab_0 243 | - send2trash=1.8.0=pyhd3eb1b0_1 244 | - sentencepiece=0.1.95=py39hd09550d_0 245 | - setuptools=65.6.3=py39h06a4308_0 246 | - six=1.16.0=pyhd3eb1b0_1 247 | - snappy=1.1.9=h295c915_0 248 | - sniffio=1.2.0=py39h06a4308_1 249 | - soupsieve=2.3.2.post1=py39h06a4308_0 250 | - sqlite=3.40.1=h5082296_0 251 | - stack_data=0.2.0=pyhd3eb1b0_0 252 | - statsmodels=0.14.0=py39ha9d4c09_0 253 | - tensorboard=2.9.0=py39h06a4308_0 254 | - tensorboard-data-server=0.6.1=py39h52d8a92_0 255 | - tensorboard-plugin-wit=1.8.1=py39h06a4308_0 256 | - tensorflow=2.9.1=cuda112py39h01bd6f0_0 257 | - tensorflow-base=2.9.1=cuda112py39h81abfd3_0 258 | - tensorflow-estimator=2.9.1=cuda112py39hd320b7a_0 259 | - termcolor=2.1.0=py39h06a4308_0 260 | - terminado=0.17.1=py39h06a4308_0 261 | - testpath=0.6.0=py39h06a4308_0 262 | - threadpoolctl=2.2.0=pyh0d69192_0 263 | - tk=8.6.12=h1ccaba5_0 264 | - tomli=2.0.1=py39h06a4308_0 265 | - tornado=6.2=py39h5eee18b_0 266 | - traitlets=5.7.1=py39h06a4308_0 267 | - typing-extensions=4.4.0=py39h06a4308_0 268 | - typing_extensions=4.4.0=py39h06a4308_0 269 | - tzdata=2022g=h04d1e81_0 270 | - unicodedata2=15.1.0=py39hd1e30aa_0 271 | - urllib3=1.26.14=py39h06a4308_0 272 | - wcwidth=0.2.5=pyhd3eb1b0_0 273 | - webencodings=0.5.1=py39h06a4308_1 274 | - websocket-client=0.58.0=py39h06a4308_4 275 | - werkzeug=2.2.2=py39h06a4308_0 276 | - wheel=0.37.1=pyhd3eb1b0_0 277 | - widgetsnbextension=4.0.5=pyhd8ed1ab_0 278 | - wrapt=1.14.1=py39h5eee18b_0 279 | - xz=5.2.10=h5eee18b_1 280 | - yaml=0.2.5=h7b6447c_0 281 | - yarl=1.8.1=py39h5eee18b_0 282 | - zeromq=4.3.4=h2531618_0 283 | - zipp=3.11.0=py39h06a4308_0 284 | - zlib=1.2.13=h166bdaf_4 285 | - zstd=1.5.5=hfc55251_0 286 | - pip: 287 | - accelerate==0.28.0 288 | - arrow==1.2.3 289 | - contourpy==1.0.7 290 | - cpufeature==0.2.1 291 | - cycler==0.11.0 292 | - datasets==2.9.0 293 | - deepspeed==0.8.1 294 | - dill==0.3.6 295 | - evaluate==0.4.0 296 | - filelock==3.9.0 297 | - flatbuffers==1.12 298 | - fonttools==4.38.0 299 | - fqdn==1.5.1 300 | - fsspec==2024.2.0 301 | - gradient-accumulator==0.3.1 302 | - hdijupyterutils==0.20.4 303 | - hjson==3.1.0 304 | - huggingface-hub==0.21.4 305 | - ipyparallel==8.4.1 306 | - isoduration==20.11.0 307 | - jsonpointer==2.3 308 | - jupyter==1.0.0 309 | - jupyter-console==6.5.1 310 | - jupyter-contrib-core==0.4.2 311 | - jupyter-server==1.23.5 312 | - kiwisolver==1.4.4 313 | - lmdb==1.4.1 314 | - matplotlib==3.6.3 315 | - multiprocess==0.70.14 316 | - nglview==3.0.8 317 | - ninja==1.11.1 318 | - pandas==1.5.3 319 | - peft==0.9.0 320 | - pillow==9.4.0 321 | - py-cpuinfo==9.0.0 322 | - pyarrow==11.0.0 323 | - pydantic==1.10.5 324 | - pynmrstar==3.3.2 325 | - pyparsing==3.0.9 326 | - pytz==2022.7.1 327 | - qtconsole==5.4.0 328 | - qtpy==2.3.0 329 | - regex==2022.10.31 330 | - responses==0.18.0 331 | - safetensors==0.4.2 332 | - tensorflow-addons==0.19.0 333 | - tokenizers==0.13.2 334 | - tqdm==4.64.1 335 | - transformers==4.26.1 336 | - typeguard==2.13.3 337 | - uri-template==1.2.0 338 | - webcolors==1.12 339 | - xxhash==3.2.0 -------------------------------------------------------------------------------- /SSMuLA/aa_global.py: -------------------------------------------------------------------------------- 1 | """Parameters for training and testing""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections import Counter 6 | 7 | import re 8 | from copy import deepcopy 9 | 10 | RAND_SEED = 42 11 | 12 | TRANSLATE_DICT = { 13 | "AAA": "K", 14 | "AAT": "N", 15 | "AAC": "N", 16 | "AAG": "K", 17 | "ATA": "I", 18 | "ATT": "I", 19 | "ATC": "I", 20 | "ATG": "M", 21 | "ACA": "T", 22 | "ACT": "T", 23 | "ACC": "T", 24 | "ACG": "T", 25 | "AGA": "R", 26 | "AGT": "S", 27 | "AGC": "S", 28 | "AGG": "R", 29 | "TAA": "*", 30 | "TAT": "Y", 31 | "TAC": "Y", 32 | "TAG": "*", 33 | "TTA": "L", 34 | "TTT": "F", 35 | "TTC": "F", 36 | "TTG": "L", 37 | "TCA": "S", 38 | "TCT": "S", 39 | "TCC": "S", 40 | "TCG": "S", 41 | "TGA": "*", 42 | "TGT": "C", 43 | "TGC": "C", 44 | "TGG": "W", 45 | "CAA": "Q", 46 | "CAT": "H", 47 | "CAC": "H", 48 | "CAG": "Q", 49 | "CTA": "L", 50 | "CTT": "L", 51 | "CTC": "L", 52 | "CTG": "L", 53 | "CCA": "P", 54 | "CCT": "P", 55 | "CCC": "P", 56 | "CCG": "P", 57 | "CGA": "R", 58 | "CGT": "R", 59 | "CGC": "R", 60 | "CGG": "R", 61 | "GAA": "E", 62 | "GAT": "D", 63 | "GAC": "D", 64 | "GAG": "E", 65 | "GTA": "V", 66 | "GTT": "V", 67 | "GTC": "V", 68 | "GTG": "V", 69 | "GCA": "A", 70 | "GCT": "A", 71 | "GCC": "A", 72 | "GCG": "A", 73 | "GGA": "G", 74 | "GGT": "G", 75 | "GGC": "G", 76 | "GGG": "G", 77 | } 78 | 79 | # Amino acid code conversion 80 | AA_DICT = { 81 | "Ala": "A", 82 | "Cys": "C", 83 | "Asp": "D", 84 | "Glu": "E", 85 | "Phe": "F", 86 | "Gly": "G", 87 | "His": "H", 88 | "Ile": "I", 89 | "Lys": "K", 90 | "Leu": "L", 91 | "Met": "M", 92 | "Asn": "N", 93 | "Pro": "P", 94 | "Gln": "Q", 95 | "Arg": "R", 96 | "Ser": "S", 97 | "Thr": "T", 98 | "Val": "V", 99 | "Trp": "W", 100 | "Tyr": "Y", 101 | "Ter": "*", 102 | } 103 | 104 | # the upper case three letter code for the amino acids 105 | ALL_AAS_TLC_DICT = {k.upper(): v for k, v in AA_DICT.items() if v != "*"} 106 | 107 | # the upper case three letter code for the amino acids 108 | ALL_AAS_TLC = list(ALL_AAS_TLC_DICT.keys()) 109 | 110 | # All canonical amino acids 111 | ALL_AAS = list(ALL_AAS_TLC_DICT.values()) 112 | ALL_AA_STR = "".join(ALL_AAS) 113 | AA_NUMB = len(ALL_AAS) 114 | ALLOWED_AAS = set(ALL_AAS) 115 | 116 | # Create a new dictionary with values as keys and counts as values 117 | CODON_COUNT_PER_AA = {aa: Counter(TRANSLATE_DICT.values())[aa] for aa in ALL_AAS + ["*"]} 118 | 119 | # Create a dictionary that links each amino acid to an index 120 | AA_TO_IND = {aa: i for i, aa in enumerate(ALL_AAS)} 121 | 122 | # Define a expressions for parsing mutations in the format Wt##Mut 123 | MUT_REGEX = re.compile("^([A-Z])([0-9]+)([A-Z])$") 124 | 125 | START_AA_IND = 1 126 | 127 | 128 | # Copied from ProFET (Ofer & Linial, DOI: 10.1093/bioinformatics/btv345) 129 | # Original comment by the ProFET authors: 'Acquired from georgiev's paper of 130 | # AAscales using helper script "GetTextData.py". + RegEx cleaning DOI: 10.1089/cmb.2008.0173' 131 | gg_1 = { 132 | "Q": -2.54, 133 | "L": 2.72, 134 | "T": -0.65, 135 | "C": 2.66, 136 | "I": 3.1, 137 | "G": 0.15, 138 | "V": 2.64, 139 | "K": -3.89, 140 | "M": 1.89, 141 | "F": 3.12, 142 | "N": -2.02, 143 | "R": -2.8, 144 | "H": -0.39, 145 | "E": -3.08, 146 | "W": 1.89, 147 | "A": 0.57, 148 | "D": -2.46, 149 | "Y": 0.79, 150 | "S": -1.1, 151 | "P": -0.58, 152 | } 153 | gg_2 = { 154 | "Q": 1.82, 155 | "L": 1.88, 156 | "T": -1.6, 157 | "C": -1.52, 158 | "I": 0.37, 159 | "G": -3.49, 160 | "V": 0.03, 161 | "K": 1.47, 162 | "M": 3.88, 163 | "F": 0.68, 164 | "N": -1.92, 165 | "R": 0.31, 166 | "H": 1, 167 | "E": 3.45, 168 | "W": -0.09, 169 | "A": 3.37, 170 | "D": -0.66, 171 | "Y": -2.62, 172 | "S": -2.05, 173 | "P": -4.33, 174 | } 175 | gg_3 = { 176 | "Q": -0.82, 177 | "L": 1.92, 178 | "T": -1.39, 179 | "C": -3.29, 180 | "I": 0.26, 181 | "G": -2.97, 182 | "V": -0.67, 183 | "K": 1.95, 184 | "M": -1.57, 185 | "F": 2.4, 186 | "N": 0.04, 187 | "R": 2.84, 188 | "H": -0.63, 189 | "E": 0.05, 190 | "W": 4.21, 191 | "A": -3.66, 192 | "D": -0.57, 193 | "Y": 4.11, 194 | "S": -2.19, 195 | "P": -0.02, 196 | } 197 | gg_4 = { 198 | "Q": -1.85, 199 | "L": 5.33, 200 | "T": 0.63, 201 | "C": -3.77, 202 | "I": 1.04, 203 | "G": 2.06, 204 | "V": 2.34, 205 | "K": 1.17, 206 | "M": -3.58, 207 | "F": -0.35, 208 | "N": -0.65, 209 | "R": 0.25, 210 | "H": -3.49, 211 | "E": 0.62, 212 | "W": -2.77, 213 | "A": 2.34, 214 | "D": 0.14, 215 | "Y": -0.63, 216 | "S": 1.36, 217 | "P": -0.21, 218 | } 219 | gg_5 = { 220 | "Q": 0.09, 221 | "L": 0.08, 222 | "T": 1.35, 223 | "C": 2.96, 224 | "I": -0.05, 225 | "G": 0.7, 226 | "V": 0.64, 227 | "K": 0.53, 228 | "M": -2.55, 229 | "F": -0.88, 230 | "N": 1.61, 231 | "R": 0.2, 232 | "H": 0.05, 233 | "E": -0.49, 234 | "W": 0.72, 235 | "A": -1.07, 236 | "D": 0.75, 237 | "Y": 1.89, 238 | "S": 1.78, 239 | "P": -8.31, 240 | } 241 | gg_6 = { 242 | "Q": 0.6, 243 | "L": 0.09, 244 | "T": -2.45, 245 | "C": -2.23, 246 | "I": -1.18, 247 | "G": 7.47, 248 | "V": -2.01, 249 | "K": 0.1, 250 | "M": 2.07, 251 | "F": 1.62, 252 | "N": 2.08, 253 | "R": -0.37, 254 | "H": 0.41, 255 | "E": 0, 256 | "W": 0.86, 257 | "A": -0.4, 258 | "D": 0.24, 259 | "Y": -0.53, 260 | "S": -3.36, 261 | "P": -1.82, 262 | } 263 | gg_7 = { 264 | "Q": 0.25, 265 | "L": 0.27, 266 | "T": -0.65, 267 | "C": 0.44, 268 | "I": -0.21, 269 | "G": 0.41, 270 | "V": -0.33, 271 | "K": 4.01, 272 | "M": 0.84, 273 | "F": -0.15, 274 | "N": 0.4, 275 | "R": 3.81, 276 | "H": 1.61, 277 | "E": -5.66, 278 | "W": -1.07, 279 | "A": 1.23, 280 | "D": -5.15, 281 | "Y": -1.3, 282 | "S": 1.39, 283 | "P": -0.12, 284 | } 285 | gg_8 = { 286 | "Q": 2.11, 287 | "L": -4.06, 288 | "T": 3.43, 289 | "C": -3.49, 290 | "I": 3.45, 291 | "G": 1.62, 292 | "V": 3.93, 293 | "K": -0.01, 294 | "M": 1.85, 295 | "F": -0.41, 296 | "N": -2.47, 297 | "R": 0.98, 298 | "H": -0.6, 299 | "E": -0.11, 300 | "W": -1.66, 301 | "A": -2.32, 302 | "D": -1.17, 303 | "Y": 1.31, 304 | "S": -1.21, 305 | "P": -1.18, 306 | } 307 | gg_9 = { 308 | "Q": -1.92, 309 | "L": 0.43, 310 | "T": 0.34, 311 | "C": 2.22, 312 | "I": 0.86, 313 | "G": -0.47, 314 | "V": -0.21, 315 | "K": -0.26, 316 | "M": -2.05, 317 | "F": 4.2, 318 | "N": -0.07, 319 | "R": 2.43, 320 | "H": 3.55, 321 | "E": 1.49, 322 | "W": -5.87, 323 | "A": -2.01, 324 | "D": 0.73, 325 | "Y": -0.56, 326 | "S": -2.83, 327 | "P": 0, 328 | } 329 | gg_10 = { 330 | "Q": -1.67, 331 | "L": -1.2, 332 | "T": 0.24, 333 | "C": -3.78, 334 | "I": 1.98, 335 | "G": -2.9, 336 | "V": 1.27, 337 | "K": -1.66, 338 | "M": 0.78, 339 | "F": 0.73, 340 | "N": 7.02, 341 | "R": -0.99, 342 | "H": 1.52, 343 | "E": -2.26, 344 | "W": -0.66, 345 | "A": 1.31, 346 | "D": 1.5, 347 | "Y": -0.95, 348 | "S": 0.39, 349 | "P": -0.66, 350 | } 351 | gg_11 = { 352 | "Q": 0.7, 353 | "L": 0.67, 354 | "T": -0.53, 355 | "C": 1.98, 356 | "I": 0.89, 357 | "G": -0.98, 358 | "V": 0.43, 359 | "K": 5.86, 360 | "M": 1.53, 361 | "F": -0.56, 362 | "N": 1.32, 363 | "R": -4.9, 364 | "H": -2.28, 365 | "E": -1.62, 366 | "W": -2.49, 367 | "A": -1.14, 368 | "D": 1.51, 369 | "Y": 1.91, 370 | "S": -2.92, 371 | "P": 0.64, 372 | } 373 | gg_12 = { 374 | "Q": -0.27, 375 | "L": -0.29, 376 | "T": 1.91, 377 | "C": -0.43, 378 | "I": -1.67, 379 | "G": -0.62, 380 | "V": -1.71, 381 | "K": -0.06, 382 | "M": 2.44, 383 | "F": 3.54, 384 | "N": -2.44, 385 | "R": 2.09, 386 | "H": -3.12, 387 | "E": -3.97, 388 | "W": -0.3, 389 | "A": 0.19, 390 | "D": 5.61, 391 | "Y": -1.26, 392 | "S": 1.27, 393 | "P": -0.92, 394 | } 395 | gg_13 = { 396 | "Q": -0.99, 397 | "L": -2.47, 398 | "T": 2.66, 399 | "C": -1.03, 400 | "I": -1.02, 401 | "G": -0.11, 402 | "V": -2.93, 403 | "K": 1.38, 404 | "M": -0.26, 405 | "F": 5.25, 406 | "N": 0.37, 407 | "R": -3.08, 408 | "H": -1.45, 409 | "E": 2.3, 410 | "W": -0.5, 411 | "A": 1.66, 412 | "D": -3.85, 413 | "Y": 1.57, 414 | "S": 2.86, 415 | "P": -0.37, 416 | } 417 | gg_14 = { 418 | "Q": -1.56, 419 | "L": -4.79, 420 | "T": -3.07, 421 | "C": 0.93, 422 | "I": -1.21, 423 | "G": 0.15, 424 | "V": 4.22, 425 | "K": 1.78, 426 | "M": -3.09, 427 | "F": 1.73, 428 | "N": -0.89, 429 | "R": 0.82, 430 | "H": -0.77, 431 | "E": -0.06, 432 | "W": 1.64, 433 | "A": 4.39, 434 | "D": 1.28, 435 | "Y": 0.2, 436 | "S": -1.88, 437 | "P": 0.17, 438 | } 439 | gg_15 = { 440 | "Q": 6.22, 441 | "L": 0.8, 442 | "T": 0.2, 443 | "C": 1.43, 444 | "I": -1.78, 445 | "G": -0.53, 446 | "V": 1.06, 447 | "K": -2.71, 448 | "M": -1.39, 449 | "F": 2.14, 450 | "N": 3.13, 451 | "R": 1.32, 452 | "H": -4.18, 453 | "E": -0.35, 454 | "W": -0.72, 455 | "A": 0.18, 456 | "D": -1.98, 457 | "Y": -0.76, 458 | "S": -2.42, 459 | "P": 0.36, 460 | } 461 | gg_16 = { 462 | "Q": -0.18, 463 | "L": -1.43, 464 | "T": -2.2, 465 | "C": 1.45, 466 | "I": 5.71, 467 | "G": 0.35, 468 | "V": -1.31, 469 | "K": 1.62, 470 | "M": -1.02, 471 | "F": 1.1, 472 | "N": 0.79, 473 | "R": 0.69, 474 | "H": -2.91, 475 | "E": 1.51, 476 | "W": 1.75, 477 | "A": -2.6, 478 | "D": 0.05, 479 | "Y": -5.19, 480 | "S": 1.75, 481 | "P": 0.08, 482 | } 483 | gg_17 = { 484 | "Q": 2.72, 485 | "L": 0.63, 486 | "T": 3.73, 487 | "C": -1.15, 488 | "I": 1.54, 489 | "G": 0.3, 490 | "V": -1.97, 491 | "K": 0.96, 492 | "M": -4.32, 493 | "F": 0.68, 494 | "N": -1.54, 495 | "R": -2.62, 496 | "H": 3.37, 497 | "E": -2.29, 498 | "W": 2.73, 499 | "A": 1.49, 500 | "D": 0.9, 501 | "Y": -2.56, 502 | "S": -2.77, 503 | "P": 0.16, 504 | } 505 | gg_18 = { 506 | "Q": 4.35, 507 | "L": -0.24, 508 | "T": -5.46, 509 | "C": -1.64, 510 | "I": 2.11, 511 | "G": 0.32, 512 | "V": -1.21, 513 | "K": -1.09, 514 | "M": -1.34, 515 | "F": 1.46, 516 | "N": -1.71, 517 | "R": -1.49, 518 | "H": 1.87, 519 | "E": -1.47, 520 | "W": -2.2, 521 | "A": 0.46, 522 | "D": 1.38, 523 | "Y": 2.87, 524 | "S": 3.36, 525 | "P": -0.34, 526 | } 527 | gg_19 = { 528 | "Q": 0.92, 529 | "L": 1.01, 530 | "T": -0.73, 531 | "C": -1.05, 532 | "I": -4.18, 533 | "G": 0.05, 534 | "V": 4.77, 535 | "K": 1.36, 536 | "M": 0.09, 537 | "F": 2.33, 538 | "N": -0.25, 539 | "R": -2.57, 540 | "H": 2.17, 541 | "E": 0.15, 542 | "W": 0.9, 543 | "A": -4.22, 544 | "D": -0.03, 545 | "Y": -3.43, 546 | "S": 2.67, 547 | "P": 0.04, 548 | } 549 | 550 | # Package all georgiev parameters 551 | georgiev_parameters = [ 552 | gg_1, 553 | gg_2, 554 | gg_3, 555 | gg_4, 556 | gg_5, 557 | gg_6, 558 | gg_7, 559 | gg_8, 560 | gg_9, 561 | gg_10, 562 | gg_11, 563 | gg_12, 564 | gg_13, 565 | gg_14, 566 | gg_15, 567 | gg_16, 568 | gg_17, 569 | gg_18, 570 | gg_19, 571 | ] 572 | 573 | DEFAULT_ESM = "esm2_t33_650M_UR50D" 574 | 575 | EMB_METHOD_COMBOS = [ 576 | {"flatten_emb": "flatten", "ifsite": True, "combo_folder": "flatten_site"}, 577 | {"flatten_emb": "mean", "ifsite": True, "combo_folder": "mean_site"}, 578 | {"flatten_emb": "mean", "ifsite": False, "combo_folder": "mean_all"}, 579 | ] 580 | 581 | EMB_COMBO_LIST = deepcopy(sorted(i["combo_folder"] for i in EMB_METHOD_COMBOS)) 582 | 583 | DEFAULT_LEARNED_EMB_COMBO = deepcopy([f"{DEFAULT_ESM}-{combo}" for combo in EMB_COMBO_LIST]) 584 | DEFAULT_LEARNED_EMB_DIR = "learned_emb" 585 | -------------------------------------------------------------------------------- /SSMuLA/vis.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | import pandas as pd 6 | 7 | from cairosvg import svg2png 8 | 9 | import seaborn as sns 10 | 11 | import matplotlib.pyplot as plt 12 | import matplotlib.colors as mcolors 13 | 14 | import colorcet as cc 15 | 16 | import bokeh 17 | from bokeh.io import export_svg 18 | from bokeh.models import NumeralTickFormatter 19 | from bokeh.themes.theme import Theme 20 | 21 | import holoviews as hv 22 | 23 | from SSMuLA.landscape_global import LIB_NAMES 24 | from SSMuLA.util import checkNgen_folder 25 | 26 | bokeh.io.output_notebook() 27 | hv.extension("bokeh", "matplotlib") 28 | 29 | # Select a colorcet colormap, for example, 'fire' or 'CET_CBL1' 30 | colormap = cc.cm["glasbey_category10"] 31 | 32 | # Extract a list of hex codes from the colormap 33 | glasbey_category10 = [mcolors.to_hex(colormap(i)) for i in range(colormap.N)] 34 | 35 | 36 | JSON_THEME = Theme( 37 | json={ 38 | "attrs": { 39 | "Title": { 40 | "align": "center", 41 | "text_font_size": "12px", 42 | "text_color": "black", 43 | "text_font": "arial", 44 | }, # title centered and bigger 45 | "Axis": { 46 | "axis_label_text_font_style": "normal", 47 | "axis_label_text_color": "black", 48 | "major_label_text_color": "black", 49 | "axis_label_text_font": "arial", 50 | "major_label_text_font": "arial", 51 | }, # no italic labels 52 | "Legend": { 53 | "title_text_font_style": "normal", 54 | "title_text_color": "black", 55 | "label_text_color": "black", 56 | "label_text_font": "arial", 57 | }, 58 | "ColorBar": { 59 | "title_text_font_style": "normal", 60 | "major_label_text_color": "black", 61 | "major_label_text_font": "arial", 62 | "title_text_color": "black", 63 | "title_text_font": "arial", 64 | }, 65 | } 66 | } 67 | ) 68 | 69 | hv.renderer("bokeh").theme = JSON_THEME 70 | 71 | # Grey for heatmap 72 | HM_GREY = "#76777B" 73 | 74 | # blue, orange, green, yellow, purple, gray 75 | FZL_PALETTE = { 76 | "blue": "#4bacc6", 77 | "orange": "#f79646ff", 78 | "light_orange": "#ffbb78", 79 | "red": "#ff8888", 80 | "maroon": "#7A303F", 81 | "green": "#9bbb59", 82 | "yellow": "#f9be00", 83 | "purple": "#8064a2", 84 | "brown": "#ae682f", 85 | "dark_brown": "#6e4a2eff", 86 | "gray": "#666666", 87 | "light_gray": "#D3D3D3", 88 | "light_blue": "#849895", 89 | "light_green": "#9DAE88", 90 | "light_yellow": "#F1D384", 91 | "light_brown": "#C7B784", 92 | "black": "#000000", 93 | } 94 | 95 | GRAY_COLORS = { 96 | "gray-blue": "#749aa3", 97 | "gray-orange": "#c58a6c", 98 | "gray-green": "#8b946e", 99 | "gray-yellow": "#d6b969", 100 | "gray-purple": "#897a8f", 101 | "gray-brown": "#8b6e57", 102 | } 103 | 104 | LIGHT_COLORS = {"yellow": "#F1D384"} 105 | 106 | PLOTEXTENTIONS = [".svg", ".png"] 107 | PLOTTYPES = [t[1:] for t in PLOTEXTENTIONS] 108 | 109 | LIB_COLORS = { 110 | n: c 111 | for (n, c) in zip( 112 | LIB_NAMES, 113 | [ 114 | FZL_PALETTE["orange"], 115 | FZL_PALETTE["light_orange"], 116 | FZL_PALETTE["brown"], 117 | FZL_PALETTE["yellow"], 118 | FZL_PALETTE["maroon"], 119 | FZL_PALETTE["purple"], 120 | ] 121 | + sns.color_palette("crest", 9).as_hex() 122 | + [FZL_PALETTE["gray"]], 123 | ) 124 | } 125 | 126 | LIB_COLORS_GLASBEY = { 127 | n: c 128 | for (n, c) in zip( 129 | LIB_NAMES, 130 | glasbey_category10[:6] 131 | + glasbey_category10[12:15] 132 | + glasbey_category10[6:10] 133 | + [glasbey_category10[15]] 134 | + glasbey_category10[10:12], 135 | ) 136 | } 137 | 138 | LIB_COLORS_CODON = {"DHFR": "#ffbb78"} # light orange 139 | 140 | MLDE_COLORS = ( 141 | [ 142 | FZL_PALETTE["orange"], 143 | FZL_PALETTE["yellow"], 144 | ] 145 | + sns.color_palette("crest", 9).as_hex() 146 | + [FZL_PALETTE["gray"]] 147 | ) 148 | 149 | SIMPLE_ZS_COLOR_MAP = { 150 | "none": FZL_PALETTE["gray"], 151 | "ed_score": FZL_PALETTE["blue"], 152 | "ev_score": FZL_PALETTE["green"], 153 | "esm_score": FZL_PALETTE["purple"], 154 | "esmif_score": FZL_PALETTE["yellow"], 155 | "coves_score": FZL_PALETTE["brown"], 156 | "Triad_score": FZL_PALETTE["orange"], 157 | } 158 | 159 | ZS_COLOR_MAP = { 160 | "none": FZL_PALETTE["gray"], 161 | "ev_score": FZL_PALETTE["green"], 162 | "esm_score": FZL_PALETTE["purple"], 163 | "esmif_score": FZL_PALETTE["yellow"], 164 | "coves_score": FZL_PALETTE["brown"], 165 | "Triad_score": FZL_PALETTE["orange"], 166 | "Triad-esmif_score": FZL_PALETTE["light_blue"], 167 | "ev-esm_score": FZL_PALETTE["light_green"], 168 | "ev-esm-esmif_score": FZL_PALETTE["light_yellow"], 169 | "Triad-ev-esm-esmif_score": FZL_PALETTE["light_brown"], 170 | "two-best-comb_score": FZL_PALETTE["light_gray"], 171 | } 172 | 173 | 174 | # define plot hooks 175 | def one_decimal_x(plot, element): 176 | plot.handles["plot"].xaxis[0].formatter = NumeralTickFormatter(format="0.0") 177 | 178 | 179 | def one_decimal_y(plot, element): 180 | plot.handles["plot"].yaxis[0].formatter = NumeralTickFormatter(format="0.0") 181 | 182 | 183 | def fixmargins(plot, element): 184 | plot.handles["plot"].min_border_right = 30 185 | plot.handles["plot"].min_border_left = 65 186 | plot.handles["plot"].min_border_top = 20 187 | plot.handles["plot"].min_border_bottom = 65 188 | plot.handles["plot"].outline_line_color = "black" 189 | plot.handles["plot"].outline_line_alpha = 1 190 | plot.handles["plot"].outline_line_width = 1 191 | plot.handles["plot"].toolbar.autohide = True 192 | 193 | 194 | def render_hv(hv_plot) -> bokeh.plotting.Figure: 195 | """Render a holoviews plot as a bokeh plot""" 196 | return hv.render(hv_plot) 197 | 198 | 199 | def save_bokeh_hv( 200 | plot_obj, 201 | plot_name: str, 202 | plot_path: str, 203 | bokehorhv: str = "hv", 204 | dpi: int = 300, 205 | scale: int = 2, 206 | skippng: bool = False, 207 | ): 208 | 209 | """ 210 | A function for exporting bokeh plots as svg 211 | 212 | Args: 213 | - plot_obj: hv or bokeh plot object 214 | - plot_name: str: name of the plot 215 | - plot_path: str: path to save the plot without the plot_name 216 | - bokehorhv: str: 'hv' or 'bokeh' 217 | - dpi: int: dpi 218 | - scale: int: scale 219 | - skippng: bool: skip png 220 | """ 221 | 222 | plot_name = plot_name.replace(" ", "_") 223 | plot_path = checkNgen_folder(plot_path.replace(" ", "_")) 224 | plot_noext = os.path.join(plot_path, plot_name) 225 | 226 | if bokehorhv == "hv": 227 | 228 | # save as html legend 229 | hv.save(plot_obj, plot_noext + ".html") 230 | 231 | # hv.save(plot_obj, plot_noext, mode="auto", fmt='svg', dpi=300, toolbar='disable') 232 | 233 | plot_obj = hv.render(plot_obj, backend="bokeh") 234 | 235 | plot_obj.toolbar_location = None 236 | plot_obj.toolbar.logo = None 237 | 238 | plot_obj.output_backend = "svg" 239 | export_svg(plot_obj, filename=plot_noext + ".svg", timeout=1200) 240 | 241 | if not skippng: 242 | svg2png( 243 | write_to=plot_noext + ".png", 244 | dpi=dpi, 245 | scale=scale, 246 | bytestring=open(plot_noext + ".svg").read().encode("utf-8"), 247 | ) 248 | else: 249 | print("Skipping png export") 250 | 251 | 252 | def save_svg(fig, plot_title: str, path2folder: str, ifshow: bool = True): 253 | """ 254 | A function for saving svg plots 255 | """ 256 | 257 | plot_title_no_space = plot_title.replace(" ", "_") 258 | plt.savefig( 259 | os.path.join(checkNgen_folder(path2folder), f"{plot_title_no_space}.svg"), 260 | bbox_inches="tight", 261 | dpi=300, 262 | format="svg", 263 | ) 264 | 265 | if ifshow: 266 | plt.show() 267 | 268 | 269 | def save_plt(fig, plot_title: str, path2folder: str): 270 | 271 | """ 272 | A helper function for saving plt plots 273 | Args: 274 | - fig: plt.figure: the figure to save 275 | - plot_title: str: the title of the plot 276 | - path2folder: str: the path to the folder to save the plot 277 | """ 278 | 279 | for ext in PLOTEXTENTIONS: 280 | plot_title_no_space = plot_title.replace(" ", "_") 281 | plt.savefig( 282 | os.path.join(checkNgen_folder(path2folder), f"{plot_title_no_space}{ext}"), 283 | bbox_inches="tight", 284 | dpi=300, 285 | ) 286 | 287 | plt.close() 288 | 289 | 290 | def plot_fit_dist( 291 | fitness: pd.Series, 292 | label: str, 293 | color: str = "", 294 | spike_length: float | None = None, 295 | ignore_line_label: bool = False, 296 | ) -> hv.Distribution: 297 | """ 298 | Plot the fitness distribution 299 | 300 | Args: 301 | - fitness: pd.Series: fitness values 302 | - label: str: label 303 | - color: str: color 304 | - ignore_line_label: bool: ignore line label 305 | 306 | Returns: 307 | - hv.Distribution: plot of the fitness distribution 308 | """ 309 | 310 | if label == "codon": 311 | cap_label = f"{label.capitalize()}-level" 312 | elif label == "AA": 313 | cap_label = f"{label.upper()}-level" 314 | else: 315 | cap_label = label 316 | 317 | if color == "": 318 | color = FZL_PALETTE["blue"] 319 | 320 | if ignore_line_label: 321 | mean_label = {} 322 | median_label = {} 323 | else: 324 | mean_label = {"label": f"Mean {label}"} 325 | median_label = {"label": f"Median {label}"} 326 | 327 | hv_dist = hv.Distribution(fitness, label=cap_label).opts( 328 | width=400, 329 | height=400, 330 | color=color, 331 | line_color=None, 332 | ) 333 | 334 | # get y_range for spike height 335 | y_range = ( 336 | hv.renderer("bokeh").instance(mode="server").get_plot(hv_dist).state.y_range 337 | ) 338 | 339 | # set spike length to be 5% of the y_range 340 | if spike_length is None: 341 | spike_length = (y_range.end - y_range.start) * 0.05 342 | 343 | return ( 344 | hv_dist 345 | * hv.Spikes([fitness.mean()], **mean_label).opts( 346 | line_dash="dotted", 347 | line_color=color, 348 | line_width=1.6, 349 | spike_length=spike_length, 350 | ) 351 | * hv.Spikes([fitness.median()], **median_label).opts( 352 | line_color=color, line_width=1.6, spike_length=spike_length 353 | ) 354 | ) 355 | 356 | 357 | def plot_zs_violin( 358 | all_df: pd.DataFrame, 359 | zs: str, 360 | encoding_list: list[str], 361 | model: str, 362 | n_sample: int, 363 | n_top: int, 364 | metric: str, 365 | plot_name: str, 366 | ) -> hv.Violin: 367 | 368 | return hv.Violin( 369 | all_df[ 370 | (all_df["zs"] == zs) 371 | & (all_df["encoding"].isin(encoding_list)) 372 | & (all_df["model"] == model) 373 | & (all_df["n_sample"] == n_sample) 374 | & (all_df["n_top"] == n_top) 375 | ] 376 | .sort_values(["lib", "n_mut_cutoff"], ascending=[True, False]) 377 | .copy(), 378 | kdims=["lib", "n_mut_cutoff"], 379 | vdims=[metric], 380 | ).opts( 381 | width=1200, 382 | height=400, 383 | violin_color="n_mut_cutoff", 384 | show_legend=True, 385 | legend_position="top", 386 | legend_offset=(0, 5), 387 | title=plot_name, 388 | ylim=(0, 1), 389 | hooks=[one_decimal_x, one_decimal_y, fixmargins, lib_ncut_hook], 390 | ) 391 | 392 | 393 | def lib_ncut_hook(plot, element): 394 | 395 | plot.handles["plot"].x_range.factors = [ 396 | (lib, n_mut) for lib in LIB_NAMES for n_mut in ["single", "double", "all"] 397 | ] 398 | plot.handles["xaxis"].major_label_text_font_size = "0pt" 399 | # plot.handles['xaxis'].group_text_font_size = '0pt' 400 | # plot.handles['yaxis'].axis_label_text_font_size = '10pt' 401 | # plot.handles['yaxis'].axis_label_text_font_style = 'normal' 402 | # plot.handles['xaxis'].axis_label_text_font_style = 'normal' 403 | 404 | 405 | def generate_related_color(reference_idx, base_idx, target_idx, palette="colorblind"): 406 | """ 407 | Generate a color that has the same relationship to palette[target_idx] 408 | as palette[base_idx] has to palette[reference_idx]. 409 | 410 | Parameters: 411 | - reference_idx: Index of the reference color in the palette. 412 | - base_idx: Index of the color that is related to reference_idx. 413 | - target_idx: Index of the color to which the transformation is applied. 414 | - palette: Name of the Seaborn palette (default: "colorblind"). 415 | 416 | Returns: 417 | - A tuple representing the new RGB color. 418 | """ 419 | import seaborn as sns 420 | import numpy as np 421 | 422 | # Get the palette 423 | colors = sns.color_palette(palette) 424 | 425 | # Compute transformation 426 | color_shift = np.array(colors[base_idx]) - np.array(colors[reference_idx]) 427 | 428 | # Apply the transformation 429 | new_color = np.array(colors[target_idx]) + color_shift 430 | 431 | # Clip to valid RGB range [0,1] 432 | new_color = np.clip(new_color, 0, 1) 433 | 434 | return tuple(new_color) -------------------------------------------------------------------------------- /SSMuLA/calc_hd.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for handling the calculation of the hamming distance cutoff fitness 3 | """ 4 | 5 | import os 6 | from glob import glob 7 | 8 | import pandas as pd 9 | import numpy as np 10 | 11 | from scipy.stats import spearmanr 12 | from sklearn.metrics import roc_curve, auc 13 | 14 | from concurrent.futures import ProcessPoolExecutor, as_completed 15 | from tqdm import tqdm 16 | 17 | import seaborn as sns 18 | import matplotlib.pyplot as plt 19 | import matplotlib.colors as mcolors 20 | from matplotlib.lines import Line2D 21 | 22 | from SSMuLA.landscape_global import LIB_INFO_DICT, hamming, lib2prot 23 | from SSMuLA.zs_analysis import ( 24 | ZS_METRIC_MAP_TITLE, 25 | ZS_METRIC_MAP_LABEL, 26 | ZS_METRIC_BASELINE, 27 | ) 28 | from SSMuLA.vis import LIB_COLORS, FZL_PALETTE, LIB_COLORS_GLASBEY, save_plt, save_svg 29 | from SSMuLA.util import checkNgen_folder, get_file_name 30 | 31 | 32 | # Define the function that will be executed in parallel 33 | def process_aa(aa, all_aas, all_fitnesses): 34 | hm2_fits = [] 35 | for aa2, fitness in zip(all_aas, all_fitnesses): 36 | if hamming(aa, aa2) > 2: 37 | continue 38 | hm2_fits.append(fitness) 39 | return aa, np.mean(hm2_fits), np.std(hm2_fits) 40 | 41 | 42 | def get_hd_avg_fit( 43 | df_csv: str, 44 | hd_dir: str = "results/hd", 45 | num_processes: None | int = None, 46 | ): 47 | 48 | df = pd.read_csv(df_csv) 49 | 50 | # no stop codons 51 | df = df[~df["AAs"].str.contains("\*")].copy() 52 | 53 | # only active variants 54 | active_df = df[df["active"]].copy() 55 | 56 | all_aas = active_df["AAs"].tolist() 57 | all_fitnesses = active_df.loc[active_df["AAs"].isin(all_aas), "fitness"].tolist() 58 | 59 | hm2_dict = {} 60 | # Set number of processes; if None, use all available cores 61 | if num_processes is None: 62 | num_processes = int(np.round(os.cpu_count() * 0.8)) 63 | 64 | with ProcessPoolExecutor(max_workers=num_processes) as executor: 65 | futures = [ 66 | executor.submit(process_aa, aa, all_aas, all_fitnesses) for aa in all_aas 67 | ] 68 | for future in tqdm(as_completed(futures), total=len(futures)): 69 | aa, mean, std = future.result() 70 | hm2_dict[aa] = {"mean": mean, "std": std} 71 | 72 | mean_df = pd.DataFrame.from_dict(hm2_dict, orient="index") 73 | 74 | # Set the index name to 'aa' 75 | mean_df.index.name = "AAs" 76 | 77 | checkNgen_folder(hd_dir) 78 | mean_df.to_csv(os.path.join(hd_dir, get_file_name(df_csv) + ".csv")) 79 | 80 | return hm2_dict 81 | 82 | 83 | def run_hd_avg_fit( 84 | data_dir: str = "data", 85 | scalefit: str = "max", 86 | hd_dir: str = "results/hd_fit", 87 | num_processes: None | int = None, 88 | all_lib: bool = True, 89 | lib_list: list[str] = [], 90 | ): 91 | 92 | """ 93 | Run the calculation of the average fitness for all sequences within a Hamming distance of 2 94 | 95 | Args: 96 | - data_dir: str, the directory containing the data 97 | - scalefit: str, the scale of the fitness values 98 | - hd_dir: str, the directory to save the results 99 | - num_processes: None | int, the number of processes to use 100 | - all_lib: bool, whether to use all libraries 101 | - lib_list: list[str], the list of libraries to use 102 | """ 103 | 104 | if all_lib or len(lib_list) == 0: 105 | df_csvs = sorted(glob(f"{os.path.normpath(data_dir)}/*/scale2{scalefit}/*.csv")) 106 | else: 107 | df_csvs = [ 108 | f"{os.path.normpath(data_dir)}/{lib}/scale2{scalefit}/{lib}.csv" 109 | for lib in lib_list 110 | ] 111 | 112 | for df_csv in df_csvs: 113 | print(f"Processing {df_csv} ...") 114 | df = get_hd_avg_fit(df_csv, hd_dir) 115 | 116 | del df 117 | 118 | 119 | def correlate_hd2fit(aa, all_aas, all_fitnesses, all_ifactive): 120 | 121 | """ 122 | A function to correlate the Hamming distance of a sequence 123 | with all other sequences with the fitness values 124 | """ 125 | 126 | hms = [-1 * hamming(aa, aa2) for aa2 in all_aas] 127 | rho = spearmanr(all_fitnesses, hms)[0] 128 | 129 | fpr, tpr, _ = roc_curve(all_ifactive, hms, pos_label=1) 130 | rocauc = auc(fpr, tpr) 131 | 132 | return aa, rho, rocauc 133 | 134 | 135 | def get_hd_avg_metric( 136 | df_csv: str, 137 | hd_dir: str = "results/hd_corr", 138 | num_processes: None | int = None, 139 | ): 140 | 141 | df = pd.read_csv(df_csv) 142 | 143 | # no stop codons 144 | df = df[~df["AAs"].str.contains("\*")].copy() 145 | 146 | all_aas = df["AAs"].tolist() 147 | all_fitnesses = df["fitness"].values 148 | all_ifactive = df["active"].values 149 | 150 | # only active variants 151 | active_df = df[df["active"]].copy() 152 | 153 | all_active_aas = active_df["AAs"].tolist() 154 | 155 | hm_dict = {} 156 | # Set number of processes; if None, use all available cores 157 | if num_processes is None: 158 | num_processes = int(np.round(os.cpu_count() * 0.8)) 159 | 160 | with ProcessPoolExecutor(max_workers=num_processes) as executor: 161 | futures = [ 162 | executor.submit(correlate_hd2fit, aa, all_aas, all_fitnesses, all_ifactive) 163 | for aa in all_active_aas 164 | ] 165 | for future in tqdm(as_completed(futures), total=len(futures)): 166 | aa, rho, rocauc = future.result() 167 | hm_dict[aa] = {"rho": rho, "rocauc": rocauc} 168 | 169 | hm_df = pd.DataFrame.from_dict(hm_dict, orient="index") 170 | 171 | # Set the index name to 'aa' 172 | hm_df.index.name = "AAs" 173 | 174 | checkNgen_folder(hd_dir) 175 | hm_df.to_csv(os.path.join(hd_dir, get_file_name(df_csv) + ".csv")) 176 | 177 | return hm_df["rho"].mean(), hm_df["rocauc"].mean() 178 | 179 | 180 | def run_hd_avg_metric( 181 | data_dir: str = "data", 182 | scalefit: str = "max", 183 | hd_dir: str = "results/hd_corr", 184 | num_processes: None | int = None, 185 | all_lib: bool = True, 186 | lib_list: list[str] = [], 187 | ): 188 | 189 | """ 190 | Run the calculation of the average fitness for all sequences within a Hamming distance of 2 191 | 192 | Args: 193 | - data_dir: str, the directory containing the data 194 | - scalefit: str, the scale of the fitness values 195 | - hd_dir: str, the directory to save the results 196 | - num_processes: None | int, the number of processes to use 197 | - all_lib: bool, whether to use all libraries 198 | - lib_list: list[str], the list of libraries to use 199 | """ 200 | 201 | hd_avg_metric = pd.DataFrame(columns=["lib", "rho", "rocauc"]) 202 | 203 | if all_lib or len(lib_list) == 0: 204 | df_csvs = sorted(glob(f"{os.path.normpath(data_dir)}/*/scale2{scalefit}/*.csv")) 205 | else: 206 | df_csvs = [ 207 | f"{os.path.normpath(data_dir)}/{lib}/scale2{scalefit}/{lib}.csv" 208 | for lib in lib_list 209 | ] 210 | 211 | for df_csv in df_csvs: 212 | print(f"Processing {df_csv} ...") 213 | rho, roc_aud = get_hd_avg_metric(df_csv, hd_dir) 214 | 215 | hd_avg_metric = hd_avg_metric._append( 216 | { 217 | "lib": get_file_name(df_csv), 218 | "rho": rho, 219 | "rocauc": roc_aud, 220 | }, 221 | ignore_index=True, 222 | ) 223 | 224 | checkNgen_folder(hd_dir) 225 | hd_avg_metric.to_csv(os.path.join(hd_dir, "hd_avg_metric.csv")) 226 | 227 | return hd_avg_metric 228 | 229 | 230 | def plot_hd_avg_fit( 231 | figname: str, 232 | hd_fit_dir: str = "results/hd/hd_fit", 233 | fit_dir: str = "data", 234 | fitscale: str = "scale2max", 235 | ifsave: bool = True, 236 | fig_dir: str = "figs", 237 | ): 238 | 239 | all_dfs = [] 240 | wt_mean = {} 241 | full_mean = {} 242 | 243 | for lib, lib_dict in LIB_INFO_DICT.items(): 244 | 245 | df = pd.read_csv(os.path.join(hd_fit_dir, f"{lib}.csv")) 246 | df["lib"] = lib 247 | all_dfs.append(df) 248 | 249 | wt_mean[lib] = df[df["AAs"] == "".join(lib_dict["AAs"].values())][ 250 | "mean" 251 | ].values[0] 252 | 253 | fit_df = pd.read_csv( 254 | os.path.join(fit_dir, lib2prot(lib), fitscale, f"{lib}.csv") 255 | ) 256 | full_mean[lib] = fit_df["fitness"].mean() 257 | 258 | all_df = pd.concat(all_dfs) 259 | 260 | fig = plt.figure(figsize=(16, 8)) 261 | ax = sns.violinplot( 262 | x="lib", y="mean", data=all_df, hue="lib", palette=LIB_COLORS_GLASBEY 263 | ) 264 | 265 | # Set the alpha value of the facecolor automatically to 0.8 266 | for violin in ax.collections[:]: # Access only the violin bodies 267 | facecolor = violin.get_facecolor().flatten() # Get the current facecolor 268 | violin.set_facecolor(mcolors.to_rgba(facecolor, alpha=0.4)) # Set new facecolor 269 | 270 | for lib in LIB_INFO_DICT.keys(): 271 | 272 | # Find the position of the violin to add the line to 273 | position = all_df["lib"].unique().tolist().index(lib) 274 | 275 | # Overlay the mean as a scatter plot 276 | ax.axhline( 277 | all_df[all_df["lib"] == lib]["mean"].mean(), 278 | color=FZL_PALETTE["light_gray"], 279 | linestyle="solid", 280 | marker="x", 281 | linewidth=2, 282 | xmin=position / len(LIB_INFO_DICT) + 0.03125, 283 | xmax=(position + 1) / len(LIB_INFO_DICT) - 0.03125, 284 | ) 285 | ax.axhline( 286 | wt_mean[lib], 287 | color=LIB_COLORS_GLASBEY[lib], 288 | linestyle="--", 289 | linewidth=2, 290 | xmin=position / len(LIB_INFO_DICT), 291 | xmax=(position + 1) / len(LIB_INFO_DICT), 292 | ) 293 | ax.axhline( 294 | full_mean[lib], 295 | color=LIB_COLORS_GLASBEY[lib], 296 | linestyle="dotted", 297 | linewidth=2, 298 | xmin=position / len(LIB_INFO_DICT), 299 | xmax=(position + 1) / len(LIB_INFO_DICT), 300 | ) 301 | 302 | lines = [ 303 | Line2D( 304 | [0], 305 | [0], 306 | color=FZL_PALETTE["light_gray"], 307 | linestyle="none", 308 | lw=2, 309 | marker="x", 310 | ), 311 | Line2D([0], [0], color="black", linestyle="--", lw=2), 312 | Line2D([0], [0], color="black", linestyle="dotted", lw=2), 313 | ] 314 | labels = [ 315 | "Mean of the mean variant fitness of double-site library\nconstructed with any active variant", 316 | "Mean variant fitness of double-site library\nconstruscted with the landscape parent", 317 | "Mean of all variants", 318 | ] 319 | 320 | ax.legend(lines, labels, loc="upper left", bbox_to_anchor=(1, 1)) 321 | 322 | ax.set_xlabel("Landscapes") 323 | ax.set_ylabel( 324 | "Mean variant fitness of double-site library constructed with an active variant" 325 | ) 326 | 327 | if ifsave: 328 | save_svg(fig, figname, fig_dir) 329 | 330 | 331 | def plot_hd_corr( 332 | metric: str, 333 | figname: str, 334 | hd_corr_dir: str = "results/hd/hd_corr", 335 | ifsave: bool = True, 336 | fig_dir: str = "figs", 337 | ): 338 | 339 | all_dfs = [] 340 | wt_mean = {} 341 | 342 | for lib, lib_dict in LIB_INFO_DICT.items(): 343 | 344 | df = pd.read_csv(os.path.join(hd_corr_dir, f"{lib}.csv")) 345 | df["lib"] = lib 346 | all_dfs.append(df) 347 | 348 | wt_mean[lib] = df[df["AAs"] == "".join(lib_dict["AAs"].values())][ 349 | metric 350 | ].values[0] 351 | 352 | all_df = pd.concat(all_dfs) 353 | 354 | fig = plt.figure(figsize=(16, 8)) 355 | ax = sns.violinplot( 356 | x="lib", y=metric, data=all_df, hue="lib", palette=LIB_COLORS_GLASBEY 357 | ) 358 | 359 | # Set the alpha value of the facecolor 360 | for violin in ax.collections[:]: # Access only the violin bodies 361 | facecolor = violin.get_facecolor().flatten() # Get the current facecolor 362 | violin.set_facecolor(mcolors.to_rgba(facecolor, alpha=0.4)) # Set new facecolor 363 | 364 | for lib in LIB_INFO_DICT.keys(): 365 | 366 | # Find the position of the violin to add the line to 367 | position = all_df["lib"].unique().tolist().index(lib) 368 | 369 | # Overlay the mean 370 | ax.axhline( 371 | all_df[all_df["lib"] == lib][metric].mean(), 372 | color=FZL_PALETTE["light_gray"], 373 | linestyle="solid", 374 | marker="x", 375 | linewidth=2, 376 | xmin=position / len(LIB_INFO_DICT) + 0.03125, 377 | xmax=(position + 1) / len(LIB_INFO_DICT) - 0.03125, 378 | ) 379 | ax.axhline( 380 | wt_mean[lib], 381 | color=LIB_COLORS_GLASBEY[lib], 382 | linestyle="--", 383 | linewidth=2, 384 | xmin=position / len(LIB_INFO_DICT), 385 | xmax=(position + 1) / len(LIB_INFO_DICT), 386 | ) 387 | 388 | lines = [ 389 | Line2D( 390 | [0], 391 | [0], 392 | color=FZL_PALETTE["light_gray"], 393 | linestyle="none", 394 | lw=2, 395 | marker="x", 396 | ), 397 | Line2D([0], [0], color="black", linestyle="--", lw=2), 398 | Line2D([0], [0], color="black", linestyle="dotted", lw=2), 399 | ] 400 | labels = [ 401 | f"Mean {ZS_METRIC_MAP_LABEL[metric]}\nfrom any active variant", 402 | "From the landscape parent", 403 | ] 404 | ax.axhline( 405 | ZS_METRIC_BASELINE[metric], 406 | color=FZL_PALETTE["light_gray"], 407 | linestyle="dotted", 408 | linewidth=2, 409 | ) 410 | ax.legend(lines, labels, loc="upper left", bbox_to_anchor=(1, 1)) 411 | 412 | ax.set_xlabel("Landscapes") 413 | y_dets = ( 414 | ZS_METRIC_MAP_TITLE[metric] 415 | .replace("\n", " ") 416 | .replace("F", "f") 417 | .replace("A", "a") 418 | ) 419 | ax.set_ylabel(f"Hamming distance {y_dets}") 420 | 421 | if ifsave: 422 | save_svg(fig, figname, fig_dir) -------------------------------------------------------------------------------- /SSMuLA/zs_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work 3 | """ 4 | 5 | # Import packages 6 | import os 7 | import numpy as np 8 | import pandas as pd 9 | from pathlib import Path 10 | from collections import Counter 11 | 12 | # EvCouplings 13 | from evcouplings.couplings import CouplingsModel 14 | from evcouplings.align.alignment import Alignment 15 | 16 | # ESM 17 | import esm 18 | 19 | # from esm.model.msa_transformer import MSATransformer 20 | 21 | # Pytorch 22 | import torch 23 | 24 | # Other 25 | from tqdm import tqdm 26 | from typing import List, Tuple, Optional 27 | 28 | 29 | import warnings 30 | 31 | # from transformers import BertTokenizer, BertModel, BertForMaskedLM, pipeline 32 | 33 | 34 | class ZeroShotPrediction: 35 | """ 36 | Zero-Shot Analysis on a given sequence and model. 37 | 38 | Input: 39 | - Model Path obtained from EvCouplings 40 | - Sequence or mutation list stored as csv, 41 | where variant column must be [wtAA]_[pos][newAA] 42 | Output: - Scores for each variant 43 | """ 44 | 45 | def __init__(self, df, wt_seq): 46 | # make sure no stopbonds are present 47 | 48 | self.df = df 49 | self.wt_sequence = wt_seq 50 | 51 | def _get_n_df(self, n: int = 1): 52 | """Get n data frame with n mutants""" 53 | return self.df[self.df["combo"].apply(lambda x: len(x) == n)].copy() 54 | 55 | 56 | class ddG(ZeroShotPrediction): 57 | def __init__(self, df=None): 58 | super().__init__(df) 59 | pass 60 | 61 | 62 | class EvMutation(ZeroShotPrediction): 63 | """ 64 | Perform the EvMutation analysis on a given sequence and model. 65 | 66 | Input: 67 | - Model Path obtained from EvCouplings 68 | - Sequence or mutation list stored as csv 69 | Output: - EvMutation (delta) score for each variant 70 | """ 71 | 72 | def __init__(self, df, wt_sequence, model_path): 73 | super().__init__(df, wt_sequence) 74 | self.model_path = model_path 75 | print("Loading model...") 76 | self.model = self.load_model() 77 | print("Model loaded") 78 | self.idx_map = self.model.index_map 79 | 80 | def check_idx(self, pos: list = [1]): 81 | """Check if the position is in the index map""" 82 | all_pos = all([p in self.idx_map for p in pos]) 83 | 84 | if all_pos: 85 | return True 86 | else: 87 | return False 88 | 89 | def load_model(self): 90 | self.model = CouplingsModel(self.model_path) 91 | return self.model 92 | 93 | def _get_hamiltonian(self, mt, wt, pos): 94 | delta_E, _, _ = self.model.delta_hamiltonian([(pos, wt, mt)]) 95 | return delta_E 96 | 97 | def _get_n_hamiltonian(self, combo: list): 98 | """Get the hamiltonian for n mutants""" 99 | delta_E, _, _ = self.model.delta_hamiltonian(combo) 100 | return delta_E 101 | 102 | def upload_dms_coupling(self, dms_coupling): 103 | """Input: - dms_coupling: Path to the dms_coupling file (csv)""" 104 | self.dms_coupling = dms_coupling 105 | 106 | return self.dms_coupling 107 | 108 | def get_single_mutant_scores(self): 109 | """Get the single mutant scores for the dms_coupling file""" 110 | 111 | # Left join the dms_coupling file with the data set 112 | self.single_mutant_scores = pd.merge( 113 | self.dms_coupling, 114 | self.df, 115 | how="left", 116 | left_on=["pos", "wt"], 117 | right_on=["pos", "wt"], 118 | ) 119 | 120 | return self.single_mutant_scores 121 | 122 | def run_evmutation(self, df, _multi=True, _epistatic=False): 123 | """ 124 | Run EvMutation for all variants in the data set 125 | 126 | Input: - df: Data set containing the variants, loops trough column = 'combo' 127 | - _mean: If True, the mean of the probabilities is calculated. 128 | If False, the sum of the probabilities is calculated 129 | Output: - Score for each variant""" 130 | score = np.zeros(len(df)) 131 | wt_sequence = list(self.wt_sequence) 132 | 133 | for i, combo in enumerate(df["combo"]): 134 | # Prepare mut list for EvMutation 135 | mut_list = [] 136 | # Check if positions are in index map 137 | if self.check_idx(df["pos"].iloc[i]): 138 | 139 | if _multi: 140 | 141 | single_mutant_scores = np.zeros((1, len(combo))) 142 | 143 | for j, mt in enumerate(combo): 144 | if mt == "WT" or mt == "NA": 145 | score[i] = np.nan 146 | continue 147 | 148 | else: 149 | pos_wt = int( 150 | df["pos"].iloc[i][j] - 1 151 | ) # Position of the mutation with python indexing 152 | pos_ev = int( 153 | df["pos"].iloc[i][j] 154 | ) # Position of the mutation with python indexing 155 | # Get single scores 156 | single_mutant_scores[0, j] = self._get_hamiltonian( 157 | mt, wt_sequence[pos_wt], pos_ev 158 | ) # TODO: Improve at one point 159 | mut_list.append((pos_ev, wt_sequence[pos_wt], mt)) 160 | 161 | # Run EvMutation 162 | score[i] = self._get_n_hamiltonian(mut_list) 163 | 164 | if _epistatic: 165 | score[i] = score[i] - np.sum(single_mutant_scores) 166 | 167 | # TODO: Get epistatic scores dE = dE_combo - sum(dE_single_mutant) 168 | else: 169 | if combo == "WT" or combo == "NA": 170 | score[i] = np.nan 171 | continue 172 | 173 | else: 174 | mt = combo[0] 175 | pos_wt = int(df["pos"].iloc[i][0] - 1) 176 | pos_ev = int(df["pos"].iloc[i][0]) 177 | mut_list.append((pos_ev, wt_sequence[pos_wt], mt)) 178 | # score[i] = self._get_n_hamiltonian(mut_list) 179 | score[i] = self._get_hamiltonian( 180 | mt, wt_sequence[pos_wt], pos_ev 181 | ) 182 | else: 183 | score[i] = np.nan 184 | continue 185 | 186 | return score 187 | 188 | def _get_n_score(self, n: list = [1]): 189 | """Get any score for each variant in the data set""" 190 | df_n_list = [] 191 | 192 | # Get the n mutant scores 193 | for i in n: 194 | # Filter out n mutants 195 | df_n = self._get_n_df(i) 196 | if df_n.empty: # Check if the DataFrame is empty after filtering 197 | assert "Data set is empty" 198 | continue 199 | 200 | if i == 1: 201 | score_n = self.run_evmutation(df_n, _multi=False) 202 | else: 203 | score_n = self.run_evmutation(df_n, _multi=True) 204 | 205 | # Add column with number of mutations 206 | 207 | df_n.loc[:, "ev_score"] = score_n 208 | df_n.loc[:, "n_mut"] = i 209 | # score_n = pd.DataFrame(score_n, columns=['ev_score']) 210 | 211 | # if fit or fitness 212 | if "fit" in df_n.columns: 213 | fit_col = "fit" 214 | elif "fitness" in df_n.columns: 215 | fit_col = "fitness" 216 | 217 | # Choose only the columns we want 218 | df_n = df_n[["muts", fit_col, "n_mut", "ev_score"]] 219 | df_n_list.append(df_n) 220 | 221 | return pd.concat(df_n_list, axis=0) 222 | 223 | 224 | class ESM(ZeroShotPrediction): 225 | def __init__(self, df, wt_seq, logits_path="", regen_esm=False): 226 | super().__init__(df, wt_seq) 227 | self.df = df 228 | self.wt_sequence = wt_seq 229 | ( 230 | self.model, 231 | self.alphabet, 232 | self.batch_converter, 233 | self.device, 234 | ) = self._infer_model() 235 | self.mask_string, self.cls_string, self.eos_string = ( 236 | self.alphabet.mask_idx, 237 | self.alphabet.cls_idx, 238 | self.alphabet.eos_idx, 239 | ) 240 | self.alphabet_size = len(self.alphabet) 241 | 242 | if logits_path != "" and os.path.exists(logits_path) and not(regen_esm): 243 | print(f"{logits_path} exists and regen_esm = {regen_esm}. Loading...") 244 | self.logits = np.load(logits_path) 245 | else: 246 | print(f"Generating {logits_path}...") 247 | self.logits = self._get_logits() 248 | 249 | def _infer_model(self): 250 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 251 | # model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() 252 | # model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1() 253 | model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() 254 | batch_converter = alphabet.get_batch_converter() 255 | model.eval() 256 | model = model.to(device) 257 | print("Using device:", device) 258 | return model, alphabet, batch_converter, device 259 | 260 | def _get_logits(self): 261 | 262 | data_wt = [("WT", self.wt_sequence)] 263 | # Get Batch tokens for WT data 264 | batch_labels_wt, batch_strs_wt, batch_tokens_wt = self.batch_converter(data_wt) 265 | 266 | logits = np.zeros((len(self.wt_sequence), self.alphabet_size)) 267 | 268 | for (i, seq) in enumerate(data_wt[0][1]): 269 | batch_tokens_masked = batch_tokens_wt.clone() 270 | batch_tokens_masked[0, i] = self.alphabet.mask_idx 271 | batch_tokens_masked = batch_tokens_masked.to(self.device) 272 | 273 | with torch.no_grad(): 274 | token_probs = torch.log_softmax( 275 | self.model(batch_tokens_masked)["logits"], dim=-1 276 | ).cpu().numpy() 277 | 278 | logits[i] = token_probs[0, i+1] 279 | 280 | return logits 281 | 282 | def _get_mutant_prob(self, mt, wt, pos): 283 | """Get the probability of the mutant given the wild type sequence at certain position.""" 284 | 285 | wt_idx = self.alphabet.get_idx(wt) 286 | mt_idx = self.alphabet.get_idx(mt) 287 | 288 | return self.logits[pos, mt_idx] - self.logits[pos, wt_idx] 289 | 290 | def run_esm(self, df, _sum=True): 291 | """ 292 | Run ESM model for all variants in the data set 293 | 294 | Input: - logits: Logits of the wild type sequence 295 | - df: Data set containing the variants, loops trough column = 'Combo' and 'Pos' 296 | - _sum: If True, the sum of the probabilities is calculated. 297 | If False, the mean of the probabilities is calculated 298 | 299 | Output: - Score for each variant 300 | """ 301 | score = np.zeros(len(df)) 302 | wt_sequence = list(self.wt_sequence) 303 | 304 | if _sum: 305 | for i, combo in enumerate(df["combo"]): 306 | s = np.zeros(len(combo)) 307 | for j, mt in enumerate(combo): 308 | if mt == "WT": 309 | score[i] = 0 310 | continue 311 | 312 | elif mt == "NA": 313 | score[i] = np.nan 314 | continue 315 | 316 | else: 317 | pos = ( 318 | int(df["pos"].iloc[i][j]) - 1 319 | ) # Position of the mutation with python indexing 320 | wt = wt_sequence[pos] 321 | s[j] = self._get_mutant_prob(mt=mt, wt=wt, pos=pos) 322 | 323 | score[i] += s.sum() 324 | 325 | else: 326 | for i, combo in enumerate(df["combo"]): 327 | 328 | mt = combo[0] 329 | 330 | if mt == "WT": 331 | score[i] = 0 332 | continue 333 | 334 | elif mt == "NA": 335 | score[i] = np.nan 336 | continue 337 | 338 | else: 339 | pos = int(df["pos"].iloc[i][0] - 1) 340 | wt = wt_sequence[pos] 341 | score[i] = self._get_mutant_prob(mt=mt, wt=wt, pos=pos) 342 | 343 | return score 344 | 345 | def _get_n_df(self, n: int = 1): 346 | """Get n data frame with n mutants""" 347 | 348 | return self.df[self.df["combo"].apply(lambda x: len(x) == n)].copy() 349 | 350 | def _get_n_score(self, n: list = [1]): 351 | """Get any score for each variant in the data set""" 352 | df_n_list = [] 353 | 354 | # Get the n mutant scores 355 | for i in n: 356 | # Filter out n mutants 357 | df_n = self._get_n_df(i) 358 | if df_n.empty: # Check if the DataFrame is empty after filtering 359 | assert "Data set is empty" 360 | continue 361 | 362 | if i == 1: 363 | score_n = self.run_esm(df_n, _sum=False) 364 | else: 365 | score_n = self.run_esm(df_n, _sum=True) 366 | 367 | # Add column with number of mutations 368 | 369 | df_n.loc[:, "esm_score"] = score_n 370 | df_n.loc[:, "n_mut"] = i 371 | # score_n = pd.DataFrame(score_n, columns=['ev_score']) 372 | # if fit or fitness 373 | if "fit" in df_n.columns: 374 | fit_col = "fit" 375 | elif "fitness" in df_n.columns: 376 | fit_col = "fitness" 377 | 378 | # Choose only the columns we want 379 | df_n = df_n[["muts", fit_col, "n_mut", "esm_score"]] 380 | df_n_list.append(df_n) 381 | 382 | return pd.concat(df_n_list, axis=0) 383 | -------------------------------------------------------------------------------- /SSMuLA/plm_finetune.py: -------------------------------------------------------------------------------- 1 | # modify from repo: 2 | # https://github.com/RSchmirler/data-repo_plm-finetune-eval/blob/main/notebooks/finetune/Finetuning_per_protein.ipynb 3 | 4 | # import dependencies 5 | import os.path 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 11 | from torch.utils.data import DataLoader 12 | 13 | import re 14 | import numpy as np 15 | import pandas as pd 16 | from copy import deepcopy 17 | 18 | import transformers, datasets 19 | from transformers.modeling_outputs import SequenceClassifierOutput 20 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack 21 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 22 | from transformers import T5EncoderModel, T5Tokenizer 23 | from transformers import EsmModel, AutoTokenizer, AutoModelForSequenceClassification 24 | from transformers import TrainingArguments, Trainer, set_seed 25 | 26 | import peft 27 | from peft import ( 28 | get_peft_config, 29 | PeftModel, 30 | PeftConfig, 31 | inject_adapter_in_model, 32 | LoraConfig, 33 | ) 34 | 35 | from evaluate import load 36 | from datasets import Dataset 37 | 38 | from tqdm import tqdm 39 | import random 40 | 41 | from scipy import stats 42 | from sklearn.metrics import accuracy_score 43 | 44 | import matplotlib.pyplot as plt 45 | 46 | from SSMuLA.util import checkNgen_folder, get_file_name 47 | 48 | 49 | # # Set environment variables to run Deepspeed from a notebook 50 | # os.environ["MASTER_ADDR"] = "localhost" 51 | # os.environ["MASTER_PORT"] = "9994" # modify if RuntimeError: Address already in use 52 | # os.environ["RANK"] = "0" 53 | # os.environ["LOCAL_RANK"] = "0" 54 | # os.environ["WORLD_SIZE"] = "1" 55 | 56 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 57 | # Deepspeed config for optimizer CPU offload 58 | 59 | ds_config = { 60 | "fp16": { 61 | "enabled": "auto", 62 | "loss_scale": 0, 63 | "loss_scale_window": 1000, 64 | "initial_scale_power": 16, 65 | "hysteresis": 2, 66 | "min_loss_scale": 1, 67 | }, 68 | "optimizer": { 69 | "type": "AdamW", 70 | "params": { 71 | "lr": "auto", 72 | "betas": "auto", 73 | "eps": "auto", 74 | "weight_decay": "auto", 75 | }, 76 | }, 77 | "scheduler": { 78 | "type": "WarmupLR", 79 | "params": { 80 | "warmup_min_lr": "auto", 81 | "warmup_max_lr": "auto", 82 | "warmup_num_steps": "auto", 83 | }, 84 | }, 85 | "zero_optimization": { 86 | "stage": 2, 87 | "offload_optimizer": {"device": "cpu", "pin_memory": True}, 88 | "allgather_partitions": True, 89 | "allgather_bucket_size": 2e8, 90 | "overlap_comm": True, 91 | "reduce_scatter": True, 92 | "reduce_bucket_size": 2e8, 93 | "contiguous_gradients": True, 94 | }, 95 | "gradient_accumulation_steps": "auto", 96 | "gradient_clipping": "auto", 97 | "steps_per_print": 2000, 98 | "train_batch_size": "auto", 99 | "train_micro_batch_size_per_gpu": "auto", 100 | "wall_clock_breakdown": False, 101 | } 102 | 103 | # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 104 | 105 | RAND_SEED_LIST = deepcopy([random.randint(0, 1000000) for _ in range(50)]) 106 | 107 | # load ESM2 models 108 | def load_esm_model(checkpoint, num_labels, half_precision, full=False, deepspeed=True): 109 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 110 | 111 | if half_precision and deepspeed: 112 | model = AutoModelForSequenceClassification.from_pretrained( 113 | checkpoint, num_labels=num_labels, torch_dtype=torch.float16 114 | ) 115 | else: 116 | model = AutoModelForSequenceClassification.from_pretrained( 117 | checkpoint, num_labels=num_labels 118 | ) 119 | 120 | if full == True: 121 | return model, tokenizer 122 | 123 | peft_config = LoraConfig( 124 | r=4, lora_alpha=1, bias="all", target_modules=["query", "key", "value", "dense"] 125 | ) 126 | 127 | model = inject_adapter_in_model(peft_config, model) 128 | 129 | # Unfreeze the prediction head 130 | for (param_name, param) in model.classifier.named_parameters(): 131 | param.requires_grad = True 132 | 133 | return model, tokenizer 134 | 135 | 136 | # Set random seeds for reproducibility of your trainings run 137 | def set_seeds(s): 138 | torch.manual_seed(s) 139 | np.random.seed(s) 140 | random.seed(s) 141 | set_seed(s) 142 | 143 | 144 | # Dataset creation 145 | def create_dataset(tokenizer, seqs, labels): 146 | tokenized = tokenizer(seqs, max_length=1024, padding=True, truncation=True) 147 | dataset = Dataset.from_dict(tokenized) 148 | dataset = dataset.add_column("labels", labels) 149 | 150 | return dataset 151 | 152 | 153 | def plot_train_val( 154 | history: list, # training history 155 | landscape: str, # landscape name 156 | ): 157 | 158 | # Get loss, val_loss, and the computed metric from history 159 | loss = [x["loss"] for x in history if "loss" in x] 160 | val_loss = [x["eval_loss"] for x in history if "eval_loss" in x] 161 | 162 | # Get spearman (for regression) or accuracy value (for classification) 163 | if [x["eval_spearmanr"] for x in history if "eval_spearmanr" in x] != []: 164 | metric = [x["eval_spearmanr"] for x in history if "eval_spearmanr" in x] 165 | else: 166 | metric = [x["eval_accuracy"] for x in history if "eval_accuracy" in x] 167 | 168 | epochs = [x["epoch"] for x in history if "loss" in x] 169 | 170 | # Create a figure with two y-axes 171 | fig, ax1 = plt.subplots(figsize=(10, 5)) 172 | ax2 = ax1.twinx() 173 | 174 | # Plot loss and val_loss on the first y-axis 175 | line1 = ax1.plot(epochs, loss, label="train_loss") 176 | line2 = ax1.plot(epochs, val_loss, label="val_loss") 177 | ax1.set_xlabel("Epoch") 178 | ax1.set_ylabel("Loss") 179 | 180 | # Plot the computed metric on the second y-axis 181 | line3 = ax2.plot(epochs, metric, color="red", label="val_metric") 182 | ax2.set_ylabel("Metric") 183 | ax2.set_ylim([0, 1]) 184 | 185 | # Combine the lines from both y-axes and create a single legend 186 | lines = line1 + line2 + line3 187 | labels = [line.get_label() for line in lines] 188 | ax1.legend(lines, labels, loc="lower left") 189 | 190 | # add title to the figure 191 | plt.title(f"Training curves for {landscape}") 192 | 193 | # return the plot but do not save it 194 | return fig 195 | 196 | 197 | def save_model(model, filepath): 198 | # Saves all parameters that were changed during finetuning 199 | 200 | # Create a dictionary to hold the non-frozen parameters 201 | non_frozen_params = {} 202 | 203 | # Iterate through all the model parameters 204 | for param_name, param in model.named_parameters(): 205 | # If the parameter has requires_grad=True, add it to the dictionary 206 | if param.requires_grad: 207 | non_frozen_params[param_name] = param 208 | 209 | # Save only the finetuned parameters 210 | torch.save(non_frozen_params, filepath) 211 | 212 | 213 | # Main training fuction 214 | def train_per_protein( 215 | checkpoint, # model checkpoint 216 | train_df, # training data 217 | valid_df, # validation data 218 | seed, # random seed 219 | device, # device to use 220 | num_labels=1, # 1 for regression, >1 for classification 221 | # effective training batch size is batch * accum 222 | # we recommend an effective batch size of 8 223 | batch=4, # for training 224 | accum=2, # gradient accumulation 225 | val_batch=16, # batch size for evaluation 226 | epochs=10, # training epochs 227 | lr=3e-4, # recommended learning rate 228 | deepspeed=False, # if gpu is large enough disable deepspeed for training speedup 229 | mixed=True, # enable mixed precision training 230 | full=False, # enable training of the full model (instead of LoRA) 231 | # gpu = 1 #gpu selection (1 for first gpu) 232 | ): 233 | 234 | print("Model used:", checkpoint, "\n") 235 | 236 | # Set all random seeds 237 | set_seeds(seed) 238 | 239 | # load model 240 | model, tokenizer = load_esm_model(checkpoint, num_labels, mixed, full, deepspeed) 241 | 242 | # Preprocess inputs 243 | # Replace uncommon AAs with "X" 244 | train_df["seq"] = train_df["seq"].str.replace( 245 | "|".join(["O", "B", "U", "Z", "J"]), "X", regex=True 246 | ) 247 | valid_df["seq"] = valid_df["seq"].str.replace( 248 | "|".join(["O", "B", "U", "Z", "J"]), "X", regex=True 249 | ) 250 | 251 | # Create Datasets 252 | train_set = create_dataset( 253 | tokenizer, list(train_df["seq"]), list(train_df["fitness"]) 254 | ) 255 | valid_set = create_dataset( 256 | tokenizer, list(valid_df["seq"]), list(valid_df["fitness"]) 257 | ) 258 | 259 | # Huggingface Trainer arguments 260 | args = TrainingArguments( 261 | "./", 262 | evaluation_strategy="epoch", 263 | logging_strategy="epoch", 264 | save_strategy="no", 265 | learning_rate=lr, 266 | per_device_train_batch_size=batch, 267 | per_device_eval_batch_size=val_batch, 268 | gradient_accumulation_steps=accum, 269 | num_train_epochs=epochs, 270 | seed=seed, 271 | deepspeed=ds_config if deepspeed else None, 272 | fp16=mixed, 273 | ) 274 | 275 | # Metric definition for validation data 276 | def compute_metrics(eval_pred): 277 | if num_labels > 1: # for classification 278 | metric = load("accuracy") 279 | predictions, labels = eval_pred 280 | predictions = np.argmax(predictions, axis=1) 281 | else: # for regression 282 | metric = load("spearmanr") 283 | predictions, labels = eval_pred 284 | 285 | return metric.compute(predictions=predictions, references=labels) 286 | 287 | # Trainer 288 | trainer = Trainer( 289 | model, 290 | args, 291 | train_dataset=train_set, 292 | eval_dataset=valid_set, 293 | tokenizer=tokenizer, 294 | compute_metrics=compute_metrics, 295 | ) 296 | 297 | # Train model 298 | trainer.train() 299 | 300 | return tokenizer, model, trainer.state.log_history 301 | 302 | 303 | def train_predict_per_protein( 304 | df_csv: str, # csv file with landscape data 305 | rep: int, # replicate number 306 | device: str, # device to use 307 | checkpoint: str = "facebook/esm2_t33_650M_UR50D", # model checkpoint 308 | n_sample: int = 384, # number of train+val 309 | zs_predictor: str = "none", # zero-shot predictor 310 | ft_frac: float = 0.125, # fraction of data for focused sampling 311 | save_dir: str = "results/finetuning", # directory to save the model, plot, and predictions 312 | train_kwargs: dict = {}, # additional training arguments 313 | rerun=False, # if True, the model will be trained again even if it already exists 314 | ): 315 | """ """ 316 | 317 | landscape = get_file_name(df_csv) 318 | 319 | seed = RAND_SEED_LIST[rep] 320 | 321 | df = pd.read_csv(df_csv) 322 | 323 | if zs_predictor == "none": 324 | df_sorted = df.copy() 325 | n = n_sample 326 | elif zs_predictor not in df.columns: 327 | print(f"{zs_predictor} not in the dataframe") 328 | df_sorted = df.copy() 329 | n = n_sample 330 | else: 331 | n_cutoff = int(len(df) * ft_frac) 332 | df_sorted = ( 333 | df.sort_values(by=zs_predictor, ascending=False).copy()[:n_cutoff].copy() 334 | ) 335 | n = min(n_sample, n_cutoff) 336 | 337 | # output_csv_path 338 | output_csv_path = os.path.join( 339 | checkNgen_folder(os.path.join(save_dir, "predictions", landscape)), 340 | f"{landscape}_{str(n_sample)}_{str(rep)}.csv", 341 | ) 342 | 343 | # check if the output already exists 344 | if os.path.exists(output_csv_path) and not rerun: 345 | print(f"Output already exists for {landscape} with n={n_sample} and rep={rep}") 346 | return df 347 | 348 | # randomly sample rows from the dataframe 349 | train_val_df = ( 350 | df_sorted.sample(n=n, random_state=seed).reset_index(drop=True).copy() 351 | ) 352 | 353 | # split the train_val_df into 90%training and 10% validation sets 354 | train_df = ( 355 | train_val_df.sample(frac=0.9, random_state=seed).reset_index(drop=True).copy() 356 | ) 357 | valid_df = train_val_df.drop(train_df.index).reset_index(drop=True).copy() 358 | 359 | tokenizer, model, history = train_per_protein( 360 | checkpoint, train_df, valid_df, seed=seed, device=device, **train_kwargs 361 | ) 362 | 363 | # save the model 364 | save_model( 365 | model, 366 | os.path.join( 367 | checkNgen_folder(os.path.join(save_dir, "model", landscape)), 368 | f"{landscape}_{str(n_sample)}_{str(rep)}.pth", 369 | ), 370 | ) 371 | 372 | # plot the training history 373 | fig = plot_train_val(history, landscape) 374 | 375 | # save the plot 376 | fig.savefig( 377 | os.path.join( 378 | checkNgen_folder(os.path.join(save_dir, "plot", landscape)), 379 | f"{landscape}_{str(n_sample)}_{str(rep)}.png", 380 | ) 381 | ) 382 | 383 | # create Dataset 384 | test_set = create_dataset(tokenizer, list(df["seq"]), list(df["fitness"])) 385 | # make compatible with torch DataLoader 386 | test_set = test_set.with_format("torch", device=device) 387 | 388 | # Create a dataloader for the test dataset 389 | test_dataloader = DataLoader(test_set, batch_size=16, shuffle=False) 390 | 391 | # Put the model in evaluation mode 392 | model.eval() 393 | 394 | # Make predictions on the test dataset 395 | predictions = [] 396 | with torch.no_grad(): 397 | for batch in tqdm(test_dataloader): 398 | input_ids = batch["input_ids"].to(device) 399 | attention_mask = batch["attention_mask"].to(device) 400 | # add batch results(logits) to predictions 401 | predictions += model.float()( 402 | input_ids, attention_mask=attention_mask 403 | ).logits.tolist() 404 | 405 | # flatten the prediction to one single list 406 | # save predictions as a new column in the test dataframe 407 | df["predictions"] = np.array(predictions).flatten() 408 | 409 | # save the dataframe 410 | df[["seq", "fitness", "predictions"]].to_csv( 411 | output_csv_path, 412 | index=False, 413 | ) 414 | 415 | return df -------------------------------------------------------------------------------- /SSMuLA/get_factor.py: -------------------------------------------------------------------------------- 1 | """A script for saving dataframes as pngs""" 2 | 3 | import os 4 | import pandas as pd 5 | from sklearn.ensemble import RandomForestRegressor 6 | from sklearn.model_selection import train_test_split 7 | from sklearn.metrics import mean_squared_error 8 | from sklearn.linear_model import LinearRegression 9 | 10 | from matplotlib.colors import LinearSegmentedColormap, to_rgb 11 | import matplotlib.pyplot as plt 12 | import seaborn as sns 13 | 14 | # for html to png 15 | from selenium import webdriver 16 | from selenium.webdriver.firefox.service import Service 17 | from selenium.webdriver.firefox.options import Options 18 | 19 | from SSMuLA.de_simulations import DE_TYPES 20 | from SSMuLA.zs_analysis import ZS_OPTS, ZS_COMB_OPTS 21 | from SSMuLA.vis_summary import ZS_METRICS 22 | from SSMuLA.get_corr import LANDSCAPE_ATTRIBUTES, val_list, zs_list 23 | from SSMuLA.vis import FZL_PALETTE, save_plt 24 | from SSMuLA.util import checkNgen_folder 25 | 26 | # Custom colormap for the MSE row, using greens 27 | cmap_mse = LinearSegmentedColormap.from_list( 28 | "mse_cmap_r", ["#FFFFFF", "#9bbb59"][::-1], N=100 29 | ) # dark to light green 30 | 31 | # Create the colormap 32 | custom_cmap = LinearSegmentedColormap.from_list( 33 | "bwg", 34 | [ 35 | FZL_PALETTE["blue"], 36 | "white", 37 | FZL_PALETTE["green"], 38 | ], 39 | N=100, 40 | ) 41 | 42 | geckodriver_path = "/disk2/fli/miniconda3/envs/SSMuLA/bin/geckodriver" 43 | 44 | de_metrics = ["mean_all", "fraction_max"] 45 | 46 | simple_des = { 47 | "recomb_SSM": "Recomb", 48 | "single_step_DE": "Single step", 49 | "top96_SSM": "Top96 recomb", 50 | } 51 | 52 | simple_de_metric_map = {} 53 | 54 | for de_type in DE_TYPES: 55 | for de_metric in de_metrics: 56 | simple_de_metric_map[f"{de_type}_{de_metric}"] = simple_des[de_type] 57 | 58 | 59 | # Styling the DataFrame 60 | def style_dataframe(df): 61 | # Define a function to apply gradient selectively 62 | def apply_gradient(row): 63 | if row.name == "mse": 64 | # Generate colors for the MSE row based on its values 65 | norm = plt.Normalize(row.min(), row.max()) 66 | rgba_colors = [cmap_mse(norm(value)) for value in row] 67 | return [ 68 | f"background-color: rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]})" 69 | for rgba in rgba_colors 70 | ] 71 | else: 72 | return [""] * len(row) # No style for other rows 73 | 74 | # Apply gradient across all rows 75 | styled_df = df.style.background_gradient(cmap="Blues") 76 | # Apply the custom gradient to the MSE row 77 | styled_df = styled_df.apply(apply_gradient, axis=1) 78 | return styled_df.format("{:.2f}").apply( 79 | lambda x: ["color: black" if x.name == "mse" else "" for _ in x], axis=1 80 | ) 81 | 82 | 83 | def styledf2png( 84 | df, 85 | filename, 86 | sub_dir="results/style_dfs", 87 | absolute_dir="/disk2/fli/SSMuLA/", 88 | width=800, 89 | height=1600, 90 | ): 91 | 92 | html_path = os.path.join(sub_dir, filename + ".html") 93 | checkNgen_folder(html_path) 94 | 95 | # Create a HTML file 96 | html_file = open(html_path, "w") 97 | html_file.write(df.to_html()) 98 | html_file.close() 99 | 100 | options = Options() 101 | options.add_argument("--headless") # Run Firefox in headless mode. 102 | 103 | s = Service(geckodriver_path) 104 | driver = webdriver.Firefox(service=s, options=options) 105 | 106 | driver.get( 107 | f"file://{os.path.join(absolute_dir, html_path)}" 108 | ) # Update the path to your HTML file 109 | 110 | # Set the size of the window to your content (optional) 111 | driver.set_window_size(width, height) # You might need to adjust this 112 | 113 | # Take screenshot 114 | driver.save_screenshot(html_path.replace(".html", ".png")) 115 | driver.quit() 116 | 117 | 118 | def get_lib_stat( 119 | lib_csv: str = "results/corr_all/384/boosting|ridge-top96/merge_all.csv", 120 | sub_dir="results/style_dfs", 121 | absolute_dir="/disk2/fli/SSMuLA/", 122 | n_mut: str = "all", 123 | ): 124 | if n_mut != "all": 125 | corr_det = f"corr_{n_mut}" 126 | if corr_det not in lib_csv: 127 | lib_csv = lib_csv.replace("corr_all", corr_det) 128 | 129 | df = pd.read_csv(lib_csv) 130 | style_df = ( 131 | df[["lib"] + LANDSCAPE_ATTRIBUTES] 132 | .set_index("lib") 133 | .T.style.format("{:.2f}") 134 | .background_gradient(cmap="Blues", axis=1) 135 | ) 136 | 137 | return styledf2png( 138 | style_df, 139 | f"lib_stat_{n_mut}_heatmap_384-boosting|ridge-top96", 140 | sub_dir=sub_dir, 141 | absolute_dir=absolute_dir, 142 | width=1450, 143 | height=950, 144 | ) 145 | 146 | def get_zs_zs_corr( 147 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 148 | sub_dir="results/style_dfs", 149 | absolute_dir="/disk2/fli/SSMuLA/", 150 | n_mut: str = "all", 151 | metric: str = "rho" 152 | ): 153 | df = pd.read_csv(corr_csv) 154 | 155 | simple_zs = [zs for zs in zs_list if n_mut in zs and metric in zs] 156 | 157 | style_df = ( 158 | df[ 159 | df["descriptor"].isin( 160 | simple_zs 161 | ) 162 | ][["descriptor"] + simple_zs] 163 | .loc[df['descriptor'].isin(simple_zs)] 164 | .rename(columns={"descriptor": "ZS predictions"}) 165 | .set_index("ZS predictions") 166 | .rename(index=lambda x: x.replace("double", "hd2")) 167 | .rename(columns=lambda x: x.replace("double", "hd2")) 168 | .style.format("{:.2f}") 169 | .background_gradient(cmap=custom_cmap, vmin=0, vmax=1) 170 | ) 171 | 172 | return styledf2png( 173 | style_df, 174 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_zs", 175 | sub_dir=sub_dir, 176 | absolute_dir=absolute_dir, 177 | width=1250, 178 | height=450, 179 | ) 180 | 181 | def get_zs_ft_corr( 182 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 183 | sub_dir="results/style_dfs", 184 | absolute_dir="/disk2/fli/SSMuLA/", 185 | n_mut: str = "all", 186 | metric: str = "rho" 187 | ): 188 | """ 189 | Correlate ZS with ftMLDE 190 | """ 191 | df = pd.read_csv(corr_csv) 192 | 193 | simple_zs = [zs for zs in zs_list if n_mut in zs and metric in zs] 194 | # simple_ft = [] 195 | 196 | 197 | style_df = ( 198 | df[ 199 | df["descriptor"].isin( 200 | simple_zs 201 | ) 202 | ][["descriptor"] + simple_zs] 203 | .loc[df['descriptor'].isin(simple_zs)] 204 | .rename(columns={"descriptor": "ZS predictions"}) 205 | .set_index("ZS predictions") 206 | .rename(index=lambda x: x.replace("double", "hd2")) 207 | .rename(columns=lambda x: x.replace("double", "hd2")) 208 | .style.format("{:.2f}") 209 | .background_gradient(cmap=custom_cmap, vmin=0, vmax=1) 210 | ) 211 | 212 | return styledf2png( 213 | style_df, 214 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_zs", 215 | sub_dir=sub_dir, 216 | absolute_dir=absolute_dir, 217 | width=1250, 218 | height=450, 219 | ) 220 | 221 | 222 | def get_zs_corr_ls( 223 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 224 | sub_dir="results/style_dfs", 225 | absolute_dir="/disk2/fli/SSMuLA/", 226 | n_mut: str = "all", 227 | metric: str = "rho" 228 | ): 229 | 230 | df = pd.read_csv(corr_csv) 231 | 232 | style_df = ( 233 | df[ 234 | df["descriptor"].isin( 235 | [zs for zs in zs_list if n_mut in zs and metric in zs] 236 | ) 237 | ][["descriptor"] + LANDSCAPE_ATTRIBUTES] 238 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map}) 239 | .iloc[0:33] 240 | .set_index("Landscape attributes") 241 | .rename(index=lambda x: x.replace("double", "hd2")).T 242 | .style.format("{:.2f}") 243 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1) 244 | ) 245 | 246 | return styledf2png( 247 | style_df, 248 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_landscape_attributes", 249 | sub_dir=sub_dir, 250 | absolute_dir=absolute_dir, 251 | width=1450, 252 | height=950, 253 | ) 254 | 255 | 256 | def get_zs_corr( 257 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 258 | sub_dir="results/style_dfs", 259 | absolute_dir="/disk2/fli/SSMuLA/", 260 | deorls: str = "de", 261 | de_calc: str = "mean_all", # or fraction_max 262 | n_mut: str = "all", 263 | ): 264 | 265 | df = pd.read_csv(corr_csv) 266 | 267 | if deorls == "de": 268 | comp_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES] 269 | dets = de_calc 270 | width=625 271 | else: 272 | comp_list = LANDSCAPE_ATTRIBUTES 273 | dets = "landscape_attributes" 274 | width=1800 275 | 276 | style_df = ( 277 | df[ 278 | df["descriptor"].isin( 279 | [zs for zs in zs_list if n_mut in zs and "ndcg" not in zs] 280 | ) 281 | ][["descriptor"] + comp_list] 282 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map}) 283 | .iloc[0:33] 284 | .set_index("Landscape attributes") 285 | .rename(index=lambda x: x.replace("double", "hd2")) 286 | .style.format("{:.2f}") 287 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1) 288 | ) 289 | 290 | return styledf2png( 291 | style_df, 292 | f"zs_{n_mut}_heatmap_384-boosting|ridge-top96_{dets}", 293 | sub_dir=sub_dir, 294 | absolute_dir=absolute_dir, 295 | width=width, 296 | height=550, 297 | ) 298 | 299 | 300 | def get_corr_heatmap( 301 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 302 | sub_dir="results/style_dfs", 303 | absolute_dir="/disk2/fli/SSMuLA/", 304 | de_calc: str = "mean_all", # or fraction_max 305 | ): 306 | 307 | de_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES] 308 | 309 | df = pd.read_csv(corr_csv) 310 | style_df = ( 311 | df[["descriptor"] + de_list] 312 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map}) 313 | .iloc[0:33] 314 | .set_index("Landscape attributes") 315 | .style.format("{:.2f}") 316 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1) 317 | ) 318 | 319 | return styledf2png( 320 | style_df, 321 | f"corr_heatmap_384-boosting|ridge-top96_{de_calc}", 322 | sub_dir=sub_dir, 323 | absolute_dir=absolute_dir, 324 | width=720, 325 | height=975, 326 | ) 327 | 328 | 329 | def get_importance_heatmap( 330 | lib_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/merge_all.csv", 331 | sub_dir="results/style_dfs", 332 | absolute_dir="/disk2/fli/SSMuLA/", 333 | de_calc: str = "mean_all", # or fraction_max 334 | ): 335 | 336 | de_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES] 337 | 338 | df = pd.read_csv(lib_csv) 339 | 340 | # Load your dataset 341 | # data = pd.read_csv('path_to_your_data.csv') 342 | 343 | # Select features and targets 344 | features = df[LANDSCAPE_ATTRIBUTES] 345 | targets = df[val_list] 346 | 347 | lr_df_list = [] 348 | 349 | # Splitting the dataset for each target and fitting a model 350 | for target in targets.columns: 351 | lr_model = LinearRegression() 352 | lr_model.fit(features, df[target]) 353 | 354 | # Feature importance 355 | feature_importances = pd.DataFrame( 356 | lr_model.coef_, index=LANDSCAPE_ATTRIBUTES, columns=[target] 357 | ) 358 | 359 | lr_df_list.append(feature_importances) 360 | lr_df = pd.concat(lr_df_list, axis=1) 361 | lr_df.index.names = ["Landscape attributes"] 362 | 363 | style_df = ( 364 | lr_df[de_list] 365 | .rename(columns=simple_de_metric_map) 366 | .style.format("{:.2f}") 367 | .background_gradient(cmap=custom_cmap) 368 | ) 369 | 370 | return styledf2png( 371 | style_df, 372 | f"importance_heatmap_384-boosting|ridge-top96_{de_calc}", 373 | sub_dir=sub_dir, 374 | absolute_dir=absolute_dir, 375 | width=720, 376 | height=975, 377 | ) 378 | 379 | 380 | def plot_all_factor( 381 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv", 382 | sub_dir="results/style_dfs_actcut-1", 383 | absolute_dir="/disk2/fli/SSMuLA/", 384 | ): 385 | 386 | for n_mut in ["all", "double"]: 387 | 388 | for metric in ["rho", "rocauc"]: 389 | get_zs_zs_corr( 390 | corr_csv=corr_csv, 391 | n_mut=n_mut, 392 | metric=metric, 393 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_zs_corr")), 394 | absolute_dir=absolute_dir, 395 | ) 396 | 397 | get_zs_corr_ls( 398 | corr_csv=corr_csv, 399 | n_mut=n_mut, 400 | metric=metric, 401 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_corr_ls")), 402 | absolute_dir=absolute_dir, 403 | ) 404 | 405 | 406 | get_zs_ft_corr( 407 | corr_csv=corr_csv, 408 | n_mut=n_mut, 409 | metric=metric, 410 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_ft_corr")), 411 | absolute_dir=absolute_dir, 412 | ) 413 | 414 | for de_calc in ["mean_all", "fraction_max"]: 415 | for deorls in ["de", "ls"]: 416 | get_zs_corr( 417 | corr_csv=corr_csv, 418 | de_calc=de_calc, 419 | n_mut=n_mut, 420 | deorls=deorls, 421 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_corr")), 422 | absolute_dir=absolute_dir, 423 | ) 424 | 425 | get_corr_heatmap( 426 | corr_csv=corr_csv, 427 | de_calc=de_calc, 428 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "corr_heatmap")), 429 | absolute_dir=absolute_dir, 430 | ) 431 | 432 | get_importance_heatmap( 433 | lib_csv=corr_csv.replace("corr.csv", "merge_all.csv"), 434 | de_calc=de_calc, 435 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "importance_heatmap")), 436 | absolute_dir=absolute_dir, 437 | ) 438 | 439 | get_lib_stat( 440 | lib_csv=corr_csv.replace("corr.csv", "merge_all.csv"), 441 | n_mut=n_mut, 442 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "lib_stat")), 443 | absolute_dir=absolute_dir, 444 | ) 445 | -------------------------------------------------------------------------------- /envs/frozen/SSMuLA.yml: -------------------------------------------------------------------------------- 1 | name: SSMuLA 2 | channels: 3 | - pytorch 4 | - pyg 5 | - anaconda 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_gnu 12 | - _py-xgboost-mutex=2.0=gpu_0 13 | - absl-py=2.1.0=pyhd8ed1ab_1 14 | - aiohappyeyeballs=2.4.6=pyhd8ed1ab_0 15 | - aiohttp=3.11.13=py311h2dc5d0c_0 16 | - aiosignal=1.3.2=pyhd8ed1ab_0 17 | - alsa-lib=1.2.13=hb9d3cd8_0 18 | - anyio=4.8.0=pyhd8ed1ab_0 19 | - argon2-cffi=23.1.0=pyhd8ed1ab_1 20 | - argon2-cffi-bindings=21.2.0=py311h9ecbd09_5 21 | - arrow=1.3.0=pyhd8ed1ab_1 22 | - asttokens=3.0.0=pyhd8ed1ab_1 23 | - astunparse=1.6.3=pyhd8ed1ab_3 24 | - async-lru=2.0.4=pyhd8ed1ab_1 25 | - attrs=25.1.0=pyh71513ae_0 26 | - babel=2.17.0=pyhd8ed1ab_0 27 | - beautifulsoup4=4.13.3=pyha770c72_0 28 | - biopandas=0.5.1=pyhd8ed1ab_1 29 | - biopython=1.85=py311h9ecbd09_1 30 | - biotite=1.1.0=py311hfdbb021_1 31 | - biotraj=1.2.2=py311hfdbb021_0 32 | - blas=1.0=mkl 33 | - bleach=6.2.0=pyh29332c3_4 34 | - bleach-with-css=6.2.0=h82add2a_4 35 | - blosc=1.21.6=he440d0b_1 36 | - bokeh=3.6.3=pyhd8ed1ab_0 37 | - brokenaxes=0.4.2=pyhd8ed1ab_0 38 | - brotli=1.1.0=hb9d3cd8_2 39 | - brotli-bin=1.1.0=hb9d3cd8_2 40 | - brotli-python=1.1.0=py311hfdbb021_2 41 | - bzip2=1.0.8=h4bc722e_7 42 | - c-ares=1.34.4=hb9d3cd8_0 43 | - c-blosc2=2.15.2=h3122c55_1 44 | - ca-certificates=2025.1.31=hbcca054_0 45 | - cached-property=1.5.2=hd8ed1ab_1 46 | - cached_property=1.5.2=pyha770c72_1 47 | - cairo=1.18.2=h3394656_1 48 | - cairocffi=1.7.1=pyhd8ed1ab_1 49 | - cairosvg=2.7.1=pyhd8ed1ab_1 50 | - certifi=2025.1.31=pyhd8ed1ab_0 51 | - cffi=1.17.1=py311hf29c0ef_0 52 | - charset-normalizer=3.4.1=pyhd8ed1ab_0 53 | - colorama=0.4.6=pyhd8ed1ab_1 54 | - colorcet=3.1.0=pyhd8ed1ab_1 55 | - comm=0.2.2=pyhd8ed1ab_1 56 | - contourpy=1.3.1=py311hd18a35c_0 57 | - cpython=3.11.11=py311hd8ed1ab_1 58 | - cssselect2=0.2.1=pyh9f0ad1d_1 59 | - cuda-cudart=12.1.105=0 60 | - cuda-cupti=12.1.105=0 61 | - cuda-libraries=12.1.0=0 62 | - cuda-nvrtc=12.1.105=0 63 | - cuda-nvtx=12.1.105=0 64 | - cuda-opencl=12.4.127=0 65 | - cuda-runtime=12.1.0=0 66 | - cuda-version=11.8=h70ddcb2_3 67 | - cudatoolkit=11.8.0=h4ba93d1_13 68 | - cycler=0.12.1=pyhd8ed1ab_1 69 | - cyrus-sasl=2.1.27=h54b06d7_7 70 | - datashader=0.17.0=pyhd8ed1ab_0 71 | - dbus=1.13.6=h5008d03_3 72 | - debugpy=1.8.12=py311hfdbb021_0 73 | - decorator=5.2.1=pyhd8ed1ab_0 74 | - defusedxml=0.7.1=pyhd8ed1ab_0 75 | - double-conversion=3.3.1=h5888daf_0 76 | - et_xmlfile=2.0.0=pyhd8ed1ab_1 77 | - exceptiongroup=1.2.2=pyhd8ed1ab_1 78 | - executing=2.1.0=pyhd8ed1ab_1 79 | - expat=2.6.4=h5888daf_0 80 | - filelock=3.17.0=pyhd8ed1ab_0 81 | - firefox=134.0=h5888daf_0 82 | - flake8=7.1.2=pyhd8ed1ab_0 83 | - flatbuffers=24.12.23=h8f4948b_0 84 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 85 | - font-ttf-inconsolata=3.000=h77eed37_0 86 | - font-ttf-source-code-pro=2.038=h77eed37_0 87 | - font-ttf-ubuntu=0.83=h77eed37_3 88 | - fontconfig=2.15.0=h7e30c49_1 89 | - fonts-conda-ecosystem=1=0 90 | - fonts-conda-forge=1=0 91 | - fonttools=4.56.0=py311h2dc5d0c_0 92 | - fqdn=1.5.1=pyhd8ed1ab_1 93 | - freetype=2.12.1=h267a509_2 94 | - frozenlist=1.5.0=py311h2dc5d0c_1 95 | - fsspec=2025.2.0=pyhd8ed1ab_0 96 | - gast=0.6.0=pyhd8ed1ab_0 97 | - geckodriver=0.36.0=hb7e49e0_0 98 | - giflib=5.2.2=hd590300_0 99 | - gmp=6.3.0=hac33072_2 100 | - gmpy2=2.1.5=py311h0f6cedb_3 101 | - google-pasta=0.2.0=pyhd8ed1ab_2 102 | - graphite2=1.3.13=h59595ed_1003 103 | - grpcio=1.67.1=py311h9789449_1 104 | - h11=0.14.0=pyhd8ed1ab_1 105 | - h2=4.2.0=pyhd8ed1ab_0 106 | - h5py=3.12.1=nompi_py311h5ed33ec_103 107 | - harfbuzz=10.3.0=h76408a6_0 108 | - hdf5=1.14.4=nompi_h2d575fe_105 109 | - holoviews=1.20.1=pyhd8ed1ab_0 110 | - hpack=4.1.0=pyhd8ed1ab_0 111 | - httpcore=1.0.7=pyh29332c3_1 112 | - httpx=0.28.1=pyhd8ed1ab_0 113 | - hvplot=0.11.2=pyhd8ed1ab_0 114 | - hyperframe=6.1.0=pyhd8ed1ab_0 115 | - icu=75.1=he02047a_0 116 | - idna=3.10=pyhd8ed1ab_1 117 | - importlib-metadata=8.6.1=pyha770c72_0 118 | - importlib_resources=6.5.2=pyhd8ed1ab_0 119 | - intel-openmp=2022.0.1=h06a4308_3633 120 | - ipykernel=6.29.5=pyh3099207_0 121 | - ipympl=0.9.6=pyhd8ed1ab_0 122 | - ipython=8.32.0=pyh907856f_0 123 | - ipython_genutils=0.2.0=pyhd8ed1ab_2 124 | - ipywidgets=8.1.5=pyhd8ed1ab_1 125 | - isoduration=20.11.0=pyhd8ed1ab_1 126 | - jedi=0.19.2=pyhd8ed1ab_1 127 | - jinja2=3.1.5=pyhd8ed1ab_0 128 | - joblib=1.4.2=pyhd8ed1ab_1 129 | - json5=0.10.0=pyhd8ed1ab_1 130 | - jsonpointer=3.0.0=py311h38be061_1 131 | - jsonschema=4.23.0=pyhd8ed1ab_1 132 | - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1 133 | - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1 134 | - jupyter-lsp=2.2.5=pyhd8ed1ab_1 135 | - jupyter_bokeh=4.0.5=pyhd8ed1ab_1 136 | - jupyter_client=8.6.3=pyhd8ed1ab_1 137 | - jupyter_core=5.7.2=pyh31011fe_1 138 | - jupyter_events=0.12.0=pyh29332c3_0 139 | - jupyter_server=2.15.0=pyhd8ed1ab_0 140 | - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1 141 | - jupyterlab=4.3.5=pyhd8ed1ab_0 142 | - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2 143 | - jupyterlab_server=2.27.3=pyhd8ed1ab_1 144 | - jupyterlab_widgets=3.0.13=pyhd8ed1ab_1 145 | - keras=3.8.0=pyh753f3f9_0 146 | - kernel-headers_linux-64=3.10.0=he073ed8_18 147 | - keyutils=1.6.1=h166bdaf_0 148 | - kiwisolver=1.4.7=py311hd18a35c_0 149 | - krb5=1.21.3=h659f571_0 150 | - lcms2=2.17=h717163a_0 151 | - ld_impl_linux-64=2.43=h712a8e2_4 152 | - lerc=4.0.0=h27087fc_0 153 | - libabseil=20240722.0=cxx17_hbbce691_4 154 | - libaec=1.1.3=h59595ed_0 155 | - libblas=3.9.0=16_linux64_mkl 156 | - libbrotlicommon=1.1.0=hb9d3cd8_2 157 | - libbrotlidec=1.1.0=hb9d3cd8_2 158 | - libbrotlienc=1.1.0=hb9d3cd8_2 159 | - libcblas=3.9.0=16_linux64_mkl 160 | - libclang-cpp19.1=19.1.7=default_hb5137d0_1 161 | - libclang13=19.1.7=default_h9c6a7e4_1 162 | - libcublas=12.1.0.26=0 163 | - libcufft=11.0.2.4=0 164 | - libcufile=1.9.1.3=0 165 | - libcups=2.3.3=h4637d8d_4 166 | - libcurand=10.3.5.147=0 167 | - libcurl=8.12.1=h332b0f4_0 168 | - libcusolver=11.4.4.55=0 169 | - libcusparse=12.0.2.55=0 170 | - libdeflate=1.23=h4ddbbb0_0 171 | - libdrm=2.4.124=hb9d3cd8_0 172 | - libedit=3.1.20250104=pl5321h7949ede_0 173 | - libegl=1.7.0=ha4b6fd6_2 174 | - libev=4.33=hd590300_2 175 | - libexpat=2.6.4=h5888daf_0 176 | - libffi=3.4.6=h2dba641_0 177 | - libgcc=14.2.0=h767d61c_2 178 | - libgcc-ng=14.2.0=h69a702a_2 179 | - libgfortran=14.2.0=h69a702a_2 180 | - libgfortran5=14.2.0=hf1ad2bd_2 181 | - libgl=1.7.0=ha4b6fd6_2 182 | - libglib=2.82.2=h2ff4ddf_1 183 | - libglvnd=1.7.0=ha4b6fd6_2 184 | - libglx=1.7.0=ha4b6fd6_2 185 | - libgomp=14.2.0=h767d61c_2 186 | - libgrpc=1.67.1=h25350d4_1 187 | - libiconv=1.18=h4ce23a2_1 188 | - libjpeg-turbo=3.0.0=hd590300_1 189 | - liblapack=3.9.0=16_linux64_mkl 190 | - libllvm15=15.0.7=ha7bfdaf_5 191 | - libllvm19=19.1.7=ha7bfdaf_1 192 | - liblzma=5.6.4=hb9d3cd8_0 193 | - libnghttp2=1.64.0=h161d5f1_0 194 | - libnpp=12.0.2.50=0 195 | - libnsl=2.0.1=hd590300_0 196 | - libntlm=1.8=hb9d3cd8_0 197 | - libnvjitlink=12.1.105=0 198 | - libnvjpeg=12.1.1.14=0 199 | - libopengl=1.7.0=ha4b6fd6_2 200 | - libpciaccess=0.18=hd590300_0 201 | - libpng=1.6.47=h943b412_0 202 | - libpq=17.4=h27ae623_0 203 | - libprotobuf=5.28.3=h6128344_1 204 | - libre2-11=2024.07.02=hbbce691_2 205 | - libsodium=1.0.20=h4ab18f5_0 206 | - libsqlite=3.49.1=hee588c1_1 207 | - libssh2=1.11.1=hf672d98_0 208 | - libstdcxx=14.2.0=h8f9b012_2 209 | - libstdcxx-ng=14.2.0=h4852527_2 210 | - libtiff=4.7.0=hd9ff511_3 211 | - libuuid=2.38.1=h0b41bf4_0 212 | - libuv=1.50.0=hb9d3cd8_0 213 | - libwebp-base=1.5.0=h851e524_0 214 | - libxcb=1.17.0=h8a09558_0 215 | - libxcrypt=4.4.36=hd590300_1 216 | - libxgboost=2.1.4=cuda118_h09a87be_0 217 | - libxkbcommon=1.8.0=hc4a0caf_0 218 | - libxml2=2.13.6=h8d12d68_0 219 | - libxslt=1.1.39=h76b75d6_0 220 | - libzlib=1.3.1=hb9d3cd8_2 221 | - linkify-it-py=2.0.3=pyhd8ed1ab_1 222 | - llvm-openmp=15.0.7=h0cdce71_0 223 | - llvmlite=0.44.0=py311h9c9ff8c_0 224 | - looseversion=1.3.0=pyhd8ed1ab_0 225 | - lz4-c=1.10.0=h5888daf_1 226 | - markdown=3.6=pyhd8ed1ab_0 227 | - markdown-it-py=3.0.0=pyhd8ed1ab_1 228 | - markupsafe=3.0.2=py311h2dc5d0c_1 229 | - matplotlib=3.10.0=py311h38be061_0 230 | - matplotlib-base=3.10.0=py311h2b939e6_0 231 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1 232 | - mccabe=0.7.0=pyhd8ed1ab_1 233 | - mdit-py-plugins=0.4.2=pyhd8ed1ab_1 234 | - mdurl=0.1.2=pyhd8ed1ab_1 235 | - mistune=3.1.2=pyhd8ed1ab_0 236 | - mkl=2022.1.0=hc2b9512_224 237 | - ml_dtypes=0.4.0=py311h7db5c69_2 238 | - mmtf-python=1.1.3=pyhd8ed1ab_0 239 | - mpc=1.3.1=h24ddda3_1 240 | - mpfr=4.2.1=h90cbb55_3 241 | - mpmath=1.3.0=pyhd8ed1ab_1 242 | - msgpack-python=1.1.0=py311hd18a35c_0 243 | - multidict=6.1.0=py311h2dc5d0c_2 244 | - multipledispatch=0.6.0=pyhd8ed1ab_1 245 | - munkres=1.1.4=pyh9f0ad1d_0 246 | - mypy=1.15.0=py311h9ecbd09_0 247 | - mypy_extensions=1.0.0=pyha770c72_1 248 | - mysql-common=9.0.1=h266115a_4 249 | - mysql-libs=9.0.1=he0572af_4 250 | - namex=0.0.8=pyhd8ed1ab_1 251 | - nbclient=0.10.2=pyhd8ed1ab_0 252 | - nbconvert-core=7.16.6=pyh29332c3_0 253 | - nbformat=5.10.4=pyhd8ed1ab_1 254 | - nccl=2.25.1.1=h03a54cd_0 255 | - ncurses=6.5=h2d0b736_3 256 | - nest-asyncio=1.6.0=pyhd8ed1ab_1 257 | - networkx=3.4.2=pyh267e887_2 258 | - nodejs=22.12.0=hf235a45_0 259 | - nomkl=2.0=0 260 | - notebook-shim=0.2.4=pyhd8ed1ab_1 261 | - numba=0.61.0=py311h4e1c48f_1 262 | - numexpr=2.10.2=py311h38b10cd_100 263 | - numpy=2.1.3=py311h71ddf71_0 264 | - ocl-icd=2.3.2=hb9d3cd8_2 265 | - ocl-icd-system=1.0.0=1 266 | - opencl-headers=2024.10.24=h5888daf_0 267 | - openjpeg=2.5.3=h5fbd93e_0 268 | - openldap=2.6.9=he970967_0 269 | - openmm=8.2.0=py311he5bdeac_2 270 | - openpyxl=3.1.5=py311h50c5138_1 271 | - openssl=3.4.1=h7b32b05_0 272 | - opt_einsum=3.4.0=pyhd8ed1ab_1 273 | - optree=0.14.0=py311hd18a35c_1 274 | - overrides=7.7.0=pyhd8ed1ab_1 275 | - packaging=24.2=pyhd8ed1ab_2 276 | - pandas=2.2.3=py311h7db5c69_1 277 | - pandocfilters=1.5.0=pyhd8ed1ab_0 278 | - panel=1.6.1=pyhd8ed1ab_0 279 | - param=2.2.0=pyhd8ed1ab_0 280 | - parso=0.8.4=pyhd8ed1ab_1 281 | - patsy=1.0.1=pyhd8ed1ab_1 282 | - pcre2=10.44=hba22ea6_2 283 | - pdbfixer=1.11=pyhd8ed1ab_0 284 | - pexpect=4.9.0=pyhd8ed1ab_1 285 | - pickleshare=0.7.5=pyhd8ed1ab_1004 286 | - pillow=11.1.0=py311h1322bbf_0 287 | - pip=25.0.1=pyh8b19718_0 288 | - pixman=0.44.2=h29eaf8c_0 289 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2 290 | - platformdirs=4.3.6=pyhd8ed1ab_1 291 | - prometheus_client=0.21.1=pyhd8ed1ab_0 292 | - prompt-toolkit=3.0.50=pyha770c72_0 293 | - propcache=0.2.1=py311h2dc5d0c_1 294 | - protobuf=5.28.3=py311hfdbb021_0 295 | - psutil=6.1.1=py311h9ecbd09_0 296 | - pthread-stubs=0.4=hb9d3cd8_1002 297 | - ptyprocess=0.7.0=pyhd8ed1ab_1 298 | - pure_eval=0.2.3=pyhd8ed1ab_1 299 | - py-cpuinfo=9.0.0=pyhd8ed1ab_1 300 | - py-xgboost=2.1.4=cuda118_pyh1b5bf1a_0 301 | - pycodestyle=2.12.1=pyhd8ed1ab_1 302 | - pycparser=2.22=pyh29332c3_1 303 | - pyct=0.5.0=pyhd8ed1ab_1 304 | - pyflakes=3.2.0=pyhd8ed1ab_1 305 | - pyg=2.5.2=py311_torch_2.1.0_cu121 306 | - pygments=2.19.1=pyhd8ed1ab_0 307 | - pyparsing=3.2.1=pyhd8ed1ab_0 308 | - pyside6=6.8.2=py311h9053184_0 309 | - pysocks=1.7.1=pyha55dd90_7 310 | - pytables=3.10.2=py311h3ebe2b2_0 311 | - python=3.11.11=h9e4cc4f_1_cpython 312 | - python-dateutil=2.9.0.post0=pyhff2d567_1 313 | - python-fastjsonschema=2.21.1=pyhd8ed1ab_0 314 | - python-flatbuffers=25.2.10=pyhbc23db3_0 315 | - python-json-logger=2.0.7=pyhd8ed1ab_0 316 | - python-tzdata=2025.1=pyhd8ed1ab_0 317 | - python_abi=3.11=5_cp311 318 | - pytorch=2.1.1=py3.11_cuda12.1_cudnn8.9.2_0 319 | - pytorch-cuda=12.1=ha16c6d3_6 320 | - pytorch-mutex=1.0=cuda 321 | - pytz=2024.1=pyhd8ed1ab_0 322 | - pyviz_comms=3.0.4=pyhd8ed1ab_1 323 | - pyyaml=6.0.2=py311h2dc5d0c_2 324 | - pyzmq=26.2.1=py311h7deb3e3_0 325 | - qhull=2020.2=h434a139_5 326 | - qt6-main=6.8.2=h588cce1_0 327 | - re2=2024.07.02=h9925aae_2 328 | - readline=8.2=h8c095d6_2 329 | - referencing=0.36.2=pyh29332c3_0 330 | - requests=2.32.3=pyhd8ed1ab_1 331 | - rfc3339-validator=0.1.4=pyhd8ed1ab_1 332 | - rfc3986-validator=0.1.1=pyh9f0ad1d_0 333 | - rich=13.9.4=pyhd8ed1ab_1 334 | - rpds-py=0.23.1=py311h687327b_0 335 | - scikit-learn=1.6.1=py311h57cc02b_0 336 | - scipy=1.15.2=py311h8f841c2_0 337 | - seaborn=0.13.2=hd8ed1ab_3 338 | - seaborn-base=0.13.2=pyhd8ed1ab_3 339 | - send2trash=1.8.3=pyh0d859eb_1 340 | - setuptools=75.8.0=pyhff2d567_0 341 | - six=1.17.0=pyhd8ed1ab_0 342 | - snappy=1.2.1=h8bd8927_1 343 | - sniffio=1.3.1=pyhd8ed1ab_1 344 | - soupsieve=2.5=pyhd8ed1ab_1 345 | - stack_data=0.6.3=pyhd8ed1ab_1 346 | - statsmodels=0.14.4=py311h9f3472d_0 347 | - sympy=1.13.3=pyh2585a3b_105 348 | - sysroot_linux-64=2.17=h0157908_18 349 | - tensorboard=2.18.0=pyhd8ed1ab_1 350 | - tensorboard-data-server=0.7.0=py311hafd3f86_2 351 | - tensorflow=2.18.0=cpu_py311h6ac8430_0 352 | - tensorflow-base=2.18.0=cpu_py311h50f7602_0 353 | - tensorflow-estimator=2.18.0=cpu_py311h1adda88_0 354 | - termcolor=2.5.0=pyhd8ed1ab_1 355 | - terminado=0.18.1=pyh0d859eb_0 356 | - threadpoolctl=3.5.0=pyhc1e730c_0 357 | - tinycss2=1.4.0=pyhd8ed1ab_0 358 | - tk=8.6.13=noxft_h4845f30_101 359 | - toolz=1.0.0=pyhd8ed1ab_1 360 | - torchtriton=2.1.0=py311 361 | - tornado=6.4.2=py311h9ecbd09_0 362 | - tqdm=4.67.1=pyhd8ed1ab_1 363 | - traitlets=5.14.3=pyhd8ed1ab_1 364 | - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0 365 | - typing-extensions=4.12.2=hd8ed1ab_1 366 | - typing_extensions=4.12.2=pyha770c72_1 367 | - typing_utils=0.1.0=pyhd8ed1ab_1 368 | - tzdata=2025a=h78e105d_0 369 | - uc-micro-py=1.0.3=pyhd8ed1ab_1 370 | - unicodedata2=16.0.0=py311h9ecbd09_0 371 | - uri-template=1.3.0=pyhd8ed1ab_1 372 | - urllib3=2.3.0=pyhd8ed1ab_0 373 | - versioneer=0.29=pyhd8ed1ab_0 374 | - wayland=1.23.1=h3e06ad9_0 375 | - wcwidth=0.2.13=pyhd8ed1ab_1 376 | - webcolors=24.11.1=pyhd8ed1ab_0 377 | - webencodings=0.5.1=pyhd8ed1ab_3 378 | - websocket-client=1.8.0=pyhd8ed1ab_1 379 | - werkzeug=3.1.3=pyhd8ed1ab_1 380 | - wheel=0.45.1=pyhd8ed1ab_1 381 | - widgetsnbextension=4.0.13=pyhd8ed1ab_1 382 | - wrapt=1.17.2=py311h9ecbd09_0 383 | - xarray=2025.1.2=pyhd8ed1ab_0 384 | - xcb-util=0.4.1=hb711507_2 385 | - xcb-util-cursor=0.1.5=hb9d3cd8_0 386 | - xcb-util-image=0.4.0=hb711507_2 387 | - xcb-util-keysyms=0.4.1=hb711507_0 388 | - xcb-util-renderutil=0.3.10=hb711507_0 389 | - xcb-util-wm=0.4.2=hb711507_0 390 | - xgboost=2.1.4=cuda118_pyh7984362_0 391 | - xkeyboard-config=2.43=hb9d3cd8_0 392 | - xorg-libice=1.1.2=hb9d3cd8_0 393 | - xorg-libsm=1.2.5=he73a12e_0 394 | - xorg-libx11=1.8.11=h4f16b4b_0 395 | - xorg-libxau=1.0.12=hb9d3cd8_0 396 | - xorg-libxcomposite=0.4.6=hb9d3cd8_2 397 | - xorg-libxcursor=1.2.3=hb9d3cd8_0 398 | - xorg-libxdamage=1.1.6=hb9d3cd8_0 399 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 400 | - xorg-libxext=1.3.6=hb9d3cd8_0 401 | - xorg-libxfixes=6.0.1=hb9d3cd8_0 402 | - xorg-libxi=1.8.2=hb9d3cd8_0 403 | - xorg-libxrandr=1.5.4=hb9d3cd8_0 404 | - xorg-libxrender=0.9.12=hb9d3cd8_0 405 | - xorg-libxtst=1.2.5=hb9d3cd8_3 406 | - xorg-libxxf86vm=1.1.6=hb9d3cd8_0 407 | - xyzservices=2025.1.0=pyhd8ed1ab_0 408 | - yaml=0.2.5=h7f98852_2 409 | - yarl=1.18.3=py311h2dc5d0c_1 410 | - zeromq=4.3.5=h3b0a872_7 411 | - zipp=3.21.0=pyhd8ed1ab_1 412 | - zlib=1.3.1=hb9d3cd8_2 413 | - zlib-ng=2.2.4=h7955e40_0 414 | - zstandard=0.19.0=py311hd4cff14_0 415 | - zstd=1.5.7=hb8e6e7a_1 416 | - pip: 417 | - black==21.12b0 418 | - blackcellmagic==0.0.3 419 | - click==8.1.8 420 | - fair-esm==2.0.0 421 | - hatch-cython==0.5.1 422 | - jupyter==1.1.1 423 | - jupyter-console==6.6.3 424 | - notebook==7.3.2 425 | - ordered-set==4.1.0 426 | - pathspec==0.12.1 427 | - rdkit==2024.9.5 428 | - tomli==1.2.3 429 | prefix: /disk2/fli/miniconda3/envs/SSMuLA 430 | --------------------------------------------------------------------------------