├── data ├── metrics │ └── .placeholder ├── wt.fa.dvc ├── experimental_data.csv.dvc ├── starting_sequences.fa.dvc ├── .gitignore └── starting_sequences.a2m.dvc ├── .dvc ├── config └── .gitignore ├── requirements_files ├── requirements-transformers.txt ├── requirements-fair-esm.txt ├── requirements-vespa.txt ├── requirements-badass.txt └── requirements-evmutation.txt ├── images └── fig1.png ├── .codecarbon.config ├── docs ├── _build │ └── latex │ │ └── aide.pdf ├── modules.rst ├── aide_predict.io.rst ├── aide_predict.rst ├── Makefile ├── aide_predict.bespoke_models.rst ├── aide_predict.utils.data_structures.rst ├── index.rst ├── aide_predict.bespoke_models.embedders.rst ├── user_guide │ ├── pipelines.md │ ├── roadmap.md │ ├── saturation_mutagenesis.md │ ├── resource_test.md │ ├── caching.md │ ├── structure_pred.md │ ├── supervised.md │ ├── msa_search.md │ ├── installation.md │ ├── position_specific.md │ ├── model_compatibility.md │ ├── contributing_models.md │ └── badass.md ├── aide_predict.utils.rst ├── aide_predict.bespoke_models.predictors.rst └── conf.py ├── showcase ├── p740_precision.png ├── p740_best_pipeline.pkl ├── epistatic_benchmarking.pdf ├── p740_model_comparison.png ├── nonlinear_vs_linear_top10.png ├── wt_petase_model_creation.pdf ├── nonlinear_vs_linear_kendall.png ├── zs_vs_supervised_vs_augmented.png ├── README.md └── fig3_data.csv ├── tests ├── __init__.py ├── test_utils │ ├── __init__.py │ ├── test_data_structures │ │ └── __init__.py │ ├── test_msa.py │ └── test_conservation.py ├── test_not_base_models │ ├── __init__.py │ ├── test_evmutation.py │ ├── test_vespa.py │ ├── test_eve.py │ ├── test_ssemb_pred.py │ ├── test_esm2_loglike.py │ ├── test_msatrans_loglike.py │ ├── test_saprot_loglike.py │ ├── test_esm2_embedding.py │ └── test_badass.py ├── test_bespoke_models │ ├── __init__.py │ ├── test_predictors │ │ ├── __init__.py │ │ └── test_hmm.py │ └── test_embedders │ │ └── test_kmer.py └── data │ ├── two_sequences.fa │ ├── some_sequences.fasta │ └── hmm-17.fa ├── .dvcignore ├── aide_predict ├── io │ ├── __init__.py │ └── bio_files.py ├── utils │ ├── __init__.py │ ├── data_structures │ │ └── __init__.py │ ├── constants.py │ ├── common.py │ ├── checks.py │ ├── alignment_calls.py │ └── plotting.py ├── bespoke_models │ ├── embedders │ │ ├── __init__.py │ │ └── kmer.py │ ├── predictors │ │ ├── __init__.py │ │ └── vespa.py │ └── __init__.py ├── __init__.py └── patches_.py ├── setup.py ├── environment.yaml ├── pytest.ini ├── slurm_submit_dvc.sh ├── .github └── workflows │ ├── docs.yaml │ └── ci-tests.yml ├── .coveragerc ├── LICENSE ├── dvc.yaml ├── external_calls └── eve │ ├── _default_model_params.json │ ├── _train_VAE_one.py │ └── _compute_evol_indices_one.py ├── scripts └── process_msa.py ├── .gitignore ├── dvc.lock └── params.yaml /data/metrics/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | [core] 2 | autostage = true 3 | -------------------------------------------------------------------------------- /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /tmp 3 | /cache 4 | -------------------------------------------------------------------------------- /requirements_files/requirements-transformers.txt: -------------------------------------------------------------------------------- 1 | transformers[torch] 2 | -------------------------------------------------------------------------------- /images/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/images/fig1.png -------------------------------------------------------------------------------- /requirements_files/requirements-fair-esm.txt: -------------------------------------------------------------------------------- 1 | fair-esm 2 | transformers[torch] 3 | 4 | -------------------------------------------------------------------------------- /requirements_files/requirements-vespa.txt: -------------------------------------------------------------------------------- 1 | transformers[torch] 2 | vespa-effect 3 | h5py 4 | -------------------------------------------------------------------------------- /.codecarbon.config: -------------------------------------------------------------------------------- 1 | [codecarbon] 2 | experiment_id = c34265b9-d544-47e6-b82b-6a928334f0c7 3 | 4 | -------------------------------------------------------------------------------- /docs/_build/latex/aide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/docs/_build/latex/aide.pdf -------------------------------------------------------------------------------- /requirements_files/requirements-badass.txt: -------------------------------------------------------------------------------- 1 | badass @ git+https://github.com/EvanKomp/BADASS@69f3723 2 | 3 | 4 | -------------------------------------------------------------------------------- /showcase/p740_precision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/p740_precision.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # File: /protein-prediction/tests/__init__.py 2 | 3 | # This file is intentionally left blank. -------------------------------------------------------------------------------- /data/wt.fa.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: 592b18586d4530aaef6557b77c963ee3 3 | size: 78 4 | hash: md5 5 | path: wt.fa 6 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | aide_predict 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | aide_predict 8 | -------------------------------------------------------------------------------- /showcase/p740_best_pipeline.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/p740_best_pipeline.pkl -------------------------------------------------------------------------------- /showcase/epistatic_benchmarking.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/epistatic_benchmarking.pdf -------------------------------------------------------------------------------- /showcase/p740_model_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/p740_model_comparison.png -------------------------------------------------------------------------------- /showcase/nonlinear_vs_linear_top10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/nonlinear_vs_linear_top10.png -------------------------------------------------------------------------------- /showcase/wt_petase_model_creation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/wt_petase_model_creation.pdf -------------------------------------------------------------------------------- /showcase/nonlinear_vs_linear_kendall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/nonlinear_vs_linear_kendall.png -------------------------------------------------------------------------------- /showcase/zs_vs_supervised_vs_augmented.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvanKomp/aide_predict/HEAD/showcase/zs_vs_supervised_vs_augmented.png -------------------------------------------------------------------------------- /data/experimental_data.csv.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: d41d8cd98f00b204e9800998ecf8427e 3 | size: 0 4 | hash: md5 5 | path: experimental_data.csv 6 | -------------------------------------------------------------------------------- /data/starting_sequences.fa.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: d41d8cd98f00b204e9800998ecf8427e 3 | size: 0 4 | hash: md5 5 | path: starting_sequences.fa 6 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /jackhmmer 2 | /wt.fa 3 | /experimental_data.csv 4 | /starting_sequences.a2m 5 | /starting_sequences.fa 6 | /run_msa 7 | /process_msa 8 | -------------------------------------------------------------------------------- /data/starting_sequences.a2m.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: 624275f98f9f67c5939e3aba6e4a873d 3 | size: 173269062 4 | hash: md5 5 | path: starting_sequences.a2m 6 | -------------------------------------------------------------------------------- /requirements_files/requirements-evmutation.txt: -------------------------------------------------------------------------------- 1 | evcouplings @ git+https://github.com/debbiemarkslab/EVcouplings@374d4c5 2 | ruamel-yaml<0.18.0 3 | numba 4 | -------------------------------------------------------------------------------- /.dvcignore: -------------------------------------------------------------------------------- 1 | # Add patterns of files dvc should ignore, which could improve 2 | # the performance. Learn more at 3 | # https://dvc.org/doc/user-guide/dvcignore 4 | -------------------------------------------------------------------------------- /aide_predict/io/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/io/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/7/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | ''' -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/utils/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /aide_predict/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/7/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | ''' 7 | 8 | -------------------------------------------------------------------------------- /tests/test_not_base_models/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/3/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name='aide_predict', 4 | version='1.0', 5 | packages=['aide_predict'], 6 | extras_require={ 7 | 'test': ['pytest', 'pytest-cov'] 8 | } 9 | ) 10 | -------------------------------------------------------------------------------- /aide_predict/bespoke_models/embedders/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/bespoke_models/embedders/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/5/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /tests/test_bespoke_models/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /tests/test_utils/test_data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/test_utils/data_structures/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/10/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /aide_predict/bespoke_models/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/bespoke_models/predictors/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /tests/test_bespoke_models/test_predictors/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_predictors/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: aidep 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - bioconda 6 | dependencies: 7 | - python<=3.11 8 | - hmmer=3.3.2 9 | - mmseqs2 10 | - plmc 11 | - cffi 12 | - mafft 13 | - pip 14 | - foldseek 15 | - pip: 16 | - numpy<2.0 17 | - scikit-learn 18 | - pandas 19 | - h5py 20 | - tqdm 21 | - biopython==1.84 22 | 23 | 24 | -------------------------------------------------------------------------------- /aide_predict/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/7/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | ''' 7 | from .patches_ import * 8 | from .bespoke_models import * 9 | from .utils.data_structures import * 10 | from .utils.checks import check_model_compatibility, get_supported_tools 11 | -------------------------------------------------------------------------------- /aide_predict/utils/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/data_structures/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/10/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | from .sequences import ProteinCharacter, ProteinSequence, ProteinSequences, ProteinSequencesOnFile 9 | from .structures import ProteinStructure, StructureMapper -------------------------------------------------------------------------------- /tests/data/two_sequences.fa: -------------------------------------------------------------------------------- 1 | >PET1 2 | MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGA 3 | IAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWS 4 | MGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCA 5 | NSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEHHHHHH 6 | >PET2 7 | ANPYERGPNPTDALLEARSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGFV 8 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -v --cov=aide_predict --cov-report=xml --cov-config=.coveragerc 3 | testpaths = tests/ 4 | python_files = test_*.py 5 | python_classes = Test* 6 | python_functions = test_* 7 | 8 | # Exclude specific test directories 9 | norecursedirs = tests/test_not_base_models 10 | 11 | # Custom marks 12 | markers = 13 | slow: marks tests as slow (deselect with '-m "not slow"') 14 | optional: marks tests that require optional dependencies 15 | -------------------------------------------------------------------------------- /docs/aide_predict.io.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.io package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | aide\_predict.io.bio\_files module 8 | ---------------------------------- 9 | 10 | .. automodule:: aide_predict.io.bio_files 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: aide_predict.io 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/aide_predict.rst: -------------------------------------------------------------------------------- 1 | aide\_predict package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | aide_predict.bespoke_models 11 | aide_predict.io 12 | aide_predict.utils 13 | 14 | Submodules 15 | ---------- 16 | 17 | aide\_predict.patches\_ module 18 | ------------------------------ 19 | 20 | .. automodule:: aide_predict.patches_ 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Module contents 26 | --------------- 27 | 28 | .. automodule:: aide_predict 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/aide_predict.bespoke_models.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.bespoke\_models package 2 | ===================================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | aide_predict.bespoke_models.embedders 11 | aide_predict.bespoke_models.predictors 12 | 13 | Submodules 14 | ---------- 15 | 16 | aide\_predict.bespoke\_models.base module 17 | ----------------------------------------- 18 | 19 | .. automodule:: aide_predict.bespoke_models.base 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: aide_predict.bespoke_models 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /slurm_submit_dvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=standard 3 | #SBATCH --account=proteinml 4 | #SBATCH --job-name=dvc_stage 5 | #SBATCH --output=dvc_stage.out 6 | #SBATCH --nodes=1 7 | #SBATCH --mem=100GB # memory per node 8 | #SBATCH --time=0-12:30 # Max time (DD-HH:MM) 9 | 10 | ############### CARBON TRACKING 11 | # start codecarbon 12 | PID=$(/projects/bpms/ekomp_tmp/software/carbon/start_tracker.sh) 13 | echo main 14 | echo $PID 15 | # Save its PID 16 | 17 | # Define a cleanup function 18 | cleanup() { 19 | echo "Cleaning up..." 20 | kill -SIGINT $PID 21 | sleep 10 22 | } 23 | # # Set the trap 24 | trap cleanup EXIT 25 | #################### END CARBON TRACKING 26 | dvc repro -s $1 --force 27 | -------------------------------------------------------------------------------- /aide_predict/utils/constants.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/constants.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/11/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | 9 | 10 | AA_MAP = { 11 | 'A': 'ALA', 12 | 'R': 'ARG', 13 | 'N': 'ASN', 14 | 'D': 'ASP', 15 | 'C': 'CYS', 16 | 'Q': 'GLN', 17 | 'E': 'GLU', 18 | 'G': 'GLY', 19 | 'H': 'HIS', 20 | 'I': 'ILE', 21 | 'L': 'LEU', 22 | 'K': 'LYS', 23 | 'M': 'MET', 24 | 'F': 'PHE', 25 | 'P': 'PRO', 26 | 'S': 'SER', 27 | 'T': 'THR', 28 | 'W': 'TRP', 29 | 'Y': 'TYR', 30 | 'V': 'VAL', 31 | } 32 | AA_SINGLE = set(AA_MAP.keys()) 33 | GAP_CHARACTERS = set(['-', '.']) 34 | NON_CONONICAL_AA_SINGLE = set(['B', 'Z', 'X', 'J', 'O', 'U']) 35 | -------------------------------------------------------------------------------- /docs/aide_predict.utils.data_structures.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.utils.data\_structures package 2 | ============================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | aide\_predict.utils.data\_structures.sequences module 8 | ----------------------------------------------------- 9 | 10 | .. automodule:: aide_predict.utils.data_structures.sequences 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | aide\_predict.utils.data\_structures.structures module 16 | ------------------------------------------------------ 17 | 18 | .. automodule:: aide_predict.utils.data_structures.structures 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: aide_predict.utils.data_structures 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: documentation 2 | 3 | on: [push, pull_request, workflow_dispatch] 4 | 5 | permissions: 6 | contents: write 7 | pages: write # Add this 8 | id-token: write # Add this 9 | 10 | jobs: 11 | docs: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-python@v5 16 | 17 | - name: Install dependencies 18 | run: | 19 | pip install sphinx sphinx_rtd_theme myst_parser 20 | 21 | - name: Sphinx build 22 | run: | 23 | sphinx-build -b html docs _build/html 24 | touch _build/html/.nojekyll 25 | 26 | - name: Deploy to GitHub Pages 27 | uses: peaceiris/actions-gh-pages@v3 28 | with: 29 | github_token: ${{ secrets.GITHUB_TOKEN }} 30 | publish_dir: _build/html 31 | force_orphan: true 32 | clean: true 33 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = aide_predict 3 | 4 | [report] 5 | exclude_lines = 6 | pragma: no cover 7 | def __repr__ 8 | if self.debug: 9 | if __name__ == .__main__.: 10 | raise NotImplementedError 11 | pass 12 | except ImportError: 13 | def _transform 14 | def _fit 15 | def _partial_fit 16 | 17 | omit = 18 | aide_predict/bespoke_models/predictors/esm2.py 19 | aide_predict/bespoke_models/predictors/msa_transformer.py 20 | aide_predict/bespoke_models/predictors/vespa.py 21 | aide_predict/bespoke_models/predictors/eve.py 22 | aide_predict/bespoke_models/predictors/ssemb.py 23 | 24 | aide_predict/bespoke_models/embedders/esm2.py 25 | aide_predict/bespoke_models/embedders/msa_transformer.py 26 | aide_predict/bespoke_models/embedders/saprot.py 27 | aide_predict/bespoke_models/predictors/saprot.py 28 | aide_predict/bespoke_models/embedders/ssemb.py 29 | 30 | aide_predict/utils/badass.py 31 | aide_predict/utils/soloseq.py 32 | aide_predict/utils/mmseqs_msa_search.py 33 | -------------------------------------------------------------------------------- /showcase/README.md: -------------------------------------------------------------------------------- 1 | # A couple of tangible use cases for the porject. 2 | 3 | For a brief and wide demo of what the package can do, see `demo` 4 | 5 | ## 1. Benchmark unsupervised, supervised, and combination models on a 4 site epistatic combinatorial library 6 | Paper for the data: https://www.biorxiv.org/content/10.1101/2024.06.23.600144v1 7 | 8 | Conducted in `epistatic_benchmarking.ipynb`. In this notebook, the dataset is tested in cross validation for: 9 | 10 | 1. Unsupervised models: MSATransformer, ESM2, EVMutation 11 | 2. Supervised models: One-hot, ESM2 embeddings, into Linear model or MLP. Random search conducted for hyperparameter optimization. 12 | 13 | 14 | ## 2. Creating a WT sequence PETase acitivity prediction model 15 | 16 | Data from our recent paper: [] 17 | 18 | Conducted in `wt_petase_model_creation.ipynb`. In this notebook, we try a number of embedding and modeling strategies with extensive hyperparameter optimization to predict PETase activity at low pH on crystaline powder. The final model is dumped and can be opened to make predictions. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EvanKomp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | AIDE: Protein Property Prediction Made Simple 2 | ============================================= 3 | 4 | .. include:: ./header.md 5 | :parser: myst_parser.sphinx_ 6 | 7 | User Guide 8 | --------- 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Getting Started 13 | 14 | user_guide/installation.md 15 | user_guide/api_examples.md 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: User Guide 20 | 21 | user_guide/data_structures.md 22 | user_guide/model_compatibility.md 23 | user_guide/protein_model.md 24 | user_guide/zero_shot.md 25 | user_guide/supervised.md 26 | user_guide/saturation_mutagenesis.md 27 | user_guide/pipelines.md 28 | user_guide/caching.md 29 | user_guide/position_specific.md 30 | user_guide/contributing_models.md 31 | user_guide/structure_pred.md 32 | user_guide/msa_search.md 33 | user_guide/badass.md 34 | user_guide/roadmap.md 35 | user_guide/resource_test.md 36 | 37 | .. toctree:: 38 | :maxdepth: 2 39 | :caption: API Reference 40 | 41 | modules 42 | 43 | Indices and tables 44 | ================== 45 | 46 | * :ref:`genindex` 47 | * :ref:`modindex` 48 | * :ref:`search` -------------------------------------------------------------------------------- /tests/test_bespoke_models/test_predictors/test_hmm.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_predictors/test_hmm.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import pytest 10 | import pandas as pd 11 | 12 | from aide_predict.bespoke_models.predictors.hmm import HMMWrapper 13 | from aide_predict.utils.data_structures import ProteinSequences 14 | 15 | from sklearn.metrics import roc_auc_score 16 | 17 | def test_hmm(): 18 | """Score HMM 17 on p740 data, should get 0.52 auroc""" 19 | 20 | sequences = ProteinSequences.from_fasta(os.path.join('tests', 'data', 'hmm-17.fa')) 21 | 22 | test_data = pd.read_csv(os.path.join('tests', 'data', 'p740_labels.csv')) 23 | labels = test_data['target'].values > 1e-5 # testing for nonzero activity 24 | seq_dict = test_data.set_index('id')['sequence'].to_dict() 25 | test_sequences = ProteinSequences.from_dict(seq_dict) 26 | 27 | model = HMMWrapper(metadata_folder='./tmp/hmm', threshold=0.0) 28 | model.fit(sequences) 29 | predictions = model.predict(test_sequences) 30 | 31 | auroc = roc_auc_score(labels, predictions) 32 | assert abs(0.52 - auroc) < 0.03 33 | 34 | if __name__ == "__main__": 35 | test_hmm() -------------------------------------------------------------------------------- /aide_predict/utils/common.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/common.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/11/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | Common utility functions 9 | 10 | ''' 11 | from types import SimpleNamespace 12 | 13 | 14 | def convert_dvc_params(dvc_params_dict: dict): 15 | """DVC Creates a nested dict with the parameters. 16 | 17 | We want an object that has nested attributes so that we can 18 | access parameters with dot notation. 19 | """ 20 | def _dict_to_obj(d): 21 | if isinstance(d, dict): 22 | return SimpleNamespace(**{k: _dict_to_obj(v) for k, v in d.items()}) 23 | elif isinstance(d, list): 24 | return [_dict_to_obj(x) for x in d] 25 | else: 26 | return d 27 | return _dict_to_obj(dvc_params_dict) 28 | 29 | class MessageBool: 30 | def __init__(self, value, message): 31 | self.message = message 32 | self.value = value 33 | 34 | def __bool__(self): 35 | return self.value 36 | 37 | def wrap(text, width=80): 38 | """ 39 | Wraps a string at a fixed width. 40 | 41 | Arguments 42 | --------- 43 | text : str 44 | Text to be wrapped 45 | width : int 46 | Line width 47 | 48 | Returns 49 | ------- 50 | str 51 | Wrapped string 52 | """ 53 | return "\n".join( 54 | [text[i:i + width] for i in range(0, len(text), width)] 55 | ) 56 | 57 | -------------------------------------------------------------------------------- /aide_predict/bespoke_models/__init__.py: -------------------------------------------------------------------------------- 1 | # aide_predict/bespoke_models/__init__.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/7/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | ''' 7 | from .predictors.hmm import HMMWrapper 8 | from .predictors.esm2 import ESM2LikelihoodWrapper 9 | from .predictors.msa_transformer import MSATransformerLikelihoodWrapper 10 | from .predictors.evmutation import EVMutationWrapper 11 | from .predictors.pretrained_transformers import model_device_context 12 | from .predictors.saprot import SaProtLikelihoodWrapper 13 | from .predictors.vespa import VESPAWrapper 14 | from .predictors.eve import EVEWrapper 15 | from .predictors.ssemb import SSEmbWrapper 16 | 17 | from .embedders.esm2 import ESM2Embedding 18 | from .embedders.ohe import OneHotAlignedEmbedding, OneHotProteinEmbedding 19 | from .embedders.msa_transformer import MSATransformerEmbedding 20 | from .embedders.saprot import SaProtEmbedding 21 | from .embedders.kmer import KmerEmbedding 22 | from .embedders.ssemb import SSEmbEmbedding 23 | 24 | 25 | TOOLS = [ 26 | HMMWrapper, 27 | ESM2LikelihoodWrapper, 28 | MSATransformerLikelihoodWrapper, 29 | EVMutationWrapper, 30 | SaProtLikelihoodWrapper, 31 | VESPAWrapper, 32 | EVEWrapper, 33 | SSEmbWrapper, 34 | 35 | # embedders 36 | ESM2Embedding, 37 | OneHotAlignedEmbedding, 38 | OneHotProteinEmbedding, 39 | MSATransformerEmbedding, 40 | SaProtEmbedding, 41 | KmerEmbedding, 42 | SSEmbEmbedding, 43 | ] 44 | -------------------------------------------------------------------------------- /docs/aide_predict.bespoke_models.embedders.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.bespoke\_models.embedders package 2 | =============================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | aide\_predict.bespoke\_models.embedders.esm2 module 8 | --------------------------------------------------- 9 | 10 | .. automodule:: aide_predict.bespoke_models.embedders.esm2 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | aide\_predict.bespoke\_models.embedders.kmer module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: aide_predict.bespoke_models.embedders.kmer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | aide\_predict.bespoke\_models.embedders.msa\_transformer module 24 | --------------------------------------------------------------- 25 | 26 | .. automodule:: aide_predict.bespoke_models.embedders.msa_transformer 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | aide\_predict.bespoke\_models.embedders.ohe module 32 | -------------------------------------------------- 33 | 34 | .. automodule:: aide_predict.bespoke_models.embedders.ohe 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | aide\_predict.bespoke\_models.embedders.saprot module 40 | ----------------------------------------------------- 41 | 42 | .. automodule:: aide_predict.bespoke_models.embedders.saprot 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: aide_predict.bespoke_models.embedders 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /tests/test_not_base_models/test_evmutation.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_predictors/test_evmutation.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/12/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import pytest 10 | import pandas as pd 11 | import numpy as np 12 | from scipy.stats import spearmanr 13 | 14 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequencesOnFile, ProteinSequence 15 | from aide_predict.bespoke_models.predictors.evmutation import EVMutationWrapper 16 | 17 | def test_evcouplings_zero_shot(): 18 | # Load the data 19 | assay_data = pd.read_csv( 20 | os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 21 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 22 | scores = assay_data['DMS_score'].tolist() 23 | 24 | # Create a small MSA for testing (in practice, you'd use a real MSA) 25 | msa_file = os.path.join('tests', 'data', 'ENVZ_ECOLI_extreme_filtered.a2m') 26 | wt = ProteinSequence.from_fasta(msa_file) 27 | 28 | # Test with standard protocol 29 | model = EVMutationWrapper( 30 | metadata_folder='./tmp/evcouplings', 31 | wt=wt, 32 | protocol="standard", 33 | theta=0.8, 34 | iterations=100, 35 | ) 36 | 37 | model.fit() 38 | print('EVCouplings model fitted!') 39 | predictions = model.predict(sequences) 40 | spearman = spearmanr(scores, predictions)[0] 41 | print(f"EVCouplings Spearman (standard): {spearman}") 42 | assert abs(spearman - 0.1) < 0.05 43 | 44 | if __name__ == "__main__": 45 | test_evcouplings_zero_shot() -------------------------------------------------------------------------------- /docs/user_guide/pipelines.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Combining models into scikit learn pipelines 3 | --- 4 | 5 | # Building ML Pipelines 6 | 7 | AIDE models can be combined with standard scikit-learn components into pipelines. Here's an example that combines one-hot encoding and ESM2 ZS predictions with a random forest: 8 | 9 | ```python 10 | from aide_predict import OneHotProteinEmbedding, ESM2LikelihoodWrapper, ProteinSequence, ProteinSequences 11 | from sklearn.pipeline import Pipeline, FeatureUnion 12 | from sklearn.preprocessing import StandardScaler, FunctionTransformer 13 | from sklearn.ensemble import RandomForestRegressor 14 | 15 | # Load data 16 | sequences = ProteinSequences.from_fasta("sequences.fasta") 17 | y = np.load("activity_values.npy") 18 | 19 | # Create wild type reference 20 | wt = sequences["wild_type"] 21 | 22 | # Create feature union that combines raw OHE with scaled ESM2 scores 23 | features = FeatureUnion([ 24 | # One-hot encoding (keep as binary) 25 | ('ohe', OneHotProteinEmbedding(flatten=True)), 26 | 27 | # ESM2 features (apply scaling) 28 | ('esm2', Pipeline([ 29 | ('predictor', ESM2LikelihoodWrapper(wt=wt, marginal_method="masked_marginal")), 30 | ('reshaper', FunctionTransformer(lambda x: x.reshape(-1, 1))), 31 | ('scaler', StandardScaler()) 32 | ])) 33 | ]) 34 | 35 | # Create and train pipeline 36 | pipeline = Pipeline([ 37 | ('features', features), 38 | ('rf', RandomForestRegressor()) 39 | ]) 40 | 41 | pipeline.fit(sequences, y) 42 | predictions = pipeline.predict(sequences) 43 | ``` 44 | 45 | The pipeline can be saved and loaded like any scikit-learn model: 46 | 47 | ```python 48 | from joblib import dump, load 49 | dump(pipeline, 'protein_model.joblib') 50 | ``` 51 | 52 | All standard scikit-learn tools like `GridSearchCV` or `cross_val_score` can be used with these pipelines. 53 | -------------------------------------------------------------------------------- /docs/user_guide/roadmap.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Roadmap 3 | --- 4 | 5 | # Roadmap 6 | 7 | ## Additional predictors from maintainers 8 | __Quarter 2, 2026__ 9 | 10 | We have the following predictors planned to be wrapped: 11 | 12 | 1. ProteinMPNN (as a zero shot predictor) 13 | - RequiresStructureMixin 14 | - Does not require wildtype, eg. can do mutant marginal on variable length structures. 15 | - Type 3 dependencies (sub environment) 16 | 2. NOMELT (as a zero shot predictor) 17 | - Temperatute specific scores 18 | - Type 3 dependencies 19 | 3. ESM3 (Zero shot prediction and embedder) 20 | - Structure aware if available 21 | - For now, the annotation data mode will not be considered. This will require an additional attribute of the ProteinSequence class (in addition to sequence, structure, msa data types already supported). 22 | 23 | ## Contributions from the community 24 | While we will maintain the codebase and address bugs identified by the community, the usefulness of the tool will ultimately require contributions from the community _a la_ higgingface scikit-learn, etc. When / If we add (2) community contributed models we will start developping the next major version (v2) 25 | 26 | ## Major update v2 27 | __Undetermined__ 28 | 29 | The component specification and software engineering exercise conducted in AIDE for different types of predictors would also be helpful for the variable and dispirate __generator__ methods available, eg. methods for producing new sequences. These broadly categorize into: 30 | - Unconditional generators 31 | - Conditional generators (eg. infilling, homolog aware, structure or other property conditioning) 32 | - ProteinMPNN 33 | - NOMELT 34 | - Tranception 35 | - Score optimizers (black box or maybe gradient aware), eg. BADASS (already included in the package) which use a scoring function (AIDE predictor) to bias generation. 36 | -------------------------------------------------------------------------------- /.github/workflows/ci-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.9] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | 20 | - name: Set up conda ${{ matrix.python-version }} 21 | uses: conda-incubator/setup-miniconda@v2 22 | with: 23 | miniconda-version: "latest" 24 | auto-update-conda: true 25 | python-version: ${{ matrix.python-version }} 26 | channels: conda-forge,defaults,bioconda 27 | channel-priority: flexible 28 | activate-environment: test_env 29 | 30 | - name: Install dependencies 31 | shell: bash -l {0} 32 | run: | 33 | conda install -n base conda-libmamba-solver 34 | conda config --set solver libmamba 35 | conda install mamba -n base -c conda-forge 36 | mamba env update --file environment.yaml --name test_env 37 | conda activate test_env 38 | echo "Current Python: $(which python)" 39 | mamba install pytest pytest-cov 40 | pip install pytest pytest-cov 41 | echo "Installed packages:" 42 | conda list 43 | 44 | - name: Run tests with pytest and coverage 45 | shell: bash -l {0} 46 | run: | 47 | conda activate test_env 48 | echo "Python being used: $(which python)" 49 | echo "Pytest version: $(pytest --version)" 50 | pytest -v -m "not slow and not optional" --cov=aide_predict --cov-report=xml --cov-config=.coveragerc 51 | 52 | 53 | - name: Upload coverage to Codecov 54 | uses: codecov/codecov-action@v1 55 | env: 56 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 57 | with: 58 | file: ./coverage.xml 59 | flags: unittests 60 | fail_ci_if_error: false 61 | -------------------------------------------------------------------------------- /dvc.yaml: -------------------------------------------------------------------------------- 1 | stages: 2 | run_msa: 3 | cmd: python scripts/run_msa.py 4 | params: 5 | - use_msa 6 | - msa_creation.msa_mode 7 | 8 | # if msa_mode is starting_sequences 9 | - msa_creation.starting_sequences.prealigned 10 | - msa_creation.starting_sequences.add_training_sequences 11 | - msa_creation.starting_sequences.activity_targets 12 | - msa_creation.starting_sequences.activity_threshold 13 | 14 | # if msa_mode is jackhmmer 15 | - msa_creation.jackhmmer.seqdb 16 | - msa_creation.jackhmmer.iterations 17 | - msa_creation.jackhmmer.domain_threshold 18 | - msa_creation.jackhmmer.sequence_threshold 19 | - msa_creation.jackhmmer.use_bitscores 20 | - msa_creation.jackhmmer.sequence_identity_filter 21 | - msa_creation.jackhmmer.theta 22 | - msa_creation.jackhmmer.minimum_column_coverage 23 | - msa_creation.jackhmmer.minimum_sequence_coverage 24 | - msa_creation.jackhmmer.cpus 25 | - msa_creation.jackhmmer.mx 26 | deps: 27 | - data/starting_sequences.fa 28 | - data/starting_sequences.a2m 29 | - data/experimental_data.csv 30 | - data/wt.fa 31 | - scripts/run_msa.py 32 | outs: 33 | - data/run_msa/ 34 | metrics: 35 | - data/metrics/run_msa.json: 36 | cache: false 37 | process_msa: 38 | cmd: python scripts/process_msa.py 39 | params: 40 | - msaprocessing.theta 41 | - msaprocessing.use_weights 42 | - msaprocessing.preprocess 43 | - msaprocessing.threshold_sequence_frac_gaps 44 | - msaprocessing.threshold_focus_cols_frac_gaps 45 | - msaprocessing.remove_sequences_with_indeterminate_AA_in_focus_cols 46 | - msaprocessing.additional_weights 47 | deps: 48 | - data/run_msa/ 49 | - scripts/process_msa.py 50 | - data/wt.fa 51 | outs: 52 | - data/process_msa/ 53 | metrics: 54 | - data/metrics/process_msa.json: 55 | cache: false 56 | -------------------------------------------------------------------------------- /external_calls/eve/_default_model_params.json: -------------------------------------------------------------------------------- 1 | { "encoder_parameters": { 2 | "hidden_layers_sizes" : [2000,1000,300], 3 | "z_dim" : 50, 4 | "convolve_input" : false, 5 | "convolution_input_depth" : 40, 6 | "nonlinear_activation" : "relu", 7 | "dropout_proba" : 0.0 8 | }, 9 | "decoder_parameters": { 10 | "hidden_layers_sizes" : [300,1000,2000], 11 | "z_dim" : 50, 12 | "bayesian_decoder" : true, 13 | "first_hidden_nonlinearity" : "relu", 14 | "last_hidden_nonlinearity" : "relu", 15 | "dropout_proba" : 0.1, 16 | "convolve_output" : true, 17 | "convolution_output_depth" : 40, 18 | "include_temperature_scaler" : true, 19 | "include_sparsity" : false, 20 | "num_tiles_sparsity" : 0, 21 | "logit_sparsity_p" : 0 22 | }, 23 | "training_parameters": { 24 | "num_training_steps" : 400000, 25 | "learning_rate" : 1e-4, 26 | "batch_size" : 256, 27 | "annealing_warm_up" : 0, 28 | "kl_latent_scale" : 1.0, 29 | "kl_global_params_scale" : 1.0, 30 | "l2_regularization" : 0.0, 31 | "use_lr_scheduler" : false, 32 | "use_validation_set" : false, 33 | "validation_set_pct" : 0.2, 34 | "validation_freq" : 1000, 35 | "log_training_info" : true, 36 | "log_training_freq" : 1000, 37 | "save_model_params_freq" : 500000 38 | } 39 | } 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /docs/user_guide/saturation_mutagenesis.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: In silico Saturation Mutagenesis 3 | --- 4 | # Saturation Mutagenesis 5 | 6 | ## Overview 7 | 8 | We provide tools to quickly run in silico saturation mutagenesis. 9 | 10 | 11 | Create a `ProteinSequences` object of all single point mutations. 12 | 13 | ```python 14 | from aide_predict import ProteinSequence, ESM2LikelihoodWrapper 15 | import pandas as pd 16 | 17 | # Define wild type sequence 18 | wt = ProteinSequence( 19 | "MKLLVLGLPGAGKGT", 20 | id="wild_type" 21 | ) 22 | 23 | # Generate all single mutants 24 | mutant_library = wt.saturation_mutagenesis() 25 | print(f"Generated {len(mutant_library)} variants") 26 | >>> Generated 285 variants # (15 positions × 19 possible mutations) 27 | ``` 28 | 29 | Then pass these to a zero shot predictor of your choice: 30 | 31 | ```python 32 | # Score variants using a zero-shot predictor 33 | model = ESM2LikelihoodWrapper( 34 | wt=wt, 35 | marginal_method="masked_marginal", 36 | pool=True # Get one score per variant 37 | ) 38 | model.fit() # No training needed 39 | scores = model.predict(mutant_library) 40 | 41 | # Create results dataframe 42 | results = pd.DataFrame({ 43 | 'mutation': mutant_library.ids, # e.g., "M1A", "K2R", etc. 44 | 'sequence': mutant_library, 45 | 'prediction': scores 46 | }) 47 | 48 | # Sort by predicted effect 49 | results = results.sort_values('prediction', ascending=False) 50 | print("Top 5 predicted beneficial mutations:") 51 | print(results.head()) 52 | ``` 53 | 54 | ## Visualizing Results 55 | 56 | AIDE provides built-in visualization tools for mutation effects: 57 | 58 | ```python 59 | from aide_predict.utils.plotting import plot_mutation_heatmap 60 | 61 | # Create heatmap of mutation effects 62 | plot_mutation_heatmap(results['mutation'], results['prediction']) 63 | ``` 64 | 65 | The heatmap shows the predicted effect of each possible amino acid substitution at each position, making it easy to identify patterns and hotspots for engineering. 66 | 67 | ## Notes 68 | - The `mutation` IDs follow standard notation: "M1A" means the M at position 1 was mutated to A -------------------------------------------------------------------------------- /docs/user_guide/resource_test.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Resource benchmarking 3 | --- 4 | # Resource testing 5 | 6 | See below for the cost it takes to run each tool. This test was run with a Dual socket Intel Xeon Sapphire Rapids 52 core CPU. When the model supports GPU, one NVIDIA H100 was provided. 7 | 8 | The test system was a GFP (238) amino acids, MSA depth (when applicable) was 201. Times measure the total time to fit the model (when applicable) and run prediction on 50 variants. Missing values are either because the core model does not support that type of prediction or because AIDE's wrapper does not support it. 9 | 10 | NOTE: The cost of some of these (*) is significantly impacted by hyperparameters. 11 | 12 | ## Zero shot predictors 13 | 14 | | Model Name | Marginal Method | GPU Total Time (s) | CPU Total Time (s) | 15 | |------------|----------------|-------------------|-------------------| 16 | | HMMWrapper | - | - | 0.136 | 17 | | ESM2LikelihoodWrapper | wildtype_marginal | 0.980 | 2.560 | 18 | | ESM2LikelihoodWrapper | mutant_marginal | 0.534 | 30.837 | 19 | | ESM2LikelihoodWrapper | masked_marginal | 0.718 | 62.507 | 20 | | MSATransformerLikelihoodWrapper | wildtype_marginal | 4.067 | 33.974 | 21 | | MSATransformerLikelihoodWrapper | mutant_marginal | 57.297 | Timeout (>1800s) | 22 | | MSATransformerLikelihoodWrapper | masked_marginal | 110.086 | Timeout (>1800s) | 23 | | EVMutationWrapper | - | - | 96.697 | 24 | | SaProtLikelihoodWrapper | wildtype_marginal | 5.356 | 24.291 | 25 | | SaProtLikelihoodWrapper | mutant_marginal | 7.326 | 220.906 | 26 | | SaProtLikelihoodWrapper | masked_marginal | 14.814 | 429.626 | 27 | | VESPAWrapper | - | 244.852 | - | 28 | | EVEWrapper * | - | 925.930 | - | 29 | | SSEmbWrapper | - | 192.999 | - | 30 | 31 | ## Embedders 32 | 33 | Cost for embedding 21 GFP sequences. 34 | 35 | | Model Name | GPU Total Time (s) | CPU Total Time (s) | 36 | |------------|-------------------|-------------------| 37 | | ESM2Embedding | 0.887 | 1.477 | 38 | | OneHotAlignedEmbedding | - | 0.092 | 39 | | OneHotProteinEmbedding | - | 0.023 | 40 | | MSATransformerEmbedding | 18.962 | 62.653 | 41 | | SaProtEmbedding | 10.439 | 32.360 | 42 | | KmerEmbedding | - | 0.005 | 43 | | SSEmbEmbedding | 665.772 | - | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/user_guide/caching.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Caching expensive model calls 3 | --- 4 | 5 | # Caching Model Outputs 6 | 7 | ## Overview 8 | 9 | Some AIDE models support caching their outputs to disk to avoid recomputing expensive transformations. This is made available with the `CacheMixin` class, which is inherited by models that support caching. You can check if a model supports caching by checking if it inherits from `CacheMixin`: 10 | 11 | ```python 12 | from aide_predict.bespoke_models.base import CacheMixin 13 | assert isinstance(model, CacheMixin) # True if model supports caching 14 | ``` 15 | 16 | ## Using Caches 17 | 18 | Caching is enabled by default for models that support it. To explicitly control caching: 19 | 20 | ```python 21 | from aide_predict import ESM2Embedding 22 | 23 | # Disable caching 24 | model = ESM2Embedding(use_cache=False) 25 | 26 | # Enable caching (default) 27 | model = ESM2Embedding(use_cache=True) 28 | ``` 29 | 30 | ## How It Works 31 | 32 | - Each protein sequence gets a unique hash based on its sequence, ID, and structure (if present) 33 | - Outputs are stored in HDF5 format for efficient retrieval 34 | - Cache also hashes the model parameters, so if model parameters change it will not use previous cache values 35 | - Stores metadata in SQLite for quick cache checking 36 | - Caches are stored in the model's metadata folder 37 | 38 | ## Models Supporting Caching 39 | 40 | You can check if a model supports caching by checking if it inherits from `CacheMixin`: 41 | ```python 42 | from aide_predict.bespoke_models.base import CacheMixin 43 | 44 | isinstance(model, CacheMixin) # True if model supports caching 45 | ``` 46 | 47 | NOTE: When wrapping a new model, it is recommended that `CacheMixin` be inherited first behind `ProteinModelWrapper`. This ensures that the final model outputs after any processing conducted by other mixins is what get cached, preventing any unnecessary recomputation. 48 | 49 | ## Cache Location 50 | 51 | Caches are stored in a `cache` subdirectory of the model's metadata folder: 52 | ```python 53 | # Specify cache location 54 | model = ESM2Embedding(metadata_folder="my_model") 55 | # Creates: my_model/cache/cache.db (metadata) 56 | # my_model/cache/embeddings.h5 (outputs) 57 | 58 | # Random temporary directory if not specified 59 | model = ESM2Embedding() 60 | ``` -------------------------------------------------------------------------------- /tests/test_not_base_models/test_vespa.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_predictors/test_esm2_likelihood.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | This file exists to test wrapped zero shot models against literature values. 9 | 10 | Tests here: 11 | - VESPA zero shot as a DMS predictor for small ProteinGym assay: ENVZ_ECOLI_Ghose_2023 12 | Expected Spearman about 0.135 13 | ''' 14 | import os 15 | import pytest 16 | 17 | import pandas as pd 18 | from scipy.stats import spearmanr 19 | 20 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 21 | 22 | import torch 23 | if torch.cuda.is_available(): 24 | DEVICE = "cuda" 25 | elif torch.backends.mps.is_available(): 26 | DEVICE = "mps" 27 | else: 28 | DEVICE = "cpu" 29 | 30 | @pytest.mark.optional 31 | def test_vespa_zero_shot(): 32 | # this model requires no MSAs 33 | 34 | from aide_predict.bespoke_models.predictors.vespa import VESPAWrapper 35 | 36 | wt_sequence = "LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR" 37 | wt = ProteinSequence(wt_sequence, id="ENVZ_ECOLI") 38 | 39 | model = VESPAWrapper( 40 | wt=wt, 41 | light=True, # Using VESPAl 42 | metadata_folder='./tmp/vespa', 43 | ) 44 | 45 | assay_data = pd.read_csv( 46 | os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 47 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 48 | scores = assay_data['DMS_score'].tolist() 49 | 50 | model.fit(sequences) # This should initialize the VESPA predictor 51 | print('VESPA model fitted!') 52 | predictions = model.predict(sequences) 53 | spearman = spearmanr(scores, predictions)[0] 54 | print(f"VESPA Spearman: {spearman}") 55 | assert abs(spearman - 0.135) < 0.03 56 | 57 | # Repeat for non-light (full VESPA) model 58 | model = VESPAWrapper( 59 | wt=wt, 60 | light=False, 61 | metadata_folder='./tmp/vespa', 62 | ) 63 | model.fit(sequences) 64 | predictions = model.predict(sequences) 65 | spearman = spearmanr(scores, predictions)[0] 66 | print(f"Full VESPA Spearman: {spearman}") 67 | assert abs(spearman - 0.135) < 0.03 68 | 69 | if __name__ == "__main__": 70 | test_vespa_zero_shot() -------------------------------------------------------------------------------- /docs/aide_predict.utils.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.utils package 2 | =========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | aide_predict.utils.data_structures 11 | 12 | Submodules 13 | ---------- 14 | 15 | aide\_predict.utils.alignment\_calls module 16 | ------------------------------------------- 17 | 18 | .. automodule:: aide_predict.utils.alignment_calls 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | aide\_predict.utils.badass module 24 | --------------------------------- 25 | 26 | .. automodule:: aide_predict.utils.badass 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | aide\_predict.utils.checks module 32 | --------------------------------- 33 | 34 | .. automodule:: aide_predict.utils.checks 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | aide\_predict.utils.common module 40 | --------------------------------- 41 | 42 | .. automodule:: aide_predict.utils.common 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | aide\_predict.utils.conservation module 48 | --------------------------------------- 49 | 50 | .. automodule:: aide_predict.utils.conservation 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | aide\_predict.utils.constants module 56 | ------------------------------------ 57 | 58 | .. automodule:: aide_predict.utils.constants 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | aide\_predict.utils.mmseqs\_msa\_search module 64 | ---------------------------------------------- 65 | 66 | .. automodule:: aide_predict.utils.mmseqs_msa_search 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | aide\_predict.utils.msa module 72 | ------------------------------ 73 | 74 | .. automodule:: aide_predict.utils.msa 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | aide\_predict.utils.plotting module 80 | ----------------------------------- 81 | 82 | .. automodule:: aide_predict.utils.plotting 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | aide\_predict.utils.soloseq module 88 | ---------------------------------- 89 | 90 | .. automodule:: aide_predict.utils.soloseq 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | Module contents 96 | --------------- 97 | 98 | .. automodule:: aide_predict.utils 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | -------------------------------------------------------------------------------- /scripts/process_msa.py: -------------------------------------------------------------------------------- 1 | # scripts/prepare_msa.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/9/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | 7 | Cleanup and compute MSA weights from MSA built on WT protein. 8 | This is handled by the MSAProcessing class, Adapted from the Mark's lab (see aide_predict/utils/msa.py for references) 9 | 10 | Note that we can also add additional weights manually via IDs. 11 | 12 | ''' 13 | import os 14 | import dvc.api 15 | import json 16 | 17 | from aide_predict.utils.msa import MSAProcessing, MSAProcessingArgs 18 | from aide_predict.io.bio_files import read_fasta 19 | 20 | import logging 21 | logger = logging.getLogger(__name__) 22 | logging.basicConfig(level=logging.INFO, filemode='w', filename='./logs/prepare_msa.log') 23 | 24 | PARAMS = dvc.api.params_show()['msaprocessing'] 25 | 26 | def main(): 27 | # first prepare directories 28 | EXECDIR = os.getcwd() 29 | if not os.path.exists(os.path.join(EXECDIR, 'data', 'process_msa')): 30 | os.makedirs(os.path.join(EXECDIR, 'data', 'process_msa')) 31 | 32 | # prepare the arguments 33 | args = MSAProcessingArgs( 34 | theta=PARAMS['theta'], 35 | use_weights=PARAMS['use_weights'], 36 | preprocess_MSA=PARAMS['preprocess'], 37 | threshold_focus_cols_frac_gaps=PARAMS['threshold_focus_cols_frac_gaps'], 38 | threshold_sequence_frac_gaps=PARAMS['threshold_sequence_frac_gaps'], 39 | remove_sequences_with_indeterminate_AA_in_focus_cols=PARAMS['remove_sequences_with_indeterminate_AA_in_focus_cols'], 40 | ) 41 | msa = MSAProcessing(args) 42 | 43 | # get the sqequence ID 44 | with open(os.path.join(EXECDIR, 'data', 'wt.fa'), 'r') as f: 45 | try: 46 | iterator = read_fasta(f) 47 | wt_id, _ = next(iterator) 48 | except StopIteration: 49 | wt_id = None 50 | logger.info(f'wt_id: {wt_id}') 51 | 52 | # process the MSA 53 | msa.process( 54 | MSA_location=os.path.join(EXECDIR, 'data', 'run_msa', 'alignment.a2m'), 55 | weights_location=os.path.join(EXECDIR, 'data', 'process_msa', 'weights.npy'), 56 | focus_seq_id=wt_id, 57 | additional_weights=None, 58 | new_a2m_location=os.path.join(EXECDIR, 'data', 'process_msa', 'alignment.a2m'), 59 | ) 60 | metrics = { 61 | 'msa_Neff': msa.Neff, 62 | 'msa_num_seqs': msa.num_sequences, 63 | 'msa_Neff_norm': msa.Neff / len(msa.seq_name_to_sequence[wt_id]), 64 | } 65 | with open(os.path.join(EXECDIR, 'data', 'metrics', 'process_msa.json'), 'w') as f: 66 | json.dump(metrics, f) 67 | 68 | if __name__ == '__main__': 69 | main() -------------------------------------------------------------------------------- /docs/aide_predict.bespoke_models.predictors.rst: -------------------------------------------------------------------------------- 1 | aide\_predict.bespoke\_models.predictors package 2 | ================================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | aide\_predict.bespoke\_models.predictors.esm2 module 8 | ---------------------------------------------------- 9 | 10 | .. automodule:: aide_predict.bespoke_models.predictors.esm2 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | aide\_predict.bespoke\_models.predictors.eve module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: aide_predict.bespoke_models.predictors.eve 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | aide\_predict.bespoke\_models.predictors.evmutation module 24 | ---------------------------------------------------------- 25 | 26 | .. automodule:: aide_predict.bespoke_models.predictors.evmutation 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | aide\_predict.bespoke\_models.predictors.hmm module 32 | --------------------------------------------------- 33 | 34 | .. automodule:: aide_predict.bespoke_models.predictors.hmm 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | aide\_predict.bespoke\_models.predictors.msa\_transformer module 40 | ---------------------------------------------------------------- 41 | 42 | .. automodule:: aide_predict.bespoke_models.predictors.msa_transformer 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | aide\_predict.bespoke\_models.predictors.pretrained\_transformers module 48 | ------------------------------------------------------------------------ 49 | 50 | .. automodule:: aide_predict.bespoke_models.predictors.pretrained_transformers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | aide\_predict.bespoke\_models.predictors.saprot module 56 | ------------------------------------------------------ 57 | 58 | .. automodule:: aide_predict.bespoke_models.predictors.saprot 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | aide\_predict.bespoke\_models.predictors.ssemb module 64 | ----------------------------------------------------- 65 | 66 | .. automodule:: aide_predict.bespoke_models.predictors.ssemb 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | aide\_predict.bespoke\_models.predictors.vespa module 72 | ----------------------------------------------------- 73 | 74 | .. automodule:: aide_predict.bespoke_models.predictors.vespa 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: aide_predict.bespoke_models.predictors 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/user_guide/structure_pred.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Structure Prediction with SoloSeq 3 | --- 4 | 5 | # Structure Prediction with SoloSeq 6 | 7 | We provide a wrapper interface to get protein structure predictions using SoloSeq, a deep learning model for protein structure prediction that requires no MSAs. It is recommended to use crystal structures or run AlphaFold2 for more accurate predictions if your task is deemed very structure sensitive. 8 | 9 | ## Installation 10 | 11 | SoloSeq requires additional setup beyond the base AIDE installation. 12 | 13 | 1. Follow the setup steps [here](https://openfold.readthedocs.io/en/latest/Installation.html) 14 | 15 | Once the environment is setup and unit tests pass: 16 | 17 | 2. Download the SoloSeq model weights: 18 | ```bash 19 | bash scripts/download_openfold_soloseq_params.sh openfold/resources 20 | ``` 21 | 22 | 3. Set environment variables (add to your `.bashrc` or equivalent): 23 | ```bash 24 | export OPENFOLD_CONDA_ENV=openfold_env # Name of conda environment 25 | export OPENFOLD_REPO=/path/to/openfold # Full path to OpenFold repo 26 | ``` 27 | 28 | ## Basic Usage 29 | 30 | AIDE provides a simplified interface to SoloSeq for predicting protein structures: 31 | 32 | ```python 33 | from aide_predict import ProteinSequences 34 | from aide_predict.utils.soloseq import run_soloseq 35 | 36 | # Load sequences 37 | sequences = ProteinSequences.from_fasta("proteins.fasta") 38 | 39 | # Run prediction 40 | pdb_paths = run_soloseq( 41 | sequences=sequences, 42 | output_dir="./predicted_structures" 43 | ) 44 | 45 | # attach predicted structures to sequence using structure mapper 46 | from aide_predict.utils.data_structures.structures import StructureMapper 47 | mapper = StructureMapper("./predicted_structures") 48 | mapper.assign_structures(sequences) 49 | ``` 50 | 51 | ### Command Line Interface 52 | 53 | You can also run predictions directly from the command line: 54 | 55 | ```bash 56 | python -m aide_predict.utils.soloseq proteins.fasta predicted_structures 57 | ``` 58 | 59 | ## Advanced Options 60 | 61 | The function provides several options to control prediction: 62 | 63 | ```python 64 | pdb_paths = run_soloseq( 65 | sequences=sequences, 66 | output_dir="predicted_structures", 67 | use_gpu=True, # Set to False for CPU-only 68 | skip_relaxation=False, # Skip refinement step 69 | save_embeddings=True, # Keep ESM embeddings 70 | device="cuda:0", # Specific GPU device 71 | force=False # Force rerun of existing predictions 72 | ) 73 | ``` 74 | 75 | Command line equivalents: 76 | 77 | ```bash 78 | python -m aide_predict.utils.soloseq proteins.fasta predicted_structures \ 79 | --no_gpu \ 80 | --skip_relaxation \ 81 | --save_embeddings \ 82 | --device cuda:1 \ 83 | --force 84 | ``` 85 | -------------------------------------------------------------------------------- /docs/user_guide/supervised.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Supervised training and prediction 3 | --- 4 | # Supervised Learning 5 | 6 | ## Overview 7 | 8 | AIDE supports supervised machine learning by converting protein sequences into numerical features using embedding models. These features can then be used with any scikit-learn compatible model. 9 | 10 | ## Basic Example 11 | 12 | Here's a complete example using ESM2 embeddings and random forest regression with hyperparameter optimization: 13 | 14 | ```python 15 | from aide_predict import ESM2Embedding, ProteinSequences 16 | from sklearn.ensemble import RandomForestRegressor 17 | from sklearn.model_selection import train_test_split, RandomizedSearchCV 18 | from sklearn.pipeline import Pipeline 19 | from scipy.stats import randint, uniform 20 | import numpy as np 21 | 22 | # Load data 23 | sequences = ProteinSequences.from_fasta("sequences.fasta") 24 | y = np.load("activity_values.npy") 25 | 26 | # Split data 27 | X_train, X_test, y_train, y_test = train_test_split( 28 | sequences, y, test_size=0.2, random_state=42 29 | ) 30 | 31 | # Create pipeline 32 | pipeline = Pipeline([ 33 | ('embedder', ESM2Embedding(pool='max', use_cache=True)), # Create sequence-level embeddings 34 | ('rf', RandomForestRegressor(random_state=42)) 35 | ]) 36 | 37 | # Define parameter space 38 | param_distributions = { 39 | 'rf__n_estimators': randint(100, 500), 40 | 'rf__max_depth': [None] + list(range(10, 50, 10)), 41 | 'rf__min_samples_split': randint(2, 20), 42 | 'rf__min_samples_leaf': randint(1, 10) 43 | } 44 | 45 | # Random search 46 | search = RandomizedSearchCV( 47 | pipeline, 48 | param_distributions=param_distributions, 49 | n_iter=20, # Number of parameter settings sampled 50 | cv=5, # 5-fold cross-validation 51 | n_jobs=-1, # Use all available cores 52 | scoring='r2', 53 | verbose=1 54 | ) 55 | 56 | # Fit model 57 | search.fit(X_train, y_train) 58 | 59 | # Print results 60 | print("\nBest parameters:", search.best_params_) 61 | print("Best CV score:", search.best_score_) 62 | print("Test score:", search.score(X_test, y_test)) 63 | 64 | # Make predictions on new sequences 65 | new_sequences = ProteinSequences.from_fasta("new_sequences.fasta") 66 | predictions = search.predict(new_sequences) 67 | ``` 68 | 69 | ## Saving and loading models 70 | Models can be dumped and loaded with joblib like any other scikit-learn model: 71 | 72 | ```python 73 | import joblib 74 | 75 | # Save the best model 76 | joblib.dump(search.best_estimator_, 'protein_model.joblib') 77 | 78 | # Load the model later 79 | loaded_model = joblib.load('protein_model.joblib') 80 | ``` 81 | 82 | Note that this may currently break the `metadata_folder` attribute of models, unless it is loaded on the same machine in the same location. 83 | In future, protocols to zip up this folder with the model during saving and loading will be provided. -------------------------------------------------------------------------------- /tests/test_not_base_models/test_eve.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_eve.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 10/28/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import sys 10 | import pytest 11 | import pandas as pd 12 | import numpy as np 13 | from scipy.stats import spearmanr 14 | 15 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequencesOnFile, ProteinSequence 16 | from aide_predict.bespoke_models.predictors.eve import EVEWrapper 17 | 18 | import logging 19 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 20 | 21 | @pytest.mark.skipif( 22 | os.environ.get('EVE_CONDA_ENV') is None or os.environ.get('EVE_REPO') is None, 23 | reason="EVE environment variables not set" 24 | ) 25 | def test_eve_zero_shot(): 26 | """ 27 | Test EVE's performance on the ENVZ_ECOLI benchmark dataset from ProteinGym. 28 | 29 | This test: 30 | 1. Uses a small MSA and minimal training steps for quick testing 31 | 2. Verifies the correlation is in the expected range 32 | 3. Checks basic model functionality 33 | """ 34 | # Load the benchmark data 35 | assay_data = pd.read_csv( 36 | os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 37 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 38 | scores = assay_data['DMS_score'].tolist() 39 | 40 | # Load MSA 41 | msa_file = os.path.join('tests', 'data', 'ENVZ_ECOLI_extreme_filtered.a2m') 42 | wt = ProteinSequence.from_fasta(msa_file) 43 | 44 | # Initialize EVE with minimal training parameters for testing 45 | model = EVEWrapper( 46 | metadata_folder='./tmp/eve', 47 | wt=wt, 48 | # Reduce training time for testing 49 | # take default values 50 | training_steps=30000 51 | ) 52 | 53 | print('Fitting EVE model...') 54 | model.fit() 55 | print('EVE model fitted!') 56 | 57 | # ensure briefly that the model is capable of handling multiple mutations 58 | test_sequences = ProteinSequences([ 59 | wt.upper()._mutate(10, 'C')._mutate(11, 'A'), 60 | wt.upper()._mutate(10, 'C')._mutate(11, 'C') 61 | ]) 62 | _ = model.predict(test_sequences) 63 | 64 | print('Making predictions...') 65 | predictions = model.predict(sequences) 66 | spearman = spearmanr(scores, predictions, nan_policy='omit')[0] 67 | print(f"EVE Spearman correlation: {spearman}") 68 | 69 | # The correlation should be in a reasonable range 70 | # Note: This is a minimal model, so we expect lower performance than the full model 71 | assert not np.isnan(spearman), "Correlation should not be NaN" 72 | assert spearman > -1 and spearman < 1, "Correlation should be between -1 and 1" 73 | assert abs(spearman - 0.03) < 0.18, "Correlation should be in expected range" 74 | 75 | if __name__ == '__main__': 76 | test_eve_zero_shot() -------------------------------------------------------------------------------- /tests/test_not_base_models/test_ssemb_pred.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_ssemb.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 4/2/2025 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import sys 10 | import pytest 11 | import pandas as pd 12 | import numpy as np 13 | from scipy.stats import spearmanr 14 | 15 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequencesOnFile, ProteinSequence 16 | from aide_predict.utils.data_structures.structures import ProteinStructure 17 | from aide_predict.bespoke_models.predictors.ssemb import SSEmbWrapper 18 | 19 | import logging 20 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 21 | 22 | @pytest.mark.skipif( 23 | os.environ.get('SSEMB_CONDA_ENV') is None or os.environ.get('SSEMB_REPO') is None, 24 | reason="SSEmb environment variables not set" 25 | ) 26 | def test_ssemb_zero_shot(): 27 | """ 28 | Test SSEmb's performance on the ENVZ_ECOLI benchmark dataset from ProteinGym. 29 | 30 | This test: 31 | 1. Uses a small MSA and protein structure for quick testing 32 | 2. Verifies the correlation is in the expected range 33 | 3. Checks basic model functionality 34 | """ 35 | # Load the benchmark data 36 | assay_data = pd.read_csv( 37 | os.path.join('tests', 'data', 'GFP_AEQVI_Sarkisyan_2016.csv')) 38 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 39 | scores = assay_data['DMS_score'].tolist() 40 | 41 | # Define wild type sequence with structure 42 | pdb_path = os.path.join('tests', 'data', 'GFP_AEQVI_Sarkisyan_2016.pdb') 43 | structure = ProteinStructure(pdb_file=pdb_path) 44 | 45 | # Load MSA 46 | msa_file = os.path.join('tests', 'data', 'GFP_AEQVI_Sarkisyan_2016.a3m') 47 | wt = ProteinSequence.from_fasta(msa_file).upper() 48 | wt.msa = wt.msa.upper() 49 | assert wt.msa.aligned 50 | wt.structure = structure 51 | 52 | # Initialize SSEmb model 53 | model = SSEmbWrapper( 54 | metadata_folder='./tmp/ssemb', 55 | wt=wt, 56 | gpu_id=0 # Use first GPU if available, otherwise CPU will be used 57 | ) 58 | 59 | print('Fitting SSEmb model...') 60 | model.fit() 61 | print('SSEmb model fitted!') 62 | 63 | print('Making predictions...') 64 | predictions = model.predict(sequences) 65 | spearman = spearmanr(scores, predictions, nan_policy='omit')[0] 66 | print(f"SSEmb Spearman correlation: {spearman}") 67 | 68 | # The correlation should be in a reasonable range 69 | assert not np.isnan(spearman), "Correlation should not be NaN" 70 | assert spearman > -1 and spearman < 1, "Correlation should be between -1 and 1" 71 | # Adjust expected correlation range based on SSEmb performance on this dataset 72 | assert abs(spearman - 0.72) < 0.02, "Correlation should be in expected range" 73 | 74 | if __name__ == '__main__': 75 | test_ssemb_zero_shot() -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | sys.path.insert(0, os.path.abspath('..')) 9 | 10 | # -- Project information ----------------------------------------------------- 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 12 | 13 | project = 'aide' 14 | copyright = '2025, Gregg T. Beckham' 15 | author = 'Evan Komp, Gregg T. Beckham' 16 | release = '1.1.01' 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.todo', 'myst_parser'] 22 | 23 | templates_path = ['_templates'] 24 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 25 | 26 | # Add autodoc options to exclude private members 27 | autodoc_default_options = { 28 | 'members': True, 29 | 'undoc-members': False, # Don't include members without docstrings 30 | 'private-members': False, # Don't include _private members 31 | 'special-members': False, # Don't include __special__ members 32 | 'inherited-members': True, 33 | 'show-inheritance': True, 34 | 'exclude-members': '_abc_impl' # Exclude specific members if needed 35 | } 36 | 37 | # Skip all _private members 38 | def skip_private_members(app, what, name, obj, skip, options): 39 | if name.startswith('_'): 40 | return True 41 | return skip 42 | 43 | # Connect the function to the autodoc-skip-member event 44 | def setup(app): 45 | app.connect('autodoc-skip-member', skip_private_members) 46 | 47 | 48 | napoleon_google_docstrings = True 49 | napoleon_numpy_docstrings = True 50 | napoleon_include_init_with_doc = True 51 | napoleon_include_private_with_doc = False 52 | napoleon_include_special_with_doc = False 53 | napoleon_use_admonition_for_examples = True 54 | napoleon_use_admonition_for_notes = True 55 | napoleon_use_admonition_for_references = True 56 | napoleon_use_ivar = True 57 | napoleon_use_param = True 58 | napoleon_use_rtype = True 59 | 60 | # Add these configurations 61 | source_suffix = { 62 | '.rst': 'restructuredtext', 63 | '.md': 'markdown', 64 | } 65 | 66 | # Myst settings 67 | myst_enable_extensions = [ 68 | "colon_fence", 69 | "deflist", 70 | "dollarmath", 71 | "amsmath", 72 | "fieldlist", 73 | "html_admonition", 74 | "html_image", 75 | "replacements", 76 | "smartquotes", 77 | "substitution", 78 | "tasklist", 79 | ] 80 | # Enable math rendering and processing 81 | mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" 82 | myst_dmath_double_inline = True 83 | myst_update_mathjax = True 84 | mathjax3_config = { 85 | "tex": { 86 | "inlineMath": [["\\(", "\\)"]], 87 | "displayMath": [["\\[", "\\]"]], 88 | }, 89 | } 90 | 91 | # -- Options for HTML output ------------------------------------------------- 92 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 93 | 94 | html_theme = 'sphinx_rtd_theme' 95 | html_static_path = ['_static'] 96 | -------------------------------------------------------------------------------- /docs/user_guide/msa_search.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Generating MSAs with MMseqs2 3 | --- 4 | 5 | # Generating MSAs with MMseqs2 6 | 7 | For problems where you have not already determined an MSA with another tool (eg. Jackhmmer, EVCouplings, MMseqs, etc.) AIDE provides a high lavel wrapper for generating Multiple Sequence Alignments (MSAs) using MMseqs2, implementing the sensitive search similar to colabfold. This can be useful when you need MSAs for models like EVMutation, MSATransformer, or EVE. This is literally just calling MMseqs with a few parameters set - all credit should go to the authors of MMseqs and Colabfold: 8 | 9 | Steinegger M and Soeding J. MMseqs2 enables sensitive protein sequence searching for the analysis of massive data sets. Nature Biotechnology, doi: 10.1038/nbt.3988 (2017). 10 | 11 | Mirdita, M., Schütze, K., Moriwaki, Y. et al. ColabFold: making protein folding accessible to all. Nat Methods 19, 679–682 (2022). https://doi.org/10.1038/s41592-022-01488-1 12 | 13 | ## Installation 14 | 15 | 1. Ensure MMseqs2 is installed and available in your PATH: 16 | ```bash 17 | conda install -c bioconda mmseqs2 18 | ``` 19 | 20 | 2. Download the ColabFold database(s): https://colabfold.mmseqs.com/. You will need to point towards this database to run the search. 21 | 22 | ## Basic Usage 23 | 24 | ### Python Interface 25 | 26 | ```python 27 | from aide_predict import ProteinSequences 28 | from aide_predict.utils.mmseqs_msa_search import run_mmseqs_search 29 | 30 | # Load sequences 31 | sequences = ProteinSequences.from_fasta("proteins.fasta") 32 | 33 | # Generate MSAs 34 | msa_paths = run_mmseqs_search( 35 | sequences=sequences, 36 | uniref_db="path/to/uniref30_2302", 37 | output_dir="./msas" 38 | ) 39 | 40 | # Load MSAs for use with models 41 | from aide_predict import ProteinSequences 42 | msas = [ProteinSequences.from_a3m(path) for path in msa_paths] 43 | ``` 44 | 45 | ### Command Line Interface 46 | 47 | You can also run MSA generation directly from the command line: 48 | 49 | ```bash 50 | python -m aide_predict.utils.mmseqs_msa_search \ 51 | proteins.fasta \ 52 | path/to/uniref30_2302 \ 53 | ./msas 54 | ``` 55 | 56 | ## Advanced Options 57 | 58 | The search can be customized with several parameters: 59 | 60 | ```python 61 | msa_paths = run_mmseqs_search( 62 | sequences=sequences, 63 | uniref_db="path/to/uniref30_2302", 64 | output_dir="./msas", 65 | mode='sensitive', # Search sensitivity: 'fast', 'standard', or 'sensitive' 66 | threads=8, # Number of CPU threads 67 | ) 68 | ``` 69 | 70 | Command line equivalents: 71 | 72 | ```bash 73 | python -m aide_predict.utils.mmseqs_msa_search \ 74 | proteins.fasta \ 75 | path/to/uniref30_2302 \ 76 | ./msas \ 77 | --mode sensitive \ 78 | --threads 8 \ 79 | --keep-tmp 80 | ``` 81 | 82 | ## Search Modes 83 | 84 | Three sensitivity modes are available: 85 | - `fast`: Quick search with sensitivity 4.0 86 | - `standard`: Balanced approach with sensitivity 5.7 (default) 87 | - `sensitive`: More thorough search with sensitivity 7.5 88 | 89 | Higher sensitivity will find more distant homologs but takes longer to run. 90 | 91 | ## Output Format 92 | 93 | MSAs are generated in A3M format, one file per input sequence. The files are named based on the sequence IDs in your input FASTA file. These files can be directly used with AIDE's MSA-based models: 94 | 95 | ```python 96 | # Use MSA with a model 97 | from aide_predict import MSATransformerLikelihoodWrapper 98 | 99 | msa = ProteinSequences.from_a3m("msas/sequence1.a3m") 100 | model = MSATransformerLikelihoodWrapper(wt=wt) 101 | model.fit(msa) 102 | ``` -------------------------------------------------------------------------------- /docs/user_guide/installation.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Installation Guide 3 | --- 4 | 5 | # Installing AIDE 6 | 7 | AIDE is designed with a modular architecture to minimize dependency conflicts while providing access to a wide range of protein prediction tools. The base package has minimal dependencies and provides core functionality, while additional components can be installed based on your specific needs. 8 | 9 | ## Quick Install 10 | 11 | The package is not currently available on PyPI, please clone the repo: 12 | 13 | ```bash 14 | git clone https://github.com/beckham-lab/aide_predict 15 | ``` 16 | 17 | For basic functionality, simply install AIDE using: 18 | 19 | ```bash 20 | # Create and activate a new conda environment 21 | conda env create -f environment.yaml 22 | 23 | # Install AIDE 24 | pip install . 25 | ``` 26 | 27 | ## Supported Tools by Installation Level 28 | 29 | AIDE provides bespoke embedders and predictors as additional modules that can be installed. These fall into three categories, with environment weight in mind: those available in the base package, those that can be installed with minimal additional pip dependencies, and those that should be built as an independant environment. 30 | 31 | ### Base Installation 32 | The base installation provides: 33 | - Core data structures for protein sequences and structures 34 | - Sequence alignment utilities 35 | - One-hot encoding embeddings 36 | - K-mer based embeddings 37 | - Basic Hidden Markov Model support 38 | - mmseqs2 MSA generation pipeline 39 | 40 | ### Minor Pip Dependencies 41 | #### Pure `transformers` models 42 | ESM2 and SaProt can be defined with the transformers library. To install these models: 43 | ```bash 44 | pip install -r requirements-transformers.txt 45 | ``` 46 | This enables: 47 | - ESM2 embeddings and likelihood scoring 48 | - SaProt structure-aware embeddings and scoring 49 | 50 | #### MSA Transformer 51 | MSA transformer requires bespoke components from fair-esm: 52 | ```bash 53 | pip install -r requirements-fair-esm.txt 54 | ``` 55 | This enables: 56 | - MSA transformer embeddings and likelihood scoring 57 | 58 | #### EVmutation 59 | For evolutionary coupling analysis: 60 | ```bash 61 | pip install -r requirements-evmutation.txt 62 | ``` 63 | This enables: 64 | - EVMutation for protein mutation effect prediction 65 | 66 | #### VESPA Integration 67 | For conservation-based variant effect prediction: 68 | ```bash 69 | pip install -r requirements-vespa.txt 70 | ``` 71 | 72 | ### Independent Environment 73 | #### EVE Integration 74 | 75 | EVE requires special handling due to its complex environment requirements: 76 | 77 | 1. Clone the EVE repository outside your AIDE directory: 78 | ```bash 79 | git clone https://github.com/OATML/EVE.git 80 | ``` 81 | 82 | 2. Set required environment variables: 83 | ```bash 84 | export EVE_REPO=/path/to/eve/repo 85 | ``` 86 | 87 | 3. Create a dedicated conda environment for EVE following their installation instructions. 88 | 89 | 4. Set the EVE environment name: 90 | ```bash 91 | export EVE_CONDA_ENV=eve_env 92 | ``` 93 | 94 | ## Verifying Your Installation 95 | 96 | You can check which components are available in your installation: 97 | 98 | ```python 99 | from aide_predict.utils.checks import get_supported_tools 100 | print(get_supported_tools()) 101 | ``` 102 | 103 | ## Common Installation Issues 104 | 105 | ### CUDA Compatibility 106 | If you're using GPU-accelerated components (ESMFold, transformers), ensure your CUDA drivers are compatible: 107 | - Check CUDA version: `nvidia-smi` 108 | - Match PyTorch installation with CUDA version 109 | - For Apple Silicon users: Some components may require alternative installations 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /docs/user_guide/position_specific.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Position-Specific Models 3 | --- 4 | 5 | # Position-Specific Models 6 | 7 | ## Overview 8 | 9 | Some protein models can generate outputs for each amino acid position in a sequence. These models use the `PositionSpecificMixin` to handle position selection and output formatting. EG. lanmguage models or one hot encodings. You might want to do this if only a few positions are changing among variants or you have a specific hypothesis about the importance of certain positions. 10 | 11 | ## Using Position-Specific Models 12 | 13 | Position-specific models have three key parameters that control their output. Flatten and pool are mutually exclusive. 14 | 15 | ```python 16 | from aide_predict import ESM2Embedding 17 | 18 | # Basic usage - outputs pooled across all positions 19 | model = ESM2Embedding( 20 | positions=None, # Consider all positions 21 | pool='mean', # Average across positions 22 | flatten=False # because pooling by mean 23 | ) 24 | 25 | # Position-specific - get embeddings for specific positions 26 | model = ESM2Embedding( 27 | positions=[0, 1, 2], # Only these positions 28 | pool=False, # Keep positions separate 29 | flatten=True # Flatten features for each position so we get a single vector 30 | ) 31 | ``` 32 | 33 | ## Output Shapes 34 | 35 | The output shape depends on the parameter combination: 36 | 37 | ```python 38 | # Example with ESM2 (1280-dimensional embeddings) 39 | X = ProteinSequences.from_fasta("sequences.fasta") 40 | 41 | # Default: pooled across positions 42 | model = ESM2Embedding(pool=True) 43 | output = model.transform(X) # Shape: (n_sequences, 1280) 44 | 45 | # Selected positions, no pooling 46 | model = ESM2Embedding( 47 | positions=[0, 1, 2], 48 | pool=False 49 | ) 50 | output = model.transform(X) # Shape: (n_sequences, 3, 1280) 51 | 52 | # Selected positions, no pooling, flattened 53 | model = ESM2Embedding( 54 | positions=[0, 1, 2], 55 | pool=False, 56 | flatten=True 57 | ) 58 | output = model.transform(X) # Shape: (n_sequences, 3*1280) 59 | ``` 60 | 61 | ## Position Specificity for Variable Length Sequences 62 | 63 | In some cases models can be position specific even if not all sequences are the same length, such as when working with homologs. However, to map positions between sequences properly, we need to: 64 | 1. Know the positions of interest in a reference sequence (usually wild type) 65 | 2. Align all sequences 66 | 3. Map the reference positions to positions in the alignment 67 | 68 | AIDE provides tools to handle this workflow: 69 | 70 | ```python 71 | # Start with unaligned sequences 72 | X = ProteinSequences.from_fasta("sequences.fasta") 73 | wt = X['wt'] 74 | wt_positions = [1, 2, 3] # 0-indexed positions of interest in wild type 75 | 76 | # Align sequences 77 | X = X.align_all() 78 | wt.msa = X 79 | 80 | # Get alignment mapping and convert positions 81 | alignment_mapping = X.get_alignment_mapping() 82 | wt_alignment_mapping = alignment_mapping[wt.id] # or use str(hash(wt)) if no ID 83 | aligned_positions = wt_alignment_mapping[wt_positions] 84 | 85 | # Now use these positions in any position-specific model 86 | model = MSATransformerEmbedding( 87 | positions=aligned_positions, 88 | pool=False, 89 | wt=wt, # used to get the alignment to align incoming sequence to. Alternative, wt can be None if all seqs in X have the msa attribute set to X 90 | 91 | ) 92 | model.fit() 93 | embeddings = model.transform(X) 94 | ``` 95 | 96 | ## Implementation Notes 97 | 98 | - If `positions` is specified but `pool=True`, the model will first select the positions then pool across them 99 | - `flatten=True` only applies when `pool=False` and there are multiple dimensions 100 | - Models will raise an error if `positions` are specified but the sequences are not aligned or of fixed length 101 | ``` -------------------------------------------------------------------------------- /tests/test_not_base_models/test_esm2_loglike.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_predictors/test_esm2_likelihood.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/26/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | This file exists to test wrapped zero shot models against literature values. 9 | 10 | Tests here: 11 | - ESM zero shot as a DMS predictor for small ProteinGym assay: ENVZ_ECOLI_Ghose_2023 12 | Expected Spearman about 0.2 13 | ''' 14 | import os 15 | import pytest 16 | 17 | import pandas as pd 18 | from scipy.stats import spearmanr 19 | 20 | from aide_predict.utils.data_structures import ProteinSequences 21 | 22 | import torch 23 | if torch.cuda.is_available(): 24 | DEVICE = "cuda" 25 | elif torch.backends.mps.is_available(): 26 | DEVICE = "mps" 27 | else: 28 | DEVICE = "cpu" 29 | 30 | 31 | @pytest.mark.optional 32 | def test_esm_zero_shot(): 33 | # this model requires no MSAs 34 | 35 | from aide_predict.bespoke_models.predictors.esm2 import ESM2LikelihoodWrapper 36 | 37 | model = ESM2LikelihoodWrapper( 38 | model_checkpoint="esm2_t6_8M_UR50D", 39 | marginal_method="masked_marginal", 40 | positions=None, 41 | device=DEVICE, 42 | pool=True, 43 | wt="LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR", 44 | metadata_folder='./tmp/esm', 45 | ) 46 | 47 | assay_data = pd.read_csv( 48 | os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 49 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 50 | scores = assay_data['DMS_score'].tolist() 51 | 52 | model.fit(sequences) # does nothing 53 | print('we did it!') 54 | predictions = model.predict(sequences) 55 | spearman = spearmanr(scores, predictions)[0] 56 | print(f"ESM Spearman: {spearman}") 57 | assert abs(spearman - 0.2) < 0.05 58 | 59 | # repeat for wild type marginal 60 | model = ESM2LikelihoodWrapper( 61 | model_checkpoint="esm2_t6_8M_UR50D", 62 | marginal_method="wildtype_marginal", 63 | positions=None, 64 | device=DEVICE, 65 | pool=True, 66 | wt="LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR", 67 | metadata_folder='./tmp/esm', 68 | ) 69 | model.fit(sequences) # does nothing 70 | predictions = model.predict(sequences) 71 | spearman = spearmanr(scores, predictions)[0] 72 | print(f"ESM Spearman: {spearman}") 73 | assert abs(spearman - 0.2) < 0.05 74 | 75 | 76 | # run it with positions specified and get position specific scores 77 | model = ESM2LikelihoodWrapper( 78 | model_checkpoint="esm2_t6_8M_UR50D", 79 | marginal_method="wildtype_marginal", 80 | positions=[8, 9, 10], 81 | device=DEVICE, 82 | pool=False, 83 | wt="LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR", 84 | metadata_folder='./tmp/esm', 85 | ) 86 | model.fit(sequences) # does nothing 87 | predictions = model.predict(sequences) 88 | assert len(predictions) == len(sequences) 89 | assert len(predictions[0]) == 3 90 | 91 | # repeat for mutant marginal 92 | model = ESM2LikelihoodWrapper( 93 | model_checkpoint="esm2_t6_8M_UR50D", 94 | marginal_method="mutant_marginal", 95 | positions=None, 96 | device=DEVICE, 97 | pool=True, 98 | wt="LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR", 99 | metadata_folder='./tmp/esm', 100 | ) 101 | model.fit(sequences) # does nothing 102 | predictions = model.predict(sequences) 103 | spearman = spearmanr(scores, predictions)[0] 104 | print(f"ESM Spearman: {spearman}") 105 | assert abs(spearman - 0.2) < 0.05 106 | 107 | if __name__ == "__main__": 108 | test_esm_zero_shot() -------------------------------------------------------------------------------- /tests/test_not_base_models/test_msatrans_loglike.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_msatrans_loglike.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/8/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import pytest 10 | 11 | import pandas as pd 12 | import numpy as np 13 | from scipy.stats import spearmanr 14 | 15 | from aide_predict.utils.data_structures import ProteinSequencesOnFile, ProteinSequences, ProteinSequence 16 | 17 | import torch 18 | if torch.cuda.is_available(): 19 | DEVICE = "cuda" 20 | elif torch.backends.mps.is_available(): 21 | DEVICE = "mps" 22 | else: 23 | DEVICE = "cpu" 24 | 25 | 26 | @pytest.mark.optional 27 | def test_msa_transformer_zero_shot(): 28 | from aide_predict.bespoke_models.predictors.msa_transformer import MSATransformerLikelihoodWrapper 29 | 30 | # Load the MSA 31 | msa_file = os.path.join('tests', 'data', 'ENVZ_ECOLI_extreme_filtered.a2m') 32 | wt_sequence = ProteinSequence.from_fasta(msa_file) 33 | 34 | # Load the assay data 35 | assay_data = pd.read_csv(os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 36 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 37 | scores = assay_data['DMS_score'].tolist() 38 | 39 | # Test wt marginal method 40 | model = MSATransformerLikelihoodWrapper( 41 | marginal_method="wildtype_marginal", 42 | positions=None, 43 | device=DEVICE, 44 | pool=True, 45 | wt=wt_sequence, 46 | batch_size=128, 47 | metadata_folder='./tmp/msa_transformer', 48 | ) 49 | 50 | model.fit() 51 | print('wt marginal method fitted') 52 | predictions = model.predict(sequences) 53 | spearman = spearmanr(scores, predictions)[0] 54 | print(f"MSA Transformer (masked marginal) Spearman: {spearman}") 55 | assert abs(spearman - 0.2) < 0.02 # Adjust the expected correlation as needed 56 | 57 | # Test masked marginal method 58 | model = MSATransformerLikelihoodWrapper( 59 | marginal_method="masked_marginal", 60 | positions=None, 61 | device=DEVICE, 62 | pool=True, 63 | wt=wt_sequence, 64 | batch_size=128, 65 | metadata_folder='./tmp/msa_transformer', 66 | ) 67 | 68 | model.fit() 69 | print('masked marginal method fitted') 70 | predictions = model.predict(sequences) 71 | spearman = spearmanr(scores, predictions)[0] 72 | print(f"MSA Transformer (wildtype marginal) Spearman: {spearman}") 73 | assert abs(spearman - 0.2) < 0.02 # Adjust the expected correlation as needed 74 | 75 | # Test mutant marginal method 76 | # THIS ONE TAKES A LONG TIME, 1k calls 77 | model = MSATransformerLikelihoodWrapper( 78 | marginal_method="mutant_marginal", 79 | positions=None, 80 | device=DEVICE, 81 | pool=True, 82 | wt=wt_sequence, 83 | batch_size=128, 84 | metadata_folder='./tmp/msa_transformer', 85 | ) 86 | 87 | model.fit() 88 | print('mutant marginal method fitted') 89 | predictions = model.predict(sequences) 90 | spearman = spearmanr(scores, predictions)[0] 91 | print(f"MSA Transformer (wildtype marginal) Spearman: {spearman}") 92 | assert abs(spearman - 0.2) < 0.02 # Adjust the expected correlation as needed 93 | 94 | # Test with specific positions and no pooling 95 | model = MSATransformerLikelihoodWrapper( 96 | marginal_method="wildtype_marginal", 97 | positions=[8, 9, 10], 98 | device=DEVICE, 99 | pool=False, 100 | wt=wt_sequence, 101 | batch_size=128, 102 | metadata_folder='./tmp/msa_transformer', 103 | ) 104 | 105 | model.fit() 106 | print('mutant marginal model fitted') 107 | predictions = model.predict(sequences) 108 | assert len(predictions) == len(sequences) 109 | assert predictions.shape[1] == 3 110 | 111 | if __name__ == "__main__": 112 | test_msa_transformer_zero_shot() -------------------------------------------------------------------------------- /tests/test_not_base_models/test_saprot_loglike.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_saprot_loglike.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/16/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import os 9 | import pytest 10 | 11 | import pandas as pd 12 | from scipy.stats import spearmanr 13 | 14 | from aide_predict.utils.data_structures import ProteinSequences, ProteinStructure, ProteinSequence 15 | 16 | import torch 17 | if torch.cuda.is_available(): 18 | DEVICE = "cuda" 19 | elif torch.backends.mps.is_available(): 20 | DEVICE = "mps" 21 | else: 22 | DEVICE = "cpu" 23 | 24 | @pytest.mark.optional 25 | def test_saprot_zero_shot(): 26 | from aide_predict.bespoke_models.predictors.saprot import SaProtLikelihoodWrapper 27 | 28 | # Load the structure 29 | structure = ProteinStructure('tests/data/ENVZ_ECOLI.pdb') 30 | 31 | wt_sequence = "LADDRTLLMAGVSHDLRTPLTRIRLATEMMSEQDGYLAESINKDIEECNAIIEQFIDYLR" 32 | wt = ProteinSequences([ProteinSequence(wt_sequence, structure=structure)])[0] 33 | 34 | assay_data = pd.read_csv(os.path.join('tests', 'data', 'ENVZ_ECOLI_Ghose_2023.csv')) 35 | sequences = ProteinSequences.from_list(assay_data['mutated_sequence'].tolist()) 36 | scores = assay_data['DMS_score'].tolist() 37 | 38 | # Repeat for wildtype marginal 39 | model = SaProtLikelihoodWrapper( 40 | model_checkpoint="westlake-repl/SaProt_650M_AF2", 41 | marginal_method="wildtype_marginal", 42 | positions=None, 43 | device=DEVICE, 44 | pool=True, 45 | wt=wt, 46 | metadata_folder='./tmp/saprot', 47 | foldseek_path='foldseek' 48 | ) 49 | model.fit(sequences) # does nothing 50 | predictions = model.predict(sequences) 51 | spearman = spearmanr(scores, predictions)[0] 52 | print(f"SaProt Spearman (wildtype_marginal): {spearman}") 53 | assert abs(spearman - 0.15) < 0.05 # Adjust this threshold as needed 54 | 55 | model = SaProtLikelihoodWrapper( 56 | model_checkpoint="westlake-repl/SaProt_650M_AF2", 57 | marginal_method="masked_marginal", 58 | positions=None, 59 | device=DEVICE, 60 | pool=True, 61 | wt=wt, 62 | metadata_folder='./tmp/saprot', 63 | foldseek_path='foldseek' # Adjust this path if necessary 64 | ) 65 | 66 | model.fit(sequences) # does nothing 67 | print('SaProt model fitted!') 68 | predictions = model.predict(sequences) 69 | spearman = spearmanr(scores, predictions)[0] 70 | print(f"SaProt Spearman (masked_marginal): {spearman}") 71 | assert abs(spearman - 0.15) < 0.05 # Adjust this threshold as needed 72 | 73 | # Run with positions specified and get position-specific scores 74 | model = SaProtLikelihoodWrapper( 75 | model_checkpoint="westlake-repl/SaProt_650M_AF2", 76 | marginal_method="wildtype_marginal", 77 | positions=[8, 9, 10], 78 | device=DEVICE, 79 | pool=False, 80 | wt=wt, 81 | metadata_folder='./tmp/saprot', 82 | foldseek_path='foldseek' 83 | ) 84 | model.fit(sequences) # does nothing 85 | predictions = model.predict(sequences) 86 | assert len(predictions) == len(sequences) 87 | assert len(predictions[0]) == 3 88 | 89 | # Repeat for mutant marginal 90 | model = SaProtLikelihoodWrapper( 91 | model_checkpoint="westlake-repl/SaProt_650M_AF2", 92 | marginal_method="mutant_marginal", 93 | positions=None, 94 | device=DEVICE, 95 | pool=True, 96 | wt=wt, 97 | metadata_folder='./tmp/saprot', 98 | foldseek_path='foldseek' 99 | ) 100 | model.fit(sequences) # does nothing 101 | predictions = model.predict(sequences) 102 | spearman = spearmanr(scores, predictions)[0] 103 | print(f"SaProt Spearman (mutant_marginal): {spearman}") 104 | assert abs(spearman - 0.15) < 0.05 # Adjust this threshold as needed 105 | 106 | if __name__ == "__main__": 107 | test_saprot_zero_shot() -------------------------------------------------------------------------------- /showcase/fig3_data.csv: -------------------------------------------------------------------------------- 1 | embedder,model,metric,value 2 | ESM2,Ridge,spearman,0.7332563361644545 3 | ESM2,Ridge,spearman,0.5195219812592327 4 | ESM2,Ridge,spearman,0.5185911003853251 5 | ESM2,Ridge,spearman,0.4130309050065345 6 | ESM2,Ridge,spearman,0.292843537228058 7 | ESM2,RandomForest,spearman,0.7139488258500014 8 | ESM2,RandomForest,spearman,0.5205487847319098 9 | ESM2,RandomForest,spearman,0.2328747542436918 10 | ESM2,RandomForest,spearman,0.3944207768729067 11 | ESM2,RandomForest,spearman,0.2908665713520141 12 | SaProt,Ridge,spearman,0.7189258729532826 13 | SaProt,Ridge,spearman,0.4943432099542618 14 | SaProt,Ridge,spearman,0.5481540680468201 15 | SaProt,Ridge,spearman,0.44436428400703015 16 | SaProt,Ridge,spearman,0.3874086319322629 17 | SaProt,RandomForest,spearman,0.6890896059864967 18 | SaProt,RandomForest,spearman,0.6436533237927391 19 | SaProt,RandomForest,spearman,0.4838059318494942 20 | SaProt,RandomForest,spearman,0.5073158909080261 21 | SaProt,RandomForest,spearman,0.4142238300741247 22 | MSATransformer,Ridge,spearman,0.13592487261375027 23 | MSATransformer,Ridge,spearman,0.21124989124870575 24 | MSATransformer,Ridge,spearman,0.10028395916309496 25 | MSATransformer,Ridge,spearman,-0.0839354758679946 26 | MSATransformer,Ridge,spearman,-0.019688631704976395 27 | MSATransformer,RandomForest,spearman,-0.03629287088697565 28 | MSATransformer,RandomForest,spearman,0.012473779378420575 29 | MSATransformer,RandomForest,spearman,-0.12530860035842473 30 | MSATransformer,RandomForest,spearman,0.008118262826279489 31 | MSATransformer,RandomForest,spearman,-0.07255363115560627 32 | AlignedOneHot,Ridge,spearman,0.7699835113403921 33 | AlignedOneHot,Ridge,spearman,0.7171753360032541 34 | AlignedOneHot,Ridge,spearman,0.3987460159132769 35 | AlignedOneHot,Ridge,spearman,0.5145320630414737 36 | AlignedOneHot,Ridge,spearman,0.4715327855807983 37 | AlignedOneHot,RandomForest,spearman,0.6803966634813294 38 | AlignedOneHot,RandomForest,spearman,0.5283345512159724 39 | AlignedOneHot,RandomForest,spearman,0.33057440073358924 40 | AlignedOneHot,RandomForest,spearman,0.40241285030294066 41 | AlignedOneHot,RandomForest,spearman,0.4351562790551842 42 | ESM2,Ridge,roc_auc,0.8959276018099548 43 | ESM2,Ridge,roc_auc,0.7697368421052632 44 | ESM2,Ridge,roc_auc,0.764172335600907 45 | ESM2,Ridge,roc_auc,0.7333333333333334 46 | ESM2,Ridge,roc_auc,0.6710875331564987 47 | ESM2,RandomForest,roc_auc,0.8981900452488687 48 | ESM2,RandomForest,roc_auc,0.7543859649122806 49 | ESM2,RandomForest,roc_auc,0.5895691609977325 50 | ESM2,RandomForest,roc_auc,0.7234567901234568 51 | ESM2,RandomForest,roc_auc,0.6923076923076923 52 | SaProt,Ridge,roc_auc,0.8914027149321267 53 | SaProt,Ridge,roc_auc,0.7653508771929824 54 | SaProt,Ridge,roc_auc,0.7936507936507936 55 | SaProt,Ridge,roc_auc,0.7481481481481482 56 | SaProt,Ridge,roc_auc,0.7506631299734748 57 | SaProt,RandomForest,roc_auc,0.8800904977375567 58 | SaProt,RandomForest,roc_auc,0.8552631578947368 59 | SaProt,RandomForest,roc_auc,0.7551020408163265 60 | SaProt,RandomForest,roc_auc,0.7703703703703704 61 | SaProt,RandomForest,roc_auc,0.773209549071618 62 | MSATransformer,Ridge,roc_auc,0.5678733031674208 63 | MSATransformer,Ridge,roc_auc,0.6513157894736842 64 | MSATransformer,Ridge,roc_auc,0.6054421768707483 65 | MSATransformer,Ridge,roc_auc,0.43209876543209874 66 | MSATransformer,Ridge,roc_auc,0.46684350132625996 67 | MSATransformer,RandomForest,roc_auc,0.5180995475113123 68 | MSATransformer,RandomForest,roc_auc,0.5087719298245614 69 | MSATransformer,RandomForest,roc_auc,0.4399092970521542 70 | MSATransformer,RandomForest,roc_auc,0.5049382716049383 71 | MSATransformer,RandomForest,roc_auc,0.4098143236074271 72 | AlignedOneHot,Ridge,roc_auc,0.9298642533936651 73 | AlignedOneHot,Ridge,roc_auc,0.9078947368421053 74 | AlignedOneHot,Ridge,roc_auc,0.7097505668934241 75 | AlignedOneHot,Ridge,roc_auc,0.7876543209876543 76 | AlignedOneHot,Ridge,roc_auc,0.7877984084880637 77 | AlignedOneHot,RandomForest,roc_auc,0.8687782805429864 78 | AlignedOneHot,RandomForest,roc_auc,0.8070175438596492 79 | AlignedOneHot,RandomForest,roc_auc,0.6417233560090703 80 | AlignedOneHot,RandomForest,roc_auc,0.7160493827160492 81 | AlignedOneHot,RandomForest,roc_auc,0.7612732095490716 82 | -------------------------------------------------------------------------------- /dvc.lock: -------------------------------------------------------------------------------- 1 | schema: '2.0' 2 | stages: 3 | run_jackhmmer: 4 | cmd: python scripts/run_jackhmmer.py 5 | deps: 6 | - path: aide_predict/utils/jackhmmer.py 7 | hash: md5 8 | md5: 29311d7958e0f3f02f2265184ca0d074 9 | size: 7534 10 | - path: data/wt.fasta 11 | hash: md5 12 | md5: a1598310bd55723f8c9255b8160c1816 13 | size: 263 14 | - path: scripts/run_jackhmmer.py 15 | hash: md5 16 | md5: fa97a43d299c54f9bab3bc81c69d2fc2 17 | size: 1726 18 | params: 19 | params.yaml: 20 | jackhmmer: 21 | seqdb: /kfs2/projects/bpms/ekomp_tmp/datasets/uniref/uniref100.fasta 22 | iterations: 5 23 | evalue: 0.0001 24 | tvalue: 0.5 25 | use_bitscores: true 26 | cpus: 16 27 | mx: BLOSUM62 28 | popen: 0.02 29 | pextend: 0.4 30 | outs: 31 | - path: data/jackhmmer/ 32 | hash: md5 33 | md5: ea7a9ff5cdb4a795b63d838c42e73b50.dir 34 | size: 42340850998 35 | nfiles: 2 36 | run_msa: 37 | cmd: python scripts/run_msa.py 38 | deps: 39 | - path: data/experimental_data.csv 40 | hash: md5 41 | md5: d41d8cd98f00b204e9800998ecf8427e 42 | size: 0 43 | - path: data/starting_sequences.a2m 44 | hash: md5 45 | md5: 624275f98f9f67c5939e3aba6e4a873d 46 | size: 173269062 47 | - path: data/starting_sequences.fa 48 | hash: md5 49 | md5: d41d8cd98f00b204e9800998ecf8427e 50 | size: 0 51 | - path: data/wt.fa 52 | hash: md5 53 | md5: 592b18586d4530aaef6557b77c963ee3 54 | size: 78 55 | - path: scripts/run_msa.py 56 | hash: md5 57 | md5: 0ef6b9083d6ebd257d39cbb9bff52734 58 | size: 7401 59 | params: 60 | params.yaml: 61 | msa_creation.jackhmmer.cpus: 16 62 | msa_creation.jackhmmer.domain_threshold: 100 63 | msa_creation.jackhmmer.iterations: 1 64 | msa_creation.jackhmmer.mx: BLOSUM62 65 | msa_creation.jackhmmer.seqdb: uniref100 66 | msa_creation.jackhmmer.sequence_identity_filter: 0.8 67 | msa_creation.jackhmmer.sequence_threshold: 100 68 | msa_creation.jackhmmer.use_bitscores: true 69 | msa_creation.msa_mode: starting_sequences 70 | msa_creation.starting_sequences.activity_targets: 71 | - 0 72 | msa_creation.starting_sequences.activity_threshold: 0.0 73 | msa_creation.starting_sequences.add_training_sequences: false 74 | msa_creation.starting_sequences.prealigned: true 75 | use_msa: true 76 | outs: 77 | - path: data/metrics/run_msa.json 78 | hash: md5 79 | md5: 8d205654d9b81dbfb76c9799add5bc43 80 | size: 61 81 | - path: data/run_msa/ 82 | hash: md5 83 | md5: 95ce7fb78568780b4c41bee915e29b69.dir 84 | size: 173269062 85 | nfiles: 1 86 | process_msa: 87 | cmd: python scripts/process_msa.py 88 | deps: 89 | - path: data/run_msa/ 90 | hash: md5 91 | md5: 95ce7fb78568780b4c41bee915e29b69.dir 92 | size: 173269062 93 | nfiles: 1 94 | - path: data/wt.fa 95 | hash: md5 96 | md5: 592b18586d4530aaef6557b77c963ee3 97 | size: 78 98 | - path: scripts/process_msa.py 99 | hash: md5 100 | md5: 2c829feb65d96f4a77628fe0d42bcfa2 101 | size: 2433 102 | params: 103 | params.yaml: 104 | msaprocessing.additional_weights: false 105 | msaprocessing.preprocess: true 106 | msaprocessing.remove_sequences_with_indeterminate_AA_in_focus_cols: true 107 | msaprocessing.theta: 0.2 108 | msaprocessing.threshold_focus_cols_frac_gaps: 0.5 109 | msaprocessing.threshold_sequence_frac_gaps: 0.5 110 | msaprocessing.use_weights: true 111 | outs: 112 | - path: data/metrics/process_msa.json 113 | hash: md5 114 | md5: 3f86e048d9c8f117fdfd3e89c83f1c6b 115 | size: 83 116 | - path: data/process_msa/ 117 | hash: md5 118 | md5: 844813c4a68d4abe187d81facf8c2cbf.dir 119 | size: 188291636 120 | nfiles: 2 121 | -------------------------------------------------------------------------------- /aide_predict/io/bio_files.py: -------------------------------------------------------------------------------- 1 | # aide_predict/io/bio_files.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 5/22/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | Some functions are copied from EVcoupling, as to avoid the additional required dependancy. All credit goes to the EVcouples team: 9 | 10 | Hopf T. A., Green A. G., Schubert B., et al. The EVcouplings Python framework for coevolutionary sequence analysis. Bioinformatics 35, 1582–1584 (2019) 11 | ''' 12 | from collections import OrderedDict 13 | 14 | import numpy as np 15 | 16 | from aide_predict.utils.common import wrap 17 | 18 | def read_fasta(fileobj): 19 | """ 20 | Generator function to read a FASTA-format file 21 | (includes aligned FASTA, A2M, A3M formats) 22 | 23 | Credit to EVcouplings 24 | 25 | Args: 26 | - fileobj: file object opened for reading 27 | 28 | Returns: 29 | - Tuple of (sequence_id, sequence) for each entry 30 | """ 31 | current_sequence = "" 32 | current_id = None 33 | 34 | for line in fileobj: 35 | # Start reading new entry. If we already have 36 | # seen an entry before, return it first. 37 | if line.startswith(">"): 38 | if current_id is not None: 39 | yield current_id, current_sequence 40 | 41 | current_id = line.rstrip()[1:] 42 | current_sequence = "" 43 | 44 | elif not line.startswith(";"): 45 | current_sequence += line.rstrip() 46 | 47 | # Also do not forget last entry in file 48 | yield current_id, current_sequence 49 | 50 | def write_fasta(sequences, fileobj, width=80): 51 | """ 52 | Write a list of IDs/sequences to a FASTA-format file 53 | 54 | Credit to EVcouplings 55 | 56 | Args: 57 | - sequences: list of (sequence_id, sequence) tuples 58 | - fileobj: file object opened for writing 59 | - width: width of sequence lines in FASTA file 60 | 61 | Returns: 62 | - None 63 | """ 64 | for (seq_id, seq) in sequences: 65 | fileobj.write(">{}\n".format(seq_id)) 66 | fileobj.write(wrap(seq, width=width) + "\n") 67 | 68 | 69 | def read_a3m(fileobj, inserts="first"): 70 | """ 71 | Read an alignment in compressed a3m format and expand 72 | into a2m format. 73 | 74 | Credit to EVcouplings 75 | 76 | Args: 77 | - fileobj: file object opened for reading 78 | - inserts: how to handle insert gaps in alignment 79 | (either "first" or "delete") 80 | 81 | Returns: 82 | - OrderedDict of sequence_id -> sequence 83 | """ 84 | seqs = OrderedDict() 85 | 86 | for i, (seq_id, seq) in enumerate(read_fasta(fileobj)): 87 | # remove any insert gaps that may still be in alignment 88 | # (just to be sure) 89 | seq = seq.replace(".", "") 90 | 91 | if inserts == "first": 92 | # define "spacing" of uppercase columns in 93 | # final alignment based on target sequence; 94 | # remaining columns will be filled with insert 95 | # gaps in the other sequences 96 | if i == 0: 97 | uppercase_cols = [ 98 | j for (j, c) in enumerate(seq) 99 | if (c == c.upper() or c == "-") 100 | ] 101 | gap_template = np.array(["."] * len(seq)) 102 | filled_seq = seq 103 | else: 104 | uppercase_chars = [ 105 | c for c in seq if c == c.upper() or c == "-" 106 | ] 107 | filled = np.copy(gap_template) 108 | filled[uppercase_cols] = uppercase_chars 109 | filled_seq = "".join(filled) 110 | 111 | elif inserts == "delete": 112 | # remove all lowercase letters and insert gaps .; 113 | # since each sequence must have same number of 114 | # uppercase letters or match gaps -, this gives 115 | # the final sequence in alignment 116 | filled_seq = "".join([c for c in seq if c == c.upper() and c != "."]) 117 | else: 118 | raise ValueError( 119 | "Invalid option for inserts: {}".format(inserts) 120 | ) 121 | 122 | seqs[seq_id] = filled_seq 123 | 124 | return seqs -------------------------------------------------------------------------------- /tests/test_not_base_models/test_esm2_embedding.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_esm2_embedding.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/8/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | 9 | import pytest 10 | import numpy as np 11 | import torch 12 | from unittest.mock import Mock, MagicMock 13 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 14 | from aide_predict.bespoke_models.embedders.esm2 import ESM2Embedding 15 | 16 | class TestESM2Embedding: 17 | 18 | @pytest.fixture(scope="class") 19 | def embedder(self): 20 | return ESM2Embedding( 21 | model_checkpoint="esm2_t6_8M_UR50D", # Using a small model for faster tests 22 | layer=-1, 23 | batch_size=2 24 | ) 25 | 26 | @pytest.fixture(scope="class") 27 | def sequences(self): 28 | return ProteinSequences([ 29 | ProteinSequence("ACDEFGHIKLMNPQRSTVWY"), 30 | ProteinSequence("LLLLLLLLLLLLLLLLLLLL") 31 | ]) 32 | 33 | @pytest.fixture(scope="class") 34 | def aligned_sequences(self): 35 | return ProteinSequences([ 36 | ProteinSequence("ACDE-FGHIKLMNPQRSTVWY"), 37 | ProteinSequence("LLLL-LLLLLLLLLLLLLLLL") 38 | ]) 39 | 40 | def test_initialization(self, embedder): 41 | assert embedder.model_checkpoint == 'esm2_t6_8M_UR50D' 42 | assert embedder.layer == -1 43 | assert embedder.positions is None 44 | assert not embedder.flatten 45 | assert not embedder.pool 46 | assert embedder.batch_size == 2 47 | 48 | def test_fit(self, embedder, sequences): 49 | embedder.fit(sequences) 50 | assert hasattr(embedder, 'fitted_') 51 | 52 | @pytest.mark.parametrize("positions,pool,flatten", [ 53 | (None, False, False), 54 | ([0, 1, 2], False, False), 55 | (None, True, False), 56 | (None, False, True), 57 | ]) 58 | def test_transform(self, embedder, sequences, positions, pool, flatten): 59 | embedder.positions = positions 60 | embedder.pool = pool 61 | embedder.flatten = flatten 62 | embedder.fit([]) 63 | print(embedder.flatten) 64 | embeddings = embedder.transform(sequences) 65 | 66 | assert isinstance(embeddings, np.ndarray) 67 | if pool: 68 | assert embeddings.shape == (2, 320) # ESM2 t6 8M model has 320 hidden dimensions 69 | elif positions: 70 | assert embeddings.shape == (2, len(positions), 320) 71 | else: 72 | if not flatten: 73 | assert embeddings.shape == (2, 20, 320) # 20 amino acids in each sequence 74 | else: 75 | assert embeddings.shape == (2, 6400) 76 | 77 | def test_transform_aligned(self, embedder, aligned_sequences): 78 | embedder.positions = [0,4] 79 | embedder.pool = False 80 | embedder.flatten = False 81 | embedder.fit([]) 82 | embeddings = embedder.transform(aligned_sequences) 83 | 84 | assert isinstance(embeddings, np.ndarray) 85 | assert embeddings.shape == (2, 2, 320) # 21 positions including the gap 86 | # check for zeros at the gap 87 | assert np.all(embeddings[:, 1] == 0) 88 | 89 | def test_transform_variable_length_error(self, embedder): 90 | sequences = ProteinSequences([ProteinSequence("ACGT"), ProteinSequence("ACGTAA")]) 91 | embedder.positions = [0, 1] 92 | 93 | with pytest.raises(ValueError): 94 | embedder.transform(sequences) 95 | 96 | def test_get_feature_names_out_pooled(self, embedder, sequences): 97 | 98 | embedder.positions = None 99 | embedder.pool = True 100 | embedder.flatten = False 101 | embedder.fit(sequences) 102 | 103 | feature_names = embedder.get_feature_names_out() 104 | assert len(feature_names) == 320 # ESM2 t6 8M model has 320 hidden dimensions 105 | 106 | def test_get_feature_names_out_flattened(self, embedder, sequences): 107 | embedder.positions = [0, 1] 108 | embedder.pool = False 109 | embedder.flatten = True 110 | embedder.fit(sequences) 111 | 112 | 113 | feature_names = embedder.get_feature_names_out() 114 | assert len(feature_names) == 2 * 320 115 | -------------------------------------------------------------------------------- /aide_predict/patches_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import pandas as pd 4 | import re 5 | 6 | # Check if np.str is not available 7 | if not hasattr(np, 'str'): 8 | # Create a custom object that behaves like the old np.str 9 | class NumpyStrProxy: 10 | def __new__(cls, *args, **kwargs): 11 | return str(*args, **kwargs) 12 | 13 | # Add the proxy to the numpy module 14 | np.str = str 15 | 16 | if not hasattr(np, 'int'): 17 | # Create a custom object that behaves like the old np.int 18 | # class NumpyIntProxy: 19 | # def __new__(cls, *args, **kwargs): 20 | # return np.int32(*args, **kwargs) 21 | 22 | # Add the proxy to the numpy module 23 | np.int = np.int32 24 | 25 | # Apply the patch 26 | sys.modules['numpy'] = np 27 | 28 | # PLMC installed via conda has a new output format the the evcouplings does not read properly 29 | try: 30 | from evcouplings.couplings import tools 31 | 32 | # Store the original function 33 | original_parse_plmc_log = tools.parse_plmc_log 34 | 35 | def patched_parse_plmc_log(log): 36 | """ 37 | A patched version of parse_plmc_log that handles the new output format. 38 | """ 39 | # Copy the original regular expressions 40 | res = { 41 | "focus": re.compile(r"Found focus (.+) as sequence (\d+)"), 42 | "seqs": re.compile(r"(\d+) valid sequences out of (\d+)"), 43 | "sites": re.compile(r"(\d+) sites out of (\d+)"), 44 | "region": re.compile(r"Region starts at (\d+)"), 45 | "samples": re.compile(r"Effective number of samples[^:]*: (\d+\.\d+)"), 46 | "optimization": re.compile(r"Gradient optimization: (.+)") 47 | } 48 | 49 | re_iter = re.compile(r"(\d+){}".format( 50 | "".join([r"\s+(\d+\.\d+)"] * 6) 51 | )) 52 | 53 | matches = {} 54 | fields = None 55 | iters = [] 56 | 57 | for line in log.split("\n"): 58 | for (name, pattern) in res.items(): 59 | m = re.search(pattern, line) 60 | if m: 61 | matches[name] = m.groups() 62 | if line.startswith("iter"): 63 | fields = line.split() 64 | m_iter = re.search(re_iter, line) 65 | if m_iter: 66 | iters.append(m_iter.groups()) 67 | 68 | if fields is not None and iters: 69 | iter_df = pd.DataFrame(iters, columns=fields) 70 | else: 71 | iter_df = None 72 | 73 | # some output only defined in focus mode 74 | focus_index = None 75 | valid_sites, total_sites = None, None 76 | region_start = 1 77 | try: 78 | focus_index = int(matches["focus"][1]) 79 | valid_sites, total_sites = map(int, matches["sites"]) 80 | region_start = int(matches["region"][0]) 81 | except KeyError: 82 | pass 83 | 84 | valid_seqs, total_seqs = map(int, matches["seqs"]) 85 | eff_samples = float(matches["samples"][0]) 86 | opt_status = matches["optimization"][0] 87 | 88 | return ( 89 | iter_df, 90 | ( 91 | focus_index, valid_seqs, total_seqs, 92 | valid_sites, total_sites, region_start, 93 | eff_samples, opt_status 94 | ) 95 | ) 96 | 97 | # Replace the original function with our patched version 98 | tools.parse_plmc_log = patched_parse_plmc_log 99 | except ImportError: 100 | pass 101 | 102 | def patch_pandas_append(): 103 | 104 | def patched_append(self, other, ignore_index=False, verify_integrity=False, sort=False): 105 | if isinstance(other, pd.DataFrame): 106 | start_index = len(self) 107 | for i, row in other.iterrows(): 108 | self.loc[start_index + i] = row 109 | elif isinstance(other, pd.Series): 110 | self.loc[len(self)] = other 111 | else: 112 | raise TypeError("Unsupported type for 'other'. Expected DataFrame or Series.") 113 | 114 | if ignore_index: 115 | self.reset_index(drop=True, inplace=True) 116 | 117 | if verify_integrity: 118 | self.drop_duplicates(inplace=True) 119 | 120 | if sort: 121 | self.sort_index(inplace=True) 122 | 123 | return self 124 | 125 | pd.DataFrame.append = patched_append 126 | 127 | # Apply the patch 128 | patch_pandas_append() -------------------------------------------------------------------------------- /external_calls/eve/_train_VAE_one.py: -------------------------------------------------------------------------------- 1 | # external_calls/eve/_train_VAE_one.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 10/28/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | 7 | 8 | Wrapper script for EVE. Expects to be called from EVE environment. 9 | 10 | Adapted from Script by P. Notin: https://github.com/OATML/EVE/tree/master 11 | ''' 12 | import os 13 | import argparse 14 | import json 15 | import sys 16 | import torch 17 | 18 | # Add EVE repo to Python path 19 | eve_repo = os.environ.get('EVE_REPO') 20 | if eve_repo: 21 | sys.path.insert(0, eve_repo) 22 | 23 | from EVE import VAE_model 24 | from utils import data_utils 25 | 26 | if __name__=='__main__': 27 | parser = argparse.ArgumentParser(description='Train EVE VAE model on a single MSA') 28 | 29 | # Core arguments for single MSA processing 30 | parser.add_argument('--msa_file', type=str, required=True, 31 | help='Path to the MSA file to process') 32 | parser.add_argument('--model_name', type=str, required=True, 33 | help='Name for the model checkpoint') 34 | 35 | # Optional processing parameters 36 | parser.add_argument('--theta_reweighting', type=float, default=0.2, 37 | help='Parameter for MSA sequence re-weighting (default: 0.2)') 38 | parser.add_argument('--weights_folder', type=str, default='weights', 39 | help='Folder to store sequence weights (default: weights/)') 40 | 41 | # Model parameters and checkpoint locations 42 | parser.add_argument('--model_parameters', type=str, required=True, 43 | help='Path to JSON file containing model parameters') 44 | parser.add_argument('--checkpoint_folder', type=str, default='checkpoints', 45 | help='Folder to store model checkpoints (default: checkpoints/)') 46 | parser.add_argument('--logs_folder', type=str, default='logs', 47 | help='Folder to store training logs (default: logs/)') 48 | 49 | args = parser.parse_args() 50 | 51 | # Create necessary directories if they don't exist 52 | os.makedirs(args.weights_folder, exist_ok=True) 53 | os.makedirs(args.checkpoint_folder, exist_ok=True) 54 | os.makedirs(args.logs_folder, exist_ok=True) 55 | 56 | # Construct paths 57 | weights_path = os.path.join(args.weights_folder, 58 | f'{args.model_name}_theta_{args.theta_reweighting}.npy') 59 | 60 | print(f"Processing MSA file: {args.msa_file}") 61 | print(f"Using theta={args.theta_reweighting} for MSA re-weighting") 62 | print(f"Model name: {args.model_name}") 63 | 64 | # Process MSA data 65 | data = data_utils.MSA_processing( 66 | MSA_location=args.msa_file, 67 | theta=args.theta_reweighting, 68 | use_weights=True, 69 | weights_location=weights_path 70 | ) 71 | 72 | # Load model parameters 73 | try: 74 | with open(args.model_parameters, 'r') as f: 75 | model_params = json.load(f) 76 | except (json.JSONDecodeError, FileNotFoundError) as e: 77 | print(f"Error loading model parameters from {args.model_parameters}: {str(e)}") 78 | exit(1) 79 | 80 | # Initialize model 81 | model = VAE_model.VAE_model( 82 | model_name=args.model_name, 83 | data=data, 84 | encoder_parameters=model_params["encoder_parameters"], 85 | decoder_parameters=model_params["decoder_parameters"], 86 | random_seed=42 87 | ) 88 | if str(model.device) == 'cpu' and torch.backends.mps.is_available(): 89 | model.device = 'mps' 90 | model.encoder.device = 'mps' 91 | model.decoder.device = 'mps' 92 | model = model.to(model.device) 93 | print("Using device:", model.device) 94 | 95 | # Update training parameters with new paths 96 | model_params["training_parameters"].update({ 97 | 'training_logs_location': args.logs_folder, 98 | 'model_checkpoint_location': args.checkpoint_folder 99 | }) 100 | 101 | # Train model 102 | print(f"Starting to train model: {args.model_name}") 103 | model.train_model(data=data, training_parameters=model_params["training_parameters"]) 104 | 105 | # Save final model 106 | print(f"Saving model: {args.model_name}") 107 | final_checkpoint_path = os.path.join(args.checkpoint_folder, f"{args.model_name}_final") 108 | model.save( 109 | model_checkpoint=final_checkpoint_path, 110 | encoder_parameters=model_params["encoder_parameters"], 111 | decoder_parameters=model_params["decoder_parameters"], 112 | training_parameters=model_params["training_parameters"] 113 | ) 114 | 115 | -------------------------------------------------------------------------------- /aide_predict/bespoke_models/embedders/kmer.py: -------------------------------------------------------------------------------- 1 | # aide_predict/bespoke_models/embedders/kmer.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 8/9/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import numpy as np 9 | from typing import List, Union, Optional 10 | from collections import defaultdict 11 | 12 | from aide_predict.bespoke_models.base import ProteinModelWrapper, CanHandleAlignedSequencesMixin 13 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 14 | from aide_predict.utils.common import MessageBool 15 | AVAILABLE = MessageBool(True, "Available") 16 | 17 | class KmerEmbedding(CanHandleAlignedSequencesMixin, ProteinModelWrapper): 18 | """ 19 | A fast K-mer embedding class for protein sequences. 20 | 21 | This class generates K-mer embeddings for protein sequences, handling both 22 | aligned and unaligned sequences efficiently. 23 | 24 | Attributes: 25 | k (int): The size of the K-mers. 26 | normalize (bool): Whether to normalize the K-mer counts. 27 | """ 28 | _available=AVAILABLE 29 | def __init__(self, metadata_folder: str = None, 30 | k: int = 3, 31 | normalize: bool = True, 32 | wt: ProteinSequence = None): 33 | """ 34 | Initialize the KmerEmbedding. 35 | 36 | Args: 37 | metadata_folder (str): Folder to store metadata. 38 | k (int): The size of the K-mers. 39 | normalize (bool): Whether to normalize the K-mer counts. 40 | """ 41 | super().__init__(metadata_folder=metadata_folder, wt=None) 42 | if k < 1: 43 | raise ValueError("K must be a positive integer.") 44 | if not isinstance(k, int): 45 | raise ValueError("K must be an integer.") 46 | 47 | self.k = k 48 | self.normalize = normalize 49 | self._kmer_to_index = {} 50 | 51 | def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'KmerEmbedding': 52 | """ 53 | Fit the K-mer embedding model. 54 | 55 | This method creates a mapping of K-mers to indices. 56 | 57 | Args: 58 | X (ProteinSequences): The input protein sequences. 59 | y (Optional[np.ndarray]): Ignored. Present for API consistency. 60 | 61 | Returns: 62 | KmerEmbedding: The fitted model. 63 | """ 64 | if len(X) == 0: 65 | raise ValueError("Cannot fit KmerEmbedding with no sequences.") 66 | unique_kmers = set() 67 | for seq in X: 68 | seq_str = str(seq).upper().replace('-', '') # Remove gaps 69 | if len(seq_str) < self.k: 70 | raise ValueError(f"Sequence {seq.id} is too short for K={self.k}.") 71 | unique_kmers.update(seq_str[i:i+self.k] for i in range(len(seq_str) - self.k + 1)) 72 | 73 | self._kmer_to_index = {kmer: i for i, kmer in enumerate(sorted(unique_kmers))} 74 | self.n_features_ = len(self._kmer_to_index) 75 | self.fitted_ = True 76 | return self 77 | 78 | def _transform(self, X: ProteinSequences) -> np.ndarray: 79 | """ 80 | Transform the protein sequences into K-mer embeddings. 81 | 82 | Args: 83 | X (ProteinSequences): The input protein sequences. 84 | 85 | Returns: 86 | np.ndarray: The K-mer embeddings for the sequences. 87 | """ 88 | embeddings = np.zeros((len(X), self.n_features_), dtype=np.float32) 89 | 90 | for i, seq in enumerate(X): 91 | seq_str = str(seq).upper().replace('-', '') # Remove gaps 92 | if len(seq_str) < self.k: 93 | raise ValueError(f"Sequence {seq.id} is too short for K={self.k}.") 94 | kmer_counts = defaultdict(int) 95 | for j in range(len(seq_str) - self.k + 1): 96 | kmer = seq_str[j:j+self.k] 97 | if kmer in self._kmer_to_index: 98 | kmer_counts[self._kmer_to_index[kmer]] += 1 99 | 100 | for idx, count in kmer_counts.items(): 101 | embeddings[i, idx] = count 102 | 103 | if self.normalize: 104 | row_sums = embeddings.sum(axis=1) 105 | embeddings = embeddings / row_sums[:, np.newaxis] 106 | 107 | return embeddings 108 | 109 | def get_feature_names_out(self, input_features: Optional[List[str]] = None) -> List[str]: 110 | """ 111 | Get output feature names for transformation. 112 | 113 | Args: 114 | input_features (Optional[List[str]]): Ignored. Present for API consistency. 115 | 116 | Returns: 117 | List[str]: Output feature names. 118 | """ 119 | return [f"kmer_{kmer}" for kmer in self._kmer_to_index.keys()] -------------------------------------------------------------------------------- /tests/data/some_sequences.fasta: -------------------------------------------------------------------------------- 1 | >PET1 2 | MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGA 3 | IAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWS 4 | MGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCA 5 | NSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEHHHHHH 6 | >PET2 7 | ANPYERGPNPTDALLEARSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGFV 8 | VITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGSLRLASQRPDLKAAIPLTPWHLNKN 9 | WSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQFL 10 | CPGPRDGLFGEVEEYRSTCPFYPNSSSVDKLAAALEHHHHHH 11 | >PET3 12 | MDGVLWRVRTAALMAALLALAAWALVWASPSVEAQSNPYQRGPNPTRSALTADGPFSVATYTVSRLSVSGFGGGVIYYPT 13 | GTSLTFGGIAMSPGYTADASSLAWLGRRLASHGFVVLVINTNSRFDYPDSRASQLSAALNYLRTSSPSAVRARLDANRLA 14 | VAGHSMGGGGTLRIAEQNPSLKAAVPLTPWHTDKTFNTSVPVLIVGAEADTVAPVSQHAIPFYQNLPSTTPKVYVELDNA 15 | SHFAPNSNNAAISVYTISWMKLWVDNDTRYRQFLCNVNDPALSDFRTNNRHCQ 16 | >PET4 17 | LPTSNPAQELEARQLGRTTRDDLINGNSASCADVIFIYARGSTETGNLGTLGPSIASNLESAFGKDGVWIQGVGGAYRAT 18 | LGDNALPRGTSSAAIREMLGLFQQANTKCPDATLIAGGYSQGAALAAASIEDLDSAIRDKIAGTVLFGYTKNLQNRGRIP 19 | NYPADRTKVFCNTGDLVCTGSLIVAAPHLAYGPDARGPAPEFLIEKVRAVRGSA 20 | >PET5 21 | MANPYERGPNPTDALLEASSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGF 22 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGTLRLASQRPDLKAAIPLTPWHLNK 23 | NWSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 24 | LCPGPRDGLFGEVEEYRSTCPFALE 25 | >PET6 26 | MANPYERGPNPTDALLEARSGPFSVSEERASRFGADGFGGGTIYYPRENNTYGAVAISPGYTGTQASVAWLGERIASHGF 27 | VVITIDTNTTLDQPDSRARQLNAALDYMINDASSAVRSRIDSSRLAVMGHSMGGGGTLRLASQRPDLKAAIPLTPWHLNK 28 | NWSSVRVPTLIIGADLDTIAPVLTHARPFYNSLPTSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 29 | LCPGPRDGLFGEVEEYRSTCPFALE 30 | >PET7 31 | MANPYERGPNPTDALLEARSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGF 32 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGSLRLASQRPDLKAAIPLTPWHLNK 33 | NWSSVRVPTLIIGADLDTIAPVLTHARPFYNSLPTSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 34 | LCPGPRDGLFGEVEEYRSTCPF 35 | >PET8 36 | MANPYERGPNPTDALLEASSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGGRIASHGF 37 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGTPRLASQRPDLKAAIPLTPWHLNK 38 | NRSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 39 | LCPGPRDGLFGEVEEYCSTCPF 40 | >PET9 41 | MANPYERGPNPTNSSIEALRGPFRVDEERVSRLQARGFGGGTIYYPTDNNTFGAVAISPGYTGTQSSISWLGERLASHGF 42 | VVMTIDTNTTLDQPDSRASQLDAALDYMVEDSSYSVRNRIDSSRLAAMGHSMGGGGTLRLAERRPDLQAAIPLTPWHTDK 43 | TWGSVRVPTLIIGAENDTIASVRSHSEPFYNSLPGSLDKAYLELDGASHFAPNLSNTTIAKYSISWLKRFVDDDTRYTQF 44 | LCPGPSTGWGSDVEEYRSTCPF 45 | >PET11 46 | QLGAIENGLESGSANACPDAILIFARGSTEPGNMGITVGPALANGLESHIRNIWIQGVGGPYDAALATNFLPRGTSQANI 47 | DEGKRLFALANQKCPNTPVVAGGYSQGAALIAAAVSELSGAVKEQVKGVALFGYTQNLQNRGGIPNYPRERTKVFCNVGD 48 | AVCTGTLIITPAHLSYTIEARGEAARFLRDRIRA 49 | >PET12 50 | MTHQIVTTQYGKVKGTTENGVHKWKGIPYAKPPVGQWRFKAPEPPEVWEDVLDATAYGPICPQPSDLLSLSYTELPRQSE 51 | DCLYVNVFAPDTPSQNLPVMVWIHGGAFYLGAGSEPLYDGSKLAAQGEVIVVTLNYRLGPFGFLHLSSFNEAYSDNLGLL 52 | DQAAALKWVRENISAFGGDPDNVTVFGESAGGMSIAALLAMPAAKGLFQKAIMESGASRTMTKEQAASTSAAFLQVLGIN 53 | EGQLDKLHTVSAEDSLKAADQLRIAEKENIFQLFFQPALDPKTLPEEPEKAIAEGAASGIPLLIGTTRDEGYLFFTPDSD 54 | VHSQETLDAALEYLLGKPLAEKAADLYPRSLESQIHMMTDLLFWRPAVAYASAQSHYAPVWMYRFDWHPKKPPYNKAFHA 55 | LELPFVFGNLDGLERMAKAEITDEVKQLSHTIQSAWITFAKTGNPSTEAVNWPAYHEETRETLILDSEITIENDPESEKR 56 | QKLFPSKGE 57 | >PET13 58 | MSLRKSFGLLSATAALVAGLVAAPPAQAAANPYQRGPDPTESLLRAARGPFAVSEQSVSRLSVSGFGGGRIYYPTTTSQG 59 | TFGAIAISPGFTASWSSLAWLGPRLASHGFVVIGIETNTRLDQPDSRGRQLLAALDYLTQRSSVRNRVDASRLAVAGHSM 60 | GGGGTLEAAKSRTSLKAAIPIAPWNLDKTWPEVRTPTLIIGGELDSIAPVATHSIPFYNSLTNAREKAYLELNNASHFFP 61 | QFSNDTMAKFMISWMKRFIDDDTRYDQFLCPPPRAIGDISDYRDTCPHT 62 | >PET14 63 | MPITARNTLASLLLASSALLLSGTAFAANPPGGDPDPGCQTDCNYQRGPDPTDAYLEAASGPYTVSTIRVSSLVPGFGGG 64 | TIHYPTNAGGGKMAGIVVIPGYLSFESSIEWWGPRLASHGFVVMTIDTNTIYDQPSQRRDQIEAALQYLVNQSNSSSSPI 65 | SGMVDSSRLAAVGWSMGGGGTLQLAADGGIKAAIALAPWNSSINDFNRIQVPTLIFACQLDAIAPVALHASPFYNRIPNT 66 | TPKAFFEMTGGDHWCANGGNIYSALLGKYGVSWMKLHLDQDTRYAPFLCGPNHAAQTLISEYRGNCPY 67 | >PET15 68 | MNKSILKKLSFGTSVLLVSMNALSWTPSPTPNPDPIPDPTPCQDDCDFTRGPNPTPSSLEASTGPYSVATRSVASSVSGF 69 | GGGTLHYPTNTTGTMGAIAVVPGFLLQESSIDFWGPKLASHGFVVITISANSGFDQPASRATQLGRALDYVINQSNGSNS 70 | PISGMVDTTRLGVVGWSMGGGGALQLASGDRLSAAIPIAPWNQGGNRFDQIETPTLVIACENDVVASVNSHASPFYNRIP 71 | STTDKAYLEINGGSHFCANDGGSIGGLLGKYGVSWMKRFIDNDLRYDAFLCGPDHAANRSVSEYRDTCNY 72 | >PET16 73 | MNVLTKCKLALGIIAIFFSLPSFAVPCSDCSNGFERGQVPRVDQLESSRGPYSVKTINVSRLARGFGGGTIHYSTESGGQ 74 | QGIIAVVPGYVSLEGSIKWWGPRLASWGFTVITIDTNTIYDQPDSRASQLSAAIDYVIDKGNDRSSPIYGLVDPNRVGVI 75 | GWSMGGGGSLKLATDRKIDAVIPQAPWYLGLSRFSSITSPTMIIACQADVVAPVSVHASRFYNQIPGTTPKAYFEIALGS 76 | HFCANTGYPSEDILGRNGVAWMKRFIDKDERYTQFLCGQNFDSSLRVSEYRDNCSYY 77 | >PET17 78 | MPPDCVLPRRLAAAALLASATLVPLSAAAQTNPYQRGPDPTTRDLEDSRGPFRYASTNVRSPSGYGAGTIYYPTDVSGSV 79 | GAVAVVPGYLARQSSIRWWGPRLASHGFVVITLDTRSTSDQPASRSAQQMAALRQVVALSETRSSPIYGKVDPNRLAVMG 80 | WSMGGGGTLISARDNPSLKAAVPFAPWHNTANFSGVQVPTLVIACENDTVAPISRHASSFYNSFSSSLAKAYLEINNGSH 81 | TCANTGNSNQALIGKYGVAWIKRFVDNDTRYSPFLCGAPHQADLRSSRLSEYRESCPY 82 | -------------------------------------------------------------------------------- /tests/test_utils/test_msa.py: -------------------------------------------------------------------------------- 1 | # tests/test_utils/test_msa.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 8/16/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import pytest 9 | import numpy as np 10 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 11 | from aide_predict.utils.msa import MSAProcessing 12 | 13 | class TestMSAProcessing: 14 | @pytest.fixture 15 | def sample_msa(self): 16 | sequences = [ 17 | ProteinSequence("MKYKVLPQGTVMKVLPK-TVMKVLPQGTV", id="seq1"), 18 | ProteinSequence("MKY-VLPQGTV-KVLPKGTVMKVLPQGTV", id="seq2"), 19 | ProteinSequence("MKYKVL-QGTVMKVLPKGTVMKVLPQGTV", id="seq3"), 20 | ProteinSequence("MKYKVLPQGTVMKVLPKGTVMKVLPQGTV", id="seq4"), 21 | ProteinSequence("MKY-VLP------VLPKGTVMKVLPQGTV", id="seq5"), 22 | ] 23 | return ProteinSequences(sequences) 24 | 25 | @pytest.fixture 26 | def msa_processor(self): 27 | return MSAProcessing( 28 | theta=0.2, 29 | threshold_sequence_frac_gaps=0.2, 30 | threshold_focus_cols_frac_gaps=0.3 31 | ) 32 | 33 | def test_initialization(self, msa_processor): 34 | assert msa_processor.theta == 0.2 35 | assert msa_processor.threshold_sequence_frac_gaps == 0.2 36 | assert msa_processor.threshold_focus_cols_frac_gaps == 0.3 37 | 38 | def test_process_with_focus(self, sample_msa, msa_processor): 39 | processed_msa = msa_processor.process(sample_msa, focus_seq_id="seq1") 40 | assert len(processed_msa) == 4 # seq5 should be removed due to high gap fraction 41 | assert processed_msa.weights is not None 42 | assert len(processed_msa.weights) == 4 43 | 44 | def test_process_without_focus(self, sample_msa, msa_processor): 45 | processed_msa = msa_processor.process(sample_msa) 46 | assert len(processed_msa) == 4 # seq5 should be removed due to high gap fraction 47 | assert processed_msa.weights is not None 48 | assert len(processed_msa.weights) == 4 49 | # Check that all sequences are uppercase when no focus is provided 50 | assert all(seq.isupper() for seq in processed_msa) 51 | 52 | def test_preprocess_msa(self, sample_msa, msa_processor): 53 | msa_processor.focus_seq = sample_msa[0] 54 | preprocessed_msa = msa_processor._preprocess_msa(sample_msa) 55 | assert len(preprocessed_msa) == 4 # seq5 should be removed 56 | assert '-' not in str(preprocessed_msa[0]) 57 | 58 | def test_get_focus_columns(self, sample_msa, msa_processor): 59 | focus_cols = msa_processor._get_focus_columns(sample_msa[0]) 60 | assert len(focus_cols) == 29 61 | assert np.sum(focus_cols) == 28 # All but 1 columns should be focus columns for seq1 62 | 63 | def test_compute_weights(self, sample_msa, msa_processor): 64 | weights = msa_processor._compute_weights(sample_msa) 65 | assert len(weights) == 5 66 | assert sum(weights) < 5.0 # Weights should sum to less than 5.0 as some sequences are downweighted 67 | 68 | def test_no_focus_all_uppercase(self, sample_msa, msa_processor): 69 | processed_msa = msa_processor.process(sample_msa) 70 | assert all(seq.isupper() for seq in processed_msa) 71 | 72 | def test_with_focus_lowercase_non_focus(self, sample_msa, msa_processor): 73 | processed_msa = msa_processor.process(sample_msa, focus_seq_id="seq1") 74 | assert any(not char.isupper() for seq in processed_msa for char in str(seq)) 75 | 76 | def test_remove_high_gap_sequences(self): 77 | 78 | sequences = [ 79 | ProteinSequence("MKYKVLPQGTV", id="seq1"), 80 | ProteinSequence("MKY-V-PQG-V", id="seq2"), # 30% gaps 81 | ProteinSequence("MKYKVLPQGTV", id="seq3"), 82 | ] 83 | msa = ProteinSequences(sequences) 84 | msa_processor = MSAProcessing(threshold_sequence_frac_gaps=0.2) 85 | processed_msa = msa_processor.process(msa) 86 | assert len(processed_msa) == 2 87 | assert "seq2" not in [seq.id for seq in processed_msa] 88 | 89 | def test_weight_computation_batch_size(self, sample_msa): 90 | msa_processor = MSAProcessing(weight_computation_batch_size=2) 91 | weights = msa_processor._compute_weights(sample_msa) 92 | assert len(weights) == 5 93 | 94 | def test_indeterminate_aa_removal(self): 95 | sequences = [ 96 | ProteinSequence("MKYKVLPQGTV", id="seq1"), 97 | ProteinSequence("MKYXVLPQGTV", id="seq2"), # Contains 'X' 98 | ProteinSequence("MKYKVLPQGTV", id="seq3"), 99 | ] 100 | msa = ProteinSequences(sequences) 101 | msa_processor = MSAProcessing(remove_sequences_with_indeterminate_aa_in_focus_cols=True) 102 | processed_msa = msa_processor.process(msa) 103 | assert len(processed_msa) == 2 104 | assert "seq2" not in [seq.id for seq in processed_msa] -------------------------------------------------------------------------------- /tests/data/hmm-17.fa: -------------------------------------------------------------------------------- 1 | >PET1 2 | MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGA 3 | IAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWS 4 | MGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCA 5 | NSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEHHHHHH 6 | >PET2 7 | ANPYERGPNPTDALLEARSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGFV 8 | VITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGSLRLASQRPDLKAAIPLTPWHLNKN 9 | WSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQFL 10 | CPGPRDGLFGEVEEYRSTCPFYPNSSSVDKLAAALEHHHHHH 11 | >PET3 12 | MDGVLWRVRTAALMAALLALAAWALVWASPSVEAQSNPYQRGPNPTRSALTADGPFSVATYTVSRLSVSGFGGGVIYYPT 13 | GTSLTFGGIAMSPGYTADASSLAWLGRRLASHGFVVLVINTNSRFDYPDSRASQLSAALNYLRTSSPSAVRARLDANRLA 14 | VAGHSMGGGGTLRIAEQNPSLKAAVPLTPWHTDKTFNTSVPVLIVGAEADTVAPVSQHAIPFYQNLPSTTPKVYVELDNA 15 | SHFAPNSNNAAISVYTISWMKLWVDNDTRYRQFLCNVNDPALSDFRTNNRHCQ 16 | >PET4 17 | LPTSNPAQELEARQLGRTTRDDLINGNSASCADVIFIYARGSTETGNLGTLGPSIASNLESAFGKDGVWIQGVGGAYRAT 18 | LGDNALPRGTSSAAIREMLGLFQQANTKCPDATLIAGGYSQGAALAAASIEDLDSAIRDKIAGTVLFGYTKNLQNRGRIP 19 | NYPADRTKVFCNTGDLVCTGSLIVAAPHLAYGPDARGPAPEFLIEKVRAVRGSA 20 | >PET5 21 | MANPYERGPNPTDALLEASSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGF 22 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGTLRLASQRPDLKAAIPLTPWHLNK 23 | NWSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 24 | LCPGPRDGLFGEVEEYRSTCPFALE 25 | >PET6 26 | MANPYERGPNPTDALLEARSGPFSVSEERASRFGADGFGGGTIYYPRENNTYGAVAISPGYTGTQASVAWLGERIASHGF 27 | VVITIDTNTTLDQPDSRARQLNAALDYMINDASSAVRSRIDSSRLAVMGHSMGGGGTLRLASQRPDLKAAIPLTPWHLNK 28 | NWSSVRVPTLIIGADLDTIAPVLTHARPFYNSLPTSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 29 | LCPGPRDGLFGEVEEYRSTCPFALE 30 | >PET7 31 | MANPYERGPNPTDALLEARSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGERIASHGF 32 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGSLRLASQRPDLKAAIPLTPWHLNK 33 | NWSSVRVPTLIIGADLDTIAPVLTHARPFYNSLPTSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 34 | LCPGPRDGLFGEVEEYRSTCPF 35 | >PET8 36 | MANPYERGPNPTDALLEASSGPFSVSEENVSRLSASGFGGGTIYYPRENNTYGAVAISPGYTGTEASIAWLGGRIASHGF 37 | VVITIDTITTLDQPDSRAEQLNAALNHMINRASSTVRSRIDSSRLAVMGHSMGGGGTPRLASQRPDLKAAIPLTPWHLNK 38 | NRSSVTVPTLIIGADLDTIAPVATHAKPFYNSLPSSISKAYLELDGATHFAPNIPNKIIGKYSVAWLKRFVDNDTRYTQF 39 | LCPGPRDGLFGEVEEYCSTCPF 40 | >PET9 41 | MANPYERGPNPTNSSIEALRGPFRVDEERVSRLQARGFGGGTIYYPTDNNTFGAVAISPGYTGTQSSISWLGERLASHGF 42 | VVMTIDTNTTLDQPDSRASQLDAALDYMVEDSSYSVRNRIDSSRLAAMGHSMGGGGTLRLAERRPDLQAAIPLTPWHTDK 43 | TWGSVRVPTLIIGAENDTIASVRSHSEPFYNSLPGSLDKAYLELDGASHFAPNLSNTTIAKYSISWLKRFVDDDTRYTQF 44 | LCPGPSTGWGSDVEEYRSTCPF 45 | >PET10 46 | XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXMRGSHHHHHHGSNPYERGPDPTEDSIEAIRGPFSVATERVSSFASG 47 | FGGGTIYYPRETDEGTFGAVAVAPGFTASQGSMSWYGERVASQGFIVFTIDTNTRLDQPGQRGRQLLAALDYLVERSDRK 48 | VRERLDPNRLAVMGHSMGGGGSLEATVMRPSLKASIPLTPWNLDKTWGQVQVPTFIIGAELDTIAPVRTHAKPFYESLPS 49 | SLPKAYMELDGATHFAPNIPNTTIAKYVISWLKRFVDEDTRYSQFLCPNPTDRAIEEYRSTCPYKLN 50 | >PET11 51 | QLGAIENGLESGSANACPDAILIFARGSTEPGNMGITVGPALANGLESHIRNIWIQGVGGPYDAALATNFLPRGTSQANI 52 | DEGKRLFALANQKCPNTPVVAGGYSQGAALIAAAVSELSGAVKEQVKGVALFGYTQNLQNRGGIPNYPRERTKVFCNVGD 53 | AVCTGTLIITPAHLSYTIEARGEAARFLRDRIRA 54 | >PET12 55 | MTHQIVTTQYGKVKGTTENGVHKWKGIPYAKPPVGQWRFKAPEPPEVWEDVLDATAYGPICPQPSDLLSLSYTELPRQSE 56 | DCLYVNVFAPDTPSQNLPVMVWIHGGAFYLGAGSEPLYDGSKLAAQGEVIVVTLNYRLGPFGFLHLSSFNEAYSDNLGLL 57 | DQAAALKWVRENISAFGGDPDNVTVFGESAGGMSIAALLAMPAAKGLFQKAIMESGASRTMTKEQAASTSAAFLQVLGIN 58 | EGQLDKLHTVSAEDSLKAADQLRIAEKENIFQLFFQPALDPKTLPEEPEKAIAEGAASGIPLLIGTTRDEGYLFFTPDSD 59 | VHSQETLDAALEYLLGKPLAEKAADLYPRSLESQIHMMTDLLFWRPAVAYASAQSHYAPVWMYRFDWHPKKPPYNKAFHA 60 | LELPFVFGNLDGLERMAKAEITDEVKQLSHTIQSAWITFAKTGNPSTEAVNWPAYHEETRETLILDSEITIENDPESEKR 61 | QKLFPSKGE 62 | >PET13 63 | MSLRKSFGLLSATAALVAGLVAAPPAQAAANPYQRGPDPTESLLRAARGPFAVSEQSVSRLSVSGFGGGRIYYPTTTSQG 64 | TFGAIAISPGFTASWSSLAWLGPRLASHGFVVIGIETNTRLDQPDSRGRQLLAALDYLTQRSSVRNRVDASRLAVAGHSM 65 | GGGGTLEAAKSRTSLKAAIPIAPWNLDKTWPEVRTPTLIIGGELDSIAPVATHSIPFYNSLTNAREKAYLELNNASHFFP 66 | QFSNDTMAKFMISWMKRFIDDDTRYDQFLCPPPRAIGDISDYRDTCPHT 67 | >PET14 68 | MPITARNTLASLLLASSALLLSGTAFAANPPGGDPDPGCQTDCNYQRGPDPTDAYLEAASGPYTVSTIRVSSLVPGFGGG 69 | TIHYPTNAGGGKMAGIVVIPGYLSFESSIEWWGPRLASHGFVVMTIDTNTIYDQPSQRRDQIEAALQYLVNQSNSSSSPI 70 | SGMVDSSRLAAVGWSMGGGGTLQLAADGGIKAAIALAPWNSSINDFNRIQVPTLIFACQLDAIAPVALHASPFYNRIPNT 71 | TPKAFFEMTGGDHWCANGGNIYSALLGKYGVSWMKLHLDQDTRYAPFLCGPNHAAQTLISEYRGNCPY 72 | >PET15 73 | MNKSILKKLSFGTSVLLVSMNALSWTPSPTPNPDPIPDPTPCQDDCDFTRGPNPTPSSLEASTGPYSVATRSVASSVSGF 74 | GGGTLHYPTNTTGTMGAIAVVPGFLLQESSIDFWGPKLASHGFVVITISANSGFDQPASRATQLGRALDYVINQSNGSNS 75 | PISGMVDTTRLGVVGWSMGGGGALQLASGDRLSAAIPIAPWNQGGNRFDQIETPTLVIACENDVVASVNSHASPFYNRIP 76 | STTDKAYLEINGGSHFCANDGGSIGGLLGKYGVSWMKRFIDNDLRYDAFLCGPDHAANRSVSEYRDTCNY 77 | >PET16 78 | MNVLTKCKLALGIIAIFFSLPSFAVPCSDCSNGFERGQVPRVDQLESSRGPYSVKTINVSRLARGFGGGTIHYSTESGGQ 79 | QGIIAVVPGYVSLEGSIKWWGPRLASWGFTVITIDTNTIYDQPDSRASQLSAAIDYVIDKGNDRSSPIYGLVDPNRVGVI 80 | GWSMGGGGSLKLATDRKIDAVIPQAPWYLGLSRFSSITSPTMIIACQADVVAPVSVHASRFYNQIPGTTPKAYFEIALGS 81 | HFCANTGYPSEDILGRNGVAWMKRFIDKDERYTQFLCGQNFDSSLRVSEYRDNCSYY 82 | >PET17 83 | MPPDCVLPRRLAAAALLASATLVPLSAAAQTNPYQRGPDPTTRDLEDSRGPFRYASTNVRSPSGYGAGTIYYPTDVSGSV 84 | GAVAVVPGYLARQSSIRWWGPRLASHGFVVITLDTRSTSDQPASRSAQQMAALRQVVALSETRSSPIYGKVDPNRLAVMG 85 | WSMGGGGTLISARDNPSLKAAVPFAPWHNTANFSGVQVPTLVIACENDTVAPISRHASSFYNSFSSSLAKAYLEINNGSH 86 | TCANTGNSNQALIGKYGVAWIKRFVDNDTRYSPFLCGAPHQADLRSSRLSEYRESCPY 87 | -------------------------------------------------------------------------------- /docs/user_guide/model_compatibility.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Model Compatibility 3 | --- 4 | 5 | # Model Compatibility 6 | 7 | ## Understanding Model Requirements 8 | 9 | AIDE models have different requirements and capabilities that determine whether they can be used with your data. Key considerations include: 10 | 11 | - Whether the model requires training data (supervised vs zero-shot) 12 | - Whether sequences must be aligned or of fixed length 13 | - Whether the model requires a Multiple Sequence Alignment (MSA) 14 | - Whether the model requires or can use structural information 15 | - Whether the model needs a wild-type sequence for comparison 16 | - Whether the model can handle variable-length sequences 17 | 18 | ## Checking Model Compatibility 19 | 20 | AIDE provides a utility function to check which models are compatible with your data: 21 | 22 | ```python 23 | from aide_predict.utils.checks import check_model_compatibility 24 | from aide_predict import ProteinSequences, ProteinSequence 25 | 26 | # Example setup 27 | sequences = ProteinSequences.from_fasta("my_sequences.fasta") 28 | msa = ProteinSequences.from_fasta("family_msa.fasta") 29 | wt = ProteinSequence("MKLLVLGLPGAGKGT", id="wild_type") 30 | wt.msa = msa 31 | 32 | # Check compatibility 33 | compatibility = check_model_compatibility( 34 | training_sequences=sequences, # Optional: sequences for supervised learning 35 | testing_sequences=None, # Optional: test sequences if different from training 36 | wt=wt # Optional: wild-type sequence, may have structure, MSA 37 | ) 38 | 39 | print("Compatible models:", compatibility["compatible"]) 40 | print("Incompatible models:", compatibility["incompatible"]) 41 | ``` 42 | 43 | The compatibility checker performs several validation steps: 44 | - Verifies if structure information is available (either in sequences or wild-type) 45 | - Checks if MSAs are available and properly aligned 46 | - Validates that sequence lengths match requirements 47 | - Ensures wild-type sequences are available when needed 48 | - Verifies that per-sequence MSAs match sequence lengths when required 49 | 50 | You can also check which tools are available in your current installation: 51 | 52 | ```python 53 | from aide_predict.utils.checks import get_supported_tools 54 | print(get_supported_tools()) 55 | ``` 56 | 57 | ## Model Categories 58 | 59 | AIDE models fall into several categories: 60 | 61 | ### 1. Zero-Shot Predictors 62 | These models don't require training data but may have other requirements: 63 | 64 | ```python 65 | # ESM2 - Requires only sequences 66 | from aide_predict import ESM2LikelihoodWrapper 67 | model = ESM2LikelihoodWrapper(wt=wt) 68 | model.fit() # No training needed 69 | scores = model.predict(sequences) 70 | 71 | # MSATransformer - Requires MSA for the WT 72 | from aide_predict import MSATransformerLikelihoodWrapper 73 | model = MSATransformerLikelihoodWrapper(wt=wt) 74 | model.fit() 75 | scores = model.predict(sequences) 76 | 77 | # SaProt - Can use structural information 78 | from aide_predict import SaProtLikelihoodWrapper 79 | model = SaProtLikelihoodWrapper(wt=wt) # wt must have structure 80 | model.fit() 81 | scores = model.predict(sequences) # Will use structure if available 82 | ``` 83 | 84 | Other zero-shot predictors include: 85 | - **HMM**: Creates Hidden Markov Models from MSAs 86 | - **EVMutation**: Uses evolutionary couplings from MSAs 87 | - **VESPA**: Pre-trained model for variant effect prediction 88 | - **EVE**: Evolutionary model using latent space representations 89 | - **SSEmb**: Structure and sequence-based variant effect predictor 90 | 91 | ### 2. Embedding Models 92 | These models convert sequences into numerical features for downstream ML: 93 | 94 | ```python 95 | # Simple one-hot encoding 96 | from aide_predict import OneHotProteinEmbedding 97 | embedder = OneHotProteinEmbedding() 98 | X = embedder.fit_transform(sequences) 99 | 100 | # Advanced language model embeddings 101 | from aide_predict import ESM2Embedding 102 | embedder = ESM2Embedding(pool=True) # pool=True for sequence-level embeddings 103 | X = embedder.fit_transform(sequences) 104 | 105 | # K-mer based embeddings 106 | from aide_predict import KmerEmbedding 107 | embedder = KmerEmbedding(k=3) 108 | X = embedder.fit_transform(sequences) 109 | ``` 110 | 111 | Other embedding models include: 112 | - **MSATransformerEmbedding**: Produces embeddings using MSAs 113 | - **SaProtEmbedding**: Structure-aware protein language model embeddings 114 | - **OneHotAlignedEmbedding**: One-hot encodings for aligned sequences 115 | 116 | ## Importance of Data Structure 117 | 118 | The compatibility of models depends heavily on the structure of your data: 119 | 120 | | Data Characteristic | Compatible Models | Incompatible Models | 121 | |---------------------|-------------------|---------------------| 122 | | Fixed-length sequences | All models | - | 123 | | Variable-length sequences | Models without `RequiresFixedLengthMixin` | Models with `RequiresFixedLengthMixin` | 124 | | Has MSA | All models | - | 125 | | No MSA | Models without MSA requirements | MSATransformer, EVMutation, EVE | 126 | | Has structure | All models | - | 127 | | No structure | Models without structure requirements | SaProt, SSEmb | 128 | | Has wild-type | All models | - | 129 | | No wild-type | Models without WT requirements | Models with `RequiresWTToFunctionMixin` | 130 | 131 | Using the appropriate data structure for your specific modeling task ensures that AIDE can provide the most accurate predictions. -------------------------------------------------------------------------------- /params.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # Environmental config 3 | ########################################################### 4 | # these parameters will not actually be tied to any dvc stages 5 | # but still referenced in files, thus when you change these 6 | # it will not trigger dvc 7 | # instead we assume that these are properly defined. 8 | 9 | # change with your locations 10 | # add as many as you want, they can be refered to by name in the 11 | # jackhmmer step. That name parameter is DVC tracked, and it cross references 12 | # the values here 13 | sequence_databases: 14 | uniref100: /kfs2/projects/proteinml/datasets/uniref/uniref100.fasta 15 | 16 | ########################################################### 17 | # High level parameters 18 | ########################################################### 19 | 20 | # if the modeling you want to do does not require and MSA you can skip those steps 21 | use_msa: true # must be true if a protein model requires MSA input 22 | 23 | # Whether to use protein models, supervised, or both. Note that if both 24 | # are passed, protein models is used as input for supervised a la hsu et al 25 | # protein models are in most cases zero-shot or covariation methods that do not 26 | # require training data, but this is not always the case. 27 | # Experimental data can safely be passed to any protein model 28 | # even if the model does not actually use it. 29 | use_protein_models: true # whether to use protein models for scoring 30 | use_supervised: true # note requires `data/experimental_data.csv` to be not empty 31 | 32 | # If you want to score based on only specific positions instead of the whole sequences 33 | # you can turn that on here. Not that not all models support this 34 | # it is also incompatable with training and testing data that are not fixed length. 35 | # if you do, a list of positiions can be passed eg [3, 4, 5] 1 indexed 36 | # or a range with a string eg '3-5' inclusive. 37 | position_specific_scoring: false 38 | 39 | 40 | ########################################################### 41 | # Creation of MSA for covariation model training 42 | ########################################################### 43 | msa_creation: 44 | # If `starting_sequences` The MSA is created using `data/starting_sequences.fa` and optionally, sequences from supervised training data 45 | # If `jackhmmer` The MSA is created using `jackhmmer` search against a database 46 | msa_mode: jackhmmer 47 | 48 | # for using predefined sequences 49 | # msa_mode: starting_sequences 50 | starting_sequences: 51 | prealigned: true # if true, the sequences are already aligned in a3m format 52 | # If you want to add sequences from any supervised training data provided 53 | # add_training_sequences: true 54 | # activity_targets: the target columns to use when determining the known actives to add 55 | # note that the threshold is non inclusive and greater than 56 | # if you want your target that is low-better to be considered active, you 57 | # must reformat the target to be high-better 58 | add_training_sequences: false # this requires `data/experimental_data.csv` to be not empty 59 | activity_targets: [0] 60 | activity_threshold: 0.0 61 | 62 | 63 | # for searching natural homologs 64 | # msa_mode: jackhmmer 65 | # note requires `data/wt.fa` to be not empty 66 | jackhmmer: 67 | seqdb: uniref100 68 | iterations: 5 69 | domain_threshold: 0.4 70 | sequence_threshold: 0.4 71 | use_bitscores: true 72 | sequence_identity_filter: 95 73 | minimum_sequence_coverage: 50 74 | minimum_column_coverage: 70 75 | theta: 0.8 76 | cpus: 16 77 | mx: 'BLOSUM62' 78 | 79 | ########################################################### 80 | # Preprocessing of MSA for model training 81 | ########################################################### 82 | msaprocessing: 83 | theta: 0.2 84 | use_weights: true 85 | preprocess: true 86 | threshold_sequence_frac_gaps: 0.5 87 | threshold_focus_cols_frac_gaps: 0.5 88 | remove_sequences_with_indeterminate_AA_in_focus_cols: true 89 | additional_weights: false 90 | 91 | ########################################################### 92 | # Training and prediction 93 | ########################################################### 94 | modeling: 95 | protein_models: 96 | # these are models that use protein sequences directly as input 97 | # Eg. HMM score, ESM log likelihood 98 | # If supervised is used, these will be used as input to the supervised model 99 | models: ['HMM'] 100 | models_kwargs: '{"HMM": {}}' 101 | supervised: 102 | type: regression 103 | model: 'KernelRidge' 104 | model_kwargs: '{alpha: 5.0}' 105 | embeddings: ['ESM'] 106 | embeddings_kwargs: '{"ESM": {}}' 107 | run_pca: true 108 | min_pca_variance_explained: 0.95 109 | standardize_X: true 110 | scale_y: true 111 | 112 | ########################################################### 113 | # If training data given (zero only does not require it) 114 | ########################################################### 115 | validation: 116 | do_cv: false 117 | cv: 5 118 | split_type: 'random' 119 | 120 | ########################################################### 121 | # Prediction of variants 122 | ########################################################### 123 | # Requires `data/variants.csv` to be not empty 124 | # must have columns `id`, `sequence` 125 | 126 | # Whether you are comparing the values to the input wild type or not 127 | prediction_mode: 128 | compare_to_wt: false 129 | -------------------------------------------------------------------------------- /external_calls/eve/_compute_evol_indices_one.py: -------------------------------------------------------------------------------- 1 | # external_calls/eve/_compute_evol_indices_one.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 10/28/2024 5 | * (c) Copyright by Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | 7 | 8 | Wrapper script for EVE. Expects to be called from EVE environment. 9 | 10 | Adapted from Script by P. Notin: https://github.com/OATML/EVE/tree/master 11 | ''' 12 | import os 13 | import sys 14 | import json 15 | import argparse 16 | import pandas as pd 17 | import torch 18 | 19 | # Add EVE repo to Python path 20 | eve_repo = os.environ.get('EVE_REPO') 21 | if eve_repo: 22 | sys.path.insert(0, eve_repo) 23 | 24 | from EVE import VAE_model 25 | from utils import data_utils 26 | 27 | if __name__=='__main__': 28 | parser = argparse.ArgumentParser(description='Compute evolutionary indices using EVE') 29 | 30 | # Core arguments for data and model 31 | parser.add_argument('--msa_file', type=str, required=True, 32 | help='Path to the MSA file') 33 | parser.add_argument('--model_name', type=str, required=True, 34 | help='Name for the model') 35 | parser.add_argument('--model_parameters', type=str, required=True, 36 | help='Path to JSON file containing model parameters') 37 | parser.add_argument('--checkpoint', type=str, required=True, 38 | help='Path to trained model checkpoint') 39 | 40 | # Processing parameters 41 | parser.add_argument('--theta_reweighting', type=float, default=0.2, 42 | help='Parameter for MSA sequence re-weighting (default: 0.2)') 43 | parser.add_argument('--weights_folder', type=str, default='weights', 44 | help='Folder to store sequence weights (default: weights/)') 45 | 46 | # Mutation analysis parameters 47 | parser.add_argument('--computation_mode', type=str, choices=['all_singles', 'input_mutations_list'], required=True, 48 | help='Compute indices for all single mutations or from input list') 49 | parser.add_argument('--mutations_file', type=str, 50 | help='Path to CSV file containing mutations to analyze (required if mode is input_mutations_list)') 51 | parser.add_argument('--output_folder', type=str, default='results', 52 | help='Folder to store results (default: results/)') 53 | parser.add_argument('--output_suffix', type=str, default='', 54 | help='Suffix to add to output filename') 55 | 56 | # Computational parameters 57 | parser.add_argument('--num_samples', type=int, default=10, 58 | help='Number of samples to approximate delta ELBO (default: 10)') 59 | parser.add_argument('--batch_size', type=int, default=256, 60 | help='Batch size for computing indices (default: 256)') 61 | 62 | args = parser.parse_args() 63 | 64 | # Create necessary directories 65 | os.makedirs(args.weights_folder, exist_ok=True) 66 | os.makedirs(args.output_folder, exist_ok=True) 67 | 68 | print(f"Processing MSA file: {args.msa_file}") 69 | print(f"Using theta={args.theta_reweighting} for MSA re-weighting") 70 | print(f"Model name: {args.model_name}") 71 | 72 | # Process MSA data 73 | weights_path = os.path.join(args.weights_folder, 74 | f'{args.model_name}_theta_{args.theta_reweighting}.npy') 75 | data = data_utils.MSA_processing( 76 | MSA_location=args.msa_file, 77 | theta=args.theta_reweighting, 78 | use_weights=True, 79 | weights_location=weights_path 80 | ) 81 | 82 | # Handle mutations based on computation mode 83 | if args.computation_mode == "all_singles": 84 | mutations_file = os.path.join(args.output_folder, f"{args.model_name}_all_singles.csv") 85 | data.save_all_singles(output_filename=mutations_file) 86 | else: 87 | if not args.mutations_file: 88 | raise ValueError("mutations_file must be provided when using input_mutations_list mode") 89 | mutations_file = args.mutations_file 90 | 91 | # Load model parameters and initialize model 92 | try: 93 | with open(args.model_parameters, 'r') as f: 94 | model_params = json.load(f) 95 | except (json.JSONDecodeError, FileNotFoundError) as e: 96 | print(f"Error loading model parameters: {str(e)}") 97 | sys.exit(1) 98 | 99 | model = VAE_model.VAE_model( 100 | model_name=args.model_name, 101 | data=data, 102 | encoder_parameters=model_params["encoder_parameters"], 103 | decoder_parameters=model_params["decoder_parameters"], 104 | random_seed=42 105 | ) 106 | if str(model.device) == 'cpu' and torch.backends.mps.is_available(): 107 | model.device = 'mps' 108 | model.encoder.device = 'mps' 109 | model.decoder.device = 'mps' 110 | model = model.to(model.device) 111 | 112 | # Load model checkpoint 113 | try: 114 | checkpoint = torch.load(args.checkpoint) 115 | model.load_state_dict(checkpoint['model_state_dict']) 116 | print(f"Initialized VAE with checkpoint: {args.checkpoint}") 117 | except Exception as e: 118 | print(f"Error loading model checkpoint: {str(e)}") 119 | sys.exit(1) 120 | 121 | # Compute evolutionary indices 122 | list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices( 123 | msa_data=data, 124 | list_mutations_location=mutations_file, 125 | num_samples=args.num_samples, 126 | batch_size=args.batch_size 127 | ) 128 | 129 | # Create results dataframe 130 | results = pd.DataFrame({ 131 | 'protein_name': args.model_name, 132 | 'mutations': list_valid_mutations, 133 | 'evol_indices': evol_indices 134 | }) 135 | 136 | # Save results 137 | output_file = os.path.join( 138 | args.output_folder, 139 | f"{args.model_name}_{args.num_samples}_samples{args.output_suffix}.csv" 140 | ) 141 | 142 | # Append to existing file if it exists and isn't empty 143 | try: 144 | keep_header = os.stat(output_file).st_size == 0 145 | except: 146 | keep_header = True 147 | 148 | results.to_csv(output_file, index=False, mode='a', header=keep_header) 149 | print(f"Results saved to: {output_file}") 150 | 151 | -------------------------------------------------------------------------------- /tests/test_utils/test_conservation.py: -------------------------------------------------------------------------------- 1 | # tests/test_utils/test_conservation.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 2/10/2025 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | 9 | import pytest 10 | import numpy as np 11 | from aide_predict.utils.conservation import ConservationAnalysis 12 | from aide_predict.utils.data_structures import ProteinSequence, ProteinSequences 13 | 14 | class TestConservationAnalysis: 15 | @pytest.fixture 16 | def aligned_sequences(self): 17 | """Create a small alignment with known conservation patterns.""" 18 | return ProteinSequences([ 19 | ProteinSequence("AILV-KRDE", id="seq1"), # Hydrophobic -> Charged 20 | ProteinSequence("VLIL-EKDR", id="seq2"), # Hydrophobic -> Charged 21 | ProteinSequence("LIVI-DERK", id="seq3"), # Hydrophobic -> Charged 22 | ProteinSequence("IVL-ARDEK", id="seq4"), # Hydrophobic -> Charged 23 | ]) 24 | 25 | @pytest.fixture 26 | def unaligned_sequences(self): 27 | """Create unaligned sequences to test validation.""" 28 | return ProteinSequences([ 29 | ProteinSequence("AILVKR", id="seq1"), 30 | ProteinSequence("VLILEKDR", id="seq2"), 31 | ]) 32 | 33 | def test_initialization(self, aligned_sequences, unaligned_sequences): 34 | """Test initialization with valid and invalid inputs.""" 35 | # Valid initialization 36 | analyzer = ConservationAnalysis(aligned_sequences) 37 | assert analyzer.sequences == aligned_sequences 38 | assert analyzer.ignore_gaps == True 39 | 40 | # Test with unaligned sequences 41 | with pytest.raises(ValueError, match="Input ProteinSequences must be aligned"): 42 | ConservationAnalysis(unaligned_sequences) 43 | 44 | 45 | def test_compute_conservation(self, aligned_sequences): 46 | """Test computation of conservation scores.""" 47 | analyzer = ConservationAnalysis(aligned_sequences) 48 | scores = analyzer.compute_conservation() 49 | 50 | # Check that we got scores for all properties 51 | assert len(scores) == len(ConservationAnalysis.PROPERTIES) 52 | 53 | # Test that scores are between 0 and 1 54 | for prop, prop_scores in scores.items(): 55 | assert np.all((prop_scores >= 0) & (prop_scores <= 1)) 56 | 57 | # Test specific conservation patterns 58 | # First 4 positions should be highly hydrophobic 59 | assert np.mean(scores['Hydrophobic'][:4]) > 0.8 60 | # Last 4 positions should be highly charged 61 | assert np.mean(scores['Charged'][-4:]) > 0.8 62 | 63 | # Test gap handling 64 | assert not np.isnan(scores['Hydrophobic'][4]) # Gap position 65 | 66 | def test_compute_significance(self, aligned_sequences): 67 | """Test computation of statistical significance.""" 68 | analyzer = ConservationAnalysis(aligned_sequences) 69 | significant_positions, p_values = analyzer.compute_significance(alpha=0.05) 70 | 71 | # Check output shapes 72 | assert len(significant_positions) == aligned_sequences.width 73 | assert all(len(pvals) == aligned_sequences.width for pvals in p_values.values()) 74 | 75 | # Check p-values are valid 76 | for prop, pvals in p_values.items(): 77 | assert np.all((pvals >= 0) & (pvals <= 1)) 78 | 79 | def test_compare_alignments(self, aligned_sequences): 80 | """Test comparison between two alignments.""" 81 | # Create a second alignment with different conservation patterns 82 | second_alignment = ProteinSequences([ 83 | ProteinSequence("KRDE-AILV", id="seq1"), # Reversed pattern 84 | ProteinSequence("EKDR-VLIL", id="seq2"), 85 | ProteinSequence("DERK-LIVI", id="seq3"), 86 | ]) 87 | 88 | differences, p_values = ConservationAnalysis.compare_alignments( 89 | aligned_sequences, 90 | second_alignment, 91 | ignore_gaps=True, 92 | alpha=0.05 93 | ) 94 | 95 | # Check output shapes 96 | assert all(len(diff) == aligned_sequences.width for diff in differences.values()) 97 | assert all(len(pvals) == aligned_sequences.width for pvals in p_values.values()) 98 | 99 | # Test specific difference patterns 100 | # Hydrophobic conservation should be higher in first alignment at start 101 | assert np.mean(differences['Hydrophobic'][:4]) < 0.2 102 | # Charged conservation should be higher in second alignment at start 103 | assert np.mean(differences['Charged'][:4]) > 0.5 104 | 105 | # Test with mismatched alignments 106 | mismatched_alignment = ProteinSequences([ 107 | ProteinSequence("KRDEAILV", id="seq1"), # Different length 108 | ]) 109 | with pytest.raises(ValueError, match="Alignments must have the same length"): 110 | ConservationAnalysis.compare_alignments(aligned_sequences, mismatched_alignment) 111 | 112 | def test_gap_handling(self, aligned_sequences): 113 | """Test how gaps are handled in conservation calculations.""" 114 | # Test with and without gap ignoring 115 | analyzer_ignore_gaps = ConservationAnalysis(aligned_sequences, ignore_gaps=True) 116 | analyzer_with_gaps = ConservationAnalysis(aligned_sequences, ignore_gaps=False) 117 | 118 | scores_ignore = analyzer_ignore_gaps.compute_conservation() 119 | scores_with = analyzer_with_gaps.compute_conservation() 120 | 121 | # Scores should be different at gap positions 122 | assert not np.allclose( 123 | scores_ignore['Hydrophobic'][4], 124 | scores_with['Hydrophobic'][4] 125 | ) 126 | 127 | # Gap positions should have lower conservation when not ignored 128 | assert scores_with['Hydrophobic'][4] < scores_ignore['Hydrophobic'][4] 129 | 130 | def test_inverse_properties(self, aligned_sequences): 131 | """Test that inverse properties behave correctly.""" 132 | analyzer = ConservationAnalysis(aligned_sequences) 133 | scores = analyzer.compute_conservation() 134 | 135 | # Test that property and its inverse sum to 1 (approximately) 136 | for prop in ['Hydrophobic', 'Polar', 'Small', 'Charged']: 137 | prop_sum = scores[prop] + scores[f'not_{prop}'] 138 | assert np.allclose(prop_sum, 1.0) -------------------------------------------------------------------------------- /tests/test_not_base_models/test_badass.py: -------------------------------------------------------------------------------- 1 | # tests/test_not_base_models/test_badass.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 1/23/2025 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import pytest 9 | import numpy as np 10 | import pandas as pd 11 | from aide_predict.utils.badass import BADASSOptimizer, BADASSOptimizerParams 12 | from aide_predict.utils.data_structures import ProteinSequence, ProteinSequences 13 | from unittest.mock import Mock, patch 14 | 15 | class TestBADASSOptimizerParams: 16 | def test_default_initialization(self): 17 | params = BADASSOptimizerParams() 18 | assert params.seqs_per_iter == 500 19 | assert params.num_iter == 200 20 | assert params.sites_to_ignore == [] # Tests empty list conversion 21 | assert params.temperature == 1.5 22 | 23 | def test_custom_initialization(self): 24 | params = BADASSOptimizerParams( 25 | seqs_per_iter=100, 26 | num_iter=50, 27 | sites_to_ignore=[1, 2, 3], 28 | temperature=2.0 29 | ) 30 | assert params.seqs_per_iter == 100 31 | assert params.num_iter == 50 32 | assert params.sites_to_ignore == [1, 2, 3] 33 | assert params.temperature == 2.0 34 | 35 | def test_to_dict(self): 36 | params = BADASSOptimizerParams(seqs_per_iter=100, temperature=2.0) 37 | param_dict = params.to_dict() 38 | assert param_dict['seqs_per_iter'] == 100 39 | assert param_dict['T'] == 2.0 # Tests correct key conversion 40 | assert 'sites_to_ignore' in param_dict 41 | assert isinstance(param_dict, dict) 42 | 43 | class TestBADASSOptimizer: 44 | @pytest.fixture 45 | def reference_sequence(self): 46 | return ProteinSequence("MKLLVLGLPGAGKGT", id="wild_type") 47 | 48 | @pytest.fixture 49 | def mock_predictor(self): 50 | def predict(sequences): 51 | # Return predictable values instead of random 52 | return np.array([0.5] * len(sequences)) 53 | return Mock(predict=Mock(side_effect=predict)) 54 | 55 | @pytest.fixture 56 | def optimizer_params(self): 57 | return BADASSOptimizerParams( 58 | seqs_per_iter=10, 59 | num_iter=2, 60 | init_score_batch_size=5 61 | ) 62 | 63 | @pytest.fixture 64 | def optimizer(self, mock_predictor, reference_sequence, optimizer_params): 65 | return BADASSOptimizer( 66 | predictor=mock_predictor, 67 | reference_sequence=reference_sequence, 68 | params=optimizer_params 69 | ) 70 | 71 | def test_initialization(self, optimizer, reference_sequence, optimizer_params): 72 | assert isinstance(optimizer.reference_sequence, ProteinSequence) 73 | assert str(optimizer.reference_sequence) == str(reference_sequence) 74 | assert optimizer.params == optimizer_params 75 | assert hasattr(optimizer, '_optimizer') 76 | 77 | def test_wrapped_predictor(self, optimizer, mock_predictor): 78 | test_sequences = ["MKLLVLGLPGAGKGT", "AKLLVLGLPGAGKGT"] 79 | scores = optimizer._wrapped_predictor(test_sequences) 80 | assert isinstance(scores, list) 81 | assert len(scores) == len(test_sequences) 82 | assert all(s == 0.5 for s in scores) # Check for expected values 83 | # Don't test number of calls since BADASS uses it internally 84 | 85 | @patch('aide_predict.utils.badass.GeneralProteinOptimizer') 86 | def test_optimize(self, mock_general_optimizer_class, optimizer): 87 | # Create mock instance 88 | mock_optimizer = Mock() 89 | mock_general_optimizer_class.return_value = mock_optimizer 90 | 91 | # Mock the optimization results 92 | mock_results = pd.DataFrame({ 93 | 'sequences': ['M1A-K2R', 'M1V'], 94 | 'scores': [0.5, 0.6] 95 | }) 96 | mock_stats = pd.DataFrame({ 97 | 'iteration': [1, 2], 98 | 'mean_score': [0.5, 0.6] 99 | }) 100 | mock_optimizer.optimize.return_value = (mock_results, mock_stats) 101 | 102 | # Replace optimizer's _optimizer with our mock 103 | optimizer._optimizer = mock_optimizer 104 | 105 | results_df, stats_df = optimizer.optimize() 106 | 107 | assert isinstance(results_df, pd.DataFrame) 108 | assert isinstance(stats_df, pd.DataFrame) 109 | assert 'full_sequence' in results_df.columns 110 | assert all(isinstance(seq, ProteinSequence) for seq in results_df['full_sequence']) 111 | 112 | def test_mutations_to_sequence(self, optimizer): 113 | # Test single mutation 114 | mut_seq = optimizer._mutations_to_sequence("M1A") 115 | assert mut_seq == "AKLLVLGLPGAGKGT" 116 | 117 | # Test multiple mutations 118 | mut_seq = optimizer._mutations_to_sequence("M1A-K2R") 119 | assert mut_seq == "ARLLVLGLPGAGKGT" 120 | 121 | # Test empty mutations 122 | mut_seq = optimizer._mutations_to_sequence("") 123 | assert mut_seq == str(optimizer.reference_sequence) 124 | 125 | def test_results_property(self, optimizer): 126 | # Test before optimization (no df attribute) 127 | assert optimizer.results == (None, None) 128 | 129 | # Test after optimization (mock df attributes) 130 | mock_df = pd.DataFrame({'test': [1]}) 131 | mock_df_stats = pd.DataFrame({'stats': [1]}) 132 | optimizer._optimizer.df = mock_df 133 | optimizer._optimizer.df_stats = mock_df_stats 134 | 135 | results, stats = optimizer.results 136 | assert results is mock_df 137 | assert stats is mock_df_stats 138 | 139 | def test_plot_and_save_methods(self, optimizer): 140 | # Create mock optimizer with required methods 141 | mock_optimizer = Mock() 142 | mock_optimizer.plot_scores = Mock() 143 | mock_optimizer.save_results = Mock() 144 | 145 | # Replace the real optimizer with our mock 146 | optimizer._optimizer = mock_optimizer 147 | 148 | # Test plot method 149 | optimizer.plot(save_figs=True) 150 | mock_optimizer.plot_scores.assert_called_once_with(save_figs=True) 151 | 152 | # Test save_results method 153 | optimizer.save_results("test_file") 154 | mock_optimizer.save_results.assert_called_once_with(filename="test_file") 155 | -------------------------------------------------------------------------------- /aide_predict/utils/checks.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/checks.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/13/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | Common checks to ensure that different pipeline components are compatable. 9 | ''' 10 | import inspect 11 | from typing import Optional, List, Dict, Type 12 | 13 | import aide_predict 14 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 15 | from aide_predict.bespoke_models.base import ProteinModelWrapper 16 | 17 | def check_model_compatibility( 18 | training_sequences: Optional[ProteinSequences] = None, 19 | testing_sequences: Optional[ProteinSequences] = None, 20 | wt: Optional[ProteinSequence] = None 21 | ) -> Dict[str, List[str]]: 22 | """ 23 | Check which models are compatible with the given data. 24 | 25 | Args: 26 | training_sequences (Optional[ProteinSequences]): Training protein sequences. 27 | testing_sequences (Optional[ProteinSequences]): Testing protein sequences. 28 | wt (Optional[ProteinSequence]): Wild-type protein sequence. 29 | 30 | Returns: 31 | Dict[str, List[str]]: A dictionary with two keys: 'compatible' and 'incompatible', 32 | each containing a list of compatible and incompatible model names respectively. 33 | """ 34 | def load_models() -> List[Type[ProteinModelWrapper]]: 35 | models = [] 36 | for name, obj in inspect.getmembers(aide_predict): 37 | if inspect.isclass(obj) and issubclass(obj, ProteinModelWrapper) and obj != ProteinModelWrapper: 38 | models.append(obj) 39 | return models 40 | 41 | def check_structures_available() -> bool: 42 | """Check if any structure information is available in the provided data.""" 43 | if wt and wt.structure is not None: 44 | return True 45 | for seq_set in [training_sequences, testing_sequences]: 46 | if seq_set: 47 | if any(seq.structure is not None for seq in seq_set): 48 | return True 49 | return False 50 | 51 | def check_msa_for_fit_available() -> bool: 52 | """Check if an MSA is available for fitting.""" 53 | # First check if training sequences are already aligned 54 | if training_sequences and training_sequences.aligned: 55 | return True 56 | # Then check if WT has an MSA 57 | if wt and wt.has_msa: 58 | return True 59 | return False 60 | 61 | def check_wt_msa_available() -> bool: 62 | """Check if wild-type has an associated MSA.""" 63 | return wt is not None and wt.has_msa 64 | 65 | def check_msa_per_sequence_available_and_same_length() -> bool: 66 | """Check if each sequence has its own MSA.""" 67 | msa_avail = False 68 | msa_same_length = False 69 | if training_sequences: 70 | if all(seq.has_msa for seq in training_sequences): 71 | msa_avail = True 72 | else: 73 | msa_avail = False 74 | if testing_sequences: 75 | if all(seq.has_msa for seq in testing_sequences): 76 | msa_avail = True 77 | else: 78 | msa_avail = False 79 | 80 | # check if the wild type has an MSA 81 | if wt and wt.has_msa: 82 | msa_avail = True 83 | 84 | # check if all sequence match msa length 85 | if training_sequences and msa_avail: 86 | for seq in training_sequences: 87 | len_seq = len(seq) 88 | msa = seq.msa if seq.has_msa else wt.msa 89 | if msa.width != len_seq: 90 | msa_same_length = False 91 | break 92 | else: 93 | msa_same_length = True 94 | if testing_sequences and msa_avail: 95 | for seq in testing_sequences: 96 | len_seq = len(seq) 97 | msa = seq.msa if seq.has_msa else wt.msa 98 | if msa.width != len_seq: 99 | msa_same_length = False 100 | break 101 | else: 102 | msa_same_length = True 103 | return msa_avail, msa_same_length 104 | 105 | def check_compatibility(model: Type[ProteinModelWrapper]) -> bool: 106 | """Check if the model is compatible with the provided data.""" 107 | if not model._available: 108 | return False 109 | 110 | # Check fixed length requirement 111 | if model._requires_fixed_length: 112 | if (training_sequences and not training_sequences.fixed_length) or \ 113 | (testing_sequences and not testing_sequences.fixed_length): 114 | return False 115 | 116 | # Check MSA-related requirements 117 | if model._requires_msa_for_fit and not check_msa_for_fit_available(): 118 | return False 119 | if model._requires_wt_msa and not check_wt_msa_available(): 120 | return False 121 | 122 | msa_avail, msa_same_length = check_msa_per_sequence_available_and_same_length() 123 | if model._requires_msa_per_sequence and not msa_avail: 124 | return False 125 | if model._requires_msa_per_sequence and not model._can_handle_aligned_sequences and not msa_same_length: 126 | return False 127 | 128 | # Check wild-type requirement 129 | if model._requires_wt_to_function and wt is None: 130 | return False 131 | 132 | # Check structure requirement 133 | if model._requires_structure and not check_structures_available(): 134 | return False 135 | 136 | return True 137 | 138 | models = load_models() 139 | compatibility = {"compatible": [], "incompatible": []} 140 | 141 | for model in models: 142 | if check_compatibility(model): 143 | compatibility["compatible"].append(model.__name__) 144 | else: 145 | compatibility["incompatible"].append(model.__name__) 146 | 147 | return compatibility 148 | 149 | def get_supported_tools(): 150 | from aide_predict.bespoke_models import TOOLS 151 | out_string = "" 152 | for tool in TOOLS: 153 | avail = tool._available 154 | if avail: 155 | message = 'AVAILABLE' 156 | else: 157 | message = tool._available.message 158 | out_string += tool.__name__ +f": {message}\n" 159 | print(out_string) 160 | return out_string -------------------------------------------------------------------------------- /docs/user_guide/contributing_models.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Contributing Models to AIDE 3 | --- 4 | 5 | # Contributing Models to AIDE 6 | 7 | ## Overview 8 | 9 | AIDE is designed to make it easy to wrap new protein prediction models into a scikit-learn compatible interface. This guide walks through the process of contributing a new model. 10 | 11 | ### 1. Setting Up Development Environment 12 | 13 | ```bash 14 | git clone https://github.com/beckham-lab/aide_predict 15 | cd aide_predict 16 | conda env create -f environment.yaml 17 | conda activate aide_predict 18 | pip install -e ".[dev]" # Installs in editable mode with development dependencies 19 | ``` 20 | 21 | ### 2. Understanding Model Dependencies 22 | 23 | AIDE uses a tiered dependency system to minimize conflicts and installation complexity: 24 | 25 | 1. **Base Dependencies**: If your model only needs numpy, scipy, scikit-learn, etc., it can be included in the base package. 26 | 27 | 2. **Optional Dependencies**: If your model needs additional pip-installable packages: 28 | - Create or update a `requirements-.txt` file 29 | - Example: `requirements-transformers.txt` for models using HuggingFace transformers 30 | 31 | 3. **Complex Dependencies**: If your model requires a specific environment or complex setup: 32 | - Package should be installed separately 33 | - AIDE will call it via subprocess 34 | - Model checks for environment variables pointing to installation 35 | - Example: EVE model checking for `EVE_REPO` and `EVE_CONDA_ENV` 36 | 37 | ### 3. Creating the Model Class 38 | 39 | Models should be placed in one of two directories: 40 | - `aide_predict/bespoke_models/embedders/`: For models that create numerical features 41 | - `aide_predict/bespoke_models/predictors/`: For models that predict protein properties 42 | 43 | Basic structure: 44 | 45 | ```python 46 | from aide_predict.bespoke_models.base import ProteinModelWrapper 47 | from aide_predict.utils.common import MessageBool 48 | 49 | # Check dependencies 50 | try: 51 | import some_required_package 52 | AVAILABLE = MessageBool(True, "Model is available") 53 | except ImportError: 54 | AVAILABLE = MessageBool(False, "Requires some_required_package") 55 | 56 | class MyModel(ProteinModelWrapper): 57 | """Documentation in NumPy style. 58 | 59 | Parameters 60 | ---------- 61 | param1 : type 62 | Description 63 | metadata_folder : str, optional 64 | Directory for model files 65 | wt : ProteinSequence, optional 66 | Wild-type sequence for comparative predictions 67 | 68 | Attributes 69 | ---------- 70 | fitted_ : bool 71 | Whether model has been fitted 72 | """ 73 | _available = AVAILABLE # Class attribute for availability 74 | 75 | def __init__(self, param1, metadata_folder=None, wt=None, **kwargs): 76 | super().__init__(metadata_folder=metadata_folder, wt=wt, **kwargs) 77 | self.param1 = param1 # Save user parameters as attributes 78 | 79 | def _fit(self, X, y=None): 80 | """Fit the model. Called by public fit() method.""" 81 | # Implementation 82 | self.fitted_ = True # Mark as fitted 83 | return self 84 | 85 | def _transform(self, X): 86 | """Transform sequences. Called by public transform() method.""" 87 | # Implementation 88 | return features 89 | ``` 90 | 91 | ### 4. Adding Model Requirements with Mixins 92 | 93 | AIDE uses mixins to declare model requirements and capabilities. Common mixins: 94 | 95 | ```python 96 | # Input requirements 97 | RequiresMSAForFitMixin # Needs MSA for fit method. If not found, will attempt to fall back to WT sequence msa 98 | RequiresWTMSAMixin # Needs a WT sequence with an msa 99 | RequiresMSAPerSequenceMixin # The model needs msas, but can handle having different MSAs for each input. If inputs do not have MSAs, will attempt fall back to WT sequence msa 100 | RequiresFixedLengthMixin # Sequences must be same length 101 | RequiresStructureMixin # Uses structural information 102 | RequiresWTToFunctionMixin # Needs wild-type sequence 103 | RequiresWTDuringInferenceMixin # Model does its own normalization to any WT internally. If not inheritted, aide will automatically normalize outputs to any WT sequence provided 104 | 105 | 106 | # Output capabilities 107 | CanRegressMixin # Can predict numeric values 108 | PositionSpecificMixin # Outputs per-position scores or embeddings 109 | 110 | # Processing behavior 111 | CacheMixin # Enables result caching 112 | AcceptsLowerCaseMixin # Handles lowercase sequences 113 | ExpectsNoFitMixin # Does not require any inputs to the fit method 114 | ShouldRefitOnSequencesMixin # restore sklearn default behaviour to refit when fit is called or params are set. Be default, models do not refit. 115 | 116 | ``` 117 | 118 | Example with mixins: 119 | 120 | ```python 121 | class MyModel( 122 | RequiresMSAMixin, # Needs MSA for training 123 | CanRegressMixin, # Makes predictions 124 | PositionSpecificMixin, # Per-position outputs 125 | CacheMixin, # Caches results 126 | ProteinModelWrapper # Always last 127 | ): 128 | pass 129 | ``` 130 | 131 | Ensure that the `_avialable` attribute is set to a valid `MessageBool` object that is computed on import based on the availability of the model's dependencies. 132 | 133 | ### 5. Testing Your Model 134 | 135 | If applicable, add scientific validation tests in `tests/test_not_base_models/`: 136 | ```python 137 | from aide_predict.bespoke_models.embedders.my_model import MyModel 138 | def test_my_model_benchmark(): 139 | """Test against published benchmark.""" 140 | model = MyModel() 141 | score = model.score(benchmark_data) 142 | assert score >= expected_performance 143 | ``` 144 | 145 | Run the tests with `pytest tests/test_not_base_models/test_my_model.py`, and copy the results. 146 | 147 | Ensure that this test is not tracked by coverage, as we do not run CI on non-base models that have additional dependencies: 148 | 149 | Update `.coveragerc`: 150 | ``` 151 | omit = 152 | ... other omitted files are here ... 153 | aide_predict/bespoke_models/embedders/my_model.py 154 | ``` 155 | 156 | ### 7. Expose your model so that AIDE can find it and test it against user data 157 | 158 | Update `aide_predict/bespoke_models/__init__.py` to include your model in the `TOOLS` list: 159 | 160 | ```python 161 | from .embedders.my_model import MyModel 162 | 163 | TOOLS = [ 164 | ...other tools are here... 165 | MyModel 166 | ] 167 | ``` 168 | 169 | ### 7. Submitting Your Contribution 170 | 171 | 1. Create a new branch 172 | 2. Implement your model in its own module 173 | 3. Add any tests 174 | 4. Submit a pull request, add any test results to the pull request so the expected performance can be verified 175 | 176 | -------------------------------------------------------------------------------- /tests/test_bespoke_models/test_embedders/test_kmer.py: -------------------------------------------------------------------------------- 1 | # tests/test_bespoke_models/test_embedders/test_kmer.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 10/23/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import pytest 9 | import numpy as np 10 | from aide_predict.bespoke_models.embedders.kmer import KmerEmbedding 11 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence 12 | 13 | @pytest.fixture 14 | def simple_sequences(): 15 | """Fixture providing simple protein sequences for testing.""" 16 | seqs = [ 17 | ProteinSequence("ACDEG", id="seq1"), 18 | ProteinSequence("ACDEF", id="seq2"), 19 | ProteinSequence("ACDEK", id="seq3") 20 | ] 21 | return ProteinSequences(seqs) 22 | 23 | @pytest.fixture 24 | def aligned_sequences(): 25 | """Fixture providing aligned protein sequences for testing.""" 26 | seqs = [ 27 | ProteinSequence("ACD-EG", id="seq1"), 28 | ProteinSequence("ACD-EF", id="seq2"), 29 | ProteinSequence("ACD-EK", id="seq3") 30 | ] 31 | return ProteinSequences(seqs) 32 | 33 | @pytest.fixture 34 | def varied_length_sequences(): 35 | """Fixture providing sequences of different lengths.""" 36 | seqs = [ 37 | ProteinSequence("ACDEG", id="seq1"), 38 | ProteinSequence("ACDEFGH", id="seq2"), 39 | ProteinSequence("ACDEKIJ", id="seq3") 40 | ] 41 | return ProteinSequences(seqs) 42 | 43 | class TestKmerEmbedding: 44 | def test_initialization(self): 45 | """Test proper initialization of KmerEmbedding.""" 46 | embedder = KmerEmbedding(k=3, normalize=True) 47 | assert embedder.k == 3 48 | assert embedder.normalize == True 49 | assert len(embedder._kmer_to_index) == 0 50 | 51 | def test_fitting_creates_kmer_mapping(self, simple_sequences): 52 | """Test that fitting creates proper kmer to index mapping.""" 53 | embedder = KmerEmbedding(k=3) 54 | embedder.fit(simple_sequences) 55 | 56 | # Check that all possible 3-mers from sequences are in mapping 57 | expected_kmers = {'ACD', 'CDE', 'DEG', 'DEF', 'DEK'} 58 | assert set(embedder._kmer_to_index.keys()) == expected_kmers 59 | assert embedder.n_features_ == len(expected_kmers) 60 | assert embedder.fitted_ 61 | 62 | def test_transform_shape(self, simple_sequences): 63 | """Test that transform returns correct shape.""" 64 | embedder = KmerEmbedding(k=3) 65 | embedder.fit(simple_sequences) 66 | embeddings = embedder.transform(simple_sequences) 67 | 68 | assert embeddings.shape == (len(simple_sequences), embedder.n_features_) 69 | assert embeddings.dtype == np.float32 70 | 71 | def test_normalization(self, simple_sequences): 72 | """Test that normalization works correctly.""" 73 | embedder = KmerEmbedding(k=3, normalize=True) 74 | embedder.fit(simple_sequences) 75 | embeddings = embedder.transform(simple_sequences) 76 | 77 | # Check that each row sums to 1 78 | row_sums = np.sum(embeddings, axis=1) 79 | assert np.allclose(row_sums, 1.0) 80 | 81 | def test_no_normalization(self, simple_sequences): 82 | """Test without normalization.""" 83 | embedder = KmerEmbedding(k=3, normalize=False) 84 | embedder.fit(simple_sequences) 85 | embeddings = embedder.transform(simple_sequences) 86 | 87 | # Each kmer should appear exactly once in each sequence 88 | assert np.all(embeddings <= 1) 89 | assert np.all(embeddings >= 0) 90 | 91 | def test_aligned_sequences(self, aligned_sequences): 92 | """Test handling of aligned sequences with gaps.""" 93 | embedder = KmerEmbedding(k=3) 94 | embedder.fit(aligned_sequences) 95 | embeddings = embedder.transform(aligned_sequences) 96 | 97 | # Should produce same results as unaligned sequences 98 | expected_kmers = {'ACD', 'CDE', 'DEG', 'DEF', 'DEK'} 99 | assert set(embedder._kmer_to_index.keys()) == expected_kmers 100 | 101 | def test_varied_length_sequences(self, varied_length_sequences): 102 | """Test handling of sequences with different lengths.""" 103 | embedder = KmerEmbedding(k=3) 104 | embedder.fit(varied_length_sequences) 105 | embeddings = embedder.transform(varied_length_sequences) 106 | 107 | # Should handle different lengths properly 108 | assert embeddings.shape[0] == len(varied_length_sequences) 109 | assert all(np.sum(embeddings[i]) > 0 for i in range(len(varied_length_sequences))) 110 | 111 | def test_feature_names(self, simple_sequences): 112 | """Test that feature names are generated correctly.""" 113 | embedder = KmerEmbedding(k=3) 114 | embedder.fit(simple_sequences) 115 | feature_names = embedder.get_feature_names_out() 116 | 117 | assert len(feature_names) == embedder.n_features_ 118 | assert all(name.startswith("kmer_") for name in feature_names) 119 | assert set(name[5:] for name in feature_names) == set(embedder._kmer_to_index.keys()) 120 | 121 | def test_transform_new_sequences(self, simple_sequences): 122 | """Test transforming sequences not seen during fitting.""" 123 | embedder = KmerEmbedding(k=3) 124 | embedder.fit(simple_sequences) 125 | 126 | new_seqs = ProteinSequences([ProteinSequence("ACDEM", id="new_seq")]) 127 | embeddings = embedder.transform(new_seqs) 128 | 129 | assert embeddings.shape == (1, embedder.n_features_) 130 | 131 | def test_invalid_k(self): 132 | """Test handling of invalid k values.""" 133 | with pytest.raises(ValueError): 134 | KmerEmbedding(k=0) 135 | with pytest.raises(ValueError): 136 | KmerEmbedding(k=-1) 137 | 138 | def test_empty_sequences(self): 139 | """Test handling of empty sequence list.""" 140 | embedder = KmerEmbedding(k=3) 141 | empty_seqs = ProteinSequences([]) 142 | 143 | with pytest.raises(ValueError): 144 | embedder.fit(empty_seqs) 145 | 146 | def test_sequence_shorter_than_k(self): 147 | """Test handling of sequences shorter than k.""" 148 | embedder = KmerEmbedding(k=5) 149 | short_seqs = ProteinSequences([ProteinSequence("ACD", id="short")]) 150 | 151 | with pytest.raises(ValueError): 152 | embedder.fit(short_seqs) 153 | 154 | def test_transform_before_fit(self, simple_sequences): 155 | """Test that transform raises error if called before fit.""" 156 | embedder = KmerEmbedding(k=3) 157 | with pytest.raises(ValueError): 158 | embedder.transform(simple_sequences) 159 | -------------------------------------------------------------------------------- /aide_predict/bespoke_models/predictors/vespa.py: -------------------------------------------------------------------------------- 1 | # aide_predict/bespoke_models/predictors/vespa.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 8/1/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | Wrapper of VESPA: 9 | Marquet, C. et al. Embeddings from protein language models predict conservation and variant effects. Hum Genet 141, 1629–1647 (2022). 10 | 11 | This model embeds the sequences with a PLM, then uses the embeddings for a pretrained logistic regression model for conservation. These 12 | are input into a model to predict single mutation effects. 13 | ''' 14 | import os 15 | from typing import Optional, Union, List 16 | import numpy as np 17 | import pandas as pd 18 | import subprocess 19 | 20 | from aide_predict.utils.data_structures import ProteinSequence, ProteinSequences 21 | from aide_predict.bespoke_models.base import ProteinModelWrapper, CanRegressMixin, RequiresWTDuringInferenceMixin, RequiresWTToFunctionMixin, ExpectsNoFitMixin, CacheMixin 22 | from aide_predict.utils.common import MessageBool 23 | 24 | try: 25 | import vespa.predict.config 26 | from vespa.predict.vespa import VespaPred 27 | from vespa.predict import utils 28 | import torch 29 | AVAILABLE = MessageBool(True, "VESPA is available.") 30 | except ImportError: 31 | AVAILABLE = MessageBool(False, "VESPA is not available. Please install it with requirements-vespa.txt.") 32 | 33 | import logging 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | class VESPAWrapper(CanRegressMixin, ExpectsNoFitMixin, RequiresWTDuringInferenceMixin, RequiresWTToFunctionMixin, CacheMixin, ProteinModelWrapper): 38 | """ 39 | A wrapper class for the VESPA (Variant Effect Score Prediction using Attention) model. 40 | 41 | This class provides an interface to use VESPA within the AIDE framework, 42 | allowing for prediction of variant effects on protein sequences. 43 | 44 | Attributes: 45 | light (bool): If True, uses the lighter VESPAl model. If False, uses the full VESPA model. 46 | """ 47 | 48 | _available = AVAILABLE 49 | 50 | def __init__(self, metadata_folder: Optional[str] = None, 51 | wt: Optional[Union[str, ProteinSequence]] = None, 52 | use_cache: bool = False, 53 | light: bool = True) -> None: 54 | """ 55 | Initialize the VESPAWrapper. 56 | 57 | Args: 58 | metadata_folder (Optional[str]): Folder to store metadata. 59 | wt (Optional[Union[str, ProteinSequence]]): Wild-type protein sequence. 60 | light (bool): If True, use the lighter VESPAl model. If False, use the full VESPA model. 61 | """ 62 | super().__init__(metadata_folder=metadata_folder, use_cache=use_cache, wt=wt) 63 | self.light = light 64 | 65 | def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'VESPAWrapper': 66 | """ 67 | Fit the VESPA model. This method is a placeholder as VESPA doesn't require fitting. 68 | 69 | Args: 70 | X (ProteinSequences): The input protein sequences. 71 | y (Optional[np.ndarray]): Ignored. Present for API consistency. 72 | 73 | Returns: 74 | VESPAWrapper: The fitted model (self). 75 | """ 76 | self.fitted_ = True 77 | return self 78 | 79 | def _transform(self, X: ProteinSequences) -> np.ndarray: 80 | """ 81 | Transform the input sequences using the VESPA model. 82 | 83 | This method checks that each input sequence is a single point mutation from the wild type, 84 | writes the mutations to a file, runs the VESPA model, and processes the results. 85 | 86 | Args: 87 | X (ProteinSequences): The input protein sequences to transform. 88 | 89 | Returns: 90 | np.ndarray: The log-transformed VESPA scores for each input sequence. 91 | 92 | Raises: 93 | ValueError: If any input sequence is not a single point mutation from the wild type, 94 | or if VESPA fails to return predictions. 95 | """ 96 | # Check that each sequence is maximum a single point mutation from the wild type 97 | for seq in X: 98 | mutations = self.wt.get_mutations(seq) 99 | if len(mutations) != 1: 100 | raise ValueError(f"Sequence {seq} is not a single point mutation from the wild type sequence {self.wt}") 101 | 102 | mutation_file = os.path.join(self.metadata_folder, "mutations.txt") 103 | wt_fasta_file = os.path.join(self.metadata_folder, "wt.fasta") 104 | 105 | # Write mutations to file 106 | with open(mutation_file, 'w') as f: 107 | for seq in X: 108 | mutation = self.wt.get_mutations(seq)[0] 109 | fromAA, pos, toAA = mutation[0], mutation[1:-1], mutation[-1] 110 | f.write(f"{self.wt.id}_{fromAA}{int(pos)-1}{toAA}\n") 111 | 112 | # Create a temporary file for the wild type sequence 113 | ProteinSequences([self.wt]).to_fasta(wt_fasta_file) 114 | 115 | # Run the model 116 | cmd = ["vespa", os.path.basename(wt_fasta_file), "-m", os.path.basename(mutation_file), "--prott5_weights_cache", '.'] 117 | if self.light: 118 | cmd.append("--vespal") 119 | else: 120 | cmd.append("--vespa") 121 | logger.info(f"Running command: {cmd}") 122 | stdout, stderr = subprocess.Popen( 123 | cmd, 124 | stderr=subprocess.PIPE, 125 | cwd=self.metadata_folder 126 | ).communicate() 127 | if stderr: 128 | logger.error(f"VESPA gave: {stderr.decode()}") 129 | 130 | results = [] 131 | outpath = os.path.join(self.metadata_folder, "vespa_run_directory", "output", "0.csv") 132 | if not os.path.exists(outpath): 133 | raise ValueError("VESPA did not return predictions, check logs.") 134 | column = "VESPA" if not self.light else "VESPAl" 135 | df = pd.read_csv(outpath, sep=';').set_index('Mutant') 136 | for seq in X: 137 | mutation = self.wt.get_mutations(seq)[0] 138 | fromAA, pos, toAA = mutation[0], mutation[1:-1], mutation[-1] 139 | result = df.loc[f"{fromAA}{int(pos)-1}{toAA}"][column] 140 | results.append(result) 141 | 142 | return np.log(1-np.array(results).reshape(-1, 1)) 143 | 144 | def get_feature_names_out(self, input_features: Optional[List[str]] = None) -> List[str]: 145 | """ 146 | Get the names of the output features. 147 | 148 | Args: 149 | input_features (Optional[List[str]]): Ignored. Present for API consistency. 150 | 151 | Returns: 152 | List[str]: A list containing the name of the output feature. 153 | """ 154 | return ["VESPA_score"] 155 | 156 | -------------------------------------------------------------------------------- /aide_predict/utils/alignment_calls.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/alignment_calls.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 6/12/2024 5 | * Company: Bottle Institute @ National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | ''' 8 | import tempfile 9 | import subprocess 10 | import os 11 | 12 | from Bio import pairwise2 13 | from Bio.Align import substitution_matrices 14 | 15 | from typing import Union, Optional 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | 20 | def sw_global_pairwise(seq1: "ProteinSequence", seq2: "ProteinSequence", matrix: str = 'BLOSUM62', gap_open: float = -10, gap_extend: float = -0.5) -> tuple['ProteinSequence', 'ProteinSequence']: 21 | """ 22 | Align two ProteinSequence objects using global alignment with a specified substitution matrix. 23 | 24 | Args: 25 | seq1 (ProteinSequence): The first protein sequence to align. 26 | seq2 (ProteinSequence): The second protein sequence to align. 27 | matrix (str, optional): The substitution matrix to use. Defaults to 'BLOSUM62'. 28 | gap_open (float, optional): The gap opening penalty. Defaults to -10. 29 | gap_extend (float, optional): The gap extension penalty. Defaults to -0.5. 30 | 31 | Returns: 32 | tuple[ProteinSequence, ProteinSequence]: A tuple containing the aligned sequences as ProteinSequence objects. 33 | """ 34 | if seq1.has_gaps or seq2.has_gaps: 35 | raise ValueError("Input sequences should not contain gaps to be aligned") 36 | 37 | # Load the substitution matrix 38 | subst_matrix = substitution_matrices.load(matrix) 39 | 40 | # Perform the global alignment 41 | alignments = pairwise2.align.globalds(str(seq1), str(seq2), subst_matrix, gap_open, gap_extend) 42 | 43 | # Get the best alignment (first in the list) 44 | best_alignment = alignments[0] 45 | 46 | # Create new ProteinSequence objects with the aligned sequences 47 | from aide_predict.utils.data_structures import ProteinSequence 48 | aligned_seq1 = ProteinSequence(best_alignment.seqA, id=seq1.id, structure=seq1.structure) 49 | aligned_seq2 = ProteinSequence(best_alignment.seqB, id=seq2.id, structure=seq2.structure) 50 | 51 | return aligned_seq1, aligned_seq2 52 | 53 | 54 | def mafft_align(sequences: "ProteinSequences", 55 | existing_alignment: Optional["ProteinSequences"] = None, 56 | realign: bool = False, 57 | output_fasta: Optional[str] = None) -> "ProteinSequences": 58 | """ 59 | Perform multiple sequence alignment using MAFFT. 60 | 61 | Args: 62 | sequences (ProteinSequences): The sequences to align. 63 | existing_alignment (Optional[ProteinSequences]): An existing alignment to add sequences to. 64 | realign (bool): If True, realign all sequences from scratch. If False, add new sequences to existing alignment. 65 | output_fasta (Optional[str]): Path to save the alignment. If None, a temporary file is used. 66 | 67 | Returns: 68 | ProteinSequences: The aligned sequences, either in memory or on file depending on output_fasta. 69 | 70 | Raises: 71 | subprocess.CalledProcessError: If MAFFT execution fails. 72 | FileNotFoundError: If MAFFT is not installed or not in PATH. 73 | """ 74 | # Create a temporary directory for input and output files 75 | from aide_predict.utils.data_structures import ProteinSequences, ProteinSequencesOnFile 76 | 77 | # check that the sequences are not already gap containing 78 | if sequences.has_gaps: 79 | raise ValueError("Input sequences should not contain gaps to be aligned") 80 | 81 | with tempfile.TemporaryDirectory() as temp_dir: 82 | # Prepare input file 83 | if isinstance(sequences, ProteinSequencesOnFile): 84 | input_fasta = sequences.file_path 85 | else: 86 | input_fasta = os.path.join(temp_dir, "input.fasta") 87 | sequences.to_fasta(input_fasta) 88 | 89 | # Prepare output file 90 | if output_fasta: 91 | output_file = output_fasta 92 | else: 93 | output_file = os.path.join(temp_dir, "output.fasta") 94 | 95 | # Prepare MAFFT command 96 | mafft_cmd = ["mafft"] 97 | 98 | # prepare existing alignment 99 | if existing_alignment is not None: 100 | if not existing_alignment.aligned: 101 | raise ValueError("Existing alignment must be aligned") 102 | if isinstance(existing_alignment, ProteinSequencesOnFile): 103 | existing_fasta = existing_alignment.file_path 104 | else: 105 | existing_fasta = os.path.join(temp_dir, "existing.fasta") 106 | existing_alignment.to_fasta(existing_fasta) 107 | 108 | if existing_alignment is not None and not realign: 109 | # Add to existing alignment 110 | mafft_cmd.extend(["--add", input_fasta, "--keeplength", existing_fasta]) 111 | 112 | elif existing_alignment is not None and realign: 113 | # Realignment 114 | mafft_cmd.extend(["--add", input_fasta, existing_fasta]) 115 | else: 116 | # New alignment or realignment 117 | mafft_cmd.extend([input_fasta]) 118 | 119 | mafft_cmd.extend([">", output_file]) 120 | 121 | # Run MAFFT 122 | try: 123 | subprocess.run(" ".join(mafft_cmd), shell=True, check=True, stderr=subprocess.PIPE) 124 | except subprocess.CalledProcessError as e: 125 | raise RuntimeError(f"MAFFT alignment failed: {e.stderr.decode()}") from e 126 | except FileNotFoundError: 127 | raise FileNotFoundError("MAFFT is not installed or not in PATH") 128 | 129 | # Load aligned sequences 130 | if output_fasta: 131 | aligned_sequences = ProteinSequencesOnFile(output_fasta) 132 | else: 133 | aligned_sequences = ProteinSequences.from_fasta(output_file) 134 | 135 | # Transfer MSAs from original sequences to aligned ones 136 | # Create mapping from sequence IDs to original sequences 137 | id_to_seq = {seq.id: seq for seq in sequences if seq.id is not None} 138 | hash_to_seq = {hash(str(seq)): seq for seq in sequences} 139 | 140 | # Transfer MSAs 141 | for i, aligned_seq in enumerate(aligned_sequences): 142 | seq_id = aligned_seq.id 143 | if seq_id in id_to_seq and id_to_seq[seq_id].has_msa: 144 | aligned_sequences[i]._msa = id_to_seq[seq_id].msa 145 | 146 | elif hash(str(aligned_seq.with_no_gaps())) in hash_to_seq: 147 | orig_seq = hash_to_seq[hash(str(aligned_seq.with_no_gaps()))] 148 | if orig_seq.has_msa: 149 | aligned_sequences[i]._msa = orig_seq.msa 150 | 151 | # also structures 152 | if seq_id in id_to_seq and id_to_seq[seq_id].has_structure: 153 | aligned_sequences[i]._structure = id_to_seq[seq_id].structure 154 | 155 | elif hash(str(aligned_seq.with_no_gaps())) in hash_to_seq: 156 | orig_seq = hash_to_seq[hash(str(aligned_seq.with_no_gaps()))] 157 | if orig_seq.has_structure: 158 | aligned_sequences[i]._structure = orig_seq.structure 159 | 160 | 161 | return aligned_sequences -------------------------------------------------------------------------------- /docs/user_guide/badass.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Sequence Optimization towards target function with BADASS 3 | --- 4 | 5 | # Protein Optimization with BADASS 6 | 7 | ## Overview 8 | 9 | AIDE integrates BADASS, an adaptive simulated annealing algorithm that efficiently explores protein sequence space to find variants with optimal properties. The BADASS algorithm was introduced in [this paper](https://www.biorxiv.org/content/10.1101/2024.10.25.620340v1) and has been adapted in AIDE to work with any of its protein prediction models. 10 | 11 | ## Installation 12 | 13 | To use BADASS with AIDE, install the required dependencies: 14 | 15 | ```bash 16 | pip install -r requirements-badass.txt 17 | ``` 18 | 19 | ## Basic Usage 20 | 21 | Here's a complete example of using BADASS with an ESM2 zero-shot predictor: 22 | 23 | ```python 24 | from aide_predict import ProteinSequence, ESM2LikelihoodWrapper 25 | from aide_predict.utils.badass import BADASSOptimizer, BADASSOptimizerParams 26 | 27 | # 1. Define your protein sequence and prediction model 28 | wt = ProteinSequence("MKLLVLGLPGAGKGTQAEKIVAAYGIPHISTGDMFRAAMKEGTPLGLQAKQYMDEGDLVPDEVTIGIVRERLSKDDCQNGFLLDGFPRTVAQAEALETMLADIASRLSALPPATQTRMILMVEDELRNLHRGQVLPSENTFRVADDNEETIKKIRQKYGNSSGVI") 29 | 30 | # 2. Set up a prediction model 31 | # Note that this can be a supervised model. In general, any ProteinModel or 32 | # scikit-learn pipeline whose input models are ProteinModelWrapper can be used. 33 | model = ESM2LikelihoodWrapper(wt=wt) 34 | model.fit([]) # No training needed for zero-shot model 35 | 36 | # 3. Configure optimization parameters 37 | params = BADASSOptimizerParams( 38 | num_mutations=3, # Maximum mutations per variant 39 | num_iter=100, # Number of optimization iterations 40 | seqs_per_iter=200 # Sequences evaluated per iteration 41 | ) 42 | 43 | # 4. Create and run the optimizer 44 | optimizer = BADASSOptimizer( 45 | predictor=model.predict, 46 | reference_sequence=wt, 47 | params=params 48 | ) 49 | 50 | # 5. Run optimization 51 | # This returns protein variants as well as scores from the optimizer 52 | # (which may be scaled and not equal to direct model outputs) 53 | results_df, stats_df = optimizer.optimize() 54 | 55 | # 6. Visualize the optimization process 56 | optimizer.plot() 57 | 58 | # 7. Print top variants 59 | print(results_df.sort_values('scores', ascending=False).head(10)) 60 | ``` 61 | 62 | ## Optimization Parameters 63 | 64 | BADASS behavior can be extensively customized through the `BADASSOptimizerParams` class: 65 | 66 | ```python 67 | params = BADASSOptimizerParams( 68 | # Core parameters 69 | seqs_per_iter=500, # Sequences per iteration 70 | num_iter=200, # Total optimization iterations 71 | num_mutations=5, # Maximum mutations per variant 72 | init_score_batch_size=500, # Batch size for initial scoring 73 | 74 | # Algorithm behavior 75 | temperature=1.5, # Initial temperature 76 | cooling_rate=0.92, # Cooling rate for SA 77 | seed=42, # Random seed 78 | gamma=0.5, # Variance boosting weight 79 | 80 | # Constraints 81 | sites_to_ignore=[1, 2, 3], # Positions to exclude from mutation (1-indexed) 82 | 83 | # Advanced options 84 | normalize_scores=True, # Normalize scores 85 | simple_simulated_annealing=False, # Use simple SA without adaptation 86 | cool_then_heat=False, # Use cooling-then-heating schedule 87 | adaptive_upper_threshold=None, # Threshold for adaptivity (float for quantile, int for top N) 88 | n_seqs_to_keep=None, # Number of sequences to keep in results 89 | score_threshold=None, # Score threshold for phase transitions (auto-computed if None) 90 | reversal_threshold=None # Score threshold for phase reversals (auto-computed if None) 91 | ) 92 | ``` 93 | 94 | ## How BADASS Works 95 | 96 | BADASS operates through the following key mechanisms: 97 | 98 | 1. **Initialization**: Computes a score matrix of all single-point mutations 99 | 2. **Sampling**: Uses Boltzmann sampling to generate candidate sequences 100 | 3. **Scoring**: Evaluates candidates with the provided predictor function 101 | 4. **Phase detection**: Identifies when the optimizer has found a promising region 102 | 5. **Adaptive temperature**: Adjusts temperature to balance exploration/exploitation 103 | 6. **Score normalization**: Standardizes scores for better comparison 104 | 105 | During optimization, BADASS maintains several tracking matrices: 106 | - Score matrix for each amino acid at each position 107 | - Observation counts for statistical significance 108 | - Variance estimates for uncertainty quantification 109 | 110 | ## Optimization Results 111 | 112 | The `optimize()` method returns two DataFrames: 113 | 114 | 1. `results_df`: Contains information about all evaluated sequences: 115 | - `sequences`: Compact mutation representation (e.g., "M1L-K5R") 116 | - `scores`: Predicted fitness scores 117 | - `full_sequence`: Complete protein sequence 118 | - `counts`: Number of times each sequence was evaluated 119 | - `num_mutations`: Number of mutations in each sequence 120 | - `iteration`: When the sequence was first observed 121 | 122 | 2. `stats_df`: Contains statistics for each iteration: 123 | - `iteration`: Iteration number 124 | - `avg_score`: Average score per iteration 125 | - `var_score`: Variance of scores 126 | - `n_eff_joint`: Effective number of joint samples 127 | - `n_eff_sites`: Effective number of sites explored 128 | - `n_eff_aa`: Effective number of amino acids explored 129 | - `T`: Temperature at each iteration 130 | - `n_seqs`: Number of sequences evaluated 131 | - `n_new_seqs`: Number of new sequences evaluated 132 | - `num_phase_transitions`: Cumulative number of phase transitions 133 | 134 | ## Analyzing Results 135 | 136 | After optimization, BADASS offers several visualization and analysis options: 137 | 138 | ```python 139 | # Plot optimization progress 140 | optimizer.plot() # Creates multiple plots showing optimization trajectory 141 | 142 | # Save results to CSV 143 | optimizer.save_results("optimization_run") 144 | 145 | # Get best sequences 146 | best_sequences = results_df.sort_values('scores', ascending=False).head(10) 147 | 148 | # Create a ProteinSequences object from best variants 149 | from aide_predict import ProteinSequences 150 | top_variants = ProteinSequences(best_sequences['full_sequence'].tolist()) 151 | 152 | # Further analyze with other AIDE tools 153 | from aide_predict.utils.plotting import plot_mutation_heatmap 154 | mutations = [seq.get_mutations(wt)[0] for seq in top_variants] 155 | scores = best_sequences['scores'].values 156 | plot_mutation_heatmap(mutations, scores) 157 | ``` 158 | 159 | The visualization includes: 160 | 1. Statistics by iteration (scores, effective samples, temperature) 161 | 2. Score distributions vs temperature 162 | 3. Score density distributions across early and late iterations 163 | 164 | ## Performance Considerations 165 | 166 | - BADASS evaluates thousands of sequences, so efficient predictors are important 167 | - For computationally expensive models, consider: 168 | - Using model caching (via `CacheMixin`) 169 | - Reducing `seqs_per_iter` and `num_iter` 170 | - Using batch processing in custom predictors 171 | - Increasing `init_score_batch_size` for better initial sampling 172 | 173 | ## References 174 | 175 | - BADASS: [biphasic annealing for diverse adaptive sequence sampling](https://www.biorxiv.org/content/10.1101/2024.10.25.620340v1) 176 | -------------------------------------------------------------------------------- /aide_predict/utils/plotting.py: -------------------------------------------------------------------------------- 1 | # aide_predict/utils/plotting.py 2 | ''' 3 | * Author: Evan Komp 4 | * Created: 7/26/2024 5 | * Company: National Renewable Energy Lab, Bioeneergy Science and Technology 6 | * License: MIT 7 | 8 | Common plotting calls. 9 | ''' 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | import numpy as np 13 | import pandas as pd 14 | from typing import Optional, Dict 15 | import copy 16 | 17 | from aide_predict.utils.data_structures import ProteinSequences 18 | 19 | def plot_protein_sequence_heatmap(sequences: ProteinSequences, 20 | figsize: tuple = (20, 5), 21 | cmap: str = 'viridis', 22 | title: str = 'Protein Sequence Heatmap') -> plt.Figure: 23 | """ 24 | Create a heatmap visualization of protein sequences with additional sequence properties. 25 | 26 | Args: 27 | sequences (ProteinSequences): A ProteinSequences object containing the protein sequences. 28 | figsize (tuple): Figure size (width, height) in inches. 29 | cmap (str): Colormap to use for the heatmap. 30 | title (str): Title of the plot. 31 | 32 | Returns: 33 | plt.Figure: The matplotlib Figure object containing the heatmap. 34 | """ 35 | # Convert sequences to numeric representation 36 | aa_to_num = {aa: i for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')} 37 | numeric_sequences = [[aa_to_num.get(aa, -1) for aa in str(seq)] for seq in sequences] 38 | 39 | # Create heatmap 40 | fig, ax = plt.subplots(figsize=figsize) 41 | sns.heatmap(numeric_sequences, cmap=cmap, cbar=False, ax=ax) 42 | 43 | # Customize plot 44 | ax.set_xlabel('Amino Acid Position') 45 | ax.set_ylabel('Sequence Number') 46 | ax.set_title(title) 47 | 48 | # Add color legend 49 | cbar = plt.colorbar(ax.collections[0], ax=ax, orientation="vertical", pad=0.02) 50 | cbar.set_ticks(np.arange(20) + 0.5) 51 | cbar.set_ticklabels(list('ACDEFGHIKLMNPQRSTVWY')) 52 | cbar.set_label('Amino Acid', rotation=270, labelpad=15) 53 | 54 | plt.tight_layout() 55 | return fig 56 | 57 | def plot_mutation_heatmap(mutations, scores): 58 | """ 59 | Plot a heatmap of single point mutation scores. 60 | 61 | Parameters: 62 | mutations (list): List of mutation strings (e.g., ["L1V", "A2G", ...]) 63 | scores (list): List of corresponding scores 64 | 65 | Returns: 66 | None (displays the plot) 67 | """ 68 | # All possible amino acids 69 | all_aas = 'ACDEFGHIKLMNPQRSTVWY' 70 | 71 | # Extract residue positions and mutant amino acids 72 | positions = [int(m[1:-1]) for m in mutations] 73 | mutant_aas = [m[-1] for m in mutations] 74 | original_aas = [m[0] for m in mutations] 75 | 76 | # Create a DataFrame 77 | df = pd.DataFrame({ 78 | 'Position': positions, 79 | 'Mutant_AA': mutant_aas, 80 | 'Original_AA': original_aas, 81 | 'Score': scores 82 | }) 83 | 84 | # Create a full matrix with all amino acids and positions 85 | full_matrix = pd.DataFrame(index=range(1, max(positions)+1), columns=list(all_aas)) 86 | 87 | # Fill the matrix with scores 88 | for _, row in df.iterrows(): 89 | full_matrix.at[row['Position'], row['Mutant_AA']] = row['Score'] 90 | 91 | # Fill NaN values with a distinct value (e.g., -999) to color them differently 92 | full_matrix = full_matrix.fillna(-999) 93 | 94 | # Create the heatmap 95 | fig, ax = plt.subplots(figsize=(12, 0.5*max(positions))) # Adjust figure size 96 | sns.heatmap(full_matrix, center=0, ax=ax, cmap='coolwarm_r', 97 | cbar_kws={'label': 'Score'}, 98 | mask=full_matrix == -999) # Mask NaN values 99 | 100 | # Customize the plot 101 | plt.title('Single Point Mutation Scores') 102 | plt.xlabel('Mutant Amino Acid') 103 | plt.ylabel('Residue Position') 104 | 105 | # Add original amino acids to y-axis labels 106 | original_aa_dict = dict(zip(positions, original_aas)) 107 | ax.set_yticks(range(len(full_matrix))) 108 | ax.set_yticklabels([f'{original_aa_dict.get(i+1, "?")} {i+1}' for i in range(len(full_matrix))]) 109 | 110 | # Adjust aspect ratio to make cells square 111 | ax.set_aspect('equal') 112 | 113 | plt.tight_layout() 114 | plt.show() 115 | 116 | def plot_conservation( 117 | conservation_scores: Dict[str, np.ndarray], 118 | p_values: Optional[Dict[str, np.ndarray]] = None, 119 | alpha: float = 1e-10, 120 | stacked: bool = False, 121 | figsize: tuple = (20, 6), 122 | title: str = "Conservation Scores Across Alignment Positions" 123 | ) -> plt.Figure: 124 | """ 125 | Create a bar plot of conservation scores across alignment positions. 126 | 127 | Args: 128 | conservation_scores (Dict[str, np.ndarray]): Dictionary of conservation scores for each property. 129 | p_values (Optional[Dict[str, np.ndarray]]): Dictionary of p-values for each property. If provided, 130 | insignificant bars will be colored grey. 131 | alpha (float): Significance level for p-values. Default is 0.05. 132 | stacked (bool): If True, create a stacked bar plot with colors for different properties. 133 | If False, create a single bar plot with height determined by sum of conservation scores. 134 | figsize (tuple): Figure size (width, height) in inches. Default is (12, 6). 135 | title (str): Title of the plot. Default is "Conservation Scores Across Alignment Positions". 136 | 137 | Returns: 138 | plt.Figure: The matplotlib Figure object containing the plot. 139 | """ 140 | # copy the scores 141 | conservation_scores = copy.deepcopy(conservation_scores) 142 | 143 | # Set up the plot 144 | fig, ax = plt.subplots(figsize=figsize) 145 | sns.set_style("whitegrid") 146 | 147 | # Prepare data 148 | positions = range(len(next(iter(conservation_scores.values())))) 149 | total_scores = np.sum([scores for scores in conservation_scores.values()], axis=0) 150 | 151 | if stacked: 152 | # Create stacked bar plot 153 | bottom = np.zeros(len(positions)) 154 | for prop, scores in conservation_scores.items(): 155 | # mark 0 insiginificant p-values 156 | if p_values is not None: 157 | insignificant_mask = p_values[prop] > alpha 158 | scores[insignificant_mask] = 0.0 159 | 160 | ax.bar(positions, scores, bottom=bottom, label=prop, linewidth=0) 161 | bottom += scores 162 | else: 163 | # Create single bar plot 164 | bars = ax.bar(positions, total_scores) 165 | 166 | # Color bars based on conservation score if p_values not provided 167 | colors = plt.cm.viridis(total_scores / 10.0) 168 | for bar, color in zip(bars, colors): 169 | bar.set_color(color) 170 | 171 | if p_values is None: 172 | pass 173 | else: 174 | # Color bars grey if insignificant 175 | significant = np.any(np.array([p < alpha for p in p_values.values()]), axis=0) 176 | for i, bar in enumerate(bars): 177 | if not significant[i]: 178 | bar.set_color("grey") 179 | bar.set_alpha(0.2) 180 | 181 | # Set labels and title 182 | ax.set_xlabel("Alignment Position") 183 | ax.set_ylabel("Conservation Score") 184 | ax.set_title(title) 185 | 186 | # Add legend if stacked 187 | if stacked: 188 | ax.legend(title="Properties", bbox_to_anchor=(1.05, 1), loc='upper left') 189 | 190 | # Adjust layout and return figure 191 | plt.tight_layout() 192 | return fig --------------------------------------------------------------------------------