├── __init__.py
├── tests
├── __init__.py
├── test_preprocess.py
├── test_triad_pre.py
├── test_triad_post.py
├── test_hd.py
├── test_coves.py
├── test_esmif.py
├── test_de_vis.py
├── test_de.py
├── test_loc_opt.py
├── test_gen_learned_emb.py
├── test_ev_esm.py
├── test_alde.py
├── test_mlde_vis.py
├── test_finetune.py
├── test_pairwise_epistasis.py
├── test_zs.py
└── test_corr.py
├── fig1.png
├── envs
├── esmif.yml
├── coves.yml
├── SSMuLA.yml
├── frozen
│ ├── esmif.yml
│ ├── coves.yml
│ └── SSMuLA.yml
└── finetune.yml
├── SSMuLA
├── gen_atom3d.py
├── __init__.py
├── run_ev_esm.py
├── finetune_analysis.py
├── est_ep.py
├── util.py
├── zs_calc.py
├── alde_analysis.py
├── triad_prepost.py
├── aa_global.py
├── vis.py
├── calc_hd.py
├── zs_models.py
├── plm_finetune.py
└── get_factor.py
├── esmif
├── esmif.sh
└── score_log_likelihoods.py
├── .gitignore
└── README.md
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fhalab/SSMuLA/HEAD/fig1.png
--------------------------------------------------------------------------------
/envs/esmif.yml:
--------------------------------------------------------------------------------
1 | # For the general environment:
2 | # Install or update using
3 | # conda env update --file esmif.yml --prune
4 |
5 | name: esmif
6 | channels:
7 | - pytorch
8 | - pyg
9 | - conda-forge
10 | - defaults
11 | dependencies:
12 | - python=3.9
13 | - biopandas
14 | - biopython
15 | - biotite
16 | - cudatoolkit=11.3
17 | - pandas
18 | - pyg
19 | - pytorch
20 | - pip
21 | - pip:
22 | - git+https://github.com/facebookresearch/esm.git
23 |
--------------------------------------------------------------------------------
/envs/coves.yml:
--------------------------------------------------------------------------------
1 | # conda env update --file coves.yml --prune
2 | # need to be fixed
3 | name: coves
4 | channels:
5 | - pytorch
6 | - pyg
7 | - pytorch3d
8 | - conda-forge
9 | - defaults
10 | - anaconda
11 | dependencies:
12 | - python=3.9
13 | - numpy # =1.19.4
14 | - pip
15 | - pyg
16 | - pytorch # =1.8.1
17 | - scikit-learn # =0.24.1
18 | # - torch_geometric # =1.7.0
19 | # - torch_scatter # =2.0.6
20 | # - torch_cluster # =1.5.9
21 | - tqdm
22 |
23 | - pip:
24 | - atom3d # =0.2.1
25 | - blackcellmagic
26 | - rdkit
--------------------------------------------------------------------------------
/tests/test_preprocess.py:
--------------------------------------------------------------------------------
1 | """Test the preprocess module."""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.fitness_process_vis import process_all, get_all_lib_stats
9 | from SSMuLA.util import checkNgen_folder
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | log_folder = checkNgen_folder("logs/fitness_process_vis")
15 |
16 | # log outputs
17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
18 | sys.stdout = f
19 |
20 | process_all(scale_fit="max")
21 | get_all_lib_stats()
22 |
23 | f.close()
--------------------------------------------------------------------------------
/tests/test_triad_pre.py:
--------------------------------------------------------------------------------
1 | """Test the triad pre and post processing."""
2 |
3 | import sys
4 | import os
5 |
6 | from glob import glob
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.triad_prepost import run_traid_gen_mut_file
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/triad/pre"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | run_traid_gen_mut_file(lib_list = ["kej"])
26 |
27 | f.close()
--------------------------------------------------------------------------------
/tests/test_triad_post.py:
--------------------------------------------------------------------------------
1 | """Test the triad pre and post processing."""
2 |
3 | import sys
4 | import os
5 |
6 | from glob import glob
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.triad_prepost import run_parse_triad_results
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/triad/post"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | run_parse_triad_results(all_lib = False, lib_list = ["T7", "TEV"])
26 |
27 | f.close()
--------------------------------------------------------------------------------
/tests/test_hd.py:
--------------------------------------------------------------------------------
1 | """A script for testing plotting for de"""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.calc_hd import run_hd_avg_fit, run_hd_avg_metric
9 | from SSMuLA.util import checkNgen_folder
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | log_folder = checkNgen_folder("logs/hd")
15 |
16 | # log outputs
17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
18 | sys.stdout = f
19 |
20 | run_hd_avg_fit(
21 | data_dir = "data",
22 | num_processes=256,
23 | hd_dir = "results/hd_fit",
24 | )
25 |
26 | run_hd_avg_metric()
27 |
28 | f.close()
--------------------------------------------------------------------------------
/tests/test_coves.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for testing coves
3 | Use coves environment
4 | """
5 |
6 | import sys
7 | import os
8 |
9 | from datetime import datetime
10 |
11 | from SSMuLA.run_coves import run_all_coves, append_all_coves_scores
12 | from SSMuLA.util import checkNgen_folder
13 |
14 | if __name__ == "__main__":
15 |
16 | log_folder = checkNgen_folder("logs/coves")
17 |
18 | # log outputs
19 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
20 | sys.stdout = f
21 |
22 | run_all_coves(n_ave=100)
23 | append_all_coves_scores()
24 |
25 | f.close()
26 |
27 | """
28 | append_all_coves_scores(
29 | libs: list|str = "ev_esm2/*",
30 | ev_esm_dir: str = "ev_esm2",
31 | coves_dir: str = "coves/100",
32 | t: float = 0.1
33 | """
--------------------------------------------------------------------------------
/tests/test_esmif.py:
--------------------------------------------------------------------------------
1 | """Test ems inverse folding zs"""
2 |
3 | import sys
4 | import os
5 |
6 |
7 | from datetime import datetime
8 |
9 | # from SSMuLA.zs_analysis import run_zs_analysis
10 | from SSMuLA.zs_data import get_all_mutfasta
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/zs"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | get_all_mutfasta(
26 | ev_esm_dir="ev_esm2",
27 | all_libs=True,
28 | )
29 |
30 | f.close()
31 |
32 | """
33 | get_all_mutfasta(
34 | ev_esm_dir: str = "ev_esm2",
35 | all_libs: bool = True,
36 | lib_list: list[str] = []
37 | )
38 | """
--------------------------------------------------------------------------------
/tests/test_de_vis.py:
--------------------------------------------------------------------------------
1 | """A script for testing plotting for de"""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.de_simulations import run_plot_de
9 | from SSMuLA.util import checkNgen_folder
10 |
11 | if __name__ == "__main__":
12 |
13 | log_folder = checkNgen_folder("logs/plot_de")
14 |
15 | # log outputs
16 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
17 | sys.stdout = f
18 |
19 | run_plot_de(
20 | scale_types = ["scale2max"],
21 | de_opts = ["DE-active"],
22 | )
23 |
24 | """
25 | def run_plot_de(
26 | scale_types: list = ["scale2max", "scale2parent"],
27 | de_opts: list = ["DE-active"],
28 | sim_folder: str = "results/de",
29 | vis_folder: str = "results/de_vis",
30 | v_width: int = 400,
31 | all_lib: bool = True,
32 | lib_list: list[str] = [],
33 | ):
34 | """
35 |
36 | f.close()
--------------------------------------------------------------------------------
/tests/test_de.py:
--------------------------------------------------------------------------------
1 | """A script for testing plotting for de"""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.de_simulations import run_all_lib_de_simulations
9 | from SSMuLA.util import checkNgen_folder
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | log_folder = checkNgen_folder("logs/de")
15 |
16 | # log outputs
17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
18 | sys.stdout = f
19 |
20 | run_all_lib_de_simulations(
21 | scale_types = ["scale2max"],
22 | de_opts = ["DE-active", "DE-0", "DE-all"],
23 | all_lib = True,
24 | lib_list = [],
25 | rerun = False
26 | )
27 |
28 | """
29 | run_all_lib_de_simulations(
30 | scale_types: list = ["scale2max", "scale2parent"],
31 | de_opts: list = ["DE-active", "DE-0", "DE-all"],
32 | save_dir: str = "results/de",
33 | all_lib: bool = True,
34 | lib_list: list[str] = [],
35 | rerun: bool = False,
36 | )
37 | """
38 |
39 | f.close()
--------------------------------------------------------------------------------
/tests/test_loc_opt.py:
--------------------------------------------------------------------------------
1 | """A script for testing the local optima"""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.landscape_optima import run_loc_opt
9 | from SSMuLA.util import checkNgen_folder
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | log_folder = checkNgen_folder("logs/loc_opt")
15 |
16 | # log outputs
17 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
18 | sys.stdout = f
19 |
20 | run_loc_opt(input_folder = "data",
21 | fitness_process_type = "scale2max",
22 | output_folder = "results/local_optima",
23 | n_jobs = 16,
24 | if_append_escape = True,
25 | rerun = True,
26 | )
27 |
28 | f.close()
29 |
30 | """
31 | run_loc_opt(
32 | input_folder: str = "data",
33 | fitness_process_type: str = "scale2max",
34 | output_folder: str = "results/local_optima",
35 | n_jobs: int = 16,
36 | rerun: bool = False,
37 | ) -> None
38 | """
39 |
40 |
41 |
--------------------------------------------------------------------------------
/tests/test_gen_learned_emb.py:
--------------------------------------------------------------------------------
1 | """
2 | A script to test the learned embedding generation
3 | Test the triad pre and post processing.
4 | """
5 |
6 | import sys
7 | import os
8 |
9 | from datetime import datetime
10 |
11 | from SSMuLA.gen_learned_emb import gen_all_learned_emb
12 | from SSMuLA.util import checkNgen_folder
13 |
14 | if __name__ == "__main__":
15 |
16 | # log outputs
17 | f = open(
18 | os.path.join(
19 | checkNgen_folder("logs/emb"),
20 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
21 | ),
22 | "w",
23 | )
24 | sys.stdout = f
25 |
26 | gen_all_learned_emb(
27 | input_folder = "results/zs_comb_6/none/scale2max/all",
28 | all_libs = False,
29 | regen = True,
30 | lib_list = ["TrpB3F"],
31 | )
32 |
33 | f.close()
34 |
35 | """
36 | input_folder: str = "results/zs_comb/none/scale2max",
37 | encoder_name: str = DEFAULT_ESM,
38 | batch_size: int = 128,
39 | regen: bool = False,
40 | emb_folder: str = "learned_emb",
41 | all_libs: bool = True,
42 | lib_list: list[str] = [],
43 | """
--------------------------------------------------------------------------------
/tests/test_ev_esm.py:
--------------------------------------------------------------------------------
1 | """Test ems inverse folding zs"""
2 |
3 | import sys
4 | import os
5 |
6 |
7 | from datetime import datetime
8 |
9 | # from SSMuLA.zs_analysis import run_zs_analysis
10 | from SSMuLA.zs_calc import calc_all_zs
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/zs"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | calc_all_zs(
26 | landscape_folder="data",
27 | output_folder="results/zs",
28 | ev_model_folder="data",
29 | regen_esm = False,
30 | rerun_zs = False
31 | )
32 |
33 | f.close()
34 |
35 | """
36 | calc_all_zs(
37 | landscape_folder: str = "data/processed",
38 | dataset_list: list[str] = [],
39 | output_folder: str = "results/zs",
40 | zs_model_names: str = "all",
41 | ev_model_folder: str = "data/evmodels",
42 | regen_esm: str = False,
43 | rerun_zs: str = False,
44 | )
45 | """
--------------------------------------------------------------------------------
/tests/test_alde.py:
--------------------------------------------------------------------------------
1 |
2 | """Test alde file comb"""
3 |
4 | import sys
5 | import os
6 |
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.alde_analysis import aggregate_alde_df
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/alde_comb"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | aggregate_alde_df(
26 | eq_ns = [2, 3, 4],
27 | alde_dir = "results/alde",
28 | alde_df_path = "results/alde/alde_all.csv",
29 |
30 | )
31 |
32 | f.close()
33 |
34 | """
35 | aggregate_alde_df(
36 | eq_ns: list[int] = [1, 2, 3, 4],
37 | zs_opts: list[str] = ["esmif", "ev", "coves", "ed", "esm", "Triad", ""],
38 | alde_model: str = "Boosting Ensemble",
39 | alde_encoding: str = "onehot",
40 | alde_acq: str = "GREEDY",
41 | alde_dir: str = "/disk2/fli/alde4ssmula",
42 | alde_df_path: str = "results/alde/alde_all.csv",
43 | )
44 | """
45 |
46 |
--------------------------------------------------------------------------------
/tests/test_mlde_vis.py:
--------------------------------------------------------------------------------
1 | """Test MLDE."""
2 |
3 | import sys
4 | import os
5 |
6 | from glob import glob
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.mlde_analysis import MLDESum
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/mlde_vis"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 |
26 | # MLDESum(
27 | # mlde_results_dir = "results_rev/mlde_coves_ens/saved",
28 | # mlde_vis_dir = "results_rev/mlde_coves_ens/vis"
29 | # )
30 | MLDESum(
31 | mlde_results_dir = "results_rev/mlde_lown/saved",
32 | mlde_vis_dir = "results_rev/mlde_lown/vis"
33 | )
34 |
35 | """
36 | MLDESum:
37 |
38 | def __init__(
39 | self,
40 | mlde_results_dir: str = "results/mlde/saved",
41 | mlde_vis_dir: str = "results/mlde/vis",
42 | all_encoding: bool = True,
43 | encoding_lists: list[str] = [],
44 | ifvis: bool = False,
45 | ) -> None:
46 | """
47 | f.close()
--------------------------------------------------------------------------------
/SSMuLA/gen_atom3d.py:
--------------------------------------------------------------------------------
1 | """
2 | For generating CoVES based zs scores
3 |
4 | NOTE:
5 | - have to use atom3d evn for atom3d gen
6 | - have to be run in the coves env for coves
7 | """
8 |
9 | from __future__ import annotations
10 |
11 | import os
12 | from glob import glob
13 |
14 | import atom3d.datasets as da
15 |
16 | from SSMuLA.util import checkNgen_folder, get_file_name
17 |
18 |
19 | def gen_lmdb_dataset(pdb_path: str, lmdb_path: str):
20 |
21 | """
22 | Generate LMDB dataset from PDB dataset
23 |
24 | Args:
25 | - pdb_path, str: Path to the PDB files
26 | - lmdb_path, str: Path to directory to save LMDB dataset
27 | """
28 |
29 | # Load dataset from directory of PDB files
30 | pdb_file = da.load_dataset(pdb_path, 'pdb')
31 | # Create LMDB dataset from PDB dataset
32 | checkNgen_folder(lmdb_path)
33 | da.make_lmdb_dataset(pdb_file, lmdb_path)
34 |
35 |
36 | def gen_all_lmdb(pdb_pattern: str = "data/*/*.pdb", lmdb_dir: str = "lmdb"):
37 |
38 | for pdb_path in sorted(glob(pdb_pattern)):
39 |
40 | print(f"Generating LMDB dataset for {pdb_path}...")
41 |
42 | protein_name = get_file_name(pdb_path)
43 |
44 | lmdb_path = os.path.join(lmdb_dir, protein_name, "lmdb")
45 | gen_lmdb_dataset(pdb_path, lmdb_path)
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/esmif/esmif.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1
2 |
3 | python score_log_likelihoods.py DHFR.pdb DHFR.fasta --chain A --outpath DHFR_esmif_scores.csv
4 | python score_log_likelihoods.py GB1.pdb GB1.fasta --chain A --outpath GB1_esmif_scores.csv
5 | python score_log_likelihoods.py ParD2.pdb ParD2.fasta --chain A --outpath ParD2_esmif_scores.csv
6 | python score_log_likelihoods.py ParD3.pdb ParD3.fasta --chain A --outpath ParD3_esmif_scores.csv
7 | python score_log_likelihoods.py TrpB.pdb TrpB3A.fasta --chain A --outpath TrpB3A_esmif_scores.csv
8 | python score_log_likelihoods.py TrpB.pdb TrpB3B.fasta --chain A --outpath TrpB3B_esmif_scores.csv
9 | python score_log_likelihoods.py TrpB.pdb TrpB3C.fasta --chain A --outpath TrpB3C_esmif_scores.csv
10 | python score_log_likelihoods.py TrpB.pdb TrpB3D.fasta --chain A --outpath TrpB3D_esmif_scores.csv
11 | python score_log_likelihoods.py TrpB.pdb TrpB3E.fasta --chain A --outpath TrpB3E_esmif_scores.csv
12 | python score_log_likelihoods.py TrpB.pdb TrpB3F.fasta --chain A --outpath TrpB3F_esmif_scores.csv
13 | python score_log_likelihoods.py TrpB.pdb TrpB3G.fasta --chain A --outpath TrpB3G_esmif_scores.csv
14 | python score_log_likelihoods.py TrpB.pdb TrpB3H.fasta --chain A --outpath TrpB3H_esmif_scores.csv
15 | python score_log_likelihoods.py TrpB.pdb TrpB3I.fasta --chain A --outpath TrpB3I_esmif_scores.csv
16 | python score_log_likelihoods.py TrpB.pdb TrpB4.fasta --chain A --outpath TrpB4_esmif_scores.csv
--------------------------------------------------------------------------------
/envs/SSMuLA.yml:
--------------------------------------------------------------------------------
1 | # For the general environment:
2 | # Install or update using
3 | # conda env update --file SSMuLA.yml --prune
4 |
5 | name: SSMuLA
6 | channels:
7 | - pytorch
8 | - pyg
9 | - pytorch3d
10 | - nvidia
11 | - conda-forge
12 | - salilab
13 | - pyviz
14 | - defaults
15 | - anaconda
16 | dependencies:
17 | - python=3.11
18 | - biopandas
19 | - biopython
20 | - biotite
21 | - bokeh
22 | - brokenaxes
23 | - cairosvg
24 | - colorcet
25 | - datashader
26 | - firefox
27 | - flake8
28 | - geckodriver
29 | - hdf5
30 | - holoviews
31 | - hvplot
32 | - ipykernel
33 | - ipympl
34 | - ipywidgets
35 | - jinja2
36 | - jupyterlab
37 | - jupyter_bokeh
38 | - matplotlib
39 | - multipledispatch
40 | - mypy
41 | - networkx
42 | - nodejs
43 | - numpy
44 | - openpyxl
45 | - pandas
46 | - panel
47 | - param
48 | - pdbfixer
49 | - pip
50 | - psutil
51 | - pyg>=2.0.3
52 | - pytables
53 | - pytorch=2.1.1 # for h100
54 | - pytorch-cuda=12.1 # change this as needed
55 | - pyviz_comms
56 | - pyyaml
57 | - requests
58 | - scikit-learn # might not be evcouplings compatible
59 | - scipy
60 | - seaborn
61 | - tensorflow
62 | - tqdm
63 | - versioneer
64 | - xarray
65 | - xgboost
66 |
67 | - pip:
68 | - blackcellmagic
69 | - fair-esm
70 | - ordered-set
71 | - rdkit
72 | # pip install https://github.com/debbiemarkslab/EVcouplings/archive/develop.zip install after the environment is created
73 | prefix: /disk2/fli/setup/miniconda3/
--------------------------------------------------------------------------------
/tests/test_finetune.py:
--------------------------------------------------------------------------------
1 | """Test MLDE."""
2 |
3 | import sys
4 | import os
5 |
6 | from glob import glob
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.plm_finetune import train_predict_per_protein
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/finetune"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 |
26 | for landscape in sorted(glob("results/zs_comb/all/*.csv")):
27 | for i in range(5):
28 | train_predict_per_protein(
29 | df_csv=landscape,
30 | rep=i,
31 | )
32 |
33 | f.close()
34 |
35 | """
36 | train_predict_per_protein(
37 | df_csv: str, # csv file with landscape data
38 | rep: int, # replicate number
39 | checkpoint: str = "facebook/esm2_t33_650M_UR50D", # model checkpoint
40 | n_sample: int = 384, # number of train+val
41 | zs_predictor: str = "none", # zero-shot predictor
42 | ft_frac: float = 0.125, # fraction of data for focused sampling
43 | plot_dir: str = "results/finetuning/plot", # directory to save the plot
44 | model_dir: str = "results/finetuning/model", # directory to save the model
45 | pred_dir: str = "results/finetuning/predictions", # directory to save the predictions
46 | train_kwargs: dict = {}, # additional training arguments
47 | )
48 | """
--------------------------------------------------------------------------------
/SSMuLA/__init__.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # #
3 | # This program is free software: you can redistribute it and/or modify #
4 | # it under the terms of the GNU General Public License as published by #
5 | # the Free Software Foundation, either version 3 of the License, or #
6 | # (at your option) any later version. #
7 | # #
8 | # This program is distributed in the hope that it will be useful, #
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of #
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
11 | # GNU General Public License for more details. #
12 | # #
13 | # You should have received a copy of the GNU General Public License #
14 | # along with this program. If not, see . #
15 | # #
16 | ###############################################################################
17 |
18 | __title__ = 'SSMuLA'
19 | __description__ = 'Site Saturation Mutagenesis Landscape Analysis'
20 | __url__ = 'https://github.com/fhalab/SSMuLA.git'
21 | __version__ = '1.0.0'
22 | __author__ = 'Francesca-Zhoufan Li'
23 | __author_email__ = 'fzl@caltech.edu'
24 | __license__ = 'GPL3'
25 |
--------------------------------------------------------------------------------
/tests/test_pairwise_epistasis.py:
--------------------------------------------------------------------------------
1 | """
2 | A script to test the pairwise epistasis calculation
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import sys
8 | import os
9 |
10 | from datetime import datetime
11 |
12 |
13 | from SSMuLA.pairwise_epistasis import calc_all_pairwise_epistasis, plot_pairwise_epistasis
14 | from SSMuLA.util import checkNgen_folder
15 |
16 |
17 | if __name__ == "__main__":
18 |
19 | log_folder = checkNgen_folder("logs/pairwise_epistasis")
20 |
21 | # log outputs
22 | f = open(os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"), 'w')
23 | sys.stdout = f
24 |
25 | fitness_process_type = "scale2max"
26 |
27 | calc_all_pairwise_epistasis(
28 | fitness_process_type=fitness_process_type,
29 | ifall=False,
30 | lib_list=["T7", "TEV"],
31 | output_folder="results/pairwise_epistasis",
32 | n_jobs=128,
33 | )
34 |
35 | """
36 | calc_all_pairwise_epistasis(
37 | input_folder: str = "data",
38 | fitness_process_type: str = "scale2max",
39 | activestart: bool = True,
40 | ifall: bool = True,
41 | lib_list: list[str] = [],
42 | output_folder: str = "results/pairwise_epistasis",
43 | n_jobs: int = 128,
44 | """
45 |
46 | pos_calc_filter_min = "none"
47 |
48 | plot_pairwise_epistasis(
49 | fitness_process_type=fitness_process_type,
50 | pos_calc_filter_min=pos_calc_filter_min,
51 | input_folder="results/pairwise_epistasis",
52 | output_folder="results/pairwise_epistasis_vis",
53 | dets_folder="results/pairwise_epistasis_dets",
54 | )
55 |
56 | f.close()
--------------------------------------------------------------------------------
/tests/test_zs.py:
--------------------------------------------------------------------------------
1 | """Test the triad pre and post processing."""
2 |
3 | import sys
4 | import os
5 |
6 | from glob import glob
7 |
8 | from datetime import datetime
9 |
10 | from SSMuLA.zs_analysis import run_zs_analysis
11 | from SSMuLA.util import checkNgen_folder
12 |
13 | if __name__ == "__main__":
14 |
15 | # log outputs
16 | f = open(
17 | os.path.join(
18 | checkNgen_folder("logs/zs"),
19 | f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out",
20 | ),
21 | "w",
22 | )
23 | sys.stdout = f
24 |
25 | # run_zs_analysis(
26 | # scale_types=["max"],
27 | # filter_min_by="none",
28 | # ev_esm_folder = "ev_esm2",
29 | # zs_comb_dir = "results/zs_comb_2",
30 | # zs_vis_dir = "results/zs_vis_2",
31 | # zs_sum_dir = "results/zs_sum_2",
32 | # )
33 |
34 | # run_zs_analysis(
35 | # scale_types=["max"],
36 | # filter_min_by="min0",
37 | # ev_esm_folder = "ev_esm2",
38 | # zs_comb_dir = "results/zs_comb_2",
39 | # zs_vis_dir = "results/zs_vis_2",
40 | # zs_sum_dir = "results/zs_sum_2",
41 | # )
42 |
43 | # run_zs_analysis(
44 | # scale_types=["max"],
45 | # filter_min_by="none",
46 | # ev_esm_folder = "ev_esm2",
47 | # zs_comb_dir = "results/zs_comb_6",
48 | # zs_vis_dir = "results/zs_vis_6",
49 | # zs_sum_dir = "results/zs_sum_6",
50 | # )
51 |
52 | run_zs_analysis(
53 | scale_types=["max"],
54 | filter_min_by="none",
55 | ev_esm_folder = "ev_esm2",
56 | zs_comb_dir = "results_old/zs_comb_7",
57 | zs_vis_dir = "results_old/zs_vis_7",
58 | zs_sum_dir = "results_old/zs_sum_7",
59 | )
60 |
61 | f.close()
62 |
63 | """
64 | run_zs_analysis(
65 | scale_types: list = ["max", "parent"],
66 | data_folder: str = "data",
67 | ev_esm_folder: str = "ev_esm",
68 | triad_folder: str = "triad",
69 | esmif_folder: str = "esmif",
70 | filter_min_by: str = "none",
71 | n_mut_cutoff_list: list[int] = [0, 1, 2],
72 | zs_comb_dir: str = "results/zs_comb",
73 | zs_vis_dir: str = "results/zs_vis",
74 | zs_sum_dir: str = "results/zs_sum",
75 | )
76 | """
--------------------------------------------------------------------------------
/SSMuLA/run_ev_esm.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work
3 | """
4 |
5 | # Import packages
6 | import glob
7 | import json
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import random
12 | from collections import Counter
13 |
14 |
15 | import argparse
16 | from pathlib import Path
17 |
18 | from tqdm import tqdm
19 | from typing import List, Tuple, Optional
20 |
21 |
22 | import warnings
23 |
24 | from SSMuLA.zs_calc import calc_all_zs
25 | from SSMuLA.util import get_file_name, checkNgen_folder
26 |
27 |
28 | # TODO clean up path to be dataset name independent
29 |
30 | def create_parser():
31 | parser = argparse.ArgumentParser(description="Run zero-shot predictions")
32 |
33 | parser.add_argument(
34 | "--landscape_folder",
35 | type=str,
36 | default="data/processed",
37 | metavar="LSF",
38 | help="A folder path for all landscape data"
39 | )
40 |
41 | parser.add_argument(
42 | "--dataset_list",
43 | type=json.loads,
44 | metavar="dsl",
45 | default=[],
46 | help="default dataset list empty to use glob for all",
47 | )
48 |
49 | parser.add_argument(
50 | "--output_folder",
51 | type=str,
52 | default="results/zs",
53 | metavar="OPF",
54 | help="A output folder path with landscape subfolders"
55 | )
56 |
57 | parser.add_argument(
58 | "--zs_model_names",
59 | type=str,
60 | metavar="ZSMN",
61 | help="A str of name(s) of zero-shot models to use seperated by comma, \
62 | available: 'esm', 'ev', developing: 'ddg', 'Bert', all: 'all' runs all currently available models \
63 | ie, 'esm, ev'",
64 | )
65 |
66 | parser.add_argument(
67 | "--ev_model_folder",
68 | type=str,
69 | default="data/",
70 | metavar="EVF",
71 | help="folder for evmodels"
72 | )
73 |
74 | parser.add_argument(
75 | "--regen_esm",
76 | type=bool,
77 | default=False,
78 | metavar="RG",
79 | help="if regenerate esm logits or load directly"
80 | )
81 |
82 | parser.add_argument(
83 | "--rerun_zs",
84 | type=bool,
85 | default=False,
86 | metavar="RR",
87 | help="if append new zs to current csv or create new output"
88 | )
89 |
90 | return parser
91 |
92 |
93 | def main(args):
94 |
95 | # Input processing
96 |
97 | calc_all_zs(landscape_folder = args.landscape_folder,
98 | dataset_list=args.dataset_list,
99 | output_folder = args.output_folder,
100 | zs_model_names = args.zs_model_names,
101 | ev_model_folder = args.ev_model_folder,
102 | regen_esm = args.regen_esm,
103 | rerun_zs = args.rerun_zs)
104 |
105 | # Run EvMutation
106 | if __name__ == "__main__":
107 | parser = create_parser()
108 | args = parser.parse_args()
109 | main(args)
--------------------------------------------------------------------------------
/SSMuLA/finetune_analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for combining and analyzing the results of the LoRA fine-tuning analysis.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from glob import glob
8 | import pandas as pd
9 |
10 | from SSMuLA.landscape_global import N_SAMPLE_LIST
11 | from SSMuLA.util import get_file_name
12 |
13 |
14 | def parse_finetune_df(
15 | finetune_dir: str, # ie results/finetuning/ev or none
16 | lib_list: list,
17 | n_top: int = 96,
18 | ) -> pd.DataFrame:
19 |
20 | """
21 | Parse the finetune dataframe and return a summary dataframe.
22 |
23 | The finetune_dir should contain the results of the finetuning analysis,
24 | where each landscape is in a separate folder, and each folder contains
25 | the results of the finetuning analysis for each landscape
26 | with the format __.csv.
27 |
28 | Args:
29 | finetune_dir (str): The directory containing the finetuning results.
30 | lib_list (list): A list of libraries to include in the analysis.
31 | n_top (int): The number of top variants to consider.
32 | """
33 |
34 | sum_df_list = []
35 |
36 | for df_path in sorted(glob(f"{finetune_dir}/*/*.csv")):
37 |
38 | landscape, n_sample, rep = get_file_name(df_path).split("_")
39 |
40 | if landscape not in lib_list:
41 | continue
42 |
43 | df = pd.read_csv(df_path)
44 | max_fit_seq = df.loc[df["fitness"].idxmax()]["seq"]
45 |
46 | # get top 96 maxes
47 | top_df = (
48 | df.sort_values(by="predictions", ascending=False)
49 | .reset_index(drop=True)
50 | .iloc[:n_top, :]
51 | )
52 | top_seqs = top_df["seq"].astype(str).values
53 |
54 | # write to sum_df
55 | sum_df_list.append(
56 | {
57 | "landscape": landscape,
58 | "n_sample": int(n_sample),
59 | "rep": int(rep),
60 | "top_maxes": top_df["fitness"].max(),
61 | "if_truemaxs": int(max_fit_seq in top_seqs),
62 | }
63 | )
64 |
65 | return (
66 | pd.DataFrame(sum_df_list)
67 | .sort_values(by=["n_sample", "landscape", "rep"])
68 | .reset_index(drop=True)
69 | .copy()
70 | )
71 |
72 |
73 | def avg_finetune_df(
74 | finetune_df: pd.DataFrame,
75 | n_sample_list: list = N_SAMPLE_LIST,
76 | ) -> pd.DataFrame:
77 |
78 | """
79 | Average the finetune dataframe over the number of samples and repetitions.
80 |
81 | Args:
82 | finetune_df (pd.DataFrame): The dataframe containing the finetuning results.
83 | n_sample_list (list): A list of the number of samples to consider.
84 | """
85 |
86 | avg_sum_df = (
87 | finetune_df[["n_sample", "top_maxes", "if_truemaxs"]]
88 | .groupby("n_sample")
89 | .agg(["mean", "std"])
90 | .reset_index()
91 | )
92 | avg_sum_df.columns = ["{}_{}".format(i, j) for i, j in avg_sum_df.columns]
93 | return (
94 | avg_sum_df.rename(columns={"n_sample_": "n_sample"})
95 | .set_index("n_sample")
96 | .copy()
97 | )
--------------------------------------------------------------------------------
/tests/test_corr.py:
--------------------------------------------------------------------------------
1 | """A script for testing plotting for de"""
2 |
3 | import sys
4 | import os
5 |
6 | from datetime import datetime
7 |
8 | from SSMuLA.get_corr import MergeLandscapeAttributes, MergeMLDEAttributes, perfom_corr
9 | from SSMuLA.util import checkNgen_folder
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | log_folder = checkNgen_folder("logs/corr")
15 |
16 | # log outputs
17 | f = open(
18 | os.path.join(log_folder, f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.out"),
19 | "w",
20 | )
21 | sys.stdout = f
22 |
23 | # MergeLandscapeAttributes(
24 | # lib_stat_path="results/fitness_distribution/max/all_lib_stats.csv",
25 | # loc_opt_path="results/local_optima/scale2max.csv",
26 | # pwe_path="results/pairwise_epistasis_vis/none/scale2max.csv",
27 | # zs_path="results/zs_sum_5/none/zs_stat_scale2max.csv",
28 | # de_path="results/de/DE-active/scale2max/all_landscape_de_summary.csv",
29 | # merge_dir="results/merged",
30 | # )
31 |
32 | # MergeMLDEAttributes(mlde_path = "results/mlde/all_results.csv", merge_dir = "results/merged_2", models=["boosting"])
33 |
34 |
35 | perfom_corr(
36 | n_mut_cutoff=0,
37 | n_list=[384],
38 | zs_path="results/zs_sum/none/zs_stat_scale2max.csv",
39 | mlde_path="results/mlde/all_df_comb_onehot.csv",
40 | corr_dir="results/corr",
41 | ifplot=False,
42 | )
43 |
44 | """
45 |
46 | MergeLandscapeAttributes:
47 |
48 | def __init__(
49 | self,
50 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv",
51 | loc_opt_path: str = "results/local_optima/scale2max.csv",
52 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv",
53 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv",
54 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv",
55 | merge_dir: str = "results/merged",
56 | )
57 |
58 | MergeMLDEAttributes(MergeLandscapeAttributes):
59 |
60 | def __init__(
61 | self,
62 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv",
63 | loc_opt_path: str = "results/local_optima/scale2max.csv",
64 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv",
65 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv",
66 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv",
67 | mlde_path: str = "results/mlde/all_df_comb_onehot_2.csv",
68 | merge_dir: str = "results/merged",
69 | n_mut_cutoff: int = 0,
70 | n_sample: int = 384,
71 | n_top: int = 96,
72 | filter_active: float = 1,
73 | ft_frac=0.125,
74 | models: list[str] = ["boosting", "ridge"],
75 | ifplot: bool = True,
76 | )
77 |
78 | perfom_corr(
79 | lib_stat_path: str = "results/fitness_distribution/max/all_lib_stats.csv",
80 | loc_opt_path: str = "results/local_optima/scale2max.csv",
81 | pwe_path: str = "results/pairwise_epistasis_vis/none/scale2max.csv",
82 | zs_path: str = "results/zs_sum/none/zs_stat_scale2max.csv",
83 | de_path: str = "results/de/DE-active/scale2max/all_landscape_de_summary.csv",
84 | mlde_path: str = "results/mlde/all_df_comb_onehot.csv",
85 | corr_dir: str = "results/corr",
86 | n_mut_cutoff: int = 0,
87 | filter_active: float = 1,
88 | ft_frac: float = 0.125,
89 | n_top_list: list[int] = [96, 384],
90 | n_list: list[int] = N_SAMPLE_LIST,
91 | models_list: list[list[str]] = [["boosting", "ridge"], ["boosting"], ["ridge"]],
92 | ifplot: bool = True,
93 | )
94 | """
95 |
96 | f.close()
--------------------------------------------------------------------------------
/envs/frozen/esmif.yml:
--------------------------------------------------------------------------------
1 | name: esmif
2 | channels:
3 | - pyg
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - aiohappyeyeballs=2.4.4=pyhd8ed1ab_1
11 | - aiohttp=3.11.11=py39h9399b63_0
12 | - aiosignal=1.3.2=pyhd8ed1ab_0
13 | - async-timeout=5.0.1=pyhd8ed1ab_1
14 | - attrs=25.1.0=pyh71513ae_0
15 | - biopandas=0.5.1=pyhd8ed1ab_1
16 | - biopython=1.85=py39h8cd3c5a_1
17 | - biotite=0.38.0=py39h44dd56e_0
18 | - blas=1.0=mkl
19 | - brotli-python=1.1.0=py39hf88036b_2
20 | - bzip2=1.0.8=h4bc722e_7
21 | - ca-certificates=2025.1.31=hbcca054_0
22 | - certifi=2024.12.14=pyhd8ed1ab_0
23 | - cffi=1.17.1=py39h15c3d72_0
24 | - charset-normalizer=3.4.1=pyhd8ed1ab_0
25 | - colorama=0.4.6=pyhd8ed1ab_1
26 | - cpuonly=2.0=0
27 | - cpython=3.9.21=py39hd8ed1ab_1
28 | - cudatoolkit=11.3.1=hb98b00a_13
29 | - filelock=3.17.0=pyhd8ed1ab_0
30 | - frozenlist=1.5.0=py39h9399b63_1
31 | - fsspec=2025.2.0=pyhd8ed1ab_0
32 | - gmp=6.3.0=hac33072_2
33 | - gmpy2=2.1.5=py39h7196dd7_3
34 | - h2=4.1.0=pyhd8ed1ab_1
35 | - hpack=4.1.0=pyhd8ed1ab_0
36 | - hyperframe=6.1.0=pyhd8ed1ab_0
37 | - idna=3.10=pyhd8ed1ab_1
38 | - intel-openmp=2022.0.1=h06a4308_3633
39 | - jinja2=3.1.5=pyhd8ed1ab_0
40 | - ld_impl_linux-64=2.43=h712a8e2_2
41 | - libblas=3.9.0=16_linux64_mkl
42 | - libcblas=3.9.0=16_linux64_mkl
43 | - libffi=3.4.2=h7f98852_5
44 | - libgcc=14.2.0=h77fa898_1
45 | - libgcc-ng=14.2.0=h69a702a_1
46 | - libgfortran=14.2.0=h69a702a_1
47 | - libgfortran-ng=14.2.0=h69a702a_1
48 | - libgfortran5=14.2.0=hd5240d6_1
49 | - libgomp=14.2.0=h77fa898_1
50 | - liblapack=3.9.0=16_linux64_mkl
51 | - liblzma=5.6.4=hb9d3cd8_0
52 | - libnsl=2.0.1=hd590300_0
53 | - libsqlite=3.48.0=hee588c1_1
54 | - libstdcxx=14.2.0=hc0a3c3a_1
55 | - libstdcxx-ng=14.2.0=h4852527_1
56 | - libuuid=2.38.1=h0b41bf4_0
57 | - libxcrypt=4.4.36=hd590300_1
58 | - libzlib=1.3.1=hb9d3cd8_2
59 | - llvm-openmp=15.0.7=h0cdce71_0
60 | - looseversion=1.3.0=pyhd8ed1ab_0
61 | - markupsafe=3.0.2=py39h9399b63_1
62 | - mkl=2022.1.0=hc2b9512_224
63 | - mmtf-python=1.1.3=pyhd8ed1ab_0
64 | - mpc=1.3.1=h24ddda3_1
65 | - mpfr=4.2.1=h90cbb55_3
66 | - mpmath=1.3.0=pyhd8ed1ab_1
67 | - msgpack-python=1.1.0=py39h74842e3_0
68 | - multidict=6.1.0=py39h9399b63_2
69 | - ncurses=6.5=h2d0b736_3
70 | - networkx=3.2.1=pyhd8ed1ab_0
71 | - numpy=1.26.4=py39h474f0d3_0
72 | - openssl=3.4.0=h7b32b05_1
73 | - pandas=2.2.3=py39h3b40f6f_2
74 | - pip=25.0=pyh8b19718_0
75 | - propcache=0.2.1=py39h9399b63_1
76 | - psutil=6.1.1=py39h8cd3c5a_0
77 | - pycparser=2.22=pyh29332c3_1
78 | - pyg=2.6.1=py39_torch_2.4.0_cpu
79 | - pyparsing=3.2.1=pyhd8ed1ab_0
80 | - pysocks=1.7.1=pyha55dd90_7
81 | - python=3.9.21=h9c0c6dc_1_cpython
82 | - python-dateutil=2.9.0.post0=pyhff2d567_1
83 | - python-tzdata=2025.1=pyhd8ed1ab_0
84 | - python_abi=3.9=5_cp39
85 | - pytorch=2.4.1=py3.9_cpu_0
86 | - pytorch-mutex=1.0=cpu
87 | - pytz=2024.1=pyhd8ed1ab_0
88 | - pyyaml=6.0.2=py39h9399b63_2
89 | - readline=8.2=h8228510_1
90 | - requests=2.32.3=pyhd8ed1ab_1
91 | - scipy=1.13.1=py39haf93ffa_0
92 | - setuptools=75.8.0=pyhff2d567_0
93 | - six=1.17.0=pyhd8ed1ab_0
94 | - sympy=1.13.3=pyh2585a3b_105
95 | - tk=8.6.13=noxft_h4845f30_101
96 | - tqdm=4.67.1=pyhd8ed1ab_1
97 | - typing-extensions=4.12.2=hd8ed1ab_1
98 | - typing_extensions=4.12.2=pyha770c72_1
99 | - tzdata=2025a=h78e105d_0
100 | - urllib3=2.3.0=pyhd8ed1ab_0
101 | - wheel=0.45.1=pyhd8ed1ab_1
102 | - yaml=0.2.5=h7f98852_2
103 | - yarl=1.18.3=py39h9399b63_1
104 | - zstandard=0.23.0=py39h08a7858_1
105 | - zstd=1.5.6=ha6fb4c9_0
106 | - pip:
107 | - fair-esm==2.0.1
108 | - pillow==11.1.0
109 | - rdkit==2024.9.4
110 | prefix: /disk2/fli/miniconda3/envs/esmif2
111 |
--------------------------------------------------------------------------------
/SSMuLA/est_ep.py:
--------------------------------------------------------------------------------
1 | """
2 | A script to estimate pairwise epistasis in a given dataset.
3 | """
4 |
5 | import os
6 | from glob import glob
7 | import pandas as pd
8 | import numpy as np
9 |
10 | from SSMuLA.landscape_global import LIB_INFO_DICT, lib2prot
11 | from SSMuLA.fitness_process_vis import parse_lib_stat
12 |
13 | from Bio.PDB import PDBParser, PDBIO
14 | from Bio.PDB.PDBExceptions import PDBConstructionWarning
15 | import warnings
16 | import itertools
17 |
18 | # Suppress PDB construction warnings
19 | warnings.simplefilter('ignore', PDBConstructionWarning)
20 |
21 | def get_ca_distance(structure, residue1, residue2, chain_id='A'):
22 | """
23 | Calculate the C-alpha distance between two residues within a structure object.
24 |
25 | Parameters:
26 | - structure: PDB structure object.
27 | - residue1, residue2: Residue numbers (integers) of the residues to measure.
28 | - chain_id: ID of the chain where the residues are located (default is chain 'A').
29 |
30 | Returns:
31 | - distance: Distance between the C-alpha atoms of the specified residues.
32 | """
33 | # Select chain and residues
34 | chain = structure[0][chain_id] # Assume using the first model
35 | res1 = chain[residue1]
36 | res2 = chain[residue2]
37 |
38 | # Fetch the 'CA' atoms if they exist
39 | if 'CA' in res1 and 'CA' in res2:
40 | ca1 = res1['CA']
41 | ca2 = res2['CA']
42 | # Calculate distance
43 | distance = ca1 - ca2
44 | return distance
45 | else:
46 | return None
47 |
48 | def calculate_pairwise_distances(pdb_file, residues_dict, chain_id='A'):
49 | """
50 | Calculate pairwise C-alpha distances for a set of residues specified in a dictionary.
51 |
52 | Parameters:
53 | - pdb_file: Path to the PDB file.
54 | - residues_dict: Dictionary mapping indices to residue numbers.
55 | - chain_id: Chain ID to look for residues.
56 |
57 | Returns:
58 | - distances: Dictionary of tuple (residue pair) to distance.
59 | """
60 | # Parse the PDB file
61 | parser = PDBParser()
62 | structure = parser.get_structure('PDB', pdb_file)
63 |
64 | # Calculate distances for all pairs
65 | distances = {}
66 | for (idx1, res1), (idx2, res2) in itertools.combinations(residues_dict.items(), 2):
67 | distance = get_ca_distance(structure, res1, res2, chain_id)
68 | distances[(res1, res2)] = distance
69 |
70 | return distances
71 |
72 |
73 | def all_lib_pairwise_dist(
74 | data_dir: str = "data",
75 | ):
76 |
77 | """
78 | Calculate pairwise distances for all libraries in the specified data directory.
79 |
80 | Args:
81 | - data_dir: Directory containing PDB files for each library.
82 |
83 | Returns:
84 | - pwd: DataFrame containing mean and standard deviation of distances for each library.
85 | """
86 | df = pd.DataFrame(columns=["lib", "res1", "res2", "dist"])
87 |
88 | chain_id = "A"
89 | for lib, l_d in LIB_INFO_DICT.items():
90 | pdb_path = os.path.join(data_dir, f"{lib2prot(lib)}/{lib2prot(lib)}.pdb")
91 |
92 | parser = PDBParser()
93 | structure = parser.get_structure("PDB", pdb_path)
94 |
95 | for (idx1, res_id1), (idx2, res_id2) in itertools.combinations(
96 | l_d["positions"].items(), 2
97 | ):
98 | df = df._append(
99 | {
100 | "lib": lib,
101 | "res1": res_id1,
102 | "res2": res_id2,
103 | "dist": get_ca_distance(structure, res_id1, res_id2, chain_id),
104 | },
105 | ignore_index=True,
106 | )
107 |
108 | pwd = df[["lib", "dist"]].groupby(["lib"]).agg(["mean", "std"])
109 | pwd.columns = ["mean", "std"]
110 |
111 | return pwd
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom ignore
2 |
3 | *.zip
4 | *.tar.gz
5 |
6 | data
7 | data/
8 | data/*
9 |
10 | data_old
11 | data_old/
12 | data_old/*
13 |
14 | data4upload
15 | data4upload/
16 | data4upload/*
17 |
18 | coves_data
19 | coves_data/
20 | coves_data/*
21 |
22 | add_data
23 | add_data/
24 | add_data/*
25 |
26 | coves
27 | coves/
28 | coves/*
29 |
30 | results
31 | results/
32 | results/*
33 | results.tar.gz
34 |
35 | results_rev
36 | results_rev/
37 | results_rev/*
38 |
39 | results_raw
40 | results_raw/
41 | results_raw/*
42 |
43 | results4upload
44 | results4upload/
45 | results4upload/*
46 |
47 | figs
48 | figs/
49 | figs/*
50 |
51 | runs
52 | runs/
53 | runs/*
54 |
55 | logs
56 | logs/
57 | logs/*
58 |
59 | esmif/*csv
60 |
61 | ev_esm
62 | ev_esm/
63 | ev_esm/*
64 |
65 | ev_esm1v
66 | ev_esm1v/
67 | ev_esm1v/*
68 |
69 | ev_esm2
70 | ev_esm2/
71 | ev_esm2/*
72 |
73 | gvp-pytorch
74 | gvp-pytorch/
75 | gvp-pytorch/*
76 |
77 | lmdb
78 | lmdb/
79 | lmdb/*
80 |
81 | triad
82 | triad/
83 | triad/*
84 |
85 | mlde
86 | mlde/
87 | mlde/*
88 |
89 | mlde_messy
90 | mlde_messy/
91 | mlde_messy/*
92 |
93 | learned_emb
94 | learned_emb/
95 | learned_emb/*
96 |
97 | KEJ
98 | KEJ/
99 | KEJ/*
100 |
101 | sandbox
102 | sandbox/
103 | sandbox/*
104 |
105 | sandbox/esmif
106 | sandbox/esmif/
107 | sandbox/esmif/*
108 |
109 | sandbox/fig_svg
110 | sandbox/fig_svg/
111 | sandbox/fig_svg/*
112 |
113 | zs
114 | zs/
115 | zs/*
116 |
117 | *.idea
118 | *.npy
119 |
120 | # Byte-compiled / optimized / DLL files
121 | __pycache__/
122 | *.py[cod]
123 | *$py.class
124 |
125 | # C extensions
126 | *.so
127 |
128 | # Distribution / packaging
129 | .Python
130 | build/
131 | develop-eggs/
132 | dist/
133 | downloads/
134 | eggs/
135 | .eggs/
136 | lib/
137 | lib64/
138 | parts/
139 | sdist/
140 | var/
141 | wheels/
142 | pip-wheel-metadata/
143 | share/python-wheels/
144 | *.egg-info/
145 | .installed.cfg
146 | *.egg
147 | MANIFEST
148 |
149 | # PyInstaller
150 | # Usually these files are written by a python script from a template
151 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
152 | *.manifest
153 | *.spec
154 |
155 | # Installer logs
156 | pip-log.txt
157 | pip-delete-this-directory.txt
158 |
159 | # Unit test / coverage reports
160 | htmlcov/
161 | .tox/
162 | .nox/
163 | .coverage
164 | .coverage.*
165 | .cache
166 | nosetests.xml
167 | coverage.xml
168 | *.cover
169 | *.py,cover
170 | .hypothesis/
171 | .pytest_cache/
172 |
173 | # Translations
174 | *.mo
175 | *.pot
176 |
177 | # Django stuff:
178 | *.log
179 | local_settings.py
180 | db.sqlite3
181 | db.sqlite3-journal
182 |
183 | # Flask stuff:
184 | instance/
185 | .webassets-cache
186 |
187 | # Scrapy stuff:
188 | .scrapy
189 |
190 | # Sphinx documentation
191 | docs/_build/
192 |
193 | # PyBuilder
194 | target/
195 |
196 | # Jupyter Notebook
197 | .ipynb_checkpoints
198 |
199 | # IPython
200 | profile_default/
201 | ipython_config.py
202 |
203 | # pyenv
204 | .python-version
205 |
206 | # pipenv
207 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
208 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
209 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
210 | # install all needed dependencies.
211 | #Pipfile.lock
212 |
213 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
214 | __pypackages__/
215 |
216 | # Celery stuff
217 | celerybeat-schedule
218 | celerybeat.pid
219 |
220 | # SageMath parsed files
221 | *.sage.py
222 |
223 | # Environments
224 | .env
225 | .venv
226 | env/
227 | venv/
228 | ENV/
229 | env.bak/
230 | venv.bak/
231 |
232 | # Spyder project settings
233 | .spyderproject
234 | .spyproject
235 |
236 | # Rope project settings
237 | .ropeproject
238 |
239 | # mkdocs documentation
240 | /site
241 |
242 | # mypy
243 | .mypy_cache/
244 | .dmypy.json
245 | dmypy.json
246 |
247 | # Pyre type checker
248 | .pyre/
249 |
--------------------------------------------------------------------------------
/SSMuLA/util.py:
--------------------------------------------------------------------------------
1 | """Util functions"""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | import pickle
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from sklearn.metrics import ndcg_score
12 |
13 |
14 | def checkNgen_folder(folder_path: str) -> str:
15 |
16 | """
17 | Check if the folder and its subfolder exists
18 | create a new directory if not
19 | Args:
20 | - folder_path: str, the folder path
21 | """
22 |
23 | # if input path is file
24 | if bool(os.path.splitext(folder_path)[1]):
25 | folder_path = os.path.dirname(folder_path)
26 |
27 | split_list = os.path.normpath(folder_path).split("/")
28 | for p, _ in enumerate(split_list):
29 | subfolder_path = "/".join(split_list[: p + 1])
30 | if not os.path.exists(subfolder_path):
31 | print(f"Making {subfolder_path} ...")
32 | os.mkdir(subfolder_path)
33 | return folder_path
34 |
35 |
36 | def pickle_save(what2save, where2save: str) -> None:
37 |
38 | """
39 | Save variable to a pickle file
40 | Args:
41 | - what2save, the varible that needs to be saved
42 | - where2save: str, the .pkl path for saving
43 | """
44 |
45 | with open(where2save, "wb") as f:
46 | pickle.dump(what2save, f)
47 |
48 |
49 | def pickle_load(path2load: str):
50 |
51 | """
52 | Load pickle file
53 | Args:
54 | - path2load: str, the .pkl path for loading
55 | """
56 |
57 | with open(path2load, "rb") as f:
58 | return pickle.load(f)
59 |
60 |
61 | def get_file_name(file_path: str) -> str:
62 |
63 | """
64 | Extract file name without the extension
65 | Args:
66 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb
67 | Returns:
68 | - str, ie WT
69 | """
70 |
71 | return os.path.splitext(os.path.basename(file_path))[0]
72 |
73 |
74 | def get_dir_name(file_path: str) -> str:
75 |
76 | """
77 | Extract dir name
78 | Args:
79 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb
80 | Returns:
81 | - str, ie Tm9D8s_3siteA_fixed
82 | """
83 |
84 | return os.path.basename(os.path.dirname(file_path))
85 |
86 |
87 | def get_dirNfile_name(file_path: str) -> [str, str]:
88 |
89 | """
90 | Extract file name without the extension and direct dir name
91 | Args:
92 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb
93 | Returns:
94 | - str, ie ['Tm9D8s_3siteA_fixed', 'WT']
95 | """
96 |
97 | return (
98 | os.path.basename(os.path.dirname(file_path)),
99 | os.path.splitext(os.path.basename(file_path))[0],
100 | )
101 |
102 |
103 | def get_fulldirNfile_name(file_path: str) -> [str, str]:
104 |
105 | """
106 | Extract file name without the extension and full dir name
107 | Args:
108 | - file_path: str, ie. data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed/WT.pdb
109 | Returns:
110 | - str, ie ['data/graph_nx/Tm9D8s/Tm9D8s_3siteA_fixed', 'WT']
111 | """
112 |
113 | return os.path.dirname(file_path), os.path.splitext(os.path.basename(file_path))[0]
114 |
115 |
116 | def ndcg_scale(true: np.ndarray, pred: np.ndarray):
117 | """Calculate the ndcg_score with neg correction"""
118 |
119 | if min(true) < 0:
120 | true = true - min(true)
121 | return ndcg_score(true[None, :], pred[None, :])
122 |
123 |
124 | def ecdf_transform(data: pd.Series) -> pd.Series:
125 |
126 | """
127 | Transform a series of fitness values into an empirical cumulative distribution function
128 |
129 | Args:
130 | - data: pd.Series, the fitness values
131 |
132 | Returns:
133 | - pd.Series, the ECDF
134 | """
135 |
136 | return data.rank(method="first") / len(data)
137 |
138 |
139 |
140 | def csv2fasta(csv: str) -> None:
141 | """
142 | A function for converting a csv file to a fasta file
143 | ie /disk2/fli/SSMuLA/ev_esm2/DHFR/DHFR.csv
144 |
145 | """
146 | df = pd.read_csv(csv)
147 |
148 | for col in ["muts", "seq"]:
149 | if col not in df.columns:
150 | raise ValueError(f"{col} column not found")
151 |
152 | fasta = csv.replace(".csv", ".fasta")
153 | with open(fasta, "w") as f:
154 | for mut, seq in zip(df["muts"].values, df["seq"].values):
155 | f.write(f">{mut}\n{seq}\n")
--------------------------------------------------------------------------------
/esmif/score_log_likelihoods.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | #
6 | # Scores sequences based on a given structure.
7 | #
8 | # usage:
9 | # score_log_likelihoods.py [-h] [--outpath OUTPATH] [--chain CHAIN] pdbfile seqfile
10 |
11 | import argparse
12 | from biotite.sequence.io.fasta import FastaFile, get_sequences
13 | import numpy as np
14 | from pathlib import Path
15 | import torch
16 | import torch.nn.functional as F
17 | from tqdm import tqdm
18 |
19 | import esm
20 | import esm.inverse_folding
21 |
22 |
23 | def score_singlechain_backbone(model, alphabet, args):
24 | if torch.cuda.is_available() and not args.nogpu:
25 | model = model.cuda()
26 | print("Transferred model to GPU")
27 | coords, native_seq = esm.inverse_folding.util.load_coords(args.pdbfile, args.chain)
28 | print('Native sequence loaded from structure file:')
29 | print(native_seq)
30 | print('\n')
31 |
32 | ll, _ = esm.inverse_folding.util.score_sequence(
33 | model, alphabet, coords, native_seq)
34 | print('Native sequence')
35 | print(f'Log likelihood: {ll:.2f}')
36 | print(f'Perplexity: {np.exp(-ll):.2f}')
37 |
38 | print('\nScoring variant sequences from sequence file..\n')
39 | infile = FastaFile()
40 | infile.read(args.seqfile)
41 | seqs = get_sequences(infile)
42 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
43 | with open(args.outpath, 'w') as fout:
44 | fout.write('seqid,log_likelihood\n')
45 | for header, seq in tqdm(seqs.items()):
46 | ll, _ = esm.inverse_folding.util.score_sequence(
47 | model, alphabet, coords, str(seq))
48 | fout.write(header + ',' + str(ll) + '\n')
49 | print(f'Results saved to {args.outpath}')
50 |
51 |
52 | def score_multichain_backbone(model, alphabet, args):
53 | if torch.cuda.is_available() and not args.nogpu:
54 | model = model.cuda()
55 | print("Transferred model to GPU")
56 | structure = esm.inverse_folding.util.load_structure(args.pdbfile)
57 | coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)
58 | target_chain_id = args.chain
59 | native_seq = native_seqs[target_chain_id]
60 | print('Native sequence loaded from structure file:')
61 | print(native_seq)
62 | print('\n')
63 |
64 | ll, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex(
65 | model, alphabet, coords, target_chain_id, native_seq)
66 | print('Native sequence')
67 | print(f'Log likelihood: {ll:.2f}')
68 | print(f'Perplexity: {np.exp(-ll):.2f}')
69 |
70 | print('\nScoring variant sequences from sequence file..\n')
71 | infile = FastaFile()
72 | infile.read(args.seqfile)
73 | seqs = get_sequences(infile)
74 | Path(args.outpath).parent.mkdir(parents=True, exist_ok=True)
75 | with open(args.outpath, 'w') as fout:
76 | fout.write('seqid,log_likelihood\n')
77 | for header, seq in tqdm(seqs.items()):
78 | ll, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex(
79 | model, alphabet, coords, target_chain_id, str(seq))
80 | fout.write(header + ',' + str(ll) + '\n')
81 | print(f'Results saved to {args.outpath}')
82 |
83 |
84 | def main():
85 | parser = argparse.ArgumentParser(
86 | description='Score sequences based on a given structure.'
87 | )
88 | parser.add_argument(
89 | 'pdbfile', type=str,
90 | help='input filepath, either .pdb or .cif',
91 | )
92 | parser.add_argument(
93 | 'seqfile', type=str,
94 | help='input filepath for variant sequences in a .fasta file',
95 | )
96 | parser.add_argument(
97 | '--outpath', type=str,
98 | help='output filepath for scores of variant sequences',
99 | default='output/sequence_scores.csv',
100 | )
101 | parser.add_argument(
102 | '--chain', type=str,
103 | help='chain id for the chain of interest', default='A',
104 | )
105 | parser.set_defaults(multichain_backbone=False)
106 | parser.add_argument(
107 | '--multichain-backbone', action='store_true',
108 | help='use the backbones of all chains in the input for conditioning'
109 | )
110 | parser.add_argument(
111 | '--singlechain-backbone', dest='multichain_backbone',
112 | action='store_false',
113 | help='use the backbone of only target chain in the input for conditioning'
114 | )
115 |
116 | parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
117 |
118 | args = parser.parse_args()
119 |
120 | model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
121 | model = model.eval()
122 |
123 | if args.multichain_backbone:
124 | score_multichain_backbone(model, alphabet, args)
125 | else:
126 | score_singlechain_backbone(model, alphabet, args)
127 |
128 |
129 |
130 | if __name__ == '__main__':
131 | main()
--------------------------------------------------------------------------------
/envs/frozen/coves.yml:
--------------------------------------------------------------------------------
1 | name: coves
2 | channels:
3 | - pytorch
4 | - pyg
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - aiohappyeyeballs=2.4.4=pyhd8ed1ab_1
11 | - aiohttp=3.11.11=py39h9399b63_0
12 | - aiosignal=1.3.2=pyhd8ed1ab_0
13 | - async-timeout=5.0.1=pyhd8ed1ab_1
14 | - attrs=25.1.0=pyh71513ae_0
15 | - blas=1.0=mkl
16 | - brotli-python=1.1.0=py39hf88036b_2
17 | - bzip2=1.0.8=h4bc722e_7
18 | - ca-certificates=2024.12.14=hbcca054_0
19 | - certifi=2024.12.14=pyhd8ed1ab_0
20 | - cffi=1.17.1=py39h15c3d72_0
21 | - charset-normalizer=3.4.1=pyhd8ed1ab_0
22 | - colorama=0.4.6=pyhd8ed1ab_1
23 | - cpython=3.9.21=py39hd8ed1ab_1
24 | - cuda-cudart=12.4.127=he02047a_2
25 | - cuda-cudart_linux-64=12.4.127=h85509e4_2
26 | - cuda-cupti=12.4.127=he02047a_2
27 | - cuda-libraries=12.4.1=ha770c72_1
28 | - cuda-nvrtc=12.4.127=he02047a_2
29 | - cuda-nvtx=12.4.127=he02047a_2
30 | - cuda-opencl=12.4.127=he02047a_1
31 | - cuda-runtime=12.4.1=ha804496_0
32 | - cuda-version=12.4=h3060b56_3
33 | - filelock=3.17.0=pyhd8ed1ab_0
34 | - frozenlist=1.5.0=py39h9399b63_1
35 | - fsspec=2024.12.0=pyhd8ed1ab_0
36 | - gmp=6.3.0=hac33072_2
37 | - gmpy2=2.1.5=py39h7196dd7_3
38 | - h2=4.1.0=pyhd8ed1ab_1
39 | - hpack=4.1.0=pyhd8ed1ab_0
40 | - hyperframe=6.1.0=pyhd8ed1ab_0
41 | - idna=3.10=pyhd8ed1ab_1
42 | - intel-openmp=2022.0.1=h06a4308_3633
43 | - jinja2=3.1.5=pyhd8ed1ab_0
44 | - joblib=1.4.2=pyhd8ed1ab_1
45 | - ld_impl_linux-64=2.43=h712a8e2_2
46 | - libblas=3.9.0=16_linux64_mkl
47 | - libcblas=3.9.0=16_linux64_mkl
48 | - libcublas=12.4.5.8=he02047a_2
49 | - libcufft=11.2.1.3=he02047a_2
50 | - libcufile=1.9.1.3=he02047a_2
51 | - libcurand=10.3.5.147=he02047a_2
52 | - libcusolver=11.6.1.9=he02047a_2
53 | - libcusparse=12.3.1.170=he02047a_2
54 | - libffi=3.4.2=h7f98852_5
55 | - libgcc=14.2.0=h77fa898_1
56 | - libgcc-ng=14.2.0=h69a702a_1
57 | - libgfortran=14.2.0=h69a702a_1
58 | - libgfortran-ng=14.2.0=h69a702a_1
59 | - libgfortran5=14.2.0=hd5240d6_1
60 | - libgomp=14.2.0=h77fa898_1
61 | - liblapack=3.9.0=16_linux64_mkl
62 | - liblzma=5.6.3=hb9d3cd8_1
63 | - libnpp=12.2.5.30=he02047a_2
64 | - libnsl=2.0.1=hd590300_0
65 | - libnvfatbin=12.4.127=he02047a_2
66 | - libnvjitlink=12.4.127=he02047a_2
67 | - libnvjpeg=12.3.1.117=he02047a_2
68 | - libsqlite=3.48.0=hee588c1_1
69 | - libstdcxx=14.2.0=hc0a3c3a_1
70 | - libstdcxx-ng=14.2.0=h4852527_1
71 | - libuuid=2.38.1=h0b41bf4_0
72 | - libxcrypt=4.4.36=hd590300_1
73 | - libzlib=1.3.1=hb9d3cd8_2
74 | - llvm-openmp=15.0.7=h0cdce71_0
75 | - markupsafe=3.0.2=py39h9399b63_1
76 | - mkl=2022.1.0=hc2b9512_224
77 | - mpc=1.3.1=h24ddda3_1
78 | - mpfr=4.2.1=h90cbb55_3
79 | - mpmath=1.3.0=pyhd8ed1ab_1
80 | - multidict=6.1.0=py39h9399b63_2
81 | - ncurses=6.5=h2d0b736_2
82 | - networkx=3.2.1=pyhd8ed1ab_0
83 | - ocl-icd=2.3.2=hb9d3cd8_2
84 | - opencl-headers=2024.10.24=h5888daf_0
85 | - openssl=3.4.0=h7b32b05_1
86 | - pandas=2.2.3=py39h3b40f6f_2
87 | - pip=24.3.1=pyh8b19718_2
88 | - propcache=0.2.1=py39h9399b63_1
89 | - psutil=6.1.1=py39h8cd3c5a_0
90 | - pycparser=2.22=pyh29332c3_1
91 | - pyg=2.6.1=py39_torch_2.4.0_cu124
92 | - pyparsing=3.2.1=pyhd8ed1ab_0
93 | - pysocks=1.7.1=pyha55dd90_7
94 | - python=3.9.21=h9c0c6dc_1_cpython
95 | - python-dateutil=2.9.0.post0=pyhff2d567_1
96 | - python-tzdata=2025.1=pyhd8ed1ab_0
97 | - python_abi=3.9=5_cp39
98 | - pytorch-cuda=12.4=hc786d27_7
99 | - pytorch-mutex=1.0=cuda
100 | - pytz=2024.1=pyhd8ed1ab_0
101 | - pyyaml=6.0.2=py39h9399b63_2
102 | - readline=8.2=h8228510_1
103 | - requests=2.32.3=pyhd8ed1ab_1
104 | - scikit-learn=1.6.1=py39h4b7350c_0
105 | - scipy=1.13.1=py39haf93ffa_0
106 | - setuptools=75.8.0=pyhff2d567_0
107 | - six=1.17.0=pyhd8ed1ab_0
108 | - sympy=1.13.3=pyh2585a3b_105
109 | - threadpoolctl=3.5.0=pyhc1e730c_0
110 | - tk=8.6.13=noxft_h4845f30_101
111 | - tqdm=4.67.1=pyhd8ed1ab_1
112 | - typing-extensions=4.12.2=hd8ed1ab_1
113 | - typing_extensions=4.12.2=pyha770c72_1
114 | - tzdata=2025a=h78e105d_0
115 | - urllib3=2.3.0=pyhd8ed1ab_0
116 | - wheel=0.45.1=pyhd8ed1ab_1
117 | - yaml=0.2.5=h7f98852_2
118 | - yarl=1.18.3=py39h9399b63_1
119 | - zstandard=0.23.0=py39h08a7858_1
120 | - zstd=1.5.6=ha6fb4c9_0
121 | - pip:
122 | - anyio==4.8.0
123 | - argon2-cffi==23.1.0
124 | - argon2-cffi-bindings==21.2.0
125 | - argparse==1.4.0
126 | - arrow==1.3.0
127 | - asttokens==3.0.0
128 | - async-lru==2.0.4
129 | - atom3d==0.2.6
130 | - babel==2.16.0
131 | - beautifulsoup4==4.12.3
132 | - biopython==1.85
133 | - black==21.12b0
134 | - blackcellmagic==0.0.3
135 | - bleach==6.2.0
136 | - blosc2==2.5.1
137 | - click==8.1.8
138 | - comm==0.2.2
139 | - debugpy==1.8.12
140 | - decorator==5.1.1
141 | - defusedxml==0.7.1
142 | - dill==0.3.9
143 | - easy-parallel==0.1.6
144 | - exceptiongroup==1.2.2
145 | - executing==2.2.0
146 | - fastjsonschema==2.21.1
147 | - fqdn==1.5.1
148 | - freesasa==2.2.1
149 | - h11==0.14.0
150 | - h5py==3.12.1
151 | - httpcore==1.0.7
152 | - httpx==0.28.1
153 | - importlib-metadata==8.6.1
154 | - ipykernel==6.29.5
155 | - ipython==8.18.1
156 | - ipywidgets==8.1.5
157 | - isoduration==20.11.0
158 | - jedi==0.19.2
159 | - json5==0.10.0
160 | - jsonpointer==3.0.0
161 | - jsonschema==4.23.0
162 | - jsonschema-specifications==2024.10.1
163 | - jupyter==1.1.1
164 | - jupyter-client==8.6.3
165 | - jupyter-console==6.6.3
166 | - jupyter-core==5.7.2
167 | - jupyter-events==0.11.0
168 | - jupyter-lsp==2.2.5
169 | - jupyter-server==2.15.0
170 | - jupyter-server-terminals==0.5.3
171 | - jupyterlab==4.3.4
172 | - jupyterlab-pygments==0.3.0
173 | - jupyterlab-server==2.27.3
174 | - jupyterlab-widgets==3.0.13
175 | - lmdb==1.6.2
176 | - matplotlib-inline==0.1.7
177 | - mistune==3.1.0
178 | - msgpack==1.1.0
179 | - multipledispatch==1.0.0
180 | - multiprocess==0.70.17
181 | - mypy-extensions==1.0.0
182 | - nbclient==0.10.2
183 | - nbconvert==7.16.5
184 | - nbformat==5.10.4
185 | - ndindex==1.9.2
186 | - nest-asyncio==1.6.0
187 | - notebook==7.3.2
188 | - notebook-shim==0.2.4
189 | - numexpr==2.10.2
190 | - numpy==1.23.5
191 | - overrides==7.7.0
192 | - packaging==24.2
193 | - pandocfilters==1.5.1
194 | - parso==0.8.4
195 | - pathos==0.3.3
196 | - pathspec==0.12.1
197 | - pexpect==4.9.0
198 | - pillow==11.1.0
199 | - platformdirs==4.3.6
200 | - pox==0.3.5
201 | - ppft==1.7.6.9
202 | - prometheus-client==0.21.1
203 | - prompt-toolkit==3.0.50
204 | - ptyprocess==0.7.0
205 | - pure-eval==0.2.3
206 | - py-cpuinfo==9.0.0
207 | - pygments==2.19.1
208 | - pyrr==0.10.3
209 | - python-dotenv==1.0.1
210 | - python-json-logger==3.2.1
211 | - pyzmq==26.2.0
212 | - rdkit==2024.9.4
213 | - referencing==0.36.2
214 | - rfc3339-validator==0.1.4
215 | - rfc3986-validator==0.1.1
216 | - rpds-py==0.22.3
217 | - send2trash==1.8.3
218 | - sniffio==1.3.1
219 | - soupsieve==2.6
220 | - stack-data==0.6.3
221 | - tables==3.9.2
222 | - terminado==0.18.1
223 | - tinycss2==1.4.0
224 | - tomli==1.2.3
225 | - torch==2.1.0+cu121
226 | - torch-cluster==1.6.3+pt21cu121
227 | - torchaudio==2.1.0+cu121
228 | - torchvision==0.16.0+cu121
229 | - tornado==6.4.2
230 | - traitlets==5.14.3
231 | - triton==2.1.0
232 | - types-python-dateutil==2.9.0.20241206
233 | - uri-template==1.3.0
234 | - wcwidth==0.2.13
235 | - webcolors==24.11.1
236 | - webencodings==0.5.1
237 | - websocket-client==1.8.0
238 | - widgetsnbextension==4.0.13
239 | - zipp==3.21.0
240 | prefix: /disk2/fli/miniconda3/envs/coves
241 |
--------------------------------------------------------------------------------
/SSMuLA/zs_calc.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work
3 | """
4 |
5 | # Import packages
6 | import os
7 | from glob import glob
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import random
12 | from collections import Counter
13 |
14 |
15 | import argparse
16 | from pathlib import Path
17 |
18 | from tqdm import tqdm
19 | from typing import List, Tuple, Optional
20 |
21 |
22 | import warnings
23 |
24 | from SSMuLA.landscape_global import lib2prot
25 | from SSMuLA.zs_models import ZeroShotPrediction, ESM, EvMutation
26 | from SSMuLA.zs_data import DataProcessor
27 | from SSMuLA.util import get_file_name, checkNgen_folder
28 |
29 |
30 | # TODO clean up evmutation model path
31 |
32 |
33 | def calc_zs(
34 | fit_df_path: str,
35 | scalefit : str = "scale2max",
36 | output_folder: str = "results/zs",
37 | zs_model_names: str = "all",
38 | ev_model_folder: str = "data/evmodels",
39 | regen_esm: bool = False,
40 | rerun_zs: bool = False,
41 | ) -> pd.DataFrame:
42 |
43 | """
44 | A function for calculating zs scores and adding them to the fitness csv
45 |
46 | Args:
47 | - input_folder: str, input folder of landscape
48 | ie, 'data/processed/AAV2_Bryant_2021',
49 | contains csv and fasta with wt seq
50 | - output_folder: str = "results/zs", with landscape subfolders
51 | - zs_model_names: str, name(s) of zero-shot models to use seperated by comma,
52 | available: 'esm', 'ev'
53 | developing: 'ddg', 'Bert'
54 | all: 'all' runs all currently available models
55 | ie, 'esm, ev'
56 | - ev_model_folder: str = "data/evmodels", folder for evmodels,
57 | with dataset name and dataset name.model
58 | ie. data/evmodels/AAV2_Bryant_2021/AAV2_Bryant_2021.model
59 | - regen_esm: str = False, if regenerate esm logits or load directly
60 | - rerun_zs: str = False, if append new zs to current csv or create new output
61 | """
62 |
63 | # deal with the /
64 | input_folder = os.path.normpath(fit_df_path.split(scalefit)[0])
65 | output_folder = os.path.normpath(output_folder)
66 |
67 | landscape_name = get_file_name(fit_df_path)
68 |
69 | if "DHFR" in landscape_name:
70 | append_fasta = "_trans"
71 | else:
72 | append_fasta = ""
73 |
74 | # fit_df_path = os.path.join(input_folder, landscape_name + ".csv")
75 | fasta_path = os.path.join(input_folder, lib2prot(landscape_name) + append_fasta + ".fasta")
76 |
77 | ev_model_path = os.path.join(
78 | os.path.normpath(ev_model_folder), lib2prot(landscape_name), lib2prot(landscape_name) + ".model"
79 | )
80 |
81 | print(fasta_path, ev_model_path)
82 |
83 | # check if file exists
84 | assert os.path.exists(fit_df_path), f"{fit_df_path} does not exist"
85 | assert os.path.exists(fasta_path), f"{fasta_path} does not exist"
86 |
87 | output_folder = checkNgen_folder(output_folder)
88 | landscape_output_folder = checkNgen_folder(
89 | os.path.join(output_folder, landscape_name)
90 | )
91 |
92 | zs_df_path = os.path.join(landscape_output_folder, landscape_name + ".csv")
93 |
94 | # Create an instance of the DataProcessor class
95 | data_processor = DataProcessor()
96 |
97 | # Call the prepare_zero_shot method
98 | data = data_processor.prepare_zero_shot(
99 | fit_df_path, fasta_path, _combo=True, _pos=True
100 | )
101 |
102 | # init df
103 | existing_zs_df = data
104 |
105 | # check if exist
106 | if os.path.exists(zs_df_path):
107 | if rerun_zs:
108 | print(f"{zs_df_path} exists. Remove for rerun_zs = {rerun_zs}")
109 | os.remove(zs_df_path)
110 | else:
111 | print(
112 | f"{zs_df_path} exists. Append new zs {zs_model_names} for rerun_zs = {rerun_zs}"
113 | )
114 | existing_zs_df = pd.read_csv(zs_df_path)
115 |
116 | # Ref Sequence
117 | wt = data_processor.get_Seq(fasta_path)
118 |
119 | if zs_model_names == "all":
120 | zs_model_list = ["esm", "ev"]
121 | else:
122 | zs_model_list = zs_model_names.split(",")
123 |
124 | # init zs_df_list
125 | zs_df_list = []
126 |
127 | print(f"zs_model_list: {zs_model_list}")
128 |
129 | # TODO ESM load logits directly
130 |
131 | for zs_model_name in zs_model_list:
132 |
133 | # get max numb of muts
134 | max_numb_mut = max(data["combo"].str.len())
135 |
136 | # Access Model
137 | if "esm" in zs_model_name:
138 |
139 | logits_path = os.path.join(landscape_output_folder, landscape_name + "_logits.npy")
140 |
141 | esm = ESM(data, wt, logits_path=logits_path, regen_esm=regen_esm)
142 |
143 | if os.path.exists(logits_path) and not(regen_esm):
144 | print(f"{logits_path} exists and regen_esm = {regen_esm}. Loading...")
145 | log_reprs = np.load(logits_path)
146 | else:
147 | print(f"Generating {logits_path}...")
148 | log_reprs = esm._get_logits()
149 | np.save(logits_path, log_reprs)
150 |
151 | score_esm = esm._get_n_score(list(range(max_numb_mut+1))[1:])
152 | zs_df_list.append(score_esm)
153 | print(f"score_esm:\n{score_esm.head()}")
154 |
155 | elif "ev" in zs_model_name:
156 | if os.path.exists(ev_model_path):
157 | ev = EvMutation(data, wt, model_path=ev_model_path)
158 | score_ev = ev._get_n_score(list(range(max_numb_mut+1))[1:])
159 | zs_df_list.append(score_ev)
160 | print(f"score_ev:\n{score_ev.head()}")
161 | else:
162 | print(f"{ev_model_path} does not exist yet. Skipping...")
163 |
164 | # TODO add remaining zs_models
165 |
166 | elif zs_model_name == "ddg":
167 | pass
168 | else:
169 | print("Model currently not available")
170 | continue
171 |
172 | # Add muts from data to df
173 | print(f"zs_df_list:\n{zs_df_list}")
174 |
175 | for zs_df in zs_df_list:
176 |
177 | df = pd.merge(
178 | existing_zs_df,
179 | zs_df[list(zs_df.columns.difference(existing_zs_df.columns)) + ["muts"]],
180 | left_on="muts",
181 | right_on="muts",
182 | how="outer",
183 | )
184 |
185 | existing_zs_df = df
186 |
187 | return df.to_csv(zs_df_path, index=False)
188 |
189 |
190 | # TODO FIX EV MODEL PATH
191 | def calc_all_zs(
192 | landscape_folder: str = "data",
193 | scalefit : str = "scale2max",
194 | dataset_list: list[str] = [],
195 | output_folder: str = "results/zs",
196 | zs_model_names: str = "all",
197 | ev_model_folder: str = "data",
198 | regen_esm: str = False,
199 | rerun_zs: str = False,
200 | ):
201 | """
202 | A function for calc same list of zs scores for all landscape datasets
203 |
204 | Args:
205 | - landscape_folder: str = "data/processed", folder path for all landscape data
206 | - dataset_list: list[str] = [], a list of encoders over write dataset_folder,
207 | ie. ['TrpB3I_Johnston_2023']
208 | - output_folder: str = "results/zs",
209 | - zs_model_names: str = "all",
210 | - ev_model_folder: str = "data/evmodels", folder for evmodels,
211 | with dataset name and dataset name.model
212 | ie. data/evmodels/AAV2_Bryant_2021/AAV2_Bryant_2021.model
213 | - regen_esm: bool = False, if regenerate esm logits or load directly
214 | - rerun_zs: bool = False, if append new zs to current csv or create new output
215 | """
216 |
217 | if len(dataset_list) == 0:
218 | landscape_paths = sorted(glob(os.path.normpath(landscape_folder) + "/*/" + scalefit + "/*.csv"))
219 | else:
220 | landscape_paths = [
221 | os.path.join(landscape_folder, dataset) for dataset in dataset_list
222 | ]
223 |
224 | for landscape_path in landscape_paths:
225 | print(f"Calc zs {zs_model_names} for {landscape_path}...")
226 |
227 | _ = calc_zs(
228 | fit_df_path=landscape_path,
229 | scalefit=scalefit,
230 | output_folder=output_folder,
231 | zs_model_names=zs_model_names,
232 | ev_model_folder=ev_model_folder,
233 | regen_esm=regen_esm,
234 | rerun_zs=rerun_zs,
235 | )
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSMuLA
2 |
3 | ## About
4 | * Code base termed "Site Saturation Mutagenesis Landscape Analysis (SSMuLA)" for our [paper](https://doi.org/10.1101/2024.10.24.619774) titled "Evaluation of Machine Learning-Assisted Directed Evolution Across Diverse Combinatorial Landscapes"
5 | * Data and results can be found at [Zenodo](https://doi.org/10.5281/zenodo.13910506)
6 | 
7 |
8 | ### Environment
9 | * For the overall environment `SSMuLA`
10 | ```
11 | conda env create -f SSMuLA.yml
12 | ```
13 | * Then install EVmutation from the [develop branch](https://github.com/debbiemarkslab/EVcouplings/archive/develop.zip) after the environment is created
14 | * For the ESM-IF environment
15 | ```
16 | conda create -n inverse python=3.9
17 | conda activate inverse
18 | conda install pytorch cudatoolkit=11.3 -c pytorch
19 | conda install pyg -c pyg -c conda-forge
20 | conda install pip
21 | pip install biotite
22 | pip install git+https://github.com/facebookresearch/esm.git
23 | ```
24 | or install the ESM-IF environment `esmif`
25 | ```
26 | conda env create -f esmif.yml
27 | ```
28 | * For the CoVES environment `coves`
29 | ```
30 | conda env create -f coves.yml
31 | ```
32 | * For installing Triad command line, see instructions [here](https://triad.protabit.com/api/static/doc/user/userGettingStarted.html)
33 | * For running ESM-2 fintuning simulations, use the `finetune.yml` environment
34 | ```
35 | conda env create -f finetune.yml
36 | ```
37 | * Frozen environment can be found in `envs/frozen`
38 |
39 | ### Datasets
40 | * The `data/` folder is organized by protein type. Each protein directory contains:
41 | - `.fasta`: FASTA file for the parent sequence
42 | - `.pdb`: PDB file for the parent structure
43 | - `.model`: EVmutation model file
44 | - `fitness_landscape/`: Folder containing CSV files for all fitness landscapes for this protein type, each listing amino acid substitutions and their corresponding fitness values from the original sources
45 | - `scale2max/`: the folder containing processed fitness csv files returned from the `process_all` function in the `SSMuLA.fitness_process_vis` module where the maximum fitness value is normalized to 1 for each landscape
46 |
47 | * Landscapes summarized in the table below and described in detail in the paper:
48 |
49 | | Landscape | PDB ID | Sites |
50 | |-----------|--------|------------------------|
51 | | ParD2 | 6X0A | I61, L64, K80 |
52 | | ParD3 | 5CEG | D61, K64, E80 |
53 | | GB1 | 2GI9 | V39, D40, G41, V54 |
54 | | DHFR | 6XG5 | A26, D27, L28 |
55 | | T7 | 1CEZ | N748, R756, Q758 |
56 | | TEV | 1LVM | T146, D148, H167, S170 |
57 | | TrpB3A | 8VHH | A104, E105, T106 |
58 | | TrpB3B | | E105, T106, G107 |
59 | | TrpB3C | | T106, G107, A108 |
60 | | TrpB3D | | T117, A118, A119 |
61 | | TrpB3E | | F184, G185, S186 |
62 | | TrpB3F | | L162, I166, Y301 |
63 | | TrpB3G | | V227, S228, Y301 |
64 | | TrpB3H | | S228, G230, S231 |
65 | | TrpB3I | | Y182, V183, F184 |
66 | | TrpB4 | | V183, F184, V227, S228 |
67 |
68 |
69 | ### Preprocessing
70 | * Run
71 | ```
72 | python -m tests.test_preprocess
73 | ```
74 | refer to the test file and the script documentation for further details
75 | * Processed with `fitness_process_vis`
76 | * Rename columns to be `AAs`, `AA1`, `AA2`, `AA3`, `AA4`, `fitness`, add `active` if not already there and add `muts` columns
77 | * Scale to `max` (with option to scale to `parent`)
78 | * Processed data saved in `scale2max` folder
79 | * The landscape stats will be saved
80 |
81 |
82 | ### Landscape attributes
83 | #### Local optima
84 | * Run
85 | ```
86 | python -m tests.local_optima
87 | ```
88 | refer to the test file and the script documentation for further details
89 | * Calculate local optima with `calc_local_optima` function in `SSMuLA.local_optima`
90 |
91 | #### Pairwise epistasis
92 | * Run
93 | ```
94 | python -m tests.pairwise_epistasis
95 | ```
96 | refer to the test file and the script documentation for further details
97 | * Calculate pairwise epistasis with `calc_all_pairwise_epistasis` function in `SSMuLA.pairwise_epistasis`
98 | * Start from all active variants scaled to max fitness without post filtering
99 | * Initial results will be saved under the default path `results/pairwise_epistasis` folder (corresponding to the `active_start` subfolder in the zenodo repo)
100 | * Post processing the output with `plot_pairwise_epistasis` function in `SSMuLA.pairwise_epistasis`
101 | * Post processed results will be saved under the default path `results/pairwise_epistasis_dets` folder with summary files (corresponding to the `processed` subfolder) and `results/pairwise_epistasis_vis` for each of the landscape with a master summary file across all landscapes (in the `pairwise_epistasis_summary.csv`)
102 |
103 | ### Zero-shot
104 | * The currrent pipeline runs EVmutation and ESM together, and then append the rest based
105 |
106 | #### EVmutation
107 | * All EVmutation predictions run with [EVcouplings](https://v2.evcouplings.org/)
108 | * All settings remain default
109 | * Model parameters in the `.model` files are downloaded and renamed
110 |
111 | #### ESM
112 | * The logits will be generated and saved in the output folder
113 | * Run
114 | ```
115 | python -m tests.test_ev_esm
116 | ```
117 | refer to the test file and the script documentation for further details
118 |
119 | #### Hamming distance
120 | * Directly calculated from `n_mut`
121 | * For Hamming ditsance testing, run
122 | ```
123 | python -m tests.hamming_distance
124 | ```
125 | to deploy `run_hd_avg_fit` and `run_hd_avg_metric` from `SSMuLA.calc_hd`
126 | refer to the test file and the script documentation for further details
127 |
128 | #### ESM-IF
129 | * Run
130 | ```
131 | python -m tests.test_esmif
132 | ```
133 | refer to the test file and the script documentation for further details
134 | * Generate the input fasta files with `get_all_mutfasta` from `SSMuLA.zs_data` to be used in ESM-IF
135 | * Set up the environment for [ESM-IF](https://github.com/facebookresearch/esm?tab=readme-ov-file#invf) to
136 | ```
137 | conda create -n inverse python=3.9
138 | conda activate inverse
139 | conda install pytorch cudatoolkit=11.3 -c pytorch
140 | conda install pyg -c pyg -c conda-forge
141 | conda install pip
142 | pip install biotite
143 | pip install git+https://github.com/facebookresearch/esm.git
144 | ```
145 | or use
146 | * With in the `esmif` folder within the new environment, run
147 | ```
148 | ./esmif.sh
149 | ```
150 | * ESM-IF results will be saved in the same directory as the `esmif.sh` script
151 |
152 | #### CoVES
153 | * Follow the instructions in the [CoVES](https://github.com/ddingding/CoVES/tree/publish)
154 | * Prepare input data in the `coves_data` folder
155 | * Run `run_all_coves` from `SSMuLA.run_coves` to get all scores
156 | * Append scores with `append_all_coves_scores` from `SSMuLA.run_coves`
157 |
158 | #### Triad
159 | * Prep mutation file in `.mut` format such as `A_1A+A_2A+A_3A+A_4A` with `TriadGenMutFile` class in `SSMuLA.triad_prepost`
160 | * Run
161 | ```
162 | python -m tests.test_triad_pre
163 | ```
164 | refer to the test file and the script documentation for further details
165 | * With `triad-2.1.3` local command line
166 | * Prepare structure with `2prep_structures.sh`
167 | * Run `3getfixed.sh`
168 | * Parse results with `ParseTriadResults` class in `SSMuLA.triad_prepost`
169 |
170 |
171 | #### Combine all zs
172 | * Run
173 | ```
174 | python -m tests.test_zs
175 | ```
176 | refer to the test file and the script documentation for further details
177 |
178 | ### Simulations
179 | #### DE
180 | * Run `de_simulations` and visualise with `plot_de_simulations`
181 | * Run
182 | ```
183 | python -m tests.test_de
184 | ```
185 | and
186 | ```
187 | python -m tests.test_de_vis
188 | ```
189 | refer to the test file and the script documentation for further details
190 |
191 |
192 | #### MLDE and ftMLDE
193 | * Use `MLDE_lite` environment
194 | * For using learned ESM embeddings, first run `gen_all_learned_emb` from `SSMuLA.gen_learned_emb`, else skip this step
195 | * Run
196 | ```
197 | python -m tests.test_gen_learned_emb
198 | ```
199 | * Run `run_all_mlde_parallelized` from `SSMuLA.mlde_lite` to run simulations
200 | * Run
201 | ```
202 | python -m tests.test_mlde
203 | ```
204 | * Important options including:
205 | * `n_mut_cutoffs`: list of integers for Hamming distance cutoff options where `[0]` means none and `[2]` for Hamming distance of two for ensemble
206 | * `zs_predictors`: list of strings for zero-shot predictors, i.e. `["none", "Triad", "ev", "esm"]` where `none` means not focused training and thus default MLDE runs; the list can be extended for non-Hamming distance ensemble, including `["Triad-esmif", "Triad-ev", "Triad-esm", "two-best"]`
207 | * `ft_lib_fracs`: list of floats for fraction of libraries to use for focused training, i.e. `[0.5, 0.25, 0.125]`
208 | * `encoding`: list of strings for encoding options, i.e. `["one-hot"] + DEFAULT_LEARNED_EMB_COMBO`
209 | * `model_classes`: list of strings for model classes, i.e. `["boosting", "ridge"]`
210 | * `n_samples`: list of integers for number of training samples to use, i.e. `[96, 384]`
211 | * `n_split`: integer for number of splits for cross-validation, i.e. `5`
212 | * `n_replicate`: integer for number of replicates for each model, i.e. `50`
213 | * `n_tops`: integer for number of variants to test the prediction, i.e. `[96, 384]`
214 | refer to the test file and the script documentation for further details
215 | * Run `MLDESum` from `SSMuLA.mlde_analysis` to get the summary dataframe and optional visualization
216 | ```
217 | python -m tests.test_mlde_vis
218 | ```
219 |
220 | #### ALDE and ftALDE
221 | * See details in [alde4ssmula](https://github.com/fhalab/alde4ssmula) repository
222 | * `aggregate_alde_df` from `SSMuLA.alde_analysis` to get the summary dataframe
223 | ```
224 | python -m tests.test_alde
225 | ```
226 |
227 | #### Fine-tuning
228 | * Run `train_predict_per_protein` from `SSMuLA.plm_finetune` for ESM-2 LoRA fine-tuning simulations
229 |
230 | ### Analysis and paper figures
231 | * All notebooks in `fig_notebooks` are used to reproduce figures in the paper with files downloaded from [Zenodo]((https://doi.org/10.5281/zenodo.13910506))
232 |
233 | ## Contact
234 | * [Francesca-Zhoufan Li](mailto:fzl@caltech.edu)
--------------------------------------------------------------------------------
/SSMuLA/alde_analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for combining and analyzing the results of the ALDE analysis.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import os
8 | import pandas as pd
9 | import numpy as np
10 |
11 |
12 | from SSMuLA.landscape_global import N_SAMPLE_LIST, LOWN_DICT
13 | from SSMuLA.zs_analysis import ZS_OPTS, map_zs_labels
14 |
15 |
16 | def avg_alde_df(
17 | eq_n: int,
18 | lib_list: list,
19 | zs: str = "",
20 | alde_model: str = "Boosting Ensemble",
21 | alde_encoding: str = "onehot",
22 | alde_acq: str = "GREEDY",
23 | alde_dir: str = "/disk2/fli/alde4ssmula",
24 | ) -> pd.DataFrame:
25 |
26 | """
27 | Average ALDE results for a given list of libraries and equal n.
28 |
29 | Args:
30 | - eq_n (int): Equal n for the libraries.
31 | - lib_list (list): List of libraries to aggregate.
32 |
33 | Returns:
34 | - df (pd.DataFrame): Aggregated ALDE results.
35 | """
36 |
37 | df = pd.DataFrame(
38 | columns=[
39 | "n_sample",
40 | "top_maxes_mean",
41 | "top_maxes_std",
42 | "if_truemaxs_mean",
43 | "if_truemaxs_std",
44 | ]
45 | )
46 |
47 | for n in N_SAMPLE_LIST:
48 |
49 | if zs != "":
50 | zs_append = f"{zs}_"
51 | else:
52 | zs_append = ""
53 |
54 | if eq_n == 1:
55 | csv_path = f"{alde_dir}/results/{zs_append}all_{str(n)}+96/all_results.csv"
56 |
57 | else:
58 | csv_path = f"{alde_dir}/results/{zs_append}{str(eq_n)}eq_{str(int((n+96)/eq_n))}/all_results.csv"
59 |
60 | if os.path.exists(csv_path):
61 | a_df = pd.read_csv(csv_path)
62 |
63 | # Get the max Timestep for each Protein
64 | max_timesteps = a_df.groupby("Protein")["Timestep"].transform("max")
65 | # DNN Ensemble
66 | # Boosting Ensemble
67 | slice_df = a_df[
68 | (a_df["Encoding"] == alde_encoding)
69 | & (a_df["Acquisition"] == alde_acq)
70 | & (a_df["Model"] == alde_model)
71 | & (a_df["Protein"].isin(lib_list))
72 | & (a_df["Timestep"] == max_timesteps)
73 | ]
74 | # for each Protein take the max of the timestep
75 |
76 | if len(lib_list) == 1:
77 | top_maxes_std = slice_df["Std"].mean()
78 | if_truemaxs_std = 0
79 | else:
80 | top_maxes_std = slice_df["Mean"].std()
81 | if_truemaxs_std = slice_df["Frac"].std()
82 |
83 | df = df._append(
84 | {
85 | "n_sample": n,
86 | "top_maxes_mean": slice_df["Mean"].mean(),
87 | "top_maxes_std": top_maxes_std,
88 | "if_truemaxs_mean": slice_df["Frac"].mean(),
89 | "if_truemaxs_std": if_truemaxs_std,
90 | },
91 | ignore_index=True,
92 | )
93 | elif "ds-ed" in csv_path:
94 | continue
95 | else:
96 | print(f"File not found: {csv_path}")
97 |
98 | df = df._append(
99 | {
100 | "n_sample": n,
101 | "top_maxes_mean": np.nan,
102 | "top_maxes_std": np.nan,
103 | "if_truemaxs_mean": np.nan,
104 | "if_truemaxs_std": np.nan,
105 | },
106 | ignore_index=True,
107 | )
108 |
109 | return df.set_index("n_sample")
110 |
111 |
112 | def aggregate_alde_df(
113 | eq_ns: list[int] = [1, 2, 3, 4],
114 | n_list: list[int] = N_SAMPLE_LIST,
115 | zs_opts: list[str] = ["esmif", "ev", "coves", "ed", "esm", "Triad", ""],
116 | alde_dir: str = "/disk2/fli/alde4ssmula",
117 | alde_res_folder: str = "results",
118 | alde_df_path: str = "results/alde/alde_all.csv",
119 | ) -> pd.DataFrame:
120 |
121 | """
122 | Aggregate ALDE results for a given list of libraries and equal n.
123 |
124 | Args:
125 | - eq_ns (list): List of equal n values.
126 | - zs_opts (list): List of zero-shot options.
127 | - alde_dir (str): Directory containing ALDE results.
128 | - alde_df_path (str): Path to save the aggregated ALDE results.
129 |
130 | Returns:
131 | - df (pd.DataFrame): Aggregated ALDE results.
132 | """
133 |
134 | # initialize the dataframe
135 | alde_all = pd.DataFrame(
136 | columns=[
137 | "n_mut_cutoff",
138 | "zs",
139 | "rounds",
140 | "n_samples",
141 | "Protein",
142 | "Encoding",
143 | "Model",
144 | "Acquisition",
145 | "Timestep",
146 | "Mean",
147 | "Std",
148 | "Frac",
149 | ]
150 | )
151 |
152 | for eq_n in eq_ns:
153 |
154 | for zs in zs_opts + ["ds-" + z for z in zs_opts if z != ""]:
155 |
156 | if "ds-" in zs:
157 | n_mut = "double"
158 | else:
159 | n_mut = "all"
160 |
161 | for n in n_list:
162 |
163 | if isinstance(n, int):
164 | dir_det = str(int((n + 96) / eq_n))
165 | n_sample = n + 96
166 | else:
167 | dir_det = n
168 | n_sample = LOWN_DICT[n]
169 |
170 | if zs != "":
171 | zs_append = f"{zs}_"
172 | else:
173 | zs_append = ""
174 |
175 | if eq_n == 1:
176 | csv_path = f"{alde_dir}/{alde_res_folder}/{zs_append}all_{str(n)}+96/all_results.csv"
177 |
178 | else:
179 | csv_path = f"{alde_dir}/{alde_res_folder}/{zs_append}{str(eq_n)}eq_{dir_det}/all_results.csv"
180 |
181 | if os.path.exists(csv_path):
182 | print(f"Reading {csv_path}...")
183 | a_df = pd.read_csv(csv_path)
184 |
185 | max_timesteps = a_df.groupby("Protein")["Timestep"].transform("max")
186 | slice_df = a_df[a_df["Timestep"] == max_timesteps].copy()
187 |
188 | slice_df["n_mut_cutoff"] = n_mut
189 | slice_df["zs"] = zs
190 | slice_df["rounds"] = eq_n
191 | slice_df["n_samples"] = n_sample
192 |
193 | # replace T7_2 with T7
194 | # slice_df = slice_df.replace("T7_2", "T7")
195 |
196 | alde_all = alde_all._append(slice_df, ignore_index=True)
197 |
198 | else:
199 | print(f"File not found: {csv_path}")
200 |
201 | alde_all = alde_all._append(
202 | {
203 | "n_mut_cutoff": n_mut,
204 | "zs": zs,
205 | "rounds": eq_n,
206 | "n_samples": n_sample,
207 | "Protein": np.nan,
208 | "Encoding": np.nan,
209 | "Model": np.nan,
210 | "Acquisition": np.nan,
211 | "Timestep": np.nan,
212 | "Mean": np.nan,
213 | "Std": np.nan,
214 | "Frac": np.nan,
215 | },
216 | ignore_index=True,
217 | )
218 |
219 | alde_all = alde_all.dropna(subset=["Protein"])
220 |
221 | alde_all.to_csv(alde_df_path, index=False)
222 |
223 | return alde_all
224 |
225 |
226 | def get_ftalde_libavg(
227 | alde_csv: str,
228 | lib_list: list,
229 | n_total: int,
230 | n_round: int,
231 | models: list = ["Boosting Ensemble"],
232 | acquisition: list = ["GREEDY"],
233 | ):
234 |
235 | """
236 | Get the FT-ALDE data for each of the library, number of rounds, models, and acquisition method.
237 | Args:
238 | alde_csv (str): Path to the FT-ALDE CSV file.
239 | lib_list (list): List of libraries to filter.
240 | n_total (int): Total number of samples.
241 | n_round (int): Number of rounds.
242 | models (list): List of models to filter.
243 | acquisition (list): List of acquisition methods to filter.
244 | Returns:
245 | pd.DataFrame: Filtered DataFrame containing FT-ALDE data.
246 | """
247 | # have all none zs and zs opt for MLDE, ALDE different rounds
248 | alde_all = pd.read_csv(alde_csv)
249 | # Replace NaN values in column 'zs' with the string "none"
250 | alde_all["zs"] = alde_all["zs"].fillna("none")
251 |
252 | slice_df = alde_all[
253 | (alde_all["rounds"] == n_round)
254 | & (alde_all["Encoding"] == "onehot")
255 | & (alde_all["Model"].isin(models))
256 | & (alde_all["Acquisition"].isin(acquisition))
257 | & (alde_all["n_samples"] == n_total)
258 | & (alde_all["Protein"].isin(lib_list))
259 | # & (alde_all["n_mut_cutoff"] == "all")
260 | ].copy()
261 |
262 | # Convert 'Category' column to categorical with defined order
263 | slice_df["zs"] = pd.Categorical(
264 | slice_df["zs"],
265 | categories=["none"]
266 | + [o.replace("_score", "") for o in ZS_OPTS]
267 | + [
268 | "ds-esmif",
269 | "ds-ev",
270 | "ds-coves",
271 | "ds-Triad",
272 | "ds-esm",
273 | ],
274 | ordered=True,
275 | )
276 |
277 | slice_df = slice_df.sort_values(by=["zs", "Protein"])
278 |
279 | slice_df["zs"] = slice_df["zs"].apply(map_zs_labels)
280 |
281 | return (
282 | slice_df[["Protein", "zs", "Mean", "Frac"]]
283 | .rename(columns={"Protein": "lib", "Mean": "top_maxes", "Frac": "if_truemaxs"})
284 | .copy()
285 | )
286 |
287 |
288 | def clean_alde_df(
289 | agg_alde_df_path: str = "results/alde/alde_all.csv",
290 | clean_alde_df_path: str = "results/alde/alde_results.csv",
291 | ):
292 | """
293 | A function to clean up the aggregated ALDE results.
294 | """
295 |
296 | alde_df = pd.read_csv(agg_alde_df_path)
297 | alde_df[
298 | (alde_df["rounds"].isin([2, 3, 4]))
299 | & (alde_df["Model"].isin(["Boosting Ensemble", "DNN Ensemble"]))
300 | & (alde_df["Acquisition"] == "GREEDY")
301 | ].rename(
302 | columns={
303 | "Protein": "lib",
304 | "Mean": "top_maxes_mean",
305 | "Std": "top_maxes_std",
306 | "Frac": "if_truemaxs_mean",
307 | "Encoding": "encoding",
308 | "Model": "model",
309 | "n_samples": "n_sample",
310 | }
311 | )[
312 | [
313 | "encoding",
314 | "model",
315 | "n_sample",
316 | "top_maxes_mean",
317 | "top_maxes_std",
318 | "if_truemaxs_mean",
319 | "n_mut_cutoff",
320 | "lib",
321 | "zs",
322 | "rounds",
323 | ]
324 | ].reset_index(
325 | drop=True
326 | ).to_csv(
327 | clean_alde_df_path, index=False
328 | )
--------------------------------------------------------------------------------
/SSMuLA/triad_prepost.py:
--------------------------------------------------------------------------------
1 | """A script for generating mut file needed for triad"""
2 |
3 | import re
4 | import os
5 |
6 | from copy import deepcopy
7 | from glob import glob
8 |
9 | import pandas as pd
10 | import numpy as np
11 | from itertools import product
12 |
13 | from SSMuLA.aa_global import ALL_AAS
14 | from SSMuLA.landscape_global import LIB_INFO_DICT, LIB_NAMES, TrpB_names
15 | from SSMuLA.util import checkNgen_folder, get_file_name
16 |
17 |
18 | # TrpB_TRIAD_FOLDER = "/home/shared_data/triad_structures"
19 | TrpB_LIB_FOLDER = "data/TrpB/scaled2max"
20 | # TrpB3_TRIAD_TXT = deepcopy(sorted(list(glob(f"{TrpB_TRIAD_FOLDER}/*3*/*/*.txt"))))
21 | TrpB4_TRIAD_TXT = deepcopy(sorted(list(glob("triad/TrpB4/*.txt"))))
22 | # /disk2/fli/SSMuLA/triad/TrpB4
23 |
24 | lib_triad_pair = {}
25 |
26 | # append the other two lib
27 | for lib in LIB_NAMES:
28 | if lib != "TrpB4":
29 | lib_triad_pair[
30 | f"data/{lib}/scaled2max/{lib}.csv"
31 | ] = f"triad/{lib}/{lib}_fixed.txt"
32 |
33 | SORTED_LIB_TRIAD_PAIR = deepcopy(
34 | dict(sorted(lib_triad_pair.items(), key=lambda x: x[0]))
35 | )
36 |
37 |
38 | class TriadLib:
39 | """
40 | A class for common traid things for a given lib
41 | """
42 |
43 | def __init__(self, input_csv: str, triad_folder: str = "triad") -> None:
44 |
45 | """
46 | Args:
47 | - input_csv: str, the path to the input csv
48 | - output_folder: str, the path to the output folder
49 | """
50 |
51 | self._input_csv = input_csv
52 | self._triad_folder = os.path.normpath(triad_folder)
53 |
54 | @property
55 | def lib_name(self) -> str:
56 |
57 | """
58 | A property for the library name
59 | """
60 | return get_file_name(self._input_csv)
61 |
62 | @property
63 | def site_num(self) -> int:
64 |
65 | """
66 | A property for the site number
67 | """
68 | return len(LIB_INFO_DICT[self.lib_name]["positions"])
69 |
70 | @property
71 | def wt_aas(self) -> list:
72 | """
73 | A property for the wildtype amino acids
74 | """
75 | return list(LIB_INFO_DICT[self.lib_name]["AAs"].values())
76 |
77 | @property
78 | def prefixes(self) -> list:
79 | """
80 | A property for the prefixes
81 | """
82 | return [
83 | f"A_{pos}" for pos in LIB_INFO_DICT[self.lib_name]["positions"].values()
84 | ]
85 |
86 | @property
87 | def df(self) -> pd.DataFrame:
88 | """
89 | A property for the dataframe and drop stop codons
90 | """
91 | return pd.read_csv(self._input_csv)
92 |
93 | @property
94 | def df_no_stop(self) -> pd.DataFrame:
95 | """
96 | A property for the dataframe and drop stop codons
97 | """
98 | return self.df[~self.df["AAs"].str.contains("\*")]
99 |
100 | @property
101 | def variants(self) -> list:
102 | """
103 | A AA sequence for the variants
104 | """
105 | return self.df_no_stop["AAs"].values.tolist()
106 |
107 | @property
108 | def mut_numb(self) -> int:
109 | """
110 | A property for the number of mutations
111 | """
112 | return len(self.df_no_stop)
113 |
114 |
115 | class TriadGenMutFile(TriadLib):
116 | """
117 | A class for generating a mut file for triad
118 | """
119 |
120 | def __init__(self, input_csv: str, triad_folder: str = "triad") -> None:
121 |
122 | """
123 | Args:
124 | - input_csv: str, the path to the input csv
125 | - output_folder: str, the path to the output folder
126 | """
127 |
128 | super().__init__(input_csv, triad_folder)
129 |
130 | print(f"Generating {self.mut_path} from {self._input_csv}...")
131 | self._mutation_encodings = self._generate_mut_file()
132 |
133 | def _generate_mut_file(self) -> None:
134 | """
135 | Generate the mut file
136 | """
137 |
138 | # Loop over variants
139 | mutation_encodings = []
140 |
141 | for variant in self.variants:
142 |
143 | # Loop over each character in the variant
144 | mut_encoding_list = []
145 | for j, (var_char, wt_char) in enumerate(zip(variant, self.wt_aas)):
146 |
147 | # If the var_char does not equal the wt_char, append
148 | if var_char != wt_char:
149 | mut_encoding_list.append(self.prefixes[j] + var_char)
150 |
151 | # If the mut_encoding_list has no entries, continue (this is wild type)
152 | if len(mut_encoding_list) == 0:
153 | continue
154 |
155 | # Otherwise, append to mutation_encodings
156 | else:
157 | mutation_encodings.append("+".join(mut_encoding_list) + "\n")
158 |
159 | # check before saving
160 | # assert len(mutation_encodings) == self.mut_numb - 1
161 |
162 | # Save the mutants
163 | with open(self.mut_path, "w") as f:
164 | f.writelines(mutation_encodings)
165 |
166 | return mutation_encodings
167 |
168 | @property
169 | def mut_path(self) -> str:
170 | """
171 | A property for the mut file path
172 | """
173 | sub_folder = checkNgen_folder(os.path.join(self._triad_folder, self.lib_name))
174 | return os.path.join(sub_folder, f"{self.lib_name}.mut")
175 |
176 | @property
177 | def mut_encoding(self) -> list:
178 | """
179 | A property for the mutation encodings
180 | """
181 | return self._mutation_encodings
182 |
183 |
184 | class ParseTriadResults(TriadLib):
185 | """
186 | A class for parsing the triad results
187 | """
188 |
189 | def __init__(
190 | self,
191 | input_csv: str,
192 | triad_txt: str,
193 | triad_folder: str = "triad",
194 | ) -> None:
195 |
196 | """
197 | Args:
198 | - input_csv: str, the path to the input csv
199 | - triad_txt: str, the path to the triad txt file
200 | - triad_folder: str, the parent folder to all triad data
201 | """
202 |
203 | super().__init__(input_csv, triad_folder)
204 |
205 | self._triad_txt = triad_txt
206 |
207 | print(f"Parsing {self._triad_txt} and save to {self.triad_csv}...")
208 |
209 | # extract triad score into dataframe
210 | self._triad_df = self._get_triad_score()
211 |
212 | # save the triad dataframe unless trpb4 as need to NOT overwrite
213 | if self.lib_name != "TrpB4":
214 | self._triad_df.to_csv(self.triad_csv, index=False)
215 |
216 | def _get_triad_score(self) -> float:
217 |
218 | """
219 | A function to load the output of a triad analysis and get a score
220 |
221 | Args:
222 | - triad_output_file: str, the path to the triad output file
223 | - WT_combo: str, the wildtype combo
224 | - num_seqs: int, the number of sequences to load
225 | """
226 |
227 | # Load the output file
228 | with open(self._triad_txt) as f:
229 |
230 | # Set some flags for starting analysis
231 | solutions_started = False
232 | record_start = False
233 |
234 | # Begin looping over the file
235 | summary_lines = []
236 | for line in f:
237 |
238 | # Start looking at data once we hit "solution"
239 | # if "Solution" in line:
240 | if "All sequences:" in line:
241 | solutions_started = True
242 |
243 | # Once we have "Index" we can start recording the rest
244 | if solutions_started and "Index" in line:
245 | record_start = True
246 |
247 | # Record appropriate data
248 | if record_start:
249 |
250 | # Strip the newline and split on whitespace
251 | summary_line = line.strip().split()
252 |
253 | if summary_line[0] == "Average":
254 | break
255 | else:
256 | # Otherwise, append the line
257 | summary_lines.append(summary_line)
258 |
259 | # Build the dataframe with col ['Index', 'Tags', 'Score', 'Seq', 'Muts']
260 | all_results = pd.DataFrame(summary_lines[1:], columns=summary_lines[0])
261 | all_results["Triad_score"] = all_results["Score"].astype(float)
262 |
263 | wt_chars = self.wt_aas
264 | reconstructed_combos = [
265 | "".join(
266 | [char if char != "-" else wt_chars[i] for i, char in enumerate(seq)]
267 | )
268 | for seq in all_results.Seq.values
269 | ]
270 | all_results["AAs"] = reconstructed_combos
271 |
272 | # Get the order
273 | all_results["Triad_rank"] = np.arange(1, len(all_results) + 1)
274 |
275 | return all_results[["AAs", "Triad_score", "Triad_rank"]]
276 |
277 | @property
278 | def triad_csv(self) -> str:
279 | """
280 | A property for the triad csv
281 | """
282 | return os.path.join(self._triad_folder, self.lib_name, f"{self.lib_name}.csv")
283 |
284 | @property
285 | def triad_df(self) -> pd.DataFrame:
286 | """
287 | A property for the triad dataframe
288 | """
289 | return self._triad_df
290 |
291 |
292 | def run_traid_gen_mut_file(all_lib: bool = True, lib_list: list[str] = []):
293 | """
294 | Run the triad gen mut file function for all libraries
295 |
296 | Args:
297 | - all_lib: bool, whether to run for all libraries
298 | - lib_list: list, a list of libraries to run for
299 | """
300 |
301 | if all_lib or len(lib_list) == 0:
302 | lib_list = glob("data/*/scale2max/*.csv")
303 |
304 | for lib in lib_list:
305 | TriadGenMutFile(input_csv=lib)
306 |
307 |
308 | def run_parse_triad_results(
309 | triad_folder: str = "triad",
310 | all_lib: bool = True,
311 | lib_list: list[str] = []
312 | ):
313 |
314 | """
315 | Run the parse triad results function for all libraries
316 |
317 | Args:
318 | - triad_folder: str, the parent folder to all triad data
319 | - all_lib: bool, whether to run for all libraries
320 | - lib_list: list, a list of libraries to run for
321 | """
322 |
323 | for lib, triad_txt in SORTED_LIB_TRIAD_PAIR.items():
324 | ParseTriadResults(input_csv=lib, triad_txt=triad_txt, triad_folder=triad_folder)
325 |
326 | # need to merge
327 | trpb4_dfs = []
328 | print(f"Parsing TrpB4 {TrpB4_TRIAD_TXT}...")
329 |
330 | for triad_txt in TrpB4_TRIAD_TXT:
331 | # ignore rank for now
332 | trpb4_dfs.append(
333 | ParseTriadResults(
334 | input_csv=os.path.join(TrpB_LIB_FOLDER, "TrpB4.csv"),
335 | triad_txt=triad_txt,
336 | triad_folder=triad_folder,
337 | )
338 | .triad_df[["AAs", "Triad_score"]]
339 | .copy()
340 | )
341 |
342 | # there will be multip wt from each file so need to drop them
343 | trpb4_df = pd.concat(trpb4_dfs).drop_duplicates().sort_values(["Triad_score"])
344 |
345 | # resort and overwrite
346 | trpb4_df["Triad_rank"] = np.arange(1, len(trpb4_df) + 1)
347 | print("Added new rank for TrpB4")
348 |
349 | trpb4_df.to_csv(os.path.join(triad_folder, "TrpB4", "TrpB4.csv"), index=False)
--------------------------------------------------------------------------------
/envs/finetune.yml:
--------------------------------------------------------------------------------
1 | # conda env update --file finetune.yml --prune
2 |
3 | name: finetune
4 | channels:
5 | - pytorch
6 | - nvidia
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=2_kmp_llvm
12 | - abseil-cpp=20211102.0=hd4dd3e8_0
13 | - absl-py=1.3.0=py39h06a4308_0
14 | - aiohttp=3.8.3=py39h5eee18b_0
15 | - aiosignal=1.2.0=pyhd3eb1b0_0
16 | - anyio=3.5.0=py39h06a4308_0
17 | - argon2-cffi=21.3.0=pyhd3eb1b0_0
18 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0
19 | - asttokens=2.0.5=pyhd3eb1b0_0
20 | - astunparse=1.6.3=py_0
21 | - async-timeout=4.0.2=py39h06a4308_0
22 | - attrs=22.1.0=py39h06a4308_0
23 | - babel=2.11.0=py39h06a4308_0
24 | - backcall=0.2.0=pyhd3eb1b0_0
25 | - beautifulsoup4=4.11.1=py39h06a4308_0
26 | - biopython=1.83=py39hd1e30aa_0
27 | - blas=1.0=mkl
28 | - bleach=4.1.0=pyhd3eb1b0_0
29 | - blinker=1.4=py39h06a4308_0
30 | - brotli=1.1.0=hd590300_1
31 | - brotli-bin=1.1.0=hd590300_1
32 | - brotlipy=0.7.0=py39h27cfd23_1003
33 | - c-ares=1.18.1=h7f8727e_0
34 | - ca-certificates=2024.2.2=hbcca054_0
35 | - cachetools=4.2.2=pyhd3eb1b0_0
36 | - certifi=2024.2.2=pyhd8ed1ab_0
37 | - cffi=1.15.1=py39h5eee18b_3
38 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
39 | - click=8.0.4=py39h06a4308_0
40 | - comm=0.1.2=py39h06a4308_0
41 | - cryptography=38.0.4=py39h9ce1e76_0
42 | - cuda=11.7.1=0
43 | - cuda-cccl=11.7.91=0
44 | - cuda-command-line-tools=11.7.1=0
45 | - cuda-compiler=11.7.1=0
46 | - cuda-cudart=11.7.99=0
47 | - cuda-cudart-dev=11.7.99=0
48 | - cuda-cuobjdump=11.7.91=0
49 | - cuda-cupti=11.7.101=0
50 | - cuda-cuxxfilt=11.7.91=0
51 | - cuda-demo-suite=12.0.140=0
52 | - cuda-documentation=12.0.140=0
53 | - cuda-driver-dev=11.7.99=0
54 | - cuda-gdb=12.0.140=0
55 | - cuda-libraries=11.7.1=0
56 | - cuda-libraries-dev=11.7.1=0
57 | - cuda-memcheck=11.8.86=0
58 | - cuda-nsight=12.0.140=0
59 | - cuda-nsight-compute=12.0.1=0
60 | - cuda-nvcc=11.7.99=0
61 | - cuda-nvdisasm=12.0.140=0
62 | - cuda-nvml-dev=11.7.91=0
63 | - cuda-nvprof=11.7.101=0
64 | - cuda-nvprune=11.7.91=0
65 | - cuda-nvrtc=11.7.99=0
66 | - cuda-nvrtc-dev=11.7.99=0
67 | - cuda-nvtx=11.7.91=0
68 | - cuda-nvvp=12.0.146=0
69 | - cuda-runtime=11.7.1=0
70 | - cuda-sanitizer-api=12.0.140=0
71 | - cuda-toolkit=11.7.1=0
72 | - cuda-tools=11.7.1=0
73 | - cuda-visual-tools=11.7.1=0
74 | - cudatoolkit=11.7.0=hd8887f6_11
75 | - cudatoolkit-dev=11.7.0=h1de0b5d_6
76 | - cudnn=8.4.1.50=hed8a83a_0
77 | - debugpy=1.5.1=py39h295c915_0
78 | - decorator=5.1.1=pyhd3eb1b0_0
79 | - defusedxml=0.7.1=pyhd3eb1b0_0
80 | - entrypoints=0.4=py39h06a4308_0
81 | - et_xmlfile=1.1.0=py39h06a4308_0
82 | - executing=0.8.3=pyhd3eb1b0_0
83 | - fftw=3.3.9=h27cfd23_1
84 | - flit-core=3.6.0=pyhd3eb1b0_0
85 | - freetype=2.12.1=hca18f0e_0
86 | - frozenlist=1.3.3=py39h5eee18b_0
87 | - gast=0.4.0=pyhd3eb1b0_0
88 | - gds-tools=1.5.1.14=0
89 | - giflib=5.2.1=h5eee18b_1
90 | - google-auth=2.6.0=pyhd3eb1b0_0
91 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
92 | - google-pasta=0.2.0=pyhd3eb1b0_0
93 | - grpc-cpp=1.46.4=h6fc47f4_3
94 | - grpcio=1.46.4=py39h4587e31_3
95 | - h5py=3.7.0=py39h737f45e_0
96 | - hdf5=1.10.6=h3ffc7dd_1
97 | - icu=70.1=h27087fc_0
98 | - idna=3.4=py39h06a4308_0
99 | - importlib-metadata=4.11.3=py39h06a4308_0
100 | - intel-openmp=2021.4.0=h06a4308_3561
101 | - ipykernel=6.19.2=py39hb070fc8_0
102 | - ipython=8.9.0=py39h06a4308_0
103 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
104 | - ipywidgets=8.0.4=pyhd8ed1ab_0
105 | - jedi=0.18.1=py39h06a4308_1
106 | - jinja2=3.1.2=py39h06a4308_0
107 | - joblib=1.1.1=py39h06a4308_0
108 | - jpeg=9e=h7f8727e_0
109 | - json5=0.9.6=pyhd3eb1b0_0
110 | - jsonschema=4.16.0=py39h06a4308_0
111 | - jupyter_client=7.4.9=py39h06a4308_0
112 | - jupyter_contrib_core=0.4.0=pyhd8ed1ab_0
113 | - jupyter_contrib_nbextensions=0.7.0=pyhd8ed1ab_0
114 | - jupyter_core=5.1.1=py39h06a4308_0
115 | - jupyter_highlight_selected_word=0.2.0=py39hf3d152e_1005
116 | - jupyter_latex_envs=1.4.6=pyhd8ed1ab_1002
117 | - jupyter_nbextensions_configurator=0.6.1=pyhd8ed1ab_0
118 | - jupyter_server=1.23.4=py39h06a4308_0
119 | - jupyterlab=3.5.3=py39h06a4308_0
120 | - jupyterlab_pygments=0.1.2=py_0
121 | - jupyterlab_server=2.16.5=py39h06a4308_0
122 | - jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
123 | - keras=2.9.0=py39h06a4308_0
124 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0
125 | - krb5=1.19.4=h568e23c_0
126 | - lcms2=2.12=h3be6417_0
127 | - ld_impl_linux-64=2.38=h1181459_1
128 | - lerc=3.0=h295c915_0
129 | - libabseil=20211102.0=cxx17_h48a1fff_2
130 | - libbrotlicommon=1.1.0=hd590300_1
131 | - libbrotlidec=1.1.0=hd590300_1
132 | - libbrotlienc=1.1.0=hd590300_1
133 | - libcublas=11.10.3.66=0
134 | - libcublas-dev=11.10.3.66=0
135 | - libcufft=10.7.2.124=h4fbf590_0
136 | - libcufft-dev=10.7.2.124=h98a8f43_0
137 | - libcufile=1.5.1.14=0
138 | - libcufile-dev=1.5.1.14=0
139 | - libcurand=10.3.1.124=0
140 | - libcurand-dev=10.3.1.124=0
141 | - libcurl=7.87.0=h91b91d3_0
142 | - libcusolver=11.4.0.1=0
143 | - libcusolver-dev=11.4.0.1=0
144 | - libcusparse=11.7.4.91=0
145 | - libcusparse-dev=11.7.4.91=0
146 | - libdeflate=1.17=h5eee18b_1
147 | - libedit=3.1.20221030=h5eee18b_0
148 | - libev=4.33=h7f8727e_1
149 | - libffi=3.4.2=h6a678d5_6
150 | - libgcc-ng=12.2.0=h65d4601_19
151 | - libgfortran-ng=11.2.0=h00389a5_1
152 | - libgfortran5=11.2.0=h1234567_1
153 | - libiconv=1.17=h166bdaf_0
154 | - libnghttp2=1.46.0=hce63b2e_0
155 | - libnpp=11.7.4.75=0
156 | - libnpp-dev=11.7.4.75=0
157 | - libnvjpeg=11.8.0.2=0
158 | - libnvjpeg-dev=11.8.0.2=0
159 | - libpng=1.6.37=hbc83047_0
160 | - libprotobuf=3.20.3=he621ea3_0
161 | - libsodium=1.0.18=h7b6447c_0
162 | - libssh2=1.10.0=h8f2d780_0
163 | - libstdcxx-ng=12.2.0=h46fd767_19
164 | - libtiff=4.5.1=h6a678d5_0
165 | - libwebp-base=1.3.2=hd590300_0
166 | - libxml2=2.10.3=h7463322_0
167 | - libxslt=1.1.37=h873f0b0_0
168 | - libzlib=1.2.13=h166bdaf_4
169 | - llvm-openmp=14.0.6=h9e868ea_0
170 | - lxml=4.9.2=py39h14694de_0
171 | - markdown=3.4.1=py39h06a4308_0
172 | - markupsafe=2.1.1=py39h7f8727e_0
173 | - matplotlib-base=3.5.3=py39h19d6b11_2
174 | - matplotlib-inline=0.1.6=py39h06a4308_0
175 | - mistune=0.8.4=py39h27cfd23_1000
176 | - mkl=2021.4.0=h06a4308_640
177 | - mkl-service=2.4.0=py39h7f8727e_0
178 | - mkl_fft=1.3.1=py39hd3c417c_0
179 | - mkl_random=1.2.2=py39h51133e4_0
180 | - multidict=6.0.2=py39h5eee18b_0
181 | - munkres=1.1.4=pyh9f0ad1d_0
182 | - nb_conda=2.2.1=unix_6
183 | - nb_conda_kernels=2.3.1=py39h06a4308_0
184 | - nbclassic=0.4.8=py39h06a4308_0
185 | - nbclient=0.5.13=py39h06a4308_0
186 | - nbconvert=6.4.4=py39h06a4308_0
187 | - nbformat=5.7.0=py39h06a4308_0
188 | - nccl=2.14.3.1=h0800d71_0
189 | - ncurses=6.4=h6a678d5_0
190 | - nest-asyncio=1.5.6=py39h06a4308_0
191 | - notebook=6.5.2=py39h06a4308_0
192 | - notebook-shim=0.2.2=py39h06a4308_0
193 | - nsight-compute=2022.4.1.6=0
194 | - numpy=1.22.3=py39he7a7128_0
195 | - numpy-base=1.22.3=py39hf524024_0
196 | - oauthlib=3.2.1=py39h06a4308_0
197 | - openjpeg=2.4.0=h3ad879b_0
198 | - openpyxl=3.0.10=py39h5eee18b_0
199 | - openssl=1.1.1w=hd590300_0
200 | - opt_einsum=3.3.0=pyhd3eb1b0_1
201 | - packaging=22.0=py39h06a4308_0
202 | - pandocfilters=1.5.0=pyhd3eb1b0_0
203 | - parso=0.8.3=pyhd3eb1b0_0
204 | - patsy=0.5.6=pyhd8ed1ab_0
205 | - pexpect=4.8.0=pyhd3eb1b0_3
206 | - pickleshare=0.7.5=pyhd3eb1b0_1003
207 | - pip=22.3.1=py39h06a4308_0
208 | - platformdirs=2.5.2=py39h06a4308_0
209 | - prometheus_client=0.14.1=py39h06a4308_0
210 | - prompt-toolkit=3.0.36=py39h06a4308_0
211 | - protobuf=3.20.3=py39h6a678d5_0
212 | - psutil=5.9.0=py39h5eee18b_0
213 | - ptyprocess=0.7.0=pyhd3eb1b0_2
214 | - pure_eval=0.2.2=pyhd3eb1b0_0
215 | - pyasn1=0.4.8=pyhd3eb1b0_0
216 | - pyasn1-modules=0.2.8=py_0
217 | - pycparser=2.21=pyhd3eb1b0_0
218 | - pygments=2.11.2=pyhd3eb1b0_0
219 | - pyjwt=2.4.0=py39h06a4308_0
220 | - pyopenssl=22.0.0=pyhd3eb1b0_0
221 | - pyrsistent=0.18.0=py39heee7806_0
222 | - pysocks=1.7.1=py39h06a4308_0
223 | - python=3.9.16=h7a1cb2a_0
224 | - python-dateutil=2.8.2=pyhd3eb1b0_0
225 | - python-fastjsonschema=2.16.2=py39h06a4308_0
226 | - python-flatbuffers=1.12=pyhd3eb1b0_0
227 | - python-tzdata=2024.1=pyhd8ed1ab_0
228 | - python_abi=3.9=2_cp39
229 | - pytorch=1.13.1=py3.9_cuda11.7_cudnn8.5.0_0
230 | - pytorch-cuda=11.7=h67b0de4_1
231 | - pytorch-mutex=1.0=cuda
232 | - pyyaml=6.0=py39h5eee18b_1
233 | - pyzmq=23.2.0=py39h6a678d5_0
234 | - re2=2022.06.01=h27087fc_1
235 | - readline=8.2=h5eee18b_0
236 | - requests=2.28.1=py39h06a4308_0
237 | - requests-oauthlib=1.3.0=py_0
238 | - rsa=4.7.2=pyhd3eb1b0_1
239 | - scikit-learn=1.2.0=py39h6a678d5_0
240 | - scipy=1.7.3=py39h6c91a56_2
241 | - seaborn=0.13.2=hd8ed1ab_0
242 | - seaborn-base=0.13.2=pyhd8ed1ab_0
243 | - send2trash=1.8.0=pyhd3eb1b0_1
244 | - sentencepiece=0.1.95=py39hd09550d_0
245 | - setuptools=65.6.3=py39h06a4308_0
246 | - six=1.16.0=pyhd3eb1b0_1
247 | - snappy=1.1.9=h295c915_0
248 | - sniffio=1.2.0=py39h06a4308_1
249 | - soupsieve=2.3.2.post1=py39h06a4308_0
250 | - sqlite=3.40.1=h5082296_0
251 | - stack_data=0.2.0=pyhd3eb1b0_0
252 | - statsmodels=0.14.0=py39ha9d4c09_0
253 | - tensorboard=2.9.0=py39h06a4308_0
254 | - tensorboard-data-server=0.6.1=py39h52d8a92_0
255 | - tensorboard-plugin-wit=1.8.1=py39h06a4308_0
256 | - tensorflow=2.9.1=cuda112py39h01bd6f0_0
257 | - tensorflow-base=2.9.1=cuda112py39h81abfd3_0
258 | - tensorflow-estimator=2.9.1=cuda112py39hd320b7a_0
259 | - termcolor=2.1.0=py39h06a4308_0
260 | - terminado=0.17.1=py39h06a4308_0
261 | - testpath=0.6.0=py39h06a4308_0
262 | - threadpoolctl=2.2.0=pyh0d69192_0
263 | - tk=8.6.12=h1ccaba5_0
264 | - tomli=2.0.1=py39h06a4308_0
265 | - tornado=6.2=py39h5eee18b_0
266 | - traitlets=5.7.1=py39h06a4308_0
267 | - typing-extensions=4.4.0=py39h06a4308_0
268 | - typing_extensions=4.4.0=py39h06a4308_0
269 | - tzdata=2022g=h04d1e81_0
270 | - unicodedata2=15.1.0=py39hd1e30aa_0
271 | - urllib3=1.26.14=py39h06a4308_0
272 | - wcwidth=0.2.5=pyhd3eb1b0_0
273 | - webencodings=0.5.1=py39h06a4308_1
274 | - websocket-client=0.58.0=py39h06a4308_4
275 | - werkzeug=2.2.2=py39h06a4308_0
276 | - wheel=0.37.1=pyhd3eb1b0_0
277 | - widgetsnbextension=4.0.5=pyhd8ed1ab_0
278 | - wrapt=1.14.1=py39h5eee18b_0
279 | - xz=5.2.10=h5eee18b_1
280 | - yaml=0.2.5=h7b6447c_0
281 | - yarl=1.8.1=py39h5eee18b_0
282 | - zeromq=4.3.4=h2531618_0
283 | - zipp=3.11.0=py39h06a4308_0
284 | - zlib=1.2.13=h166bdaf_4
285 | - zstd=1.5.5=hfc55251_0
286 | - pip:
287 | - accelerate==0.28.0
288 | - arrow==1.2.3
289 | - contourpy==1.0.7
290 | - cpufeature==0.2.1
291 | - cycler==0.11.0
292 | - datasets==2.9.0
293 | - deepspeed==0.8.1
294 | - dill==0.3.6
295 | - evaluate==0.4.0
296 | - filelock==3.9.0
297 | - flatbuffers==1.12
298 | - fonttools==4.38.0
299 | - fqdn==1.5.1
300 | - fsspec==2024.2.0
301 | - gradient-accumulator==0.3.1
302 | - hdijupyterutils==0.20.4
303 | - hjson==3.1.0
304 | - huggingface-hub==0.21.4
305 | - ipyparallel==8.4.1
306 | - isoduration==20.11.0
307 | - jsonpointer==2.3
308 | - jupyter==1.0.0
309 | - jupyter-console==6.5.1
310 | - jupyter-contrib-core==0.4.2
311 | - jupyter-server==1.23.5
312 | - kiwisolver==1.4.4
313 | - lmdb==1.4.1
314 | - matplotlib==3.6.3
315 | - multiprocess==0.70.14
316 | - nglview==3.0.8
317 | - ninja==1.11.1
318 | - pandas==1.5.3
319 | - peft==0.9.0
320 | - pillow==9.4.0
321 | - py-cpuinfo==9.0.0
322 | - pyarrow==11.0.0
323 | - pydantic==1.10.5
324 | - pynmrstar==3.3.2
325 | - pyparsing==3.0.9
326 | - pytz==2022.7.1
327 | - qtconsole==5.4.0
328 | - qtpy==2.3.0
329 | - regex==2022.10.31
330 | - responses==0.18.0
331 | - safetensors==0.4.2
332 | - tensorflow-addons==0.19.0
333 | - tokenizers==0.13.2
334 | - tqdm==4.64.1
335 | - transformers==4.26.1
336 | - typeguard==2.13.3
337 | - uri-template==1.2.0
338 | - webcolors==1.12
339 | - xxhash==3.2.0
--------------------------------------------------------------------------------
/SSMuLA/aa_global.py:
--------------------------------------------------------------------------------
1 | """Parameters for training and testing"""
2 |
3 | from __future__ import annotations
4 |
5 | from collections import Counter
6 |
7 | import re
8 | from copy import deepcopy
9 |
10 | RAND_SEED = 42
11 |
12 | TRANSLATE_DICT = {
13 | "AAA": "K",
14 | "AAT": "N",
15 | "AAC": "N",
16 | "AAG": "K",
17 | "ATA": "I",
18 | "ATT": "I",
19 | "ATC": "I",
20 | "ATG": "M",
21 | "ACA": "T",
22 | "ACT": "T",
23 | "ACC": "T",
24 | "ACG": "T",
25 | "AGA": "R",
26 | "AGT": "S",
27 | "AGC": "S",
28 | "AGG": "R",
29 | "TAA": "*",
30 | "TAT": "Y",
31 | "TAC": "Y",
32 | "TAG": "*",
33 | "TTA": "L",
34 | "TTT": "F",
35 | "TTC": "F",
36 | "TTG": "L",
37 | "TCA": "S",
38 | "TCT": "S",
39 | "TCC": "S",
40 | "TCG": "S",
41 | "TGA": "*",
42 | "TGT": "C",
43 | "TGC": "C",
44 | "TGG": "W",
45 | "CAA": "Q",
46 | "CAT": "H",
47 | "CAC": "H",
48 | "CAG": "Q",
49 | "CTA": "L",
50 | "CTT": "L",
51 | "CTC": "L",
52 | "CTG": "L",
53 | "CCA": "P",
54 | "CCT": "P",
55 | "CCC": "P",
56 | "CCG": "P",
57 | "CGA": "R",
58 | "CGT": "R",
59 | "CGC": "R",
60 | "CGG": "R",
61 | "GAA": "E",
62 | "GAT": "D",
63 | "GAC": "D",
64 | "GAG": "E",
65 | "GTA": "V",
66 | "GTT": "V",
67 | "GTC": "V",
68 | "GTG": "V",
69 | "GCA": "A",
70 | "GCT": "A",
71 | "GCC": "A",
72 | "GCG": "A",
73 | "GGA": "G",
74 | "GGT": "G",
75 | "GGC": "G",
76 | "GGG": "G",
77 | }
78 |
79 | # Amino acid code conversion
80 | AA_DICT = {
81 | "Ala": "A",
82 | "Cys": "C",
83 | "Asp": "D",
84 | "Glu": "E",
85 | "Phe": "F",
86 | "Gly": "G",
87 | "His": "H",
88 | "Ile": "I",
89 | "Lys": "K",
90 | "Leu": "L",
91 | "Met": "M",
92 | "Asn": "N",
93 | "Pro": "P",
94 | "Gln": "Q",
95 | "Arg": "R",
96 | "Ser": "S",
97 | "Thr": "T",
98 | "Val": "V",
99 | "Trp": "W",
100 | "Tyr": "Y",
101 | "Ter": "*",
102 | }
103 |
104 | # the upper case three letter code for the amino acids
105 | ALL_AAS_TLC_DICT = {k.upper(): v for k, v in AA_DICT.items() if v != "*"}
106 |
107 | # the upper case three letter code for the amino acids
108 | ALL_AAS_TLC = list(ALL_AAS_TLC_DICT.keys())
109 |
110 | # All canonical amino acids
111 | ALL_AAS = list(ALL_AAS_TLC_DICT.values())
112 | ALL_AA_STR = "".join(ALL_AAS)
113 | AA_NUMB = len(ALL_AAS)
114 | ALLOWED_AAS = set(ALL_AAS)
115 |
116 | # Create a new dictionary with values as keys and counts as values
117 | CODON_COUNT_PER_AA = {aa: Counter(TRANSLATE_DICT.values())[aa] for aa in ALL_AAS + ["*"]}
118 |
119 | # Create a dictionary that links each amino acid to an index
120 | AA_TO_IND = {aa: i for i, aa in enumerate(ALL_AAS)}
121 |
122 | # Define a expressions for parsing mutations in the format Wt##Mut
123 | MUT_REGEX = re.compile("^([A-Z])([0-9]+)([A-Z])$")
124 |
125 | START_AA_IND = 1
126 |
127 |
128 | # Copied from ProFET (Ofer & Linial, DOI: 10.1093/bioinformatics/btv345)
129 | # Original comment by the ProFET authors: 'Acquired from georgiev's paper of
130 | # AAscales using helper script "GetTextData.py". + RegEx cleaning DOI: 10.1089/cmb.2008.0173'
131 | gg_1 = {
132 | "Q": -2.54,
133 | "L": 2.72,
134 | "T": -0.65,
135 | "C": 2.66,
136 | "I": 3.1,
137 | "G": 0.15,
138 | "V": 2.64,
139 | "K": -3.89,
140 | "M": 1.89,
141 | "F": 3.12,
142 | "N": -2.02,
143 | "R": -2.8,
144 | "H": -0.39,
145 | "E": -3.08,
146 | "W": 1.89,
147 | "A": 0.57,
148 | "D": -2.46,
149 | "Y": 0.79,
150 | "S": -1.1,
151 | "P": -0.58,
152 | }
153 | gg_2 = {
154 | "Q": 1.82,
155 | "L": 1.88,
156 | "T": -1.6,
157 | "C": -1.52,
158 | "I": 0.37,
159 | "G": -3.49,
160 | "V": 0.03,
161 | "K": 1.47,
162 | "M": 3.88,
163 | "F": 0.68,
164 | "N": -1.92,
165 | "R": 0.31,
166 | "H": 1,
167 | "E": 3.45,
168 | "W": -0.09,
169 | "A": 3.37,
170 | "D": -0.66,
171 | "Y": -2.62,
172 | "S": -2.05,
173 | "P": -4.33,
174 | }
175 | gg_3 = {
176 | "Q": -0.82,
177 | "L": 1.92,
178 | "T": -1.39,
179 | "C": -3.29,
180 | "I": 0.26,
181 | "G": -2.97,
182 | "V": -0.67,
183 | "K": 1.95,
184 | "M": -1.57,
185 | "F": 2.4,
186 | "N": 0.04,
187 | "R": 2.84,
188 | "H": -0.63,
189 | "E": 0.05,
190 | "W": 4.21,
191 | "A": -3.66,
192 | "D": -0.57,
193 | "Y": 4.11,
194 | "S": -2.19,
195 | "P": -0.02,
196 | }
197 | gg_4 = {
198 | "Q": -1.85,
199 | "L": 5.33,
200 | "T": 0.63,
201 | "C": -3.77,
202 | "I": 1.04,
203 | "G": 2.06,
204 | "V": 2.34,
205 | "K": 1.17,
206 | "M": -3.58,
207 | "F": -0.35,
208 | "N": -0.65,
209 | "R": 0.25,
210 | "H": -3.49,
211 | "E": 0.62,
212 | "W": -2.77,
213 | "A": 2.34,
214 | "D": 0.14,
215 | "Y": -0.63,
216 | "S": 1.36,
217 | "P": -0.21,
218 | }
219 | gg_5 = {
220 | "Q": 0.09,
221 | "L": 0.08,
222 | "T": 1.35,
223 | "C": 2.96,
224 | "I": -0.05,
225 | "G": 0.7,
226 | "V": 0.64,
227 | "K": 0.53,
228 | "M": -2.55,
229 | "F": -0.88,
230 | "N": 1.61,
231 | "R": 0.2,
232 | "H": 0.05,
233 | "E": -0.49,
234 | "W": 0.72,
235 | "A": -1.07,
236 | "D": 0.75,
237 | "Y": 1.89,
238 | "S": 1.78,
239 | "P": -8.31,
240 | }
241 | gg_6 = {
242 | "Q": 0.6,
243 | "L": 0.09,
244 | "T": -2.45,
245 | "C": -2.23,
246 | "I": -1.18,
247 | "G": 7.47,
248 | "V": -2.01,
249 | "K": 0.1,
250 | "M": 2.07,
251 | "F": 1.62,
252 | "N": 2.08,
253 | "R": -0.37,
254 | "H": 0.41,
255 | "E": 0,
256 | "W": 0.86,
257 | "A": -0.4,
258 | "D": 0.24,
259 | "Y": -0.53,
260 | "S": -3.36,
261 | "P": -1.82,
262 | }
263 | gg_7 = {
264 | "Q": 0.25,
265 | "L": 0.27,
266 | "T": -0.65,
267 | "C": 0.44,
268 | "I": -0.21,
269 | "G": 0.41,
270 | "V": -0.33,
271 | "K": 4.01,
272 | "M": 0.84,
273 | "F": -0.15,
274 | "N": 0.4,
275 | "R": 3.81,
276 | "H": 1.61,
277 | "E": -5.66,
278 | "W": -1.07,
279 | "A": 1.23,
280 | "D": -5.15,
281 | "Y": -1.3,
282 | "S": 1.39,
283 | "P": -0.12,
284 | }
285 | gg_8 = {
286 | "Q": 2.11,
287 | "L": -4.06,
288 | "T": 3.43,
289 | "C": -3.49,
290 | "I": 3.45,
291 | "G": 1.62,
292 | "V": 3.93,
293 | "K": -0.01,
294 | "M": 1.85,
295 | "F": -0.41,
296 | "N": -2.47,
297 | "R": 0.98,
298 | "H": -0.6,
299 | "E": -0.11,
300 | "W": -1.66,
301 | "A": -2.32,
302 | "D": -1.17,
303 | "Y": 1.31,
304 | "S": -1.21,
305 | "P": -1.18,
306 | }
307 | gg_9 = {
308 | "Q": -1.92,
309 | "L": 0.43,
310 | "T": 0.34,
311 | "C": 2.22,
312 | "I": 0.86,
313 | "G": -0.47,
314 | "V": -0.21,
315 | "K": -0.26,
316 | "M": -2.05,
317 | "F": 4.2,
318 | "N": -0.07,
319 | "R": 2.43,
320 | "H": 3.55,
321 | "E": 1.49,
322 | "W": -5.87,
323 | "A": -2.01,
324 | "D": 0.73,
325 | "Y": -0.56,
326 | "S": -2.83,
327 | "P": 0,
328 | }
329 | gg_10 = {
330 | "Q": -1.67,
331 | "L": -1.2,
332 | "T": 0.24,
333 | "C": -3.78,
334 | "I": 1.98,
335 | "G": -2.9,
336 | "V": 1.27,
337 | "K": -1.66,
338 | "M": 0.78,
339 | "F": 0.73,
340 | "N": 7.02,
341 | "R": -0.99,
342 | "H": 1.52,
343 | "E": -2.26,
344 | "W": -0.66,
345 | "A": 1.31,
346 | "D": 1.5,
347 | "Y": -0.95,
348 | "S": 0.39,
349 | "P": -0.66,
350 | }
351 | gg_11 = {
352 | "Q": 0.7,
353 | "L": 0.67,
354 | "T": -0.53,
355 | "C": 1.98,
356 | "I": 0.89,
357 | "G": -0.98,
358 | "V": 0.43,
359 | "K": 5.86,
360 | "M": 1.53,
361 | "F": -0.56,
362 | "N": 1.32,
363 | "R": -4.9,
364 | "H": -2.28,
365 | "E": -1.62,
366 | "W": -2.49,
367 | "A": -1.14,
368 | "D": 1.51,
369 | "Y": 1.91,
370 | "S": -2.92,
371 | "P": 0.64,
372 | }
373 | gg_12 = {
374 | "Q": -0.27,
375 | "L": -0.29,
376 | "T": 1.91,
377 | "C": -0.43,
378 | "I": -1.67,
379 | "G": -0.62,
380 | "V": -1.71,
381 | "K": -0.06,
382 | "M": 2.44,
383 | "F": 3.54,
384 | "N": -2.44,
385 | "R": 2.09,
386 | "H": -3.12,
387 | "E": -3.97,
388 | "W": -0.3,
389 | "A": 0.19,
390 | "D": 5.61,
391 | "Y": -1.26,
392 | "S": 1.27,
393 | "P": -0.92,
394 | }
395 | gg_13 = {
396 | "Q": -0.99,
397 | "L": -2.47,
398 | "T": 2.66,
399 | "C": -1.03,
400 | "I": -1.02,
401 | "G": -0.11,
402 | "V": -2.93,
403 | "K": 1.38,
404 | "M": -0.26,
405 | "F": 5.25,
406 | "N": 0.37,
407 | "R": -3.08,
408 | "H": -1.45,
409 | "E": 2.3,
410 | "W": -0.5,
411 | "A": 1.66,
412 | "D": -3.85,
413 | "Y": 1.57,
414 | "S": 2.86,
415 | "P": -0.37,
416 | }
417 | gg_14 = {
418 | "Q": -1.56,
419 | "L": -4.79,
420 | "T": -3.07,
421 | "C": 0.93,
422 | "I": -1.21,
423 | "G": 0.15,
424 | "V": 4.22,
425 | "K": 1.78,
426 | "M": -3.09,
427 | "F": 1.73,
428 | "N": -0.89,
429 | "R": 0.82,
430 | "H": -0.77,
431 | "E": -0.06,
432 | "W": 1.64,
433 | "A": 4.39,
434 | "D": 1.28,
435 | "Y": 0.2,
436 | "S": -1.88,
437 | "P": 0.17,
438 | }
439 | gg_15 = {
440 | "Q": 6.22,
441 | "L": 0.8,
442 | "T": 0.2,
443 | "C": 1.43,
444 | "I": -1.78,
445 | "G": -0.53,
446 | "V": 1.06,
447 | "K": -2.71,
448 | "M": -1.39,
449 | "F": 2.14,
450 | "N": 3.13,
451 | "R": 1.32,
452 | "H": -4.18,
453 | "E": -0.35,
454 | "W": -0.72,
455 | "A": 0.18,
456 | "D": -1.98,
457 | "Y": -0.76,
458 | "S": -2.42,
459 | "P": 0.36,
460 | }
461 | gg_16 = {
462 | "Q": -0.18,
463 | "L": -1.43,
464 | "T": -2.2,
465 | "C": 1.45,
466 | "I": 5.71,
467 | "G": 0.35,
468 | "V": -1.31,
469 | "K": 1.62,
470 | "M": -1.02,
471 | "F": 1.1,
472 | "N": 0.79,
473 | "R": 0.69,
474 | "H": -2.91,
475 | "E": 1.51,
476 | "W": 1.75,
477 | "A": -2.6,
478 | "D": 0.05,
479 | "Y": -5.19,
480 | "S": 1.75,
481 | "P": 0.08,
482 | }
483 | gg_17 = {
484 | "Q": 2.72,
485 | "L": 0.63,
486 | "T": 3.73,
487 | "C": -1.15,
488 | "I": 1.54,
489 | "G": 0.3,
490 | "V": -1.97,
491 | "K": 0.96,
492 | "M": -4.32,
493 | "F": 0.68,
494 | "N": -1.54,
495 | "R": -2.62,
496 | "H": 3.37,
497 | "E": -2.29,
498 | "W": 2.73,
499 | "A": 1.49,
500 | "D": 0.9,
501 | "Y": -2.56,
502 | "S": -2.77,
503 | "P": 0.16,
504 | }
505 | gg_18 = {
506 | "Q": 4.35,
507 | "L": -0.24,
508 | "T": -5.46,
509 | "C": -1.64,
510 | "I": 2.11,
511 | "G": 0.32,
512 | "V": -1.21,
513 | "K": -1.09,
514 | "M": -1.34,
515 | "F": 1.46,
516 | "N": -1.71,
517 | "R": -1.49,
518 | "H": 1.87,
519 | "E": -1.47,
520 | "W": -2.2,
521 | "A": 0.46,
522 | "D": 1.38,
523 | "Y": 2.87,
524 | "S": 3.36,
525 | "P": -0.34,
526 | }
527 | gg_19 = {
528 | "Q": 0.92,
529 | "L": 1.01,
530 | "T": -0.73,
531 | "C": -1.05,
532 | "I": -4.18,
533 | "G": 0.05,
534 | "V": 4.77,
535 | "K": 1.36,
536 | "M": 0.09,
537 | "F": 2.33,
538 | "N": -0.25,
539 | "R": -2.57,
540 | "H": 2.17,
541 | "E": 0.15,
542 | "W": 0.9,
543 | "A": -4.22,
544 | "D": -0.03,
545 | "Y": -3.43,
546 | "S": 2.67,
547 | "P": 0.04,
548 | }
549 |
550 | # Package all georgiev parameters
551 | georgiev_parameters = [
552 | gg_1,
553 | gg_2,
554 | gg_3,
555 | gg_4,
556 | gg_5,
557 | gg_6,
558 | gg_7,
559 | gg_8,
560 | gg_9,
561 | gg_10,
562 | gg_11,
563 | gg_12,
564 | gg_13,
565 | gg_14,
566 | gg_15,
567 | gg_16,
568 | gg_17,
569 | gg_18,
570 | gg_19,
571 | ]
572 |
573 | DEFAULT_ESM = "esm2_t33_650M_UR50D"
574 |
575 | EMB_METHOD_COMBOS = [
576 | {"flatten_emb": "flatten", "ifsite": True, "combo_folder": "flatten_site"},
577 | {"flatten_emb": "mean", "ifsite": True, "combo_folder": "mean_site"},
578 | {"flatten_emb": "mean", "ifsite": False, "combo_folder": "mean_all"},
579 | ]
580 |
581 | EMB_COMBO_LIST = deepcopy(sorted(i["combo_folder"] for i in EMB_METHOD_COMBOS))
582 |
583 | DEFAULT_LEARNED_EMB_COMBO = deepcopy([f"{DEFAULT_ESM}-{combo}" for combo in EMB_COMBO_LIST])
584 | DEFAULT_LEARNED_EMB_DIR = "learned_emb"
585 |
--------------------------------------------------------------------------------
/SSMuLA/vis.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 |
5 | import pandas as pd
6 |
7 | from cairosvg import svg2png
8 |
9 | import seaborn as sns
10 |
11 | import matplotlib.pyplot as plt
12 | import matplotlib.colors as mcolors
13 |
14 | import colorcet as cc
15 |
16 | import bokeh
17 | from bokeh.io import export_svg
18 | from bokeh.models import NumeralTickFormatter
19 | from bokeh.themes.theme import Theme
20 |
21 | import holoviews as hv
22 |
23 | from SSMuLA.landscape_global import LIB_NAMES
24 | from SSMuLA.util import checkNgen_folder
25 |
26 | bokeh.io.output_notebook()
27 | hv.extension("bokeh", "matplotlib")
28 |
29 | # Select a colorcet colormap, for example, 'fire' or 'CET_CBL1'
30 | colormap = cc.cm["glasbey_category10"]
31 |
32 | # Extract a list of hex codes from the colormap
33 | glasbey_category10 = [mcolors.to_hex(colormap(i)) for i in range(colormap.N)]
34 |
35 |
36 | JSON_THEME = Theme(
37 | json={
38 | "attrs": {
39 | "Title": {
40 | "align": "center",
41 | "text_font_size": "12px",
42 | "text_color": "black",
43 | "text_font": "arial",
44 | }, # title centered and bigger
45 | "Axis": {
46 | "axis_label_text_font_style": "normal",
47 | "axis_label_text_color": "black",
48 | "major_label_text_color": "black",
49 | "axis_label_text_font": "arial",
50 | "major_label_text_font": "arial",
51 | }, # no italic labels
52 | "Legend": {
53 | "title_text_font_style": "normal",
54 | "title_text_color": "black",
55 | "label_text_color": "black",
56 | "label_text_font": "arial",
57 | },
58 | "ColorBar": {
59 | "title_text_font_style": "normal",
60 | "major_label_text_color": "black",
61 | "major_label_text_font": "arial",
62 | "title_text_color": "black",
63 | "title_text_font": "arial",
64 | },
65 | }
66 | }
67 | )
68 |
69 | hv.renderer("bokeh").theme = JSON_THEME
70 |
71 | # Grey for heatmap
72 | HM_GREY = "#76777B"
73 |
74 | # blue, orange, green, yellow, purple, gray
75 | FZL_PALETTE = {
76 | "blue": "#4bacc6",
77 | "orange": "#f79646ff",
78 | "light_orange": "#ffbb78",
79 | "red": "#ff8888",
80 | "maroon": "#7A303F",
81 | "green": "#9bbb59",
82 | "yellow": "#f9be00",
83 | "purple": "#8064a2",
84 | "brown": "#ae682f",
85 | "dark_brown": "#6e4a2eff",
86 | "gray": "#666666",
87 | "light_gray": "#D3D3D3",
88 | "light_blue": "#849895",
89 | "light_green": "#9DAE88",
90 | "light_yellow": "#F1D384",
91 | "light_brown": "#C7B784",
92 | "black": "#000000",
93 | }
94 |
95 | GRAY_COLORS = {
96 | "gray-blue": "#749aa3",
97 | "gray-orange": "#c58a6c",
98 | "gray-green": "#8b946e",
99 | "gray-yellow": "#d6b969",
100 | "gray-purple": "#897a8f",
101 | "gray-brown": "#8b6e57",
102 | }
103 |
104 | LIGHT_COLORS = {"yellow": "#F1D384"}
105 |
106 | PLOTEXTENTIONS = [".svg", ".png"]
107 | PLOTTYPES = [t[1:] for t in PLOTEXTENTIONS]
108 |
109 | LIB_COLORS = {
110 | n: c
111 | for (n, c) in zip(
112 | LIB_NAMES,
113 | [
114 | FZL_PALETTE["orange"],
115 | FZL_PALETTE["light_orange"],
116 | FZL_PALETTE["brown"],
117 | FZL_PALETTE["yellow"],
118 | FZL_PALETTE["maroon"],
119 | FZL_PALETTE["purple"],
120 | ]
121 | + sns.color_palette("crest", 9).as_hex()
122 | + [FZL_PALETTE["gray"]],
123 | )
124 | }
125 |
126 | LIB_COLORS_GLASBEY = {
127 | n: c
128 | for (n, c) in zip(
129 | LIB_NAMES,
130 | glasbey_category10[:6]
131 | + glasbey_category10[12:15]
132 | + glasbey_category10[6:10]
133 | + [glasbey_category10[15]]
134 | + glasbey_category10[10:12],
135 | )
136 | }
137 |
138 | LIB_COLORS_CODON = {"DHFR": "#ffbb78"} # light orange
139 |
140 | MLDE_COLORS = (
141 | [
142 | FZL_PALETTE["orange"],
143 | FZL_PALETTE["yellow"],
144 | ]
145 | + sns.color_palette("crest", 9).as_hex()
146 | + [FZL_PALETTE["gray"]]
147 | )
148 |
149 | SIMPLE_ZS_COLOR_MAP = {
150 | "none": FZL_PALETTE["gray"],
151 | "ed_score": FZL_PALETTE["blue"],
152 | "ev_score": FZL_PALETTE["green"],
153 | "esm_score": FZL_PALETTE["purple"],
154 | "esmif_score": FZL_PALETTE["yellow"],
155 | "coves_score": FZL_PALETTE["brown"],
156 | "Triad_score": FZL_PALETTE["orange"],
157 | }
158 |
159 | ZS_COLOR_MAP = {
160 | "none": FZL_PALETTE["gray"],
161 | "ev_score": FZL_PALETTE["green"],
162 | "esm_score": FZL_PALETTE["purple"],
163 | "esmif_score": FZL_PALETTE["yellow"],
164 | "coves_score": FZL_PALETTE["brown"],
165 | "Triad_score": FZL_PALETTE["orange"],
166 | "Triad-esmif_score": FZL_PALETTE["light_blue"],
167 | "ev-esm_score": FZL_PALETTE["light_green"],
168 | "ev-esm-esmif_score": FZL_PALETTE["light_yellow"],
169 | "Triad-ev-esm-esmif_score": FZL_PALETTE["light_brown"],
170 | "two-best-comb_score": FZL_PALETTE["light_gray"],
171 | }
172 |
173 |
174 | # define plot hooks
175 | def one_decimal_x(plot, element):
176 | plot.handles["plot"].xaxis[0].formatter = NumeralTickFormatter(format="0.0")
177 |
178 |
179 | def one_decimal_y(plot, element):
180 | plot.handles["plot"].yaxis[0].formatter = NumeralTickFormatter(format="0.0")
181 |
182 |
183 | def fixmargins(plot, element):
184 | plot.handles["plot"].min_border_right = 30
185 | plot.handles["plot"].min_border_left = 65
186 | plot.handles["plot"].min_border_top = 20
187 | plot.handles["plot"].min_border_bottom = 65
188 | plot.handles["plot"].outline_line_color = "black"
189 | plot.handles["plot"].outline_line_alpha = 1
190 | plot.handles["plot"].outline_line_width = 1
191 | plot.handles["plot"].toolbar.autohide = True
192 |
193 |
194 | def render_hv(hv_plot) -> bokeh.plotting.Figure:
195 | """Render a holoviews plot as a bokeh plot"""
196 | return hv.render(hv_plot)
197 |
198 |
199 | def save_bokeh_hv(
200 | plot_obj,
201 | plot_name: str,
202 | plot_path: str,
203 | bokehorhv: str = "hv",
204 | dpi: int = 300,
205 | scale: int = 2,
206 | skippng: bool = False,
207 | ):
208 |
209 | """
210 | A function for exporting bokeh plots as svg
211 |
212 | Args:
213 | - plot_obj: hv or bokeh plot object
214 | - plot_name: str: name of the plot
215 | - plot_path: str: path to save the plot without the plot_name
216 | - bokehorhv: str: 'hv' or 'bokeh'
217 | - dpi: int: dpi
218 | - scale: int: scale
219 | - skippng: bool: skip png
220 | """
221 |
222 | plot_name = plot_name.replace(" ", "_")
223 | plot_path = checkNgen_folder(plot_path.replace(" ", "_"))
224 | plot_noext = os.path.join(plot_path, plot_name)
225 |
226 | if bokehorhv == "hv":
227 |
228 | # save as html legend
229 | hv.save(plot_obj, plot_noext + ".html")
230 |
231 | # hv.save(plot_obj, plot_noext, mode="auto", fmt='svg', dpi=300, toolbar='disable')
232 |
233 | plot_obj = hv.render(plot_obj, backend="bokeh")
234 |
235 | plot_obj.toolbar_location = None
236 | plot_obj.toolbar.logo = None
237 |
238 | plot_obj.output_backend = "svg"
239 | export_svg(plot_obj, filename=plot_noext + ".svg", timeout=1200)
240 |
241 | if not skippng:
242 | svg2png(
243 | write_to=plot_noext + ".png",
244 | dpi=dpi,
245 | scale=scale,
246 | bytestring=open(plot_noext + ".svg").read().encode("utf-8"),
247 | )
248 | else:
249 | print("Skipping png export")
250 |
251 |
252 | def save_svg(fig, plot_title: str, path2folder: str, ifshow: bool = True):
253 | """
254 | A function for saving svg plots
255 | """
256 |
257 | plot_title_no_space = plot_title.replace(" ", "_")
258 | plt.savefig(
259 | os.path.join(checkNgen_folder(path2folder), f"{plot_title_no_space}.svg"),
260 | bbox_inches="tight",
261 | dpi=300,
262 | format="svg",
263 | )
264 |
265 | if ifshow:
266 | plt.show()
267 |
268 |
269 | def save_plt(fig, plot_title: str, path2folder: str):
270 |
271 | """
272 | A helper function for saving plt plots
273 | Args:
274 | - fig: plt.figure: the figure to save
275 | - plot_title: str: the title of the plot
276 | - path2folder: str: the path to the folder to save the plot
277 | """
278 |
279 | for ext in PLOTEXTENTIONS:
280 | plot_title_no_space = plot_title.replace(" ", "_")
281 | plt.savefig(
282 | os.path.join(checkNgen_folder(path2folder), f"{plot_title_no_space}{ext}"),
283 | bbox_inches="tight",
284 | dpi=300,
285 | )
286 |
287 | plt.close()
288 |
289 |
290 | def plot_fit_dist(
291 | fitness: pd.Series,
292 | label: str,
293 | color: str = "",
294 | spike_length: float | None = None,
295 | ignore_line_label: bool = False,
296 | ) -> hv.Distribution:
297 | """
298 | Plot the fitness distribution
299 |
300 | Args:
301 | - fitness: pd.Series: fitness values
302 | - label: str: label
303 | - color: str: color
304 | - ignore_line_label: bool: ignore line label
305 |
306 | Returns:
307 | - hv.Distribution: plot of the fitness distribution
308 | """
309 |
310 | if label == "codon":
311 | cap_label = f"{label.capitalize()}-level"
312 | elif label == "AA":
313 | cap_label = f"{label.upper()}-level"
314 | else:
315 | cap_label = label
316 |
317 | if color == "":
318 | color = FZL_PALETTE["blue"]
319 |
320 | if ignore_line_label:
321 | mean_label = {}
322 | median_label = {}
323 | else:
324 | mean_label = {"label": f"Mean {label}"}
325 | median_label = {"label": f"Median {label}"}
326 |
327 | hv_dist = hv.Distribution(fitness, label=cap_label).opts(
328 | width=400,
329 | height=400,
330 | color=color,
331 | line_color=None,
332 | )
333 |
334 | # get y_range for spike height
335 | y_range = (
336 | hv.renderer("bokeh").instance(mode="server").get_plot(hv_dist).state.y_range
337 | )
338 |
339 | # set spike length to be 5% of the y_range
340 | if spike_length is None:
341 | spike_length = (y_range.end - y_range.start) * 0.05
342 |
343 | return (
344 | hv_dist
345 | * hv.Spikes([fitness.mean()], **mean_label).opts(
346 | line_dash="dotted",
347 | line_color=color,
348 | line_width=1.6,
349 | spike_length=spike_length,
350 | )
351 | * hv.Spikes([fitness.median()], **median_label).opts(
352 | line_color=color, line_width=1.6, spike_length=spike_length
353 | )
354 | )
355 |
356 |
357 | def plot_zs_violin(
358 | all_df: pd.DataFrame,
359 | zs: str,
360 | encoding_list: list[str],
361 | model: str,
362 | n_sample: int,
363 | n_top: int,
364 | metric: str,
365 | plot_name: str,
366 | ) -> hv.Violin:
367 |
368 | return hv.Violin(
369 | all_df[
370 | (all_df["zs"] == zs)
371 | & (all_df["encoding"].isin(encoding_list))
372 | & (all_df["model"] == model)
373 | & (all_df["n_sample"] == n_sample)
374 | & (all_df["n_top"] == n_top)
375 | ]
376 | .sort_values(["lib", "n_mut_cutoff"], ascending=[True, False])
377 | .copy(),
378 | kdims=["lib", "n_mut_cutoff"],
379 | vdims=[metric],
380 | ).opts(
381 | width=1200,
382 | height=400,
383 | violin_color="n_mut_cutoff",
384 | show_legend=True,
385 | legend_position="top",
386 | legend_offset=(0, 5),
387 | title=plot_name,
388 | ylim=(0, 1),
389 | hooks=[one_decimal_x, one_decimal_y, fixmargins, lib_ncut_hook],
390 | )
391 |
392 |
393 | def lib_ncut_hook(plot, element):
394 |
395 | plot.handles["plot"].x_range.factors = [
396 | (lib, n_mut) for lib in LIB_NAMES for n_mut in ["single", "double", "all"]
397 | ]
398 | plot.handles["xaxis"].major_label_text_font_size = "0pt"
399 | # plot.handles['xaxis'].group_text_font_size = '0pt'
400 | # plot.handles['yaxis'].axis_label_text_font_size = '10pt'
401 | # plot.handles['yaxis'].axis_label_text_font_style = 'normal'
402 | # plot.handles['xaxis'].axis_label_text_font_style = 'normal'
403 |
404 |
405 | def generate_related_color(reference_idx, base_idx, target_idx, palette="colorblind"):
406 | """
407 | Generate a color that has the same relationship to palette[target_idx]
408 | as palette[base_idx] has to palette[reference_idx].
409 |
410 | Parameters:
411 | - reference_idx: Index of the reference color in the palette.
412 | - base_idx: Index of the color that is related to reference_idx.
413 | - target_idx: Index of the color to which the transformation is applied.
414 | - palette: Name of the Seaborn palette (default: "colorblind").
415 |
416 | Returns:
417 | - A tuple representing the new RGB color.
418 | """
419 | import seaborn as sns
420 | import numpy as np
421 |
422 | # Get the palette
423 | colors = sns.color_palette(palette)
424 |
425 | # Compute transformation
426 | color_shift = np.array(colors[base_idx]) - np.array(colors[reference_idx])
427 |
428 | # Apply the transformation
429 | new_color = np.array(colors[target_idx]) + color_shift
430 |
431 | # Clip to valid RGB range [0,1]
432 | new_color = np.clip(new_color, 0, 1)
433 |
434 | return tuple(new_color)
--------------------------------------------------------------------------------
/SSMuLA/calc_hd.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for handling the calculation of the hamming distance cutoff fitness
3 | """
4 |
5 | import os
6 | from glob import glob
7 |
8 | import pandas as pd
9 | import numpy as np
10 |
11 | from scipy.stats import spearmanr
12 | from sklearn.metrics import roc_curve, auc
13 |
14 | from concurrent.futures import ProcessPoolExecutor, as_completed
15 | from tqdm import tqdm
16 |
17 | import seaborn as sns
18 | import matplotlib.pyplot as plt
19 | import matplotlib.colors as mcolors
20 | from matplotlib.lines import Line2D
21 |
22 | from SSMuLA.landscape_global import LIB_INFO_DICT, hamming, lib2prot
23 | from SSMuLA.zs_analysis import (
24 | ZS_METRIC_MAP_TITLE,
25 | ZS_METRIC_MAP_LABEL,
26 | ZS_METRIC_BASELINE,
27 | )
28 | from SSMuLA.vis import LIB_COLORS, FZL_PALETTE, LIB_COLORS_GLASBEY, save_plt, save_svg
29 | from SSMuLA.util import checkNgen_folder, get_file_name
30 |
31 |
32 | # Define the function that will be executed in parallel
33 | def process_aa(aa, all_aas, all_fitnesses):
34 | hm2_fits = []
35 | for aa2, fitness in zip(all_aas, all_fitnesses):
36 | if hamming(aa, aa2) > 2:
37 | continue
38 | hm2_fits.append(fitness)
39 | return aa, np.mean(hm2_fits), np.std(hm2_fits)
40 |
41 |
42 | def get_hd_avg_fit(
43 | df_csv: str,
44 | hd_dir: str = "results/hd",
45 | num_processes: None | int = None,
46 | ):
47 |
48 | df = pd.read_csv(df_csv)
49 |
50 | # no stop codons
51 | df = df[~df["AAs"].str.contains("\*")].copy()
52 |
53 | # only active variants
54 | active_df = df[df["active"]].copy()
55 |
56 | all_aas = active_df["AAs"].tolist()
57 | all_fitnesses = active_df.loc[active_df["AAs"].isin(all_aas), "fitness"].tolist()
58 |
59 | hm2_dict = {}
60 | # Set number of processes; if None, use all available cores
61 | if num_processes is None:
62 | num_processes = int(np.round(os.cpu_count() * 0.8))
63 |
64 | with ProcessPoolExecutor(max_workers=num_processes) as executor:
65 | futures = [
66 | executor.submit(process_aa, aa, all_aas, all_fitnesses) for aa in all_aas
67 | ]
68 | for future in tqdm(as_completed(futures), total=len(futures)):
69 | aa, mean, std = future.result()
70 | hm2_dict[aa] = {"mean": mean, "std": std}
71 |
72 | mean_df = pd.DataFrame.from_dict(hm2_dict, orient="index")
73 |
74 | # Set the index name to 'aa'
75 | mean_df.index.name = "AAs"
76 |
77 | checkNgen_folder(hd_dir)
78 | mean_df.to_csv(os.path.join(hd_dir, get_file_name(df_csv) + ".csv"))
79 |
80 | return hm2_dict
81 |
82 |
83 | def run_hd_avg_fit(
84 | data_dir: str = "data",
85 | scalefit: str = "max",
86 | hd_dir: str = "results/hd_fit",
87 | num_processes: None | int = None,
88 | all_lib: bool = True,
89 | lib_list: list[str] = [],
90 | ):
91 |
92 | """
93 | Run the calculation of the average fitness for all sequences within a Hamming distance of 2
94 |
95 | Args:
96 | - data_dir: str, the directory containing the data
97 | - scalefit: str, the scale of the fitness values
98 | - hd_dir: str, the directory to save the results
99 | - num_processes: None | int, the number of processes to use
100 | - all_lib: bool, whether to use all libraries
101 | - lib_list: list[str], the list of libraries to use
102 | """
103 |
104 | if all_lib or len(lib_list) == 0:
105 | df_csvs = sorted(glob(f"{os.path.normpath(data_dir)}/*/scale2{scalefit}/*.csv"))
106 | else:
107 | df_csvs = [
108 | f"{os.path.normpath(data_dir)}/{lib}/scale2{scalefit}/{lib}.csv"
109 | for lib in lib_list
110 | ]
111 |
112 | for df_csv in df_csvs:
113 | print(f"Processing {df_csv} ...")
114 | df = get_hd_avg_fit(df_csv, hd_dir)
115 |
116 | del df
117 |
118 |
119 | def correlate_hd2fit(aa, all_aas, all_fitnesses, all_ifactive):
120 |
121 | """
122 | A function to correlate the Hamming distance of a sequence
123 | with all other sequences with the fitness values
124 | """
125 |
126 | hms = [-1 * hamming(aa, aa2) for aa2 in all_aas]
127 | rho = spearmanr(all_fitnesses, hms)[0]
128 |
129 | fpr, tpr, _ = roc_curve(all_ifactive, hms, pos_label=1)
130 | rocauc = auc(fpr, tpr)
131 |
132 | return aa, rho, rocauc
133 |
134 |
135 | def get_hd_avg_metric(
136 | df_csv: str,
137 | hd_dir: str = "results/hd_corr",
138 | num_processes: None | int = None,
139 | ):
140 |
141 | df = pd.read_csv(df_csv)
142 |
143 | # no stop codons
144 | df = df[~df["AAs"].str.contains("\*")].copy()
145 |
146 | all_aas = df["AAs"].tolist()
147 | all_fitnesses = df["fitness"].values
148 | all_ifactive = df["active"].values
149 |
150 | # only active variants
151 | active_df = df[df["active"]].copy()
152 |
153 | all_active_aas = active_df["AAs"].tolist()
154 |
155 | hm_dict = {}
156 | # Set number of processes; if None, use all available cores
157 | if num_processes is None:
158 | num_processes = int(np.round(os.cpu_count() * 0.8))
159 |
160 | with ProcessPoolExecutor(max_workers=num_processes) as executor:
161 | futures = [
162 | executor.submit(correlate_hd2fit, aa, all_aas, all_fitnesses, all_ifactive)
163 | for aa in all_active_aas
164 | ]
165 | for future in tqdm(as_completed(futures), total=len(futures)):
166 | aa, rho, rocauc = future.result()
167 | hm_dict[aa] = {"rho": rho, "rocauc": rocauc}
168 |
169 | hm_df = pd.DataFrame.from_dict(hm_dict, orient="index")
170 |
171 | # Set the index name to 'aa'
172 | hm_df.index.name = "AAs"
173 |
174 | checkNgen_folder(hd_dir)
175 | hm_df.to_csv(os.path.join(hd_dir, get_file_name(df_csv) + ".csv"))
176 |
177 | return hm_df["rho"].mean(), hm_df["rocauc"].mean()
178 |
179 |
180 | def run_hd_avg_metric(
181 | data_dir: str = "data",
182 | scalefit: str = "max",
183 | hd_dir: str = "results/hd_corr",
184 | num_processes: None | int = None,
185 | all_lib: bool = True,
186 | lib_list: list[str] = [],
187 | ):
188 |
189 | """
190 | Run the calculation of the average fitness for all sequences within a Hamming distance of 2
191 |
192 | Args:
193 | - data_dir: str, the directory containing the data
194 | - scalefit: str, the scale of the fitness values
195 | - hd_dir: str, the directory to save the results
196 | - num_processes: None | int, the number of processes to use
197 | - all_lib: bool, whether to use all libraries
198 | - lib_list: list[str], the list of libraries to use
199 | """
200 |
201 | hd_avg_metric = pd.DataFrame(columns=["lib", "rho", "rocauc"])
202 |
203 | if all_lib or len(lib_list) == 0:
204 | df_csvs = sorted(glob(f"{os.path.normpath(data_dir)}/*/scale2{scalefit}/*.csv"))
205 | else:
206 | df_csvs = [
207 | f"{os.path.normpath(data_dir)}/{lib}/scale2{scalefit}/{lib}.csv"
208 | for lib in lib_list
209 | ]
210 |
211 | for df_csv in df_csvs:
212 | print(f"Processing {df_csv} ...")
213 | rho, roc_aud = get_hd_avg_metric(df_csv, hd_dir)
214 |
215 | hd_avg_metric = hd_avg_metric._append(
216 | {
217 | "lib": get_file_name(df_csv),
218 | "rho": rho,
219 | "rocauc": roc_aud,
220 | },
221 | ignore_index=True,
222 | )
223 |
224 | checkNgen_folder(hd_dir)
225 | hd_avg_metric.to_csv(os.path.join(hd_dir, "hd_avg_metric.csv"))
226 |
227 | return hd_avg_metric
228 |
229 |
230 | def plot_hd_avg_fit(
231 | figname: str,
232 | hd_fit_dir: str = "results/hd/hd_fit",
233 | fit_dir: str = "data",
234 | fitscale: str = "scale2max",
235 | ifsave: bool = True,
236 | fig_dir: str = "figs",
237 | ):
238 |
239 | all_dfs = []
240 | wt_mean = {}
241 | full_mean = {}
242 |
243 | for lib, lib_dict in LIB_INFO_DICT.items():
244 |
245 | df = pd.read_csv(os.path.join(hd_fit_dir, f"{lib}.csv"))
246 | df["lib"] = lib
247 | all_dfs.append(df)
248 |
249 | wt_mean[lib] = df[df["AAs"] == "".join(lib_dict["AAs"].values())][
250 | "mean"
251 | ].values[0]
252 |
253 | fit_df = pd.read_csv(
254 | os.path.join(fit_dir, lib2prot(lib), fitscale, f"{lib}.csv")
255 | )
256 | full_mean[lib] = fit_df["fitness"].mean()
257 |
258 | all_df = pd.concat(all_dfs)
259 |
260 | fig = plt.figure(figsize=(16, 8))
261 | ax = sns.violinplot(
262 | x="lib", y="mean", data=all_df, hue="lib", palette=LIB_COLORS_GLASBEY
263 | )
264 |
265 | # Set the alpha value of the facecolor automatically to 0.8
266 | for violin in ax.collections[:]: # Access only the violin bodies
267 | facecolor = violin.get_facecolor().flatten() # Get the current facecolor
268 | violin.set_facecolor(mcolors.to_rgba(facecolor, alpha=0.4)) # Set new facecolor
269 |
270 | for lib in LIB_INFO_DICT.keys():
271 |
272 | # Find the position of the violin to add the line to
273 | position = all_df["lib"].unique().tolist().index(lib)
274 |
275 | # Overlay the mean as a scatter plot
276 | ax.axhline(
277 | all_df[all_df["lib"] == lib]["mean"].mean(),
278 | color=FZL_PALETTE["light_gray"],
279 | linestyle="solid",
280 | marker="x",
281 | linewidth=2,
282 | xmin=position / len(LIB_INFO_DICT) + 0.03125,
283 | xmax=(position + 1) / len(LIB_INFO_DICT) - 0.03125,
284 | )
285 | ax.axhline(
286 | wt_mean[lib],
287 | color=LIB_COLORS_GLASBEY[lib],
288 | linestyle="--",
289 | linewidth=2,
290 | xmin=position / len(LIB_INFO_DICT),
291 | xmax=(position + 1) / len(LIB_INFO_DICT),
292 | )
293 | ax.axhline(
294 | full_mean[lib],
295 | color=LIB_COLORS_GLASBEY[lib],
296 | linestyle="dotted",
297 | linewidth=2,
298 | xmin=position / len(LIB_INFO_DICT),
299 | xmax=(position + 1) / len(LIB_INFO_DICT),
300 | )
301 |
302 | lines = [
303 | Line2D(
304 | [0],
305 | [0],
306 | color=FZL_PALETTE["light_gray"],
307 | linestyle="none",
308 | lw=2,
309 | marker="x",
310 | ),
311 | Line2D([0], [0], color="black", linestyle="--", lw=2),
312 | Line2D([0], [0], color="black", linestyle="dotted", lw=2),
313 | ]
314 | labels = [
315 | "Mean of the mean variant fitness of double-site library\nconstructed with any active variant",
316 | "Mean variant fitness of double-site library\nconstruscted with the landscape parent",
317 | "Mean of all variants",
318 | ]
319 |
320 | ax.legend(lines, labels, loc="upper left", bbox_to_anchor=(1, 1))
321 |
322 | ax.set_xlabel("Landscapes")
323 | ax.set_ylabel(
324 | "Mean variant fitness of double-site library constructed with an active variant"
325 | )
326 |
327 | if ifsave:
328 | save_svg(fig, figname, fig_dir)
329 |
330 |
331 | def plot_hd_corr(
332 | metric: str,
333 | figname: str,
334 | hd_corr_dir: str = "results/hd/hd_corr",
335 | ifsave: bool = True,
336 | fig_dir: str = "figs",
337 | ):
338 |
339 | all_dfs = []
340 | wt_mean = {}
341 |
342 | for lib, lib_dict in LIB_INFO_DICT.items():
343 |
344 | df = pd.read_csv(os.path.join(hd_corr_dir, f"{lib}.csv"))
345 | df["lib"] = lib
346 | all_dfs.append(df)
347 |
348 | wt_mean[lib] = df[df["AAs"] == "".join(lib_dict["AAs"].values())][
349 | metric
350 | ].values[0]
351 |
352 | all_df = pd.concat(all_dfs)
353 |
354 | fig = plt.figure(figsize=(16, 8))
355 | ax = sns.violinplot(
356 | x="lib", y=metric, data=all_df, hue="lib", palette=LIB_COLORS_GLASBEY
357 | )
358 |
359 | # Set the alpha value of the facecolor
360 | for violin in ax.collections[:]: # Access only the violin bodies
361 | facecolor = violin.get_facecolor().flatten() # Get the current facecolor
362 | violin.set_facecolor(mcolors.to_rgba(facecolor, alpha=0.4)) # Set new facecolor
363 |
364 | for lib in LIB_INFO_DICT.keys():
365 |
366 | # Find the position of the violin to add the line to
367 | position = all_df["lib"].unique().tolist().index(lib)
368 |
369 | # Overlay the mean
370 | ax.axhline(
371 | all_df[all_df["lib"] == lib][metric].mean(),
372 | color=FZL_PALETTE["light_gray"],
373 | linestyle="solid",
374 | marker="x",
375 | linewidth=2,
376 | xmin=position / len(LIB_INFO_DICT) + 0.03125,
377 | xmax=(position + 1) / len(LIB_INFO_DICT) - 0.03125,
378 | )
379 | ax.axhline(
380 | wt_mean[lib],
381 | color=LIB_COLORS_GLASBEY[lib],
382 | linestyle="--",
383 | linewidth=2,
384 | xmin=position / len(LIB_INFO_DICT),
385 | xmax=(position + 1) / len(LIB_INFO_DICT),
386 | )
387 |
388 | lines = [
389 | Line2D(
390 | [0],
391 | [0],
392 | color=FZL_PALETTE["light_gray"],
393 | linestyle="none",
394 | lw=2,
395 | marker="x",
396 | ),
397 | Line2D([0], [0], color="black", linestyle="--", lw=2),
398 | Line2D([0], [0], color="black", linestyle="dotted", lw=2),
399 | ]
400 | labels = [
401 | f"Mean {ZS_METRIC_MAP_LABEL[metric]}\nfrom any active variant",
402 | "From the landscape parent",
403 | ]
404 | ax.axhline(
405 | ZS_METRIC_BASELINE[metric],
406 | color=FZL_PALETTE["light_gray"],
407 | linestyle="dotted",
408 | linewidth=2,
409 | )
410 | ax.legend(lines, labels, loc="upper left", bbox_to_anchor=(1, 1))
411 |
412 | ax.set_xlabel("Landscapes")
413 | y_dets = (
414 | ZS_METRIC_MAP_TITLE[metric]
415 | .replace("\n", " ")
416 | .replace("F", "f")
417 | .replace("A", "a")
418 | )
419 | ax.set_ylabel(f"Hamming distance {y_dets}")
420 |
421 | if ifsave:
422 | save_svg(fig, figname, fig_dir)
--------------------------------------------------------------------------------
/SSMuLA/zs_models.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for generating zs scores gratefully adapted from EmreGuersoy's work
3 | """
4 |
5 | # Import packages
6 | import os
7 | import numpy as np
8 | import pandas as pd
9 | from pathlib import Path
10 | from collections import Counter
11 |
12 | # EvCouplings
13 | from evcouplings.couplings import CouplingsModel
14 | from evcouplings.align.alignment import Alignment
15 |
16 | # ESM
17 | import esm
18 |
19 | # from esm.model.msa_transformer import MSATransformer
20 |
21 | # Pytorch
22 | import torch
23 |
24 | # Other
25 | from tqdm import tqdm
26 | from typing import List, Tuple, Optional
27 |
28 |
29 | import warnings
30 |
31 | # from transformers import BertTokenizer, BertModel, BertForMaskedLM, pipeline
32 |
33 |
34 | class ZeroShotPrediction:
35 | """
36 | Zero-Shot Analysis on a given sequence and model.
37 |
38 | Input:
39 | - Model Path obtained from EvCouplings
40 | - Sequence or mutation list stored as csv,
41 | where variant column must be [wtAA]_[pos][newAA]
42 | Output: - Scores for each variant
43 | """
44 |
45 | def __init__(self, df, wt_seq):
46 | # make sure no stopbonds are present
47 |
48 | self.df = df
49 | self.wt_sequence = wt_seq
50 |
51 | def _get_n_df(self, n: int = 1):
52 | """Get n data frame with n mutants"""
53 | return self.df[self.df["combo"].apply(lambda x: len(x) == n)].copy()
54 |
55 |
56 | class ddG(ZeroShotPrediction):
57 | def __init__(self, df=None):
58 | super().__init__(df)
59 | pass
60 |
61 |
62 | class EvMutation(ZeroShotPrediction):
63 | """
64 | Perform the EvMutation analysis on a given sequence and model.
65 |
66 | Input:
67 | - Model Path obtained from EvCouplings
68 | - Sequence or mutation list stored as csv
69 | Output: - EvMutation (delta) score for each variant
70 | """
71 |
72 | def __init__(self, df, wt_sequence, model_path):
73 | super().__init__(df, wt_sequence)
74 | self.model_path = model_path
75 | print("Loading model...")
76 | self.model = self.load_model()
77 | print("Model loaded")
78 | self.idx_map = self.model.index_map
79 |
80 | def check_idx(self, pos: list = [1]):
81 | """Check if the position is in the index map"""
82 | all_pos = all([p in self.idx_map for p in pos])
83 |
84 | if all_pos:
85 | return True
86 | else:
87 | return False
88 |
89 | def load_model(self):
90 | self.model = CouplingsModel(self.model_path)
91 | return self.model
92 |
93 | def _get_hamiltonian(self, mt, wt, pos):
94 | delta_E, _, _ = self.model.delta_hamiltonian([(pos, wt, mt)])
95 | return delta_E
96 |
97 | def _get_n_hamiltonian(self, combo: list):
98 | """Get the hamiltonian for n mutants"""
99 | delta_E, _, _ = self.model.delta_hamiltonian(combo)
100 | return delta_E
101 |
102 | def upload_dms_coupling(self, dms_coupling):
103 | """Input: - dms_coupling: Path to the dms_coupling file (csv)"""
104 | self.dms_coupling = dms_coupling
105 |
106 | return self.dms_coupling
107 |
108 | def get_single_mutant_scores(self):
109 | """Get the single mutant scores for the dms_coupling file"""
110 |
111 | # Left join the dms_coupling file with the data set
112 | self.single_mutant_scores = pd.merge(
113 | self.dms_coupling,
114 | self.df,
115 | how="left",
116 | left_on=["pos", "wt"],
117 | right_on=["pos", "wt"],
118 | )
119 |
120 | return self.single_mutant_scores
121 |
122 | def run_evmutation(self, df, _multi=True, _epistatic=False):
123 | """
124 | Run EvMutation for all variants in the data set
125 |
126 | Input: - df: Data set containing the variants, loops trough column = 'combo'
127 | - _mean: If True, the mean of the probabilities is calculated.
128 | If False, the sum of the probabilities is calculated
129 | Output: - Score for each variant"""
130 | score = np.zeros(len(df))
131 | wt_sequence = list(self.wt_sequence)
132 |
133 | for i, combo in enumerate(df["combo"]):
134 | # Prepare mut list for EvMutation
135 | mut_list = []
136 | # Check if positions are in index map
137 | if self.check_idx(df["pos"].iloc[i]):
138 |
139 | if _multi:
140 |
141 | single_mutant_scores = np.zeros((1, len(combo)))
142 |
143 | for j, mt in enumerate(combo):
144 | if mt == "WT" or mt == "NA":
145 | score[i] = np.nan
146 | continue
147 |
148 | else:
149 | pos_wt = int(
150 | df["pos"].iloc[i][j] - 1
151 | ) # Position of the mutation with python indexing
152 | pos_ev = int(
153 | df["pos"].iloc[i][j]
154 | ) # Position of the mutation with python indexing
155 | # Get single scores
156 | single_mutant_scores[0, j] = self._get_hamiltonian(
157 | mt, wt_sequence[pos_wt], pos_ev
158 | ) # TODO: Improve at one point
159 | mut_list.append((pos_ev, wt_sequence[pos_wt], mt))
160 |
161 | # Run EvMutation
162 | score[i] = self._get_n_hamiltonian(mut_list)
163 |
164 | if _epistatic:
165 | score[i] = score[i] - np.sum(single_mutant_scores)
166 |
167 | # TODO: Get epistatic scores dE = dE_combo - sum(dE_single_mutant)
168 | else:
169 | if combo == "WT" or combo == "NA":
170 | score[i] = np.nan
171 | continue
172 |
173 | else:
174 | mt = combo[0]
175 | pos_wt = int(df["pos"].iloc[i][0] - 1)
176 | pos_ev = int(df["pos"].iloc[i][0])
177 | mut_list.append((pos_ev, wt_sequence[pos_wt], mt))
178 | # score[i] = self._get_n_hamiltonian(mut_list)
179 | score[i] = self._get_hamiltonian(
180 | mt, wt_sequence[pos_wt], pos_ev
181 | )
182 | else:
183 | score[i] = np.nan
184 | continue
185 |
186 | return score
187 |
188 | def _get_n_score(self, n: list = [1]):
189 | """Get any score for each variant in the data set"""
190 | df_n_list = []
191 |
192 | # Get the n mutant scores
193 | for i in n:
194 | # Filter out n mutants
195 | df_n = self._get_n_df(i)
196 | if df_n.empty: # Check if the DataFrame is empty after filtering
197 | assert "Data set is empty"
198 | continue
199 |
200 | if i == 1:
201 | score_n = self.run_evmutation(df_n, _multi=False)
202 | else:
203 | score_n = self.run_evmutation(df_n, _multi=True)
204 |
205 | # Add column with number of mutations
206 |
207 | df_n.loc[:, "ev_score"] = score_n
208 | df_n.loc[:, "n_mut"] = i
209 | # score_n = pd.DataFrame(score_n, columns=['ev_score'])
210 |
211 | # if fit or fitness
212 | if "fit" in df_n.columns:
213 | fit_col = "fit"
214 | elif "fitness" in df_n.columns:
215 | fit_col = "fitness"
216 |
217 | # Choose only the columns we want
218 | df_n = df_n[["muts", fit_col, "n_mut", "ev_score"]]
219 | df_n_list.append(df_n)
220 |
221 | return pd.concat(df_n_list, axis=0)
222 |
223 |
224 | class ESM(ZeroShotPrediction):
225 | def __init__(self, df, wt_seq, logits_path="", regen_esm=False):
226 | super().__init__(df, wt_seq)
227 | self.df = df
228 | self.wt_sequence = wt_seq
229 | (
230 | self.model,
231 | self.alphabet,
232 | self.batch_converter,
233 | self.device,
234 | ) = self._infer_model()
235 | self.mask_string, self.cls_string, self.eos_string = (
236 | self.alphabet.mask_idx,
237 | self.alphabet.cls_idx,
238 | self.alphabet.eos_idx,
239 | )
240 | self.alphabet_size = len(self.alphabet)
241 |
242 | if logits_path != "" and os.path.exists(logits_path) and not(regen_esm):
243 | print(f"{logits_path} exists and regen_esm = {regen_esm}. Loading...")
244 | self.logits = np.load(logits_path)
245 | else:
246 | print(f"Generating {logits_path}...")
247 | self.logits = self._get_logits()
248 |
249 | def _infer_model(self):
250 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
251 | # model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
252 | # model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
253 | model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
254 | batch_converter = alphabet.get_batch_converter()
255 | model.eval()
256 | model = model.to(device)
257 | print("Using device:", device)
258 | return model, alphabet, batch_converter, device
259 |
260 | def _get_logits(self):
261 |
262 | data_wt = [("WT", self.wt_sequence)]
263 | # Get Batch tokens for WT data
264 | batch_labels_wt, batch_strs_wt, batch_tokens_wt = self.batch_converter(data_wt)
265 |
266 | logits = np.zeros((len(self.wt_sequence), self.alphabet_size))
267 |
268 | for (i, seq) in enumerate(data_wt[0][1]):
269 | batch_tokens_masked = batch_tokens_wt.clone()
270 | batch_tokens_masked[0, i] = self.alphabet.mask_idx
271 | batch_tokens_masked = batch_tokens_masked.to(self.device)
272 |
273 | with torch.no_grad():
274 | token_probs = torch.log_softmax(
275 | self.model(batch_tokens_masked)["logits"], dim=-1
276 | ).cpu().numpy()
277 |
278 | logits[i] = token_probs[0, i+1]
279 |
280 | return logits
281 |
282 | def _get_mutant_prob(self, mt, wt, pos):
283 | """Get the probability of the mutant given the wild type sequence at certain position."""
284 |
285 | wt_idx = self.alphabet.get_idx(wt)
286 | mt_idx = self.alphabet.get_idx(mt)
287 |
288 | return self.logits[pos, mt_idx] - self.logits[pos, wt_idx]
289 |
290 | def run_esm(self, df, _sum=True):
291 | """
292 | Run ESM model for all variants in the data set
293 |
294 | Input: - logits: Logits of the wild type sequence
295 | - df: Data set containing the variants, loops trough column = 'Combo' and 'Pos'
296 | - _sum: If True, the sum of the probabilities is calculated.
297 | If False, the mean of the probabilities is calculated
298 |
299 | Output: - Score for each variant
300 | """
301 | score = np.zeros(len(df))
302 | wt_sequence = list(self.wt_sequence)
303 |
304 | if _sum:
305 | for i, combo in enumerate(df["combo"]):
306 | s = np.zeros(len(combo))
307 | for j, mt in enumerate(combo):
308 | if mt == "WT":
309 | score[i] = 0
310 | continue
311 |
312 | elif mt == "NA":
313 | score[i] = np.nan
314 | continue
315 |
316 | else:
317 | pos = (
318 | int(df["pos"].iloc[i][j]) - 1
319 | ) # Position of the mutation with python indexing
320 | wt = wt_sequence[pos]
321 | s[j] = self._get_mutant_prob(mt=mt, wt=wt, pos=pos)
322 |
323 | score[i] += s.sum()
324 |
325 | else:
326 | for i, combo in enumerate(df["combo"]):
327 |
328 | mt = combo[0]
329 |
330 | if mt == "WT":
331 | score[i] = 0
332 | continue
333 |
334 | elif mt == "NA":
335 | score[i] = np.nan
336 | continue
337 |
338 | else:
339 | pos = int(df["pos"].iloc[i][0] - 1)
340 | wt = wt_sequence[pos]
341 | score[i] = self._get_mutant_prob(mt=mt, wt=wt, pos=pos)
342 |
343 | return score
344 |
345 | def _get_n_df(self, n: int = 1):
346 | """Get n data frame with n mutants"""
347 |
348 | return self.df[self.df["combo"].apply(lambda x: len(x) == n)].copy()
349 |
350 | def _get_n_score(self, n: list = [1]):
351 | """Get any score for each variant in the data set"""
352 | df_n_list = []
353 |
354 | # Get the n mutant scores
355 | for i in n:
356 | # Filter out n mutants
357 | df_n = self._get_n_df(i)
358 | if df_n.empty: # Check if the DataFrame is empty after filtering
359 | assert "Data set is empty"
360 | continue
361 |
362 | if i == 1:
363 | score_n = self.run_esm(df_n, _sum=False)
364 | else:
365 | score_n = self.run_esm(df_n, _sum=True)
366 |
367 | # Add column with number of mutations
368 |
369 | df_n.loc[:, "esm_score"] = score_n
370 | df_n.loc[:, "n_mut"] = i
371 | # score_n = pd.DataFrame(score_n, columns=['ev_score'])
372 | # if fit or fitness
373 | if "fit" in df_n.columns:
374 | fit_col = "fit"
375 | elif "fitness" in df_n.columns:
376 | fit_col = "fitness"
377 |
378 | # Choose only the columns we want
379 | df_n = df_n[["muts", fit_col, "n_mut", "esm_score"]]
380 | df_n_list.append(df_n)
381 |
382 | return pd.concat(df_n_list, axis=0)
383 |
--------------------------------------------------------------------------------
/SSMuLA/plm_finetune.py:
--------------------------------------------------------------------------------
1 | # modify from repo:
2 | # https://github.com/RSchmirler/data-repo_plm-finetune-eval/blob/main/notebooks/finetune/Finetuning_per_protein.ipynb
3 |
4 | # import dependencies
5 | import os.path
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11 | from torch.utils.data import DataLoader
12 |
13 | import re
14 | import numpy as np
15 | import pandas as pd
16 | from copy import deepcopy
17 |
18 | import transformers, datasets
19 | from transformers.modeling_outputs import SequenceClassifierOutput
20 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
21 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
22 | from transformers import T5EncoderModel, T5Tokenizer
23 | from transformers import EsmModel, AutoTokenizer, AutoModelForSequenceClassification
24 | from transformers import TrainingArguments, Trainer, set_seed
25 |
26 | import peft
27 | from peft import (
28 | get_peft_config,
29 | PeftModel,
30 | PeftConfig,
31 | inject_adapter_in_model,
32 | LoraConfig,
33 | )
34 |
35 | from evaluate import load
36 | from datasets import Dataset
37 |
38 | from tqdm import tqdm
39 | import random
40 |
41 | from scipy import stats
42 | from sklearn.metrics import accuracy_score
43 |
44 | import matplotlib.pyplot as plt
45 |
46 | from SSMuLA.util import checkNgen_folder, get_file_name
47 |
48 |
49 | # # Set environment variables to run Deepspeed from a notebook
50 | # os.environ["MASTER_ADDR"] = "localhost"
51 | # os.environ["MASTER_PORT"] = "9994" # modify if RuntimeError: Address already in use
52 | # os.environ["RANK"] = "0"
53 | # os.environ["LOCAL_RANK"] = "0"
54 | # os.environ["WORLD_SIZE"] = "1"
55 |
56 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
57 | # Deepspeed config for optimizer CPU offload
58 |
59 | ds_config = {
60 | "fp16": {
61 | "enabled": "auto",
62 | "loss_scale": 0,
63 | "loss_scale_window": 1000,
64 | "initial_scale_power": 16,
65 | "hysteresis": 2,
66 | "min_loss_scale": 1,
67 | },
68 | "optimizer": {
69 | "type": "AdamW",
70 | "params": {
71 | "lr": "auto",
72 | "betas": "auto",
73 | "eps": "auto",
74 | "weight_decay": "auto",
75 | },
76 | },
77 | "scheduler": {
78 | "type": "WarmupLR",
79 | "params": {
80 | "warmup_min_lr": "auto",
81 | "warmup_max_lr": "auto",
82 | "warmup_num_steps": "auto",
83 | },
84 | },
85 | "zero_optimization": {
86 | "stage": 2,
87 | "offload_optimizer": {"device": "cpu", "pin_memory": True},
88 | "allgather_partitions": True,
89 | "allgather_bucket_size": 2e8,
90 | "overlap_comm": True,
91 | "reduce_scatter": True,
92 | "reduce_bucket_size": 2e8,
93 | "contiguous_gradients": True,
94 | },
95 | "gradient_accumulation_steps": "auto",
96 | "gradient_clipping": "auto",
97 | "steps_per_print": 2000,
98 | "train_batch_size": "auto",
99 | "train_micro_batch_size_per_gpu": "auto",
100 | "wall_clock_breakdown": False,
101 | }
102 |
103 | # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
104 |
105 | RAND_SEED_LIST = deepcopy([random.randint(0, 1000000) for _ in range(50)])
106 |
107 | # load ESM2 models
108 | def load_esm_model(checkpoint, num_labels, half_precision, full=False, deepspeed=True):
109 | tokenizer = AutoTokenizer.from_pretrained(checkpoint)
110 |
111 | if half_precision and deepspeed:
112 | model = AutoModelForSequenceClassification.from_pretrained(
113 | checkpoint, num_labels=num_labels, torch_dtype=torch.float16
114 | )
115 | else:
116 | model = AutoModelForSequenceClassification.from_pretrained(
117 | checkpoint, num_labels=num_labels
118 | )
119 |
120 | if full == True:
121 | return model, tokenizer
122 |
123 | peft_config = LoraConfig(
124 | r=4, lora_alpha=1, bias="all", target_modules=["query", "key", "value", "dense"]
125 | )
126 |
127 | model = inject_adapter_in_model(peft_config, model)
128 |
129 | # Unfreeze the prediction head
130 | for (param_name, param) in model.classifier.named_parameters():
131 | param.requires_grad = True
132 |
133 | return model, tokenizer
134 |
135 |
136 | # Set random seeds for reproducibility of your trainings run
137 | def set_seeds(s):
138 | torch.manual_seed(s)
139 | np.random.seed(s)
140 | random.seed(s)
141 | set_seed(s)
142 |
143 |
144 | # Dataset creation
145 | def create_dataset(tokenizer, seqs, labels):
146 | tokenized = tokenizer(seqs, max_length=1024, padding=True, truncation=True)
147 | dataset = Dataset.from_dict(tokenized)
148 | dataset = dataset.add_column("labels", labels)
149 |
150 | return dataset
151 |
152 |
153 | def plot_train_val(
154 | history: list, # training history
155 | landscape: str, # landscape name
156 | ):
157 |
158 | # Get loss, val_loss, and the computed metric from history
159 | loss = [x["loss"] for x in history if "loss" in x]
160 | val_loss = [x["eval_loss"] for x in history if "eval_loss" in x]
161 |
162 | # Get spearman (for regression) or accuracy value (for classification)
163 | if [x["eval_spearmanr"] for x in history if "eval_spearmanr" in x] != []:
164 | metric = [x["eval_spearmanr"] for x in history if "eval_spearmanr" in x]
165 | else:
166 | metric = [x["eval_accuracy"] for x in history if "eval_accuracy" in x]
167 |
168 | epochs = [x["epoch"] for x in history if "loss" in x]
169 |
170 | # Create a figure with two y-axes
171 | fig, ax1 = plt.subplots(figsize=(10, 5))
172 | ax2 = ax1.twinx()
173 |
174 | # Plot loss and val_loss on the first y-axis
175 | line1 = ax1.plot(epochs, loss, label="train_loss")
176 | line2 = ax1.plot(epochs, val_loss, label="val_loss")
177 | ax1.set_xlabel("Epoch")
178 | ax1.set_ylabel("Loss")
179 |
180 | # Plot the computed metric on the second y-axis
181 | line3 = ax2.plot(epochs, metric, color="red", label="val_metric")
182 | ax2.set_ylabel("Metric")
183 | ax2.set_ylim([0, 1])
184 |
185 | # Combine the lines from both y-axes and create a single legend
186 | lines = line1 + line2 + line3
187 | labels = [line.get_label() for line in lines]
188 | ax1.legend(lines, labels, loc="lower left")
189 |
190 | # add title to the figure
191 | plt.title(f"Training curves for {landscape}")
192 |
193 | # return the plot but do not save it
194 | return fig
195 |
196 |
197 | def save_model(model, filepath):
198 | # Saves all parameters that were changed during finetuning
199 |
200 | # Create a dictionary to hold the non-frozen parameters
201 | non_frozen_params = {}
202 |
203 | # Iterate through all the model parameters
204 | for param_name, param in model.named_parameters():
205 | # If the parameter has requires_grad=True, add it to the dictionary
206 | if param.requires_grad:
207 | non_frozen_params[param_name] = param
208 |
209 | # Save only the finetuned parameters
210 | torch.save(non_frozen_params, filepath)
211 |
212 |
213 | # Main training fuction
214 | def train_per_protein(
215 | checkpoint, # model checkpoint
216 | train_df, # training data
217 | valid_df, # validation data
218 | seed, # random seed
219 | device, # device to use
220 | num_labels=1, # 1 for regression, >1 for classification
221 | # effective training batch size is batch * accum
222 | # we recommend an effective batch size of 8
223 | batch=4, # for training
224 | accum=2, # gradient accumulation
225 | val_batch=16, # batch size for evaluation
226 | epochs=10, # training epochs
227 | lr=3e-4, # recommended learning rate
228 | deepspeed=False, # if gpu is large enough disable deepspeed for training speedup
229 | mixed=True, # enable mixed precision training
230 | full=False, # enable training of the full model (instead of LoRA)
231 | # gpu = 1 #gpu selection (1 for first gpu)
232 | ):
233 |
234 | print("Model used:", checkpoint, "\n")
235 |
236 | # Set all random seeds
237 | set_seeds(seed)
238 |
239 | # load model
240 | model, tokenizer = load_esm_model(checkpoint, num_labels, mixed, full, deepspeed)
241 |
242 | # Preprocess inputs
243 | # Replace uncommon AAs with "X"
244 | train_df["seq"] = train_df["seq"].str.replace(
245 | "|".join(["O", "B", "U", "Z", "J"]), "X", regex=True
246 | )
247 | valid_df["seq"] = valid_df["seq"].str.replace(
248 | "|".join(["O", "B", "U", "Z", "J"]), "X", regex=True
249 | )
250 |
251 | # Create Datasets
252 | train_set = create_dataset(
253 | tokenizer, list(train_df["seq"]), list(train_df["fitness"])
254 | )
255 | valid_set = create_dataset(
256 | tokenizer, list(valid_df["seq"]), list(valid_df["fitness"])
257 | )
258 |
259 | # Huggingface Trainer arguments
260 | args = TrainingArguments(
261 | "./",
262 | evaluation_strategy="epoch",
263 | logging_strategy="epoch",
264 | save_strategy="no",
265 | learning_rate=lr,
266 | per_device_train_batch_size=batch,
267 | per_device_eval_batch_size=val_batch,
268 | gradient_accumulation_steps=accum,
269 | num_train_epochs=epochs,
270 | seed=seed,
271 | deepspeed=ds_config if deepspeed else None,
272 | fp16=mixed,
273 | )
274 |
275 | # Metric definition for validation data
276 | def compute_metrics(eval_pred):
277 | if num_labels > 1: # for classification
278 | metric = load("accuracy")
279 | predictions, labels = eval_pred
280 | predictions = np.argmax(predictions, axis=1)
281 | else: # for regression
282 | metric = load("spearmanr")
283 | predictions, labels = eval_pred
284 |
285 | return metric.compute(predictions=predictions, references=labels)
286 |
287 | # Trainer
288 | trainer = Trainer(
289 | model,
290 | args,
291 | train_dataset=train_set,
292 | eval_dataset=valid_set,
293 | tokenizer=tokenizer,
294 | compute_metrics=compute_metrics,
295 | )
296 |
297 | # Train model
298 | trainer.train()
299 |
300 | return tokenizer, model, trainer.state.log_history
301 |
302 |
303 | def train_predict_per_protein(
304 | df_csv: str, # csv file with landscape data
305 | rep: int, # replicate number
306 | device: str, # device to use
307 | checkpoint: str = "facebook/esm2_t33_650M_UR50D", # model checkpoint
308 | n_sample: int = 384, # number of train+val
309 | zs_predictor: str = "none", # zero-shot predictor
310 | ft_frac: float = 0.125, # fraction of data for focused sampling
311 | save_dir: str = "results/finetuning", # directory to save the model, plot, and predictions
312 | train_kwargs: dict = {}, # additional training arguments
313 | rerun=False, # if True, the model will be trained again even if it already exists
314 | ):
315 | """ """
316 |
317 | landscape = get_file_name(df_csv)
318 |
319 | seed = RAND_SEED_LIST[rep]
320 |
321 | df = pd.read_csv(df_csv)
322 |
323 | if zs_predictor == "none":
324 | df_sorted = df.copy()
325 | n = n_sample
326 | elif zs_predictor not in df.columns:
327 | print(f"{zs_predictor} not in the dataframe")
328 | df_sorted = df.copy()
329 | n = n_sample
330 | else:
331 | n_cutoff = int(len(df) * ft_frac)
332 | df_sorted = (
333 | df.sort_values(by=zs_predictor, ascending=False).copy()[:n_cutoff].copy()
334 | )
335 | n = min(n_sample, n_cutoff)
336 |
337 | # output_csv_path
338 | output_csv_path = os.path.join(
339 | checkNgen_folder(os.path.join(save_dir, "predictions", landscape)),
340 | f"{landscape}_{str(n_sample)}_{str(rep)}.csv",
341 | )
342 |
343 | # check if the output already exists
344 | if os.path.exists(output_csv_path) and not rerun:
345 | print(f"Output already exists for {landscape} with n={n_sample} and rep={rep}")
346 | return df
347 |
348 | # randomly sample rows from the dataframe
349 | train_val_df = (
350 | df_sorted.sample(n=n, random_state=seed).reset_index(drop=True).copy()
351 | )
352 |
353 | # split the train_val_df into 90%training and 10% validation sets
354 | train_df = (
355 | train_val_df.sample(frac=0.9, random_state=seed).reset_index(drop=True).copy()
356 | )
357 | valid_df = train_val_df.drop(train_df.index).reset_index(drop=True).copy()
358 |
359 | tokenizer, model, history = train_per_protein(
360 | checkpoint, train_df, valid_df, seed=seed, device=device, **train_kwargs
361 | )
362 |
363 | # save the model
364 | save_model(
365 | model,
366 | os.path.join(
367 | checkNgen_folder(os.path.join(save_dir, "model", landscape)),
368 | f"{landscape}_{str(n_sample)}_{str(rep)}.pth",
369 | ),
370 | )
371 |
372 | # plot the training history
373 | fig = plot_train_val(history, landscape)
374 |
375 | # save the plot
376 | fig.savefig(
377 | os.path.join(
378 | checkNgen_folder(os.path.join(save_dir, "plot", landscape)),
379 | f"{landscape}_{str(n_sample)}_{str(rep)}.png",
380 | )
381 | )
382 |
383 | # create Dataset
384 | test_set = create_dataset(tokenizer, list(df["seq"]), list(df["fitness"]))
385 | # make compatible with torch DataLoader
386 | test_set = test_set.with_format("torch", device=device)
387 |
388 | # Create a dataloader for the test dataset
389 | test_dataloader = DataLoader(test_set, batch_size=16, shuffle=False)
390 |
391 | # Put the model in evaluation mode
392 | model.eval()
393 |
394 | # Make predictions on the test dataset
395 | predictions = []
396 | with torch.no_grad():
397 | for batch in tqdm(test_dataloader):
398 | input_ids = batch["input_ids"].to(device)
399 | attention_mask = batch["attention_mask"].to(device)
400 | # add batch results(logits) to predictions
401 | predictions += model.float()(
402 | input_ids, attention_mask=attention_mask
403 | ).logits.tolist()
404 |
405 | # flatten the prediction to one single list
406 | # save predictions as a new column in the test dataframe
407 | df["predictions"] = np.array(predictions).flatten()
408 |
409 | # save the dataframe
410 | df[["seq", "fitness", "predictions"]].to_csv(
411 | output_csv_path,
412 | index=False,
413 | )
414 |
415 | return df
--------------------------------------------------------------------------------
/SSMuLA/get_factor.py:
--------------------------------------------------------------------------------
1 | """A script for saving dataframes as pngs"""
2 |
3 | import os
4 | import pandas as pd
5 | from sklearn.ensemble import RandomForestRegressor
6 | from sklearn.model_selection import train_test_split
7 | from sklearn.metrics import mean_squared_error
8 | from sklearn.linear_model import LinearRegression
9 |
10 | from matplotlib.colors import LinearSegmentedColormap, to_rgb
11 | import matplotlib.pyplot as plt
12 | import seaborn as sns
13 |
14 | # for html to png
15 | from selenium import webdriver
16 | from selenium.webdriver.firefox.service import Service
17 | from selenium.webdriver.firefox.options import Options
18 |
19 | from SSMuLA.de_simulations import DE_TYPES
20 | from SSMuLA.zs_analysis import ZS_OPTS, ZS_COMB_OPTS
21 | from SSMuLA.vis_summary import ZS_METRICS
22 | from SSMuLA.get_corr import LANDSCAPE_ATTRIBUTES, val_list, zs_list
23 | from SSMuLA.vis import FZL_PALETTE, save_plt
24 | from SSMuLA.util import checkNgen_folder
25 |
26 | # Custom colormap for the MSE row, using greens
27 | cmap_mse = LinearSegmentedColormap.from_list(
28 | "mse_cmap_r", ["#FFFFFF", "#9bbb59"][::-1], N=100
29 | ) # dark to light green
30 |
31 | # Create the colormap
32 | custom_cmap = LinearSegmentedColormap.from_list(
33 | "bwg",
34 | [
35 | FZL_PALETTE["blue"],
36 | "white",
37 | FZL_PALETTE["green"],
38 | ],
39 | N=100,
40 | )
41 |
42 | geckodriver_path = "/disk2/fli/miniconda3/envs/SSMuLA/bin/geckodriver"
43 |
44 | de_metrics = ["mean_all", "fraction_max"]
45 |
46 | simple_des = {
47 | "recomb_SSM": "Recomb",
48 | "single_step_DE": "Single step",
49 | "top96_SSM": "Top96 recomb",
50 | }
51 |
52 | simple_de_metric_map = {}
53 |
54 | for de_type in DE_TYPES:
55 | for de_metric in de_metrics:
56 | simple_de_metric_map[f"{de_type}_{de_metric}"] = simple_des[de_type]
57 |
58 |
59 | # Styling the DataFrame
60 | def style_dataframe(df):
61 | # Define a function to apply gradient selectively
62 | def apply_gradient(row):
63 | if row.name == "mse":
64 | # Generate colors for the MSE row based on its values
65 | norm = plt.Normalize(row.min(), row.max())
66 | rgba_colors = [cmap_mse(norm(value)) for value in row]
67 | return [
68 | f"background-color: rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]})"
69 | for rgba in rgba_colors
70 | ]
71 | else:
72 | return [""] * len(row) # No style for other rows
73 |
74 | # Apply gradient across all rows
75 | styled_df = df.style.background_gradient(cmap="Blues")
76 | # Apply the custom gradient to the MSE row
77 | styled_df = styled_df.apply(apply_gradient, axis=1)
78 | return styled_df.format("{:.2f}").apply(
79 | lambda x: ["color: black" if x.name == "mse" else "" for _ in x], axis=1
80 | )
81 |
82 |
83 | def styledf2png(
84 | df,
85 | filename,
86 | sub_dir="results/style_dfs",
87 | absolute_dir="/disk2/fli/SSMuLA/",
88 | width=800,
89 | height=1600,
90 | ):
91 |
92 | html_path = os.path.join(sub_dir, filename + ".html")
93 | checkNgen_folder(html_path)
94 |
95 | # Create a HTML file
96 | html_file = open(html_path, "w")
97 | html_file.write(df.to_html())
98 | html_file.close()
99 |
100 | options = Options()
101 | options.add_argument("--headless") # Run Firefox in headless mode.
102 |
103 | s = Service(geckodriver_path)
104 | driver = webdriver.Firefox(service=s, options=options)
105 |
106 | driver.get(
107 | f"file://{os.path.join(absolute_dir, html_path)}"
108 | ) # Update the path to your HTML file
109 |
110 | # Set the size of the window to your content (optional)
111 | driver.set_window_size(width, height) # You might need to adjust this
112 |
113 | # Take screenshot
114 | driver.save_screenshot(html_path.replace(".html", ".png"))
115 | driver.quit()
116 |
117 |
118 | def get_lib_stat(
119 | lib_csv: str = "results/corr_all/384/boosting|ridge-top96/merge_all.csv",
120 | sub_dir="results/style_dfs",
121 | absolute_dir="/disk2/fli/SSMuLA/",
122 | n_mut: str = "all",
123 | ):
124 | if n_mut != "all":
125 | corr_det = f"corr_{n_mut}"
126 | if corr_det not in lib_csv:
127 | lib_csv = lib_csv.replace("corr_all", corr_det)
128 |
129 | df = pd.read_csv(lib_csv)
130 | style_df = (
131 | df[["lib"] + LANDSCAPE_ATTRIBUTES]
132 | .set_index("lib")
133 | .T.style.format("{:.2f}")
134 | .background_gradient(cmap="Blues", axis=1)
135 | )
136 |
137 | return styledf2png(
138 | style_df,
139 | f"lib_stat_{n_mut}_heatmap_384-boosting|ridge-top96",
140 | sub_dir=sub_dir,
141 | absolute_dir=absolute_dir,
142 | width=1450,
143 | height=950,
144 | )
145 |
146 | def get_zs_zs_corr(
147 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
148 | sub_dir="results/style_dfs",
149 | absolute_dir="/disk2/fli/SSMuLA/",
150 | n_mut: str = "all",
151 | metric: str = "rho"
152 | ):
153 | df = pd.read_csv(corr_csv)
154 |
155 | simple_zs = [zs for zs in zs_list if n_mut in zs and metric in zs]
156 |
157 | style_df = (
158 | df[
159 | df["descriptor"].isin(
160 | simple_zs
161 | )
162 | ][["descriptor"] + simple_zs]
163 | .loc[df['descriptor'].isin(simple_zs)]
164 | .rename(columns={"descriptor": "ZS predictions"})
165 | .set_index("ZS predictions")
166 | .rename(index=lambda x: x.replace("double", "hd2"))
167 | .rename(columns=lambda x: x.replace("double", "hd2"))
168 | .style.format("{:.2f}")
169 | .background_gradient(cmap=custom_cmap, vmin=0, vmax=1)
170 | )
171 |
172 | return styledf2png(
173 | style_df,
174 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_zs",
175 | sub_dir=sub_dir,
176 | absolute_dir=absolute_dir,
177 | width=1250,
178 | height=450,
179 | )
180 |
181 | def get_zs_ft_corr(
182 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
183 | sub_dir="results/style_dfs",
184 | absolute_dir="/disk2/fli/SSMuLA/",
185 | n_mut: str = "all",
186 | metric: str = "rho"
187 | ):
188 | """
189 | Correlate ZS with ftMLDE
190 | """
191 | df = pd.read_csv(corr_csv)
192 |
193 | simple_zs = [zs for zs in zs_list if n_mut in zs and metric in zs]
194 | # simple_ft = []
195 |
196 |
197 | style_df = (
198 | df[
199 | df["descriptor"].isin(
200 | simple_zs
201 | )
202 | ][["descriptor"] + simple_zs]
203 | .loc[df['descriptor'].isin(simple_zs)]
204 | .rename(columns={"descriptor": "ZS predictions"})
205 | .set_index("ZS predictions")
206 | .rename(index=lambda x: x.replace("double", "hd2"))
207 | .rename(columns=lambda x: x.replace("double", "hd2"))
208 | .style.format("{:.2f}")
209 | .background_gradient(cmap=custom_cmap, vmin=0, vmax=1)
210 | )
211 |
212 | return styledf2png(
213 | style_df,
214 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_zs",
215 | sub_dir=sub_dir,
216 | absolute_dir=absolute_dir,
217 | width=1250,
218 | height=450,
219 | )
220 |
221 |
222 | def get_zs_corr_ls(
223 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
224 | sub_dir="results/style_dfs",
225 | absolute_dir="/disk2/fli/SSMuLA/",
226 | n_mut: str = "all",
227 | metric: str = "rho"
228 | ):
229 |
230 | df = pd.read_csv(corr_csv)
231 |
232 | style_df = (
233 | df[
234 | df["descriptor"].isin(
235 | [zs for zs in zs_list if n_mut in zs and metric in zs]
236 | )
237 | ][["descriptor"] + LANDSCAPE_ATTRIBUTES]
238 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map})
239 | .iloc[0:33]
240 | .set_index("Landscape attributes")
241 | .rename(index=lambda x: x.replace("double", "hd2")).T
242 | .style.format("{:.2f}")
243 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1)
244 | )
245 |
246 | return styledf2png(
247 | style_df,
248 | f"zs_{n_mut}_{metric}_heatmap_384-boosting|ridge-top96_landscape_attributes",
249 | sub_dir=sub_dir,
250 | absolute_dir=absolute_dir,
251 | width=1450,
252 | height=950,
253 | )
254 |
255 |
256 | def get_zs_corr(
257 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
258 | sub_dir="results/style_dfs",
259 | absolute_dir="/disk2/fli/SSMuLA/",
260 | deorls: str = "de",
261 | de_calc: str = "mean_all", # or fraction_max
262 | n_mut: str = "all",
263 | ):
264 |
265 | df = pd.read_csv(corr_csv)
266 |
267 | if deorls == "de":
268 | comp_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES]
269 | dets = de_calc
270 | width=625
271 | else:
272 | comp_list = LANDSCAPE_ATTRIBUTES
273 | dets = "landscape_attributes"
274 | width=1800
275 |
276 | style_df = (
277 | df[
278 | df["descriptor"].isin(
279 | [zs for zs in zs_list if n_mut in zs and "ndcg" not in zs]
280 | )
281 | ][["descriptor"] + comp_list]
282 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map})
283 | .iloc[0:33]
284 | .set_index("Landscape attributes")
285 | .rename(index=lambda x: x.replace("double", "hd2"))
286 | .style.format("{:.2f}")
287 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1)
288 | )
289 |
290 | return styledf2png(
291 | style_df,
292 | f"zs_{n_mut}_heatmap_384-boosting|ridge-top96_{dets}",
293 | sub_dir=sub_dir,
294 | absolute_dir=absolute_dir,
295 | width=width,
296 | height=550,
297 | )
298 |
299 |
300 | def get_corr_heatmap(
301 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
302 | sub_dir="results/style_dfs",
303 | absolute_dir="/disk2/fli/SSMuLA/",
304 | de_calc: str = "mean_all", # or fraction_max
305 | ):
306 |
307 | de_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES]
308 |
309 | df = pd.read_csv(corr_csv)
310 | style_df = (
311 | df[["descriptor"] + de_list]
312 | .rename(columns={"descriptor": "Landscape attributes", **simple_de_metric_map})
313 | .iloc[0:33]
314 | .set_index("Landscape attributes")
315 | .style.format("{:.2f}")
316 | .background_gradient(cmap=custom_cmap, vmin=-1, vmax=1)
317 | )
318 |
319 | return styledf2png(
320 | style_df,
321 | f"corr_heatmap_384-boosting|ridge-top96_{de_calc}",
322 | sub_dir=sub_dir,
323 | absolute_dir=absolute_dir,
324 | width=720,
325 | height=975,
326 | )
327 |
328 |
329 | def get_importance_heatmap(
330 | lib_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/merge_all.csv",
331 | sub_dir="results/style_dfs",
332 | absolute_dir="/disk2/fli/SSMuLA/",
333 | de_calc: str = "mean_all", # or fraction_max
334 | ):
335 |
336 | de_list = [f"{de_type}_{de_calc}" for de_type in DE_TYPES]
337 |
338 | df = pd.read_csv(lib_csv)
339 |
340 | # Load your dataset
341 | # data = pd.read_csv('path_to_your_data.csv')
342 |
343 | # Select features and targets
344 | features = df[LANDSCAPE_ATTRIBUTES]
345 | targets = df[val_list]
346 |
347 | lr_df_list = []
348 |
349 | # Splitting the dataset for each target and fitting a model
350 | for target in targets.columns:
351 | lr_model = LinearRegression()
352 | lr_model.fit(features, df[target])
353 |
354 | # Feature importance
355 | feature_importances = pd.DataFrame(
356 | lr_model.coef_, index=LANDSCAPE_ATTRIBUTES, columns=[target]
357 | )
358 |
359 | lr_df_list.append(feature_importances)
360 | lr_df = pd.concat(lr_df_list, axis=1)
361 | lr_df.index.names = ["Landscape attributes"]
362 |
363 | style_df = (
364 | lr_df[de_list]
365 | .rename(columns=simple_de_metric_map)
366 | .style.format("{:.2f}")
367 | .background_gradient(cmap=custom_cmap)
368 | )
369 |
370 | return styledf2png(
371 | style_df,
372 | f"importance_heatmap_384-boosting|ridge-top96_{de_calc}",
373 | sub_dir=sub_dir,
374 | absolute_dir=absolute_dir,
375 | width=720,
376 | height=975,
377 | )
378 |
379 |
380 | def plot_all_factor(
381 | corr_csv: str = "results/corr_all/384/boosting|ridge-top96/actcut-1/corr.csv",
382 | sub_dir="results/style_dfs_actcut-1",
383 | absolute_dir="/disk2/fli/SSMuLA/",
384 | ):
385 |
386 | for n_mut in ["all", "double"]:
387 |
388 | for metric in ["rho", "rocauc"]:
389 | get_zs_zs_corr(
390 | corr_csv=corr_csv,
391 | n_mut=n_mut,
392 | metric=metric,
393 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_zs_corr")),
394 | absolute_dir=absolute_dir,
395 | )
396 |
397 | get_zs_corr_ls(
398 | corr_csv=corr_csv,
399 | n_mut=n_mut,
400 | metric=metric,
401 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_corr_ls")),
402 | absolute_dir=absolute_dir,
403 | )
404 |
405 |
406 | get_zs_ft_corr(
407 | corr_csv=corr_csv,
408 | n_mut=n_mut,
409 | metric=metric,
410 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_ft_corr")),
411 | absolute_dir=absolute_dir,
412 | )
413 |
414 | for de_calc in ["mean_all", "fraction_max"]:
415 | for deorls in ["de", "ls"]:
416 | get_zs_corr(
417 | corr_csv=corr_csv,
418 | de_calc=de_calc,
419 | n_mut=n_mut,
420 | deorls=deorls,
421 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "zs_corr")),
422 | absolute_dir=absolute_dir,
423 | )
424 |
425 | get_corr_heatmap(
426 | corr_csv=corr_csv,
427 | de_calc=de_calc,
428 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "corr_heatmap")),
429 | absolute_dir=absolute_dir,
430 | )
431 |
432 | get_importance_heatmap(
433 | lib_csv=corr_csv.replace("corr.csv", "merge_all.csv"),
434 | de_calc=de_calc,
435 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "importance_heatmap")),
436 | absolute_dir=absolute_dir,
437 | )
438 |
439 | get_lib_stat(
440 | lib_csv=corr_csv.replace("corr.csv", "merge_all.csv"),
441 | n_mut=n_mut,
442 | sub_dir=checkNgen_folder(os.path.join(sub_dir, "lib_stat")),
443 | absolute_dir=absolute_dir,
444 | )
445 |
--------------------------------------------------------------------------------
/envs/frozen/SSMuLA.yml:
--------------------------------------------------------------------------------
1 | name: SSMuLA
2 | channels:
3 | - pytorch
4 | - pyg
5 | - anaconda
6 | - nvidia
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=2_gnu
12 | - _py-xgboost-mutex=2.0=gpu_0
13 | - absl-py=2.1.0=pyhd8ed1ab_1
14 | - aiohappyeyeballs=2.4.6=pyhd8ed1ab_0
15 | - aiohttp=3.11.13=py311h2dc5d0c_0
16 | - aiosignal=1.3.2=pyhd8ed1ab_0
17 | - alsa-lib=1.2.13=hb9d3cd8_0
18 | - anyio=4.8.0=pyhd8ed1ab_0
19 | - argon2-cffi=23.1.0=pyhd8ed1ab_1
20 | - argon2-cffi-bindings=21.2.0=py311h9ecbd09_5
21 | - arrow=1.3.0=pyhd8ed1ab_1
22 | - asttokens=3.0.0=pyhd8ed1ab_1
23 | - astunparse=1.6.3=pyhd8ed1ab_3
24 | - async-lru=2.0.4=pyhd8ed1ab_1
25 | - attrs=25.1.0=pyh71513ae_0
26 | - babel=2.17.0=pyhd8ed1ab_0
27 | - beautifulsoup4=4.13.3=pyha770c72_0
28 | - biopandas=0.5.1=pyhd8ed1ab_1
29 | - biopython=1.85=py311h9ecbd09_1
30 | - biotite=1.1.0=py311hfdbb021_1
31 | - biotraj=1.2.2=py311hfdbb021_0
32 | - blas=1.0=mkl
33 | - bleach=6.2.0=pyh29332c3_4
34 | - bleach-with-css=6.2.0=h82add2a_4
35 | - blosc=1.21.6=he440d0b_1
36 | - bokeh=3.6.3=pyhd8ed1ab_0
37 | - brokenaxes=0.4.2=pyhd8ed1ab_0
38 | - brotli=1.1.0=hb9d3cd8_2
39 | - brotli-bin=1.1.0=hb9d3cd8_2
40 | - brotli-python=1.1.0=py311hfdbb021_2
41 | - bzip2=1.0.8=h4bc722e_7
42 | - c-ares=1.34.4=hb9d3cd8_0
43 | - c-blosc2=2.15.2=h3122c55_1
44 | - ca-certificates=2025.1.31=hbcca054_0
45 | - cached-property=1.5.2=hd8ed1ab_1
46 | - cached_property=1.5.2=pyha770c72_1
47 | - cairo=1.18.2=h3394656_1
48 | - cairocffi=1.7.1=pyhd8ed1ab_1
49 | - cairosvg=2.7.1=pyhd8ed1ab_1
50 | - certifi=2025.1.31=pyhd8ed1ab_0
51 | - cffi=1.17.1=py311hf29c0ef_0
52 | - charset-normalizer=3.4.1=pyhd8ed1ab_0
53 | - colorama=0.4.6=pyhd8ed1ab_1
54 | - colorcet=3.1.0=pyhd8ed1ab_1
55 | - comm=0.2.2=pyhd8ed1ab_1
56 | - contourpy=1.3.1=py311hd18a35c_0
57 | - cpython=3.11.11=py311hd8ed1ab_1
58 | - cssselect2=0.2.1=pyh9f0ad1d_1
59 | - cuda-cudart=12.1.105=0
60 | - cuda-cupti=12.1.105=0
61 | - cuda-libraries=12.1.0=0
62 | - cuda-nvrtc=12.1.105=0
63 | - cuda-nvtx=12.1.105=0
64 | - cuda-opencl=12.4.127=0
65 | - cuda-runtime=12.1.0=0
66 | - cuda-version=11.8=h70ddcb2_3
67 | - cudatoolkit=11.8.0=h4ba93d1_13
68 | - cycler=0.12.1=pyhd8ed1ab_1
69 | - cyrus-sasl=2.1.27=h54b06d7_7
70 | - datashader=0.17.0=pyhd8ed1ab_0
71 | - dbus=1.13.6=h5008d03_3
72 | - debugpy=1.8.12=py311hfdbb021_0
73 | - decorator=5.2.1=pyhd8ed1ab_0
74 | - defusedxml=0.7.1=pyhd8ed1ab_0
75 | - double-conversion=3.3.1=h5888daf_0
76 | - et_xmlfile=2.0.0=pyhd8ed1ab_1
77 | - exceptiongroup=1.2.2=pyhd8ed1ab_1
78 | - executing=2.1.0=pyhd8ed1ab_1
79 | - expat=2.6.4=h5888daf_0
80 | - filelock=3.17.0=pyhd8ed1ab_0
81 | - firefox=134.0=h5888daf_0
82 | - flake8=7.1.2=pyhd8ed1ab_0
83 | - flatbuffers=24.12.23=h8f4948b_0
84 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
85 | - font-ttf-inconsolata=3.000=h77eed37_0
86 | - font-ttf-source-code-pro=2.038=h77eed37_0
87 | - font-ttf-ubuntu=0.83=h77eed37_3
88 | - fontconfig=2.15.0=h7e30c49_1
89 | - fonts-conda-ecosystem=1=0
90 | - fonts-conda-forge=1=0
91 | - fonttools=4.56.0=py311h2dc5d0c_0
92 | - fqdn=1.5.1=pyhd8ed1ab_1
93 | - freetype=2.12.1=h267a509_2
94 | - frozenlist=1.5.0=py311h2dc5d0c_1
95 | - fsspec=2025.2.0=pyhd8ed1ab_0
96 | - gast=0.6.0=pyhd8ed1ab_0
97 | - geckodriver=0.36.0=hb7e49e0_0
98 | - giflib=5.2.2=hd590300_0
99 | - gmp=6.3.0=hac33072_2
100 | - gmpy2=2.1.5=py311h0f6cedb_3
101 | - google-pasta=0.2.0=pyhd8ed1ab_2
102 | - graphite2=1.3.13=h59595ed_1003
103 | - grpcio=1.67.1=py311h9789449_1
104 | - h11=0.14.0=pyhd8ed1ab_1
105 | - h2=4.2.0=pyhd8ed1ab_0
106 | - h5py=3.12.1=nompi_py311h5ed33ec_103
107 | - harfbuzz=10.3.0=h76408a6_0
108 | - hdf5=1.14.4=nompi_h2d575fe_105
109 | - holoviews=1.20.1=pyhd8ed1ab_0
110 | - hpack=4.1.0=pyhd8ed1ab_0
111 | - httpcore=1.0.7=pyh29332c3_1
112 | - httpx=0.28.1=pyhd8ed1ab_0
113 | - hvplot=0.11.2=pyhd8ed1ab_0
114 | - hyperframe=6.1.0=pyhd8ed1ab_0
115 | - icu=75.1=he02047a_0
116 | - idna=3.10=pyhd8ed1ab_1
117 | - importlib-metadata=8.6.1=pyha770c72_0
118 | - importlib_resources=6.5.2=pyhd8ed1ab_0
119 | - intel-openmp=2022.0.1=h06a4308_3633
120 | - ipykernel=6.29.5=pyh3099207_0
121 | - ipympl=0.9.6=pyhd8ed1ab_0
122 | - ipython=8.32.0=pyh907856f_0
123 | - ipython_genutils=0.2.0=pyhd8ed1ab_2
124 | - ipywidgets=8.1.5=pyhd8ed1ab_1
125 | - isoduration=20.11.0=pyhd8ed1ab_1
126 | - jedi=0.19.2=pyhd8ed1ab_1
127 | - jinja2=3.1.5=pyhd8ed1ab_0
128 | - joblib=1.4.2=pyhd8ed1ab_1
129 | - json5=0.10.0=pyhd8ed1ab_1
130 | - jsonpointer=3.0.0=py311h38be061_1
131 | - jsonschema=4.23.0=pyhd8ed1ab_1
132 | - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
133 | - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
134 | - jupyter-lsp=2.2.5=pyhd8ed1ab_1
135 | - jupyter_bokeh=4.0.5=pyhd8ed1ab_1
136 | - jupyter_client=8.6.3=pyhd8ed1ab_1
137 | - jupyter_core=5.7.2=pyh31011fe_1
138 | - jupyter_events=0.12.0=pyh29332c3_0
139 | - jupyter_server=2.15.0=pyhd8ed1ab_0
140 | - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
141 | - jupyterlab=4.3.5=pyhd8ed1ab_0
142 | - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
143 | - jupyterlab_server=2.27.3=pyhd8ed1ab_1
144 | - jupyterlab_widgets=3.0.13=pyhd8ed1ab_1
145 | - keras=3.8.0=pyh753f3f9_0
146 | - kernel-headers_linux-64=3.10.0=he073ed8_18
147 | - keyutils=1.6.1=h166bdaf_0
148 | - kiwisolver=1.4.7=py311hd18a35c_0
149 | - krb5=1.21.3=h659f571_0
150 | - lcms2=2.17=h717163a_0
151 | - ld_impl_linux-64=2.43=h712a8e2_4
152 | - lerc=4.0.0=h27087fc_0
153 | - libabseil=20240722.0=cxx17_hbbce691_4
154 | - libaec=1.1.3=h59595ed_0
155 | - libblas=3.9.0=16_linux64_mkl
156 | - libbrotlicommon=1.1.0=hb9d3cd8_2
157 | - libbrotlidec=1.1.0=hb9d3cd8_2
158 | - libbrotlienc=1.1.0=hb9d3cd8_2
159 | - libcblas=3.9.0=16_linux64_mkl
160 | - libclang-cpp19.1=19.1.7=default_hb5137d0_1
161 | - libclang13=19.1.7=default_h9c6a7e4_1
162 | - libcublas=12.1.0.26=0
163 | - libcufft=11.0.2.4=0
164 | - libcufile=1.9.1.3=0
165 | - libcups=2.3.3=h4637d8d_4
166 | - libcurand=10.3.5.147=0
167 | - libcurl=8.12.1=h332b0f4_0
168 | - libcusolver=11.4.4.55=0
169 | - libcusparse=12.0.2.55=0
170 | - libdeflate=1.23=h4ddbbb0_0
171 | - libdrm=2.4.124=hb9d3cd8_0
172 | - libedit=3.1.20250104=pl5321h7949ede_0
173 | - libegl=1.7.0=ha4b6fd6_2
174 | - libev=4.33=hd590300_2
175 | - libexpat=2.6.4=h5888daf_0
176 | - libffi=3.4.6=h2dba641_0
177 | - libgcc=14.2.0=h767d61c_2
178 | - libgcc-ng=14.2.0=h69a702a_2
179 | - libgfortran=14.2.0=h69a702a_2
180 | - libgfortran5=14.2.0=hf1ad2bd_2
181 | - libgl=1.7.0=ha4b6fd6_2
182 | - libglib=2.82.2=h2ff4ddf_1
183 | - libglvnd=1.7.0=ha4b6fd6_2
184 | - libglx=1.7.0=ha4b6fd6_2
185 | - libgomp=14.2.0=h767d61c_2
186 | - libgrpc=1.67.1=h25350d4_1
187 | - libiconv=1.18=h4ce23a2_1
188 | - libjpeg-turbo=3.0.0=hd590300_1
189 | - liblapack=3.9.0=16_linux64_mkl
190 | - libllvm15=15.0.7=ha7bfdaf_5
191 | - libllvm19=19.1.7=ha7bfdaf_1
192 | - liblzma=5.6.4=hb9d3cd8_0
193 | - libnghttp2=1.64.0=h161d5f1_0
194 | - libnpp=12.0.2.50=0
195 | - libnsl=2.0.1=hd590300_0
196 | - libntlm=1.8=hb9d3cd8_0
197 | - libnvjitlink=12.1.105=0
198 | - libnvjpeg=12.1.1.14=0
199 | - libopengl=1.7.0=ha4b6fd6_2
200 | - libpciaccess=0.18=hd590300_0
201 | - libpng=1.6.47=h943b412_0
202 | - libpq=17.4=h27ae623_0
203 | - libprotobuf=5.28.3=h6128344_1
204 | - libre2-11=2024.07.02=hbbce691_2
205 | - libsodium=1.0.20=h4ab18f5_0
206 | - libsqlite=3.49.1=hee588c1_1
207 | - libssh2=1.11.1=hf672d98_0
208 | - libstdcxx=14.2.0=h8f9b012_2
209 | - libstdcxx-ng=14.2.0=h4852527_2
210 | - libtiff=4.7.0=hd9ff511_3
211 | - libuuid=2.38.1=h0b41bf4_0
212 | - libuv=1.50.0=hb9d3cd8_0
213 | - libwebp-base=1.5.0=h851e524_0
214 | - libxcb=1.17.0=h8a09558_0
215 | - libxcrypt=4.4.36=hd590300_1
216 | - libxgboost=2.1.4=cuda118_h09a87be_0
217 | - libxkbcommon=1.8.0=hc4a0caf_0
218 | - libxml2=2.13.6=h8d12d68_0
219 | - libxslt=1.1.39=h76b75d6_0
220 | - libzlib=1.3.1=hb9d3cd8_2
221 | - linkify-it-py=2.0.3=pyhd8ed1ab_1
222 | - llvm-openmp=15.0.7=h0cdce71_0
223 | - llvmlite=0.44.0=py311h9c9ff8c_0
224 | - looseversion=1.3.0=pyhd8ed1ab_0
225 | - lz4-c=1.10.0=h5888daf_1
226 | - markdown=3.6=pyhd8ed1ab_0
227 | - markdown-it-py=3.0.0=pyhd8ed1ab_1
228 | - markupsafe=3.0.2=py311h2dc5d0c_1
229 | - matplotlib=3.10.0=py311h38be061_0
230 | - matplotlib-base=3.10.0=py311h2b939e6_0
231 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1
232 | - mccabe=0.7.0=pyhd8ed1ab_1
233 | - mdit-py-plugins=0.4.2=pyhd8ed1ab_1
234 | - mdurl=0.1.2=pyhd8ed1ab_1
235 | - mistune=3.1.2=pyhd8ed1ab_0
236 | - mkl=2022.1.0=hc2b9512_224
237 | - ml_dtypes=0.4.0=py311h7db5c69_2
238 | - mmtf-python=1.1.3=pyhd8ed1ab_0
239 | - mpc=1.3.1=h24ddda3_1
240 | - mpfr=4.2.1=h90cbb55_3
241 | - mpmath=1.3.0=pyhd8ed1ab_1
242 | - msgpack-python=1.1.0=py311hd18a35c_0
243 | - multidict=6.1.0=py311h2dc5d0c_2
244 | - multipledispatch=0.6.0=pyhd8ed1ab_1
245 | - munkres=1.1.4=pyh9f0ad1d_0
246 | - mypy=1.15.0=py311h9ecbd09_0
247 | - mypy_extensions=1.0.0=pyha770c72_1
248 | - mysql-common=9.0.1=h266115a_4
249 | - mysql-libs=9.0.1=he0572af_4
250 | - namex=0.0.8=pyhd8ed1ab_1
251 | - nbclient=0.10.2=pyhd8ed1ab_0
252 | - nbconvert-core=7.16.6=pyh29332c3_0
253 | - nbformat=5.10.4=pyhd8ed1ab_1
254 | - nccl=2.25.1.1=h03a54cd_0
255 | - ncurses=6.5=h2d0b736_3
256 | - nest-asyncio=1.6.0=pyhd8ed1ab_1
257 | - networkx=3.4.2=pyh267e887_2
258 | - nodejs=22.12.0=hf235a45_0
259 | - nomkl=2.0=0
260 | - notebook-shim=0.2.4=pyhd8ed1ab_1
261 | - numba=0.61.0=py311h4e1c48f_1
262 | - numexpr=2.10.2=py311h38b10cd_100
263 | - numpy=2.1.3=py311h71ddf71_0
264 | - ocl-icd=2.3.2=hb9d3cd8_2
265 | - ocl-icd-system=1.0.0=1
266 | - opencl-headers=2024.10.24=h5888daf_0
267 | - openjpeg=2.5.3=h5fbd93e_0
268 | - openldap=2.6.9=he970967_0
269 | - openmm=8.2.0=py311he5bdeac_2
270 | - openpyxl=3.1.5=py311h50c5138_1
271 | - openssl=3.4.1=h7b32b05_0
272 | - opt_einsum=3.4.0=pyhd8ed1ab_1
273 | - optree=0.14.0=py311hd18a35c_1
274 | - overrides=7.7.0=pyhd8ed1ab_1
275 | - packaging=24.2=pyhd8ed1ab_2
276 | - pandas=2.2.3=py311h7db5c69_1
277 | - pandocfilters=1.5.0=pyhd8ed1ab_0
278 | - panel=1.6.1=pyhd8ed1ab_0
279 | - param=2.2.0=pyhd8ed1ab_0
280 | - parso=0.8.4=pyhd8ed1ab_1
281 | - patsy=1.0.1=pyhd8ed1ab_1
282 | - pcre2=10.44=hba22ea6_2
283 | - pdbfixer=1.11=pyhd8ed1ab_0
284 | - pexpect=4.9.0=pyhd8ed1ab_1
285 | - pickleshare=0.7.5=pyhd8ed1ab_1004
286 | - pillow=11.1.0=py311h1322bbf_0
287 | - pip=25.0.1=pyh8b19718_0
288 | - pixman=0.44.2=h29eaf8c_0
289 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
290 | - platformdirs=4.3.6=pyhd8ed1ab_1
291 | - prometheus_client=0.21.1=pyhd8ed1ab_0
292 | - prompt-toolkit=3.0.50=pyha770c72_0
293 | - propcache=0.2.1=py311h2dc5d0c_1
294 | - protobuf=5.28.3=py311hfdbb021_0
295 | - psutil=6.1.1=py311h9ecbd09_0
296 | - pthread-stubs=0.4=hb9d3cd8_1002
297 | - ptyprocess=0.7.0=pyhd8ed1ab_1
298 | - pure_eval=0.2.3=pyhd8ed1ab_1
299 | - py-cpuinfo=9.0.0=pyhd8ed1ab_1
300 | - py-xgboost=2.1.4=cuda118_pyh1b5bf1a_0
301 | - pycodestyle=2.12.1=pyhd8ed1ab_1
302 | - pycparser=2.22=pyh29332c3_1
303 | - pyct=0.5.0=pyhd8ed1ab_1
304 | - pyflakes=3.2.0=pyhd8ed1ab_1
305 | - pyg=2.5.2=py311_torch_2.1.0_cu121
306 | - pygments=2.19.1=pyhd8ed1ab_0
307 | - pyparsing=3.2.1=pyhd8ed1ab_0
308 | - pyside6=6.8.2=py311h9053184_0
309 | - pysocks=1.7.1=pyha55dd90_7
310 | - pytables=3.10.2=py311h3ebe2b2_0
311 | - python=3.11.11=h9e4cc4f_1_cpython
312 | - python-dateutil=2.9.0.post0=pyhff2d567_1
313 | - python-fastjsonschema=2.21.1=pyhd8ed1ab_0
314 | - python-flatbuffers=25.2.10=pyhbc23db3_0
315 | - python-json-logger=2.0.7=pyhd8ed1ab_0
316 | - python-tzdata=2025.1=pyhd8ed1ab_0
317 | - python_abi=3.11=5_cp311
318 | - pytorch=2.1.1=py3.11_cuda12.1_cudnn8.9.2_0
319 | - pytorch-cuda=12.1=ha16c6d3_6
320 | - pytorch-mutex=1.0=cuda
321 | - pytz=2024.1=pyhd8ed1ab_0
322 | - pyviz_comms=3.0.4=pyhd8ed1ab_1
323 | - pyyaml=6.0.2=py311h2dc5d0c_2
324 | - pyzmq=26.2.1=py311h7deb3e3_0
325 | - qhull=2020.2=h434a139_5
326 | - qt6-main=6.8.2=h588cce1_0
327 | - re2=2024.07.02=h9925aae_2
328 | - readline=8.2=h8c095d6_2
329 | - referencing=0.36.2=pyh29332c3_0
330 | - requests=2.32.3=pyhd8ed1ab_1
331 | - rfc3339-validator=0.1.4=pyhd8ed1ab_1
332 | - rfc3986-validator=0.1.1=pyh9f0ad1d_0
333 | - rich=13.9.4=pyhd8ed1ab_1
334 | - rpds-py=0.23.1=py311h687327b_0
335 | - scikit-learn=1.6.1=py311h57cc02b_0
336 | - scipy=1.15.2=py311h8f841c2_0
337 | - seaborn=0.13.2=hd8ed1ab_3
338 | - seaborn-base=0.13.2=pyhd8ed1ab_3
339 | - send2trash=1.8.3=pyh0d859eb_1
340 | - setuptools=75.8.0=pyhff2d567_0
341 | - six=1.17.0=pyhd8ed1ab_0
342 | - snappy=1.2.1=h8bd8927_1
343 | - sniffio=1.3.1=pyhd8ed1ab_1
344 | - soupsieve=2.5=pyhd8ed1ab_1
345 | - stack_data=0.6.3=pyhd8ed1ab_1
346 | - statsmodels=0.14.4=py311h9f3472d_0
347 | - sympy=1.13.3=pyh2585a3b_105
348 | - sysroot_linux-64=2.17=h0157908_18
349 | - tensorboard=2.18.0=pyhd8ed1ab_1
350 | - tensorboard-data-server=0.7.0=py311hafd3f86_2
351 | - tensorflow=2.18.0=cpu_py311h6ac8430_0
352 | - tensorflow-base=2.18.0=cpu_py311h50f7602_0
353 | - tensorflow-estimator=2.18.0=cpu_py311h1adda88_0
354 | - termcolor=2.5.0=pyhd8ed1ab_1
355 | - terminado=0.18.1=pyh0d859eb_0
356 | - threadpoolctl=3.5.0=pyhc1e730c_0
357 | - tinycss2=1.4.0=pyhd8ed1ab_0
358 | - tk=8.6.13=noxft_h4845f30_101
359 | - toolz=1.0.0=pyhd8ed1ab_1
360 | - torchtriton=2.1.0=py311
361 | - tornado=6.4.2=py311h9ecbd09_0
362 | - tqdm=4.67.1=pyhd8ed1ab_1
363 | - traitlets=5.14.3=pyhd8ed1ab_1
364 | - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
365 | - typing-extensions=4.12.2=hd8ed1ab_1
366 | - typing_extensions=4.12.2=pyha770c72_1
367 | - typing_utils=0.1.0=pyhd8ed1ab_1
368 | - tzdata=2025a=h78e105d_0
369 | - uc-micro-py=1.0.3=pyhd8ed1ab_1
370 | - unicodedata2=16.0.0=py311h9ecbd09_0
371 | - uri-template=1.3.0=pyhd8ed1ab_1
372 | - urllib3=2.3.0=pyhd8ed1ab_0
373 | - versioneer=0.29=pyhd8ed1ab_0
374 | - wayland=1.23.1=h3e06ad9_0
375 | - wcwidth=0.2.13=pyhd8ed1ab_1
376 | - webcolors=24.11.1=pyhd8ed1ab_0
377 | - webencodings=0.5.1=pyhd8ed1ab_3
378 | - websocket-client=1.8.0=pyhd8ed1ab_1
379 | - werkzeug=3.1.3=pyhd8ed1ab_1
380 | - wheel=0.45.1=pyhd8ed1ab_1
381 | - widgetsnbextension=4.0.13=pyhd8ed1ab_1
382 | - wrapt=1.17.2=py311h9ecbd09_0
383 | - xarray=2025.1.2=pyhd8ed1ab_0
384 | - xcb-util=0.4.1=hb711507_2
385 | - xcb-util-cursor=0.1.5=hb9d3cd8_0
386 | - xcb-util-image=0.4.0=hb711507_2
387 | - xcb-util-keysyms=0.4.1=hb711507_0
388 | - xcb-util-renderutil=0.3.10=hb711507_0
389 | - xcb-util-wm=0.4.2=hb711507_0
390 | - xgboost=2.1.4=cuda118_pyh7984362_0
391 | - xkeyboard-config=2.43=hb9d3cd8_0
392 | - xorg-libice=1.1.2=hb9d3cd8_0
393 | - xorg-libsm=1.2.5=he73a12e_0
394 | - xorg-libx11=1.8.11=h4f16b4b_0
395 | - xorg-libxau=1.0.12=hb9d3cd8_0
396 | - xorg-libxcomposite=0.4.6=hb9d3cd8_2
397 | - xorg-libxcursor=1.2.3=hb9d3cd8_0
398 | - xorg-libxdamage=1.1.6=hb9d3cd8_0
399 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0
400 | - xorg-libxext=1.3.6=hb9d3cd8_0
401 | - xorg-libxfixes=6.0.1=hb9d3cd8_0
402 | - xorg-libxi=1.8.2=hb9d3cd8_0
403 | - xorg-libxrandr=1.5.4=hb9d3cd8_0
404 | - xorg-libxrender=0.9.12=hb9d3cd8_0
405 | - xorg-libxtst=1.2.5=hb9d3cd8_3
406 | - xorg-libxxf86vm=1.1.6=hb9d3cd8_0
407 | - xyzservices=2025.1.0=pyhd8ed1ab_0
408 | - yaml=0.2.5=h7f98852_2
409 | - yarl=1.18.3=py311h2dc5d0c_1
410 | - zeromq=4.3.5=h3b0a872_7
411 | - zipp=3.21.0=pyhd8ed1ab_1
412 | - zlib=1.3.1=hb9d3cd8_2
413 | - zlib-ng=2.2.4=h7955e40_0
414 | - zstandard=0.19.0=py311hd4cff14_0
415 | - zstd=1.5.7=hb8e6e7a_1
416 | - pip:
417 | - black==21.12b0
418 | - blackcellmagic==0.0.3
419 | - click==8.1.8
420 | - fair-esm==2.0.0
421 | - hatch-cython==0.5.1
422 | - jupyter==1.1.1
423 | - jupyter-console==6.6.3
424 | - notebook==7.3.2
425 | - ordered-set==4.1.0
426 | - pathspec==0.12.1
427 | - rdkit==2024.9.5
428 | - tomli==1.2.3
429 | prefix: /disk2/fli/miniconda3/envs/SSMuLA
430 |
--------------------------------------------------------------------------------