├── .github
└── workflows
│ └── cicd.yaml
├── .gitignore
├── .idea
└── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── app
├── __init__.py
├── app.py
├── logo.png
└── tooltips.py
├── demo
├── MLDE_benchmark.py
├── clustering_demo.py
├── demo_data
│ ├── 1zb6.pdb
│ ├── GB1.pdb
│ ├── Nitric_Oxide_Dioxygenase.csv
│ ├── Nitric_Oxide_Dioxygenase_raw.csv
│ ├── Nitric_Oxide_Dioxygenase_wt.fasta
│ ├── binding_data
│ │ ├── dataset_creation.py
│ │ ├── replicate_1.csv
│ │ ├── replicate_2.csv
│ │ └── replicate_3.csv
│ └── methyltransfereases.csv
├── design_demo.py
├── discovery_demo.py
├── fold_design_demo.py
├── mlde_demo.py
├── plot_benchmark.py
├── plot_library_demo.py
├── test
└── zero_shot_demo.py
├── docs
├── .gitignore
├── README.md
├── conf.py
├── index.rst
└── tutorial
│ └── tutorial.py
├── environment.yml
├── favicon.ico
├── proteusEnvironment.yml
├── pyproject.toml
├── run_benchmark.sh
├── setup.py
├── src
└── proteusAI
│ ├── Library
│ ├── __init__.py
│ └── library.py
│ ├── Model
│ ├── __init__.py
│ └── model.py
│ ├── Protein
│ ├── __init__.py
│ └── protein.py
│ ├── __init__.py
│ ├── data_tools
│ ├── MSA.py
│ ├── __init__.py
│ └── pdb.py
│ ├── design_tools
│ ├── Constraints.py
│ ├── MCMC.py
│ ├── ZeroShot.py
│ └── __init__.py
│ ├── io_tools
│ ├── __init__.py
│ ├── embeddings.py
│ ├── fasta.py
│ └── matrices
│ │ ├── BLOSUM50
│ │ ├── BLOSUM62
│ │ └── alphabet
│ ├── mining_tools
│ ├── __init__.py
│ ├── alphafoldDB.py
│ ├── blast.py
│ └── uniprot.py
│ ├── ml_tools
│ ├── __init__.py
│ ├── bo_tools
│ │ ├── __init__.py
│ │ ├── acq_fn.py
│ │ └── genetic_algorithm.py
│ ├── esm_tools
│ │ ├── __init__.py
│ │ ├── alphabet.pt
│ │ └── esm_tools.py
│ ├── sklearn_tools
│ │ ├── __init__.py
│ │ └── grid_search.py
│ └── torch_tools
│ │ ├── __init__.py
│ │ ├── matrices
│ │ ├── BLOSUM50
│ │ ├── BLOSUM62
│ │ └── alphabet
│ │ └── torch_tools.py
│ ├── struc
│ ├── __init__.py
│ └── struc.py
│ └── visual_tools
│ ├── __init__.py
│ └── plots.py
└── tests
└── test_module.py
/.github/workflows/cicd.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | pull_request:
9 | branches: [ "main", "shiny_server"]
10 | schedule:
11 | - cron: '0 2 * * 3'
12 |
13 | permissions:
14 | contents: read
15 |
16 |
17 | jobs:
18 | format:
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: actions/checkout@v4
22 | - uses: psf/black@stable
23 | lint:
24 | name: Lint with ruff
25 | runs-on: ubuntu-latest
26 | steps:
27 | - uses: actions/checkout@v4
28 |
29 | - uses: actions/setup-python@v5
30 | with:
31 | python-version: "3.11"
32 | - name: Install ruff
33 | run: |
34 | pip install ruff
35 | - name: Lint with ruff
36 | run: |
37 | # stop the build if there are Python syntax errors or undefined names
38 | ruff check .
39 | test:
40 | name: Test
41 | runs-on: ubuntu-latest
42 | strategy:
43 | matrix:
44 | python-version: ["3.8"]
45 | steps:
46 | - uses: actions/checkout@v4
47 | - name: Set up Python ${{ matrix.python-version }}
48 | uses: actions/setup-python@v5
49 | with:
50 | python-version: ${{ matrix.python-version }}
51 | cache: 'pip' # caching pip dependencies
52 | # cache-dependency-path: '**/pyproject.toml'
53 | - name: Install dependencies
54 | run: |
55 | python -m pip install --upgrade pip
56 | pip install pytest
57 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.4.1+cpu.html
58 | - name: Run tests
59 | run: python -m pytest tests
60 |
61 | build_source_dist:
62 | name: Build source distribution
63 | if: startsWith(github.ref, 'refs/heads/main') || startsWith(github.ref, 'refs/tags')
64 | runs-on: ubuntu-latest
65 | steps:
66 | - uses: actions/checkout@v4
67 |
68 | - uses: actions/setup-python@v5
69 | with:
70 | python-version: "3.10"
71 |
72 | - name: Install build
73 | run: python -m pip install build
74 |
75 | - name: Run build
76 | run: python -m build --sdist
77 |
78 | - uses: actions/upload-artifact@v4
79 | with:
80 | path: ./dist/*.tar.gz
81 | # Needed in case of building packages with external binaries (e.g. Cython, RUst-extensions, etc.)
82 | # build_wheels:
83 | # name: Build wheels on ${{ matrix.os }}
84 | # if: startsWith(github.ref, 'refs/heads/main') || startsWith(github.ref, 'refs/tags')
85 | # runs-on: ${{ matrix.os }}
86 | # strategy:
87 | # matrix:
88 | # os: [ubuntu-20.04, windows-2019, macOS-10.15]
89 |
90 | # steps:
91 | # - uses: actions/checkout@v4
92 |
93 | # - uses: actions/setup-python@v5
94 | # with:
95 | # python-version: "3.10"
96 |
97 | # - name: Install cibuildwheel
98 | # run: python -m pip install cibuildwheel==2.3.1
99 |
100 | # - name: Build wheels
101 | # run: python -m cibuildwheel --output-dir wheels
102 |
103 | # - uses: actions/upload-artifact@v4
104 | # with:
105 | # path: ./wheels/*.whl
106 |
107 | publish:
108 | name: Publish package
109 | if: startsWith(github.ref, 'refs/tags')
110 | needs:
111 | - format
112 | - lint
113 | - test
114 | - build_source_dist
115 | # - build_wheels
116 | runs-on: ubuntu-latest
117 |
118 | steps:
119 | - uses: actions/download-artifact@v4
120 | with:
121 | name: artifact
122 | path: ./dist
123 |
124 | - uses: pypa/gh-action-pypi-publish@release/v1
125 | with:
126 | user: __token__
127 | password: ${{ secrets.PYPI_API_TOKEN }}
128 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # biodata
132 | *.pdb
133 | *.fasta
134 | *.xml
135 | *.pt
136 | *.png
137 | *.sdf
138 | *.pth
139 | *.zip
140 | *.gif
141 | *.pml
142 | *.dat
143 | *.csv
144 | *.txt
145 | *.pse
146 | .DS_Store
147 | *.json
148 | *.pt
149 |
150 | # from server
151 | *.err
152 | *.out
153 | 4a6d.cif
154 | .idea/ProteusAI.iml
155 | Untitled.ipynb
156 | *.joblib
157 | demo/example_project/models/rf/params.json
158 | demo/example_project/models/rf/results.json
159 |
160 | # For env testing
161 | test.py
162 | demo/mpnn_test.py
163 | demo/mpnn_validation.py
164 | demo/design_demo.py
165 | mpnn.sh
166 | test1.cif
167 | test2.cif
168 | shiny-server-1.5.22.1017-amd64.deb
169 | proteusAI_developer.key
170 | proteusai_developer.key
171 | proteusAI_developer.key.pub
172 | proteusai_developer.key.pub
173 |
174 | google_analytics.html
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version and other tools you might need
9 | build:
10 | os: "ubuntu-20.04"
11 | tools:
12 | python: "mambaforge-22.9"
13 |
14 |
15 | # Build documentation in the "docs/" directory with Sphinx
16 | sphinx:
17 | configuration: docs/conf.py
18 |
19 | # Optionally build your docs in additional formats such as PDF and ePub
20 | # formats:
21 | # - pdf
22 | # - epub
23 |
24 | # Optional but recommended, declare the Python requirements required
25 | # to build your documentation
26 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
27 | # python:
28 | # install:
29 | # - method: pip
30 | # path: .
31 | # extra_requirements:
32 | # - docs
33 |
34 | conda:
35 | environment: environment.yml
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ProteusAI contributors All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | graft src
2 | recursive-exclude __pycache__ *.py[cod]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 |
4 |
5 |
6 | # ProteusAI
7 | ProteusAI is a library for machine learning-guided protein design and engineering.
8 | The library enables workflows from protein structure prediction, the prediction of
9 | mutational effects-, and zero-shot prediction of mutational effects.
10 | The goal is to provide state-of-the-art machine learning for protein engineering in a central library.
11 |
12 | ProteusAI is primarily powered by [PyTorch](https://pytorch.org/get-started/locally/),
13 | [scikit-learn](https://scikit-learn.org/stable/),
14 | and [ESM](https://github.com/facebookresearch/esm) protein language models.
15 |
16 | Cite our preprint [here](https://www.biorxiv.org/content/10.1101/2024.10.01.616114v1).
17 |
18 |
19 | Test out the ProteusAI web app [proteusai.bio](http://proteusai.bio/)
20 |
21 | ## Getting started
22 |
23 | ----
24 | The commands used below are tested on Ubuntu 20.04 and IOS. Some tweaks may be needed for other OS.
25 | We recommend using conda environments to install ProteusAI.
26 |
27 |
28 | Clone the repository and cd to ProteusAI:
29 | ```bash
30 | git clone https://github.com/jonfunk21/ProteusAI.git
31 | cd ProteusAI
32 | ```
33 |
34 | Install the latest version of proteusAI in a new environment
35 | ```bash
36 | conda env create -n proteusAI
37 | conda activate proteusAI
38 | ```
39 |
40 | This uses the `environment.yml` file to install the dependencies.
41 |
42 | ## GPU support
43 | By default proteus will install torch-scatter using cpu compatible binaries.
44 | If you want to take full advantage of GPUs for the Design module consider
45 | uninstalling the default `torch-scatter`, and replace it with the CUDA
46 | compatible version:
47 |
48 | ```bash
49 | pip uninstall torch-scatter
50 | pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+cuda.html
51 | ```
52 |
53 | If you have acces to GPUs you can run ESM-Fold. Unfortunately the installation of
54 | openfold can be unstable for some people. Please follow installation instructions
55 | on of the [ESM repository](https://github.com/facebookresearch/esm) and use
56 | the discussions in the issues section for help.
57 |
58 | ### Install using pip locally for developement
59 |
60 | Install a local version which picks up the latest changes using an editable install:
61 |
62 | ```bash
63 | # conda env create -n proteusAI python=3.8
64 | # conda activate proteusAI
65 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.4.1+cpu.html
66 | ```
67 |
68 | ### Troubleshooting
69 | You can check a working configuration for a Ubuntu machine (VM)
70 | in the `proteusEnvironment.yml` file. The latest versions can be checked by us via
71 | our actions.
72 |
73 | ### Setting shiny server
74 |
75 | Install Shiny Server on Ubuntu 18.04+ (the instructions for other systems are availabe at posit.co, please skip the section about R Shiny packages installation) with the following commands:
76 | ```bash
77 | sudo apt-get install gdebi-core
78 | wget https://download3.rstudio.org/ubuntu-18.04/x86_64/shiny-server-1.5.22.1017-amd64.deb
79 | sudo gdebi shiny-server-1.5.22.1017-amd64.deb
80 | ```
81 |
82 | Edit the default config file `/etc/shiny-server/shiny-server.conf` for Shiny Server (the `sudo` command or root privileges are required):
83 | ```bash
84 | # Use python from the virtual environment to run Shiny apps
85 | python /home/proteus_developer/miniforge3/envs/proteusAI_depl/bin/python;
86 |
87 | # Instruct Shiny Server to run applications as the user "shiny"
88 | run_as shiny;
89 |
90 | # Never delete logs regardless of the their exit code
91 | preserve_logs true;
92 |
93 | # Do not replace errors with the generic error message, show them as they are
94 | sanitize_errors false;
95 |
96 | # Define a server that listens on port 80
97 | server {
98 | listen 80;
99 |
100 | # Define a location at the base URL
101 | location / {
102 |
103 | # Host the directory of Shiny Apps stored in this directory
104 | site_dir /srv/shiny-server;
105 |
106 | # Log all Shiny output to files in this directory
107 | log_dir /var/log/shiny-server;
108 |
109 | # When a user visits the base URL rather than a particular application,
110 | # an index of the applications available in this directory will be shown.
111 | directory_index on;
112 | }
113 | }
114 | ```
115 | Restart the shiny server with the following command to apply the server configuration changes:
116 | ```bash
117 | sudo systemctl restart shiny-server
118 | ```
119 | If you deploy the app on your local machine, be sure that the port 80 is open and not blocked by a firewall. You can check it with `netstat`:
120 | ```bash
121 | nc 80
122 | ```
123 | If you deploy the app on your Azure Virtual Machine (VM), please add an Inbound Port rule in the Networking - Network Settings section on Azure Portal. Set the following properties:
124 | ```yaml
125 | Source: Any
126 | Source port ranges: *
127 | Destination: Any
128 | Service: HTTP
129 | Destination port ranges: 80
130 | Protocol: TCP
131 | Action: Allow
132 | ```
133 | Other fields can beleaft as they are by default.
134 |
135 | Finally, create symlinks to your app files in the default Shiny Server folder `/srv/shiny-server/`:
136 |
137 | ```bash
138 | sudo ln -s /home/proteus_developer/ProteusAI/app/app.py /srv/shiny-server/app.py
139 | sudo ln -s /home/proteus_developer/ProteusAI/app/logo.png /srv/shiny-server/logo.png
140 | ```
141 | If everything has been done correctly, you must see the application index page available at `http://127.0.0.1` (if you deploy your app locally) or at `http://` (if you deploy your app on an Azure VM). Additionally, the remote app can still be available in your local browser (the Shiny extension in Visual Studio must be enabled) if you run the following terminal command on the VM:
142 | ```bash
143 | /home/proteus_developer/miniforge3/envs/proteusAI_depl/bin/python -m shiny run --port 33015 --reload --autoreload-port 43613 /home/proteus_developer/ProteusAI/app/app.py
144 | ```
145 | If you get warnings, debug or "Disconnected from the server" messages, it is likely due to:
146 | - absent python modules,
147 | - updated versions of the current python modules,
148 | - using relative paths instead of absolute paths (Shiny Server sees relative paths as starting from `/srv/shiny-server/` folder)
149 | or
150 | - logical errors in the code.
151 |
152 | In order to debug the application, see what is written in the server logs under `/var/log/shiny-server` (the log_dir parameter can be reset in the Shiny Server instance config file `/etc/shiny-server/shiny-server.conf`).
153 |
154 | ### Note on permissions:
155 | The app may give some problems due to directories not having permissions to create directories or load files to certain directories. When this happen, a solution found was to use the following:
156 |
157 | ```bash
158 | chmod 777 directory_name
159 | ```
160 |
161 |
162 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/app/__init__.py
--------------------------------------------------------------------------------
/app/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/app/logo.png
--------------------------------------------------------------------------------
/app/tooltips.py:
--------------------------------------------------------------------------------
1 | data_tooltips = """
2 | ProteusAI, a user-friendly and open-source ML platform, streamlines protein engineering and design tasks.
3 | ProteusAI offers modules to support researchers in various stages of the design-build-test-learn (DBTL) cycle,
4 | including protein discovery, structure-based design, zero-shot predictions, and ML-guided directed evolution (MLDE).
5 | Our benchmarking results demonstrate ProteusAI’s efficiency in improving proteins and enzymes within a few DBTL-cycle
6 | iterations. ProteusAI democratizes access to ML-guided protein engineering and is freely available for academic and
7 | commercial use.
8 | You can upload different data types to get started with ProteusAI. Click on the other module tabs to learn about their
9 | functionality and the expected data types.
10 | """
11 |
12 | zs_file_type = """
13 | Upload a FASTA file containing the protein sequence, or a PDB file containing the structure of the protein
14 | for which you want to generate zero-shot predictions.
15 | """
16 |
17 | zs_tooltips = """
18 | The ProteusAI Zero-Shot Module is designed to create a mutant library with no prior data.
19 | The module uses scores generated by large protein language models, such as ESM-1v, that have
20 | been trained to predict hidden residues in hundreds of millions of protein sequences.
21 | Often, you will find that several residues in your protein sequence have low predicted probabilities.
22 | It has been previously shown that swapping these residues for residues with higher probabilities
23 | has beneficial effects on the candidate protein. In ProteusAI, we provide access to several language
24 | models which have been trained under slightly different conditions. The best models to produce Zero-Shot
25 | scores are ESM-1v and ESM-2 (650M). However, these models will take a long time to compute the results.
26 | Consider using ESM-2 (35M) to get familiar with the module first before moving to the larger models.
27 | """
28 |
29 | discovery_file_type = """
30 | Upload a CSV or EXCEL file in the 'Data' tab under 'Library' to proceed with the Discovery module.
31 | The file should contain a column for protein sequences, a column with the protein names, and a column for
32 | annotations, which can also be empty or partially populated with annotations.
33 | """
34 |
35 | discovery_tooltips = """
36 | The Discovery Module offers a structured approach to identifying proteins even with little to no
37 | experimental data to start with. The goal of the module is to identify proteins with similar
38 | functions and to propose novel sequences that are likely to have similar functions.
39 | The module relies on representations generated by large protein language models that transform protein
40 | sequences into meaningful vector representations. It has been shown that these vector representations often
41 | cluster based on function. Clustering should be used if all, very few, or no sequences have annotations.
42 | Classification should be used if some or all sequences are annotated. To find out if you have enough
43 | sequences for classification, we recommend using the model statistics on the validation set, which are
44 | automatically generated by the module after training.
45 | """
46 |
47 | mlde_file_type = """
48 | Upload a CSV or EXCEL file in the 'Data' tab under 'Library' to proceed with the MLDE module.
49 | The file should contain a column for protein sequences, a column with the protein names (e.g.,
50 | mutant descriptions M15V), and a column for experimental values (e.g., enzyme activity,
51 | fluorescence, etc.).
52 | """
53 |
54 | mlde_tooltips = """
55 | The Machine Learning Guided Directed Evolution (MLDE) module offers a structured approach to
56 | improve protein function through iterative mutagenesis, inspired by Directed Evolution.
57 | Here, machine learning models are trained on previously generated experimental results. The
58 | 'Search' algorithm will then propose novel sequences that will be evaluated and ranked by the
59 | trained model to predict mutants that are likely to improve function. The Bayesian optimization
60 | algorithms used to search for novel mutants are based on models trained on protein representations
61 | that can either be generated from large language models, which is currently very slow, or from
62 | classical algorithms such as BLOSUM62. For now, we recommend the use of BLOSUM62 representations
63 | combined with Random Forest models for the best trade-off between speed and quality. However, we encourage
64 | experimentation with other models and representations.
65 | """
66 |
67 | design_file_type = """
68 | Upload a PDB file containing the structure of the protein
69 | to use the (structure-based) Design module.
70 | """
71 |
72 | design_tooltips = """
73 | The Protein Design module is a structure-based approach to predict novel sequences using 'Inverse Folding'
74 | algorithms. The designed sequences are likely to preserve the fold of the protein while improving
75 | the thermal stability and solubility of proteins. To preserve important functions of the protein, we recommend
76 | the preservation of protein-protein, ligand-ligand interfaces, and evolutionarily conserved sites, which can be
77 | entered manually. The temperature factor influences the diversity of designs. We recommend the generation of at
78 | least 1,000 sequences and rigorous filtering before ordering variants for validation. To give an example: Sort
79 | the sequences from the lowest to highest score, predict the structure of the lowest-scoring variants, and proceed
80 | with the designs that preserve the geometry of the active site (in the case of an enzyme). Experiment with a small
81 | sample size to explore temperature values that yield desired levels of diversification before generating large
82 | numbers of sequences.
83 | """
84 |
85 | representations_tooltips = """
86 | The Representations module offers methods to compute and visualize vector representations of proteins. These are primarily
87 | used by the MLDE and Discovery modules to make training more data-efficient. The representations are generated from
88 | classical algorithms such as BLOSUM62 or large protein language models that infuse helpful inductive biases into protein
89 | sequence representations. In some cases, the representations can be used to cluster proteins based on function or to
90 | predict protein properties. The module offers several visualization techniques to explore the representations and to
91 | understand the underlying structure of the protein data. Advanced analysis and predictions can be made by using the
92 | MLDE or Discovery modules in combination with the Representations module.
93 | """
94 |
95 | zs_entropy_tooltips = """
96 | This plot shows the entropy values across the protein sequence, providing insights into the diversity tolerated at each position.
97 | The higher the entropy, the great the variety of amino acids tolerated at the position.
98 | """
99 |
100 | zs_heatmap_tooltips = """
101 | This heatmap visualizes the computed zero-shot scores for the protein. The scores at each position are normalised to the score
102 | of the original amino acid, which is set to zero (white) and highlighted by a black box. A positive score (blue) indicates that mutating the position to that amino acid could
103 | have beneficial effects, while a negative score (red) indicates the mutation would not be favourable.
104 | """
105 |
--------------------------------------------------------------------------------
/demo/MLDE_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append("src/")
5 | import proteusAI as pai
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import json
9 | import argparse
10 | import pandas as pd
11 |
12 | # Initialize the argparse parser
13 | parser = argparse.ArgumentParser(description="Benchmarking ProteusAI MLDE")
14 |
15 | # Add arguments corresponding to the variables
16 | parser.add_argument(
17 | "--user", type=str, default="benchmark", help="User name or identifier."
18 | )
19 | parser.add_argument(
20 | "--sample-sizes",
21 | type=int,
22 | nargs="+",
23 | default=[5, 10, 20, 100],
24 | help="List of sample sizes.",
25 | )
26 | parser.add_argument("--model", type=str, default="gp", help="Model name.")
27 | parser.add_argument("--rep", type=str, default="esm2", help="Representation type.")
28 | parser.add_argument(
29 | "--zs-model", type=str, default="esm1v", help="Zero-shot model name."
30 | )
31 | parser.add_argument(
32 | "--benchmark-folder",
33 | type=str,
34 | default="demo/demo_data/DMS/",
35 | help="Path to the benchmark folder.",
36 | )
37 | parser.add_argument("--seed", type=int, default=42, help="Random seed.")
38 | parser.add_argument(
39 | "--max-iter",
40 | type=int,
41 | default=100,
42 | help="Maximum number of iterations (None for unlimited).",
43 | )
44 | parser.add_argument(
45 | "--device",
46 | type=str,
47 | default="cuda",
48 | help="Device to run the model on (e.g., cuda, cpu).",
49 | )
50 | parser.add_argument(
51 | "--batch-size", type=int, default=1, help="Batch size for processing."
52 | )
53 | parser.add_argument(
54 | "--improvement",
55 | type=str,
56 | nargs="+",
57 | default=[5, 10, 20, 50, "improved"],
58 | help="List of improvements.",
59 | )
60 | parser.add_argument(
61 | "--acquisition_fn", type=str, default="ei", help="ProteusAI acquisition functions"
62 | )
63 | parser.add_argument(
64 | "--k_folds", type=int, default=5, help="K-fold cross validation for sk_learn models"
65 | )
66 |
67 |
68 | def benchmark(dataset, fasta, model, embedding, name, sample_size, results_df):
69 | iteration = 1
70 |
71 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num
72 | lib = pai.Library(
73 | user=USER,
74 | source=dataset,
75 | seqs_col="mutated_sequence",
76 | y_col="DMS_score",
77 | y_type="num",
78 | names_col="mutant",
79 | )
80 |
81 | # compute representations for this dataset
82 | lib.compute(method=REP, batch_size=BATCH_SIZE, device=DEVICE)
83 |
84 | # plot destination
85 | # plot_dest = os.path.join(lib.user, name, embedding, 'plots', str(sample_size))
86 | # os.makedirs(plot_dest, exist_ok=True)
87 |
88 | # wt sequence
89 | protein = pai.Protein(user=USER, source=fasta)
90 |
91 | # zero-shot scores
92 | out = protein.zs_prediction(model=ZS_MODEL, batch_size=BATCH_SIZE, device=DEVICE)
93 | zs_lib = pai.Library(user=USER, source=out)
94 |
95 | # Simulate selection of top N ZS-predictions for the initial librarys
96 | zs_prots = [prot for prot in zs_lib.proteins if prot.name in lib.names]
97 | sorted_zs_prots = sorted(zs_prots, key=lambda prot: prot.y_pred, reverse=True)
98 |
99 | # train on the ZS-selected initial library (assume that they have been assayed now) train with 80:10:10 split
100 | # if the sample size is to small the initial batch has to be large enough to give meaningful statistics. Here 15
101 | if sample_size <= 15:
102 | top_N_zs_names = [prot.name for prot in sorted_zs_prots[:15]]
103 | zs_selected = [prot for prot in lib.proteins if prot.name in top_N_zs_names]
104 | n_train = 11
105 | n_test = 2
106 | _sample_size = 15
107 | else:
108 | top_N_zs_names = [prot.name for prot in sorted_zs_prots[:sample_size]]
109 | zs_selected = [prot for prot in lib.proteins if prot.name in top_N_zs_names]
110 | n_train = int(sample_size * 0.8)
111 | n_test = int(sample_size * 0.1)
112 | _sample_size = sample_size
113 |
114 | # train model
115 | model = pai.Model(model_type=model)
116 | model.train(
117 | library=lib,
118 | x=REP,
119 | split={
120 | "train": zs_selected[:n_train],
121 | "test": zs_selected[n_train : n_train + n_test],
122 | "val": zs_selected[n_train + n_test : _sample_size],
123 | },
124 | seed=SEED,
125 | model_type=MODEL,
126 | k_folds=K_FOLD,
127 | )
128 |
129 | # add to results df
130 | results_df = add_to_data(
131 | data=results_df,
132 | proteins=zs_selected,
133 | iteration=iteration,
134 | sample_size=_sample_size,
135 | dataset=name,
136 | model=model,
137 | )
138 |
139 | # use the model to make predictions on the remaining search space
140 | search_space = [prot for prot in lib.proteins if prot.name not in top_N_zs_names]
141 | ranked_search_space, _, _, _, _ = model.predict(search_space, acq_fn=ACQ_FN)
142 |
143 | # Prepare the tracking of top N variants,
144 | top_variants_counts = IMPROVEMENT
145 | found_counts = {count: 0 for count in top_variants_counts}
146 | first_discovered = [None] * len(top_variants_counts)
147 |
148 | # Identify the top variants by score
149 | actual_top_variants = sorted(lib.proteins, key=lambda prot: prot.y, reverse=True)
150 | actual_top_variants = {
151 | count: [prot.name for prot in actual_top_variants[:count]]
152 | for count in top_variants_counts
153 | if isinstance(count, int)
154 | }
155 |
156 | # count variants that are improved over wt, 1 standard deviation above wt
157 | y_std = np.std(lib.y, ddof=1)
158 | y_mean = np.mean(lib.y)
159 | actual_top_variants["improved"] = [
160 | prot.name for prot in lib.proteins if prot.y > y_mean + 1 * y_std
161 | ]
162 |
163 | # add sequences to the new dataset, and continue the loop until the dataset is exhausted
164 | sampled_data = zs_selected
165 |
166 | while len(ranked_search_space) >= sample_size:
167 | # Check if we have found all top variants, including the 1st zero-shot round
168 | sampled_names = [prot.name for prot in sampled_data] # noqa: F841
169 | for c, count in enumerate(found_counts):
170 | found = len(
171 | [
172 | prot.name
173 | for prot in sampled_data
174 | if prot.name in actual_top_variants[count]
175 | ]
176 | )
177 | found_counts[count] = found
178 | if found > 0 and first_discovered[c] is None:
179 | first_discovered[c] = iteration
180 |
181 | # Break the loop if all elements in first_discovered are not None
182 | if all(value is not None for value in first_discovered):
183 | break
184 |
185 | # Break if maximum number of iterations have been reached
186 | if iteration == MAX_ITER:
187 | break
188 |
189 | iteration += 1
190 |
191 | # select variants that have now been 'assayed'
192 | sample = ranked_search_space[:sample_size]
193 |
194 | # Remove the selected top N elements from the ranked search space
195 | ranked_search_space = ranked_search_space[sample_size:]
196 |
197 | # add new 'assayed' sample to sampled data
198 | sampled_data += sample
199 |
200 | # split into train, test and val
201 | n_train = int(len(sampled_data) * 0.8)
202 | n_test = int(len(sampled_data) * 0.1)
203 |
204 | # handle the very low data results
205 | if n_test == 0:
206 | n_test = 1
207 | n_train = n_train - n_test
208 |
209 | split = {
210 | "train": sampled_data[:n_train],
211 | "test": sampled_data[n_train : n_train + n_test],
212 | "val": sampled_data[n_train + n_test :],
213 | }
214 |
215 | # train model on new data
216 | model.train(
217 | library=lib, x=REP, split=split, seed=SEED, model_type=MODEL, k_folds=K_FOLD
218 | )
219 |
220 | # add to results
221 | results_df = add_to_data(
222 | data=results_df,
223 | proteins=sample,
224 | iteration=iteration,
225 | sample_size=sample_size,
226 | dataset=name,
227 | model=model,
228 | )
229 |
230 | # re-score the new search space
231 | (
232 | ranked_search_space,
233 | sorted_y_pred,
234 | sorted_sigma_pred,
235 | y_val,
236 | sorted_acq_score,
237 | ) = model.predict(ranked_search_space, acq_fn=ACQ_FN)
238 |
239 | # save when the first datapoints for each dataset and category have been discvered
240 | first_discovered_data[name][sample_size] = first_discovered
241 |
242 | return found_counts, results_df
243 |
244 |
245 | def plot_results(found_counts, name, iter, dest, sample_size):
246 | counts = list(found_counts.keys())
247 | found = [found_counts[count] for count in counts]
248 | x_positions = range(len(counts)) # Create fixed distance positions for the x-axis
249 |
250 | plt.figure(figsize=(10, 6))
251 | plt.bar(x_positions, found, color="skyblue")
252 | plt.xlabel("Top N Variants")
253 | plt.ylabel("Number of Variants Found")
254 | plt.title(f"{name}: Number of Top N Variants Found After {iter} Iterations")
255 | plt.xticks(x_positions, counts) # Set custom x-axis positions and labels
256 | plt.ylim(0, max(found) + 1)
257 |
258 | for i, count in enumerate(found):
259 | plt.text(x_positions[i], count + 0.1, str(count), ha="center")
260 |
261 | plt.savefig(os.path.join(dest, f"top_variants_{iter}_iterations_{name}.png"))
262 |
263 |
264 | def add_to_data(data: pd.DataFrame, proteins, iteration, sample_size, dataset, model):
265 | """Add sampling results to dataframe"""
266 | names = [prot.name for prot in proteins]
267 | ys = [prot.y for prot in proteins]
268 | y_preds = [prot.y_pred for prot in proteins]
269 | y_sigmas = [prot.y_sigma for prot in proteins]
270 | models = [MODEL] * len(names)
271 | reps = [REP] * len(names)
272 | acq_fns = [ACQ_FN] * len(names)
273 | rounds = [iteration] * len(names)
274 | datasets = [dataset] * len(names)
275 | sample_sizes = [sample_size] * len(names)
276 | test_r2s = [model.test_r2] * len(names)
277 | val_r2s = [model.val_r2] * len(names)
278 |
279 | new_data = pd.DataFrame(
280 | {
281 | "name": names,
282 | "y": ys,
283 | "y_pred": y_preds,
284 | "y_sigma": y_sigmas,
285 | "test_r2": test_r2s,
286 | "val_r2": val_r2s,
287 | "model": models,
288 | "rep": reps,
289 | "acq_fn": acq_fns,
290 | "round": rounds,
291 | "sample_size": sample_sizes,
292 | "dataset": datasets,
293 | }
294 | )
295 |
296 | # Append the new data to the existing DataFrame
297 | updated_data = pd.concat([data, new_data], ignore_index=True)
298 | return updated_data
299 |
300 |
301 | # Parse the arguments
302 | args = parser.parse_args()
303 |
304 | # Assign parsed arguments to capitalized variable names
305 | USER = args.user
306 | SAMPLE_SIZES = args.sample_sizes
307 | MODEL = args.model
308 | REP = args.rep
309 | ZS_MODEL = args.zs_model
310 | BENCHMARK_FOLDER = args.benchmark_folder
311 | SEED = args.seed
312 | MAX_ITER = args.max_iter
313 | DEVICE = args.device
314 | BATCH_SIZE = args.batch_size
315 | IMPROVEMENT = args.improvement
316 | ACQ_FN = args.acquisition_fn
317 | K_FOLD = args.k_folds
318 |
319 | # benchmark data
320 | datasets = [f for f in os.listdir(BENCHMARK_FOLDER) if f.endswith(".csv")]
321 | fastas = [f for f in os.listdir(BENCHMARK_FOLDER) if f.endswith(".fasta")]
322 | datasets.sort()
323 | fastas.sort()
324 |
325 | # save sampled data
326 | results_df = pd.DataFrame(
327 | {
328 | "name": [],
329 | "y": [],
330 | "y_pred": [],
331 | "y_sigma": [],
332 | "test_r2": [],
333 | "val_r2": [],
334 | "model": [],
335 | "rep": [],
336 | "acq_fn": [],
337 | "round": [],
338 | "sample_size": [],
339 | "dataset": [],
340 | }
341 | )
342 |
343 | first_discovered_data = {}
344 | for i in range(len(datasets)):
345 | for N in SAMPLE_SIZES:
346 | print(
347 | f"RUNNING model:{MODEL}, rep:{REP}, acq:{ACQ_FN}, sample_size:{N}",
348 | flush=True,
349 | )
350 | d = os.path.join(BENCHMARK_FOLDER, datasets[i])
351 | f = os.path.join(BENCHMARK_FOLDER, fastas[i])
352 | name = datasets[i][:-4]
353 |
354 | if N == SAMPLE_SIZES[0]:
355 | first_discovered_data[name] = {N: []}
356 | else:
357 | first_discovered_data[name][N] = []
358 |
359 | found_counts, results_df = benchmark(
360 | d,
361 | f,
362 | model=MODEL,
363 | embedding=REP,
364 | name=name,
365 | sample_size=N,
366 | results_df=results_df,
367 | )
368 | # save first discovered data
369 | with open(
370 | os.path.join(
371 | "usrs/benchmark/", f"first_discovered_data_{MODEL}_{REP}_{ACQ_FN}.json"
372 | ),
373 | "w",
374 | ) as file:
375 | json.dump(first_discovered_data, file)
376 |
377 | results_df.to_csv(
378 | os.path.join("usrs/benchmark/", f"results_df_{MODEL}_{REP}_{ACQ_FN}.csv"),
379 | index=False,
380 | )
381 |
--------------------------------------------------------------------------------
/demo/clustering_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import matplotlib.pyplot as plt
4 |
5 | import proteusAI as pai
6 |
7 | print(os.getcwd())
8 | sys.path.append("src/")
9 |
10 |
11 | # will initiate storage space - else in memory
12 | dataset = "demo/demo_data/methyltransfereases.csv"
13 | y_column = "class"
14 |
15 | results_dictionary = {}
16 |
17 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num
18 | library = pai.Library(
19 | source=dataset,
20 | seqs_col="sequence",
21 | y_col=y_column,
22 | y_type="class",
23 | names_col="uid",
24 | )
25 |
26 | # compute and save ESM-2 representations at example_lib/representations/esm2
27 | library.compute(method="esm2_8M", batch_size=10)
28 |
29 | # dimensionality reduction
30 | dr_method = "tsne"
31 |
32 | # random seed
33 | seed = 42
34 |
35 | # define a model
36 | model = pai.Model(
37 | library=library,
38 | k_folds=5,
39 | model_type="hdbscan",
40 | rep="esm2_8M",
41 | min_cluster_size=30,
42 | min_samples=50,
43 | dr_method=dr_method,
44 | seed=seed,
45 | )
46 |
47 | # train model
48 | model.train()
49 |
50 | # search predict the classes of unknown sequences
51 | out = model.search()
52 | search_mask = out["mask"]
53 |
54 | # save results
55 | if not os.path.exists("demo/demo_data/out/"):
56 | os.makedirs("demo/demo_data/out/", exist_ok=True)
57 |
58 | out["df"].to_csv("demo/demo_data/out/clustering_search_results.csv")
59 |
60 | model_lib = pai.Library(source=out)
61 |
62 | # plot results
63 | fig, ax, plot_df = model.library.plot(
64 | rep="esm2_8M",
65 | method=dr_method,
66 | use_y_pred=True,
67 | highlight_mask=search_mask,
68 | seed=seed,
69 | )
70 | plt.savefig(f"demo/demo_data/out/clustering_results_{dr_method}.png")
71 |
--------------------------------------------------------------------------------
/demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta:
--------------------------------------------------------------------------------
1 | >wt|NOD
2 | MAPTLSEQTRQLVRASVPALQKHSVAISATMYRLLFERYPETRSLCELPERQIHKIASALLAYARSIDNPSALQAAIRRMVLSHARAGVQAVHYPLYWECLRDAIKEVLGPDATETLLQAWKEAYDFLAHLLSTKEAQVYAVLAE
--------------------------------------------------------------------------------
/demo/demo_data/binding_data/dataset_creation.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | # Load the dataframes
4 | df1 = pd.read_csv("demo/demo_data/binding_data/replicate_1.csv")
5 | df2 = pd.read_csv("demo/demo_data/binding_data/replicate_2.csv")
6 | df3 = pd.read_csv("demo/demo_data/binding_data/replicate_3.csv")
7 |
8 | # Merge the dataframes on CDR1H_AA and CDR3H_AA to find common rows
9 | merged_df = pd.merge(df1, df2, on=["CDR1H_AA", "CDR3H_AA"], suffixes=("_df1", "_df2"))
10 | merged_df = pd.merge(merged_df, df3, on=["CDR1H_AA", "CDR3H_AA"])
11 |
12 | # Calculate the average of fit_KD from the three dataframes
13 | merged_df["fit_KD_avg"] = merged_df[["fit_KD_df1", "fit_KD_df2", "fit_KD"]].mean(axis=1)
14 |
15 | # Select the relevant columns for the final dataframe
16 | result_df = merged_df[["CDR1H_AA", "CDR3H_AA", "fit_KD_avg"]]
17 |
18 | # Group by CDR1H_AA and CDR3H_AA to remove duplicates and average the fit_KD_avg
19 | final_df = (
20 | result_df.groupby(["CDR1H_AA", "CDR3H_AA"])
21 | .agg({"fit_KD_avg": "mean"})
22 | .reset_index()
23 | )
24 |
25 | # Display the final result
26 | final_df
27 |
28 | wt_seq = "EVKLDETGGGLVQPGRPMKLSCVASGFTFSDYWMNWVRQSPEKGLEWVAQIRNKPYNYETYYSDSVKGRFTISRDDSKSSVYLQMNNLRVEDMGIYYCTGSYYGMDYWGQGTSVTVSSAKTTAPSVYPLAPVCGDTTGSSVTLGCLVKGYFPEPVTLTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVTSSTWPSQSITCNVAHPASSTKVDKKIEPRG"
29 |
30 | # CDRH1 and CDR3H wild-type sequences
31 | cdr1h_wt = wt_seq[27:37] # TFSDYWMNWV
32 | cdr3h_wt = wt_seq[99:109] # GSYYGMDYWG
33 |
34 | seqs = []
35 | mutant_names = []
36 | for i, row in final_df.iterrows():
37 | seq = list(wt_seq)
38 | CDR1H_AA = row["CDR1H_AA"]
39 | CDR3H_AA = row["CDR3H_AA"]
40 | seq[27:37] = list(CDR1H_AA)
41 | seq[99:109] = list(CDR3H_AA)
42 | mutated_seq = "".join(seq)
43 | seqs.append(mutated_seq)
44 |
45 | # Determine the mutations for naming
46 | mutations = []
47 |
48 | # Compare CDR1H sequence
49 | for j in range(len(cdr1h_wt)):
50 | if CDR1H_AA[j] != cdr1h_wt[j]:
51 | mutations.append(f"{cdr1h_wt[j]}{27+j+1}{CDR1H_AA[j]}")
52 |
53 | # Compare CDR3H sequence
54 | for j in range(len(cdr3h_wt)):
55 | if CDR3H_AA[j] != cdr3h_wt[j]:
56 | mutations.append(f"{cdr3h_wt[j]}{99+j+1}{CDR3H_AA[j]}")
57 |
58 | # Combine mutations into a mutant name
59 | if mutations:
60 | mutant_name = ":".join(mutations)
61 | else:
62 | mutant_name = "WT"
63 |
64 | mutant_names.append(mutant_name)
65 |
66 | data = {
67 | "mutant": mutant_names,
68 | "mutated_sequence": seqs,
69 | "DMS_score": final_df["fit_KD_avg"].to_list(),
70 | "DMS_score_bin": [None] * len(mutant_names),
71 | }
72 |
73 | results_df = pd.DataFrame(data)
74 | results_df.to_csv("demo/demo_data/DMS/SCFV_HUMAN_Adams_2016_affinity.csv", index=False)
75 |
76 | with open("demo/demo_data/DMS/SCFV_HUMAN_Adams_2016_affinity.fasta", "w") as f:
77 | f.write(">SCFV_HUMAN_Adams_2016_affinity\n")
78 | f.write(f"{wt_seq}")
79 |
--------------------------------------------------------------------------------
/demo/design_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import proteusAI as pai
4 |
5 | files = ["demo/demo_data/GB1.pdb"] # input files
6 | temps = [1.0] # sampling temperatures
7 | fixed = {"GB1": [1, 2, 3]} # fixed residues
8 | num_samples = 100 # number of samples
9 |
10 | for f in files:
11 | protein = pai.Protein(source=f)
12 |
13 | fname = f.split("/")[-1][:-4]
14 |
15 | for temp in temps:
16 | for key, fix in fixed.items():
17 | out = protein.esm_if(
18 | num_samples=num_samples, target_chain="A", temperature=temp, fixed=fix
19 | )
20 |
21 | # save results
22 | if not os.path.exists("demo/demo_data/out/"):
23 | os.makedirs("demo/demo_data/out/", exist_ok=True)
24 |
25 | out["df"].to_csv(
26 | f"demo/demo_data/out_{fname}_temp_{temp}_{key}_out.csv", index=False
27 | )
28 |
29 |
30 | ### UNCOMMENT THESE LINES TO FOLD DESIGNS IF YOU HAVE ESM-FOLD INSTALLED ###
31 |
32 | # Designs can be folded to check the structure and confidence of the designs
33 | # import shutil
34 | # design_library = pai.Library(source=out)
35 |
36 | # this will fold the designs
37 | # fold_out = design_library.fold()
38 |
39 | # save the folded designs
40 | # fold_out["df"].to_csv(
41 | # f"demo/demo_data/out{fname}_temp_{temp}_{key}_folded.csv", index=False
42 | # )
43 |
44 | # save the structures
45 | # os.makedirs(f"demo/demo_data/out/{fname}_temp_{temp}_{key}_pdb/", exist_ok=True)
46 |
47 | # Move all files from source to destination
48 | # for file in os.listdir(out["struc_path"]):
49 | # shutil.move(
50 | # os.path.join(out["struc_path"], file),
51 | # f"demo/demo_data/out/{fname}_temp_{temp}_{key}_pdb/",
52 | # )
53 |
--------------------------------------------------------------------------------
/demo/discovery_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import matplotlib.pyplot as plt
4 |
5 | import proteusAI as pai
6 |
7 | print(os.getcwd())
8 | sys.path.append("src/")
9 |
10 |
11 | # will initiate storage space - else in memory
12 | dataset = "demo/demo_data/methyltransfereases.csv"
13 | y_column = "coverage_5"
14 |
15 | results_dictionary = {}
16 |
17 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num
18 | library = pai.Library(
19 | source=dataset,
20 | seqs_col="sequence",
21 | y_col=y_column,
22 | y_type="class",
23 | names_col="uid",
24 | )
25 |
26 | # compute and save ESM-2 representations at example_lib/representations/esm2
27 | library.compute(method="esm2_8M", batch_size=10)
28 |
29 | # define a model
30 | model = pai.Model(library=library, k_folds=5, model_type="rf", rep="esm2_8M")
31 |
32 | # train model
33 | model.train()
34 |
35 | # search predict the classes of unknown sequences
36 | out = model.search()
37 | search_mask = out["mask"]
38 |
39 | # save results
40 | if not os.path.exists("demo/demo_data/out/"):
41 | os.makedirs("demo/demo_data/out/", exist_ok=True)
42 |
43 | out["df"].to_csv("demo/demo_data/out/discovery_search_results.csv")
44 |
45 | model_lib = pai.Library(source=out)
46 |
47 | # plot results
48 | fig, ax, plot_df = model.library.plot(
49 | rep="esm2_8M", use_y_pred=True, highlight_mask=search_mask
50 | )
51 | plt.savefig("demo/demo_data/out/search_results.png")
52 |
--------------------------------------------------------------------------------
/demo/fold_design_demo.py:
--------------------------------------------------------------------------------
1 | ### THIS DEMO ONLY WORKS IF ESM-FOLD IS INSTALLED ###
2 |
3 | import os
4 | import shutil
5 |
6 | import proteusAI as pai
7 |
8 | files = ["demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta"] # input files
9 |
10 | for f in files:
11 | protein = pai.Protein(source=f)
12 |
13 | # fold the protein
14 | fold_out = protein.esm_fold(relax=True)
15 |
16 | # design the now folded protein
17 | design_out = protein.esm_if(
18 | num_samples=100, target_chain="A", temperature=1.0, fixed=[]
19 | )
20 |
21 | # move results to demo folder
22 | if not os.path.exists("demo/demo_data/out/"):
23 | os.makedirs("demo/demo_data/out/", exist_ok=True)
24 |
25 | # save the design
26 | design_out["df"].to_csv("demo/demo_data/fold_design_out.csv", index=False)
27 |
28 | # move the structures
29 | os.makedirs("demo/demo_data/out/fold_design_pdb/", exist_ok=True)
30 |
31 | # Copy all files from the source to the destination directory
32 | for file in os.listdir(fold_out["struc_path"]):
33 | src = os.path.join(fold_out["struc_path"], file)
34 | dst = os.path.join("demo/demo_data/out/fold_design_pdb/", file)
35 | shutil.copy(src, dst)
36 |
--------------------------------------------------------------------------------
/demo/mlde_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import proteusAI as pai
5 |
6 | print(os.getcwd())
7 | sys.path.append("src/")
8 |
9 |
10 | # will initiate storage space - else in memory
11 | dataset = "demo/demo_data/Nitric_Oxide_Dioxygenase.csv"
12 |
13 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num
14 | library = pai.Library(
15 | source=dataset,
16 | seqs_col="Sequence",
17 | y_col="Data",
18 | y_type="num",
19 | names_col="Description",
20 | )
21 |
22 | # compute and save ESM-2 representations at example_lib/representations/esm2
23 | library.compute(method="esm2_8M", batch_size=10)
24 |
25 | # define a model
26 | model = pai.Model(library=library, k_folds=5, model_type="rf", x="vhse")
27 |
28 | # train model
29 | model.train()
30 |
31 | # search for new mutants
32 | out = model.search(optim_problem="max")
33 |
34 | # save results
35 | if not os.path.exists("demo/demo_data/out/"):
36 | os.makedirs("demo/demo_data/out/", exist_ok=True)
37 |
38 | out.to_csv("demo/demo_data/out/demo_search_results.csv")
39 |
--------------------------------------------------------------------------------
/demo/plot_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import matplotlib.pyplot as plt
5 | import pandas as pd
6 | import seaborn as sns
7 |
8 | # Initialize the argparse parser
9 | parser = argparse.ArgumentParser(description="Plot the benchmark results.")
10 |
11 | # Add arguments
12 | parser.add_argument("--rep", type=str, default="blosum62", help="Representation type.")
13 | parser.add_argument("--max-sample", type=int, default=100, help="Maximum sample size.")
14 | parser.add_argument("--model", type=str, default="gp", help="Model name.")
15 | parser.add_argument(
16 | "--acquisition-fn", type=str, default="ei", help="Acquisition function name."
17 | )
18 |
19 | # Parse the arguments
20 | args = parser.parse_args()
21 |
22 | # Assign parsed arguments to variables
23 | REP = args.rep
24 | MAX_SAMPLE = args.max_sample
25 | MODEL = args.model
26 | ACQ_FN = args.acquisition_fn
27 |
28 | # Dictionaries for pretty names
29 | rep_dict = {
30 | "ohe": "One-hot",
31 | "blosum50": "BLOSUM50",
32 | "blosum62": "BLOSUM62",
33 | "esm2": "ESM-2",
34 | "esm1v": "ESM-1v",
35 | "vae": "VAE",
36 | }
37 | model_dict = {
38 | "rf": "Random Forest",
39 | "knn": "KNN",
40 | "svm": "SVM",
41 | "esm2": "ESM-2",
42 | "esm1v": "ESM-1v",
43 | "gp": "Gaussian Process",
44 | "ridge": "Ridge Regression",
45 | }
46 | acq_dict = {
47 | "ei": "Expected Improvement",
48 | "ucb": "Upper Confidence Bound",
49 | "greedy": "Greedy",
50 | "random": "Random",
51 | }
52 |
53 | # Load data from JSON file
54 | with open(f"usrs/benchmark/first_discovered_data_{MODEL}_{REP}_{ACQ_FN}.json") as f:
55 | data = json.load(f)
56 |
57 | # Define the top N variants and sample sizes
58 | datasets = list(data.keys())
59 | sample_sizes = [int(i) for i in list(data[datasets[0]].keys())]
60 | top_n_variants = ["5", "10", "20", "50", "improved"]
61 |
62 | # Initialize a list to store the processed data
63 | plot_data = []
64 |
65 | # Process the data
66 | for dataset, results in data.items():
67 | for sample_size, rounds in results.items():
68 | if rounds is not None:
69 | for i, round_count in enumerate(rounds):
70 | if round_count is not None:
71 | plot_data.append(
72 | {
73 | "Dataset": dataset,
74 | "Sample Size": int(sample_size),
75 | "Top N Variants": top_n_variants[i],
76 | "Rounds": round_count,
77 | }
78 | )
79 |
80 | # Create a DataFrame from the processed data
81 | df = pd.DataFrame(plot_data)
82 | df["Top N Variants"] = pd.Categorical(
83 | df["Top N Variants"], categories=top_n_variants, ordered=True
84 | )
85 |
86 | # Set plot style
87 | sns.set(style="whitegrid")
88 |
89 | # Create a grid of box plots for each sample size with shared y-axis
90 | g = sns.FacetGrid(
91 | df, col="Sample Size", col_wrap=2, height=5.5, aspect=1.8, sharey=True
92 | ) # Adjusted height and aspect
93 |
94 |
95 | # Function to create box plot and strip plot
96 | def plot_box_and_strip(data, x, y, **kwargs):
97 | ax = plt.gca()
98 | sns.boxplot(data=data, x=x, y=y, color="lightgray", width=0.5, ax=ax) # noqa: F841
99 | sns.stripplot(data=data, x=x, y=y, **kwargs)
100 |
101 | # Add median values as text
102 | medians = data.groupby(x)[y].median()
103 | for i, median in enumerate(medians):
104 | ax.text(
105 | i, median + 1, f"{median:.2f}", ha="center", va="bottom", color="black"
106 | ) # Adjusting the position to be just above the median line
107 |
108 | # Set y-axis limit to MAX_SAMPLE
109 | ax.set_ylim(0, MAX_SAMPLE)
110 |
111 |
112 | # Apply the function to each subplot
113 | g.map_dataframe(
114 | plot_box_and_strip,
115 | x="Top N Variants",
116 | y="Rounds",
117 | hue="Dataset",
118 | dodge=True,
119 | jitter=True,
120 | palette="muted",
121 | )
122 |
123 | # Draw vertical bars between the sampling sizes
124 | for ax in g.axes.flatten():
125 | for x in range(len(top_n_variants) - 1):
126 | ax.axvline(x + 0.5, color="grey", linestyle="-", linewidth=0.8)
127 |
128 | # Adjust the title and layout
129 | g.set_titles(col_template="Sample Size: {col_name}")
130 | g.set_axis_labels("Top N Variants", "Rounds")
131 |
132 | rep = rep_dict[REP]
133 | model = model_dict[MODEL]
134 | acq_fn = acq_dict[ACQ_FN]
135 | subtitle = f"Representation: {rep}, Model: {model}, Acquisition Function: {acq_fn}"
136 |
137 | # Move legend to the right side
138 | g.add_legend(
139 | title="Dataset", bbox_to_anchor=(0.87, 0.47), loc="center left", borderaxespad=0
140 | )
141 |
142 | # Set the main title and subtitle
143 | g.fig.suptitle(
144 | "Rounds to Discover Top N Variants Across Different Sample Sizes", y=0.94, x=0.44
145 | )
146 | plt.text(
147 | 0.44,
148 | 0.9,
149 | subtitle,
150 | ha="center",
151 | va="center",
152 | fontsize=10,
153 | transform=g.fig.transFigure,
154 | )
155 |
156 | # Adjust the layout to optimize spacing
157 | plt.tight_layout(pad=2.0) # Adjusted padding
158 | g.fig.subplots_adjust(top=0.85, right=0.85, hspace=0.3, wspace=0.3) # Adjusted spacing
159 |
160 | # Save the plot
161 | plt.savefig(
162 | f"usrs/benchmark/first_discovered_{MODEL}_{REP}_{ACQ_FN}.png",
163 | bbox_inches="tight",
164 | dpi=300,
165 | )
166 |
167 | # Show the plot
168 | plt.show()
169 |
--------------------------------------------------------------------------------
/demo/plot_library_demo.py:
--------------------------------------------------------------------------------
1 | import proteusAI as pai
2 | import matplotlib.pyplot as plt
3 |
4 | # Initialize the library
5 | library = pai.Library(
6 | source="demo/demo_data/Nitric_Oxide_Dioxygenase.csv",
7 | seqs_col="Sequence",
8 | y_col="Data",
9 | y_type="num",
10 | names_col="Description", # Column containing the sequence descriptions
11 | )
12 | print(library.seqs)
13 | # Compute embeddings using the specified method
14 | library.compute(method="esm2_8M")
15 |
16 | # Generate a UMAP plot
17 | fig, ax, df = library.plot_umap(rep="esm2_8M")
18 |
19 | # Customize and display the plot
20 | plt.show()
21 |
--------------------------------------------------------------------------------
/demo/test:
--------------------------------------------------------------------------------
1 | >sampled_seq_1 recovery: 0.5272727272727272
2 | ELYTIWWDGTDLRGSETVTAKDAATAEKTFKAAAAENGVTGTWADDTQKTFTVTG
3 | >sampled_seq_2 recovery: 0.4727272727272727
4 | MIYTLKWQGGTLTGEETVTAADSSTAEAQFKSSAKNAGYSGQWTNQFKKTYTVTG
5 | >sampled_seq_3 recovery: 0.45454545454545453
6 | STYTLVKAGAKQKGQSTVEADDANDAEAAFRRGAADKGISGQWEDATGKTFTVTS
7 | >sampled_seq_4 recovery: 0.41818181818181815
8 | SEYTLEWDGALDTGVAEVEAANAKAATKVFASAAKDNGLKGKWTDRIDKTFTAKG
9 | >sampled_seq_5 recovery: 0.36363636363636365
10 | GTYTLLWDGDTVTGERTVSAASETSAEDKFNRAATEKGITGDWGDRHANTFTATG
11 | >sampled_seq_6 recovery: 0.4727272727272727
12 | GTYTLIWDAATLKGTSTVKASDIAAAQKIFNNMAKDLGVTGTWTDRKDRTYTVTG
13 | >sampled_seq_7 recovery: 0.43636363636363634
14 | GTYTLIWAGDKVHGTSTIAAADQATATALFKKGAKEAGMDGEWDDKASNTYTVVG
15 | >sampled_seq_8 recovery: 0.5636363636363636
16 | GQYTLKLSGGDVTGTTTVSADDKATAEKLFESAAKENGLQGEWNDFATKTFTVTG
17 | >sampled_seq_9 recovery: 0.4
18 | MTYTLVWAGEKAKGTQTVKAKNSAAATADFKESAAANGLSGKWSNHKTKTYTVSG
19 | >sampled_seq_10 recovery: 0.4727272727272727
20 | GIYTIKLDGKTYNGTATVTAADAATAEAKFESLAEQNGMTGAWTDATARTYTVTA
21 |
--------------------------------------------------------------------------------
/demo/zero_shot_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import proteusAI as pai
4 |
5 | files = ["demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta"] # input files
6 | models = [
7 | "esm2_8M"
8 | ] # Ideally use esm1v. ems2_8M is fast and used for demo purposes. esm1v is more accurate.
9 |
10 | for f in files:
11 | protein = pai.Protein(source=f)
12 |
13 | for model in models:
14 | out = protein.zs_prediction(model=model)
15 |
16 | # save results
17 | if not os.path.exists("demo/demo_data/out/"):
18 | os.makedirs("demo/demo_data/out/", exist_ok=True)
19 |
20 | out["df"].to_csv(
21 | f"demo/demo_data/zs_outdemo_{f.split('/')[-1][:-6]}.csv", index=False
22 | )
23 |
24 | print(out["df"])
25 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build
2 | jupyter_execute
3 | reference
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Docs creation
2 |
3 | In order to build the docs you need to
4 |
5 | 1. install sphinx and additional support packages
6 | 2. build the package reference files
7 | 3. run sphinx to create a local html version
8 |
9 | The documentation is build using readthedocs automatically.
10 |
11 | Install the docs dependencies of the package (as speciefied in toml):
12 |
13 | ```bash
14 | # in main folder
15 | pip install ".[docs]"
16 | ```
17 |
18 | ## Build docs using Sphinx command line tools
19 |
20 | Command to be run from `path/to/docs`, i.e. from within the `docs` package folder:
21 |
22 | Options:
23 | - `--separate` to build separate pages for each (sub-)module
24 |
25 | ```bash
26 | # pwd: docs
27 | # apidoc
28 | sphinx-apidoc --force --implicit-namespaces --module-first -o reference ../src/proteusAI
29 | # build docs
30 | sphinx-build -n -W --keep-going -b html ./ ./_build/
31 | ```
32 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | from importlib import metadata
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | project = "proteusAI"
20 | copyright = "2024, Jonathan Funk"
21 | author = "Jonathan Funk"
22 | PACKAGE_VERSION = metadata.version("proteusAI")
23 | version = PACKAGE_VERSION
24 | release = PACKAGE_VERSION
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 |
29 | # Add any Sphinx extension module names here, as strings. They can be
30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31 | # ones.
32 | extensions = [
33 | "sphinx.ext.autodoc",
34 | "sphinx.ext.autodoc.typehints",
35 | "sphinx.ext.viewcode",
36 | "sphinx.ext.napoleon",
37 | "sphinx.ext.intersphinx",
38 | "sphinx_new_tab_link",
39 | "myst_nb",
40 | ]
41 |
42 | # https://myst-nb.readthedocs.io/en/latest/computation/execute.html
43 | nb_execution_mode = "auto"
44 |
45 | myst_enable_extensions = ["dollarmath", "amsmath"]
46 |
47 | # Plolty support through require javascript library
48 | # https://myst-nb.readthedocs.io/en/latest/render/interactive.html#plotly
49 | html_js_files = [
50 | "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"
51 | ]
52 |
53 | # https://myst-nb.readthedocs.io/en/latest/configuration.html
54 | # Execution
55 | nb_execution_raise_on_error = True
56 | # Rendering
57 | nb_merge_streams = True
58 |
59 | # https://myst-nb.readthedocs.io/en/latest/authoring/custom-formats.html#write-custom-formats
60 | nb_custom_formats = {".py": ["jupytext.reads", {"fmt": "py:percent"}]}
61 |
62 | # Add any paths that contain templates here, relative to this directory.
63 | templates_path = ["_templates"]
64 |
65 | # List of patterns, relative to source directory, that match files and
66 | # directories to ignore when looking for source files.
67 | # This pattern also affects html_static_path and html_extra_path.
68 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "jupyter_execute", "conf.py"]
69 |
70 |
71 | # Intersphinx options
72 | intersphinx_mapping = {
73 | "python": ("https://docs.python.org/3", None),
74 | # "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
75 | # "scikit-learn": ("https://scikit-learn.org/stable/", None),
76 | # "matplotlib": ("https://matplotlib.org/stable/", None),
77 | }
78 |
79 | # -- Options for HTML output -------------------------------------------------
80 |
81 | # The theme to use for HTML and HTML Help pages. See the documentation for
82 | # a list of builtin themes.
83 | # See:
84 | # https://github.com/executablebooks/MyST-NB/blob/master/docs/conf.py
85 | # html_title = ""
86 | html_theme = "sphinx_book_theme"
87 | # html_logo = "_static/logo-wide.svg"
88 | # html_favicon = "_static/logo-square.svg"
89 | html_theme_options = {
90 | "github_url": "https://github.com/jonfunk21/proteusAI",
91 | "repository_url": "https://github.com/jonfunk21/proteusAI",
92 | "repository_branch": "main",
93 | "home_page_in_toc": True,
94 | "path_to_docs": "docs",
95 | "show_navbar_depth": 1,
96 | "use_edit_page_button": True,
97 | "use_repository_button": True,
98 | "use_download_button": True,
99 | "launch_buttons": {
100 | "colab_url": "https://colab.research.google.com"
101 | # "binderhub_url": "https://mybinder.org",
102 | # "notebook_interface": "jupyterlab",
103 | },
104 | "navigation_with_keys": False,
105 | }
106 |
107 | # Add any paths that contain custom static files (such as style sheets) here,
108 | # relative to this directory. They are copied after the builtin static files,
109 | # so a file named "default.css" will overwrite the builtin "default.css".
110 | # html_static_path = ["_static"]
111 |
112 |
113 | # -- Setup for sphinx-apidoc -------------------------------------------------
114 |
115 | # Read the Docs doesn't support running arbitrary commands like tox.
116 | # sphinx-apidoc needs to be called manually if Sphinx is running there.
117 | # https://github.com/readthedocs/readthedocs.org/issues/1139
118 |
119 | if os.environ.get("READTHEDOCS") == "True":
120 | from pathlib import Path
121 |
122 | PROJECT_ROOT = Path(__file__).parent.parent
123 | PACKAGE_ROOT = PROJECT_ROOT / "src" / "proteusAI"
124 |
125 | def run_apidoc(_):
126 | from sphinx.ext import apidoc
127 |
128 | apidoc.main(
129 | [
130 | "--force",
131 | "--implicit-namespaces",
132 | "--module-first",
133 | "--separate",
134 | "-o",
135 | str(PROJECT_ROOT / "docs" / "reference"),
136 | str(PACKAGE_ROOT),
137 | str(PACKAGE_ROOT / "*.c"),
138 | str(PACKAGE_ROOT / "*.so"),
139 | ]
140 | )
141 |
142 | def setup(app):
143 | app.connect("builder-inited", run_apidoc)
144 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | The proteusAI package
2 | =====================
3 |
4 | .. include:: ../README.md
5 | :parser: myst_parser.sphinx_
6 | :start-line: 1
7 |
8 |
9 | .. toctree::
10 | :hidden:
11 | :maxdepth: 2
12 | :caption: Tutorial
13 |
14 | tutorial/tutorial
15 |
16 | .. toctree::
17 | :hidden:
18 | :maxdepth: 2
19 | :caption: Contents:
20 |
21 | reference/modules
22 |
23 | .. toctree::
24 | :hidden:
25 | :caption: Technical notes
26 |
27 | README.md
28 |
29 |
30 |
31 | Indices and tables
32 | ==================
33 |
34 | * :ref:`genindex`
35 | * :ref:`modindex`
36 | * :ref:`search`
37 |
--------------------------------------------------------------------------------
/docs/tutorial/tutorial.py:
--------------------------------------------------------------------------------
1 | # %% [markdown]
2 | # # Tutorial
3 |
4 | # %%
5 | import proteusAI as pai
6 |
7 | pai.__version__
8 |
9 | # # MLDE Tutorial
10 | # %% [markdown]
11 | # ## Loading data from csv or excel
12 | # The data must contain the mutant sequences and y_values for the MLDE workflow. It is recommended to have useful sequence names for later interpretability, and to define the data type.
13 |
14 | # %%
15 | library = pai.Library(
16 | source="demo/demo_data/Nitric_Oxide_Dioxygenase_raw.csv",
17 | seqs_col="Sequence",
18 | y_col="Data",
19 | y_type="num",
20 | names_col="Description",
21 | )
22 |
23 |
24 | # %% [markdown]
25 | # ## Compute representations (skipped here to save time)
26 | # The available representations are ('esm2', 'esm1v', 'blosum62', 'blosum50', and 'ohe')
27 |
28 | # %%
29 | # library.compute(method='esm2', batch_size=10) # uncomment this line to compute esm2 representations
30 |
31 |
32 | # %% [markdown]
33 | # ## Define a model, using fast to compute BLOSUM62 representations (good for demo purposes and surprisingly competetive with esm2 representations).
34 |
35 | # %%
36 | model = pai.Model(library=library)
37 |
38 | # %% [markdown]
39 | # ## Training the model
40 |
41 | # %%
42 | _ = model.train(k_folds=5, model_type="rf", x="blosum62")
43 |
44 |
45 | # %% [markdown]
46 | # ## Search for new mutants
47 | # Searching new mutants will produce an output dataframe containing the new predictions. Here we are using the expected improvement ('ei') acquisition function.
48 |
49 | # %%
50 | out = model.search(acq_fn="ei")
51 | out
52 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: proteusAI
2 | channels:
3 | - pyg
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - anaconda
8 | dependencies:
9 | - python=3.8
10 | # - pytorch-scatter
11 | - pytorch=2.4.1
12 | # - pyg
13 | # - cpuonly
14 | - pip
15 | - sphinx
16 | - pdbfixer
17 | - pip:
18 | - -e .[docs] -f https://data.pyg.org/whl/torch-2.4.1+cpu.html
19 |
--------------------------------------------------------------------------------
/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/favicon.ico
--------------------------------------------------------------------------------
/proteusEnvironment.yml:
--------------------------------------------------------------------------------
1 | name: proteusAI_depl
2 | channels:
3 | - anaconda
4 | - pyg
5 | - pytorch
6 | - nvidia
7 | - conda-forge
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=2_kmp_llvm
11 | - aiohttp=3.9.3=py38h01eb140_1
12 | - aiosignal=1.3.1=pyhd8ed1ab_0
13 | - anyio=4.4.0=pyhd8ed1ab_0
14 | - appdirs=1.4.4=pyh9f0ad1d_0
15 | - asgiref=3.8.1=pyhd8ed1ab_0
16 | - asttokens=2.4.1=pyhd8ed1ab_0
17 | - async-timeout=4.0.3=pyhd8ed1ab_0
18 | - attrs=23.2.0=pyh71513ae_0
19 | - backcall=0.2.0=pyh9f0ad1d_0
20 | - blas=2.116=mkl
21 | - blas-devel=3.9.0=16_linux64_mkl
22 | - brotli-python=1.1.0=py38h17151c0_1
23 | - bzip2=1.0.8=hd590300_5
24 | - ca-certificates=2024.7.4=hbcca054_0
25 | - certifi=2024.7.4=pyhd8ed1ab_0
26 | - charset-normalizer=3.3.2=pyhd8ed1ab_0
27 | - click=8.1.7=unix_pyh707e725_0
28 | - colorama=0.4.6=pyhd8ed1ab_0
29 | - comm=0.2.2=pyhd8ed1ab_0
30 | - cuda-cudart=12.1.105=0
31 | - cuda-cupti=12.1.105=0
32 | - cuda-libraries=12.1.0=0
33 | - cuda-nvrtc=12.1.105=0
34 | - cuda-nvtx=12.1.105=0
35 | - cuda-opencl=12.4.127=0
36 | - cuda-runtime=12.1.0=0
37 | - cudatoolkit=11.8.0=h4ba93d1_13
38 | - debugpy=1.8.1=py38h17151c0_0
39 | - decorator=5.1.1=pyhd8ed1ab_0
40 | - exceptiongroup=1.2.0=pyhd8ed1ab_2
41 | - executing=2.0.1=pyhd8ed1ab_0
42 | - ffmpeg=4.3=hf484d3e_0
43 | - filelock=3.13.3=pyhd8ed1ab_0
44 | - freetype=2.12.1=h267a509_2
45 | - frozenlist=1.4.1=py38h01eb140_0
46 | - fsspec=2024.3.1=pyhca7485f_0
47 | - gmp=6.3.0=h59595ed_1
48 | - gmpy2=2.1.2=py38h793c122_1
49 | - gnutls=3.6.13=h85f3911_1
50 | - gpytorch=1.12=pyhd8ed1ab_0
51 | - h11=0.14.0=pyhd8ed1ab_0
52 | - htmltools=0.5.2=pyhd8ed1ab_0
53 | - icu=73.2=h59595ed_0
54 | - idna=3.6=pyhd8ed1ab_0
55 | - importlib-metadata=7.1.0=pyha770c72_0
56 | - importlib_metadata=7.1.0=hd8ed1ab_0
57 | - ipykernel=6.29.3=pyhd33586a_0
58 | - ipython=8.12.2=pyh41d4057_0
59 | - jaxtyping=0.2.19=pyhd8ed1ab_0
60 | - jedi=0.19.1=pyhd8ed1ab_0
61 | - jinja2=3.1.3=pyhd8ed1ab_0
62 | - joblib=1.3.2=pyhd8ed1ab_0
63 | - jpeg=9e=h166bdaf_2
64 | - jupyter_client=8.6.2=pyhd8ed1ab_0
65 | - jupyter_core=5.7.2=py38h578d9bd_0
66 | - keyutils=1.6.1=h166bdaf_0
67 | - krb5=1.21.2=h659d440_0
68 | - lame=3.100=h166bdaf_1003
69 | - lcms2=2.15=hfd0df8a_0
70 | - ld_impl_linux-64=2.40=h41732ed_0
71 | - lerc=4.0.0=h27087fc_0
72 | - libblas=3.9.0=16_linux64_mkl
73 | - libcblas=3.9.0=16_linux64_mkl
74 | - libcublas=12.1.0.26=0
75 | - libcufft=11.0.2.4=0
76 | - libcufile=1.9.1.3=0
77 | - libcurand=10.3.5.147=0
78 | - libcusolver=11.4.4.55=0
79 | - libcusparse=12.0.2.55=0
80 | - libdeflate=1.17=h0b41bf4_0
81 | - libedit=3.1.20191231=he28a2e2_2
82 | - libffi=3.4.2=h7f98852_5
83 | - libgcc-ng=13.2.0=h807b86a_5
84 | - libgfortran-ng=13.2.0=h69a702a_5
85 | - libgfortran5=13.2.0=ha4646dd_5
86 | - libgomp=13.2.0=h807b86a_5
87 | - libhwloc=2.9.3=default_h554bfaf_1009
88 | - libiconv=1.17=hd590300_2
89 | - libjpeg-turbo=2.0.0=h9bf148f_0
90 | - liblapack=3.9.0=16_linux64_mkl
91 | - liblapacke=3.9.0=16_linux64_mkl
92 | - libllvm14=14.0.6=hcd5def8_4
93 | - libnpp=12.0.2.50=0
94 | - libnsl=2.0.1=hd590300_0
95 | - libnvjitlink=12.1.105=0
96 | - libnvjpeg=12.1.1.14=0
97 | - libpng=1.6.43=h2797004_0
98 | - libsodium=1.0.18=h36c2ea0_1
99 | - libsqlite=3.45.2=h2797004_0
100 | - libstdcxx-ng=13.2.0=h7e041cc_5
101 | - libtiff=4.5.0=h6adf6a1_2
102 | - libuuid=2.38.1=h0b41bf4_0
103 | - libwebp-base=1.3.2=hd590300_0
104 | - libxcb=1.13=h7f98852_1004
105 | - libxcrypt=4.4.36=hd590300_1
106 | - libxml2=2.12.6=h232c23b_1
107 | - libzlib=1.2.13=hd590300_5
108 | - linear_operator=0.5.2=pyhd8ed1ab_0
109 | - linkify-it-py=2.0.3=pyhd8ed1ab_0
110 | - llvm-openmp=15.0.7=h0cdce71_0
111 | - llvmlite=0.41.1=py38h94a1851_0
112 | - markdown-it-py=3.0.0=pyhd8ed1ab_0
113 | - markupsafe=2.1.5=py38h01eb140_0
114 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0
115 | - mdit-py-plugins=0.4.1=pyhd8ed1ab_0
116 | - mdurl=0.1.2=pyhd8ed1ab_0
117 | - mkl=2022.1.0=h84fe81f_915
118 | - mkl-devel=2022.1.0=ha770c72_916
119 | - mkl-include=2022.1.0=h84fe81f_915
120 | - mpc=1.3.1=hfe3b2da_0
121 | - mpfr=4.2.1=h9458935_1
122 | - mpmath=1.3.0=pyhd8ed1ab_0
123 | - multidict=6.0.5=py38h01eb140_0
124 | - ncurses=6.4.20240210=h59595ed_0
125 | - nest-asyncio=1.6.0=pyhd8ed1ab_0
126 | - nettle=3.6=he412f7d_0
127 | - networkx=3.1=pyhd8ed1ab_0
128 | - numba=0.58.1=py38h4144172_0
129 | - numpy=1.23.5=py38h7042d01_0
130 | - ocl-icd=2.3.2=hd590300_1
131 | - ocl-icd-system=1.0.0=1
132 | - openh264=2.1.1=h780b84a_0
133 | - openjpeg=2.5.0=hfec8fc6_2
134 | - openmm=8.1.1=py38h10fed7f_1
135 | - openssl=3.3.1=h4ab18f5_1
136 | - packaging=24.0=pyhd8ed1ab_0
137 | - pandas=2.0.3=py38h01efb38_1
138 | - parso=0.8.4=pyhd8ed1ab_0
139 | - pdbfixer=1.9=pyh1a96a4e_0
140 | - pexpect=4.9.0=pyhd8ed1ab_0
141 | - pickleshare=0.7.5=py_1003
142 | - pillow=9.4.0=py38hde6dc18_1
143 | - pip=24.0=pyhd8ed1ab_0
144 | - platformdirs=4.2.0=pyhd8ed1ab_0
145 | - pooch=1.8.1=pyhd8ed1ab_0
146 | - prompt-toolkit=3.0.42=pyha770c72_0
147 | - prompt_toolkit=3.0.42=hd8ed1ab_0
148 | - psutil=5.9.8=py38h01eb140_0
149 | - pthread-stubs=0.4=h36c2ea0_1001
150 | - ptyprocess=0.7.0=pyhd3deb0d_0
151 | - pure_eval=0.2.2=pyhd8ed1ab_0
152 | - pyg=2.5.2=py38_torch_2.2.0_cu121
153 | - pygments=2.18.0=pyhd8ed1ab_0
154 | - pynndescent=0.5.13=pyhff2d567_0
155 | - pyparsing=3.1.2=pyhd8ed1ab_0
156 | - pysocks=1.7.1=pyha2e5f31_6
157 | - python=3.8.19=hd12c33a_0_cpython
158 | - python-dateutil=2.9.0=pyhd8ed1ab_0
159 | - python-multipart=0.0.9=pyhd8ed1ab_0
160 | - python-tzdata=2024.1=pyhd8ed1ab_0
161 | - python_abi=3.8=4_cp38
162 | - pytorch=2.2.2=py3.8_cuda12.1_cudnn8.9.2_0
163 | - pytorch-cuda=12.1=ha16c6d3_5
164 | - pytorch-mutex=1.0=cuda
165 | - pytz=2024.1=pyhd8ed1ab_0
166 | - pyyaml=6.0.1=py38h01eb140_1
167 | - pyzmq=26.0.3=py38ha44f8e3_0
168 | - questionary=2.0.1=pyhd8ed1ab_0
169 | - readline=8.2=h8228510_1
170 | - requests=2.31.0=pyhd8ed1ab_0
171 | - scikit-learn=1.3.2=py38ha25d942_2
172 | - scipy=1.10.1=py38h59b608b_3
173 | - setuptools=69.2.0=pyhd8ed1ab_0
174 | - shiny=0.10.2=pyhd8ed1ab_0
175 | - six=1.16.0=pyh6c4a22f_0
176 | - sniffio=1.3.1=pyhd8ed1ab_0
177 | - stack_data=0.6.2=pyhd8ed1ab_0
178 | - starlette=0.37.2=pyhd8ed1ab_0
179 | - sympy=1.12=pypyh9d50eac_103
180 | - tbb=2021.11.0=h00ab1b0_1
181 | - threadpoolctl=3.4.0=pyhc1e730c_0
182 | - tk=8.6.13=noxft_h4845f30_101
183 | - torchaudio=2.2.2=py38_cu121
184 | - torchtriton=2.2.0=py38
185 | - torchvision=0.17.2=py38_cu121
186 | - tornado=6.4=py38h01eb140_0
187 | - tqdm=4.66.2=pyhd8ed1ab_0
188 | - traitlets=5.14.3=pyhd8ed1ab_0
189 | - typeguard=2.13.3=pyhd8ed1ab_0
190 | - typing-extensions=4.11.0=hd8ed1ab_0
191 | - typing_extensions=4.11.0=pyha770c72_0
192 | - uc-micro-py=1.0.3=pyhd8ed1ab_0
193 | - umap-learn=0.5.4=py38h06a4308_0
194 | - urllib3=2.2.1=pyhd8ed1ab_0
195 | - uvicorn=0.30.1=py38h578d9bd_0
196 | - watchfiles=0.22.0=py38h31a4407_0
197 | - wcwidth=0.2.13=pyhd8ed1ab_0
198 | - websockets=12.0=py38h01eb140_0
199 | - wheel=0.43.0=pyhd8ed1ab_1
200 | - xorg-libxau=1.0.11=hd590300_0
201 | - xorg-libxdmcp=1.1.3=h7f98852_0
202 | - xz=5.2.6=h166bdaf_0
203 | - yaml=0.2.5=h7f98852_2
204 | - yarl=1.9.4=py38h01eb140_0
205 | - zeromq=4.3.5=h75354e8_4
206 | - zipp=3.17.0=pyhd8ed1ab_0
207 | - zlib=1.2.13=hd590300_5
208 | - zstd=1.5.5=hfc55251_0
209 | - biotite=0.35.0
210 | - seaborn=0.13.2
211 | - pyg::pytorch-scatter=2.1.2
212 | - py3dmol=2.3.0
213 | - pip:
214 | - fair-esm=2.0.0
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "proteusAI"
7 | version = "0.1.1"
8 | requires-python = ">= 3.8"
9 | description = "ProteusAI is a python package designed for AI driven protein engineering."
10 | readme = "README.md"
11 |
12 | dependencies = [
13 | "torch==2.4.1",
14 | "torch_geometric",
15 | "torch-scatter",
16 | "uvicorn",
17 | "asgiref",
18 | "starlette",
19 | #"pdbfixer",
20 | "shiny",
21 | "pandas",
22 | "numpy",
23 | "requests",
24 | "scipy",
25 | "fair-esm",
26 | "matplotlib",
27 | "biopython",
28 | "biotite",
29 | "scikit-learn",
30 | "optuna",
31 | "seaborn",
32 | "plotly",
33 | "openpyxl",
34 | "py3Dmol",
35 | "gpytorch",
36 | "openmm",
37 | "umap-learn",
38 | "hdbscan",
39 | "proteusAI",
40 | #"pdbfixer @ git+https://github.com/openmm/pdbfixer@1.9"
41 | ]
42 |
43 | [project.optional-dependencies]
44 | docs = [
45 | "sphinx",
46 | "sphinx-book-theme",
47 | "myst-nb",
48 | "ipywidgets",
49 | "sphinx-new-tab-link!=0.2.2",
50 | "jupytext",
51 | ]
52 | dev = ["black", "ruff", "pytest", "flake8", "flake8-import-order", "flake8-builtins", "flake8-bugbear"]
53 |
54 | # Handle imports with flake8
55 | [tool.flake8]
56 | max-line-length = 88
57 | import-order-style = "google"
58 | application-import-names = ["proteusAI"]
59 |
60 |
--------------------------------------------------------------------------------
/run_benchmark.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the possible combinations of models and embeddings
4 | representations=("esm2" "blosum62" "ohe") # "blosum50"
5 | acq_fns=("ei" "greedy" "ucb") # "random"
6 | models=($1) # "gp" "rf" "ridge" "svm" "knn"
7 |
8 | for acq_fn in "${acq_fns[@]}"; do
9 | for rep in "${representations[@]}"; do
10 | for model in "${models[@]}"; do
11 | # Execute the script with the current combination of model and embedding
12 | python demo/MLDE_benchmark.py --model "$model" --rep "$rep" --acquisition_fn "$acq_fn" --max-iter 100
13 | python demo/plot_benchmark.py --rep "$rep" --max-sample 100 --model "$model" --acquisition-fn "$acq_fn"
14 | done
15 | done
16 | done
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | author="Jonathan Funk",
5 | author_email="funk.jonathan21@gmail.com",
6 | url="https://github.com/jonfunk21/ProteusAI",
7 | packages=find_packages("src"),
8 | package_dir={"": "src"},
9 | classifiers=[
10 | "Development Status :: 3 - Alpha",
11 | "Intended Audience :: Developers",
12 | "License :: OSI Approved :: MIT License",
13 | "Programming Language :: Python :: 3",
14 | "Programming Language :: Python :: 3.6",
15 | "Programming Language :: Python :: 3.7",
16 | "Programming Language :: Python :: 3.8",
17 | "Programming Language :: Python :: 3.9",
18 | ],
19 | python_requires=">=3.8",
20 | )
21 |
--------------------------------------------------------------------------------
/src/proteusAI/Library/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage is concerned with the Library object of ProteusAI.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from .library import Library # noqa: F401
12 |
--------------------------------------------------------------------------------
/src/proteusAI/Model/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage is concerned with the Model object of ProteusAI.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from .model import Model # noqa: F401
12 |
--------------------------------------------------------------------------------
/src/proteusAI/Protein/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage is concerned with the Protein object of ProteusAI.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from .protein import Protein # noqa: F401
12 |
--------------------------------------------------------------------------------
/src/proteusAI/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 |
5 | from importlib import metadata
6 |
7 | __version__ = metadata.version("proteusAI")
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from .Protein import * # noqa: F403
12 | from .Library import * # noqa: F403
13 | from .Model import * # noqa: F403
14 |
--------------------------------------------------------------------------------
/src/proteusAI/data_tools/MSA.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import biotite.sequence.graphics as graphics
8 | import biotite.application.muscle as muscle
9 | import matplotlib.pyplot as plt
10 | from Bio.SeqRecord import SeqRecord
11 | from Bio.Seq import Seq
12 | from Bio.Align.Applications import ClustalwCommandline
13 | from Bio import AlignIO
14 | from Bio import SeqIO
15 | from collections import Counter
16 | from biotite.sequence import ProteinSequence
17 |
18 |
19 | def align_proteins(
20 | names: list,
21 | seqs: list,
22 | plot_results: bool = False,
23 | plt_range: tuple = (0, 200),
24 | muscle_version: str = "5",
25 | save_fig: str = None,
26 | save_fasta: str = None,
27 | figsize: tuple = (10.0, 8.0),
28 | ):
29 | """
30 | performs multiple sequence alignement given a list of blast names and the corresponding sequences.
31 |
32 | Parameters:
33 | names (list): list of sequence names
34 | seqs (list): list of sequences for MSA
35 | plot_results (bool, optional): plot results.
36 | plt_range (tuple, optional): range of sequence which is plotted. Default first 200 amino acids.
37 | muscle_version (str, optional): which muscle version is installed on your machine. Default 5.
38 | save_fig (str): saves fig if path is provided. Default None - won't save the figure.
39 | save_fasta (str): saves fasta if path is provided. Default None - won't save fasta.
40 | figsize (tuple): dimensions of the figure.
41 |
42 | Returns:
43 | dict: MSA results of sequence names/ids and gapped sequences
44 | """
45 |
46 | # Convert sequences to ProteinSequence objects if they are strings
47 | seqs = [ProteinSequence(seq) if isinstance(seq, str) else seq for seq in seqs]
48 |
49 | if muscle_version == "5":
50 | app = muscle.Muscle5App(seqs)
51 | elif muscle_version == "3":
52 | app = muscle.MuscleApp(seqs)
53 | else:
54 | raise ValueError("Muscle version must be either 3 or 5")
55 |
56 | app.start()
57 | app.join()
58 | alignment = app.get_alignment()
59 |
60 | # Print the MSA with hit IDs
61 | gapped_seqs = alignment.get_gapped_sequences()
62 |
63 | MSA_results = {}
64 | for i in range(len(gapped_seqs)):
65 | MSA_results[names[i]] = gapped_seqs[i]
66 |
67 | # Reorder alignments to reflect sequence distance
68 | fig = plt.figure(figsize=figsize)
69 | ax = fig.add_subplot(111)
70 | order = app.get_alignment_order()
71 | graphics.plot_alignment_type_based(
72 | ax,
73 | alignment[plt_range[0] : plt_range[1], order.tolist()],
74 | labels=[names[i] for i in order],
75 | show_numbers=True,
76 | color_scheme="clustalx",
77 | )
78 | fig.tight_layout()
79 |
80 | if save_fig is not None:
81 | plt.savefig(save_fig)
82 |
83 | if plot_results:
84 | plt.show()
85 | else:
86 | plt.close()
87 |
88 | if save_fasta is not None:
89 | with open(save_fasta, "w") as f:
90 | for i, key in enumerate(MSA_results.keys()):
91 | s = MSA_results[key]
92 | if i < len(MSA_results) - 1:
93 | f.writelines(f">{key}\n")
94 | f.writelines(f"{s}\n")
95 | else:
96 | f.writelines(f">{key}\n")
97 | f.writelines(f"{s}")
98 |
99 | return MSA_results
100 |
101 |
102 | def MSA_results_to_fasta(MSA_results: dict, fname: str):
103 | """
104 | Takes MSA results from the align proteins function and writes then into a fasta format.
105 |
106 | Parameters:
107 | MSA_results (dict): Dictionary of MSA results.
108 | fname (str): file name.
109 |
110 | Returns:
111 | None
112 | """
113 | with open(fname, "w") as f:
114 | for i, key in enumerate(MSA_results.keys()):
115 | s = MSA_results[key]
116 | if i < len(MSA_results) - 1:
117 | f.writelines(f">{key}\n")
118 | f.writelines(f"{s}\n")
119 | else:
120 | f.writelines(f">{key}\n")
121 | f.writelines(f"{s}")
122 |
123 |
124 | def align_dna(dna_sequences: list, verbose: bool = False):
125 | """
126 | performs multiple sequence alignement given a list of blast names and the corresponding sequences.
127 | The function uses the clustalw2 app which uses the global needleman-wunsch algorithm.
128 |
129 | Parameters:
130 | seqs (list): list of DNA sequences for MSA
131 | verbose (bool): print std out and std error if True. Default False
132 |
133 | Returns:
134 | list: MSA list of sequences
135 | """
136 | # Create SeqRecord objects for each DNA sequence
137 | seq_records = [
138 | SeqRecord(Seq(seq), id=f"seq{i + 1}") for i, seq in enumerate(dna_sequences)
139 | ]
140 |
141 | # Write the DNA sequences to a temporary file in FASTA format
142 | temp_input_file = "temp_input.fasta"
143 | with open(temp_input_file, "w") as handle:
144 | SeqIO.write(seq_records, handle, "fasta")
145 |
146 | # Run ClustalW to perform the multiple sequence alignment
147 | temp_output_file = "temp_output.aln"
148 | clustalw_cline = ClustalwCommandline(
149 | "clustalw2", infile=temp_input_file, outfile=temp_output_file
150 | )
151 | stdout, stderr = clustalw_cline()
152 |
153 | if verbose:
154 | print("std out:")
155 | print("-------")
156 | print(stdout)
157 | print("std error:")
158 | print("----------")
159 | print(stderr)
160 |
161 | # Read in the aligned sequences from the output file
162 | aligned_seqs = []
163 | with open(temp_output_file) as handle:
164 | alignment = AlignIO.read(handle, "clustal")
165 | for record in alignment:
166 | aligned_seqs.append(str(record.seq))
167 |
168 | return aligned_seqs
169 |
170 |
171 | def get_consensus_sequence(dna_sequences: list):
172 | """
173 | Calculates the consensus sequence of multiple sequence alignements.
174 | It uses the most common character of every sequence in a list. All
175 | sequences need to be the same length.
176 |
177 | Parameters:
178 | dna_sequences (list): list of DNA sequences.
179 |
180 | Returns:
181 | str: consensus sequence.
182 | """
183 | # Get the length of the sequences
184 | sequence_length = len(dna_sequences[0])
185 |
186 | # Iterate over each position in the sequences
187 | consensus_sequence = ""
188 | for i in range(sequence_length):
189 | # Create a list of the characters at this position
190 | char_list = [seq[i] for seq in dna_sequences]
191 |
192 | # Count the occurrences of each character
193 | char_counts = Counter(char_list)
194 |
195 | # Get the most common character
196 | most_common_char = char_counts.most_common(1)[0][0]
197 |
198 | # Add the most common character to the consensus sequence
199 | consensus_sequence += most_common_char
200 |
201 | return consensus_sequence
202 |
--------------------------------------------------------------------------------
/src/proteusAI/data_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for data_tools engineering, visualization and analysis.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.data_tools.pdb import * # noqa: F403
12 | from proteusAI.data_tools.MSA import * # noqa: F403
13 |
--------------------------------------------------------------------------------
/src/proteusAI/data_tools/pdb.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | from biotite.structure.io.mol import MOLFile
8 | from biotite.structure.io.pdb import PDBFile
9 | import numpy as np
10 | import biotite.structure as struc
11 | import biotite.structure.io as strucio
12 | from Bio import SeqIO
13 | import py3Dmol
14 | from string import ascii_uppercase, ascii_lowercase
15 |
16 |
17 | alphabet_list = list(ascii_uppercase + ascii_lowercase)
18 | pymol_color_list = [
19 | "#33ff33",
20 | "#00ffff",
21 | "#ff33cc",
22 | "#ffff00",
23 | "#ff9999",
24 | "#e5e5e5",
25 | "#7f7fff",
26 | "#ff7f00",
27 | "#7fff7f",
28 | "#199999",
29 | "#ff007f",
30 | "#ffdd5e",
31 | "#8c3f99",
32 | "#b2b2b2",
33 | "#007fff",
34 | "#c4b200",
35 | "#8cb266",
36 | "#00bfbf",
37 | "#b27f7f",
38 | "#fcd1a5",
39 | "#ff7f7f",
40 | "#ffbfdd",
41 | "#7fffff",
42 | "#ffff7f",
43 | "#00ff7f",
44 | "#337fcc",
45 | "#d8337f",
46 | "#bfff3f",
47 | "#ff7fff",
48 | "#d8d8ff",
49 | "#3fffbf",
50 | "#b78c4c",
51 | "#339933",
52 | "#66b2b2",
53 | "#ba8c84",
54 | "#84bf00",
55 | "#b24c66",
56 | "#7f7f7f",
57 | "#3f3fa5",
58 | "#a5512b",
59 | ]
60 |
61 |
62 | def get_atom_array(file_path):
63 | """
64 | Returns atom array for a pdb file.
65 |
66 | Parameters:
67 | file_path (str): path to pdb file
68 |
69 | Returns:
70 | atom array
71 | """
72 | if file_path.endswith("pdb"):
73 | atom_mol = PDBFile.read(file_path)
74 | atom_array = atom_mol.get_structure()
75 | else:
76 | try:
77 | atom_mol = MOLFile.read(file_path)
78 | atom_array = atom_mol.get_structure()
79 | except Exception as e:
80 | raise ValueError(f"File: {file_path} has an invalid format. Error: {e}")
81 |
82 | return atom_array
83 |
84 |
85 | def mol_contacts(mols, protein, dist=4.0):
86 | """
87 | Get residue ids of contacts of small molecule(s) to protein.
88 |
89 | Parameters:
90 | mols (str, list of strings): path to molecule file
91 | protein (str): path to protein file
92 | dist (float): distance to be considered contact
93 |
94 | Returns:
95 | set: res_ids of protein residues which are in contact with molecules
96 | """
97 | if isinstance(mols, list) or isinstance(mols, tuple):
98 | mols = [get_atom_array(m) for m in mols]
99 | else:
100 | mols = [get_atom_array(mols)]
101 |
102 | protein = get_atom_array(protein)
103 |
104 | res_ids = set()
105 | for mol in mols:
106 | cell_list = struc.CellList(mol, cell_size=dist)
107 | for prot in protein:
108 | contacts = cell_list.get_atoms(prot.coord, radius=dist)
109 |
110 | contact_indices = np.where((contacts != -1).any(axis=1))[0]
111 |
112 | contact_res_ids = prot.res_id[contact_indices]
113 | res_ids.update(contact_res_ids)
114 |
115 | res_ids = sorted(res_ids)
116 | return res_ids
117 |
118 |
119 | def pdb_to_fasta(pdb_path: str):
120 | """
121 | Returns fasta sequence of pdb file.
122 |
123 | Parameters:
124 | pdb_path (str): path to pdb file
125 |
126 | Returns:
127 | str: fasta sequence string
128 | """
129 | seq = ""
130 | with open(pdb_path, "r") as pdb_file:
131 | for record in SeqIO.parse(pdb_file, "pdb-atom"):
132 | seq = "".join([">", str(record.id), "\n", str(record.seq)])
133 |
134 | return seq
135 |
136 |
137 | def get_sse(pdb_path: str):
138 | """
139 | Returns the secondary structure of a protein given the pdb file.
140 | The secondary structure is infered using the P-SEA algorithm
141 | as imple-mented by the biotite Python package.
142 |
143 | Parameters:
144 | pdb_path (str): path to pdb
145 | """
146 | array = strucio.load_structure(pdb_path)
147 | sse = struc.annotate_sse(array)
148 | sse = "".join(sse)
149 | return sse
150 |
151 |
152 | ### Visualization
153 | def show_pdb(
154 | pdb_path,
155 | color="confidence",
156 | vmin=50,
157 | vmax=90,
158 | chains=None,
159 | Ls=None,
160 | size=(800, 480),
161 | show_sidechains=False,
162 | show_mainchains=False,
163 | highlight=None,
164 | ):
165 | """
166 | This function displays the 3D structure of a protein from a given PDB file in a Jupyter notebook.
167 | The protein structure can be colored by chain, rainbow, pLDDT, or confidence value. The size of the
168 | display can be changed. The sidechains and mainchains can be displayed or hidden.
169 | Parameters:
170 | pdb_path (str): The filename of the PDB file that contains the protein structure.
171 | color (str, optional): The color scheme for the protein structure. Can be "chain", "rainbow", "pLDDT", or "confidence". Defaults to "rainbow".
172 | vmin (float, optional): The minimum value of pLDDT or confidence value. Defaults to 50.
173 | vmax (float, optional): The maximum value of pLDDT or confidence value. Defaults to 90.
174 | chains (int, optional): The number of chains to be displayed. Defaults to None.
175 | Ls (list, optional): A list of the chains to be displayed. Defaults to None.
176 | size (tuple, optional): The size of the display window. Defaults to (800, 480).
177 | show_sidechains (bool, optional): Whether to display the sidechains. Defaults to False.
178 | show_mainchains (bool, optional): Whether to display the mainchains. Defaults to False.
179 | Returns:
180 | view: The 3Dmol view object that displays the protein structure.
181 | """
182 |
183 | with open(pdb_path) as ifile:
184 | system = "".join([x for x in ifile])
185 |
186 | view = py3Dmol.view(
187 | js="https://3dmol.org/build/3Dmol.js", width=size[0], height=size[1]
188 | )
189 |
190 | if chains is None:
191 | chains = 1 if Ls is None else len(Ls)
192 |
193 | view.addModelsAsFrames(system)
194 | if color == "pLDDT" or color == "confidence":
195 | view.setStyle(
196 | {
197 | "cartoon": {
198 | "colorscheme": {
199 | "prop": "b",
200 | "gradient": "rwb",
201 | "min": vmin,
202 | "max": vmax,
203 | }
204 | }
205 | }
206 | )
207 | elif color == "rainbow":
208 | view.setStyle({"cartoon": {"color": "spectrum"}})
209 | elif color == "chain":
210 | for n, chain, color in zip(range(chains), alphabet_list, pymol_color_list):
211 | view.setStyle({"chain": chain}, {"cartoon": {"color": color}})
212 |
213 | if show_sidechains:
214 | BB = ["C", "O", "N"]
215 | view.addStyle(
216 | {
217 | "and": [
218 | {"resn": ["GLY", "PRO"], "invert": True},
219 | {"atom": BB, "invert": True},
220 | ]
221 | },
222 | {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}},
223 | )
224 | view.addStyle(
225 | {"and": [{"resn": "GLY"}, {"atom": "CA"}]},
226 | {"sphere": {"colorscheme": "WhiteCarbon", "radius": 0.3}},
227 | )
228 | view.addStyle(
229 | {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]},
230 | {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}},
231 | )
232 | if show_mainchains:
233 | BB = ["C", "O", "N", "CA"]
234 | view.addStyle(
235 | {"atom": BB}, {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}}
236 | )
237 | view.zoomTo()
238 |
239 | view.zoomTo()
240 |
241 | return view
242 |
--------------------------------------------------------------------------------
/src/proteusAI/design_tools/Constraints.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import numpy as np
8 | import esm
9 | import typing as T
10 | import biotite.sequence as seq
11 | import biotite.sequence.align as align
12 | from biotite.structure.io.pdb import PDBFile
13 | import biotite.structure as struc
14 | from biotite.structure import sasa
15 | import tempfile
16 |
17 | from proteusAI.data_tools import pdb # TODO: double check with Johny!
18 |
19 |
20 | # _____Sequence Constraints_____
21 | def length_constraint(seqs: list, max_len: int = 200):
22 | """
23 | Constraint for the length of a seuqences.
24 |
25 | Parameters:
26 | seqs (list): sequences to be scored
27 | max_len (int): maximum length that a sequence is allowed to be. Default = 300
28 |
29 | Returns:
30 | np.array: Energy values
31 | """
32 | energies = np.zeros(len(seqs))
33 |
34 | for i, sequence in enumerate(seqs):
35 | if len(sequence) > max_len:
36 | energies[i] = float(len(sequence) - max_len)
37 | else:
38 | energies[i] = 0.0
39 |
40 | return energies
41 |
42 |
43 | def seq_identity(seqs, ref, matrix="BLOSUM62", local=False):
44 | """
45 | Calculates sequence identity of sequences against a reference sequence based on alignment.
46 | By default a global alignment is performed using the BLOSUM62 matrix.
47 |
48 | Parameters:
49 | seq1 (str): reference sequence
50 | seq2 (str): query sequence
51 | matrix (str): alignement matrix {BLOSUM62, BLOSUM50, BLOSUM30}. Default BLOSUM62
52 | local (bool): Local alignment if True, else global alignment.
53 |
54 | Returns:
55 | numpy.ndarray: identity scores of sequences
56 | """
57 | alph = seq.ProteinSequence.alphabet
58 | matrix = align.SubstitutionMatrix(alph, alph, matrix)
59 |
60 | seqs = [seq.ProteinSequence(s) for s in seqs]
61 | ref = seq.ProteinSequence(ref)
62 |
63 | scores = np.zeros(len(seqs))
64 | for i, s in enumerate(seqs):
65 | alignments = align.align_optimal(s, ref, matrix, local=local)
66 | score = align.get_sequence_identity(alignments[0])
67 | scores[i] = score
68 |
69 | return scores
70 |
71 |
72 | # _____Structure Constraints_____
73 | def string_to_tempfile(data):
74 | # create a temporary file
75 | with tempfile.NamedTemporaryFile(delete=False) as temp_file:
76 | # write the string to the file
77 | temp_file.write(data.encode("utf-8"))
78 | # flush the file to make sure the data_tools is written
79 | temp_file.flush()
80 | # return the file object
81 | return temp_file
82 |
83 |
84 | def create_batched_sequence_datasest(
85 | sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024
86 | ) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]:
87 | """
88 | Taken from https://github.com/facebookresearch/esm/blob/main/scripts/esmfold_inference.py
89 | """
90 | batch_headers, batch_sequences, num_tokens = [], [], 0
91 | for header, sequence in sequences:
92 | if (len(sequence) + num_tokens > max_tokens_per_batch) and num_tokens > 0:
93 | yield batch_headers, batch_sequences
94 | batch_headers, batch_sequences, num_tokens = [], [], 0
95 | batch_headers.append(header)
96 | batch_sequences.append(sequence)
97 | num_tokens += len(sequence)
98 |
99 | yield batch_headers, batch_sequences
100 |
101 |
102 | def structure_prediction(
103 | sequences: list,
104 | names: list,
105 | chunk_size: int = 124,
106 | max_tokens_per_batch: int = 1024,
107 | num_recycles: int = None,
108 | ):
109 | """
110 | Predict the structure of proteins.
111 |
112 | Parameters:
113 | sequences (list): all sequences for structure prediction
114 | names (list): names of the sequences
115 | chunck_size (int): Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Recommended values: 128, 64, 32.
116 | max_tokens_per_batch (int): Maximum number of tokens per gpu forward-pass. This will group shorter sequences together.
117 | num_recycles (int): Number of recycles to run. Defaults to number used in training 4.
118 |
119 | Returns:
120 | all_headers, all_sequences, all_pdbs, pTMs, mean_pLDDTs
121 | """
122 | model = esm.pretrained.esmfold_v1()
123 | model = model.eval().cuda()
124 | model.set_chunk_size(chunk_size)
125 | all_sequences = list(zip(names, sequences))
126 |
127 | batched_sequences = create_batched_sequence_datasest(
128 | all_sequences, max_tokens_per_batch
129 | )
130 | all_headers = []
131 | all_sequences = []
132 | all_pdbs = []
133 | pTMs = []
134 | mean_pLDDTs = []
135 | for headers, sequences in batched_sequences:
136 | output = model.infer(sequences, num_recycles=num_recycles)
137 | output = {key: value.cpu() for key, value in output.items()}
138 | pdbs = model.output_to_pdb(output)
139 | for header, sequence, pdb_string, mean_plddt, ptm in zip(
140 | headers, sequences, pdbs, output["mean_plddt"], output["ptm"]
141 | ):
142 | all_headers.append(header)
143 | all_sequences.append(sequence)
144 | all_pdbs.append(
145 | PDBFile.read(string_to_tempfile(pdb_string).name)
146 | ) # biotite pdb file name
147 | mean_pLDDTs.append(mean_plddt.item())
148 | pTMs.append(ptm.item())
149 |
150 | return all_headers, all_sequences, all_pdbs, pTMs, mean_pLDDTs
151 |
152 |
153 | def globularity(pdbs):
154 | """
155 | globularity constraint
156 |
157 | Parameters:
158 | pdb (list): list of biotite pdb file
159 |
160 | Returns:
161 | np.array: variances of coordinates for each structure
162 | """
163 | variances = np.zeros(len(pdbs))
164 |
165 | for i, pdb_list in enumerate(pdbs):
166 | variance = pdb_list.get_coord().var()
167 | variances[i] = variance.item()
168 |
169 | return variances
170 |
171 |
172 | def surface_exposed_hydrophobics(pdbs):
173 | """
174 | Calculate the surface exposed hydrophobics using the Shrake-Rupley (“rolling probe”) algorithm.
175 |
176 | Parameters:
177 | pdbs (list): list of biotite pdb file
178 |
179 | Returns:
180 | np.array: average sasa values for each structure
181 | """
182 | avrg_sasa_values = np.zeros(len(pdbs))
183 | for i, pdb_list in enumerate(pdbs):
184 | struc = pdb_list.get_structure()
185 | sasa_val = sasa(
186 | struc[0],
187 | probe_radius=1.4,
188 | atom_filter=None,
189 | ignore_ions=True,
190 | point_number=1000,
191 | point_distr="Fibonacci",
192 | vdw_radii="ProtOr",
193 | )
194 | sasa_mean = sasa_val.mean()
195 | avrg_sasa_values[i] = sasa_mean.item()
196 |
197 | return avrg_sasa_values
198 |
199 |
200 | def backbone_coordination(samples: list, refs: list):
201 | """
202 | Superimpose structures and calculate the RMSD of the backbones.
203 | sample structures will be aligned against reference structures.
204 |
205 | Parameters:
206 | -----------
207 | samples (list): list of sample structures
208 | refs (list): list of reference structures
209 |
210 | Returns:
211 | --------
212 | np.array: RMSD values of alignments
213 | """
214 | if len(samples) != len(refs):
215 | raise "samples and refs must have the same length"
216 |
217 | rmsds = np.zeros(len(samples))
218 |
219 | for i in range(len(samples)):
220 | _, _, rmsd = pdb.struc_align(refs[i], samples[i])
221 | rmsds[i] = rmsd.item()
222 |
223 | return rmsds
224 |
225 |
226 | def all_atom_coordination(samples, refs, sample_consts, ref_consts):
227 | """
228 | Calculate the RMSD of residues with all atom constraints.
229 | All atomic positions will be taken into consideration.
230 |
231 | Parameters:
232 | -----------
233 | samples (list): list of biotite pdb files (mutated)
234 | refs (list): list of biotite pdb files (reference)
235 | sample_consts (list): list of constraints on the sample
236 | (potentially shifts over time due to deletions and insertions)
237 | ref_consts (list): list of constraints of the reference
238 |
239 | Returns:
240 | --------
241 | np.array (len(samples),): calculated RMSD values for each sequence
242 | """
243 |
244 | rmsds = np.zeros(len(samples))
245 |
246 | for i in range(len(samples)):
247 | sample = samples[i]
248 | ref = refs[i]
249 | sample_const = sample_consts[i]
250 | ref_const = ref_consts[i]
251 |
252 | # get structures
253 | sample_struc = sample.get_structure()[0]
254 | ref_struc = ref.get_structure()[0]
255 |
256 | # get indices for alignment
257 | sample_indices = np.where(
258 | np.isin(sample_struc.res_id, [i + 1 for i in sample_const["all_atm"]])
259 | )
260 | ref_indices = np.where(
261 | np.isin(ref_struc.res_id, [i + 1 for i in ref_const["all_atm"]])
262 | )
263 |
264 | sample_struc_common = sample_struc[sample_indices[0]]
265 | ref_struc_common = ref_struc[ref_indices[0]]
266 |
267 | sample_superimposed, transformation = struc.superimpose(
268 | ref_struc_common, sample_struc_common
269 | )
270 |
271 | sample_superimposed = struc.superimpose_apply(sample_struc, transformation)
272 |
273 | sample_pdb = PDBFile()
274 | PDBFile.set_structure(sample_pdb, sample_superimposed)
275 | sample_coord = sample_pdb.get_coord()[0][sample_indices]
276 |
277 | ref_pdb = PDBFile()
278 | PDBFile.set_structure(ref_pdb, ref_struc)
279 | ref_coord = ref_pdb.get_coord()[0][ref_indices]
280 |
281 | rmsd = struc.rmsd(sample_coord, ref_coord)
282 |
283 | rmsds[i] = rmsd
284 |
285 | return rmsds
286 |
--------------------------------------------------------------------------------
/src/proteusAI/design_tools/MCMC.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import os
8 | import random
9 | import numpy as np
10 | from proteusAI.design_tools import Constraints
11 | import pandas as pd
12 |
13 |
14 | class ProteinDesign:
15 | """
16 | Optimizes a protein sequence based on a custom energy function. Set weights of constraints to 0 which you don't want to use.
17 |
18 | Parameters:
19 | -----------
20 | native_seq (str): native sequence to be optimized
21 | constraints (dict): constraints on sequence.
22 | Keys describe the kind of constraint and values the position on which they act.
23 | sampler (str): choose between simulated_annealing and substitution design_tools.
24 | Default 'simulated annealing'
25 | n_traj (int): number of independent trajectories per sampling step. Lowest energy mutant will be selected when
26 | multiple are viable. Default 16
27 | steps (int): number of sampling steps per trajectory.
28 | For simulated annealing, the number of iterations is often chosen in the range of [1,000, 10,000].
29 | T (float): sampling temperature.
30 | For simulated annealing, T0 is often chosen in the range [1, 100]. default 1
31 | M (float): rate of temperature decay.
32 | or simulated annealing, a is often chosen in the range [0.01, 0.1] or [0.001, 0.01]. Default 0.01
33 | mut_p (tuple): probabilities for substitution, insertion and deletion.
34 | Default [0.6, 0.2, 0.2]
35 | pred_struc (bool): if True predict the structure of the protein at every step and use structure
36 | based constraints in the energy function. Default True.
37 | max_len (int): maximum length sequence length for lenght constraint.
38 | Default 300.
39 | w_len (float): weight of length constraint.
40 | Default 0.01
41 | w_identity (float): Weight of sequence identity constraint. Positive values reward low sequence identity to native sequence.
42 | Default 0.04
43 | w_ptm (float): weight for ptm. pTM is calculated as 1-pTM, because lower energies should be better.
44 | Default 1.
45 | w_plddt (float): weight for plddt. The mean pLDDT is calculated as 1-mean_pLDDT, because lower energies should be better.
46 | Default 1.
47 | w_globularity (float): weight of globularity constraint
48 | Default 0.001
49 | w_bb_coord (float): weight on backbone coordination constraint. Constraints backbone to native structure.
50 | Default 0.02
51 | w_all_atm (float): weight on all atom coordination constraint. Acts on all atoms which are constrained
52 | Default 0.15
53 | w_sasa (float): weight of surface exposed hydrophobics constraint
54 | Default 0.02
55 | outdir (str): path to output directory.
56 | Default None
57 | verbose (bool): if verbose print information
58 | """
59 |
60 | def __init__(
61 | self,
62 | native_seq: str = None,
63 | constraints=None,
64 | sampler: str = "simulated_annealing",
65 | n_traj: int = 16,
66 | steps: int = 1000,
67 | T: float = 10.0,
68 | M: float = 0.01,
69 | mut_p: list = (0.6, 0.2, 0.2),
70 | pred_struc: bool = True,
71 | max_len: int = 300,
72 | w_len: float = 0.01,
73 | w_identity: float = 0.1,
74 | w_ptm: float = 1,
75 | w_plddt: float = 1,
76 | w_globularity: float = 0.001,
77 | w_bb_coord: float = 0.02,
78 | w_all_atm: float = 0.15,
79 | w_sasa: float = 0.02,
80 | outdir: str = None,
81 | verbose: bool = False,
82 | ):
83 |
84 | if constraints is None:
85 | constraints = {"no_mut": [], "all_atm": []}
86 | self.native_seq = native_seq
87 | self.sampler = sampler
88 | self.n_traj = n_traj
89 | self.steps = steps
90 | self.mut_p = mut_p
91 | self.T = T
92 | self.M = M
93 | self.pred_struc = pred_struc
94 | self.max_len = max_len
95 | self.w_max_len = w_len
96 | self.w_identity = w_identity
97 | self.w_ptm = w_ptm
98 | self.w_plddt = w_plddt
99 | self.w_globularity = w_globularity
100 | self.outdir = outdir
101 | self.verbose = verbose
102 | self.constraints = constraints
103 | self.w_sasa = w_sasa
104 | self.w_bb_coord = w_bb_coord
105 | self.w_all_atm = w_all_atm
106 |
107 | # Parameters
108 | self.ref_pdbs = None
109 | self.ref_constraints = None
110 | self.initial_energy = None
111 |
112 | def __str__(self):
113 | lines: list[str] = [
114 | "ProteusAI.MCMC.Design class: \n",
115 | "---------------------------------------\n",
116 | "When Hallucination.run() sequences will be hallucinated using this seed sequence:\n\n",
117 | f"{self.native_seq}\n",
118 | "\nThe following variables were set:\n\n",
119 | "variable\t|value\n",
120 | "----------------+-------------------\n",
121 | f"algorithm: \t|{self.sampler}\n",
122 | f"steps: \t\t|{self.steps}\n",
123 | f"n_traj: \t|{self.n_traj}\n",
124 | f"mut_p: \t\t|{self.mut_p}\n",
125 | f"T: \t\t|{self.T}\n",
126 | f"M: \t\t|{self.M}\n\n",
127 | "The energy function is a linear combination of the following constraints:\n\n",
128 | "constraint\t|value\t|weight\n",
129 | "----------------+-------+------------\n",
130 | f"length \t\t|{self.max_len}\t|{self.w_max_len}\n",
131 | f"identity\t|\t|{self.w_identity}\n",
132 | ]
133 | s = "".join(lines)
134 | if self.pred_struc:
135 | lines = [
136 | s,
137 | f"pTM\t\t|\t|{self.w_ptm}\n",
138 | f"pLDDT\t\t|\t|{self.w_plddt}\n",
139 | f"bb_coord\t\t|{self.w_bb_coord}\n",
140 | f"all_atm\t\t|\t|{self.w_all_atm}\n",
141 | f"sasa\t\t|\t|{self.w_sasa}\n",
142 | ]
143 | s = "".join(lines)
144 | return s
145 |
146 | ### SAMPLERS
147 | def mutate(self, seqs, mut_p: list = None, constraints: list = None):
148 | """
149 | mutates input sequences.
150 |
151 | Parameters:
152 | seqs (list): list of peptide sequences
153 | mut_p (list): mutation probabilities
154 | constraints (list): dictionary of constraints
155 |
156 | Returns:
157 | list: mutated sequences
158 | """
159 |
160 | if mut_p is None:
161 | mut_p = [0.6, 0.2, 0.2]
162 |
163 | AAs = (
164 | "A",
165 | "C",
166 | "D",
167 | "E",
168 | "F",
169 | "G",
170 | "H",
171 | "I",
172 | "K",
173 | "L",
174 | "M",
175 | "N",
176 | "P",
177 | "Q",
178 | "R",
179 | "S",
180 | "T",
181 | "V",
182 | "W",
183 | "Y",
184 | )
185 |
186 | mut_types = ("substitution", "insertion", "deletion")
187 |
188 | mutated_seqs = []
189 | mutated_constraints = []
190 | mutations = []
191 | for i, seq in enumerate(seqs):
192 | mut_constraints = {}
193 |
194 | # loop until allowed mutation has been selected
195 | mutate = True
196 | while mutate:
197 | pos = random.randint(0, len(seq) - 1)
198 | mut_type = random.choices(mut_types, mut_p)[0]
199 | if pos in constraints[i]["no_mut"] or pos in constraints[i]["all_atm"]:
200 | pass
201 | # secondary structure constraint disallows deletion
202 | # insertions between two secondary structure constraints will have the constraint of their neighbors
203 | else:
204 | break
205 |
206 | if mut_type == "substitution":
207 | replacement = random.choice(AAs)
208 | mut_seq = "".join([seq[:pos], replacement, seq[pos + 1 :]])
209 | for const in constraints[i].keys():
210 | positions = constraints[i][const]
211 | mut_constraints[const] = positions
212 | mutations.append(f"sub:{seq[pos]}{pos}{replacement}")
213 |
214 | elif mut_type == "insertion":
215 | insertion = random.choice(AAs)
216 | mut_seq = "".join([seq[:pos], insertion, seq[pos:]])
217 | # shift constraints after insertion
218 | for const in constraints[i].keys():
219 | positions = constraints[i][const]
220 | positions = [i if i < pos else i + 1 for i in positions]
221 | mut_constraints[const] = positions
222 | mutations.append(f"ins:{pos}{insertion}")
223 |
224 | elif mut_type == "deletion" and len(seq) > 1:
225 | lines = list(seq)
226 | del lines[pos]
227 | mut_seq = "".join(lines)
228 | # shift constraints after deletion
229 | for const in constraints[i].keys():
230 | positions = constraints[i][const]
231 | positions = [i if i < pos else i - 1 for i in positions]
232 | mut_constraints[const] = positions
233 | mutations.append(f"del:{seq[pos]}{pos}")
234 |
235 | else:
236 | # will perform insertion if length is to small
237 | insertion = random.choice(AAs)
238 | mut_seq = "".join([seq[:pos], insertion, seq[pos:]])
239 | # shift constraints after insertion
240 | for const in constraints[i].keys():
241 | positions = constraints[i][const]
242 | positions = [i if i < pos else i + 1 for i in positions]
243 | mut_constraints[const] = positions
244 | mutations.append(f"ins:{pos}{insertion}")
245 |
246 | mutated_seqs.append(mut_seq)
247 | mutated_constraints.append(mut_constraints)
248 |
249 | return mutated_seqs, mutated_constraints, mutations
250 |
251 | ### ENERGY FUNCTION and ACCEPTANCE CRITERION
252 | def energy_function(self, seqs: list, i: int, constraints: list):
253 | """
254 | Combines constraints into an energy function. The energy function
255 | returns the energy values of the mutated files and the associated pdb
256 | files as temporary files. In addition it returns a dictionary of the different
257 | energies.
258 |
259 | Parameters:
260 | seqs (list): list of sequences
261 | i (int): current iteration in sampling
262 | constraints (list): list of constraints
263 |
264 | Returns:
265 | tuple: Energy value, pdbs, energy_log
266 | """
267 | # reinitialize energy
268 | energies = np.zeros(len(seqs))
269 | energy_log = dict()
270 |
271 | e_len = self.w_max_len * Constraints.length_constraint(
272 | seqs=seqs, max_len=self.max_len
273 | )
274 | e_identity = self.w_identity * Constraints.seq_identity(
275 | seqs=seqs, ref=self.native_seq
276 | )
277 |
278 | energies += e_len
279 | energies += e_identity
280 |
281 | energy_log[f"e_len x {self.w_max_len}"] = e_len
282 | energy_log[f"e_identity x {self.w_identity}"] = e_identity
283 |
284 | pdbs = []
285 | if self.pred_struc:
286 | # structure prediction
287 | names = [f"sequence_{j}_cycle_{i}" for j in range(len(seqs))]
288 | headers, sequences, pdbs, pTMs, mean_pLDDTs = (
289 | Constraints.structure_prediction(seqs, names)
290 | )
291 | pTMs = [1 - val for val in pTMs]
292 | mean_pLDDTs = [1 - val / 100 for val in mean_pLDDTs]
293 |
294 | e_pTMs = self.w_ptm * np.array(pTMs)
295 | e_mean_pLDDTs = self.w_plddt * np.array(mean_pLDDTs)
296 | e_globularity = self.w_globularity * Constraints.globularity(pdbs)
297 | e_sasa = self.w_sasa * Constraints.surface_exposed_hydrophobics(pdbs)
298 |
299 | energies += e_pTMs
300 | energies += e_mean_pLDDTs
301 | energies += e_globularity
302 | energies += e_sasa
303 |
304 | energy_log[f"e_pTMs x {self.w_ptm}"] = e_pTMs
305 | energy_log[f"e_mean_pLDDTs x {self.w_plddt}"] = e_mean_pLDDTs
306 | energy_log[f"e_globularity x {self.w_globularity}"] = e_globularity
307 | energy_log[f"e_sasa x {self.w_sasa}"] = e_sasa
308 |
309 | # there are now ref pdbs before the first calculation
310 | if self.ref_pdbs is not None:
311 | e_bb_coord = self.w_bb_coord * Constraints.backbone_coordination(
312 | pdbs, self.ref_pdbs
313 | )
314 | e_all_atm = self.w_all_atm * Constraints.all_atom_coordination(
315 | pdbs, self.ref_pdbs, constraints, self.ref_constraints
316 | )
317 |
318 | energies += e_bb_coord
319 | energies += e_all_atm
320 |
321 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = e_bb_coord
322 | energy_log[f"e_all_atm x {self.w_all_atm}"] = e_all_atm
323 | else:
324 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = []
325 | energy_log[f"e_all_atm x {self.w_all_atm}"] = []
326 |
327 | energy_log["iteration"] = i + 1
328 |
329 | return energies, pdbs, energy_log
330 |
331 | def p_accept(self, E_x_mut, E_x_i, T, i, M):
332 | """
333 | Decides to accep or reject changes. Changes which have a lower energy
334 | than the previous state will always be accepted. Changes which have
335 | higher energies will be accepted with a probability p_accept. The
336 | acceptance probability for bad states decreases over time.
337 |
338 | Parameters:
339 | -----------
340 | E_x_mut (np.array): energies of mutated sequences
341 | E_x_i (np.array): energies of initial sequences
342 | T (float): Temperature
343 | i (int): current itteration
344 | M (float): decay constant
345 |
346 | Returns:
347 | --------
348 | np.array: accept probabilities
349 | """
350 | T = T / (1 + M * i)
351 | dE = E_x_i - E_x_mut
352 | exp_val = np.exp(1 / (T * dE))
353 | p_accept = np.minimum(exp_val, np.ones_like(exp_val))
354 | return p_accept
355 |
356 | ### RUN
357 | def run(self):
358 | """
359 | Runs MCMC-sampling based on user defined inputs. Returns optimized sequences.
360 | """
361 | native_seq = self.native_seq
362 | constraints = self.constraints
363 | n_traj = self.n_traj
364 | steps = self.steps
365 | sampler = self.sampler
366 | energy_function = self.energy_function
367 | T = self.T
368 | M = self.M
369 | p_accept = self.p_accept
370 | mut_p = self.mut_p
371 | outdir = self.outdir
372 | pdb_out = os.path.join(outdir, "pdbs")
373 | png_out = os.path.join(outdir, "pngs")
374 | data_out = os.path.join(outdir, "data_tools")
375 |
376 | if outdir is not None:
377 | if not os.path.exists(outdir):
378 | os.mkdir(outdir)
379 | if not os.path.exists(pdb_out):
380 | os.mkdir(pdb_out)
381 | if not os.path.exists(png_out):
382 | os.mkdir(png_out)
383 | if not os.path.exists(data_out):
384 | os.mkdir(data_out)
385 |
386 | if sampler == "simulated_annealing":
387 | mutate = self.mutate
388 |
389 | if native_seq is None:
390 | raise "The optimizer needs a sequence to run. Define a sequence by calling SequenceOptimizer(native_seq = )"
391 |
392 | seqs = [native_seq for _ in range(n_traj)]
393 | constraints = [constraints for _ in range(n_traj)]
394 | self.ref_constraints = constraints.copy() # THESE ARE CORRECT
395 |
396 | # for initial calculation don't use the full sequences, unecessary
397 | # calculation of initial state
398 | E_x_i, pdbs, energy_log = energy_function([seqs[0]], -1, [constraints[0]])
399 | E_x_i = [E_x_i[0] for _ in range(n_traj)]
400 | pdbs = [pdbs[0] for _ in range(n_traj)]
401 |
402 | # empty energies dictionary for the first run
403 | for key in energy_log.keys():
404 | energy_log[key] = []
405 | energy_log["T"] = []
406 | energy_log["M"] = []
407 | energy_log["mut"] = []
408 | energy_log["description"] = []
409 |
410 | self.initial_energy = E_x_i.copy()
411 | self.ref_pdbs = pdbs.copy()
412 |
413 | if self.pred_struc and outdir is not None:
414 | # saves the n th structure
415 | num = "{:0{}d}".format(len(energy_log["iteration"]), len(str(self.steps)))
416 | pdbs[0].write(os.path.join(pdb_out, f"{num}_design.pdb"))
417 |
418 | # write energy_log in data_out
419 | if outdir is not None:
420 | df = pd.DataFrame(energy_log)
421 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False)
422 |
423 | for i in range(steps):
424 | mut_seqs, _constraints, mutations = mutate(seqs, mut_p, constraints)
425 | E_x_mut, pdbs_mut, _energy_log = energy_function(mut_seqs, i, _constraints)
426 | # accept or reject change
427 | p = p_accept(E_x_mut, E_x_i, T, i, M)
428 |
429 | new_struc_found = False
430 | accepted_ind = [] # indices of accepted structures
431 | for n in range(n_traj):
432 | if p[n] > random.random():
433 | accepted_ind.append(n)
434 | E_x_i[n] = E_x_mut[n]
435 | seqs[n] = mut_seqs[n]
436 | constraints[n] = _constraints[n]
437 | new_struc_found = True
438 |
439 | if new_struc_found:
440 | # get index of lowest energie sructure out of the newly found structures
441 | min_E = accepted_ind[0]
442 | for a in accepted_ind:
443 | if a < min_E:
444 | min_E = a
445 |
446 | # update all to lowest energy structure
447 | E_x_i = [E_x_i[min_E] for _ in range(n_traj)]
448 | seqs = [seqs[min_E] for _ in range(n_traj)]
449 | constraints = [constraints[min_E] for _ in range(n_traj)]
450 | pdbs = [pdbs_mut[min_E] for _ in range(n_traj)]
451 |
452 | for key in energy_log.keys():
453 | # skip skalar values in this step
454 | if key not in ["T", "M", "iteration", "mut", "description"]:
455 | e = _energy_log[key]
456 | energy_log[key].append(e[min_E].item())
457 |
458 | energy_log["iteration"].append(i)
459 | energy_log["T"].append(T)
460 | energy_log["M"].append(M)
461 | energy_log["mut"].append(mutations[min_E])
462 |
463 | num = "{:0{}d}".format(
464 | len(energy_log["iteration"]), len(str(self.steps))
465 | )
466 | energy_log["description"].append(f"{num}_design")
467 |
468 | if self.pred_struc and outdir is not None:
469 | # saves the n th structure
470 | pdbs[0].write(os.path.join(pdb_out, f"{num}_design.pdb"))
471 |
472 | # write energy_log in data_out
473 | if outdir is not None:
474 | df = pd.DataFrame(energy_log)
475 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False)
476 |
477 | return seqs
478 |
--------------------------------------------------------------------------------
/src/proteusAI/design_tools/ZeroShot.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import os
8 | import numpy as np
9 | from proteusAI.design_tools import Constraints
10 | import pandas as pd
11 |
12 |
13 | class ZeroShot:
14 | """
15 | ZeroShot inference for single mutational effects. Mutates every single position to every possible amino acid at
16 | that position. Then calculate differnt metrics of structure and embeddings for the mutants. Observe which mutations
17 | have the greatest effect on the mutants.
18 |
19 | Parameters:
20 | -----------
21 | seq (str): native sequence to be optimized
22 | constraints (dict): constraints on sequence.
23 | Keys describe the kind of constraint and values the position on which they act.
24 | sampler (str): choose between simulated_annealing and substitution design_tools.
25 | Default 'simulated annealing'
26 | batch_size (int): number of independent trajectories per sampling step. Lowest energy mutant will be selected when
27 | multiple are viable. Default 16
28 | steps (int): number of sampling steps per trajectory.
29 | For simulated annealing, the number of iterations is often chosen in the range of [1,000, 10,000].
30 | T (float): sampling temperature.
31 | For simulated annealing, T0 is often chosen in the range [1, 100]. default 1
32 | M (float): rate of temperature decay.
33 | or simulated annealing, a is often chosen in the range [0.01, 0.1] or [0.001, 0.01]. Default 0.01
34 | mut_p (tuple): probabilities for substitution, insertion and deletion.
35 | Default [0.6, 0.2, 0.2]
36 | pred_struc (bool): if True predict the structure of the protein at every step and use structure
37 | based constraints in the energy function. Default True.
38 | max_len (int): maximum length sequence length for lenght constraint.
39 | Default 300.
40 | w_len (float): weight of length constraint.
41 | Default 0.01
42 | w_identity (float): Weight of sequence identity constraint. Positive values reward low sequence identity to native sequence.
43 | Default 0.04
44 | w_ptm (float): weight for ptm. pTM is calculated as 1-pTM, because lower energies should be better.
45 | Default 1.
46 | w_plddt (float): weight for plddt. The mean pLDDT is calculated as 1-mean_pLDDT, because lower energies should be better.
47 | Default 1.
48 | w_globularity (float): weight of globularity constraint
49 | Default 0.001
50 | w_bb_coord (float): weight on backbone coordination constraint. Constraints backbone to native structure.
51 | Default 0.02
52 | w_all_atm (float): weight on all atom coordination constraint. Acts on all atoms which are constrained
53 | Default 0.15
54 | w_sasa (float): weight of surface exposed hydrophobics constraint
55 | Default 0.02
56 | outdir (str): path to output directory.
57 | Default None
58 | verbose (bool): if verbose print information
59 | """
60 |
61 | def __init__(
62 | self,
63 | seq: str = None,
64 | name="Prot",
65 | constraints=None,
66 | batch_size: int = 20,
67 | pred_struc: bool = True,
68 | w_ptm: float = 1,
69 | w_plddt: float = 1,
70 | w_globularity: float = 0.001,
71 | w_bb_coord: float = 0.02,
72 | w_all_atm: float = 0.15,
73 | w_sasa: float = 0.02,
74 | outdir: str = None,
75 | verbose: bool = False,
76 | ):
77 |
78 | if constraints is None:
79 | constraints = {"all_atm": []}
80 | self.seq = seq
81 | self.name = name
82 | self.constraints = constraints
83 | self.batch_size = batch_size
84 | self.w_ptm = w_ptm
85 | self.w_plddt = w_plddt
86 | self.w_globularity = w_globularity
87 | self.outdir = outdir
88 | self.verbose = verbose
89 | self.w_sasa = w_sasa
90 | self.w_bb_coord = w_bb_coord
91 | self.w_all_atm = w_all_atm
92 |
93 | # Parameters
94 | self.ref_pdbs = None
95 |
96 | def __str__(self):
97 | lines = [
98 | "ProteusAI.MCMC.ZeroShot class: \n",
99 | "---------------------------------------\n",
100 | "When Hallucination.run() sequences will be hallucinated using this seed sequence:\n\n",
101 | f"{self.seq}\n",
102 | "\nThe following variables were set:\n\n",
103 | "variable\t|value\n",
104 | "----------------+-------------------\n",
105 | f"batch_size: \t|{self.batch_size}\n",
106 | "The energy function is a linear combination of the following constraints:\n\n",
107 | "constraint\t|value\t|weight\n",
108 | "----------------+-------+------------\n",
109 | ]
110 | s = "".join(lines)
111 | lines = [
112 | s,
113 | f"pTM\t\t|\t|{self.w_ptm}\n",
114 | f"pLDDT\t\t|\t|{self.w_plddt}\n",
115 | f"bb_coord\t\t|{self.w_bb_coord}\n",
116 | f"all_atm\t\t|\t|{self.w_all_atm}\n",
117 | f"sasa\t\t|\t|{self.w_sasa}\n",
118 | ]
119 | s = "".join(lines)
120 | return s
121 |
122 | ### SAMPLERS
123 | def mutate(self, seq, pos):
124 | """
125 | mutates input sequences. to all possible amino acids at position
126 |
127 | Parameters:
128 | seqs (str): native protein sequence
129 | pos (int): position to mutate
130 |
131 | Returns:
132 | list: mutated sequences
133 | """
134 |
135 | AAs = (
136 | "A",
137 | "C",
138 | "D",
139 | "E",
140 | "F",
141 | "G",
142 | "H",
143 | "I",
144 | "K",
145 | "L",
146 | "M",
147 | "N",
148 | "P",
149 | "Q",
150 | "R",
151 | "S",
152 | "T",
153 | "V",
154 | "W",
155 | "Y",
156 | )
157 |
158 | native_aa = seq[pos]
159 |
160 | mut_seqs = []
161 | names = []
162 | for res in AAs:
163 | if res != native_aa:
164 | mut_seq = "".join([seq[:pos], res, seq[pos + 1 :]])
165 | mut_seqs.append(mut_seq)
166 | names.append(str(seq[pos]) + str(pos) + str(res))
167 | return mut_seqs, names
168 |
169 | ### ENERGY FUNCTION and ACCEPTANCE CRITERION
170 | def energy_function(self, seqs: list, pos: int, names):
171 | """
172 | Combines constraints into an energy function. The energy function
173 | returns the energy values of the mutated files and the associated pdb
174 | files as temporary files. In addition it returns a dictionary of the different
175 | energies.
176 |
177 | Parameters:
178 | seqs (list): list of sequences
179 | pos (int): position in sequence
180 | constraints (list): list of constraints
181 |
182 | Returns:
183 | tuple: Energy value, pdbs, energy_log
184 | """
185 | # reinitialize energy
186 | energies = np.zeros(len(seqs))
187 | energy_log = dict()
188 | constraints = self.constraints
189 |
190 | for i, c in enumerate(constraints["all_atm"]):
191 | if c == pos:
192 | del constraints["all_atm"][i]
193 | constraints = [constraints for i in range(len(seqs))]
194 |
195 | # structure prediction
196 | names = [f"{name}_{self.name}" for name in names]
197 | headers, sequences, pdbs, pTMs, mean_pLDDTs = Constraints.structure_prediction(
198 | seqs, names
199 | )
200 | pTMs = [1 - val for val in pTMs]
201 | mean_pLDDTs = [1 - val / 100 for val in mean_pLDDTs]
202 |
203 | e_pTMs = self.w_ptm * np.array(pTMs)
204 | e_mean_pLDDTs = self.w_plddt * np.array(mean_pLDDTs)
205 | e_globularity = self.w_globularity * Constraints.globularity(pdbs)
206 | e_sasa = self.w_sasa * Constraints.surface_exposed_hydrophobics(pdbs)
207 |
208 | energies += e_pTMs
209 | energies += e_mean_pLDDTs
210 | energies += e_globularity
211 | energies += e_sasa
212 |
213 | energy_log[f"e_pTMs x {self.w_ptm}"] = e_pTMs
214 | energy_log[f"e_mean_pLDDTs x {self.w_plddt}"] = e_mean_pLDDTs
215 | energy_log[f"e_globularity x {self.w_globularity}"] = e_globularity
216 | energy_log[f"e_sasa x {self.w_sasa}"] = e_sasa
217 |
218 | # there are now ref pdbs before the first calculation
219 | if self.ref_pdbs is not None:
220 | e_bb_coord = self.w_bb_coord * Constraints.backbone_coordination(
221 | pdbs, self.ref_pdbs
222 | )
223 | e_all_atm = self.w_all_atm * Constraints.all_atom_coordination(
224 | pdbs, self.ref_pdbs, constraints, constraints
225 | )
226 |
227 | energies += e_bb_coord
228 | energies += e_all_atm
229 |
230 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = e_bb_coord
231 | energy_log[f"e_all_atm x {self.w_all_atm}"] = e_all_atm
232 | else:
233 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = [0]
234 | energy_log[f"e_all_atm x {self.w_all_atm}"] = [0]
235 |
236 | energy_log["position"] = [pos]
237 |
238 | return energies, pdbs, energy_log
239 |
240 | ### RUN
241 | def run(self):
242 | """
243 | Runs MCMC-sampling based on user defined inputs. Returns optimized sequences.
244 | """
245 | seq = self.seq
246 | batch_size = self.batch_size
247 | energy_function = self.energy_function
248 | outdir = self.outdir
249 | mutate = self.mutate
250 | pdb_out = os.path.join(outdir, "pdbs")
251 | png_out = os.path.join(outdir, "pngs")
252 | data_out = os.path.join(outdir, "data_tools")
253 |
254 | if outdir is not None:
255 | if not os.path.exists(outdir):
256 | os.mkdir(outdir)
257 | if not os.path.exists(pdb_out):
258 | os.mkdir(pdb_out)
259 | if not os.path.exists(png_out):
260 | os.mkdir(png_out)
261 | if not os.path.exists(data_out):
262 | os.mkdir(data_out)
263 |
264 | if seq is None:
265 | raise "Provide a sequence(seq = )"
266 |
267 | # for initial calculation don't use the full sequences, unecessary
268 | # calculation of initial state
269 | _, pdbs, energy_log = energy_function([seq], 0, ["native"])
270 | pdbs = [pdbs[0] for _ in range(batch_size - 1)]
271 | energy_log["mut"] = ["-"]
272 | energy_log["description"] = ["native"]
273 |
274 | # make energies to list
275 | for key in energy_log.keys():
276 | if not isinstance(energy_log[key], list):
277 | energy_log[key] = energy_log[key].tolist()
278 |
279 | self.ref_pdbs = pdbs.copy()
280 |
281 | if outdir is not None:
282 | # saves the n th structure
283 | pdbs[0].write(os.path.join(pdb_out, f"native_{self.name}.pdb"))
284 |
285 | # write energy_log in data_out
286 | if outdir is not None:
287 | df = pd.DataFrame(energy_log)
288 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False)
289 |
290 | for pos in range(len(seq)):
291 | seqs, names = mutate(seq, pos)
292 | E_x_i, pdbs, _energy_log = energy_function(seqs, pos, names)
293 |
294 | for n in range(len(seqs)):
295 | for key in energy_log.keys():
296 | # skip skalar values in this step
297 | if key not in ["position", "mut", "description"]:
298 | e = _energy_log[key].tolist()
299 | with open("test", "w") as f:
300 | print("energy_log", file=f)
301 | print(energy_log, file=f)
302 | print("_energy_log", file=f)
303 | print(_energy_log, file=f)
304 | print("key", file=f)
305 | print(key, file=f)
306 | print("e", file=f)
307 | print(e, file=f)
308 | energy_log[key].append(e[n])
309 |
310 | energy_log["position"].append(pos)
311 | energy_log["mut"].append(names[n])
312 |
313 | energy_log["description"].append(f"{names[n]}_{self.name}")
314 |
315 | if outdir is not None:
316 | # saves the n th structure
317 | pdbs[n].write(os.path.join(pdb_out, f"{names[n]}_{self.name}.pdb"))
318 |
319 | # write energy_log in data_out
320 | if outdir is not None:
321 | df = pd.DataFrame(energy_log)
322 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False)
323 |
--------------------------------------------------------------------------------
/src/proteusAI/design_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for protein design_tools.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.design_tools.Constraints import * # noqa: F403
12 | from proteusAI.design_tools.MCMC import * # noqa: F403
13 | from proteusAI.design_tools.ZeroShot import * # noqa: F403
14 |
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for mining_tools.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.io_tools.embeddings import * # noqa: F403
12 | from proteusAI.io_tools.fasta import * # noqa: F403
13 |
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/embeddings.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import os
8 | import torch
9 | from typing import Union
10 |
11 |
12 | def load_embeddings(
13 | path: str, names: Union[list, None] = None, map_location: str = "cpu"
14 | ) -> tuple:
15 | """
16 | Loads all representations files from a directory, returns the names/ids and sequences as lists.
17 |
18 | Parameters:
19 | path (str): path to directory containing representations files
20 | names (list): list of file names in case files should be loaded in a specific order if names not None.
21 | Will go to provided path and load files by name order. Default None
22 |
23 | Returns:
24 | tuple: two lists containing the names and sequences as torch tensors
25 |
26 | Example:
27 | names, sequences = load('/path/to/representations')
28 | """
29 |
30 | tensors = []
31 | if names is None:
32 | files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")]
33 | names = [f[:-3] for f in os.listdir(path) if f.endswith(".pt")]
34 | for f in files:
35 | t = torch.load(f, map_location=map_location, weights_only=False)
36 | tensors.append(t)
37 | else:
38 | for name in names:
39 | t = torch.load(
40 | os.path.join(path, name), map_location=map_location, weights_only=False
41 | )
42 | tensors.append(t)
43 |
44 | return names, tensors
45 |
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/fasta.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import os
8 | from biotite.sequence import ProteinSequence
9 | import numpy as np
10 | import hashlib
11 |
12 |
13 | def load_all_fastas(
14 | path: str, file_type: str = ".fasta", biotite: bool = False
15 | ) -> dict:
16 | """
17 | Loads all fasta files from a directory, returns the names/ids and sequences as lists.
18 |
19 | Parameters:
20 | path (str): path to directory containing fasta files
21 | file_type (str): some fastas are stored with different file endings. Default '.fasta'.
22 | biotite (bool): returns sequences as biotite.sequence.ProteinSequence object
23 |
24 | Returns:
25 | dict: dictionary of file_names and tuple (names, sequences)
26 |
27 | Example:
28 | results = load_fastas('/path/to/fastas')
29 | """
30 | file_names = [f for f in os.listdir(path) if f.endswith(file_type)]
31 | files = [os.path.join(path, f) for f in file_names if f.endswith(file_type)]
32 |
33 | results = {}
34 | for i, file in enumerate(files):
35 | names = []
36 | sequences = []
37 | with open(file, "r") as f:
38 | current_sequence = ""
39 | for line in f:
40 | line = line.strip()
41 | if line.startswith(">"):
42 | if current_sequence:
43 | sequences.append(current_sequence)
44 | names.append(line[1:])
45 | current_sequence = ""
46 | else:
47 | current_sequence += line
48 | if biotite:
49 | sequences.append(ProteinSequence(current_sequence))
50 | else:
51 | sequences.append(current_sequence)
52 |
53 | results[file_names[i]] = (names, sequences)
54 | return results
55 |
56 |
57 | def load_fasta(file: str, biotite: bool = False) -> tuple:
58 | """
59 | Load all sequences in a fasta file. Returns names and sequences
60 |
61 | Parameters:
62 | file (str): path to file
63 | biotite (bool): returns sequences as biotite.sequence.ProteinSequence object
64 |
65 | Returns:
66 | tuple: two lists containing the names and sequences
67 |
68 | Example:
69 | names, sequences = load_fastas('example.fasta')
70 | """
71 |
72 | names = []
73 | sequences = []
74 | with open(file, "r") as f:
75 | current_sequence = ""
76 | for line in f:
77 | line = line.strip()
78 | if line.startswith(">"):
79 | if current_sequence:
80 | sequences.append(current_sequence)
81 | names.append(line[1:])
82 | current_sequence = ""
83 | else:
84 | current_sequence += line
85 | if biotite:
86 | sequences.append(ProteinSequence(current_sequence))
87 | else:
88 | sequences.append(current_sequence)
89 |
90 | return names, sequences
91 |
92 |
93 | def write_fasta(names: list, sequences: list, dest: str = None):
94 | """
95 | Takes a list of names and sequences and writes a single
96 | fasta file containing all the names and sequences. The
97 | files will be saved at the destination
98 |
99 | Parameters:
100 | names (list): list of sequence names
101 | sequences (list): list of sequences
102 | dest (str): path to output file
103 |
104 | Example:
105 | write_fasta(names, sequences, './out.fasta')
106 | """
107 | assert isinstance(names, list) and isinstance(
108 | sequences, list
109 | ), "names and sequences must be type list"
110 | assert len(names) == len(sequences), "names and sequences must have the same length"
111 |
112 | with open(dest, "w") as f:
113 | for i in range(len(names)):
114 | f.writelines(">" + names[i] + "\n")
115 | if i == len(names) - 1:
116 | f.writelines(sequences[i])
117 | else:
118 | f.writelines(sequences[i] + "\n")
119 |
120 |
121 | def one_hot_encoding(sequence: str):
122 | """
123 | Returns one hot encoding for amino acid sequence. Unknown amino acids will be
124 | encoded with 0.5 at in entire row.
125 |
126 | Parameters:
127 | -----------
128 | sequence (str): Amino acid sequence
129 |
130 | Returns:
131 | --------
132 | numpy.ndarray: One hot encoded sequence
133 | """
134 | # Define amino acid alphabets and create dictionary
135 | amino_acids = "ACDEFGHIKLMNPQRSTVWY"
136 | aa_dict = {aa: i for i, aa in enumerate(amino_acids)}
137 |
138 | # Initialize empty numpy array for one-hot encoding
139 | seq_one_hot = np.zeros((len(sequence), len(amino_acids)))
140 |
141 | # Convert each amino acid in sequence to one-hot encoding
142 | for i, aa in enumerate(sequence):
143 | if aa in aa_dict:
144 | seq_one_hot[i, aa_dict[aa]] = 1.0
145 | else:
146 | # Handle unknown amino acids with a default value of 0.5
147 | seq_one_hot[i, :] = 0.5
148 |
149 | return seq_one_hot
150 |
151 |
152 | def blosum_encoding(sequence, matrix="BLOSUM62", canonical=True):
153 | """
154 | Returns BLOSUM encoding for amino acid sequence. Unknown amino acids will be
155 | encoded with 0.5 at in entire row.
156 |
157 | Parameters:
158 | -----------
159 | sequence (str): Amino acid sequence
160 | blosum_matrix_choice (str): Choice of BLOSUM matrix. Can be 'BLOSUM50' or 'BLOSUM62'
161 | canonical (bool): only use canonical amino acids
162 |
163 | Returns:
164 | --------
165 | numpy.ndarray: BLOSUM encoded sequence
166 | """
167 |
168 | # Get the directory of the current script
169 | script_dir = os.path.dirname(os.path.realpath(__file__))
170 |
171 | ### Amino Acid codes
172 | alphabet_file = os.path.join(script_dir, "matrices/alphabet")
173 | alphabet = np.loadtxt(alphabet_file, dtype=str)
174 |
175 | # Define BLOSUM matrices
176 | _blosum50 = (
177 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM50"), dtype=float)
178 | .reshape((24, -1))
179 | .T
180 | )
181 | _blosum62 = (
182 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM62"), dtype=float)
183 | .reshape((24, -1))
184 | .T
185 | )
186 |
187 | # Choose BLOSUM matrix
188 | if matrix == "BLOSUM50":
189 | matrix = _blosum50
190 | elif matrix == "BLOSUM62":
191 | matrix = _blosum62
192 | else:
193 | raise ValueError(
194 | "Invalid BLOSUM matrix choice. Choose 'BLOSUM50' or 'BLOSUM62'."
195 | )
196 |
197 | blosum_matrix = {}
198 | for i, letter_1 in enumerate(alphabet):
199 | if canonical:
200 | blosum_matrix[letter_1] = matrix[i][:20]
201 | else:
202 | blosum_matrix[letter_1] = matrix[i]
203 |
204 | # create empty encoding vector
205 | encoding = np.zeros((len(sequence), len(blosum_matrix["A"])))
206 |
207 | # Convert each amino acid in sequence to BLOSUM encoding
208 | for i, aa in enumerate(sequence):
209 | if aa in alphabet:
210 | encoding[i, :] = blosum_matrix[aa]
211 | else:
212 | # Handle unknown amino acids with a default value of 0.5
213 | encoding[i, :] = 0.5
214 |
215 | return encoding
216 |
217 |
218 | # hash sequences using sha256
219 | def hash_sequence(sequence: str, length: int = 20) -> str:
220 | """
221 | Hashes a sequence using sha256
222 |
223 | Parameters:
224 | -----------
225 | sequence (str): Amino acid sequence
226 |
227 | Returns:
228 | --------
229 | str: Hashed sequence
230 | """
231 | if sequence is None:
232 | pass
233 | else:
234 | return hashlib.sha256(sequence.encode()).hexdigest()[0:length]
235 |
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/matrices/BLOSUM50:
--------------------------------------------------------------------------------
1 | 5 -2 -1 -2 -1 -1 -1 0 -2 -1 -2 -1 -1 -3 -1 1 0 -3 -2 0
2 | -2 7 -1 -2 -4 1 0 -3 0 -4 -3 3 -2 -3 -3 -1 -1 -3 -1 -3
3 | -1 -1 7 2 -2 0 0 0 1 -3 -4 0 -2 -4 -2 1 0 -4 -2 -3
4 | -2 -2 2 8 -4 0 2 -1 -1 -4 -4 -1 -4 -5 -1 0 -1 -5 -3 -4
5 | -1 -4 -2 -4 13 -3 -3 -3 -3 -2 -2 -3 -2 -2 -4 -1 -1 -5 -3 -1
6 | -1 1 0 0 -3 7 2 -2 1 -3 -2 2 0 -4 -1 0 -1 -1 -1 -3
7 | -1 0 0 2 -3 2 6 -3 0 -4 -3 1 -2 -3 -1 -1 -1 -3 -2 -3
8 | 0 -3 0 -1 -3 -2 -3 8 -2 -4 -4 -2 -3 -4 -2 0 -2 -3 -3 -4
9 | -2 0 1 -1 -3 1 0 -2 10 -4 -3 0 -1 -1 -2 -1 -2 -3 2 -4
10 | -1 -4 -3 -4 -2 -3 -4 -4 -4 5 2 -3 2 0 -3 -3 -1 -3 -1 4
11 | -2 -3 -4 -4 -2 -2 -3 -4 -3 2 5 -3 3 1 -4 -3 -1 -2 -1 1
12 | -1 3 0 -1 -3 2 1 -2 0 -3 -3 6 -2 -4 -1 0 -1 -3 -2 -3
13 | -1 -2 -2 -4 -2 0 -2 -3 -1 2 3 -2 7 0 -3 -2 -1 -1 0 1
14 | -3 -3 -4 -5 -2 -4 -3 -4 -1 0 1 -4 0 8 -4 -3 -2 1 4 -1
15 | -1 -3 -2 -1 -4 -1 -1 -2 -2 -3 -4 -1 -3 -4 10 -1 -1 -4 -3 -3
16 | 1 -1 1 0 -1 0 -1 0 -1 -3 -3 0 -2 -3 -1 5 2 -4 -2 -2
17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 2 5 -3 -2 0
18 | -3 -3 -4 -5 -5 -1 -3 -3 -3 -3 -2 -3 -1 1 -4 -4 -3 15 2 -3
19 | -2 -1 -2 -3 -3 -1 -2 -3 2 -1 -1 -2 0 4 -3 -2 -2 2 8 -1
20 | 0 -3 -3 -4 -1 -3 -3 -4 -4 4 1 -3 1 -1 -3 -2 0 -3 -1 5
21 | -2 -1 4 5 -3 0 1 -1 0 -4 -4 0 -3 -4 -2 0 0 -5 -3 -4
22 | -1 0 0 1 -3 4 5 -2 0 -3 -3 1 -1 -4 -1 0 -1 -2 -2 -3
23 | -1 -1 -1 -1 -2 -1 -1 -2 -1 -1 -1 -1 -1 -2 -2 -1 0 -3 -1 -1
24 | -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5
25 |
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/matrices/BLOSUM62:
--------------------------------------------------------------------------------
1 | 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0
2 | -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3
3 | -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3
4 | -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3
5 | 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1
6 | -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2
7 | -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2
8 | 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3
9 | -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3
10 | -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3
11 | -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1
12 | -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2
13 | -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1
14 | -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1
15 | -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2
16 | 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2
17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0
18 | -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3
19 | -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1
20 | 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4
21 | -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3
22 | -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2
23 | 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1
24 | -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4
--------------------------------------------------------------------------------
/src/proteusAI/io_tools/matrices/alphabet:
--------------------------------------------------------------------------------
1 | A
2 | R
3 | N
4 | D
5 | C
6 | Q
7 | E
8 | G
9 | H
10 | I
11 | L
12 | K
13 | M
14 | F
15 | P
16 | S
17 | T
18 | W
19 | Y
20 | V
21 |
--------------------------------------------------------------------------------
/src/proteusAI/mining_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for mining_tools.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.mining_tools.alphafoldDB import * # noqa: F403
12 | from proteusAI.mining_tools.blast import * # noqa: F403
13 | from proteusAI.mining_tools.uniprot import * # noqa: F403
14 |
--------------------------------------------------------------------------------
/src/proteusAI/mining_tools/alphafoldDB.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | import requests
8 | import os
9 |
10 |
11 | def get_AF2_pdb(protein_id: str, out_path: str) -> bool:
12 | """
13 | This function takes in a UniProt ID and an output path and downloads the corresponding AlphaFold model
14 | from the EBI AlphaFold database in PDB format.
15 |
16 | Parameters:
17 | protein_id (str): The UniProt ID of the protein
18 | out_path (str): The path to save the PDB file to. The directory containing the output file will be created if it does not exist.
19 |
20 | Returns:
21 | bool: True if the PDB file was downloaded successfully, False otherwise.
22 | """
23 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
24 |
25 | requestURL = f"https://alphafold.ebi.ac.uk/files/AF-{protein_id}-F1-model_v3.pdb"
26 | r = requests.get(requestURL)
27 |
28 | if r.status_code == 200:
29 | with open(out_path, "wb") as f:
30 | f.write(r.content)
31 | return True
32 | else:
33 | return False
34 |
--------------------------------------------------------------------------------
/src/proteusAI/mining_tools/blast.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | from tempfile import gettempdir
8 | import biotite.sequence.io.fasta as fasta
9 | import biotite.application.blast as blast
10 | import biotite.database.entrez as entrez
11 |
12 |
13 | # TODO: change email
14 | def search_related_sequences(
15 | query: str,
16 | program: str = "blastp",
17 | database: str = "nr",
18 | obey_rules: bool = True,
19 | mail: str = "johnyfunk@gmail.com",
20 | ):
21 | """
22 | Search for related sequences using the plast web app.
23 |
24 | Parameters:
25 | query (str):
26 | Query sequence
27 | program (str, optional):
28 | The specific BLAST program. One of 'blastn', 'megablast', 'blastp', 'blastx', 'tblastn' and 'tblastx'.
29 | database (str, optional):
30 | The NCBI sequence database to blast against. By default it contains all sequences (`database`='nr'`).
31 | obey_rules (bool, optional):
32 | If true, the application raises an :class:`RuleViolationError`, if the server is contacted too often,
33 | based on the NCBI BLAST usage rules. (Default: True)
34 | mail : str, optional
35 | If a mail address is provided, it will be appended in the
36 | HTTP request. This allows the NCBI to contact you in case
37 | your application sends too many requests.
38 |
39 | Returns:
40 | tuple: two lits, the first containing the hits and the second the hit sequences
41 |
42 | Example:
43 | hits, hit_seqs = blast_related_sequences(query=sequence, database='swissprot')
44 | """
45 | # Search only the UniProt/SwissProt database
46 | blast_app = blast.BlastWebApp(
47 | program=program,
48 | query=query,
49 | database=database,
50 | obey_rules=obey_rules,
51 | mail=mail,
52 | )
53 | blast_app.start()
54 | blast_app.join()
55 | alignments = blast_app.get_alignments()
56 | # Get hit IDs for hits with score > 200
57 | hits = []
58 | for ali in alignments:
59 | if ali.score > 200:
60 | hits.append(ali.hit_id)
61 | # Get the sequences from hit IDs
62 | hit_seqs = []
63 | for hit in hits:
64 | file_name = entrez.fetch(hit, gettempdir(), "fa", "protein", "fasta")
65 | fasta_file = fasta.FastaFile.read(file_name)
66 | hit_seqs.append(fasta.get_sequence(fasta_file))
67 |
68 | return hits, hit_seqs
69 |
--------------------------------------------------------------------------------
/src/proteusAI/mining_tools/uniprot.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | from Bio import SeqIO
8 | import requests
9 | from io import StringIO
10 | from hashlib import md5
11 |
12 |
13 | def get_protein_sequence(uniprot_id: str) -> str:
14 | """
15 | This function takes a UniProt ID as input and returns the corresponding sequence record.
16 |
17 | Parameters:
18 | uniprot_id (str): The UniProt identifier of the protein of interest
19 |
20 | Returns:
21 | list: List with results as Bio.SeqRecord.SeqRecord object
22 |
23 | Example:
24 | ASMT_representations = uniprot.get_protein_sequence('P46597')
25 | sequence_string = str(ASMT_representations[0].seq)
26 | """
27 | base_url = "http://www.uniprot.org/uniprot/"
28 | current_url = base_url + uniprot_id + ".fasta"
29 | response = requests.post(current_url)
30 | c_data = "".join(response.text)
31 |
32 | seq = StringIO(c_data)
33 | p_seq = list(SeqIO.parse(seq, "fasta"))
34 | return p_seq
35 |
36 |
37 | def get_uniprot_id(sequence: str) -> str:
38 | """
39 | This function takes in a protein sequence string and returns the corresponding UniProt ID.
40 |
41 | Parameters:
42 | sequence (str): The protein sequence string
43 |
44 | Returns:
45 | str: The UniProt ID of the protein, if it exists in UniProt. None otherwise.
46 | """
47 | h = md5(sequence.encode()).digest().hex()
48 | requestURL = f"https://www.ebi.ac.uk/proteins/api/proteins?offset=0&size=100&md5={h}" # 5e2c446cc1c54ee4406b9f6683b7f98d
49 | r = requests.get(requestURL, headers={"Accept": "application/json"})
50 |
51 | if not r.ok:
52 | return None
53 | data = r.json()
54 |
55 | if len(data) == 0:
56 | return None
57 | else:
58 | return data[0]["accession"]
59 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for machine learning
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/bo_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for bayesian optimization tools.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.ml_tools.bo_tools.acq_fn import * # noqa: F403
12 | from proteusAI.ml_tools.bo_tools.genetic_algorithm import * # noqa: F403
13 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/bo_tools/acq_fn.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk and Laura Sofia Machado"
6 |
7 | import numpy as np
8 | from scipy.stats import norm
9 |
10 |
11 | def greedy(mean, std=None, current_best=None, xi=None):
12 | """
13 | Greedy acquisition function.
14 |
15 | Args:
16 | mean (np.array): This is the mean function from the GP over the considered set of points.
17 | std (np.array, optional): This is the standard deviation function from the GP over the considered set of points. Default is None.
18 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None.
19 | xi (float, optional): Small value added to avoid corner cases. Default is None.
20 |
21 | Returns:
22 | np.array: The mean values for all the points, as greedy acquisition selects the best based on mean.
23 | """
24 | return mean
25 |
26 |
27 | def EI(mean, std, current_best, xi=0.1):
28 | """
29 | Expected Improvement acquisition function.
30 |
31 | It implements the following function:
32 |
33 | | (mu - mu^+ - xi) Phi(Z) + sigma phi(Z) if sigma > 0
34 | EI(x) = |
35 | | 0 if sigma = 0
36 |
37 | where Phi is the CDF and phi the PDF of the normal distribution
38 | and
39 | Z = (mu - mu^+ - xi) / sigma
40 |
41 | Args:
42 | mean (np.array): This is the mean function from the GP over the considered set of points.
43 | std (np.array): This is the standard deviation function from the GP over the considered set of points.
44 | current_best (float): This is the current maximum of the unknown function: mu^+.
45 | xi (float): Small value added to avoid corner cases.
46 |
47 | Returns:
48 | np.array: The value of this acquisition function for all the points.
49 | """
50 |
51 | Z = (mean - current_best - xi) / (std + 1e-9)
52 | EI = (mean - current_best - xi) * norm.cdf(Z) + std * norm.pdf(Z)
53 | EI[std == 0] = 0
54 |
55 | return EI
56 |
57 |
58 | def UCB(mean, std, current_best=None, kappa=1.5):
59 | """
60 | Upper-Confidence Bound acquisition function.
61 |
62 | Args:
63 | mean (np.array): This is the mean function from the GP over the considered set of points.
64 | std (np.array): This is the standard deviation function from the GP over the considered set of points.
65 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None.
66 | kappa (float): Exploration-exploitation trade-off parameter. The higher the value, the more exploration. Default is 0.
67 |
68 | Returns:
69 | np.array: The value of this acquisition function for all the points.
70 | """
71 | return mean + kappa * std
72 |
73 |
74 | def random_acquisition(mean, std=None, current_best=None, xi=None):
75 | """
76 | Random acquisition function. Assigns random acquisition values to all points in the unobserved set.
77 |
78 | Args:
79 | mean (np.array): This is the mean function from the GP over the considered set of points.
80 | std (np.array, optional): This is the standard deviation function from the GP over the considered set of points. Default is None.
81 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None.
82 | xi (float, optional): Small value added to avoid corner cases. Default is None.
83 |
84 | Returns:
85 | np.array: Random acquisition values for all points in the unobserved set.
86 | """
87 | n_unobserved = len(mean)
88 | np.random.seed(None)
89 | random_acq_values = np.random.random(n_unobserved)
90 | return random_acq_values
91 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/bo_tools/genetic_algorithm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | #####################################
6 | ### Simulated Annealing Discovery ###
7 | #####################################
8 |
9 |
10 | def precompute_distances(vectors):
11 | """Precompute the pairwise Euclidean distance matrix."""
12 | num_vectors = len(vectors)
13 | distance_matrix = np.zeros((num_vectors, num_vectors))
14 |
15 | for i in range(num_vectors):
16 | for j in range(i + 1, num_vectors):
17 | dist = np.linalg.norm(vectors[i] - vectors[j])
18 | distance_matrix[i, j] = dist
19 | distance_matrix[j, i] = dist
20 |
21 | return distance_matrix
22 |
23 |
24 | def diversity_score_incremental(
25 | current_score, selected_indices, idx_in, idx_out, distance_matrix
26 | ):
27 | """Update the diversity score incrementally when a vector is swapped."""
28 | new_score = current_score
29 |
30 | for idx in selected_indices:
31 | if idx != idx_out:
32 | new_score -= distance_matrix[idx_out, idx]
33 | new_score += distance_matrix[idx_in, idx]
34 |
35 | return new_score
36 |
37 |
38 | def simulated_annealing(
39 | vectors,
40 | N,
41 | initial_temperature=1000.0,
42 | cooling_rate=0.003,
43 | max_iterations=10000,
44 | pbar=None,
45 | ):
46 | """
47 | Simulated Annealing to select N vectors that maximize diversity.
48 |
49 | Args:
50 | vectors (list): List of numpy arrays.
51 | N (int): Number of sequences that should be sampled.
52 | initial_temperature (float): Initial temperature of the simulated annealing algorithm. Default 1000.0.
53 | cooling_rate (float): Cooling rate of the simulated annealing algorithm. Default 0.003.
54 | max_iterations (int): Maximum number of iterations of the simulated annealing algorithm. Default 10000.
55 |
56 | Returns:
57 | list: Indices of diverse vectors.
58 | """
59 |
60 | if pbar:
61 | pbar.set(message="Computing distance matrix", detail="...")
62 |
63 | # Precompute all pairwise distances
64 | distance_matrix = precompute_distances(vectors)
65 |
66 | # Randomly initialize the selection of N vectors
67 | selected_indices = random.sample(range(len(vectors)), N)
68 | current_score = sum(
69 | distance_matrix[i, j]
70 | for i in selected_indices
71 | for j in selected_indices
72 | if i < j
73 | )
74 |
75 | temperature = initial_temperature
76 | best_score = current_score
77 | best_selection = selected_indices[:]
78 |
79 | for iteration in range(max_iterations):
80 |
81 | if pbar:
82 | pbar.set(iteration, message="Minimizing energy", detail="...")
83 |
84 | # Randomly select a vector to swap
85 | idx_out = random.choice(selected_indices)
86 | idx_in = random.choice(
87 | [i for i in range(len(vectors)) if i not in selected_indices]
88 | )
89 |
90 | # Incrementally update the diversity score
91 | new_score = diversity_score_incremental(
92 | current_score, selected_indices, idx_in, idx_out, distance_matrix
93 | )
94 |
95 | # Decide whether to accept the new solution
96 | delta = new_score - current_score
97 | if delta > 0 or np.exp(delta / temperature) > random.random():
98 | selected_indices.remove(idx_out)
99 | selected_indices.append(idx_in)
100 | current_score = new_score
101 |
102 | # Update the best solution found so far
103 | if new_score > best_score:
104 | best_score = new_score
105 | best_selection = selected_indices[:]
106 |
107 | # Cool down the temperature
108 | temperature *= 1 - cooling_rate
109 |
110 | # Early stopping if the temperature is low enough
111 | # if temperature < 1e-8:
112 | # break
113 |
114 | return best_selection, best_score
115 |
116 |
117 | #######################################
118 | ### Genetic Algorithm for Mutations ###
119 | #######################################
120 |
121 |
122 | def find_mutations(sequences):
123 | """
124 | Takes a list of protein sequences and returns a dictionary with mutation positions
125 | as keys and lists of amino acids at those positions as values.
126 |
127 | Parameters:
128 | sequences (list): A list of protein sequences (strings).
129 |
130 | Returns:
131 | dict: A dictionary where keys are positions (1-indexed) and values are lists of amino acids at those positions.
132 | """
133 | # Check if the list is empty
134 | if not sequences:
135 | return {}
136 |
137 | # Initialize a dictionary to store mutations
138 | mutations = {}
139 |
140 | # Get the reference sequence (assuming all sequences have the same length)
141 | reference_seq = sequences[0]
142 |
143 | # Iterate through each position in the sequence
144 | for i in range(len(reference_seq)):
145 | # Initialize a set to store the different amino acids at this position
146 | amino_acids = set()
147 |
148 | # Check the amino acid at this position in each sequence
149 | for seq in sequences:
150 | amino_acids.add(seq[i])
151 |
152 | # If there is more than one unique amino acid at this position, it's a mutation
153 | if len(amino_acids) > 1:
154 | mutations[i + 1] = list(amino_acids) # +1 to make position 1-indexed
155 |
156 | return mutations
157 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/esm_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for protein language activity_prediction
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.ml_tools.esm_tools import * # noqa: F403
12 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/esm_tools/alphabet.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/src/proteusAI/ml_tools/esm_tools/alphabet.pt
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/sklearn_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for sklearn_tools activity_prediction.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.ml_tools.sklearn_tools.grid_search import * # noqa: F403
12 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/sklearn_tools/grid_search.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | __name__ = "proteusAI"
5 | __author__ = "Jonathan Funk"
6 |
7 | from sklearn.model_selection import GridSearchCV
8 | from sklearn.ensemble import RandomForestRegressor
9 | from scipy.stats import pearsonr
10 | from sklearn.neighbors import KNeighborsRegressor
11 | import pandas as pd
12 | from sklearn.svm import SVR
13 | import numpy
14 |
15 |
16 | def knnr_grid_search(
17 | Xs_train: numpy.ndarray,
18 | Xs_test: numpy.ndarray,
19 | ys_train: list,
20 | ys_test: list,
21 | param_grid: dict = None,
22 | verbose: int = 1,
23 | ):
24 | """
25 | Performs a KNN regressor grid search using 5 fold cross validation.
26 |
27 | Parameters:
28 | -----------
29 | Xs_train, Xs_test (numpy.ndarray): train and test values for training
30 | ys_train, ys_test (list): y values for train and test
31 | param_grid (dict): parameter grid for model
32 | verbose (int): Type of information printing during run
33 |
34 | Returns:
35 | --------
36 | Returns the best performing model of grid search, the test R squared value, correlation coefficient,
37 | p-value of the fit and a dataframe containing the fit information.
38 | """
39 | # Instantiate the model
40 | knnr = KNeighborsRegressor()
41 |
42 | if param_grid is None:
43 | param_grid = {
44 | "n_neighbors": [3, 5, 7],
45 | "weights": ["uniform", "distance"],
46 | "algorithm": ["ball_tree", "kd_tree", "brute"],
47 | "leaf_size": [10, 15, 20],
48 | "p": [1, 2, 3],
49 | }
50 | # Create a GridSearchCV object and fit the model
51 | # grid_search = GridSearchCV(estimator=knnr, param_grid=param_grid, cv=5, n_jobs=-1, verbose=1)
52 | grid_search = GridSearchCV(
53 | estimator=knnr,
54 | param_grid=param_grid,
55 | scoring="r2",
56 | cv=5,
57 | verbose=verbose,
58 | n_jobs=-1, # use all available cores
59 | )
60 |
61 | grid_search.fit(Xs_train, ys_train)
62 | test_r2 = grid_search.score(Xs_test, ys_test)
63 | predictions = grid_search.best_estimator_.predict(Xs_test)
64 | corr_coef, p_value = pearsonr(predictions, ys_test)
65 |
66 | # Print the best hyperparameters and the best score
67 | if verbose is not None:
68 | print("Best hyperparameters: ", grid_search.best_params_)
69 | print("Best score: ", grid_search.best_score_)
70 | print("Test R^2 score: ", test_r2)
71 | print("Correlation coefficient: {:.2f}".format(corr_coef))
72 | print("p-value: {:.4f}".format(p_value))
73 |
74 | return (
75 | grid_search.best_estimator_,
76 | test_r2,
77 | corr_coef,
78 | p_value,
79 | pd.DataFrame.from_dict(grid_search.cv_results_),
80 | )
81 |
82 |
83 | def rfr_grid_search(
84 | Xs_train: numpy.ndarray,
85 | Xs_test: numpy.ndarray,
86 | ys_train: list,
87 | ys_test: list,
88 | param_grid: dict = None,
89 | verbose: int = 1,
90 | ):
91 | """
92 | Performs a Random Forrest regressor grid search using 5 fold cross validation.
93 |
94 | Parameters:
95 | -----------
96 | Xs_train, Xs_test (numpy.ndarray): train and test values for training
97 | ys_train, ys_test (list): y values for train and test
98 | param_grid (dict): parameter grid for model
99 | verbose (int): Type of information printing during run
100 |
101 | Returns:
102 | --------
103 | Returns the best performing model of grid search, the test R squared value, correlation coefficient,
104 | p-value of the fit and a dataframe containing the fit information.
105 | """
106 | # Instantiate the model
107 | rfr = RandomForestRegressor(random_state=42)
108 |
109 | if param_grid is None:
110 | param_grid = {
111 | "n_estimators": [20, 50, 100, 200],
112 | "criterion": ["squared_error", "absolute_error"],
113 | "max_features": ["sqrt", "log2"],
114 | "max_depth": [5, 10, 15],
115 | "min_samples_split": [2, 5, 10],
116 | "min_samples_leaf": [1, 4],
117 | }
118 |
119 | # Create a GridSearchCV object and fit the model
120 | grid_search = GridSearchCV(
121 | estimator=rfr,
122 | param_grid=param_grid,
123 | scoring="r2",
124 | cv=5,
125 | verbose=verbose,
126 | n_jobs=-1, # use all available cores
127 | )
128 | grid_search.fit(Xs_train, ys_train)
129 |
130 | # Evaluate the performance of the model on the test set
131 | test_r2 = grid_search.score(Xs_test, ys_test)
132 | predictions = grid_search.best_estimator_.predict(Xs_test)
133 | corr_coef, p_value = pearsonr(predictions, ys_test)
134 |
135 | if verbose is not None:
136 | # Print the best hyperparameters and the best score
137 | print("Best hyperparameters: ", grid_search.best_params_)
138 | print("Best score: ", grid_search.best_score_)
139 | print("Test R^2 score: ", test_r2)
140 | print("Correlation coefficient: {:.2f}".format(corr_coef))
141 | print("p-value: {:.4f}".format(p_value))
142 |
143 | return (
144 | grid_search.best_estimator_,
145 | test_r2,
146 | corr_coef,
147 | p_value,
148 | pd.DataFrame.from_dict(grid_search.cv_results_),
149 | )
150 |
151 |
152 | def svr_grid_search(
153 | Xs_train: numpy.ndarray,
154 | Xs_test: numpy.ndarray,
155 | ys_train: list,
156 | ys_test: list,
157 | param_grid: dict = None,
158 | verbose: int = 1,
159 | ):
160 | """
161 | Performs a Support Vector regressor grid search using 5 fold cross validation.
162 |
163 | Parameters:
164 | -----------
165 | Xs_train, Xs_test (numpy.ndarray): train and test values for training
166 | ys_train, ys_test (list): y values for train and test
167 | param_grid (dict): parameter grid for model
168 | verbose (int): Type of information printing during run
169 |
170 | Returns:
171 | --------
172 | Returns the best performing model of grid search, the test R squared value, correlation coefficient,
173 | p-value of the fit and a dataframe containing the fit information.
174 | """
175 | # Instantiate the model
176 | svr = SVR()
177 |
178 | if param_grid is None:
179 | param_grid = {
180 | "C": [0.1, 1, 2, 5, 10, 100, 200, 400],
181 | "gamma": ["scale"],
182 | "kernel": ["linear", "poly", "rbf", "sigmoid"],
183 | "degree": [3],
184 | }
185 |
186 | # Create a GridSearchCV object and fit the model
187 | grid_search = GridSearchCV(
188 | estimator=svr,
189 | param_grid=param_grid,
190 | scoring="r2",
191 | cv=5,
192 | verbose=verbose,
193 | n_jobs=-1, # use all available cores
194 | )
195 | grid_search.fit(Xs_train, ys_train)
196 |
197 | # Evaluate the performance of the model on the test set
198 | test_r2 = grid_search.score(Xs_test, ys_test)
199 | predictions = grid_search.best_estimator_.predict(Xs_test)
200 | corr_coef, p_value = pearsonr(predictions, ys_test)
201 |
202 | if verbose is not None:
203 | # Print the best hyperparameters and the best score
204 | print("Best hyperparameters: ", grid_search.best_params_)
205 | print("Best score: ", grid_search.best_score_)
206 | print("Test R^2 score: ", test_r2)
207 | print("Correlation coefficient: {:.2f}".format(corr_coef))
208 | print("p-value: {:.4f}".format(p_value))
209 |
210 | return (
211 | grid_search.best_estimator_,
212 | test_r2,
213 | corr_coef,
214 | p_value,
215 | pd.DataFrame.from_dict(grid_search.cv_results_),
216 | grid_search.best_params_,
217 | )
218 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/torch_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for pytorch tools.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.ml_tools.torch_tools.torch_tools import * # noqa: F403
12 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/torch_tools/matrices/BLOSUM50:
--------------------------------------------------------------------------------
1 | 5 -2 -1 -2 -1 -1 -1 0 -2 -1 -2 -1 -1 -3 -1 1 0 -3 -2 0
2 | -2 7 -1 -2 -4 1 0 -3 0 -4 -3 3 -2 -3 -3 -1 -1 -3 -1 -3
3 | -1 -1 7 2 -2 0 0 0 1 -3 -4 0 -2 -4 -2 1 0 -4 -2 -3
4 | -2 -2 2 8 -4 0 2 -1 -1 -4 -4 -1 -4 -5 -1 0 -1 -5 -3 -4
5 | -1 -4 -2 -4 13 -3 -3 -3 -3 -2 -2 -3 -2 -2 -4 -1 -1 -5 -3 -1
6 | -1 1 0 0 -3 7 2 -2 1 -3 -2 2 0 -4 -1 0 -1 -1 -1 -3
7 | -1 0 0 2 -3 2 6 -3 0 -4 -3 1 -2 -3 -1 -1 -1 -3 -2 -3
8 | 0 -3 0 -1 -3 -2 -3 8 -2 -4 -4 -2 -3 -4 -2 0 -2 -3 -3 -4
9 | -2 0 1 -1 -3 1 0 -2 10 -4 -3 0 -1 -1 -2 -1 -2 -3 2 -4
10 | -1 -4 -3 -4 -2 -3 -4 -4 -4 5 2 -3 2 0 -3 -3 -1 -3 -1 4
11 | -2 -3 -4 -4 -2 -2 -3 -4 -3 2 5 -3 3 1 -4 -3 -1 -2 -1 1
12 | -1 3 0 -1 -3 2 1 -2 0 -3 -3 6 -2 -4 -1 0 -1 -3 -2 -3
13 | -1 -2 -2 -4 -2 0 -2 -3 -1 2 3 -2 7 0 -3 -2 -1 -1 0 1
14 | -3 -3 -4 -5 -2 -4 -3 -4 -1 0 1 -4 0 8 -4 -3 -2 1 4 -1
15 | -1 -3 -2 -1 -4 -1 -1 -2 -2 -3 -4 -1 -3 -4 10 -1 -1 -4 -3 -3
16 | 1 -1 1 0 -1 0 -1 0 -1 -3 -3 0 -2 -3 -1 5 2 -4 -2 -2
17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 2 5 -3 -2 0
18 | -3 -3 -4 -5 -5 -1 -3 -3 -3 -3 -2 -3 -1 1 -4 -4 -3 15 2 -3
19 | -2 -1 -2 -3 -3 -1 -2 -3 2 -1 -1 -2 0 4 -3 -2 -2 2 8 -1
20 | 0 -3 -3 -4 -1 -3 -3 -4 -4 4 1 -3 1 -1 -3 -2 0 -3 -1 5
21 | -2 -1 4 5 -3 0 1 -1 0 -4 -4 0 -3 -4 -2 0 0 -5 -3 -4
22 | -1 0 0 1 -3 4 5 -2 0 -3 -3 1 -1 -4 -1 0 -1 -2 -2 -3
23 | -1 -1 -1 -1 -2 -1 -1 -2 -1 -1 -1 -1 -1 -2 -2 -1 0 -3 -1 -1
24 | -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5
25 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/torch_tools/matrices/BLOSUM62:
--------------------------------------------------------------------------------
1 | 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0
2 | -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3
3 | -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3
4 | -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3
5 | 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1
6 | -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2
7 | -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2
8 | 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3
9 | -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3
10 | -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3
11 | -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1
12 | -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2
13 | -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1
14 | -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1
15 | -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2
16 | 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2
17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0
18 | -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3
19 | -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1
20 | 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4
21 | -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3
22 | -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2
23 | 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1
24 | -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/torch_tools/matrices/alphabet:
--------------------------------------------------------------------------------
1 | A
2 | R
3 | N
4 | D
5 | C
6 | Q
7 | E
8 | G
9 | H
10 | I
11 | L
12 | K
13 | M
14 | F
15 | P
16 | S
17 | T
18 | W
19 | Y
20 | V
21 |
--------------------------------------------------------------------------------
/src/proteusAI/ml_tools/torch_tools/torch_tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import seaborn as sns
6 | from typing import Union
7 | import gpytorch
8 |
9 |
10 | def one_hot_encoder(sequences, alphabet=None, canonical=True, pbar=None, padding=None):
11 | """
12 | Encodes sequences provided an alphabet.
13 |
14 | Parameters:
15 | sequences (list or str): list of amino acid sequences or a single sequence.
16 | alphabet (list or None): list of characters in the alphabet or None to load from file.
17 | canonical (bool): only use canonical amino acids.
18 | padding (int or None): the length to which all sequences should be padded.
19 | If None, no padding beyond the length of the longest sequence.
20 |
21 | Returns:
22 | torch.Tensor: (number of sequences, padding or maximum sequence length, size of the alphabet) for list input
23 | (padding or maximum sequence length, size of the alphabet) for string input
24 | """
25 | # Check if sequences is a string
26 | if isinstance(sequences, str):
27 | singular = True
28 | sequences = [sequences] # Make it a list to use the same code below
29 | else:
30 | singular = False
31 |
32 | # Load the alphabet from a file if it's not provided
33 | if alphabet is None:
34 | # Get the directory of the current script
35 | script_dir = os.path.dirname(os.path.realpath(__file__))
36 |
37 | ### Amino Acid codes
38 | alphabet_file = os.path.join(script_dir, "matrices/alphabet")
39 | alphabet = np.loadtxt(alphabet_file, dtype=str)
40 |
41 | # If canonical is True, only use the first 20 characters of the alphabet
42 | if canonical:
43 | alphabet = alphabet[:20]
44 |
45 | # Create a dictionary to map each character in the alphabet to its index
46 | alphabet_dict = {char: i for i, char in enumerate(alphabet)}
47 |
48 | # Determine the length to which sequences should be padded
49 | max_sequence_length = max(len(sequence) for sequence in sequences)
50 | padded_length = padding if padding is not None else max_sequence_length
51 |
52 | n_sequences = len(sequences)
53 | alphabet_size = len(alphabet)
54 |
55 | # Create an empty tensor of the right size, with padding length
56 | tensor = torch.zeros((n_sequences, padded_length, alphabet_size))
57 |
58 | # Fill the tensor
59 | for i, sequence in enumerate(sequences):
60 | if pbar:
61 | pbar.set(
62 | i, message="Computing", detail=f"{i}/{len(sequences)} remaining..."
63 | )
64 | for j, character in enumerate(sequence):
65 | if j >= padded_length:
66 | break # Stop if the sequence length exceeds the padded length
67 | # Get the index of the character in the alphabet
68 | char_index = alphabet_dict.get(
69 | character, -1
70 | ) # Return -1 if character is not in the alphabet
71 | if char_index != -1:
72 | # Set the corresponding element of the tensor to 1
73 | tensor[i, j, char_index] = 1.0
74 |
75 | # If the input was a string, return a tensor of shape (padded_length, alphabet_size)
76 | if singular:
77 | tensor = tensor.squeeze(0)
78 |
79 | return tensor
80 |
81 |
82 | def blosum_encoding(
83 | sequences, matrix="BLOSUM62", canonical=True, pbar=None, padding=None
84 | ):
85 | """
86 | Returns BLOSUM encoding for amino acid sequences. Unknown amino acids will be
87 | encoded with 0.5 in the entire row.
88 |
89 | Parameters:
90 | -----------
91 | sequences (list or str): List of amino acid sequences or a single sequence.
92 | matrix (str): Choice of BLOSUM matrix. Can be 'BLOSUM50' or 'BLOSUM62'.
93 | canonical (bool): Only use canonical amino acids.
94 | padding (int or None): The length to which all sequences should be padded.
95 | If None, no padding beyond the length of the longest sequence.
96 | pbar: Progress bar for shiny app.
97 |
98 | Returns:
99 | --------
100 | torch.Tensor: BLOSUM encoded sequence.
101 | """
102 |
103 | # Check if sequences is a string
104 | if isinstance(sequences, str):
105 | singular = True
106 | sequences = [sequences] # Make it a list to use the same code below
107 | else:
108 | singular = False
109 |
110 | # Get the directory of the current script
111 | script_dir = os.path.dirname(os.path.realpath(__file__))
112 |
113 | # Load the alphabet
114 | alphabet_file = os.path.join(script_dir, "matrices/alphabet")
115 | alphabet = np.loadtxt(alphabet_file, dtype=str)
116 |
117 | # Define BLOSUM matrices
118 | _blosum50 = (
119 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM50"), dtype=float)
120 | .reshape((24, -1))
121 | .T
122 | )
123 | _blosum62 = (
124 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM62"), dtype=float)
125 | .reshape((24, -1))
126 | .T
127 | )
128 |
129 | # Choose BLOSUM matrix
130 | if matrix == "BLOSUM50":
131 | matrix = _blosum50
132 | elif matrix == "BLOSUM62":
133 | matrix = _blosum62
134 | else:
135 | raise ValueError(
136 | "Invalid BLOSUM matrix choice. Choose 'BLOSUM50' or 'BLOSUM62'."
137 | )
138 |
139 | # Create the BLOSUM encoding dictionary
140 | blosum_matrix = {}
141 | for i, letter_1 in enumerate(alphabet):
142 | if canonical:
143 | blosum_matrix[letter_1] = matrix[i][:20]
144 | else:
145 | blosum_matrix[letter_1] = matrix[i]
146 |
147 | # Determine the length to which sequences should be padded
148 | max_sequence_length = max(len(sequence) for sequence in sequences)
149 | padded_length = padding if padding is not None else max_sequence_length
150 |
151 | n_sequences = len(sequences)
152 | alphabet_size = len(blosum_matrix["A"])
153 |
154 | # Create an empty tensor of the right size, with padding length
155 | tensor = torch.zeros((n_sequences, padded_length, alphabet_size))
156 |
157 | # Convert each amino acid in sequence to BLOSUM encoding
158 | for i, sequence in enumerate(sequences):
159 | if pbar:
160 | pbar.set(
161 | i, message="Computing", detail=f"{i}/{len(sequences)} remaining..."
162 | )
163 | for j, aa in enumerate(sequence):
164 | if j >= padded_length:
165 | break # Stop if the sequence length exceeds the padded length
166 | if aa in alphabet:
167 | tensor[i, j, :] = torch.tensor(blosum_matrix[aa])
168 | else:
169 | # Handle unknown amino acids with a default value of 0.5
170 | tensor[i, j, :] = 0.5
171 |
172 | # If the input was a string, return a tensor of shape (padded_length, alphabet_size)
173 | if singular:
174 | tensor = tensor.squeeze(0)
175 |
176 | return tensor
177 |
178 |
179 | # Define the VHSE dictionary
180 | vhse_dict = {
181 | "A": [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
182 | "R": [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
183 | "N": [-0.99, 0.00, -0.37, 0.69, -0.55, 0.85, 0.73, -0.80],
184 | "D": [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
185 | "C": [0.18, -1.67, -0.46, -0.21, 0.00, 1.20, -1.61, -0.19],
186 | "Q": [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
187 | "E": [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.02],
188 | "G": [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, 2.01, -1.34],
189 | "H": [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
190 | "I": [1.27, -0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
191 | "L": [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
192 | "K": [-1.17, 0.70, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13],
193 | "M": [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
194 | "F": [1.52, 0.61, 0.96, -0.16, 0.25, 0.28, -1.33, -0.20],
195 | "P": [0.22, -0.17, -0.50, 0.05, -0.01, -1.34, -0.19, 3.56],
196 | "S": [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
197 | "T": [-0.34, -0.51, -0.55, -1.06, -0.06, -0.01, -0.79, 0.39],
198 | "W": [1.50, 2.06, 1.79, 0.75, 0.75, -0.13, -1.01, -0.85],
199 | "Y": [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
200 | "V": [0.76, -0.92, -0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
201 | }
202 |
203 |
204 | def vhse_encoder(sequences, padding=None, pbar=None):
205 | """
206 | Encodes sequences using VHSE descriptors.
207 |
208 | Parameters:
209 | sequences (list or str): List of amino acid sequences or a single sequence.
210 | padding (int or None): Length to which all sequences should be padded.
211 | If None, no padding beyond the longest sequence.
212 | pbar: Progress bar for tracking (optional).
213 |
214 | Returns:
215 | torch.Tensor: VHSE encoded tensor of shape
216 | (number of sequences, padding or max sequence length, 8)
217 | """
218 | # Ensure input is a list
219 | if isinstance(sequences, str):
220 | singular = True
221 | sequences = [sequences]
222 | else:
223 | singular = False
224 |
225 | # Determine maximum sequence length and apply padding
226 | max_sequence_length = max(len(sequence) for sequence in sequences)
227 | padded_length = padding if padding is not None else max_sequence_length
228 |
229 | n_sequences = len(sequences)
230 | vhse_size = 8 # VHSE descriptors have 8 components
231 |
232 | # Initialize output tensor with zeros
233 | tensor = torch.zeros((n_sequences, padded_length, vhse_size))
234 |
235 | # Encode each sequence
236 | for i, sequence in enumerate(sequences):
237 | if pbar:
238 | pbar.set(i, message="Encoding", detail=f"{i}/{n_sequences} completed...")
239 | for j, aa in enumerate(sequence):
240 | if j >= padded_length:
241 | break
242 | tensor[i, j] = torch.tensor(
243 | vhse_dict.get(aa, [0.5] * vhse_size)
244 | ) # Default for unknown AAs
245 |
246 | # Squeeze output for single sequence input
247 | if singular:
248 | tensor = tensor.squeeze(0)
249 |
250 | return tensor
251 |
252 |
253 | def plot_attention(attention: list, layer: int, head: int, seq: Union[str, list]):
254 | """
255 | Plot the attention weights for a specific layer and head.
256 |
257 | Args:
258 | attention (list): List of attention weights from the model
259 | layer (int): Index of the layer to visualize
260 | head (int): Index of the head to visualize
261 | seq (str): Input sequence as a list of tokens
262 | """
263 |
264 | if isinstance(seq, str):
265 | seq = [char for char in seq]
266 |
267 | # Get the attention weights for the specified layer and head
268 | attn_weights = attention[layer][head].detach().cpu().numpy()
269 |
270 | # Create a heatmap using seaborn
271 | plt.figure(figsize=(10, 10))
272 | sns.heatmap(attn_weights, xticklabels=seq, yticklabels=seq, cmap="viridis")
273 |
274 | # Set plot title and labels
275 | plt.title(f"Attention weights - Layer {layer + 1}, Head {head + 1}")
276 | plt.xlabel("Input tokens")
277 | plt.ylabel("Output tokens")
278 |
279 | # Show the plot
280 | plt.show()
281 |
282 |
283 | class GP(gpytorch.models.ExactGP):
284 | def __init__(
285 | self, train_x, train_y, likelihood, fix_mean=False
286 | ): # special method: instantiate object
287 | super(GP, self).__init__(train_x, train_y, likelihood)
288 | self.mean_module = gpytorch.means.ConstantMean() # attribute
289 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
290 | self.mean_module.constant.data.fill_(1) # Set the mean value to 1
291 | if fix_mean:
292 | self.mean_module.constant.requires_grad_(False)
293 |
294 | def forward(self, x):
295 | mean_x = self.mean_module(x)
296 | covar_x = self.covar_module(x)
297 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
298 |
299 |
300 | def predict_gp(model, likelihood, X):
301 | model.eval()
302 | likelihood.eval()
303 |
304 | with torch.no_grad():
305 | predictions = likelihood(model(X))
306 | y_pred = predictions.mean
307 | y_std = predictions.stddev
308 | # lower, upper = predictions.confidence_region()
309 |
310 | return y_pred, y_std
311 |
312 |
313 | def computeR2(y_true, y_pred):
314 | """
315 | Compute R2-values for to torch tensors.
316 |
317 | Args:
318 | y_true (torch.Tensor): true y-values
319 | y_pred (torch.Tensor): predicted y-values
320 | """
321 | # Ensure the tensors are 1-dimensional
322 | if y_true.dim() != 1 or y_pred.dim() != 1:
323 | raise ValueError("Both y_true and y_pred must be 1-dimensional tensors")
324 |
325 | # Compute the mean of true values
326 | y_mean = torch.mean(y_true)
327 |
328 | # Compute the total sum of squares (SS_tot)
329 | ss_tot = torch.sum((y_true - y_mean) ** 2)
330 |
331 | # Compute the residual sum of squares (SS_res)
332 | ss_res = torch.sum((y_true - y_pred) ** 2)
333 |
334 | # Compute the R² value
335 | r2 = 1 - (ss_res / ss_tot)
336 |
337 | return r2.item() # Convert tensor to float
338 |
--------------------------------------------------------------------------------
/src/proteusAI/struc/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for structures.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.struc.struc import * # noqa: F403
12 |
--------------------------------------------------------------------------------
/src/proteusAI/visual_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # This source code is part of the proteusAI package and is distributed
2 | # under the MIT License.
3 |
4 | """
5 | A subpackage for sklearn_tools activity_prediction.
6 | """
7 |
8 | __name__ = "proteusAI"
9 | __author__ = "Jonathan Funk"
10 |
11 | from proteusAI.visual_tools.plots import * # noqa: F403
12 |
--------------------------------------------------------------------------------
/src/proteusAI/visual_tools/plots.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | from typing import Union, List
5 | import pandas as pd
6 | from sklearn.manifold import TSNE
7 | from sklearn.decomposition import PCA
8 | import numpy as np
9 | import umap
10 |
11 | matplotlib.use("Agg")
12 | representation_dict = {
13 | "One-hot": "ohe",
14 | "BLOSUM50": "blosum50",
15 | "BLOSUM62": "blosum62",
16 | "ESM-2": "esm2",
17 | "ESM-1v": "esm1v",
18 | }
19 |
20 |
21 | def plot_predictions_vs_groundtruth(
22 | y_true: list,
23 | y_pred: list,
24 | title: Union[str, None] = None,
25 | x_label: Union[str, None] = None,
26 | y_label: Union[str, None] = None,
27 | plot_grid: bool = True,
28 | file: Union[str, None] = None,
29 | show_plot: bool = True,
30 | width: Union[float, None] = None,
31 | ):
32 | # Create the plot
33 | fig, ax = plt.subplots(figsize=(10, 5))
34 | sns.scatterplot(x=y_true, y=y_pred, alpha=0.5, ax=ax)
35 |
36 | # Set plot title and labels
37 | if title is None:
38 | title = "Predicted vs. True y-values"
39 | if y_label is None:
40 | y_label = "predicted y"
41 | if x_label is None:
42 | x_label = "y"
43 |
44 | ax.set_title(title)
45 | ax.set_xlabel(x_label)
46 | ax.set_ylabel(y_label)
47 |
48 | # Add the diagonal line
49 | min_val = min(min(y_true) * 1.05, min(y_pred) * 1.05)
50 | max_val = max(max(y_true) * 1.05, max(y_pred) * 1.05)
51 | ax.plot(
52 | [min_val, max_val],
53 | [min_val, max_val],
54 | color="grey",
55 | linestyle="dotted",
56 | linewidth=2,
57 | )
58 |
59 | # Add vertical error bars and confidence region if width is specified
60 | if width is not None:
61 | # Add error bars with T-shape
62 | for true, pred in zip(y_true, y_pred):
63 | ax.plot(
64 | [true, true], [pred - width, pred + width], color="darkgrey", alpha=0.8
65 | )
66 | ax.plot(
67 | [
68 | true - 0.005 * (max_val - min_val),
69 | true + 0.005 * (max_val - min_val),
70 | ],
71 | [pred - width, pred - width],
72 | color="darkgrey",
73 | alpha=0.8,
74 | )
75 | ax.plot(
76 | [
77 | true - 0.005 * (max_val - min_val),
78 | true + 0.005 * (max_val - min_val),
79 | ],
80 | [pred + width, pred + width],
81 | color="darkgrey",
82 | alpha=0.8,
83 | )
84 |
85 | # Add confidence region
86 | x_range = np.linspace(min_val, max_val, 100)
87 | upper_conf = x_range + width
88 | lower_conf = x_range - width
89 | ax.fill_between(
90 | x_range,
91 | lower_conf,
92 | upper_conf,
93 | color="blue",
94 | alpha=0.08,
95 | label="Confidence Region",
96 | )
97 |
98 | # Adjust layout to remove extra whitespace
99 | fig.tight_layout()
100 |
101 | # Add grid if specified
102 | ax.set_xlim(min_val, max_val)
103 | ax.grid(plot_grid)
104 |
105 | # Save the plot to a file
106 | if file is not None:
107 | fig.savefig(file)
108 |
109 | # Show the plot
110 | if show_plot:
111 | plt.show()
112 |
113 | # Return the figure and axes
114 | return fig, ax
115 |
116 |
117 | def plot_tsne(
118 | x: List[np.ndarray],
119 | y: Union[List[Union[float, str]], None] = None,
120 | y_upper: Union[float, None] = None,
121 | y_lower: Union[float, None] = None,
122 | names: Union[List[str], None] = None,
123 | y_type: str = "num",
124 | random_state: int = 42,
125 | rep_type: Union[str, None] = None,
126 | highlight_mask: Union[List[Union[int, float]], None] = None,
127 | highlight_label: str = "Highlighted",
128 | df: Union[pd.DataFrame, None] = None,
129 | ):
130 | """
131 | Create a t-SNE plot and optionally color by y values, with special coloring for points outside given thresholds.
132 | Handles cases where y is None or a list of Nones by not applying hue.
133 | Optionally highlights points based on the highlight_mask.
134 |
135 | Args:
136 | x (List[np.ndarray]): List of sequence representations as numpy arrays.
137 | y (List[Union[float, str]]): List of y values, can be None or contain None.
138 | y_upper (float): Upper threshold for special coloring.
139 | y_lower (float): Lower threshold for special coloring.
140 | names (List[str]): List of names for each point.
141 | y_type (str): 'class' for categorical labels or 'num' for numerical labels.
142 | random_state (int): Random state.
143 | rep_type (str): Representation type used for plotting.
144 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted.
145 | highlight_label (str): Text for the legend entry of highlighted points.
146 | df (pd.DataFrame): DataFrame containing the data to plot.
147 | """
148 | fig, ax = plt.subplots(figsize=(10, 5))
149 |
150 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x])
151 |
152 | if len(x.shape) == 3: # Flatten if necessary
153 | x = x.reshape(x.shape[0], -1)
154 |
155 | if df is None:
156 | tsne = TSNE(n_components=2, verbose=1, random_state=random_state)
157 | z = tsne.fit_transform(x)
158 |
159 | df = pd.DataFrame(z, columns=["z1", "z2"])
160 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative
161 | if names and len(names) == len(y):
162 | df["names"] = names
163 | else:
164 | df["names"] = [None] * len(y)
165 |
166 | # Handle the palette based on whether y is numerical or categorical
167 | if isinstance(y[0], (int, float)): # If y is numerical
168 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True)
169 | else: # If y is categorical
170 | cmap = sns.color_palette("Set2", as_cmap=False)
171 |
172 | hue = (
173 | "y" if df["y"].isnull().sum() != len(df["y"]) else None
174 | ) # Use hue only if y is informative
175 | scatter = sns.scatterplot(
176 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df
177 | )
178 |
179 | # Apply special coloring only if y values are valid and thresholds are provided
180 | if hue and (y_upper is not None or y_lower is not None):
181 | outlier_mask = (
182 | (df["y"] > y_upper)
183 | if y_upper is not None
184 | else np.zeros(len(df), dtype=bool)
185 | )
186 | outlier_mask |= (
187 | (df["y"] < y_lower)
188 | if y_lower is not None
189 | else np.zeros(len(df), dtype=bool)
190 | )
191 | scatter.scatter(
192 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey"
193 | )
194 |
195 | # Highlight points based on the highlight_mask
196 | if highlight_mask is not None:
197 | highlight_mask = np.array(highlight_mask)
198 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask
199 | scatter.scatter(
200 | df["z1"][highlight_points],
201 | df["z2"][highlight_points],
202 | color="red",
203 | marker="x",
204 | s=60,
205 | alpha=0.7,
206 | label=highlight_label,
207 | )
208 |
209 | scatter.set_title(f"t-SNE projection of {rep_type if rep_type else 'data'}")
210 |
211 | # Add the legend, making sure to include highlighted points
212 | handles, labels = scatter.get_legend_handles_labels()
213 | if highlight_label in labels:
214 | ax.legend(handles, labels, title="Legend")
215 | else:
216 | ax.legend(title="Legend")
217 |
218 | return fig, ax, df
219 |
220 |
221 | def plot_umap(
222 | x: List[np.ndarray],
223 | y: Union[List[Union[float, str]], None] = None,
224 | y_upper: Union[float, None] = None,
225 | y_lower: Union[float, None] = None,
226 | names: Union[List[str], None] = None,
227 | y_type: str = "num",
228 | random_state: int = 42,
229 | rep_type: Union[str, None] = None,
230 | highlight_mask: Union[List[Union[int, float]], None] = None,
231 | highlight_label: str = "Highlighted",
232 | df: Union[pd.DataFrame, None] = None,
233 | html: bool = False,
234 | ):
235 | """
236 | Create a UMAP plot and optionally color by y values, with special coloring for points outside given thresholds.
237 | Handles cases where y is None or a list of Nones by not applying hue.
238 | Optionally highlights points based on the highlight_mask.
239 |
240 | Args:
241 | x (List[np.ndarray]): List of sequence representations as numpy arrays.
242 | y (List[Union[float, str]]): List of y values, can be None or contain None.
243 | y_upper (float): Upper threshold for special coloring.
244 | y_lower (float): Lower threshold for special coloring.
245 | names (List[str]): List of names for each point.
246 | y_type (str): 'class' for categorical labels or 'num' for numerical labels.
247 | random_state (int): Random state.
248 | rep_type (str): Representation type used for plotting.
249 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted.
250 | highlight_label (str): Text for the legend entry of highlighted points.
251 | df (pd.DataFrame): DataFrame containing the data to plot.
252 | """
253 | fig, ax = plt.subplots(figsize=(10, 5))
254 |
255 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x])
256 |
257 | if len(x.shape) == 3: # Flatten if necessary
258 | x = x.reshape(x.shape[0], -1)
259 |
260 | if df is None:
261 | umap_model = umap.UMAP(
262 | n_neighbors=70, min_dist=0.0, n_components=2, random_state=random_state
263 | )
264 |
265 | z = umap_model.fit_transform(x)
266 | df = pd.DataFrame(z, columns=["z1", "z2"])
267 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative
268 | if names and len(names) is not None and len(names) == len(y):
269 | df["names"] = names
270 | else:
271 | df["names"] = [None] * len(y)
272 |
273 | # Handle the palette based on whether y is numerical or categorical
274 | if isinstance(y[0], (int, float)): # If y is numerical
275 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True)
276 | else: # If y is categorical
277 | cmap = sns.color_palette("Set2", as_cmap=False)
278 |
279 | hue = (
280 | "y" if df["y"].isnull().sum() != len(df["y"]) else None
281 | ) # Use hue only if y is informative
282 | scatter = sns.scatterplot(
283 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df
284 | )
285 |
286 | # Apply special coloring only if y values are valid and thresholds are provided
287 | if hue and (y_upper is not None or y_lower is not None):
288 | outlier_mask = (
289 | (df["y"] > y_upper)
290 | if y_upper is not None
291 | else np.zeros(len(df), dtype=bool)
292 | )
293 | outlier_mask |= (
294 | (df["y"] < y_lower)
295 | if y_lower is not None
296 | else np.zeros(len(df), dtype=bool)
297 | )
298 | scatter.scatter(
299 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey"
300 | )
301 |
302 | # Highlight points based on the highlight_mask
303 | if highlight_mask is not None:
304 | highlight_mask = np.array(highlight_mask)
305 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask
306 | scatter.scatter(
307 | df["z1"][highlight_points],
308 | df["z2"][highlight_points],
309 | color="red",
310 | marker="x",
311 | s=60,
312 | alpha=0.7,
313 | label=highlight_label,
314 | )
315 |
316 | scatter.set_title(f"UMAP projection of {rep_type if rep_type else 'data'}")
317 |
318 | # Add the legend, making sure to include highlighted points
319 | handles, labels = scatter.get_legend_handles_labels()
320 | if highlight_label in labels:
321 | ax.legend(handles, labels, title="Legend")
322 | else:
323 | ax.legend(title="Legend")
324 |
325 | return fig, ax, df
326 |
327 |
328 | def plot_pca(
329 | x: List[np.ndarray],
330 | y: Union[List[Union[float, str]], None] = None,
331 | y_upper: Union[float, None] = None,
332 | y_lower: Union[float, None] = None,
333 | names: Union[List[str], None] = None,
334 | y_type: str = "num",
335 | random_state: int = 42,
336 | rep_type: Union[str, None] = None,
337 | highlight_mask: Union[List[Union[int, float]], None] = None,
338 | highlight_label: str = "Highlighted",
339 | df: Union[pd.DataFrame, None] = None,
340 | ):
341 | """
342 | Create a PCA plot and optionally color by y values, with special coloring for points outside given thresholds.
343 | Handles cases where y is None or a list of Nones by not applying hue.
344 | Optionally highlights points based on the highlight_mask.
345 |
346 | Args:
347 | x (List[np.ndarray]): List of sequence representations as numpy arrays.
348 | y (List[Union[float, str]]): List of y values, can be None or contain None.
349 | y_upper (float): Upper threshold for special coloring.
350 | y_lower (float): Lower threshold for special coloring.
351 | names (List[str]): List of names for each point.
352 | y_type (str): 'class' for categorical labels or 'num' for numerical labels.
353 | random_state (int): Random state.
354 | rep_type (str): Representation type used for plotting.
355 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted.
356 | highlight_label (str): Text for the legend entry of highlighted points.
357 | df (pd.DataFrame): DataFrame containing the data to plot.
358 | """
359 | fig, ax = plt.subplots(figsize=(10, 5))
360 |
361 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x])
362 |
363 | if len(x.shape) == 3: # Flatten if necessary
364 | x = x.reshape(x.shape[0], -1)
365 |
366 | if df is None:
367 | pca = PCA(n_components=2, random_state=random_state)
368 | z = pca.fit_transform(x)
369 | df = pd.DataFrame(z, columns=["z1", "z2"])
370 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative
371 | if names and len(names) == len(y):
372 | df["names"] = names
373 | else:
374 | df["names"] = [None] * len(y)
375 |
376 | # Handle the palette based on whether y is numerical or categorical
377 | if isinstance(y[0], (int, float)): # If y is numerical
378 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True)
379 | else: # If y is categorical
380 | cmap = sns.color_palette("Set2", as_cmap=False)
381 |
382 | hue = (
383 | "y" if df["y"].isnull().sum() != len(df["y"]) else None
384 | ) # Use hue only if y is informative
385 | scatter = sns.scatterplot(
386 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df
387 | )
388 |
389 | # Apply special coloring only if y values are valid and thresholds are provided
390 | if hue and (y_upper is not None or y_lower is not None):
391 | outlier_mask = (
392 | (df["y"] > y_upper)
393 | if y_upper is not None
394 | else np.zeros(len(df), dtype=bool)
395 | )
396 | outlier_mask |= (
397 | (df["y"] < y_lower)
398 | if y_lower is not None
399 | else np.zeros(len(df), dtype=bool)
400 | )
401 | scatter.scatter(
402 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey"
403 | )
404 |
405 | # Highlight points based on the highlight_mask
406 | if highlight_mask is not None:
407 | highlight_mask = np.array(highlight_mask)
408 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask
409 | scatter.scatter(
410 | df["z1"][highlight_points],
411 | df["z2"][highlight_points],
412 | color="red",
413 | marker="x",
414 | s=60,
415 | alpha=0.7,
416 | label=highlight_label,
417 | )
418 |
419 | scatter.set_title(f"PCA projection of {rep_type if rep_type else 'data'}")
420 |
421 | # Add the legend, making sure to include highlighted points
422 | handles, labels = scatter.get_legend_handles_labels()
423 | if highlight_label in labels:
424 | ax.legend(handles, labels, title="Legend")
425 | else:
426 | ax.legend(title="Legend")
427 |
428 | return fig, ax, df
429 |
--------------------------------------------------------------------------------
/tests/test_module.py:
--------------------------------------------------------------------------------
1 | """test python package related functions"""
2 |
3 |
4 | def test_load_package():
5 | import proteusAI
6 |
7 | proteusAI.__version__
8 |
--------------------------------------------------------------------------------