├── .github └── workflows │ └── cicd.yaml ├── .gitignore ├── .idea └── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── app ├── __init__.py ├── app.py ├── logo.png └── tooltips.py ├── demo ├── MLDE_benchmark.py ├── clustering_demo.py ├── demo_data │ ├── 1zb6.pdb │ ├── GB1.pdb │ ├── Nitric_Oxide_Dioxygenase.csv │ ├── Nitric_Oxide_Dioxygenase_raw.csv │ ├── Nitric_Oxide_Dioxygenase_wt.fasta │ ├── binding_data │ │ ├── dataset_creation.py │ │ ├── replicate_1.csv │ │ ├── replicate_2.csv │ │ └── replicate_3.csv │ └── methyltransfereases.csv ├── design_demo.py ├── discovery_demo.py ├── fold_design_demo.py ├── mlde_demo.py ├── plot_benchmark.py ├── plot_library_demo.py ├── test └── zero_shot_demo.py ├── docs ├── .gitignore ├── README.md ├── conf.py ├── index.rst └── tutorial │ └── tutorial.py ├── environment.yml ├── favicon.ico ├── proteusEnvironment.yml ├── pyproject.toml ├── run_benchmark.sh ├── setup.py ├── src └── proteusAI │ ├── Library │ ├── __init__.py │ └── library.py │ ├── Model │ ├── __init__.py │ └── model.py │ ├── Protein │ ├── __init__.py │ └── protein.py │ ├── __init__.py │ ├── data_tools │ ├── MSA.py │ ├── __init__.py │ └── pdb.py │ ├── design_tools │ ├── Constraints.py │ ├── MCMC.py │ ├── ZeroShot.py │ └── __init__.py │ ├── io_tools │ ├── __init__.py │ ├── embeddings.py │ ├── fasta.py │ └── matrices │ │ ├── BLOSUM50 │ │ ├── BLOSUM62 │ │ └── alphabet │ ├── mining_tools │ ├── __init__.py │ ├── alphafoldDB.py │ ├── blast.py │ └── uniprot.py │ ├── ml_tools │ ├── __init__.py │ ├── bo_tools │ │ ├── __init__.py │ │ ├── acq_fn.py │ │ └── genetic_algorithm.py │ ├── esm_tools │ │ ├── __init__.py │ │ ├── alphabet.pt │ │ └── esm_tools.py │ ├── sklearn_tools │ │ ├── __init__.py │ │ └── grid_search.py │ └── torch_tools │ │ ├── __init__.py │ │ ├── matrices │ │ ├── BLOSUM50 │ │ ├── BLOSUM62 │ │ └── alphabet │ │ └── torch_tools.py │ ├── struc │ ├── __init__.py │ └── struc.py │ └── visual_tools │ ├── __init__.py │ └── plots.py └── tests └── test_module.py /.github/workflows/cicd.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | pull_request: 9 | branches: [ "main", "shiny_server"] 10 | schedule: 11 | - cron: '0 2 * * 3' 12 | 13 | permissions: 14 | contents: read 15 | 16 | 17 | jobs: 18 | format: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: psf/black@stable 23 | lint: 24 | name: Lint with ruff 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.11" 32 | - name: Install ruff 33 | run: | 34 | pip install ruff 35 | - name: Lint with ruff 36 | run: | 37 | # stop the build if there are Python syntax errors or undefined names 38 | ruff check . 39 | test: 40 | name: Test 41 | runs-on: ubuntu-latest 42 | strategy: 43 | matrix: 44 | python-version: ["3.8"] 45 | steps: 46 | - uses: actions/checkout@v4 47 | - name: Set up Python ${{ matrix.python-version }} 48 | uses: actions/setup-python@v5 49 | with: 50 | python-version: ${{ matrix.python-version }} 51 | cache: 'pip' # caching pip dependencies 52 | # cache-dependency-path: '**/pyproject.toml' 53 | - name: Install dependencies 54 | run: | 55 | python -m pip install --upgrade pip 56 | pip install pytest 57 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.4.1+cpu.html 58 | - name: Run tests 59 | run: python -m pytest tests 60 | 61 | build_source_dist: 62 | name: Build source distribution 63 | if: startsWith(github.ref, 'refs/heads/main') || startsWith(github.ref, 'refs/tags') 64 | runs-on: ubuntu-latest 65 | steps: 66 | - uses: actions/checkout@v4 67 | 68 | - uses: actions/setup-python@v5 69 | with: 70 | python-version: "3.10" 71 | 72 | - name: Install build 73 | run: python -m pip install build 74 | 75 | - name: Run build 76 | run: python -m build --sdist 77 | 78 | - uses: actions/upload-artifact@v4 79 | with: 80 | path: ./dist/*.tar.gz 81 | # Needed in case of building packages with external binaries (e.g. Cython, RUst-extensions, etc.) 82 | # build_wheels: 83 | # name: Build wheels on ${{ matrix.os }} 84 | # if: startsWith(github.ref, 'refs/heads/main') || startsWith(github.ref, 'refs/tags') 85 | # runs-on: ${{ matrix.os }} 86 | # strategy: 87 | # matrix: 88 | # os: [ubuntu-20.04, windows-2019, macOS-10.15] 89 | 90 | # steps: 91 | # - uses: actions/checkout@v4 92 | 93 | # - uses: actions/setup-python@v5 94 | # with: 95 | # python-version: "3.10" 96 | 97 | # - name: Install cibuildwheel 98 | # run: python -m pip install cibuildwheel==2.3.1 99 | 100 | # - name: Build wheels 101 | # run: python -m cibuildwheel --output-dir wheels 102 | 103 | # - uses: actions/upload-artifact@v4 104 | # with: 105 | # path: ./wheels/*.whl 106 | 107 | publish: 108 | name: Publish package 109 | if: startsWith(github.ref, 'refs/tags') 110 | needs: 111 | - format 112 | - lint 113 | - test 114 | - build_source_dist 115 | # - build_wheels 116 | runs-on: ubuntu-latest 117 | 118 | steps: 119 | - uses: actions/download-artifact@v4 120 | with: 121 | name: artifact 122 | path: ./dist 123 | 124 | - uses: pypa/gh-action-pypi-publish@release/v1 125 | with: 126 | user: __token__ 127 | password: ${{ secrets.PYPI_API_TOKEN }} 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # biodata 132 | *.pdb 133 | *.fasta 134 | *.xml 135 | *.pt 136 | *.png 137 | *.sdf 138 | *.pth 139 | *.zip 140 | *.gif 141 | *.pml 142 | *.dat 143 | *.csv 144 | *.txt 145 | *.pse 146 | .DS_Store 147 | *.json 148 | *.pt 149 | 150 | # from server 151 | *.err 152 | *.out 153 | 4a6d.cif 154 | .idea/ProteusAI.iml 155 | Untitled.ipynb 156 | *.joblib 157 | demo/example_project/models/rf/params.json 158 | demo/example_project/models/rf/results.json 159 | 160 | # For env testing 161 | test.py 162 | demo/mpnn_test.py 163 | demo/mpnn_validation.py 164 | demo/design_demo.py 165 | mpnn.sh 166 | test1.cif 167 | test2.cif 168 | shiny-server-1.5.22.1017-amd64.deb 169 | proteusAI_developer.key 170 | proteusai_developer.key 171 | proteusAI_developer.key.pub 172 | proteusai_developer.key.pub 173 | 174 | google_analytics.html -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: "ubuntu-20.04" 11 | tools: 12 | python: "mambaforge-22.9" 13 | 14 | 15 | # Build documentation in the "docs/" directory with Sphinx 16 | sphinx: 17 | configuration: docs/conf.py 18 | 19 | # Optionally build your docs in additional formats such as PDF and ePub 20 | # formats: 21 | # - pdf 22 | # - epub 23 | 24 | # Optional but recommended, declare the Python requirements required 25 | # to build your documentation 26 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 27 | # python: 28 | # install: 29 | # - method: pip 30 | # path: . 31 | # extra_requirements: 32 | # - docs 33 | 34 | conda: 35 | environment: environment.yml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ProteusAI contributors All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | recursive-exclude __pycache__ *.py[cod] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![App_Overview](https://github.com/user-attachments/assets/6a11e863-a45e-47d3-903d-8fb5bc260cd2) 3 | 4 | 5 | 6 | # ProteusAI 7 | ProteusAI is a library for machine learning-guided protein design and engineering. 8 | The library enables workflows from protein structure prediction, the prediction of 9 | mutational effects-, and zero-shot prediction of mutational effects. 10 | The goal is to provide state-of-the-art machine learning for protein engineering in a central library. 11 | 12 | ProteusAI is primarily powered by [PyTorch](https://pytorch.org/get-started/locally/), 13 | [scikit-learn](https://scikit-learn.org/stable/), 14 | and [ESM](https://github.com/facebookresearch/esm) protein language models. 15 | 16 | Cite our preprint [here](https://www.biorxiv.org/content/10.1101/2024.10.01.616114v1). 17 | 18 | 19 | Test out the ProteusAI web app [proteusai.bio](http://proteusai.bio/) 20 | 21 | ## Getting started 22 | 23 | ---- 24 | The commands used below are tested on Ubuntu 20.04 and IOS. Some tweaks may be needed for other OS. 25 | We recommend using conda environments to install ProteusAI. 26 | 27 | 28 | Clone the repository and cd to ProteusAI: 29 | ```bash 30 | git clone https://github.com/jonfunk21/ProteusAI.git 31 | cd ProteusAI 32 | ``` 33 | 34 | Install the latest version of proteusAI in a new environment 35 | ```bash 36 | conda env create -n proteusAI 37 | conda activate proteusAI 38 | ``` 39 | 40 | This uses the `environment.yml` file to install the dependencies. 41 | 42 | ## GPU support 43 | By default proteus will install torch-scatter using cpu compatible binaries. 44 | If you want to take full advantage of GPUs for the Design module consider 45 | uninstalling the default `torch-scatter`, and replace it with the CUDA 46 | compatible version: 47 | 48 | ```bash 49 | pip uninstall torch-scatter 50 | pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+cuda.html 51 | ``` 52 | 53 | If you have acces to GPUs you can run ESM-Fold. Unfortunately the installation of 54 | openfold can be unstable for some people. Please follow installation instructions 55 | on of the [ESM repository](https://github.com/facebookresearch/esm) and use 56 | the discussions in the issues section for help. 57 | 58 | ### Install using pip locally for developement 59 | 60 | Install a local version which picks up the latest changes using an editable install: 61 | 62 | ```bash 63 | # conda env create -n proteusAI python=3.8 64 | # conda activate proteusAI 65 | pip install -e . --find-links https://data.pyg.org/whl/torch-2.4.1+cpu.html 66 | ``` 67 | 68 | ### Troubleshooting 69 | You can check a working configuration for a Ubuntu machine (VM) 70 | in the `proteusEnvironment.yml` file. The latest versions can be checked by us via 71 | our actions. 72 | 73 | ### Setting shiny server 74 | 75 | Install Shiny Server on Ubuntu 18.04+ (the instructions for other systems are availabe at posit.co, please skip the section about R Shiny packages installation) with the following commands: 76 | ```bash 77 | sudo apt-get install gdebi-core 78 | wget https://download3.rstudio.org/ubuntu-18.04/x86_64/shiny-server-1.5.22.1017-amd64.deb 79 | sudo gdebi shiny-server-1.5.22.1017-amd64.deb 80 | ``` 81 | 82 | Edit the default config file `/etc/shiny-server/shiny-server.conf` for Shiny Server (the `sudo` command or root privileges are required): 83 | ```bash 84 | # Use python from the virtual environment to run Shiny apps 85 | python /home/proteus_developer/miniforge3/envs/proteusAI_depl/bin/python; 86 | 87 | # Instruct Shiny Server to run applications as the user "shiny" 88 | run_as shiny; 89 | 90 | # Never delete logs regardless of the their exit code 91 | preserve_logs true; 92 | 93 | # Do not replace errors with the generic error message, show them as they are 94 | sanitize_errors false; 95 | 96 | # Define a server that listens on port 80 97 | server { 98 | listen 80; 99 | 100 | # Define a location at the base URL 101 | location / { 102 | 103 | # Host the directory of Shiny Apps stored in this directory 104 | site_dir /srv/shiny-server; 105 | 106 | # Log all Shiny output to files in this directory 107 | log_dir /var/log/shiny-server; 108 | 109 | # When a user visits the base URL rather than a particular application, 110 | # an index of the applications available in this directory will be shown. 111 | directory_index on; 112 | } 113 | } 114 | ``` 115 | Restart the shiny server with the following command to apply the server configuration changes: 116 | ```bash 117 | sudo systemctl restart shiny-server 118 | ``` 119 | If you deploy the app on your local machine, be sure that the port 80 is open and not blocked by a firewall. You can check it with `netstat`: 120 | ```bash 121 | nc 80 122 | ``` 123 | If you deploy the app on your Azure Virtual Machine (VM), please add an Inbound Port rule in the Networking - Network Settings section on Azure Portal. Set the following properties: 124 | ```yaml 125 | Source: Any 126 | Source port ranges: * 127 | Destination: Any 128 | Service: HTTP 129 | Destination port ranges: 80 130 | Protocol: TCP 131 | Action: Allow 132 | ``` 133 | Other fields can beleaft as they are by default. 134 | 135 | Finally, create symlinks to your app files in the default Shiny Server folder `/srv/shiny-server/`: 136 | 137 | ```bash 138 | sudo ln -s /home/proteus_developer/ProteusAI/app/app.py /srv/shiny-server/app.py 139 | sudo ln -s /home/proteus_developer/ProteusAI/app/logo.png /srv/shiny-server/logo.png 140 | ``` 141 | If everything has been done correctly, you must see the application index page available at `http://127.0.0.1` (if you deploy your app locally) or at `http://` (if you deploy your app on an Azure VM). Additionally, the remote app can still be available in your local browser (the Shiny extension in Visual Studio must be enabled) if you run the following terminal command on the VM: 142 | ```bash 143 | /home/proteus_developer/miniforge3/envs/proteusAI_depl/bin/python -m shiny run --port 33015 --reload --autoreload-port 43613 /home/proteus_developer/ProteusAI/app/app.py 144 | ``` 145 | If you get warnings, debug or "Disconnected from the server" messages, it is likely due to: 146 | - absent python modules, 147 | - updated versions of the current python modules, 148 | - using relative paths instead of absolute paths (Shiny Server sees relative paths as starting from `/srv/shiny-server/` folder) 149 | or 150 | - logical errors in the code. 151 | 152 | In order to debug the application, see what is written in the server logs under `/var/log/shiny-server` (the log_dir parameter can be reset in the Shiny Server instance config file `/etc/shiny-server/shiny-server.conf`). 153 | 154 | ### Note on permissions: 155 | The app may give some problems due to directories not having permissions to create directories or load files to certain directories. When this happen, a solution found was to use the following: 156 | 157 | ```bash 158 | chmod 777 directory_name 159 | ``` 160 | 161 | 162 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/app/__init__.py -------------------------------------------------------------------------------- /app/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/app/logo.png -------------------------------------------------------------------------------- /app/tooltips.py: -------------------------------------------------------------------------------- 1 | data_tooltips = """ 2 | ProteusAI, a user-friendly and open-source ML platform, streamlines protein engineering and design tasks. 3 | ProteusAI offers modules to support researchers in various stages of the design-build-test-learn (DBTL) cycle, 4 | including protein discovery, structure-based design, zero-shot predictions, and ML-guided directed evolution (MLDE). 5 | Our benchmarking results demonstrate ProteusAI’s efficiency in improving proteins and enzymes within a few DBTL-cycle 6 | iterations. ProteusAI democratizes access to ML-guided protein engineering and is freely available for academic and 7 | commercial use. 8 | You can upload different data types to get started with ProteusAI. Click on the other module tabs to learn about their 9 | functionality and the expected data types. 10 | """ 11 | 12 | zs_file_type = """ 13 | Upload a FASTA file containing the protein sequence, or a PDB file containing the structure of the protein 14 | for which you want to generate zero-shot predictions. 15 | """ 16 | 17 | zs_tooltips = """ 18 | The ProteusAI Zero-Shot Module is designed to create a mutant library with no prior data. 19 | The module uses scores generated by large protein language models, such as ESM-1v, that have 20 | been trained to predict hidden residues in hundreds of millions of protein sequences. 21 | Often, you will find that several residues in your protein sequence have low predicted probabilities. 22 | It has been previously shown that swapping these residues for residues with higher probabilities 23 | has beneficial effects on the candidate protein. In ProteusAI, we provide access to several language 24 | models which have been trained under slightly different conditions. The best models to produce Zero-Shot 25 | scores are ESM-1v and ESM-2 (650M). However, these models will take a long time to compute the results. 26 | Consider using ESM-2 (35M) to get familiar with the module first before moving to the larger models. 27 | """ 28 | 29 | discovery_file_type = """ 30 | Upload a CSV or EXCEL file in the 'Data' tab under 'Library' to proceed with the Discovery module. 31 | The file should contain a column for protein sequences, a column with the protein names, and a column for 32 | annotations, which can also be empty or partially populated with annotations. 33 | """ 34 | 35 | discovery_tooltips = """ 36 | The Discovery Module offers a structured approach to identifying proteins even with little to no 37 | experimental data to start with. The goal of the module is to identify proteins with similar 38 | functions and to propose novel sequences that are likely to have similar functions. 39 | The module relies on representations generated by large protein language models that transform protein 40 | sequences into meaningful vector representations. It has been shown that these vector representations often 41 | cluster based on function. Clustering should be used if all, very few, or no sequences have annotations. 42 | Classification should be used if some or all sequences are annotated. To find out if you have enough 43 | sequences for classification, we recommend using the model statistics on the validation set, which are 44 | automatically generated by the module after training. 45 | """ 46 | 47 | mlde_file_type = """ 48 | Upload a CSV or EXCEL file in the 'Data' tab under 'Library' to proceed with the MLDE module. 49 | The file should contain a column for protein sequences, a column with the protein names (e.g., 50 | mutant descriptions M15V), and a column for experimental values (e.g., enzyme activity, 51 | fluorescence, etc.). 52 | """ 53 | 54 | mlde_tooltips = """ 55 | The Machine Learning Guided Directed Evolution (MLDE) module offers a structured approach to 56 | improve protein function through iterative mutagenesis, inspired by Directed Evolution. 57 | Here, machine learning models are trained on previously generated experimental results. The 58 | 'Search' algorithm will then propose novel sequences that will be evaluated and ranked by the 59 | trained model to predict mutants that are likely to improve function. The Bayesian optimization 60 | algorithms used to search for novel mutants are based on models trained on protein representations 61 | that can either be generated from large language models, which is currently very slow, or from 62 | classical algorithms such as BLOSUM62. For now, we recommend the use of BLOSUM62 representations 63 | combined with Random Forest models for the best trade-off between speed and quality. However, we encourage 64 | experimentation with other models and representations. 65 | """ 66 | 67 | design_file_type = """ 68 | Upload a PDB file containing the structure of the protein 69 | to use the (structure-based) Design module. 70 | """ 71 | 72 | design_tooltips = """ 73 | The Protein Design module is a structure-based approach to predict novel sequences using 'Inverse Folding' 74 | algorithms. The designed sequences are likely to preserve the fold of the protein while improving 75 | the thermal stability and solubility of proteins. To preserve important functions of the protein, we recommend 76 | the preservation of protein-protein, ligand-ligand interfaces, and evolutionarily conserved sites, which can be 77 | entered manually. The temperature factor influences the diversity of designs. We recommend the generation of at 78 | least 1,000 sequences and rigorous filtering before ordering variants for validation. To give an example: Sort 79 | the sequences from the lowest to highest score, predict the structure of the lowest-scoring variants, and proceed 80 | with the designs that preserve the geometry of the active site (in the case of an enzyme). Experiment with a small 81 | sample size to explore temperature values that yield desired levels of diversification before generating large 82 | numbers of sequences. 83 | """ 84 | 85 | representations_tooltips = """ 86 | The Representations module offers methods to compute and visualize vector representations of proteins. These are primarily 87 | used by the MLDE and Discovery modules to make training more data-efficient. The representations are generated from 88 | classical algorithms such as BLOSUM62 or large protein language models that infuse helpful inductive biases into protein 89 | sequence representations. In some cases, the representations can be used to cluster proteins based on function or to 90 | predict protein properties. The module offers several visualization techniques to explore the representations and to 91 | understand the underlying structure of the protein data. Advanced analysis and predictions can be made by using the 92 | MLDE or Discovery modules in combination with the Representations module. 93 | """ 94 | 95 | zs_entropy_tooltips = """ 96 | This plot shows the entropy values across the protein sequence, providing insights into the diversity tolerated at each position. 97 | The higher the entropy, the great the variety of amino acids tolerated at the position. 98 | """ 99 | 100 | zs_heatmap_tooltips = """ 101 | This heatmap visualizes the computed zero-shot scores for the protein. The scores at each position are normalised to the score 102 | of the original amino acid, which is set to zero (white) and highlighted by a black box. A positive score (blue) indicates that mutating the position to that amino acid could 103 | have beneficial effects, while a negative score (red) indicates the mutation would not be favourable. 104 | """ 105 | -------------------------------------------------------------------------------- /demo/MLDE_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("src/") 5 | import proteusAI as pai 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import json 9 | import argparse 10 | import pandas as pd 11 | 12 | # Initialize the argparse parser 13 | parser = argparse.ArgumentParser(description="Benchmarking ProteusAI MLDE") 14 | 15 | # Add arguments corresponding to the variables 16 | parser.add_argument( 17 | "--user", type=str, default="benchmark", help="User name or identifier." 18 | ) 19 | parser.add_argument( 20 | "--sample-sizes", 21 | type=int, 22 | nargs="+", 23 | default=[5, 10, 20, 100], 24 | help="List of sample sizes.", 25 | ) 26 | parser.add_argument("--model", type=str, default="gp", help="Model name.") 27 | parser.add_argument("--rep", type=str, default="esm2", help="Representation type.") 28 | parser.add_argument( 29 | "--zs-model", type=str, default="esm1v", help="Zero-shot model name." 30 | ) 31 | parser.add_argument( 32 | "--benchmark-folder", 33 | type=str, 34 | default="demo/demo_data/DMS/", 35 | help="Path to the benchmark folder.", 36 | ) 37 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 38 | parser.add_argument( 39 | "--max-iter", 40 | type=int, 41 | default=100, 42 | help="Maximum number of iterations (None for unlimited).", 43 | ) 44 | parser.add_argument( 45 | "--device", 46 | type=str, 47 | default="cuda", 48 | help="Device to run the model on (e.g., cuda, cpu).", 49 | ) 50 | parser.add_argument( 51 | "--batch-size", type=int, default=1, help="Batch size for processing." 52 | ) 53 | parser.add_argument( 54 | "--improvement", 55 | type=str, 56 | nargs="+", 57 | default=[5, 10, 20, 50, "improved"], 58 | help="List of improvements.", 59 | ) 60 | parser.add_argument( 61 | "--acquisition_fn", type=str, default="ei", help="ProteusAI acquisition functions" 62 | ) 63 | parser.add_argument( 64 | "--k_folds", type=int, default=5, help="K-fold cross validation for sk_learn models" 65 | ) 66 | 67 | 68 | def benchmark(dataset, fasta, model, embedding, name, sample_size, results_df): 69 | iteration = 1 70 | 71 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num 72 | lib = pai.Library( 73 | user=USER, 74 | source=dataset, 75 | seqs_col="mutated_sequence", 76 | y_col="DMS_score", 77 | y_type="num", 78 | names_col="mutant", 79 | ) 80 | 81 | # compute representations for this dataset 82 | lib.compute(method=REP, batch_size=BATCH_SIZE, device=DEVICE) 83 | 84 | # plot destination 85 | # plot_dest = os.path.join(lib.user, name, embedding, 'plots', str(sample_size)) 86 | # os.makedirs(plot_dest, exist_ok=True) 87 | 88 | # wt sequence 89 | protein = pai.Protein(user=USER, source=fasta) 90 | 91 | # zero-shot scores 92 | out = protein.zs_prediction(model=ZS_MODEL, batch_size=BATCH_SIZE, device=DEVICE) 93 | zs_lib = pai.Library(user=USER, source=out) 94 | 95 | # Simulate selection of top N ZS-predictions for the initial librarys 96 | zs_prots = [prot for prot in zs_lib.proteins if prot.name in lib.names] 97 | sorted_zs_prots = sorted(zs_prots, key=lambda prot: prot.y_pred, reverse=True) 98 | 99 | # train on the ZS-selected initial library (assume that they have been assayed now) train with 80:10:10 split 100 | # if the sample size is to small the initial batch has to be large enough to give meaningful statistics. Here 15 101 | if sample_size <= 15: 102 | top_N_zs_names = [prot.name for prot in sorted_zs_prots[:15]] 103 | zs_selected = [prot for prot in lib.proteins if prot.name in top_N_zs_names] 104 | n_train = 11 105 | n_test = 2 106 | _sample_size = 15 107 | else: 108 | top_N_zs_names = [prot.name for prot in sorted_zs_prots[:sample_size]] 109 | zs_selected = [prot for prot in lib.proteins if prot.name in top_N_zs_names] 110 | n_train = int(sample_size * 0.8) 111 | n_test = int(sample_size * 0.1) 112 | _sample_size = sample_size 113 | 114 | # train model 115 | model = pai.Model(model_type=model) 116 | model.train( 117 | library=lib, 118 | x=REP, 119 | split={ 120 | "train": zs_selected[:n_train], 121 | "test": zs_selected[n_train : n_train + n_test], 122 | "val": zs_selected[n_train + n_test : _sample_size], 123 | }, 124 | seed=SEED, 125 | model_type=MODEL, 126 | k_folds=K_FOLD, 127 | ) 128 | 129 | # add to results df 130 | results_df = add_to_data( 131 | data=results_df, 132 | proteins=zs_selected, 133 | iteration=iteration, 134 | sample_size=_sample_size, 135 | dataset=name, 136 | model=model, 137 | ) 138 | 139 | # use the model to make predictions on the remaining search space 140 | search_space = [prot for prot in lib.proteins if prot.name not in top_N_zs_names] 141 | ranked_search_space, _, _, _, _ = model.predict(search_space, acq_fn=ACQ_FN) 142 | 143 | # Prepare the tracking of top N variants, 144 | top_variants_counts = IMPROVEMENT 145 | found_counts = {count: 0 for count in top_variants_counts} 146 | first_discovered = [None] * len(top_variants_counts) 147 | 148 | # Identify the top variants by score 149 | actual_top_variants = sorted(lib.proteins, key=lambda prot: prot.y, reverse=True) 150 | actual_top_variants = { 151 | count: [prot.name for prot in actual_top_variants[:count]] 152 | for count in top_variants_counts 153 | if isinstance(count, int) 154 | } 155 | 156 | # count variants that are improved over wt, 1 standard deviation above wt 157 | y_std = np.std(lib.y, ddof=1) 158 | y_mean = np.mean(lib.y) 159 | actual_top_variants["improved"] = [ 160 | prot.name for prot in lib.proteins if prot.y > y_mean + 1 * y_std 161 | ] 162 | 163 | # add sequences to the new dataset, and continue the loop until the dataset is exhausted 164 | sampled_data = zs_selected 165 | 166 | while len(ranked_search_space) >= sample_size: 167 | # Check if we have found all top variants, including the 1st zero-shot round 168 | sampled_names = [prot.name for prot in sampled_data] # noqa: F841 169 | for c, count in enumerate(found_counts): 170 | found = len( 171 | [ 172 | prot.name 173 | for prot in sampled_data 174 | if prot.name in actual_top_variants[count] 175 | ] 176 | ) 177 | found_counts[count] = found 178 | if found > 0 and first_discovered[c] is None: 179 | first_discovered[c] = iteration 180 | 181 | # Break the loop if all elements in first_discovered are not None 182 | if all(value is not None for value in first_discovered): 183 | break 184 | 185 | # Break if maximum number of iterations have been reached 186 | if iteration == MAX_ITER: 187 | break 188 | 189 | iteration += 1 190 | 191 | # select variants that have now been 'assayed' 192 | sample = ranked_search_space[:sample_size] 193 | 194 | # Remove the selected top N elements from the ranked search space 195 | ranked_search_space = ranked_search_space[sample_size:] 196 | 197 | # add new 'assayed' sample to sampled data 198 | sampled_data += sample 199 | 200 | # split into train, test and val 201 | n_train = int(len(sampled_data) * 0.8) 202 | n_test = int(len(sampled_data) * 0.1) 203 | 204 | # handle the very low data results 205 | if n_test == 0: 206 | n_test = 1 207 | n_train = n_train - n_test 208 | 209 | split = { 210 | "train": sampled_data[:n_train], 211 | "test": sampled_data[n_train : n_train + n_test], 212 | "val": sampled_data[n_train + n_test :], 213 | } 214 | 215 | # train model on new data 216 | model.train( 217 | library=lib, x=REP, split=split, seed=SEED, model_type=MODEL, k_folds=K_FOLD 218 | ) 219 | 220 | # add to results 221 | results_df = add_to_data( 222 | data=results_df, 223 | proteins=sample, 224 | iteration=iteration, 225 | sample_size=sample_size, 226 | dataset=name, 227 | model=model, 228 | ) 229 | 230 | # re-score the new search space 231 | ( 232 | ranked_search_space, 233 | sorted_y_pred, 234 | sorted_sigma_pred, 235 | y_val, 236 | sorted_acq_score, 237 | ) = model.predict(ranked_search_space, acq_fn=ACQ_FN) 238 | 239 | # save when the first datapoints for each dataset and category have been discvered 240 | first_discovered_data[name][sample_size] = first_discovered 241 | 242 | return found_counts, results_df 243 | 244 | 245 | def plot_results(found_counts, name, iter, dest, sample_size): 246 | counts = list(found_counts.keys()) 247 | found = [found_counts[count] for count in counts] 248 | x_positions = range(len(counts)) # Create fixed distance positions for the x-axis 249 | 250 | plt.figure(figsize=(10, 6)) 251 | plt.bar(x_positions, found, color="skyblue") 252 | plt.xlabel("Top N Variants") 253 | plt.ylabel("Number of Variants Found") 254 | plt.title(f"{name}: Number of Top N Variants Found After {iter} Iterations") 255 | plt.xticks(x_positions, counts) # Set custom x-axis positions and labels 256 | plt.ylim(0, max(found) + 1) 257 | 258 | for i, count in enumerate(found): 259 | plt.text(x_positions[i], count + 0.1, str(count), ha="center") 260 | 261 | plt.savefig(os.path.join(dest, f"top_variants_{iter}_iterations_{name}.png")) 262 | 263 | 264 | def add_to_data(data: pd.DataFrame, proteins, iteration, sample_size, dataset, model): 265 | """Add sampling results to dataframe""" 266 | names = [prot.name for prot in proteins] 267 | ys = [prot.y for prot in proteins] 268 | y_preds = [prot.y_pred for prot in proteins] 269 | y_sigmas = [prot.y_sigma for prot in proteins] 270 | models = [MODEL] * len(names) 271 | reps = [REP] * len(names) 272 | acq_fns = [ACQ_FN] * len(names) 273 | rounds = [iteration] * len(names) 274 | datasets = [dataset] * len(names) 275 | sample_sizes = [sample_size] * len(names) 276 | test_r2s = [model.test_r2] * len(names) 277 | val_r2s = [model.val_r2] * len(names) 278 | 279 | new_data = pd.DataFrame( 280 | { 281 | "name": names, 282 | "y": ys, 283 | "y_pred": y_preds, 284 | "y_sigma": y_sigmas, 285 | "test_r2": test_r2s, 286 | "val_r2": val_r2s, 287 | "model": models, 288 | "rep": reps, 289 | "acq_fn": acq_fns, 290 | "round": rounds, 291 | "sample_size": sample_sizes, 292 | "dataset": datasets, 293 | } 294 | ) 295 | 296 | # Append the new data to the existing DataFrame 297 | updated_data = pd.concat([data, new_data], ignore_index=True) 298 | return updated_data 299 | 300 | 301 | # Parse the arguments 302 | args = parser.parse_args() 303 | 304 | # Assign parsed arguments to capitalized variable names 305 | USER = args.user 306 | SAMPLE_SIZES = args.sample_sizes 307 | MODEL = args.model 308 | REP = args.rep 309 | ZS_MODEL = args.zs_model 310 | BENCHMARK_FOLDER = args.benchmark_folder 311 | SEED = args.seed 312 | MAX_ITER = args.max_iter 313 | DEVICE = args.device 314 | BATCH_SIZE = args.batch_size 315 | IMPROVEMENT = args.improvement 316 | ACQ_FN = args.acquisition_fn 317 | K_FOLD = args.k_folds 318 | 319 | # benchmark data 320 | datasets = [f for f in os.listdir(BENCHMARK_FOLDER) if f.endswith(".csv")] 321 | fastas = [f for f in os.listdir(BENCHMARK_FOLDER) if f.endswith(".fasta")] 322 | datasets.sort() 323 | fastas.sort() 324 | 325 | # save sampled data 326 | results_df = pd.DataFrame( 327 | { 328 | "name": [], 329 | "y": [], 330 | "y_pred": [], 331 | "y_sigma": [], 332 | "test_r2": [], 333 | "val_r2": [], 334 | "model": [], 335 | "rep": [], 336 | "acq_fn": [], 337 | "round": [], 338 | "sample_size": [], 339 | "dataset": [], 340 | } 341 | ) 342 | 343 | first_discovered_data = {} 344 | for i in range(len(datasets)): 345 | for N in SAMPLE_SIZES: 346 | print( 347 | f"RUNNING model:{MODEL}, rep:{REP}, acq:{ACQ_FN}, sample_size:{N}", 348 | flush=True, 349 | ) 350 | d = os.path.join(BENCHMARK_FOLDER, datasets[i]) 351 | f = os.path.join(BENCHMARK_FOLDER, fastas[i]) 352 | name = datasets[i][:-4] 353 | 354 | if N == SAMPLE_SIZES[0]: 355 | first_discovered_data[name] = {N: []} 356 | else: 357 | first_discovered_data[name][N] = [] 358 | 359 | found_counts, results_df = benchmark( 360 | d, 361 | f, 362 | model=MODEL, 363 | embedding=REP, 364 | name=name, 365 | sample_size=N, 366 | results_df=results_df, 367 | ) 368 | # save first discovered data 369 | with open( 370 | os.path.join( 371 | "usrs/benchmark/", f"first_discovered_data_{MODEL}_{REP}_{ACQ_FN}.json" 372 | ), 373 | "w", 374 | ) as file: 375 | json.dump(first_discovered_data, file) 376 | 377 | results_df.to_csv( 378 | os.path.join("usrs/benchmark/", f"results_df_{MODEL}_{REP}_{ACQ_FN}.csv"), 379 | index=False, 380 | ) 381 | -------------------------------------------------------------------------------- /demo/clustering_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import matplotlib.pyplot as plt 4 | 5 | import proteusAI as pai 6 | 7 | print(os.getcwd()) 8 | sys.path.append("src/") 9 | 10 | 11 | # will initiate storage space - else in memory 12 | dataset = "demo/demo_data/methyltransfereases.csv" 13 | y_column = "class" 14 | 15 | results_dictionary = {} 16 | 17 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num 18 | library = pai.Library( 19 | source=dataset, 20 | seqs_col="sequence", 21 | y_col=y_column, 22 | y_type="class", 23 | names_col="uid", 24 | ) 25 | 26 | # compute and save ESM-2 representations at example_lib/representations/esm2 27 | library.compute(method="esm2_8M", batch_size=10) 28 | 29 | # dimensionality reduction 30 | dr_method = "tsne" 31 | 32 | # random seed 33 | seed = 42 34 | 35 | # define a model 36 | model = pai.Model( 37 | library=library, 38 | k_folds=5, 39 | model_type="hdbscan", 40 | rep="esm2_8M", 41 | min_cluster_size=30, 42 | min_samples=50, 43 | dr_method=dr_method, 44 | seed=seed, 45 | ) 46 | 47 | # train model 48 | model.train() 49 | 50 | # search predict the classes of unknown sequences 51 | out = model.search() 52 | search_mask = out["mask"] 53 | 54 | # save results 55 | if not os.path.exists("demo/demo_data/out/"): 56 | os.makedirs("demo/demo_data/out/", exist_ok=True) 57 | 58 | out["df"].to_csv("demo/demo_data/out/clustering_search_results.csv") 59 | 60 | model_lib = pai.Library(source=out) 61 | 62 | # plot results 63 | fig, ax, plot_df = model.library.plot( 64 | rep="esm2_8M", 65 | method=dr_method, 66 | use_y_pred=True, 67 | highlight_mask=search_mask, 68 | seed=seed, 69 | ) 70 | plt.savefig(f"demo/demo_data/out/clustering_results_{dr_method}.png") 71 | -------------------------------------------------------------------------------- /demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta: -------------------------------------------------------------------------------- 1 | >wt|NOD 2 | MAPTLSEQTRQLVRASVPALQKHSVAISATMYRLLFERYPETRSLCELPERQIHKIASALLAYARSIDNPSALQAAIRRMVLSHARAGVQAVHYPLYWECLRDAIKEVLGPDATETLLQAWKEAYDFLAHLLSTKEAQVYAVLAE -------------------------------------------------------------------------------- /demo/demo_data/binding_data/dataset_creation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # Load the dataframes 4 | df1 = pd.read_csv("demo/demo_data/binding_data/replicate_1.csv") 5 | df2 = pd.read_csv("demo/demo_data/binding_data/replicate_2.csv") 6 | df3 = pd.read_csv("demo/demo_data/binding_data/replicate_3.csv") 7 | 8 | # Merge the dataframes on CDR1H_AA and CDR3H_AA to find common rows 9 | merged_df = pd.merge(df1, df2, on=["CDR1H_AA", "CDR3H_AA"], suffixes=("_df1", "_df2")) 10 | merged_df = pd.merge(merged_df, df3, on=["CDR1H_AA", "CDR3H_AA"]) 11 | 12 | # Calculate the average of fit_KD from the three dataframes 13 | merged_df["fit_KD_avg"] = merged_df[["fit_KD_df1", "fit_KD_df2", "fit_KD"]].mean(axis=1) 14 | 15 | # Select the relevant columns for the final dataframe 16 | result_df = merged_df[["CDR1H_AA", "CDR3H_AA", "fit_KD_avg"]] 17 | 18 | # Group by CDR1H_AA and CDR3H_AA to remove duplicates and average the fit_KD_avg 19 | final_df = ( 20 | result_df.groupby(["CDR1H_AA", "CDR3H_AA"]) 21 | .agg({"fit_KD_avg": "mean"}) 22 | .reset_index() 23 | ) 24 | 25 | # Display the final result 26 | final_df 27 | 28 | wt_seq = "EVKLDETGGGLVQPGRPMKLSCVASGFTFSDYWMNWVRQSPEKGLEWVAQIRNKPYNYETYYSDSVKGRFTISRDDSKSSVYLQMNNLRVEDMGIYYCTGSYYGMDYWGQGTSVTVSSAKTTAPSVYPLAPVCGDTTGSSVTLGCLVKGYFPEPVTLTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVTSSTWPSQSITCNVAHPASSTKVDKKIEPRG" 29 | 30 | # CDRH1 and CDR3H wild-type sequences 31 | cdr1h_wt = wt_seq[27:37] # TFSDYWMNWV 32 | cdr3h_wt = wt_seq[99:109] # GSYYGMDYWG 33 | 34 | seqs = [] 35 | mutant_names = [] 36 | for i, row in final_df.iterrows(): 37 | seq = list(wt_seq) 38 | CDR1H_AA = row["CDR1H_AA"] 39 | CDR3H_AA = row["CDR3H_AA"] 40 | seq[27:37] = list(CDR1H_AA) 41 | seq[99:109] = list(CDR3H_AA) 42 | mutated_seq = "".join(seq) 43 | seqs.append(mutated_seq) 44 | 45 | # Determine the mutations for naming 46 | mutations = [] 47 | 48 | # Compare CDR1H sequence 49 | for j in range(len(cdr1h_wt)): 50 | if CDR1H_AA[j] != cdr1h_wt[j]: 51 | mutations.append(f"{cdr1h_wt[j]}{27+j+1}{CDR1H_AA[j]}") 52 | 53 | # Compare CDR3H sequence 54 | for j in range(len(cdr3h_wt)): 55 | if CDR3H_AA[j] != cdr3h_wt[j]: 56 | mutations.append(f"{cdr3h_wt[j]}{99+j+1}{CDR3H_AA[j]}") 57 | 58 | # Combine mutations into a mutant name 59 | if mutations: 60 | mutant_name = ":".join(mutations) 61 | else: 62 | mutant_name = "WT" 63 | 64 | mutant_names.append(mutant_name) 65 | 66 | data = { 67 | "mutant": mutant_names, 68 | "mutated_sequence": seqs, 69 | "DMS_score": final_df["fit_KD_avg"].to_list(), 70 | "DMS_score_bin": [None] * len(mutant_names), 71 | } 72 | 73 | results_df = pd.DataFrame(data) 74 | results_df.to_csv("demo/demo_data/DMS/SCFV_HUMAN_Adams_2016_affinity.csv", index=False) 75 | 76 | with open("demo/demo_data/DMS/SCFV_HUMAN_Adams_2016_affinity.fasta", "w") as f: 77 | f.write(">SCFV_HUMAN_Adams_2016_affinity\n") 78 | f.write(f"{wt_seq}") 79 | -------------------------------------------------------------------------------- /demo/design_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import proteusAI as pai 4 | 5 | files = ["demo/demo_data/GB1.pdb"] # input files 6 | temps = [1.0] # sampling temperatures 7 | fixed = {"GB1": [1, 2, 3]} # fixed residues 8 | num_samples = 100 # number of samples 9 | 10 | for f in files: 11 | protein = pai.Protein(source=f) 12 | 13 | fname = f.split("/")[-1][:-4] 14 | 15 | for temp in temps: 16 | for key, fix in fixed.items(): 17 | out = protein.esm_if( 18 | num_samples=num_samples, target_chain="A", temperature=temp, fixed=fix 19 | ) 20 | 21 | # save results 22 | if not os.path.exists("demo/demo_data/out/"): 23 | os.makedirs("demo/demo_data/out/", exist_ok=True) 24 | 25 | out["df"].to_csv( 26 | f"demo/demo_data/out_{fname}_temp_{temp}_{key}_out.csv", index=False 27 | ) 28 | 29 | 30 | ### UNCOMMENT THESE LINES TO FOLD DESIGNS IF YOU HAVE ESM-FOLD INSTALLED ### 31 | 32 | # Designs can be folded to check the structure and confidence of the designs 33 | # import shutil 34 | # design_library = pai.Library(source=out) 35 | 36 | # this will fold the designs 37 | # fold_out = design_library.fold() 38 | 39 | # save the folded designs 40 | # fold_out["df"].to_csv( 41 | # f"demo/demo_data/out{fname}_temp_{temp}_{key}_folded.csv", index=False 42 | # ) 43 | 44 | # save the structures 45 | # os.makedirs(f"demo/demo_data/out/{fname}_temp_{temp}_{key}_pdb/", exist_ok=True) 46 | 47 | # Move all files from source to destination 48 | # for file in os.listdir(out["struc_path"]): 49 | # shutil.move( 50 | # os.path.join(out["struc_path"], file), 51 | # f"demo/demo_data/out/{fname}_temp_{temp}_{key}_pdb/", 52 | # ) 53 | -------------------------------------------------------------------------------- /demo/discovery_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import matplotlib.pyplot as plt 4 | 5 | import proteusAI as pai 6 | 7 | print(os.getcwd()) 8 | sys.path.append("src/") 9 | 10 | 11 | # will initiate storage space - else in memory 12 | dataset = "demo/demo_data/methyltransfereases.csv" 13 | y_column = "coverage_5" 14 | 15 | results_dictionary = {} 16 | 17 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num 18 | library = pai.Library( 19 | source=dataset, 20 | seqs_col="sequence", 21 | y_col=y_column, 22 | y_type="class", 23 | names_col="uid", 24 | ) 25 | 26 | # compute and save ESM-2 representations at example_lib/representations/esm2 27 | library.compute(method="esm2_8M", batch_size=10) 28 | 29 | # define a model 30 | model = pai.Model(library=library, k_folds=5, model_type="rf", rep="esm2_8M") 31 | 32 | # train model 33 | model.train() 34 | 35 | # search predict the classes of unknown sequences 36 | out = model.search() 37 | search_mask = out["mask"] 38 | 39 | # save results 40 | if not os.path.exists("demo/demo_data/out/"): 41 | os.makedirs("demo/demo_data/out/", exist_ok=True) 42 | 43 | out["df"].to_csv("demo/demo_data/out/discovery_search_results.csv") 44 | 45 | model_lib = pai.Library(source=out) 46 | 47 | # plot results 48 | fig, ax, plot_df = model.library.plot( 49 | rep="esm2_8M", use_y_pred=True, highlight_mask=search_mask 50 | ) 51 | plt.savefig("demo/demo_data/out/search_results.png") 52 | -------------------------------------------------------------------------------- /demo/fold_design_demo.py: -------------------------------------------------------------------------------- 1 | ### THIS DEMO ONLY WORKS IF ESM-FOLD IS INSTALLED ### 2 | 3 | import os 4 | import shutil 5 | 6 | import proteusAI as pai 7 | 8 | files = ["demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta"] # input files 9 | 10 | for f in files: 11 | protein = pai.Protein(source=f) 12 | 13 | # fold the protein 14 | fold_out = protein.esm_fold(relax=True) 15 | 16 | # design the now folded protein 17 | design_out = protein.esm_if( 18 | num_samples=100, target_chain="A", temperature=1.0, fixed=[] 19 | ) 20 | 21 | # move results to demo folder 22 | if not os.path.exists("demo/demo_data/out/"): 23 | os.makedirs("demo/demo_data/out/", exist_ok=True) 24 | 25 | # save the design 26 | design_out["df"].to_csv("demo/demo_data/fold_design_out.csv", index=False) 27 | 28 | # move the structures 29 | os.makedirs("demo/demo_data/out/fold_design_pdb/", exist_ok=True) 30 | 31 | # Copy all files from the source to the destination directory 32 | for file in os.listdir(fold_out["struc_path"]): 33 | src = os.path.join(fold_out["struc_path"], file) 34 | dst = os.path.join("demo/demo_data/out/fold_design_pdb/", file) 35 | shutil.copy(src, dst) 36 | -------------------------------------------------------------------------------- /demo/mlde_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import proteusAI as pai 5 | 6 | print(os.getcwd()) 7 | sys.path.append("src/") 8 | 9 | 10 | # will initiate storage space - else in memory 11 | dataset = "demo/demo_data/Nitric_Oxide_Dioxygenase.csv" 12 | 13 | # load data from csv or excel: x should be sequences, y should be labels, y_type class or num 14 | library = pai.Library( 15 | source=dataset, 16 | seqs_col="Sequence", 17 | y_col="Data", 18 | y_type="num", 19 | names_col="Description", 20 | ) 21 | 22 | # compute and save ESM-2 representations at example_lib/representations/esm2 23 | library.compute(method="esm2_8M", batch_size=10) 24 | 25 | # define a model 26 | model = pai.Model(library=library, k_folds=5, model_type="rf", x="vhse") 27 | 28 | # train model 29 | model.train() 30 | 31 | # search for new mutants 32 | out = model.search(optim_problem="max") 33 | 34 | # save results 35 | if not os.path.exists("demo/demo_data/out/"): 36 | os.makedirs("demo/demo_data/out/", exist_ok=True) 37 | 38 | out.to_csv("demo/demo_data/out/demo_search_results.csv") 39 | -------------------------------------------------------------------------------- /demo/plot_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import seaborn as sns 7 | 8 | # Initialize the argparse parser 9 | parser = argparse.ArgumentParser(description="Plot the benchmark results.") 10 | 11 | # Add arguments 12 | parser.add_argument("--rep", type=str, default="blosum62", help="Representation type.") 13 | parser.add_argument("--max-sample", type=int, default=100, help="Maximum sample size.") 14 | parser.add_argument("--model", type=str, default="gp", help="Model name.") 15 | parser.add_argument( 16 | "--acquisition-fn", type=str, default="ei", help="Acquisition function name." 17 | ) 18 | 19 | # Parse the arguments 20 | args = parser.parse_args() 21 | 22 | # Assign parsed arguments to variables 23 | REP = args.rep 24 | MAX_SAMPLE = args.max_sample 25 | MODEL = args.model 26 | ACQ_FN = args.acquisition_fn 27 | 28 | # Dictionaries for pretty names 29 | rep_dict = { 30 | "ohe": "One-hot", 31 | "blosum50": "BLOSUM50", 32 | "blosum62": "BLOSUM62", 33 | "esm2": "ESM-2", 34 | "esm1v": "ESM-1v", 35 | "vae": "VAE", 36 | } 37 | model_dict = { 38 | "rf": "Random Forest", 39 | "knn": "KNN", 40 | "svm": "SVM", 41 | "esm2": "ESM-2", 42 | "esm1v": "ESM-1v", 43 | "gp": "Gaussian Process", 44 | "ridge": "Ridge Regression", 45 | } 46 | acq_dict = { 47 | "ei": "Expected Improvement", 48 | "ucb": "Upper Confidence Bound", 49 | "greedy": "Greedy", 50 | "random": "Random", 51 | } 52 | 53 | # Load data from JSON file 54 | with open(f"usrs/benchmark/first_discovered_data_{MODEL}_{REP}_{ACQ_FN}.json") as f: 55 | data = json.load(f) 56 | 57 | # Define the top N variants and sample sizes 58 | datasets = list(data.keys()) 59 | sample_sizes = [int(i) for i in list(data[datasets[0]].keys())] 60 | top_n_variants = ["5", "10", "20", "50", "improved"] 61 | 62 | # Initialize a list to store the processed data 63 | plot_data = [] 64 | 65 | # Process the data 66 | for dataset, results in data.items(): 67 | for sample_size, rounds in results.items(): 68 | if rounds is not None: 69 | for i, round_count in enumerate(rounds): 70 | if round_count is not None: 71 | plot_data.append( 72 | { 73 | "Dataset": dataset, 74 | "Sample Size": int(sample_size), 75 | "Top N Variants": top_n_variants[i], 76 | "Rounds": round_count, 77 | } 78 | ) 79 | 80 | # Create a DataFrame from the processed data 81 | df = pd.DataFrame(plot_data) 82 | df["Top N Variants"] = pd.Categorical( 83 | df["Top N Variants"], categories=top_n_variants, ordered=True 84 | ) 85 | 86 | # Set plot style 87 | sns.set(style="whitegrid") 88 | 89 | # Create a grid of box plots for each sample size with shared y-axis 90 | g = sns.FacetGrid( 91 | df, col="Sample Size", col_wrap=2, height=5.5, aspect=1.8, sharey=True 92 | ) # Adjusted height and aspect 93 | 94 | 95 | # Function to create box plot and strip plot 96 | def plot_box_and_strip(data, x, y, **kwargs): 97 | ax = plt.gca() 98 | sns.boxplot(data=data, x=x, y=y, color="lightgray", width=0.5, ax=ax) # noqa: F841 99 | sns.stripplot(data=data, x=x, y=y, **kwargs) 100 | 101 | # Add median values as text 102 | medians = data.groupby(x)[y].median() 103 | for i, median in enumerate(medians): 104 | ax.text( 105 | i, median + 1, f"{median:.2f}", ha="center", va="bottom", color="black" 106 | ) # Adjusting the position to be just above the median line 107 | 108 | # Set y-axis limit to MAX_SAMPLE 109 | ax.set_ylim(0, MAX_SAMPLE) 110 | 111 | 112 | # Apply the function to each subplot 113 | g.map_dataframe( 114 | plot_box_and_strip, 115 | x="Top N Variants", 116 | y="Rounds", 117 | hue="Dataset", 118 | dodge=True, 119 | jitter=True, 120 | palette="muted", 121 | ) 122 | 123 | # Draw vertical bars between the sampling sizes 124 | for ax in g.axes.flatten(): 125 | for x in range(len(top_n_variants) - 1): 126 | ax.axvline(x + 0.5, color="grey", linestyle="-", linewidth=0.8) 127 | 128 | # Adjust the title and layout 129 | g.set_titles(col_template="Sample Size: {col_name}") 130 | g.set_axis_labels("Top N Variants", "Rounds") 131 | 132 | rep = rep_dict[REP] 133 | model = model_dict[MODEL] 134 | acq_fn = acq_dict[ACQ_FN] 135 | subtitle = f"Representation: {rep}, Model: {model}, Acquisition Function: {acq_fn}" 136 | 137 | # Move legend to the right side 138 | g.add_legend( 139 | title="Dataset", bbox_to_anchor=(0.87, 0.47), loc="center left", borderaxespad=0 140 | ) 141 | 142 | # Set the main title and subtitle 143 | g.fig.suptitle( 144 | "Rounds to Discover Top N Variants Across Different Sample Sizes", y=0.94, x=0.44 145 | ) 146 | plt.text( 147 | 0.44, 148 | 0.9, 149 | subtitle, 150 | ha="center", 151 | va="center", 152 | fontsize=10, 153 | transform=g.fig.transFigure, 154 | ) 155 | 156 | # Adjust the layout to optimize spacing 157 | plt.tight_layout(pad=2.0) # Adjusted padding 158 | g.fig.subplots_adjust(top=0.85, right=0.85, hspace=0.3, wspace=0.3) # Adjusted spacing 159 | 160 | # Save the plot 161 | plt.savefig( 162 | f"usrs/benchmark/first_discovered_{MODEL}_{REP}_{ACQ_FN}.png", 163 | bbox_inches="tight", 164 | dpi=300, 165 | ) 166 | 167 | # Show the plot 168 | plt.show() 169 | -------------------------------------------------------------------------------- /demo/plot_library_demo.py: -------------------------------------------------------------------------------- 1 | import proteusAI as pai 2 | import matplotlib.pyplot as plt 3 | 4 | # Initialize the library 5 | library = pai.Library( 6 | source="demo/demo_data/Nitric_Oxide_Dioxygenase.csv", 7 | seqs_col="Sequence", 8 | y_col="Data", 9 | y_type="num", 10 | names_col="Description", # Column containing the sequence descriptions 11 | ) 12 | print(library.seqs) 13 | # Compute embeddings using the specified method 14 | library.compute(method="esm2_8M") 15 | 16 | # Generate a UMAP plot 17 | fig, ax, df = library.plot_umap(rep="esm2_8M") 18 | 19 | # Customize and display the plot 20 | plt.show() 21 | -------------------------------------------------------------------------------- /demo/test: -------------------------------------------------------------------------------- 1 | >sampled_seq_1 recovery: 0.5272727272727272 2 | ELYTIWWDGTDLRGSETVTAKDAATAEKTFKAAAAENGVTGTWADDTQKTFTVTG 3 | >sampled_seq_2 recovery: 0.4727272727272727 4 | MIYTLKWQGGTLTGEETVTAADSSTAEAQFKSSAKNAGYSGQWTNQFKKTYTVTG 5 | >sampled_seq_3 recovery: 0.45454545454545453 6 | STYTLVKAGAKQKGQSTVEADDANDAEAAFRRGAADKGISGQWEDATGKTFTVTS 7 | >sampled_seq_4 recovery: 0.41818181818181815 8 | SEYTLEWDGALDTGVAEVEAANAKAATKVFASAAKDNGLKGKWTDRIDKTFTAKG 9 | >sampled_seq_5 recovery: 0.36363636363636365 10 | GTYTLLWDGDTVTGERTVSAASETSAEDKFNRAATEKGITGDWGDRHANTFTATG 11 | >sampled_seq_6 recovery: 0.4727272727272727 12 | GTYTLIWDAATLKGTSTVKASDIAAAQKIFNNMAKDLGVTGTWTDRKDRTYTVTG 13 | >sampled_seq_7 recovery: 0.43636363636363634 14 | GTYTLIWAGDKVHGTSTIAAADQATATALFKKGAKEAGMDGEWDDKASNTYTVVG 15 | >sampled_seq_8 recovery: 0.5636363636363636 16 | GQYTLKLSGGDVTGTTTVSADDKATAEKLFESAAKENGLQGEWNDFATKTFTVTG 17 | >sampled_seq_9 recovery: 0.4 18 | MTYTLVWAGEKAKGTQTVKAKNSAAATADFKESAAANGLSGKWSNHKTKTYTVSG 19 | >sampled_seq_10 recovery: 0.4727272727272727 20 | GIYTIKLDGKTYNGTATVTAADAATAEAKFESLAEQNGMTGAWTDATARTYTVTA 21 | -------------------------------------------------------------------------------- /demo/zero_shot_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import proteusAI as pai 4 | 5 | files = ["demo/demo_data/Nitric_Oxide_Dioxygenase_wt.fasta"] # input files 6 | models = [ 7 | "esm2_8M" 8 | ] # Ideally use esm1v. ems2_8M is fast and used for demo purposes. esm1v is more accurate. 9 | 10 | for f in files: 11 | protein = pai.Protein(source=f) 12 | 13 | for model in models: 14 | out = protein.zs_prediction(model=model) 15 | 16 | # save results 17 | if not os.path.exists("demo/demo_data/out/"): 18 | os.makedirs("demo/demo_data/out/", exist_ok=True) 19 | 20 | out["df"].to_csv( 21 | f"demo/demo_data/zs_outdemo_{f.split('/')[-1][:-6]}.csv", index=False 22 | ) 23 | 24 | print(out["df"]) 25 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | jupyter_execute 3 | reference -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Docs creation 2 | 3 | In order to build the docs you need to 4 | 5 | 1. install sphinx and additional support packages 6 | 2. build the package reference files 7 | 3. run sphinx to create a local html version 8 | 9 | The documentation is build using readthedocs automatically. 10 | 11 | Install the docs dependencies of the package (as speciefied in toml): 12 | 13 | ```bash 14 | # in main folder 15 | pip install ".[docs]" 16 | ``` 17 | 18 | ## Build docs using Sphinx command line tools 19 | 20 | Command to be run from `path/to/docs`, i.e. from within the `docs` package folder: 21 | 22 | Options: 23 | - `--separate` to build separate pages for each (sub-)module 24 | 25 | ```bash 26 | # pwd: docs 27 | # apidoc 28 | sphinx-apidoc --force --implicit-namespaces --module-first -o reference ../src/proteusAI 29 | # build docs 30 | sphinx-build -n -W --keep-going -b html ./ ./_build/ 31 | ``` 32 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | from importlib import metadata 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = "proteusAI" 20 | copyright = "2024, Jonathan Funk" 21 | author = "Jonathan Funk" 22 | PACKAGE_VERSION = metadata.version("proteusAI") 23 | version = PACKAGE_VERSION 24 | release = PACKAGE_VERSION 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | "sphinx.ext.autodoc", 34 | "sphinx.ext.autodoc.typehints", 35 | "sphinx.ext.viewcode", 36 | "sphinx.ext.napoleon", 37 | "sphinx.ext.intersphinx", 38 | "sphinx_new_tab_link", 39 | "myst_nb", 40 | ] 41 | 42 | # https://myst-nb.readthedocs.io/en/latest/computation/execute.html 43 | nb_execution_mode = "auto" 44 | 45 | myst_enable_extensions = ["dollarmath", "amsmath"] 46 | 47 | # Plolty support through require javascript library 48 | # https://myst-nb.readthedocs.io/en/latest/render/interactive.html#plotly 49 | html_js_files = [ 50 | "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js" 51 | ] 52 | 53 | # https://myst-nb.readthedocs.io/en/latest/configuration.html 54 | # Execution 55 | nb_execution_raise_on_error = True 56 | # Rendering 57 | nb_merge_streams = True 58 | 59 | # https://myst-nb.readthedocs.io/en/latest/authoring/custom-formats.html#write-custom-formats 60 | nb_custom_formats = {".py": ["jupytext.reads", {"fmt": "py:percent"}]} 61 | 62 | # Add any paths that contain templates here, relative to this directory. 63 | templates_path = ["_templates"] 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path. 68 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "jupyter_execute", "conf.py"] 69 | 70 | 71 | # Intersphinx options 72 | intersphinx_mapping = { 73 | "python": ("https://docs.python.org/3", None), 74 | # "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 75 | # "scikit-learn": ("https://scikit-learn.org/stable/", None), 76 | # "matplotlib": ("https://matplotlib.org/stable/", None), 77 | } 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # See: 84 | # https://github.com/executablebooks/MyST-NB/blob/master/docs/conf.py 85 | # html_title = "" 86 | html_theme = "sphinx_book_theme" 87 | # html_logo = "_static/logo-wide.svg" 88 | # html_favicon = "_static/logo-square.svg" 89 | html_theme_options = { 90 | "github_url": "https://github.com/jonfunk21/proteusAI", 91 | "repository_url": "https://github.com/jonfunk21/proteusAI", 92 | "repository_branch": "main", 93 | "home_page_in_toc": True, 94 | "path_to_docs": "docs", 95 | "show_navbar_depth": 1, 96 | "use_edit_page_button": True, 97 | "use_repository_button": True, 98 | "use_download_button": True, 99 | "launch_buttons": { 100 | "colab_url": "https://colab.research.google.com" 101 | # "binderhub_url": "https://mybinder.org", 102 | # "notebook_interface": "jupyterlab", 103 | }, 104 | "navigation_with_keys": False, 105 | } 106 | 107 | # Add any paths that contain custom static files (such as style sheets) here, 108 | # relative to this directory. They are copied after the builtin static files, 109 | # so a file named "default.css" will overwrite the builtin "default.css". 110 | # html_static_path = ["_static"] 111 | 112 | 113 | # -- Setup for sphinx-apidoc ------------------------------------------------- 114 | 115 | # Read the Docs doesn't support running arbitrary commands like tox. 116 | # sphinx-apidoc needs to be called manually if Sphinx is running there. 117 | # https://github.com/readthedocs/readthedocs.org/issues/1139 118 | 119 | if os.environ.get("READTHEDOCS") == "True": 120 | from pathlib import Path 121 | 122 | PROJECT_ROOT = Path(__file__).parent.parent 123 | PACKAGE_ROOT = PROJECT_ROOT / "src" / "proteusAI" 124 | 125 | def run_apidoc(_): 126 | from sphinx.ext import apidoc 127 | 128 | apidoc.main( 129 | [ 130 | "--force", 131 | "--implicit-namespaces", 132 | "--module-first", 133 | "--separate", 134 | "-o", 135 | str(PROJECT_ROOT / "docs" / "reference"), 136 | str(PACKAGE_ROOT), 137 | str(PACKAGE_ROOT / "*.c"), 138 | str(PACKAGE_ROOT / "*.so"), 139 | ] 140 | ) 141 | 142 | def setup(app): 143 | app.connect("builder-inited", run_apidoc) 144 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | The proteusAI package 2 | ===================== 3 | 4 | .. include:: ../README.md 5 | :parser: myst_parser.sphinx_ 6 | :start-line: 1 7 | 8 | 9 | .. toctree:: 10 | :hidden: 11 | :maxdepth: 2 12 | :caption: Tutorial 13 | 14 | tutorial/tutorial 15 | 16 | .. toctree:: 17 | :hidden: 18 | :maxdepth: 2 19 | :caption: Contents: 20 | 21 | reference/modules 22 | 23 | .. toctree:: 24 | :hidden: 25 | :caption: Technical notes 26 | 27 | README.md 28 | 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | * :ref:`search` 37 | -------------------------------------------------------------------------------- /docs/tutorial/tutorial.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Tutorial 3 | 4 | # %% 5 | import proteusAI as pai 6 | 7 | pai.__version__ 8 | 9 | # # MLDE Tutorial 10 | # %% [markdown] 11 | # ## Loading data from csv or excel 12 | # The data must contain the mutant sequences and y_values for the MLDE workflow. It is recommended to have useful sequence names for later interpretability, and to define the data type. 13 | 14 | # %% 15 | library = pai.Library( 16 | source="demo/demo_data/Nitric_Oxide_Dioxygenase_raw.csv", 17 | seqs_col="Sequence", 18 | y_col="Data", 19 | y_type="num", 20 | names_col="Description", 21 | ) 22 | 23 | 24 | # %% [markdown] 25 | # ## Compute representations (skipped here to save time) 26 | # The available representations are ('esm2', 'esm1v', 'blosum62', 'blosum50', and 'ohe') 27 | 28 | # %% 29 | # library.compute(method='esm2', batch_size=10) # uncomment this line to compute esm2 representations 30 | 31 | 32 | # %% [markdown] 33 | # ## Define a model, using fast to compute BLOSUM62 representations (good for demo purposes and surprisingly competetive with esm2 representations). 34 | 35 | # %% 36 | model = pai.Model(library=library) 37 | 38 | # %% [markdown] 39 | # ## Training the model 40 | 41 | # %% 42 | _ = model.train(k_folds=5, model_type="rf", x="blosum62") 43 | 44 | 45 | # %% [markdown] 46 | # ## Search for new mutants 47 | # Searching new mutants will produce an output dataframe containing the new predictions. Here we are using the expected improvement ('ei') acquisition function. 48 | 49 | # %% 50 | out = model.search(acq_fn="ei") 51 | out 52 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: proteusAI 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - anaconda 8 | dependencies: 9 | - python=3.8 10 | # - pytorch-scatter 11 | - pytorch=2.4.1 12 | # - pyg 13 | # - cpuonly 14 | - pip 15 | - sphinx 16 | - pdbfixer 17 | - pip: 18 | - -e .[docs] -f https://data.pyg.org/whl/torch-2.4.1+cpu.html 19 | -------------------------------------------------------------------------------- /favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/favicon.ico -------------------------------------------------------------------------------- /proteusEnvironment.yml: -------------------------------------------------------------------------------- 1 | name: proteusAI_depl 2 | channels: 3 | - anaconda 4 | - pyg 5 | - pytorch 6 | - nvidia 7 | - conda-forge 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_kmp_llvm 11 | - aiohttp=3.9.3=py38h01eb140_1 12 | - aiosignal=1.3.1=pyhd8ed1ab_0 13 | - anyio=4.4.0=pyhd8ed1ab_0 14 | - appdirs=1.4.4=pyh9f0ad1d_0 15 | - asgiref=3.8.1=pyhd8ed1ab_0 16 | - asttokens=2.4.1=pyhd8ed1ab_0 17 | - async-timeout=4.0.3=pyhd8ed1ab_0 18 | - attrs=23.2.0=pyh71513ae_0 19 | - backcall=0.2.0=pyh9f0ad1d_0 20 | - blas=2.116=mkl 21 | - blas-devel=3.9.0=16_linux64_mkl 22 | - brotli-python=1.1.0=py38h17151c0_1 23 | - bzip2=1.0.8=hd590300_5 24 | - ca-certificates=2024.7.4=hbcca054_0 25 | - certifi=2024.7.4=pyhd8ed1ab_0 26 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 27 | - click=8.1.7=unix_pyh707e725_0 28 | - colorama=0.4.6=pyhd8ed1ab_0 29 | - comm=0.2.2=pyhd8ed1ab_0 30 | - cuda-cudart=12.1.105=0 31 | - cuda-cupti=12.1.105=0 32 | - cuda-libraries=12.1.0=0 33 | - cuda-nvrtc=12.1.105=0 34 | - cuda-nvtx=12.1.105=0 35 | - cuda-opencl=12.4.127=0 36 | - cuda-runtime=12.1.0=0 37 | - cudatoolkit=11.8.0=h4ba93d1_13 38 | - debugpy=1.8.1=py38h17151c0_0 39 | - decorator=5.1.1=pyhd8ed1ab_0 40 | - exceptiongroup=1.2.0=pyhd8ed1ab_2 41 | - executing=2.0.1=pyhd8ed1ab_0 42 | - ffmpeg=4.3=hf484d3e_0 43 | - filelock=3.13.3=pyhd8ed1ab_0 44 | - freetype=2.12.1=h267a509_2 45 | - frozenlist=1.4.1=py38h01eb140_0 46 | - fsspec=2024.3.1=pyhca7485f_0 47 | - gmp=6.3.0=h59595ed_1 48 | - gmpy2=2.1.2=py38h793c122_1 49 | - gnutls=3.6.13=h85f3911_1 50 | - gpytorch=1.12=pyhd8ed1ab_0 51 | - h11=0.14.0=pyhd8ed1ab_0 52 | - htmltools=0.5.2=pyhd8ed1ab_0 53 | - icu=73.2=h59595ed_0 54 | - idna=3.6=pyhd8ed1ab_0 55 | - importlib-metadata=7.1.0=pyha770c72_0 56 | - importlib_metadata=7.1.0=hd8ed1ab_0 57 | - ipykernel=6.29.3=pyhd33586a_0 58 | - ipython=8.12.2=pyh41d4057_0 59 | - jaxtyping=0.2.19=pyhd8ed1ab_0 60 | - jedi=0.19.1=pyhd8ed1ab_0 61 | - jinja2=3.1.3=pyhd8ed1ab_0 62 | - joblib=1.3.2=pyhd8ed1ab_0 63 | - jpeg=9e=h166bdaf_2 64 | - jupyter_client=8.6.2=pyhd8ed1ab_0 65 | - jupyter_core=5.7.2=py38h578d9bd_0 66 | - keyutils=1.6.1=h166bdaf_0 67 | - krb5=1.21.2=h659d440_0 68 | - lame=3.100=h166bdaf_1003 69 | - lcms2=2.15=hfd0df8a_0 70 | - ld_impl_linux-64=2.40=h41732ed_0 71 | - lerc=4.0.0=h27087fc_0 72 | - libblas=3.9.0=16_linux64_mkl 73 | - libcblas=3.9.0=16_linux64_mkl 74 | - libcublas=12.1.0.26=0 75 | - libcufft=11.0.2.4=0 76 | - libcufile=1.9.1.3=0 77 | - libcurand=10.3.5.147=0 78 | - libcusolver=11.4.4.55=0 79 | - libcusparse=12.0.2.55=0 80 | - libdeflate=1.17=h0b41bf4_0 81 | - libedit=3.1.20191231=he28a2e2_2 82 | - libffi=3.4.2=h7f98852_5 83 | - libgcc-ng=13.2.0=h807b86a_5 84 | - libgfortran-ng=13.2.0=h69a702a_5 85 | - libgfortran5=13.2.0=ha4646dd_5 86 | - libgomp=13.2.0=h807b86a_5 87 | - libhwloc=2.9.3=default_h554bfaf_1009 88 | - libiconv=1.17=hd590300_2 89 | - libjpeg-turbo=2.0.0=h9bf148f_0 90 | - liblapack=3.9.0=16_linux64_mkl 91 | - liblapacke=3.9.0=16_linux64_mkl 92 | - libllvm14=14.0.6=hcd5def8_4 93 | - libnpp=12.0.2.50=0 94 | - libnsl=2.0.1=hd590300_0 95 | - libnvjitlink=12.1.105=0 96 | - libnvjpeg=12.1.1.14=0 97 | - libpng=1.6.43=h2797004_0 98 | - libsodium=1.0.18=h36c2ea0_1 99 | - libsqlite=3.45.2=h2797004_0 100 | - libstdcxx-ng=13.2.0=h7e041cc_5 101 | - libtiff=4.5.0=h6adf6a1_2 102 | - libuuid=2.38.1=h0b41bf4_0 103 | - libwebp-base=1.3.2=hd590300_0 104 | - libxcb=1.13=h7f98852_1004 105 | - libxcrypt=4.4.36=hd590300_1 106 | - libxml2=2.12.6=h232c23b_1 107 | - libzlib=1.2.13=hd590300_5 108 | - linear_operator=0.5.2=pyhd8ed1ab_0 109 | - linkify-it-py=2.0.3=pyhd8ed1ab_0 110 | - llvm-openmp=15.0.7=h0cdce71_0 111 | - llvmlite=0.41.1=py38h94a1851_0 112 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 113 | - markupsafe=2.1.5=py38h01eb140_0 114 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 115 | - mdit-py-plugins=0.4.1=pyhd8ed1ab_0 116 | - mdurl=0.1.2=pyhd8ed1ab_0 117 | - mkl=2022.1.0=h84fe81f_915 118 | - mkl-devel=2022.1.0=ha770c72_916 119 | - mkl-include=2022.1.0=h84fe81f_915 120 | - mpc=1.3.1=hfe3b2da_0 121 | - mpfr=4.2.1=h9458935_1 122 | - mpmath=1.3.0=pyhd8ed1ab_0 123 | - multidict=6.0.5=py38h01eb140_0 124 | - ncurses=6.4.20240210=h59595ed_0 125 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 126 | - nettle=3.6=he412f7d_0 127 | - networkx=3.1=pyhd8ed1ab_0 128 | - numba=0.58.1=py38h4144172_0 129 | - numpy=1.23.5=py38h7042d01_0 130 | - ocl-icd=2.3.2=hd590300_1 131 | - ocl-icd-system=1.0.0=1 132 | - openh264=2.1.1=h780b84a_0 133 | - openjpeg=2.5.0=hfec8fc6_2 134 | - openmm=8.1.1=py38h10fed7f_1 135 | - openssl=3.3.1=h4ab18f5_1 136 | - packaging=24.0=pyhd8ed1ab_0 137 | - pandas=2.0.3=py38h01efb38_1 138 | - parso=0.8.4=pyhd8ed1ab_0 139 | - pdbfixer=1.9=pyh1a96a4e_0 140 | - pexpect=4.9.0=pyhd8ed1ab_0 141 | - pickleshare=0.7.5=py_1003 142 | - pillow=9.4.0=py38hde6dc18_1 143 | - pip=24.0=pyhd8ed1ab_0 144 | - platformdirs=4.2.0=pyhd8ed1ab_0 145 | - pooch=1.8.1=pyhd8ed1ab_0 146 | - prompt-toolkit=3.0.42=pyha770c72_0 147 | - prompt_toolkit=3.0.42=hd8ed1ab_0 148 | - psutil=5.9.8=py38h01eb140_0 149 | - pthread-stubs=0.4=h36c2ea0_1001 150 | - ptyprocess=0.7.0=pyhd3deb0d_0 151 | - pure_eval=0.2.2=pyhd8ed1ab_0 152 | - pyg=2.5.2=py38_torch_2.2.0_cu121 153 | - pygments=2.18.0=pyhd8ed1ab_0 154 | - pynndescent=0.5.13=pyhff2d567_0 155 | - pyparsing=3.1.2=pyhd8ed1ab_0 156 | - pysocks=1.7.1=pyha2e5f31_6 157 | - python=3.8.19=hd12c33a_0_cpython 158 | - python-dateutil=2.9.0=pyhd8ed1ab_0 159 | - python-multipart=0.0.9=pyhd8ed1ab_0 160 | - python-tzdata=2024.1=pyhd8ed1ab_0 161 | - python_abi=3.8=4_cp38 162 | - pytorch=2.2.2=py3.8_cuda12.1_cudnn8.9.2_0 163 | - pytorch-cuda=12.1=ha16c6d3_5 164 | - pytorch-mutex=1.0=cuda 165 | - pytz=2024.1=pyhd8ed1ab_0 166 | - pyyaml=6.0.1=py38h01eb140_1 167 | - pyzmq=26.0.3=py38ha44f8e3_0 168 | - questionary=2.0.1=pyhd8ed1ab_0 169 | - readline=8.2=h8228510_1 170 | - requests=2.31.0=pyhd8ed1ab_0 171 | - scikit-learn=1.3.2=py38ha25d942_2 172 | - scipy=1.10.1=py38h59b608b_3 173 | - setuptools=69.2.0=pyhd8ed1ab_0 174 | - shiny=0.10.2=pyhd8ed1ab_0 175 | - six=1.16.0=pyh6c4a22f_0 176 | - sniffio=1.3.1=pyhd8ed1ab_0 177 | - stack_data=0.6.2=pyhd8ed1ab_0 178 | - starlette=0.37.2=pyhd8ed1ab_0 179 | - sympy=1.12=pypyh9d50eac_103 180 | - tbb=2021.11.0=h00ab1b0_1 181 | - threadpoolctl=3.4.0=pyhc1e730c_0 182 | - tk=8.6.13=noxft_h4845f30_101 183 | - torchaudio=2.2.2=py38_cu121 184 | - torchtriton=2.2.0=py38 185 | - torchvision=0.17.2=py38_cu121 186 | - tornado=6.4=py38h01eb140_0 187 | - tqdm=4.66.2=pyhd8ed1ab_0 188 | - traitlets=5.14.3=pyhd8ed1ab_0 189 | - typeguard=2.13.3=pyhd8ed1ab_0 190 | - typing-extensions=4.11.0=hd8ed1ab_0 191 | - typing_extensions=4.11.0=pyha770c72_0 192 | - uc-micro-py=1.0.3=pyhd8ed1ab_0 193 | - umap-learn=0.5.4=py38h06a4308_0 194 | - urllib3=2.2.1=pyhd8ed1ab_0 195 | - uvicorn=0.30.1=py38h578d9bd_0 196 | - watchfiles=0.22.0=py38h31a4407_0 197 | - wcwidth=0.2.13=pyhd8ed1ab_0 198 | - websockets=12.0=py38h01eb140_0 199 | - wheel=0.43.0=pyhd8ed1ab_1 200 | - xorg-libxau=1.0.11=hd590300_0 201 | - xorg-libxdmcp=1.1.3=h7f98852_0 202 | - xz=5.2.6=h166bdaf_0 203 | - yaml=0.2.5=h7f98852_2 204 | - yarl=1.9.4=py38h01eb140_0 205 | - zeromq=4.3.5=h75354e8_4 206 | - zipp=3.17.0=pyhd8ed1ab_0 207 | - zlib=1.2.13=hd590300_5 208 | - zstd=1.5.5=hfc55251_0 209 | - biotite=0.35.0 210 | - seaborn=0.13.2 211 | - pyg::pytorch-scatter=2.1.2 212 | - py3dmol=2.3.0 213 | - pip: 214 | - fair-esm=2.0.0 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "proteusAI" 7 | version = "0.1.1" 8 | requires-python = ">= 3.8" 9 | description = "ProteusAI is a python package designed for AI driven protein engineering." 10 | readme = "README.md" 11 | 12 | dependencies = [ 13 | "torch==2.4.1", 14 | "torch_geometric", 15 | "torch-scatter", 16 | "uvicorn", 17 | "asgiref", 18 | "starlette", 19 | #"pdbfixer", 20 | "shiny", 21 | "pandas", 22 | "numpy", 23 | "requests", 24 | "scipy", 25 | "fair-esm", 26 | "matplotlib", 27 | "biopython", 28 | "biotite", 29 | "scikit-learn", 30 | "optuna", 31 | "seaborn", 32 | "plotly", 33 | "openpyxl", 34 | "py3Dmol", 35 | "gpytorch", 36 | "openmm", 37 | "umap-learn", 38 | "hdbscan", 39 | "proteusAI", 40 | #"pdbfixer @ git+https://github.com/openmm/pdbfixer@1.9" 41 | ] 42 | 43 | [project.optional-dependencies] 44 | docs = [ 45 | "sphinx", 46 | "sphinx-book-theme", 47 | "myst-nb", 48 | "ipywidgets", 49 | "sphinx-new-tab-link!=0.2.2", 50 | "jupytext", 51 | ] 52 | dev = ["black", "ruff", "pytest", "flake8", "flake8-import-order", "flake8-builtins", "flake8-bugbear"] 53 | 54 | # Handle imports with flake8 55 | [tool.flake8] 56 | max-line-length = 88 57 | import-order-style = "google" 58 | application-import-names = ["proteusAI"] 59 | 60 | -------------------------------------------------------------------------------- /run_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the possible combinations of models and embeddings 4 | representations=("esm2" "blosum62" "ohe") # "blosum50" 5 | acq_fns=("ei" "greedy" "ucb") # "random" 6 | models=($1) # "gp" "rf" "ridge" "svm" "knn" 7 | 8 | for acq_fn in "${acq_fns[@]}"; do 9 | for rep in "${representations[@]}"; do 10 | for model in "${models[@]}"; do 11 | # Execute the script with the current combination of model and embedding 12 | python demo/MLDE_benchmark.py --model "$model" --rep "$rep" --acquisition_fn "$acq_fn" --max-iter 100 13 | python demo/plot_benchmark.py --rep "$rep" --max-sample 100 --model "$model" --acquisition-fn "$acq_fn" 14 | done 15 | done 16 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | author="Jonathan Funk", 5 | author_email="funk.jonathan21@gmail.com", 6 | url="https://github.com/jonfunk21/ProteusAI", 7 | packages=find_packages("src"), 8 | package_dir={"": "src"}, 9 | classifiers=[ 10 | "Development Status :: 3 - Alpha", 11 | "Intended Audience :: Developers", 12 | "License :: OSI Approved :: MIT License", 13 | "Programming Language :: Python :: 3", 14 | "Programming Language :: Python :: 3.6", 15 | "Programming Language :: Python :: 3.7", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | ], 19 | python_requires=">=3.8", 20 | ) 21 | -------------------------------------------------------------------------------- /src/proteusAI/Library/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage is concerned with the Library object of ProteusAI. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from .library import Library # noqa: F401 12 | -------------------------------------------------------------------------------- /src/proteusAI/Model/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage is concerned with the Model object of ProteusAI. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from .model import Model # noqa: F401 12 | -------------------------------------------------------------------------------- /src/proteusAI/Protein/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage is concerned with the Protein object of ProteusAI. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from .protein import Protein # noqa: F401 12 | -------------------------------------------------------------------------------- /src/proteusAI/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | 5 | from importlib import metadata 6 | 7 | __version__ = metadata.version("proteusAI") 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from .Protein import * # noqa: F403 12 | from .Library import * # noqa: F403 13 | from .Model import * # noqa: F403 14 | -------------------------------------------------------------------------------- /src/proteusAI/data_tools/MSA.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import biotite.sequence.graphics as graphics 8 | import biotite.application.muscle as muscle 9 | import matplotlib.pyplot as plt 10 | from Bio.SeqRecord import SeqRecord 11 | from Bio.Seq import Seq 12 | from Bio.Align.Applications import ClustalwCommandline 13 | from Bio import AlignIO 14 | from Bio import SeqIO 15 | from collections import Counter 16 | from biotite.sequence import ProteinSequence 17 | 18 | 19 | def align_proteins( 20 | names: list, 21 | seqs: list, 22 | plot_results: bool = False, 23 | plt_range: tuple = (0, 200), 24 | muscle_version: str = "5", 25 | save_fig: str = None, 26 | save_fasta: str = None, 27 | figsize: tuple = (10.0, 8.0), 28 | ): 29 | """ 30 | performs multiple sequence alignement given a list of blast names and the corresponding sequences. 31 | 32 | Parameters: 33 | names (list): list of sequence names 34 | seqs (list): list of sequences for MSA 35 | plot_results (bool, optional): plot results. 36 | plt_range (tuple, optional): range of sequence which is plotted. Default first 200 amino acids. 37 | muscle_version (str, optional): which muscle version is installed on your machine. Default 5. 38 | save_fig (str): saves fig if path is provided. Default None - won't save the figure. 39 | save_fasta (str): saves fasta if path is provided. Default None - won't save fasta. 40 | figsize (tuple): dimensions of the figure. 41 | 42 | Returns: 43 | dict: MSA results of sequence names/ids and gapped sequences 44 | """ 45 | 46 | # Convert sequences to ProteinSequence objects if they are strings 47 | seqs = [ProteinSequence(seq) if isinstance(seq, str) else seq for seq in seqs] 48 | 49 | if muscle_version == "5": 50 | app = muscle.Muscle5App(seqs) 51 | elif muscle_version == "3": 52 | app = muscle.MuscleApp(seqs) 53 | else: 54 | raise ValueError("Muscle version must be either 3 or 5") 55 | 56 | app.start() 57 | app.join() 58 | alignment = app.get_alignment() 59 | 60 | # Print the MSA with hit IDs 61 | gapped_seqs = alignment.get_gapped_sequences() 62 | 63 | MSA_results = {} 64 | for i in range(len(gapped_seqs)): 65 | MSA_results[names[i]] = gapped_seqs[i] 66 | 67 | # Reorder alignments to reflect sequence distance 68 | fig = plt.figure(figsize=figsize) 69 | ax = fig.add_subplot(111) 70 | order = app.get_alignment_order() 71 | graphics.plot_alignment_type_based( 72 | ax, 73 | alignment[plt_range[0] : plt_range[1], order.tolist()], 74 | labels=[names[i] for i in order], 75 | show_numbers=True, 76 | color_scheme="clustalx", 77 | ) 78 | fig.tight_layout() 79 | 80 | if save_fig is not None: 81 | plt.savefig(save_fig) 82 | 83 | if plot_results: 84 | plt.show() 85 | else: 86 | plt.close() 87 | 88 | if save_fasta is not None: 89 | with open(save_fasta, "w") as f: 90 | for i, key in enumerate(MSA_results.keys()): 91 | s = MSA_results[key] 92 | if i < len(MSA_results) - 1: 93 | f.writelines(f">{key}\n") 94 | f.writelines(f"{s}\n") 95 | else: 96 | f.writelines(f">{key}\n") 97 | f.writelines(f"{s}") 98 | 99 | return MSA_results 100 | 101 | 102 | def MSA_results_to_fasta(MSA_results: dict, fname: str): 103 | """ 104 | Takes MSA results from the align proteins function and writes then into a fasta format. 105 | 106 | Parameters: 107 | MSA_results (dict): Dictionary of MSA results. 108 | fname (str): file name. 109 | 110 | Returns: 111 | None 112 | """ 113 | with open(fname, "w") as f: 114 | for i, key in enumerate(MSA_results.keys()): 115 | s = MSA_results[key] 116 | if i < len(MSA_results) - 1: 117 | f.writelines(f">{key}\n") 118 | f.writelines(f"{s}\n") 119 | else: 120 | f.writelines(f">{key}\n") 121 | f.writelines(f"{s}") 122 | 123 | 124 | def align_dna(dna_sequences: list, verbose: bool = False): 125 | """ 126 | performs multiple sequence alignement given a list of blast names and the corresponding sequences. 127 | The function uses the clustalw2 app which uses the global needleman-wunsch algorithm. 128 | 129 | Parameters: 130 | seqs (list): list of DNA sequences for MSA 131 | verbose (bool): print std out and std error if True. Default False 132 | 133 | Returns: 134 | list: MSA list of sequences 135 | """ 136 | # Create SeqRecord objects for each DNA sequence 137 | seq_records = [ 138 | SeqRecord(Seq(seq), id=f"seq{i + 1}") for i, seq in enumerate(dna_sequences) 139 | ] 140 | 141 | # Write the DNA sequences to a temporary file in FASTA format 142 | temp_input_file = "temp_input.fasta" 143 | with open(temp_input_file, "w") as handle: 144 | SeqIO.write(seq_records, handle, "fasta") 145 | 146 | # Run ClustalW to perform the multiple sequence alignment 147 | temp_output_file = "temp_output.aln" 148 | clustalw_cline = ClustalwCommandline( 149 | "clustalw2", infile=temp_input_file, outfile=temp_output_file 150 | ) 151 | stdout, stderr = clustalw_cline() 152 | 153 | if verbose: 154 | print("std out:") 155 | print("-------") 156 | print(stdout) 157 | print("std error:") 158 | print("----------") 159 | print(stderr) 160 | 161 | # Read in the aligned sequences from the output file 162 | aligned_seqs = [] 163 | with open(temp_output_file) as handle: 164 | alignment = AlignIO.read(handle, "clustal") 165 | for record in alignment: 166 | aligned_seqs.append(str(record.seq)) 167 | 168 | return aligned_seqs 169 | 170 | 171 | def get_consensus_sequence(dna_sequences: list): 172 | """ 173 | Calculates the consensus sequence of multiple sequence alignements. 174 | It uses the most common character of every sequence in a list. All 175 | sequences need to be the same length. 176 | 177 | Parameters: 178 | dna_sequences (list): list of DNA sequences. 179 | 180 | Returns: 181 | str: consensus sequence. 182 | """ 183 | # Get the length of the sequences 184 | sequence_length = len(dna_sequences[0]) 185 | 186 | # Iterate over each position in the sequences 187 | consensus_sequence = "" 188 | for i in range(sequence_length): 189 | # Create a list of the characters at this position 190 | char_list = [seq[i] for seq in dna_sequences] 191 | 192 | # Count the occurrences of each character 193 | char_counts = Counter(char_list) 194 | 195 | # Get the most common character 196 | most_common_char = char_counts.most_common(1)[0][0] 197 | 198 | # Add the most common character to the consensus sequence 199 | consensus_sequence += most_common_char 200 | 201 | return consensus_sequence 202 | -------------------------------------------------------------------------------- /src/proteusAI/data_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for data_tools engineering, visualization and analysis. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.data_tools.pdb import * # noqa: F403 12 | from proteusAI.data_tools.MSA import * # noqa: F403 13 | -------------------------------------------------------------------------------- /src/proteusAI/data_tools/pdb.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | from biotite.structure.io.mol import MOLFile 8 | from biotite.structure.io.pdb import PDBFile 9 | import numpy as np 10 | import biotite.structure as struc 11 | import biotite.structure.io as strucio 12 | from Bio import SeqIO 13 | import py3Dmol 14 | from string import ascii_uppercase, ascii_lowercase 15 | 16 | 17 | alphabet_list = list(ascii_uppercase + ascii_lowercase) 18 | pymol_color_list = [ 19 | "#33ff33", 20 | "#00ffff", 21 | "#ff33cc", 22 | "#ffff00", 23 | "#ff9999", 24 | "#e5e5e5", 25 | "#7f7fff", 26 | "#ff7f00", 27 | "#7fff7f", 28 | "#199999", 29 | "#ff007f", 30 | "#ffdd5e", 31 | "#8c3f99", 32 | "#b2b2b2", 33 | "#007fff", 34 | "#c4b200", 35 | "#8cb266", 36 | "#00bfbf", 37 | "#b27f7f", 38 | "#fcd1a5", 39 | "#ff7f7f", 40 | "#ffbfdd", 41 | "#7fffff", 42 | "#ffff7f", 43 | "#00ff7f", 44 | "#337fcc", 45 | "#d8337f", 46 | "#bfff3f", 47 | "#ff7fff", 48 | "#d8d8ff", 49 | "#3fffbf", 50 | "#b78c4c", 51 | "#339933", 52 | "#66b2b2", 53 | "#ba8c84", 54 | "#84bf00", 55 | "#b24c66", 56 | "#7f7f7f", 57 | "#3f3fa5", 58 | "#a5512b", 59 | ] 60 | 61 | 62 | def get_atom_array(file_path): 63 | """ 64 | Returns atom array for a pdb file. 65 | 66 | Parameters: 67 | file_path (str): path to pdb file 68 | 69 | Returns: 70 | atom array 71 | """ 72 | if file_path.endswith("pdb"): 73 | atom_mol = PDBFile.read(file_path) 74 | atom_array = atom_mol.get_structure() 75 | else: 76 | try: 77 | atom_mol = MOLFile.read(file_path) 78 | atom_array = atom_mol.get_structure() 79 | except Exception as e: 80 | raise ValueError(f"File: {file_path} has an invalid format. Error: {e}") 81 | 82 | return atom_array 83 | 84 | 85 | def mol_contacts(mols, protein, dist=4.0): 86 | """ 87 | Get residue ids of contacts of small molecule(s) to protein. 88 | 89 | Parameters: 90 | mols (str, list of strings): path to molecule file 91 | protein (str): path to protein file 92 | dist (float): distance to be considered contact 93 | 94 | Returns: 95 | set: res_ids of protein residues which are in contact with molecules 96 | """ 97 | if isinstance(mols, list) or isinstance(mols, tuple): 98 | mols = [get_atom_array(m) for m in mols] 99 | else: 100 | mols = [get_atom_array(mols)] 101 | 102 | protein = get_atom_array(protein) 103 | 104 | res_ids = set() 105 | for mol in mols: 106 | cell_list = struc.CellList(mol, cell_size=dist) 107 | for prot in protein: 108 | contacts = cell_list.get_atoms(prot.coord, radius=dist) 109 | 110 | contact_indices = np.where((contacts != -1).any(axis=1))[0] 111 | 112 | contact_res_ids = prot.res_id[contact_indices] 113 | res_ids.update(contact_res_ids) 114 | 115 | res_ids = sorted(res_ids) 116 | return res_ids 117 | 118 | 119 | def pdb_to_fasta(pdb_path: str): 120 | """ 121 | Returns fasta sequence of pdb file. 122 | 123 | Parameters: 124 | pdb_path (str): path to pdb file 125 | 126 | Returns: 127 | str: fasta sequence string 128 | """ 129 | seq = "" 130 | with open(pdb_path, "r") as pdb_file: 131 | for record in SeqIO.parse(pdb_file, "pdb-atom"): 132 | seq = "".join([">", str(record.id), "\n", str(record.seq)]) 133 | 134 | return seq 135 | 136 | 137 | def get_sse(pdb_path: str): 138 | """ 139 | Returns the secondary structure of a protein given the pdb file. 140 | The secondary structure is infered using the P-SEA algorithm 141 | as imple-mented by the biotite Python package. 142 | 143 | Parameters: 144 | pdb_path (str): path to pdb 145 | """ 146 | array = strucio.load_structure(pdb_path) 147 | sse = struc.annotate_sse(array) 148 | sse = "".join(sse) 149 | return sse 150 | 151 | 152 | ### Visualization 153 | def show_pdb( 154 | pdb_path, 155 | color="confidence", 156 | vmin=50, 157 | vmax=90, 158 | chains=None, 159 | Ls=None, 160 | size=(800, 480), 161 | show_sidechains=False, 162 | show_mainchains=False, 163 | highlight=None, 164 | ): 165 | """ 166 | This function displays the 3D structure of a protein from a given PDB file in a Jupyter notebook. 167 | The protein structure can be colored by chain, rainbow, pLDDT, or confidence value. The size of the 168 | display can be changed. The sidechains and mainchains can be displayed or hidden. 169 | Parameters: 170 | pdb_path (str): The filename of the PDB file that contains the protein structure. 171 | color (str, optional): The color scheme for the protein structure. Can be "chain", "rainbow", "pLDDT", or "confidence". Defaults to "rainbow". 172 | vmin (float, optional): The minimum value of pLDDT or confidence value. Defaults to 50. 173 | vmax (float, optional): The maximum value of pLDDT or confidence value. Defaults to 90. 174 | chains (int, optional): The number of chains to be displayed. Defaults to None. 175 | Ls (list, optional): A list of the chains to be displayed. Defaults to None. 176 | size (tuple, optional): The size of the display window. Defaults to (800, 480). 177 | show_sidechains (bool, optional): Whether to display the sidechains. Defaults to False. 178 | show_mainchains (bool, optional): Whether to display the mainchains. Defaults to False. 179 | Returns: 180 | view: The 3Dmol view object that displays the protein structure. 181 | """ 182 | 183 | with open(pdb_path) as ifile: 184 | system = "".join([x for x in ifile]) 185 | 186 | view = py3Dmol.view( 187 | js="https://3dmol.org/build/3Dmol.js", width=size[0], height=size[1] 188 | ) 189 | 190 | if chains is None: 191 | chains = 1 if Ls is None else len(Ls) 192 | 193 | view.addModelsAsFrames(system) 194 | if color == "pLDDT" or color == "confidence": 195 | view.setStyle( 196 | { 197 | "cartoon": { 198 | "colorscheme": { 199 | "prop": "b", 200 | "gradient": "rwb", 201 | "min": vmin, 202 | "max": vmax, 203 | } 204 | } 205 | } 206 | ) 207 | elif color == "rainbow": 208 | view.setStyle({"cartoon": {"color": "spectrum"}}) 209 | elif color == "chain": 210 | for n, chain, color in zip(range(chains), alphabet_list, pymol_color_list): 211 | view.setStyle({"chain": chain}, {"cartoon": {"color": color}}) 212 | 213 | if show_sidechains: 214 | BB = ["C", "O", "N"] 215 | view.addStyle( 216 | { 217 | "and": [ 218 | {"resn": ["GLY", "PRO"], "invert": True}, 219 | {"atom": BB, "invert": True}, 220 | ] 221 | }, 222 | {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}}, 223 | ) 224 | view.addStyle( 225 | {"and": [{"resn": "GLY"}, {"atom": "CA"}]}, 226 | {"sphere": {"colorscheme": "WhiteCarbon", "radius": 0.3}}, 227 | ) 228 | view.addStyle( 229 | {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]}, 230 | {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}}, 231 | ) 232 | if show_mainchains: 233 | BB = ["C", "O", "N", "CA"] 234 | view.addStyle( 235 | {"atom": BB}, {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}} 236 | ) 237 | view.zoomTo() 238 | 239 | view.zoomTo() 240 | 241 | return view 242 | -------------------------------------------------------------------------------- /src/proteusAI/design_tools/Constraints.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import numpy as np 8 | import esm 9 | import typing as T 10 | import biotite.sequence as seq 11 | import biotite.sequence.align as align 12 | from biotite.structure.io.pdb import PDBFile 13 | import biotite.structure as struc 14 | from biotite.structure import sasa 15 | import tempfile 16 | 17 | from proteusAI.data_tools import pdb # TODO: double check with Johny! 18 | 19 | 20 | # _____Sequence Constraints_____ 21 | def length_constraint(seqs: list, max_len: int = 200): 22 | """ 23 | Constraint for the length of a seuqences. 24 | 25 | Parameters: 26 | seqs (list): sequences to be scored 27 | max_len (int): maximum length that a sequence is allowed to be. Default = 300 28 | 29 | Returns: 30 | np.array: Energy values 31 | """ 32 | energies = np.zeros(len(seqs)) 33 | 34 | for i, sequence in enumerate(seqs): 35 | if len(sequence) > max_len: 36 | energies[i] = float(len(sequence) - max_len) 37 | else: 38 | energies[i] = 0.0 39 | 40 | return energies 41 | 42 | 43 | def seq_identity(seqs, ref, matrix="BLOSUM62", local=False): 44 | """ 45 | Calculates sequence identity of sequences against a reference sequence based on alignment. 46 | By default a global alignment is performed using the BLOSUM62 matrix. 47 | 48 | Parameters: 49 | seq1 (str): reference sequence 50 | seq2 (str): query sequence 51 | matrix (str): alignement matrix {BLOSUM62, BLOSUM50, BLOSUM30}. Default BLOSUM62 52 | local (bool): Local alignment if True, else global alignment. 53 | 54 | Returns: 55 | numpy.ndarray: identity scores of sequences 56 | """ 57 | alph = seq.ProteinSequence.alphabet 58 | matrix = align.SubstitutionMatrix(alph, alph, matrix) 59 | 60 | seqs = [seq.ProteinSequence(s) for s in seqs] 61 | ref = seq.ProteinSequence(ref) 62 | 63 | scores = np.zeros(len(seqs)) 64 | for i, s in enumerate(seqs): 65 | alignments = align.align_optimal(s, ref, matrix, local=local) 66 | score = align.get_sequence_identity(alignments[0]) 67 | scores[i] = score 68 | 69 | return scores 70 | 71 | 72 | # _____Structure Constraints_____ 73 | def string_to_tempfile(data): 74 | # create a temporary file 75 | with tempfile.NamedTemporaryFile(delete=False) as temp_file: 76 | # write the string to the file 77 | temp_file.write(data.encode("utf-8")) 78 | # flush the file to make sure the data_tools is written 79 | temp_file.flush() 80 | # return the file object 81 | return temp_file 82 | 83 | 84 | def create_batched_sequence_datasest( 85 | sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024 86 | ) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]: 87 | """ 88 | Taken from https://github.com/facebookresearch/esm/blob/main/scripts/esmfold_inference.py 89 | """ 90 | batch_headers, batch_sequences, num_tokens = [], [], 0 91 | for header, sequence in sequences: 92 | if (len(sequence) + num_tokens > max_tokens_per_batch) and num_tokens > 0: 93 | yield batch_headers, batch_sequences 94 | batch_headers, batch_sequences, num_tokens = [], [], 0 95 | batch_headers.append(header) 96 | batch_sequences.append(sequence) 97 | num_tokens += len(sequence) 98 | 99 | yield batch_headers, batch_sequences 100 | 101 | 102 | def structure_prediction( 103 | sequences: list, 104 | names: list, 105 | chunk_size: int = 124, 106 | max_tokens_per_batch: int = 1024, 107 | num_recycles: int = None, 108 | ): 109 | """ 110 | Predict the structure of proteins. 111 | 112 | Parameters: 113 | sequences (list): all sequences for structure prediction 114 | names (list): names of the sequences 115 | chunck_size (int): Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Recommended values: 128, 64, 32. 116 | max_tokens_per_batch (int): Maximum number of tokens per gpu forward-pass. This will group shorter sequences together. 117 | num_recycles (int): Number of recycles to run. Defaults to number used in training 4. 118 | 119 | Returns: 120 | all_headers, all_sequences, all_pdbs, pTMs, mean_pLDDTs 121 | """ 122 | model = esm.pretrained.esmfold_v1() 123 | model = model.eval().cuda() 124 | model.set_chunk_size(chunk_size) 125 | all_sequences = list(zip(names, sequences)) 126 | 127 | batched_sequences = create_batched_sequence_datasest( 128 | all_sequences, max_tokens_per_batch 129 | ) 130 | all_headers = [] 131 | all_sequences = [] 132 | all_pdbs = [] 133 | pTMs = [] 134 | mean_pLDDTs = [] 135 | for headers, sequences in batched_sequences: 136 | output = model.infer(sequences, num_recycles=num_recycles) 137 | output = {key: value.cpu() for key, value in output.items()} 138 | pdbs = model.output_to_pdb(output) 139 | for header, sequence, pdb_string, mean_plddt, ptm in zip( 140 | headers, sequences, pdbs, output["mean_plddt"], output["ptm"] 141 | ): 142 | all_headers.append(header) 143 | all_sequences.append(sequence) 144 | all_pdbs.append( 145 | PDBFile.read(string_to_tempfile(pdb_string).name) 146 | ) # biotite pdb file name 147 | mean_pLDDTs.append(mean_plddt.item()) 148 | pTMs.append(ptm.item()) 149 | 150 | return all_headers, all_sequences, all_pdbs, pTMs, mean_pLDDTs 151 | 152 | 153 | def globularity(pdbs): 154 | """ 155 | globularity constraint 156 | 157 | Parameters: 158 | pdb (list): list of biotite pdb file 159 | 160 | Returns: 161 | np.array: variances of coordinates for each structure 162 | """ 163 | variances = np.zeros(len(pdbs)) 164 | 165 | for i, pdb_list in enumerate(pdbs): 166 | variance = pdb_list.get_coord().var() 167 | variances[i] = variance.item() 168 | 169 | return variances 170 | 171 | 172 | def surface_exposed_hydrophobics(pdbs): 173 | """ 174 | Calculate the surface exposed hydrophobics using the Shrake-Rupley (“rolling probe”) algorithm. 175 | 176 | Parameters: 177 | pdbs (list): list of biotite pdb file 178 | 179 | Returns: 180 | np.array: average sasa values for each structure 181 | """ 182 | avrg_sasa_values = np.zeros(len(pdbs)) 183 | for i, pdb_list in enumerate(pdbs): 184 | struc = pdb_list.get_structure() 185 | sasa_val = sasa( 186 | struc[0], 187 | probe_radius=1.4, 188 | atom_filter=None, 189 | ignore_ions=True, 190 | point_number=1000, 191 | point_distr="Fibonacci", 192 | vdw_radii="ProtOr", 193 | ) 194 | sasa_mean = sasa_val.mean() 195 | avrg_sasa_values[i] = sasa_mean.item() 196 | 197 | return avrg_sasa_values 198 | 199 | 200 | def backbone_coordination(samples: list, refs: list): 201 | """ 202 | Superimpose structures and calculate the RMSD of the backbones. 203 | sample structures will be aligned against reference structures. 204 | 205 | Parameters: 206 | ----------- 207 | samples (list): list of sample structures 208 | refs (list): list of reference structures 209 | 210 | Returns: 211 | -------- 212 | np.array: RMSD values of alignments 213 | """ 214 | if len(samples) != len(refs): 215 | raise "samples and refs must have the same length" 216 | 217 | rmsds = np.zeros(len(samples)) 218 | 219 | for i in range(len(samples)): 220 | _, _, rmsd = pdb.struc_align(refs[i], samples[i]) 221 | rmsds[i] = rmsd.item() 222 | 223 | return rmsds 224 | 225 | 226 | def all_atom_coordination(samples, refs, sample_consts, ref_consts): 227 | """ 228 | Calculate the RMSD of residues with all atom constraints. 229 | All atomic positions will be taken into consideration. 230 | 231 | Parameters: 232 | ----------- 233 | samples (list): list of biotite pdb files (mutated) 234 | refs (list): list of biotite pdb files (reference) 235 | sample_consts (list): list of constraints on the sample 236 | (potentially shifts over time due to deletions and insertions) 237 | ref_consts (list): list of constraints of the reference 238 | 239 | Returns: 240 | -------- 241 | np.array (len(samples),): calculated RMSD values for each sequence 242 | """ 243 | 244 | rmsds = np.zeros(len(samples)) 245 | 246 | for i in range(len(samples)): 247 | sample = samples[i] 248 | ref = refs[i] 249 | sample_const = sample_consts[i] 250 | ref_const = ref_consts[i] 251 | 252 | # get structures 253 | sample_struc = sample.get_structure()[0] 254 | ref_struc = ref.get_structure()[0] 255 | 256 | # get indices for alignment 257 | sample_indices = np.where( 258 | np.isin(sample_struc.res_id, [i + 1 for i in sample_const["all_atm"]]) 259 | ) 260 | ref_indices = np.where( 261 | np.isin(ref_struc.res_id, [i + 1 for i in ref_const["all_atm"]]) 262 | ) 263 | 264 | sample_struc_common = sample_struc[sample_indices[0]] 265 | ref_struc_common = ref_struc[ref_indices[0]] 266 | 267 | sample_superimposed, transformation = struc.superimpose( 268 | ref_struc_common, sample_struc_common 269 | ) 270 | 271 | sample_superimposed = struc.superimpose_apply(sample_struc, transformation) 272 | 273 | sample_pdb = PDBFile() 274 | PDBFile.set_structure(sample_pdb, sample_superimposed) 275 | sample_coord = sample_pdb.get_coord()[0][sample_indices] 276 | 277 | ref_pdb = PDBFile() 278 | PDBFile.set_structure(ref_pdb, ref_struc) 279 | ref_coord = ref_pdb.get_coord()[0][ref_indices] 280 | 281 | rmsd = struc.rmsd(sample_coord, ref_coord) 282 | 283 | rmsds[i] = rmsd 284 | 285 | return rmsds 286 | -------------------------------------------------------------------------------- /src/proteusAI/design_tools/MCMC.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import os 8 | import random 9 | import numpy as np 10 | from proteusAI.design_tools import Constraints 11 | import pandas as pd 12 | 13 | 14 | class ProteinDesign: 15 | """ 16 | Optimizes a protein sequence based on a custom energy function. Set weights of constraints to 0 which you don't want to use. 17 | 18 | Parameters: 19 | ----------- 20 | native_seq (str): native sequence to be optimized 21 | constraints (dict): constraints on sequence. 22 | Keys describe the kind of constraint and values the position on which they act. 23 | sampler (str): choose between simulated_annealing and substitution design_tools. 24 | Default 'simulated annealing' 25 | n_traj (int): number of independent trajectories per sampling step. Lowest energy mutant will be selected when 26 | multiple are viable. Default 16 27 | steps (int): number of sampling steps per trajectory. 28 | For simulated annealing, the number of iterations is often chosen in the range of [1,000, 10,000]. 29 | T (float): sampling temperature. 30 | For simulated annealing, T0 is often chosen in the range [1, 100]. default 1 31 | M (float): rate of temperature decay. 32 | or simulated annealing, a is often chosen in the range [0.01, 0.1] or [0.001, 0.01]. Default 0.01 33 | mut_p (tuple): probabilities for substitution, insertion and deletion. 34 | Default [0.6, 0.2, 0.2] 35 | pred_struc (bool): if True predict the structure of the protein at every step and use structure 36 | based constraints in the energy function. Default True. 37 | max_len (int): maximum length sequence length for lenght constraint. 38 | Default 300. 39 | w_len (float): weight of length constraint. 40 | Default 0.01 41 | w_identity (float): Weight of sequence identity constraint. Positive values reward low sequence identity to native sequence. 42 | Default 0.04 43 | w_ptm (float): weight for ptm. pTM is calculated as 1-pTM, because lower energies should be better. 44 | Default 1. 45 | w_plddt (float): weight for plddt. The mean pLDDT is calculated as 1-mean_pLDDT, because lower energies should be better. 46 | Default 1. 47 | w_globularity (float): weight of globularity constraint 48 | Default 0.001 49 | w_bb_coord (float): weight on backbone coordination constraint. Constraints backbone to native structure. 50 | Default 0.02 51 | w_all_atm (float): weight on all atom coordination constraint. Acts on all atoms which are constrained 52 | Default 0.15 53 | w_sasa (float): weight of surface exposed hydrophobics constraint 54 | Default 0.02 55 | outdir (str): path to output directory. 56 | Default None 57 | verbose (bool): if verbose print information 58 | """ 59 | 60 | def __init__( 61 | self, 62 | native_seq: str = None, 63 | constraints=None, 64 | sampler: str = "simulated_annealing", 65 | n_traj: int = 16, 66 | steps: int = 1000, 67 | T: float = 10.0, 68 | M: float = 0.01, 69 | mut_p: list = (0.6, 0.2, 0.2), 70 | pred_struc: bool = True, 71 | max_len: int = 300, 72 | w_len: float = 0.01, 73 | w_identity: float = 0.1, 74 | w_ptm: float = 1, 75 | w_plddt: float = 1, 76 | w_globularity: float = 0.001, 77 | w_bb_coord: float = 0.02, 78 | w_all_atm: float = 0.15, 79 | w_sasa: float = 0.02, 80 | outdir: str = None, 81 | verbose: bool = False, 82 | ): 83 | 84 | if constraints is None: 85 | constraints = {"no_mut": [], "all_atm": []} 86 | self.native_seq = native_seq 87 | self.sampler = sampler 88 | self.n_traj = n_traj 89 | self.steps = steps 90 | self.mut_p = mut_p 91 | self.T = T 92 | self.M = M 93 | self.pred_struc = pred_struc 94 | self.max_len = max_len 95 | self.w_max_len = w_len 96 | self.w_identity = w_identity 97 | self.w_ptm = w_ptm 98 | self.w_plddt = w_plddt 99 | self.w_globularity = w_globularity 100 | self.outdir = outdir 101 | self.verbose = verbose 102 | self.constraints = constraints 103 | self.w_sasa = w_sasa 104 | self.w_bb_coord = w_bb_coord 105 | self.w_all_atm = w_all_atm 106 | 107 | # Parameters 108 | self.ref_pdbs = None 109 | self.ref_constraints = None 110 | self.initial_energy = None 111 | 112 | def __str__(self): 113 | lines: list[str] = [ 114 | "ProteusAI.MCMC.Design class: \n", 115 | "---------------------------------------\n", 116 | "When Hallucination.run() sequences will be hallucinated using this seed sequence:\n\n", 117 | f"{self.native_seq}\n", 118 | "\nThe following variables were set:\n\n", 119 | "variable\t|value\n", 120 | "----------------+-------------------\n", 121 | f"algorithm: \t|{self.sampler}\n", 122 | f"steps: \t\t|{self.steps}\n", 123 | f"n_traj: \t|{self.n_traj}\n", 124 | f"mut_p: \t\t|{self.mut_p}\n", 125 | f"T: \t\t|{self.T}\n", 126 | f"M: \t\t|{self.M}\n\n", 127 | "The energy function is a linear combination of the following constraints:\n\n", 128 | "constraint\t|value\t|weight\n", 129 | "----------------+-------+------------\n", 130 | f"length \t\t|{self.max_len}\t|{self.w_max_len}\n", 131 | f"identity\t|\t|{self.w_identity}\n", 132 | ] 133 | s = "".join(lines) 134 | if self.pred_struc: 135 | lines = [ 136 | s, 137 | f"pTM\t\t|\t|{self.w_ptm}\n", 138 | f"pLDDT\t\t|\t|{self.w_plddt}\n", 139 | f"bb_coord\t\t|{self.w_bb_coord}\n", 140 | f"all_atm\t\t|\t|{self.w_all_atm}\n", 141 | f"sasa\t\t|\t|{self.w_sasa}\n", 142 | ] 143 | s = "".join(lines) 144 | return s 145 | 146 | ### SAMPLERS 147 | def mutate(self, seqs, mut_p: list = None, constraints: list = None): 148 | """ 149 | mutates input sequences. 150 | 151 | Parameters: 152 | seqs (list): list of peptide sequences 153 | mut_p (list): mutation probabilities 154 | constraints (list): dictionary of constraints 155 | 156 | Returns: 157 | list: mutated sequences 158 | """ 159 | 160 | if mut_p is None: 161 | mut_p = [0.6, 0.2, 0.2] 162 | 163 | AAs = ( 164 | "A", 165 | "C", 166 | "D", 167 | "E", 168 | "F", 169 | "G", 170 | "H", 171 | "I", 172 | "K", 173 | "L", 174 | "M", 175 | "N", 176 | "P", 177 | "Q", 178 | "R", 179 | "S", 180 | "T", 181 | "V", 182 | "W", 183 | "Y", 184 | ) 185 | 186 | mut_types = ("substitution", "insertion", "deletion") 187 | 188 | mutated_seqs = [] 189 | mutated_constraints = [] 190 | mutations = [] 191 | for i, seq in enumerate(seqs): 192 | mut_constraints = {} 193 | 194 | # loop until allowed mutation has been selected 195 | mutate = True 196 | while mutate: 197 | pos = random.randint(0, len(seq) - 1) 198 | mut_type = random.choices(mut_types, mut_p)[0] 199 | if pos in constraints[i]["no_mut"] or pos in constraints[i]["all_atm"]: 200 | pass 201 | # secondary structure constraint disallows deletion 202 | # insertions between two secondary structure constraints will have the constraint of their neighbors 203 | else: 204 | break 205 | 206 | if mut_type == "substitution": 207 | replacement = random.choice(AAs) 208 | mut_seq = "".join([seq[:pos], replacement, seq[pos + 1 :]]) 209 | for const in constraints[i].keys(): 210 | positions = constraints[i][const] 211 | mut_constraints[const] = positions 212 | mutations.append(f"sub:{seq[pos]}{pos}{replacement}") 213 | 214 | elif mut_type == "insertion": 215 | insertion = random.choice(AAs) 216 | mut_seq = "".join([seq[:pos], insertion, seq[pos:]]) 217 | # shift constraints after insertion 218 | for const in constraints[i].keys(): 219 | positions = constraints[i][const] 220 | positions = [i if i < pos else i + 1 for i in positions] 221 | mut_constraints[const] = positions 222 | mutations.append(f"ins:{pos}{insertion}") 223 | 224 | elif mut_type == "deletion" and len(seq) > 1: 225 | lines = list(seq) 226 | del lines[pos] 227 | mut_seq = "".join(lines) 228 | # shift constraints after deletion 229 | for const in constraints[i].keys(): 230 | positions = constraints[i][const] 231 | positions = [i if i < pos else i - 1 for i in positions] 232 | mut_constraints[const] = positions 233 | mutations.append(f"del:{seq[pos]}{pos}") 234 | 235 | else: 236 | # will perform insertion if length is to small 237 | insertion = random.choice(AAs) 238 | mut_seq = "".join([seq[:pos], insertion, seq[pos:]]) 239 | # shift constraints after insertion 240 | for const in constraints[i].keys(): 241 | positions = constraints[i][const] 242 | positions = [i if i < pos else i + 1 for i in positions] 243 | mut_constraints[const] = positions 244 | mutations.append(f"ins:{pos}{insertion}") 245 | 246 | mutated_seqs.append(mut_seq) 247 | mutated_constraints.append(mut_constraints) 248 | 249 | return mutated_seqs, mutated_constraints, mutations 250 | 251 | ### ENERGY FUNCTION and ACCEPTANCE CRITERION 252 | def energy_function(self, seqs: list, i: int, constraints: list): 253 | """ 254 | Combines constraints into an energy function. The energy function 255 | returns the energy values of the mutated files and the associated pdb 256 | files as temporary files. In addition it returns a dictionary of the different 257 | energies. 258 | 259 | Parameters: 260 | seqs (list): list of sequences 261 | i (int): current iteration in sampling 262 | constraints (list): list of constraints 263 | 264 | Returns: 265 | tuple: Energy value, pdbs, energy_log 266 | """ 267 | # reinitialize energy 268 | energies = np.zeros(len(seqs)) 269 | energy_log = dict() 270 | 271 | e_len = self.w_max_len * Constraints.length_constraint( 272 | seqs=seqs, max_len=self.max_len 273 | ) 274 | e_identity = self.w_identity * Constraints.seq_identity( 275 | seqs=seqs, ref=self.native_seq 276 | ) 277 | 278 | energies += e_len 279 | energies += e_identity 280 | 281 | energy_log[f"e_len x {self.w_max_len}"] = e_len 282 | energy_log[f"e_identity x {self.w_identity}"] = e_identity 283 | 284 | pdbs = [] 285 | if self.pred_struc: 286 | # structure prediction 287 | names = [f"sequence_{j}_cycle_{i}" for j in range(len(seqs))] 288 | headers, sequences, pdbs, pTMs, mean_pLDDTs = ( 289 | Constraints.structure_prediction(seqs, names) 290 | ) 291 | pTMs = [1 - val for val in pTMs] 292 | mean_pLDDTs = [1 - val / 100 for val in mean_pLDDTs] 293 | 294 | e_pTMs = self.w_ptm * np.array(pTMs) 295 | e_mean_pLDDTs = self.w_plddt * np.array(mean_pLDDTs) 296 | e_globularity = self.w_globularity * Constraints.globularity(pdbs) 297 | e_sasa = self.w_sasa * Constraints.surface_exposed_hydrophobics(pdbs) 298 | 299 | energies += e_pTMs 300 | energies += e_mean_pLDDTs 301 | energies += e_globularity 302 | energies += e_sasa 303 | 304 | energy_log[f"e_pTMs x {self.w_ptm}"] = e_pTMs 305 | energy_log[f"e_mean_pLDDTs x {self.w_plddt}"] = e_mean_pLDDTs 306 | energy_log[f"e_globularity x {self.w_globularity}"] = e_globularity 307 | energy_log[f"e_sasa x {self.w_sasa}"] = e_sasa 308 | 309 | # there are now ref pdbs before the first calculation 310 | if self.ref_pdbs is not None: 311 | e_bb_coord = self.w_bb_coord * Constraints.backbone_coordination( 312 | pdbs, self.ref_pdbs 313 | ) 314 | e_all_atm = self.w_all_atm * Constraints.all_atom_coordination( 315 | pdbs, self.ref_pdbs, constraints, self.ref_constraints 316 | ) 317 | 318 | energies += e_bb_coord 319 | energies += e_all_atm 320 | 321 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = e_bb_coord 322 | energy_log[f"e_all_atm x {self.w_all_atm}"] = e_all_atm 323 | else: 324 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = [] 325 | energy_log[f"e_all_atm x {self.w_all_atm}"] = [] 326 | 327 | energy_log["iteration"] = i + 1 328 | 329 | return energies, pdbs, energy_log 330 | 331 | def p_accept(self, E_x_mut, E_x_i, T, i, M): 332 | """ 333 | Decides to accep or reject changes. Changes which have a lower energy 334 | than the previous state will always be accepted. Changes which have 335 | higher energies will be accepted with a probability p_accept. The 336 | acceptance probability for bad states decreases over time. 337 | 338 | Parameters: 339 | ----------- 340 | E_x_mut (np.array): energies of mutated sequences 341 | E_x_i (np.array): energies of initial sequences 342 | T (float): Temperature 343 | i (int): current itteration 344 | M (float): decay constant 345 | 346 | Returns: 347 | -------- 348 | np.array: accept probabilities 349 | """ 350 | T = T / (1 + M * i) 351 | dE = E_x_i - E_x_mut 352 | exp_val = np.exp(1 / (T * dE)) 353 | p_accept = np.minimum(exp_val, np.ones_like(exp_val)) 354 | return p_accept 355 | 356 | ### RUN 357 | def run(self): 358 | """ 359 | Runs MCMC-sampling based on user defined inputs. Returns optimized sequences. 360 | """ 361 | native_seq = self.native_seq 362 | constraints = self.constraints 363 | n_traj = self.n_traj 364 | steps = self.steps 365 | sampler = self.sampler 366 | energy_function = self.energy_function 367 | T = self.T 368 | M = self.M 369 | p_accept = self.p_accept 370 | mut_p = self.mut_p 371 | outdir = self.outdir 372 | pdb_out = os.path.join(outdir, "pdbs") 373 | png_out = os.path.join(outdir, "pngs") 374 | data_out = os.path.join(outdir, "data_tools") 375 | 376 | if outdir is not None: 377 | if not os.path.exists(outdir): 378 | os.mkdir(outdir) 379 | if not os.path.exists(pdb_out): 380 | os.mkdir(pdb_out) 381 | if not os.path.exists(png_out): 382 | os.mkdir(png_out) 383 | if not os.path.exists(data_out): 384 | os.mkdir(data_out) 385 | 386 | if sampler == "simulated_annealing": 387 | mutate = self.mutate 388 | 389 | if native_seq is None: 390 | raise "The optimizer needs a sequence to run. Define a sequence by calling SequenceOptimizer(native_seq = )" 391 | 392 | seqs = [native_seq for _ in range(n_traj)] 393 | constraints = [constraints for _ in range(n_traj)] 394 | self.ref_constraints = constraints.copy() # THESE ARE CORRECT 395 | 396 | # for initial calculation don't use the full sequences, unecessary 397 | # calculation of initial state 398 | E_x_i, pdbs, energy_log = energy_function([seqs[0]], -1, [constraints[0]]) 399 | E_x_i = [E_x_i[0] for _ in range(n_traj)] 400 | pdbs = [pdbs[0] for _ in range(n_traj)] 401 | 402 | # empty energies dictionary for the first run 403 | for key in energy_log.keys(): 404 | energy_log[key] = [] 405 | energy_log["T"] = [] 406 | energy_log["M"] = [] 407 | energy_log["mut"] = [] 408 | energy_log["description"] = [] 409 | 410 | self.initial_energy = E_x_i.copy() 411 | self.ref_pdbs = pdbs.copy() 412 | 413 | if self.pred_struc and outdir is not None: 414 | # saves the n th structure 415 | num = "{:0{}d}".format(len(energy_log["iteration"]), len(str(self.steps))) 416 | pdbs[0].write(os.path.join(pdb_out, f"{num}_design.pdb")) 417 | 418 | # write energy_log in data_out 419 | if outdir is not None: 420 | df = pd.DataFrame(energy_log) 421 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False) 422 | 423 | for i in range(steps): 424 | mut_seqs, _constraints, mutations = mutate(seqs, mut_p, constraints) 425 | E_x_mut, pdbs_mut, _energy_log = energy_function(mut_seqs, i, _constraints) 426 | # accept or reject change 427 | p = p_accept(E_x_mut, E_x_i, T, i, M) 428 | 429 | new_struc_found = False 430 | accepted_ind = [] # indices of accepted structures 431 | for n in range(n_traj): 432 | if p[n] > random.random(): 433 | accepted_ind.append(n) 434 | E_x_i[n] = E_x_mut[n] 435 | seqs[n] = mut_seqs[n] 436 | constraints[n] = _constraints[n] 437 | new_struc_found = True 438 | 439 | if new_struc_found: 440 | # get index of lowest energie sructure out of the newly found structures 441 | min_E = accepted_ind[0] 442 | for a in accepted_ind: 443 | if a < min_E: 444 | min_E = a 445 | 446 | # update all to lowest energy structure 447 | E_x_i = [E_x_i[min_E] for _ in range(n_traj)] 448 | seqs = [seqs[min_E] for _ in range(n_traj)] 449 | constraints = [constraints[min_E] for _ in range(n_traj)] 450 | pdbs = [pdbs_mut[min_E] for _ in range(n_traj)] 451 | 452 | for key in energy_log.keys(): 453 | # skip skalar values in this step 454 | if key not in ["T", "M", "iteration", "mut", "description"]: 455 | e = _energy_log[key] 456 | energy_log[key].append(e[min_E].item()) 457 | 458 | energy_log["iteration"].append(i) 459 | energy_log["T"].append(T) 460 | energy_log["M"].append(M) 461 | energy_log["mut"].append(mutations[min_E]) 462 | 463 | num = "{:0{}d}".format( 464 | len(energy_log["iteration"]), len(str(self.steps)) 465 | ) 466 | energy_log["description"].append(f"{num}_design") 467 | 468 | if self.pred_struc and outdir is not None: 469 | # saves the n th structure 470 | pdbs[0].write(os.path.join(pdb_out, f"{num}_design.pdb")) 471 | 472 | # write energy_log in data_out 473 | if outdir is not None: 474 | df = pd.DataFrame(energy_log) 475 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False) 476 | 477 | return seqs 478 | -------------------------------------------------------------------------------- /src/proteusAI/design_tools/ZeroShot.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import os 8 | import numpy as np 9 | from proteusAI.design_tools import Constraints 10 | import pandas as pd 11 | 12 | 13 | class ZeroShot: 14 | """ 15 | ZeroShot inference for single mutational effects. Mutates every single position to every possible amino acid at 16 | that position. Then calculate differnt metrics of structure and embeddings for the mutants. Observe which mutations 17 | have the greatest effect on the mutants. 18 | 19 | Parameters: 20 | ----------- 21 | seq (str): native sequence to be optimized 22 | constraints (dict): constraints on sequence. 23 | Keys describe the kind of constraint and values the position on which they act. 24 | sampler (str): choose between simulated_annealing and substitution design_tools. 25 | Default 'simulated annealing' 26 | batch_size (int): number of independent trajectories per sampling step. Lowest energy mutant will be selected when 27 | multiple are viable. Default 16 28 | steps (int): number of sampling steps per trajectory. 29 | For simulated annealing, the number of iterations is often chosen in the range of [1,000, 10,000]. 30 | T (float): sampling temperature. 31 | For simulated annealing, T0 is often chosen in the range [1, 100]. default 1 32 | M (float): rate of temperature decay. 33 | or simulated annealing, a is often chosen in the range [0.01, 0.1] or [0.001, 0.01]. Default 0.01 34 | mut_p (tuple): probabilities for substitution, insertion and deletion. 35 | Default [0.6, 0.2, 0.2] 36 | pred_struc (bool): if True predict the structure of the protein at every step and use structure 37 | based constraints in the energy function. Default True. 38 | max_len (int): maximum length sequence length for lenght constraint. 39 | Default 300. 40 | w_len (float): weight of length constraint. 41 | Default 0.01 42 | w_identity (float): Weight of sequence identity constraint. Positive values reward low sequence identity to native sequence. 43 | Default 0.04 44 | w_ptm (float): weight for ptm. pTM is calculated as 1-pTM, because lower energies should be better. 45 | Default 1. 46 | w_plddt (float): weight for plddt. The mean pLDDT is calculated as 1-mean_pLDDT, because lower energies should be better. 47 | Default 1. 48 | w_globularity (float): weight of globularity constraint 49 | Default 0.001 50 | w_bb_coord (float): weight on backbone coordination constraint. Constraints backbone to native structure. 51 | Default 0.02 52 | w_all_atm (float): weight on all atom coordination constraint. Acts on all atoms which are constrained 53 | Default 0.15 54 | w_sasa (float): weight of surface exposed hydrophobics constraint 55 | Default 0.02 56 | outdir (str): path to output directory. 57 | Default None 58 | verbose (bool): if verbose print information 59 | """ 60 | 61 | def __init__( 62 | self, 63 | seq: str = None, 64 | name="Prot", 65 | constraints=None, 66 | batch_size: int = 20, 67 | pred_struc: bool = True, 68 | w_ptm: float = 1, 69 | w_plddt: float = 1, 70 | w_globularity: float = 0.001, 71 | w_bb_coord: float = 0.02, 72 | w_all_atm: float = 0.15, 73 | w_sasa: float = 0.02, 74 | outdir: str = None, 75 | verbose: bool = False, 76 | ): 77 | 78 | if constraints is None: 79 | constraints = {"all_atm": []} 80 | self.seq = seq 81 | self.name = name 82 | self.constraints = constraints 83 | self.batch_size = batch_size 84 | self.w_ptm = w_ptm 85 | self.w_plddt = w_plddt 86 | self.w_globularity = w_globularity 87 | self.outdir = outdir 88 | self.verbose = verbose 89 | self.w_sasa = w_sasa 90 | self.w_bb_coord = w_bb_coord 91 | self.w_all_atm = w_all_atm 92 | 93 | # Parameters 94 | self.ref_pdbs = None 95 | 96 | def __str__(self): 97 | lines = [ 98 | "ProteusAI.MCMC.ZeroShot class: \n", 99 | "---------------------------------------\n", 100 | "When Hallucination.run() sequences will be hallucinated using this seed sequence:\n\n", 101 | f"{self.seq}\n", 102 | "\nThe following variables were set:\n\n", 103 | "variable\t|value\n", 104 | "----------------+-------------------\n", 105 | f"batch_size: \t|{self.batch_size}\n", 106 | "The energy function is a linear combination of the following constraints:\n\n", 107 | "constraint\t|value\t|weight\n", 108 | "----------------+-------+------------\n", 109 | ] 110 | s = "".join(lines) 111 | lines = [ 112 | s, 113 | f"pTM\t\t|\t|{self.w_ptm}\n", 114 | f"pLDDT\t\t|\t|{self.w_plddt}\n", 115 | f"bb_coord\t\t|{self.w_bb_coord}\n", 116 | f"all_atm\t\t|\t|{self.w_all_atm}\n", 117 | f"sasa\t\t|\t|{self.w_sasa}\n", 118 | ] 119 | s = "".join(lines) 120 | return s 121 | 122 | ### SAMPLERS 123 | def mutate(self, seq, pos): 124 | """ 125 | mutates input sequences. to all possible amino acids at position 126 | 127 | Parameters: 128 | seqs (str): native protein sequence 129 | pos (int): position to mutate 130 | 131 | Returns: 132 | list: mutated sequences 133 | """ 134 | 135 | AAs = ( 136 | "A", 137 | "C", 138 | "D", 139 | "E", 140 | "F", 141 | "G", 142 | "H", 143 | "I", 144 | "K", 145 | "L", 146 | "M", 147 | "N", 148 | "P", 149 | "Q", 150 | "R", 151 | "S", 152 | "T", 153 | "V", 154 | "W", 155 | "Y", 156 | ) 157 | 158 | native_aa = seq[pos] 159 | 160 | mut_seqs = [] 161 | names = [] 162 | for res in AAs: 163 | if res != native_aa: 164 | mut_seq = "".join([seq[:pos], res, seq[pos + 1 :]]) 165 | mut_seqs.append(mut_seq) 166 | names.append(str(seq[pos]) + str(pos) + str(res)) 167 | return mut_seqs, names 168 | 169 | ### ENERGY FUNCTION and ACCEPTANCE CRITERION 170 | def energy_function(self, seqs: list, pos: int, names): 171 | """ 172 | Combines constraints into an energy function. The energy function 173 | returns the energy values of the mutated files and the associated pdb 174 | files as temporary files. In addition it returns a dictionary of the different 175 | energies. 176 | 177 | Parameters: 178 | seqs (list): list of sequences 179 | pos (int): position in sequence 180 | constraints (list): list of constraints 181 | 182 | Returns: 183 | tuple: Energy value, pdbs, energy_log 184 | """ 185 | # reinitialize energy 186 | energies = np.zeros(len(seqs)) 187 | energy_log = dict() 188 | constraints = self.constraints 189 | 190 | for i, c in enumerate(constraints["all_atm"]): 191 | if c == pos: 192 | del constraints["all_atm"][i] 193 | constraints = [constraints for i in range(len(seqs))] 194 | 195 | # structure prediction 196 | names = [f"{name}_{self.name}" for name in names] 197 | headers, sequences, pdbs, pTMs, mean_pLDDTs = Constraints.structure_prediction( 198 | seqs, names 199 | ) 200 | pTMs = [1 - val for val in pTMs] 201 | mean_pLDDTs = [1 - val / 100 for val in mean_pLDDTs] 202 | 203 | e_pTMs = self.w_ptm * np.array(pTMs) 204 | e_mean_pLDDTs = self.w_plddt * np.array(mean_pLDDTs) 205 | e_globularity = self.w_globularity * Constraints.globularity(pdbs) 206 | e_sasa = self.w_sasa * Constraints.surface_exposed_hydrophobics(pdbs) 207 | 208 | energies += e_pTMs 209 | energies += e_mean_pLDDTs 210 | energies += e_globularity 211 | energies += e_sasa 212 | 213 | energy_log[f"e_pTMs x {self.w_ptm}"] = e_pTMs 214 | energy_log[f"e_mean_pLDDTs x {self.w_plddt}"] = e_mean_pLDDTs 215 | energy_log[f"e_globularity x {self.w_globularity}"] = e_globularity 216 | energy_log[f"e_sasa x {self.w_sasa}"] = e_sasa 217 | 218 | # there are now ref pdbs before the first calculation 219 | if self.ref_pdbs is not None: 220 | e_bb_coord = self.w_bb_coord * Constraints.backbone_coordination( 221 | pdbs, self.ref_pdbs 222 | ) 223 | e_all_atm = self.w_all_atm * Constraints.all_atom_coordination( 224 | pdbs, self.ref_pdbs, constraints, constraints 225 | ) 226 | 227 | energies += e_bb_coord 228 | energies += e_all_atm 229 | 230 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = e_bb_coord 231 | energy_log[f"e_all_atm x {self.w_all_atm}"] = e_all_atm 232 | else: 233 | energy_log[f"e_bb_coord x {self.w_bb_coord}"] = [0] 234 | energy_log[f"e_all_atm x {self.w_all_atm}"] = [0] 235 | 236 | energy_log["position"] = [pos] 237 | 238 | return energies, pdbs, energy_log 239 | 240 | ### RUN 241 | def run(self): 242 | """ 243 | Runs MCMC-sampling based on user defined inputs. Returns optimized sequences. 244 | """ 245 | seq = self.seq 246 | batch_size = self.batch_size 247 | energy_function = self.energy_function 248 | outdir = self.outdir 249 | mutate = self.mutate 250 | pdb_out = os.path.join(outdir, "pdbs") 251 | png_out = os.path.join(outdir, "pngs") 252 | data_out = os.path.join(outdir, "data_tools") 253 | 254 | if outdir is not None: 255 | if not os.path.exists(outdir): 256 | os.mkdir(outdir) 257 | if not os.path.exists(pdb_out): 258 | os.mkdir(pdb_out) 259 | if not os.path.exists(png_out): 260 | os.mkdir(png_out) 261 | if not os.path.exists(data_out): 262 | os.mkdir(data_out) 263 | 264 | if seq is None: 265 | raise "Provide a sequence(seq = )" 266 | 267 | # for initial calculation don't use the full sequences, unecessary 268 | # calculation of initial state 269 | _, pdbs, energy_log = energy_function([seq], 0, ["native"]) 270 | pdbs = [pdbs[0] for _ in range(batch_size - 1)] 271 | energy_log["mut"] = ["-"] 272 | energy_log["description"] = ["native"] 273 | 274 | # make energies to list 275 | for key in energy_log.keys(): 276 | if not isinstance(energy_log[key], list): 277 | energy_log[key] = energy_log[key].tolist() 278 | 279 | self.ref_pdbs = pdbs.copy() 280 | 281 | if outdir is not None: 282 | # saves the n th structure 283 | pdbs[0].write(os.path.join(pdb_out, f"native_{self.name}.pdb")) 284 | 285 | # write energy_log in data_out 286 | if outdir is not None: 287 | df = pd.DataFrame(energy_log) 288 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False) 289 | 290 | for pos in range(len(seq)): 291 | seqs, names = mutate(seq, pos) 292 | E_x_i, pdbs, _energy_log = energy_function(seqs, pos, names) 293 | 294 | for n in range(len(seqs)): 295 | for key in energy_log.keys(): 296 | # skip skalar values in this step 297 | if key not in ["position", "mut", "description"]: 298 | e = _energy_log[key].tolist() 299 | with open("test", "w") as f: 300 | print("energy_log", file=f) 301 | print(energy_log, file=f) 302 | print("_energy_log", file=f) 303 | print(_energy_log, file=f) 304 | print("key", file=f) 305 | print(key, file=f) 306 | print("e", file=f) 307 | print(e, file=f) 308 | energy_log[key].append(e[n]) 309 | 310 | energy_log["position"].append(pos) 311 | energy_log["mut"].append(names[n]) 312 | 313 | energy_log["description"].append(f"{names[n]}_{self.name}") 314 | 315 | if outdir is not None: 316 | # saves the n th structure 317 | pdbs[n].write(os.path.join(pdb_out, f"{names[n]}_{self.name}.pdb")) 318 | 319 | # write energy_log in data_out 320 | if outdir is not None: 321 | df = pd.DataFrame(energy_log) 322 | df.to_csv(os.path.join(data_out, "energy_log.pdb"), index=False) 323 | -------------------------------------------------------------------------------- /src/proteusAI/design_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for protein design_tools. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.design_tools.Constraints import * # noqa: F403 12 | from proteusAI.design_tools.MCMC import * # noqa: F403 13 | from proteusAI.design_tools.ZeroShot import * # noqa: F403 14 | -------------------------------------------------------------------------------- /src/proteusAI/io_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for mining_tools. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.io_tools.embeddings import * # noqa: F403 12 | from proteusAI.io_tools.fasta import * # noqa: F403 13 | -------------------------------------------------------------------------------- /src/proteusAI/io_tools/embeddings.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import os 8 | import torch 9 | from typing import Union 10 | 11 | 12 | def load_embeddings( 13 | path: str, names: Union[list, None] = None, map_location: str = "cpu" 14 | ) -> tuple: 15 | """ 16 | Loads all representations files from a directory, returns the names/ids and sequences as lists. 17 | 18 | Parameters: 19 | path (str): path to directory containing representations files 20 | names (list): list of file names in case files should be loaded in a specific order if names not None. 21 | Will go to provided path and load files by name order. Default None 22 | 23 | Returns: 24 | tuple: two lists containing the names and sequences as torch tensors 25 | 26 | Example: 27 | names, sequences = load('/path/to/representations') 28 | """ 29 | 30 | tensors = [] 31 | if names is None: 32 | files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")] 33 | names = [f[:-3] for f in os.listdir(path) if f.endswith(".pt")] 34 | for f in files: 35 | t = torch.load(f, map_location=map_location, weights_only=False) 36 | tensors.append(t) 37 | else: 38 | for name in names: 39 | t = torch.load( 40 | os.path.join(path, name), map_location=map_location, weights_only=False 41 | ) 42 | tensors.append(t) 43 | 44 | return names, tensors 45 | -------------------------------------------------------------------------------- /src/proteusAI/io_tools/fasta.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import os 8 | from biotite.sequence import ProteinSequence 9 | import numpy as np 10 | import hashlib 11 | 12 | 13 | def load_all_fastas( 14 | path: str, file_type: str = ".fasta", biotite: bool = False 15 | ) -> dict: 16 | """ 17 | Loads all fasta files from a directory, returns the names/ids and sequences as lists. 18 | 19 | Parameters: 20 | path (str): path to directory containing fasta files 21 | file_type (str): some fastas are stored with different file endings. Default '.fasta'. 22 | biotite (bool): returns sequences as biotite.sequence.ProteinSequence object 23 | 24 | Returns: 25 | dict: dictionary of file_names and tuple (names, sequences) 26 | 27 | Example: 28 | results = load_fastas('/path/to/fastas') 29 | """ 30 | file_names = [f for f in os.listdir(path) if f.endswith(file_type)] 31 | files = [os.path.join(path, f) for f in file_names if f.endswith(file_type)] 32 | 33 | results = {} 34 | for i, file in enumerate(files): 35 | names = [] 36 | sequences = [] 37 | with open(file, "r") as f: 38 | current_sequence = "" 39 | for line in f: 40 | line = line.strip() 41 | if line.startswith(">"): 42 | if current_sequence: 43 | sequences.append(current_sequence) 44 | names.append(line[1:]) 45 | current_sequence = "" 46 | else: 47 | current_sequence += line 48 | if biotite: 49 | sequences.append(ProteinSequence(current_sequence)) 50 | else: 51 | sequences.append(current_sequence) 52 | 53 | results[file_names[i]] = (names, sequences) 54 | return results 55 | 56 | 57 | def load_fasta(file: str, biotite: bool = False) -> tuple: 58 | """ 59 | Load all sequences in a fasta file. Returns names and sequences 60 | 61 | Parameters: 62 | file (str): path to file 63 | biotite (bool): returns sequences as biotite.sequence.ProteinSequence object 64 | 65 | Returns: 66 | tuple: two lists containing the names and sequences 67 | 68 | Example: 69 | names, sequences = load_fastas('example.fasta') 70 | """ 71 | 72 | names = [] 73 | sequences = [] 74 | with open(file, "r") as f: 75 | current_sequence = "" 76 | for line in f: 77 | line = line.strip() 78 | if line.startswith(">"): 79 | if current_sequence: 80 | sequences.append(current_sequence) 81 | names.append(line[1:]) 82 | current_sequence = "" 83 | else: 84 | current_sequence += line 85 | if biotite: 86 | sequences.append(ProteinSequence(current_sequence)) 87 | else: 88 | sequences.append(current_sequence) 89 | 90 | return names, sequences 91 | 92 | 93 | def write_fasta(names: list, sequences: list, dest: str = None): 94 | """ 95 | Takes a list of names and sequences and writes a single 96 | fasta file containing all the names and sequences. The 97 | files will be saved at the destination 98 | 99 | Parameters: 100 | names (list): list of sequence names 101 | sequences (list): list of sequences 102 | dest (str): path to output file 103 | 104 | Example: 105 | write_fasta(names, sequences, './out.fasta') 106 | """ 107 | assert isinstance(names, list) and isinstance( 108 | sequences, list 109 | ), "names and sequences must be type list" 110 | assert len(names) == len(sequences), "names and sequences must have the same length" 111 | 112 | with open(dest, "w") as f: 113 | for i in range(len(names)): 114 | f.writelines(">" + names[i] + "\n") 115 | if i == len(names) - 1: 116 | f.writelines(sequences[i]) 117 | else: 118 | f.writelines(sequences[i] + "\n") 119 | 120 | 121 | def one_hot_encoding(sequence: str): 122 | """ 123 | Returns one hot encoding for amino acid sequence. Unknown amino acids will be 124 | encoded with 0.5 at in entire row. 125 | 126 | Parameters: 127 | ----------- 128 | sequence (str): Amino acid sequence 129 | 130 | Returns: 131 | -------- 132 | numpy.ndarray: One hot encoded sequence 133 | """ 134 | # Define amino acid alphabets and create dictionary 135 | amino_acids = "ACDEFGHIKLMNPQRSTVWY" 136 | aa_dict = {aa: i for i, aa in enumerate(amino_acids)} 137 | 138 | # Initialize empty numpy array for one-hot encoding 139 | seq_one_hot = np.zeros((len(sequence), len(amino_acids))) 140 | 141 | # Convert each amino acid in sequence to one-hot encoding 142 | for i, aa in enumerate(sequence): 143 | if aa in aa_dict: 144 | seq_one_hot[i, aa_dict[aa]] = 1.0 145 | else: 146 | # Handle unknown amino acids with a default value of 0.5 147 | seq_one_hot[i, :] = 0.5 148 | 149 | return seq_one_hot 150 | 151 | 152 | def blosum_encoding(sequence, matrix="BLOSUM62", canonical=True): 153 | """ 154 | Returns BLOSUM encoding for amino acid sequence. Unknown amino acids will be 155 | encoded with 0.5 at in entire row. 156 | 157 | Parameters: 158 | ----------- 159 | sequence (str): Amino acid sequence 160 | blosum_matrix_choice (str): Choice of BLOSUM matrix. Can be 'BLOSUM50' or 'BLOSUM62' 161 | canonical (bool): only use canonical amino acids 162 | 163 | Returns: 164 | -------- 165 | numpy.ndarray: BLOSUM encoded sequence 166 | """ 167 | 168 | # Get the directory of the current script 169 | script_dir = os.path.dirname(os.path.realpath(__file__)) 170 | 171 | ### Amino Acid codes 172 | alphabet_file = os.path.join(script_dir, "matrices/alphabet") 173 | alphabet = np.loadtxt(alphabet_file, dtype=str) 174 | 175 | # Define BLOSUM matrices 176 | _blosum50 = ( 177 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM50"), dtype=float) 178 | .reshape((24, -1)) 179 | .T 180 | ) 181 | _blosum62 = ( 182 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM62"), dtype=float) 183 | .reshape((24, -1)) 184 | .T 185 | ) 186 | 187 | # Choose BLOSUM matrix 188 | if matrix == "BLOSUM50": 189 | matrix = _blosum50 190 | elif matrix == "BLOSUM62": 191 | matrix = _blosum62 192 | else: 193 | raise ValueError( 194 | "Invalid BLOSUM matrix choice. Choose 'BLOSUM50' or 'BLOSUM62'." 195 | ) 196 | 197 | blosum_matrix = {} 198 | for i, letter_1 in enumerate(alphabet): 199 | if canonical: 200 | blosum_matrix[letter_1] = matrix[i][:20] 201 | else: 202 | blosum_matrix[letter_1] = matrix[i] 203 | 204 | # create empty encoding vector 205 | encoding = np.zeros((len(sequence), len(blosum_matrix["A"]))) 206 | 207 | # Convert each amino acid in sequence to BLOSUM encoding 208 | for i, aa in enumerate(sequence): 209 | if aa in alphabet: 210 | encoding[i, :] = blosum_matrix[aa] 211 | else: 212 | # Handle unknown amino acids with a default value of 0.5 213 | encoding[i, :] = 0.5 214 | 215 | return encoding 216 | 217 | 218 | # hash sequences using sha256 219 | def hash_sequence(sequence: str, length: int = 20) -> str: 220 | """ 221 | Hashes a sequence using sha256 222 | 223 | Parameters: 224 | ----------- 225 | sequence (str): Amino acid sequence 226 | 227 | Returns: 228 | -------- 229 | str: Hashed sequence 230 | """ 231 | if sequence is None: 232 | pass 233 | else: 234 | return hashlib.sha256(sequence.encode()).hexdigest()[0:length] 235 | -------------------------------------------------------------------------------- /src/proteusAI/io_tools/matrices/BLOSUM50: -------------------------------------------------------------------------------- 1 | 5 -2 -1 -2 -1 -1 -1 0 -2 -1 -2 -1 -1 -3 -1 1 0 -3 -2 0 2 | -2 7 -1 -2 -4 1 0 -3 0 -4 -3 3 -2 -3 -3 -1 -1 -3 -1 -3 3 | -1 -1 7 2 -2 0 0 0 1 -3 -4 0 -2 -4 -2 1 0 -4 -2 -3 4 | -2 -2 2 8 -4 0 2 -1 -1 -4 -4 -1 -4 -5 -1 0 -1 -5 -3 -4 5 | -1 -4 -2 -4 13 -3 -3 -3 -3 -2 -2 -3 -2 -2 -4 -1 -1 -5 -3 -1 6 | -1 1 0 0 -3 7 2 -2 1 -3 -2 2 0 -4 -1 0 -1 -1 -1 -3 7 | -1 0 0 2 -3 2 6 -3 0 -4 -3 1 -2 -3 -1 -1 -1 -3 -2 -3 8 | 0 -3 0 -1 -3 -2 -3 8 -2 -4 -4 -2 -3 -4 -2 0 -2 -3 -3 -4 9 | -2 0 1 -1 -3 1 0 -2 10 -4 -3 0 -1 -1 -2 -1 -2 -3 2 -4 10 | -1 -4 -3 -4 -2 -3 -4 -4 -4 5 2 -3 2 0 -3 -3 -1 -3 -1 4 11 | -2 -3 -4 -4 -2 -2 -3 -4 -3 2 5 -3 3 1 -4 -3 -1 -2 -1 1 12 | -1 3 0 -1 -3 2 1 -2 0 -3 -3 6 -2 -4 -1 0 -1 -3 -2 -3 13 | -1 -2 -2 -4 -2 0 -2 -3 -1 2 3 -2 7 0 -3 -2 -1 -1 0 1 14 | -3 -3 -4 -5 -2 -4 -3 -4 -1 0 1 -4 0 8 -4 -3 -2 1 4 -1 15 | -1 -3 -2 -1 -4 -1 -1 -2 -2 -3 -4 -1 -3 -4 10 -1 -1 -4 -3 -3 16 | 1 -1 1 0 -1 0 -1 0 -1 -3 -3 0 -2 -3 -1 5 2 -4 -2 -2 17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 2 5 -3 -2 0 18 | -3 -3 -4 -5 -5 -1 -3 -3 -3 -3 -2 -3 -1 1 -4 -4 -3 15 2 -3 19 | -2 -1 -2 -3 -3 -1 -2 -3 2 -1 -1 -2 0 4 -3 -2 -2 2 8 -1 20 | 0 -3 -3 -4 -1 -3 -3 -4 -4 4 1 -3 1 -1 -3 -2 0 -3 -1 5 21 | -2 -1 4 5 -3 0 1 -1 0 -4 -4 0 -3 -4 -2 0 0 -5 -3 -4 22 | -1 0 0 1 -3 4 5 -2 0 -3 -3 1 -1 -4 -1 0 -1 -2 -2 -3 23 | -1 -1 -1 -1 -2 -1 -1 -2 -1 -1 -1 -1 -1 -2 -2 -1 0 -3 -1 -1 24 | -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 25 | -------------------------------------------------------------------------------- /src/proteusAI/io_tools/matrices/BLOSUM62: -------------------------------------------------------------------------------- 1 | 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 2 | -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 3 | -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 4 | -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 5 | 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 6 | -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 7 | -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 8 | 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 9 | -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 10 | -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 11 | -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 12 | -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 13 | -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 14 | -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 15 | -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 16 | 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 18 | -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 19 | -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 20 | 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 21 | -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 22 | -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 23 | 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 24 | -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -------------------------------------------------------------------------------- /src/proteusAI/io_tools/matrices/alphabet: -------------------------------------------------------------------------------- 1 | A 2 | R 3 | N 4 | D 5 | C 6 | Q 7 | E 8 | G 9 | H 10 | I 11 | L 12 | K 13 | M 14 | F 15 | P 16 | S 17 | T 18 | W 19 | Y 20 | V 21 | -------------------------------------------------------------------------------- /src/proteusAI/mining_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for mining_tools. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.mining_tools.alphafoldDB import * # noqa: F403 12 | from proteusAI.mining_tools.blast import * # noqa: F403 13 | from proteusAI.mining_tools.uniprot import * # noqa: F403 14 | -------------------------------------------------------------------------------- /src/proteusAI/mining_tools/alphafoldDB.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | import requests 8 | import os 9 | 10 | 11 | def get_AF2_pdb(protein_id: str, out_path: str) -> bool: 12 | """ 13 | This function takes in a UniProt ID and an output path and downloads the corresponding AlphaFold model 14 | from the EBI AlphaFold database in PDB format. 15 | 16 | Parameters: 17 | protein_id (str): The UniProt ID of the protein 18 | out_path (str): The path to save the PDB file to. The directory containing the output file will be created if it does not exist. 19 | 20 | Returns: 21 | bool: True if the PDB file was downloaded successfully, False otherwise. 22 | """ 23 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 24 | 25 | requestURL = f"https://alphafold.ebi.ac.uk/files/AF-{protein_id}-F1-model_v3.pdb" 26 | r = requests.get(requestURL) 27 | 28 | if r.status_code == 200: 29 | with open(out_path, "wb") as f: 30 | f.write(r.content) 31 | return True 32 | else: 33 | return False 34 | -------------------------------------------------------------------------------- /src/proteusAI/mining_tools/blast.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | from tempfile import gettempdir 8 | import biotite.sequence.io.fasta as fasta 9 | import biotite.application.blast as blast 10 | import biotite.database.entrez as entrez 11 | 12 | 13 | # TODO: change email 14 | def search_related_sequences( 15 | query: str, 16 | program: str = "blastp", 17 | database: str = "nr", 18 | obey_rules: bool = True, 19 | mail: str = "johnyfunk@gmail.com", 20 | ): 21 | """ 22 | Search for related sequences using the plast web app. 23 | 24 | Parameters: 25 | query (str): 26 | Query sequence 27 | program (str, optional): 28 | The specific BLAST program. One of 'blastn', 'megablast', 'blastp', 'blastx', 'tblastn' and 'tblastx'. 29 | database (str, optional): 30 | The NCBI sequence database to blast against. By default it contains all sequences (`database`='nr'`). 31 | obey_rules (bool, optional): 32 | If true, the application raises an :class:`RuleViolationError`, if the server is contacted too often, 33 | based on the NCBI BLAST usage rules. (Default: True) 34 | mail : str, optional 35 | If a mail address is provided, it will be appended in the 36 | HTTP request. This allows the NCBI to contact you in case 37 | your application sends too many requests. 38 | 39 | Returns: 40 | tuple: two lits, the first containing the hits and the second the hit sequences 41 | 42 | Example: 43 | hits, hit_seqs = blast_related_sequences(query=sequence, database='swissprot') 44 | """ 45 | # Search only the UniProt/SwissProt database 46 | blast_app = blast.BlastWebApp( 47 | program=program, 48 | query=query, 49 | database=database, 50 | obey_rules=obey_rules, 51 | mail=mail, 52 | ) 53 | blast_app.start() 54 | blast_app.join() 55 | alignments = blast_app.get_alignments() 56 | # Get hit IDs for hits with score > 200 57 | hits = [] 58 | for ali in alignments: 59 | if ali.score > 200: 60 | hits.append(ali.hit_id) 61 | # Get the sequences from hit IDs 62 | hit_seqs = [] 63 | for hit in hits: 64 | file_name = entrez.fetch(hit, gettempdir(), "fa", "protein", "fasta") 65 | fasta_file = fasta.FastaFile.read(file_name) 66 | hit_seqs.append(fasta.get_sequence(fasta_file)) 67 | 68 | return hits, hit_seqs 69 | -------------------------------------------------------------------------------- /src/proteusAI/mining_tools/uniprot.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | from Bio import SeqIO 8 | import requests 9 | from io import StringIO 10 | from hashlib import md5 11 | 12 | 13 | def get_protein_sequence(uniprot_id: str) -> str: 14 | """ 15 | This function takes a UniProt ID as input and returns the corresponding sequence record. 16 | 17 | Parameters: 18 | uniprot_id (str): The UniProt identifier of the protein of interest 19 | 20 | Returns: 21 | list: List with results as Bio.SeqRecord.SeqRecord object 22 | 23 | Example: 24 | ASMT_representations = uniprot.get_protein_sequence('P46597') 25 | sequence_string = str(ASMT_representations[0].seq) 26 | """ 27 | base_url = "http://www.uniprot.org/uniprot/" 28 | current_url = base_url + uniprot_id + ".fasta" 29 | response = requests.post(current_url) 30 | c_data = "".join(response.text) 31 | 32 | seq = StringIO(c_data) 33 | p_seq = list(SeqIO.parse(seq, "fasta")) 34 | return p_seq 35 | 36 | 37 | def get_uniprot_id(sequence: str) -> str: 38 | """ 39 | This function takes in a protein sequence string and returns the corresponding UniProt ID. 40 | 41 | Parameters: 42 | sequence (str): The protein sequence string 43 | 44 | Returns: 45 | str: The UniProt ID of the protein, if it exists in UniProt. None otherwise. 46 | """ 47 | h = md5(sequence.encode()).digest().hex() 48 | requestURL = f"https://www.ebi.ac.uk/proteins/api/proteins?offset=0&size=100&md5={h}" # 5e2c446cc1c54ee4406b9f6683b7f98d 49 | r = requests.get(requestURL, headers={"Accept": "application/json"}) 50 | 51 | if not r.ok: 52 | return None 53 | data = r.json() 54 | 55 | if len(data) == 0: 56 | return None 57 | else: 58 | return data[0]["accession"] 59 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for machine learning 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/bo_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for bayesian optimization tools. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.ml_tools.bo_tools.acq_fn import * # noqa: F403 12 | from proteusAI.ml_tools.bo_tools.genetic_algorithm import * # noqa: F403 13 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/bo_tools/acq_fn.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk and Laura Sofia Machado" 6 | 7 | import numpy as np 8 | from scipy.stats import norm 9 | 10 | 11 | def greedy(mean, std=None, current_best=None, xi=None): 12 | """ 13 | Greedy acquisition function. 14 | 15 | Args: 16 | mean (np.array): This is the mean function from the GP over the considered set of points. 17 | std (np.array, optional): This is the standard deviation function from the GP over the considered set of points. Default is None. 18 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None. 19 | xi (float, optional): Small value added to avoid corner cases. Default is None. 20 | 21 | Returns: 22 | np.array: The mean values for all the points, as greedy acquisition selects the best based on mean. 23 | """ 24 | return mean 25 | 26 | 27 | def EI(mean, std, current_best, xi=0.1): 28 | """ 29 | Expected Improvement acquisition function. 30 | 31 | It implements the following function: 32 | 33 | | (mu - mu^+ - xi) Phi(Z) + sigma phi(Z) if sigma > 0 34 | EI(x) = | 35 | | 0 if sigma = 0 36 | 37 | where Phi is the CDF and phi the PDF of the normal distribution 38 | and 39 | Z = (mu - mu^+ - xi) / sigma 40 | 41 | Args: 42 | mean (np.array): This is the mean function from the GP over the considered set of points. 43 | std (np.array): This is the standard deviation function from the GP over the considered set of points. 44 | current_best (float): This is the current maximum of the unknown function: mu^+. 45 | xi (float): Small value added to avoid corner cases. 46 | 47 | Returns: 48 | np.array: The value of this acquisition function for all the points. 49 | """ 50 | 51 | Z = (mean - current_best - xi) / (std + 1e-9) 52 | EI = (mean - current_best - xi) * norm.cdf(Z) + std * norm.pdf(Z) 53 | EI[std == 0] = 0 54 | 55 | return EI 56 | 57 | 58 | def UCB(mean, std, current_best=None, kappa=1.5): 59 | """ 60 | Upper-Confidence Bound acquisition function. 61 | 62 | Args: 63 | mean (np.array): This is the mean function from the GP over the considered set of points. 64 | std (np.array): This is the standard deviation function from the GP over the considered set of points. 65 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None. 66 | kappa (float): Exploration-exploitation trade-off parameter. The higher the value, the more exploration. Default is 0. 67 | 68 | Returns: 69 | np.array: The value of this acquisition function for all the points. 70 | """ 71 | return mean + kappa * std 72 | 73 | 74 | def random_acquisition(mean, std=None, current_best=None, xi=None): 75 | """ 76 | Random acquisition function. Assigns random acquisition values to all points in the unobserved set. 77 | 78 | Args: 79 | mean (np.array): This is the mean function from the GP over the considered set of points. 80 | std (np.array, optional): This is the standard deviation function from the GP over the considered set of points. Default is None. 81 | current_best (float, optional): This is the current maximum of the unknown function: mu^+. Default is None. 82 | xi (float, optional): Small value added to avoid corner cases. Default is None. 83 | 84 | Returns: 85 | np.array: Random acquisition values for all points in the unobserved set. 86 | """ 87 | n_unobserved = len(mean) 88 | np.random.seed(None) 89 | random_acq_values = np.random.random(n_unobserved) 90 | return random_acq_values 91 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/bo_tools/genetic_algorithm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | ##################################### 6 | ### Simulated Annealing Discovery ### 7 | ##################################### 8 | 9 | 10 | def precompute_distances(vectors): 11 | """Precompute the pairwise Euclidean distance matrix.""" 12 | num_vectors = len(vectors) 13 | distance_matrix = np.zeros((num_vectors, num_vectors)) 14 | 15 | for i in range(num_vectors): 16 | for j in range(i + 1, num_vectors): 17 | dist = np.linalg.norm(vectors[i] - vectors[j]) 18 | distance_matrix[i, j] = dist 19 | distance_matrix[j, i] = dist 20 | 21 | return distance_matrix 22 | 23 | 24 | def diversity_score_incremental( 25 | current_score, selected_indices, idx_in, idx_out, distance_matrix 26 | ): 27 | """Update the diversity score incrementally when a vector is swapped.""" 28 | new_score = current_score 29 | 30 | for idx in selected_indices: 31 | if idx != idx_out: 32 | new_score -= distance_matrix[idx_out, idx] 33 | new_score += distance_matrix[idx_in, idx] 34 | 35 | return new_score 36 | 37 | 38 | def simulated_annealing( 39 | vectors, 40 | N, 41 | initial_temperature=1000.0, 42 | cooling_rate=0.003, 43 | max_iterations=10000, 44 | pbar=None, 45 | ): 46 | """ 47 | Simulated Annealing to select N vectors that maximize diversity. 48 | 49 | Args: 50 | vectors (list): List of numpy arrays. 51 | N (int): Number of sequences that should be sampled. 52 | initial_temperature (float): Initial temperature of the simulated annealing algorithm. Default 1000.0. 53 | cooling_rate (float): Cooling rate of the simulated annealing algorithm. Default 0.003. 54 | max_iterations (int): Maximum number of iterations of the simulated annealing algorithm. Default 10000. 55 | 56 | Returns: 57 | list: Indices of diverse vectors. 58 | """ 59 | 60 | if pbar: 61 | pbar.set(message="Computing distance matrix", detail="...") 62 | 63 | # Precompute all pairwise distances 64 | distance_matrix = precompute_distances(vectors) 65 | 66 | # Randomly initialize the selection of N vectors 67 | selected_indices = random.sample(range(len(vectors)), N) 68 | current_score = sum( 69 | distance_matrix[i, j] 70 | for i in selected_indices 71 | for j in selected_indices 72 | if i < j 73 | ) 74 | 75 | temperature = initial_temperature 76 | best_score = current_score 77 | best_selection = selected_indices[:] 78 | 79 | for iteration in range(max_iterations): 80 | 81 | if pbar: 82 | pbar.set(iteration, message="Minimizing energy", detail="...") 83 | 84 | # Randomly select a vector to swap 85 | idx_out = random.choice(selected_indices) 86 | idx_in = random.choice( 87 | [i for i in range(len(vectors)) if i not in selected_indices] 88 | ) 89 | 90 | # Incrementally update the diversity score 91 | new_score = diversity_score_incremental( 92 | current_score, selected_indices, idx_in, idx_out, distance_matrix 93 | ) 94 | 95 | # Decide whether to accept the new solution 96 | delta = new_score - current_score 97 | if delta > 0 or np.exp(delta / temperature) > random.random(): 98 | selected_indices.remove(idx_out) 99 | selected_indices.append(idx_in) 100 | current_score = new_score 101 | 102 | # Update the best solution found so far 103 | if new_score > best_score: 104 | best_score = new_score 105 | best_selection = selected_indices[:] 106 | 107 | # Cool down the temperature 108 | temperature *= 1 - cooling_rate 109 | 110 | # Early stopping if the temperature is low enough 111 | # if temperature < 1e-8: 112 | # break 113 | 114 | return best_selection, best_score 115 | 116 | 117 | ####################################### 118 | ### Genetic Algorithm for Mutations ### 119 | ####################################### 120 | 121 | 122 | def find_mutations(sequences): 123 | """ 124 | Takes a list of protein sequences and returns a dictionary with mutation positions 125 | as keys and lists of amino acids at those positions as values. 126 | 127 | Parameters: 128 | sequences (list): A list of protein sequences (strings). 129 | 130 | Returns: 131 | dict: A dictionary where keys are positions (1-indexed) and values are lists of amino acids at those positions. 132 | """ 133 | # Check if the list is empty 134 | if not sequences: 135 | return {} 136 | 137 | # Initialize a dictionary to store mutations 138 | mutations = {} 139 | 140 | # Get the reference sequence (assuming all sequences have the same length) 141 | reference_seq = sequences[0] 142 | 143 | # Iterate through each position in the sequence 144 | for i in range(len(reference_seq)): 145 | # Initialize a set to store the different amino acids at this position 146 | amino_acids = set() 147 | 148 | # Check the amino acid at this position in each sequence 149 | for seq in sequences: 150 | amino_acids.add(seq[i]) 151 | 152 | # If there is more than one unique amino acid at this position, it's a mutation 153 | if len(amino_acids) > 1: 154 | mutations[i + 1] = list(amino_acids) # +1 to make position 1-indexed 155 | 156 | return mutations 157 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/esm_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for protein language activity_prediction 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.ml_tools.esm_tools import * # noqa: F403 12 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/esm_tools/alphabet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonfunk21/ProteusAI/b4584dc9c64334ac0cf5e8f9fee54878aa1fb735/src/proteusAI/ml_tools/esm_tools/alphabet.pt -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/sklearn_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for sklearn_tools activity_prediction. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.ml_tools.sklearn_tools.grid_search import * # noqa: F403 12 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/sklearn_tools/grid_search.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | __name__ = "proteusAI" 5 | __author__ = "Jonathan Funk" 6 | 7 | from sklearn.model_selection import GridSearchCV 8 | from sklearn.ensemble import RandomForestRegressor 9 | from scipy.stats import pearsonr 10 | from sklearn.neighbors import KNeighborsRegressor 11 | import pandas as pd 12 | from sklearn.svm import SVR 13 | import numpy 14 | 15 | 16 | def knnr_grid_search( 17 | Xs_train: numpy.ndarray, 18 | Xs_test: numpy.ndarray, 19 | ys_train: list, 20 | ys_test: list, 21 | param_grid: dict = None, 22 | verbose: int = 1, 23 | ): 24 | """ 25 | Performs a KNN regressor grid search using 5 fold cross validation. 26 | 27 | Parameters: 28 | ----------- 29 | Xs_train, Xs_test (numpy.ndarray): train and test values for training 30 | ys_train, ys_test (list): y values for train and test 31 | param_grid (dict): parameter grid for model 32 | verbose (int): Type of information printing during run 33 | 34 | Returns: 35 | -------- 36 | Returns the best performing model of grid search, the test R squared value, correlation coefficient, 37 | p-value of the fit and a dataframe containing the fit information. 38 | """ 39 | # Instantiate the model 40 | knnr = KNeighborsRegressor() 41 | 42 | if param_grid is None: 43 | param_grid = { 44 | "n_neighbors": [3, 5, 7], 45 | "weights": ["uniform", "distance"], 46 | "algorithm": ["ball_tree", "kd_tree", "brute"], 47 | "leaf_size": [10, 15, 20], 48 | "p": [1, 2, 3], 49 | } 50 | # Create a GridSearchCV object and fit the model 51 | # grid_search = GridSearchCV(estimator=knnr, param_grid=param_grid, cv=5, n_jobs=-1, verbose=1) 52 | grid_search = GridSearchCV( 53 | estimator=knnr, 54 | param_grid=param_grid, 55 | scoring="r2", 56 | cv=5, 57 | verbose=verbose, 58 | n_jobs=-1, # use all available cores 59 | ) 60 | 61 | grid_search.fit(Xs_train, ys_train) 62 | test_r2 = grid_search.score(Xs_test, ys_test) 63 | predictions = grid_search.best_estimator_.predict(Xs_test) 64 | corr_coef, p_value = pearsonr(predictions, ys_test) 65 | 66 | # Print the best hyperparameters and the best score 67 | if verbose is not None: 68 | print("Best hyperparameters: ", grid_search.best_params_) 69 | print("Best score: ", grid_search.best_score_) 70 | print("Test R^2 score: ", test_r2) 71 | print("Correlation coefficient: {:.2f}".format(corr_coef)) 72 | print("p-value: {:.4f}".format(p_value)) 73 | 74 | return ( 75 | grid_search.best_estimator_, 76 | test_r2, 77 | corr_coef, 78 | p_value, 79 | pd.DataFrame.from_dict(grid_search.cv_results_), 80 | ) 81 | 82 | 83 | def rfr_grid_search( 84 | Xs_train: numpy.ndarray, 85 | Xs_test: numpy.ndarray, 86 | ys_train: list, 87 | ys_test: list, 88 | param_grid: dict = None, 89 | verbose: int = 1, 90 | ): 91 | """ 92 | Performs a Random Forrest regressor grid search using 5 fold cross validation. 93 | 94 | Parameters: 95 | ----------- 96 | Xs_train, Xs_test (numpy.ndarray): train and test values for training 97 | ys_train, ys_test (list): y values for train and test 98 | param_grid (dict): parameter grid for model 99 | verbose (int): Type of information printing during run 100 | 101 | Returns: 102 | -------- 103 | Returns the best performing model of grid search, the test R squared value, correlation coefficient, 104 | p-value of the fit and a dataframe containing the fit information. 105 | """ 106 | # Instantiate the model 107 | rfr = RandomForestRegressor(random_state=42) 108 | 109 | if param_grid is None: 110 | param_grid = { 111 | "n_estimators": [20, 50, 100, 200], 112 | "criterion": ["squared_error", "absolute_error"], 113 | "max_features": ["sqrt", "log2"], 114 | "max_depth": [5, 10, 15], 115 | "min_samples_split": [2, 5, 10], 116 | "min_samples_leaf": [1, 4], 117 | } 118 | 119 | # Create a GridSearchCV object and fit the model 120 | grid_search = GridSearchCV( 121 | estimator=rfr, 122 | param_grid=param_grid, 123 | scoring="r2", 124 | cv=5, 125 | verbose=verbose, 126 | n_jobs=-1, # use all available cores 127 | ) 128 | grid_search.fit(Xs_train, ys_train) 129 | 130 | # Evaluate the performance of the model on the test set 131 | test_r2 = grid_search.score(Xs_test, ys_test) 132 | predictions = grid_search.best_estimator_.predict(Xs_test) 133 | corr_coef, p_value = pearsonr(predictions, ys_test) 134 | 135 | if verbose is not None: 136 | # Print the best hyperparameters and the best score 137 | print("Best hyperparameters: ", grid_search.best_params_) 138 | print("Best score: ", grid_search.best_score_) 139 | print("Test R^2 score: ", test_r2) 140 | print("Correlation coefficient: {:.2f}".format(corr_coef)) 141 | print("p-value: {:.4f}".format(p_value)) 142 | 143 | return ( 144 | grid_search.best_estimator_, 145 | test_r2, 146 | corr_coef, 147 | p_value, 148 | pd.DataFrame.from_dict(grid_search.cv_results_), 149 | ) 150 | 151 | 152 | def svr_grid_search( 153 | Xs_train: numpy.ndarray, 154 | Xs_test: numpy.ndarray, 155 | ys_train: list, 156 | ys_test: list, 157 | param_grid: dict = None, 158 | verbose: int = 1, 159 | ): 160 | """ 161 | Performs a Support Vector regressor grid search using 5 fold cross validation. 162 | 163 | Parameters: 164 | ----------- 165 | Xs_train, Xs_test (numpy.ndarray): train and test values for training 166 | ys_train, ys_test (list): y values for train and test 167 | param_grid (dict): parameter grid for model 168 | verbose (int): Type of information printing during run 169 | 170 | Returns: 171 | -------- 172 | Returns the best performing model of grid search, the test R squared value, correlation coefficient, 173 | p-value of the fit and a dataframe containing the fit information. 174 | """ 175 | # Instantiate the model 176 | svr = SVR() 177 | 178 | if param_grid is None: 179 | param_grid = { 180 | "C": [0.1, 1, 2, 5, 10, 100, 200, 400], 181 | "gamma": ["scale"], 182 | "kernel": ["linear", "poly", "rbf", "sigmoid"], 183 | "degree": [3], 184 | } 185 | 186 | # Create a GridSearchCV object and fit the model 187 | grid_search = GridSearchCV( 188 | estimator=svr, 189 | param_grid=param_grid, 190 | scoring="r2", 191 | cv=5, 192 | verbose=verbose, 193 | n_jobs=-1, # use all available cores 194 | ) 195 | grid_search.fit(Xs_train, ys_train) 196 | 197 | # Evaluate the performance of the model on the test set 198 | test_r2 = grid_search.score(Xs_test, ys_test) 199 | predictions = grid_search.best_estimator_.predict(Xs_test) 200 | corr_coef, p_value = pearsonr(predictions, ys_test) 201 | 202 | if verbose is not None: 203 | # Print the best hyperparameters and the best score 204 | print("Best hyperparameters: ", grid_search.best_params_) 205 | print("Best score: ", grid_search.best_score_) 206 | print("Test R^2 score: ", test_r2) 207 | print("Correlation coefficient: {:.2f}".format(corr_coef)) 208 | print("p-value: {:.4f}".format(p_value)) 209 | 210 | return ( 211 | grid_search.best_estimator_, 212 | test_r2, 213 | corr_coef, 214 | p_value, 215 | pd.DataFrame.from_dict(grid_search.cv_results_), 216 | grid_search.best_params_, 217 | ) 218 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/torch_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for pytorch tools. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.ml_tools.torch_tools.torch_tools import * # noqa: F403 12 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/torch_tools/matrices/BLOSUM50: -------------------------------------------------------------------------------- 1 | 5 -2 -1 -2 -1 -1 -1 0 -2 -1 -2 -1 -1 -3 -1 1 0 -3 -2 0 2 | -2 7 -1 -2 -4 1 0 -3 0 -4 -3 3 -2 -3 -3 -1 -1 -3 -1 -3 3 | -1 -1 7 2 -2 0 0 0 1 -3 -4 0 -2 -4 -2 1 0 -4 -2 -3 4 | -2 -2 2 8 -4 0 2 -1 -1 -4 -4 -1 -4 -5 -1 0 -1 -5 -3 -4 5 | -1 -4 -2 -4 13 -3 -3 -3 -3 -2 -2 -3 -2 -2 -4 -1 -1 -5 -3 -1 6 | -1 1 0 0 -3 7 2 -2 1 -3 -2 2 0 -4 -1 0 -1 -1 -1 -3 7 | -1 0 0 2 -3 2 6 -3 0 -4 -3 1 -2 -3 -1 -1 -1 -3 -2 -3 8 | 0 -3 0 -1 -3 -2 -3 8 -2 -4 -4 -2 -3 -4 -2 0 -2 -3 -3 -4 9 | -2 0 1 -1 -3 1 0 -2 10 -4 -3 0 -1 -1 -2 -1 -2 -3 2 -4 10 | -1 -4 -3 -4 -2 -3 -4 -4 -4 5 2 -3 2 0 -3 -3 -1 -3 -1 4 11 | -2 -3 -4 -4 -2 -2 -3 -4 -3 2 5 -3 3 1 -4 -3 -1 -2 -1 1 12 | -1 3 0 -1 -3 2 1 -2 0 -3 -3 6 -2 -4 -1 0 -1 -3 -2 -3 13 | -1 -2 -2 -4 -2 0 -2 -3 -1 2 3 -2 7 0 -3 -2 -1 -1 0 1 14 | -3 -3 -4 -5 -2 -4 -3 -4 -1 0 1 -4 0 8 -4 -3 -2 1 4 -1 15 | -1 -3 -2 -1 -4 -1 -1 -2 -2 -3 -4 -1 -3 -4 10 -1 -1 -4 -3 -3 16 | 1 -1 1 0 -1 0 -1 0 -1 -3 -3 0 -2 -3 -1 5 2 -4 -2 -2 17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 2 5 -3 -2 0 18 | -3 -3 -4 -5 -5 -1 -3 -3 -3 -3 -2 -3 -1 1 -4 -4 -3 15 2 -3 19 | -2 -1 -2 -3 -3 -1 -2 -3 2 -1 -1 -2 0 4 -3 -2 -2 2 8 -1 20 | 0 -3 -3 -4 -1 -3 -3 -4 -4 4 1 -3 1 -1 -3 -2 0 -3 -1 5 21 | -2 -1 4 5 -3 0 1 -1 0 -4 -4 0 -3 -4 -2 0 0 -5 -3 -4 22 | -1 0 0 1 -3 4 5 -2 0 -3 -3 1 -1 -4 -1 0 -1 -2 -2 -3 23 | -1 -1 -1 -1 -2 -1 -1 -2 -1 -1 -1 -1 -1 -2 -2 -1 0 -3 -1 -1 24 | -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 -5 25 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/torch_tools/matrices/BLOSUM62: -------------------------------------------------------------------------------- 1 | 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 2 | -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 3 | -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 4 | -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 5 | 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 6 | -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 7 | -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 8 | 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 9 | -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 10 | -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 11 | -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 12 | -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 13 | -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 14 | -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 15 | -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 16 | 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 17 | 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 18 | -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 19 | -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 20 | 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 21 | -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 22 | -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 23 | 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 24 | -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/torch_tools/matrices/alphabet: -------------------------------------------------------------------------------- 1 | A 2 | R 3 | N 4 | D 5 | C 6 | Q 7 | E 8 | G 9 | H 10 | I 11 | L 12 | K 13 | M 14 | F 15 | P 16 | S 17 | T 18 | W 19 | Y 20 | V 21 | -------------------------------------------------------------------------------- /src/proteusAI/ml_tools/torch_tools/torch_tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | from typing import Union 7 | import gpytorch 8 | 9 | 10 | def one_hot_encoder(sequences, alphabet=None, canonical=True, pbar=None, padding=None): 11 | """ 12 | Encodes sequences provided an alphabet. 13 | 14 | Parameters: 15 | sequences (list or str): list of amino acid sequences or a single sequence. 16 | alphabet (list or None): list of characters in the alphabet or None to load from file. 17 | canonical (bool): only use canonical amino acids. 18 | padding (int or None): the length to which all sequences should be padded. 19 | If None, no padding beyond the length of the longest sequence. 20 | 21 | Returns: 22 | torch.Tensor: (number of sequences, padding or maximum sequence length, size of the alphabet) for list input 23 | (padding or maximum sequence length, size of the alphabet) for string input 24 | """ 25 | # Check if sequences is a string 26 | if isinstance(sequences, str): 27 | singular = True 28 | sequences = [sequences] # Make it a list to use the same code below 29 | else: 30 | singular = False 31 | 32 | # Load the alphabet from a file if it's not provided 33 | if alphabet is None: 34 | # Get the directory of the current script 35 | script_dir = os.path.dirname(os.path.realpath(__file__)) 36 | 37 | ### Amino Acid codes 38 | alphabet_file = os.path.join(script_dir, "matrices/alphabet") 39 | alphabet = np.loadtxt(alphabet_file, dtype=str) 40 | 41 | # If canonical is True, only use the first 20 characters of the alphabet 42 | if canonical: 43 | alphabet = alphabet[:20] 44 | 45 | # Create a dictionary to map each character in the alphabet to its index 46 | alphabet_dict = {char: i for i, char in enumerate(alphabet)} 47 | 48 | # Determine the length to which sequences should be padded 49 | max_sequence_length = max(len(sequence) for sequence in sequences) 50 | padded_length = padding if padding is not None else max_sequence_length 51 | 52 | n_sequences = len(sequences) 53 | alphabet_size = len(alphabet) 54 | 55 | # Create an empty tensor of the right size, with padding length 56 | tensor = torch.zeros((n_sequences, padded_length, alphabet_size)) 57 | 58 | # Fill the tensor 59 | for i, sequence in enumerate(sequences): 60 | if pbar: 61 | pbar.set( 62 | i, message="Computing", detail=f"{i}/{len(sequences)} remaining..." 63 | ) 64 | for j, character in enumerate(sequence): 65 | if j >= padded_length: 66 | break # Stop if the sequence length exceeds the padded length 67 | # Get the index of the character in the alphabet 68 | char_index = alphabet_dict.get( 69 | character, -1 70 | ) # Return -1 if character is not in the alphabet 71 | if char_index != -1: 72 | # Set the corresponding element of the tensor to 1 73 | tensor[i, j, char_index] = 1.0 74 | 75 | # If the input was a string, return a tensor of shape (padded_length, alphabet_size) 76 | if singular: 77 | tensor = tensor.squeeze(0) 78 | 79 | return tensor 80 | 81 | 82 | def blosum_encoding( 83 | sequences, matrix="BLOSUM62", canonical=True, pbar=None, padding=None 84 | ): 85 | """ 86 | Returns BLOSUM encoding for amino acid sequences. Unknown amino acids will be 87 | encoded with 0.5 in the entire row. 88 | 89 | Parameters: 90 | ----------- 91 | sequences (list or str): List of amino acid sequences or a single sequence. 92 | matrix (str): Choice of BLOSUM matrix. Can be 'BLOSUM50' or 'BLOSUM62'. 93 | canonical (bool): Only use canonical amino acids. 94 | padding (int or None): The length to which all sequences should be padded. 95 | If None, no padding beyond the length of the longest sequence. 96 | pbar: Progress bar for shiny app. 97 | 98 | Returns: 99 | -------- 100 | torch.Tensor: BLOSUM encoded sequence. 101 | """ 102 | 103 | # Check if sequences is a string 104 | if isinstance(sequences, str): 105 | singular = True 106 | sequences = [sequences] # Make it a list to use the same code below 107 | else: 108 | singular = False 109 | 110 | # Get the directory of the current script 111 | script_dir = os.path.dirname(os.path.realpath(__file__)) 112 | 113 | # Load the alphabet 114 | alphabet_file = os.path.join(script_dir, "matrices/alphabet") 115 | alphabet = np.loadtxt(alphabet_file, dtype=str) 116 | 117 | # Define BLOSUM matrices 118 | _blosum50 = ( 119 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM50"), dtype=float) 120 | .reshape((24, -1)) 121 | .T 122 | ) 123 | _blosum62 = ( 124 | np.loadtxt(os.path.join(script_dir, "matrices/BLOSUM62"), dtype=float) 125 | .reshape((24, -1)) 126 | .T 127 | ) 128 | 129 | # Choose BLOSUM matrix 130 | if matrix == "BLOSUM50": 131 | matrix = _blosum50 132 | elif matrix == "BLOSUM62": 133 | matrix = _blosum62 134 | else: 135 | raise ValueError( 136 | "Invalid BLOSUM matrix choice. Choose 'BLOSUM50' or 'BLOSUM62'." 137 | ) 138 | 139 | # Create the BLOSUM encoding dictionary 140 | blosum_matrix = {} 141 | for i, letter_1 in enumerate(alphabet): 142 | if canonical: 143 | blosum_matrix[letter_1] = matrix[i][:20] 144 | else: 145 | blosum_matrix[letter_1] = matrix[i] 146 | 147 | # Determine the length to which sequences should be padded 148 | max_sequence_length = max(len(sequence) for sequence in sequences) 149 | padded_length = padding if padding is not None else max_sequence_length 150 | 151 | n_sequences = len(sequences) 152 | alphabet_size = len(blosum_matrix["A"]) 153 | 154 | # Create an empty tensor of the right size, with padding length 155 | tensor = torch.zeros((n_sequences, padded_length, alphabet_size)) 156 | 157 | # Convert each amino acid in sequence to BLOSUM encoding 158 | for i, sequence in enumerate(sequences): 159 | if pbar: 160 | pbar.set( 161 | i, message="Computing", detail=f"{i}/{len(sequences)} remaining..." 162 | ) 163 | for j, aa in enumerate(sequence): 164 | if j >= padded_length: 165 | break # Stop if the sequence length exceeds the padded length 166 | if aa in alphabet: 167 | tensor[i, j, :] = torch.tensor(blosum_matrix[aa]) 168 | else: 169 | # Handle unknown amino acids with a default value of 0.5 170 | tensor[i, j, :] = 0.5 171 | 172 | # If the input was a string, return a tensor of shape (padded_length, alphabet_size) 173 | if singular: 174 | tensor = tensor.squeeze(0) 175 | 176 | return tensor 177 | 178 | 179 | # Define the VHSE dictionary 180 | vhse_dict = { 181 | "A": [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], 182 | "R": [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], 183 | "N": [-0.99, 0.00, -0.37, 0.69, -0.55, 0.85, 0.73, -0.80], 184 | "D": [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], 185 | "C": [0.18, -1.67, -0.46, -0.21, 0.00, 1.20, -1.61, -0.19], 186 | "Q": [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], 187 | "E": [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.02], 188 | "G": [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, 2.01, -1.34], 189 | "H": [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], 190 | "I": [1.27, -0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], 191 | "L": [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], 192 | "K": [-1.17, 0.70, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13], 193 | "M": [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], 194 | "F": [1.52, 0.61, 0.96, -0.16, 0.25, 0.28, -1.33, -0.20], 195 | "P": [0.22, -0.17, -0.50, 0.05, -0.01, -1.34, -0.19, 3.56], 196 | "S": [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], 197 | "T": [-0.34, -0.51, -0.55, -1.06, -0.06, -0.01, -0.79, 0.39], 198 | "W": [1.50, 2.06, 1.79, 0.75, 0.75, -0.13, -1.01, -0.85], 199 | "Y": [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], 200 | "V": [0.76, -0.92, -0.17, -1.91, 0.22, -1.40, -0.24, -0.03], 201 | } 202 | 203 | 204 | def vhse_encoder(sequences, padding=None, pbar=None): 205 | """ 206 | Encodes sequences using VHSE descriptors. 207 | 208 | Parameters: 209 | sequences (list or str): List of amino acid sequences or a single sequence. 210 | padding (int or None): Length to which all sequences should be padded. 211 | If None, no padding beyond the longest sequence. 212 | pbar: Progress bar for tracking (optional). 213 | 214 | Returns: 215 | torch.Tensor: VHSE encoded tensor of shape 216 | (number of sequences, padding or max sequence length, 8) 217 | """ 218 | # Ensure input is a list 219 | if isinstance(sequences, str): 220 | singular = True 221 | sequences = [sequences] 222 | else: 223 | singular = False 224 | 225 | # Determine maximum sequence length and apply padding 226 | max_sequence_length = max(len(sequence) for sequence in sequences) 227 | padded_length = padding if padding is not None else max_sequence_length 228 | 229 | n_sequences = len(sequences) 230 | vhse_size = 8 # VHSE descriptors have 8 components 231 | 232 | # Initialize output tensor with zeros 233 | tensor = torch.zeros((n_sequences, padded_length, vhse_size)) 234 | 235 | # Encode each sequence 236 | for i, sequence in enumerate(sequences): 237 | if pbar: 238 | pbar.set(i, message="Encoding", detail=f"{i}/{n_sequences} completed...") 239 | for j, aa in enumerate(sequence): 240 | if j >= padded_length: 241 | break 242 | tensor[i, j] = torch.tensor( 243 | vhse_dict.get(aa, [0.5] * vhse_size) 244 | ) # Default for unknown AAs 245 | 246 | # Squeeze output for single sequence input 247 | if singular: 248 | tensor = tensor.squeeze(0) 249 | 250 | return tensor 251 | 252 | 253 | def plot_attention(attention: list, layer: int, head: int, seq: Union[str, list]): 254 | """ 255 | Plot the attention weights for a specific layer and head. 256 | 257 | Args: 258 | attention (list): List of attention weights from the model 259 | layer (int): Index of the layer to visualize 260 | head (int): Index of the head to visualize 261 | seq (str): Input sequence as a list of tokens 262 | """ 263 | 264 | if isinstance(seq, str): 265 | seq = [char for char in seq] 266 | 267 | # Get the attention weights for the specified layer and head 268 | attn_weights = attention[layer][head].detach().cpu().numpy() 269 | 270 | # Create a heatmap using seaborn 271 | plt.figure(figsize=(10, 10)) 272 | sns.heatmap(attn_weights, xticklabels=seq, yticklabels=seq, cmap="viridis") 273 | 274 | # Set plot title and labels 275 | plt.title(f"Attention weights - Layer {layer + 1}, Head {head + 1}") 276 | plt.xlabel("Input tokens") 277 | plt.ylabel("Output tokens") 278 | 279 | # Show the plot 280 | plt.show() 281 | 282 | 283 | class GP(gpytorch.models.ExactGP): 284 | def __init__( 285 | self, train_x, train_y, likelihood, fix_mean=False 286 | ): # special method: instantiate object 287 | super(GP, self).__init__(train_x, train_y, likelihood) 288 | self.mean_module = gpytorch.means.ConstantMean() # attribute 289 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) 290 | self.mean_module.constant.data.fill_(1) # Set the mean value to 1 291 | if fix_mean: 292 | self.mean_module.constant.requires_grad_(False) 293 | 294 | def forward(self, x): 295 | mean_x = self.mean_module(x) 296 | covar_x = self.covar_module(x) 297 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 298 | 299 | 300 | def predict_gp(model, likelihood, X): 301 | model.eval() 302 | likelihood.eval() 303 | 304 | with torch.no_grad(): 305 | predictions = likelihood(model(X)) 306 | y_pred = predictions.mean 307 | y_std = predictions.stddev 308 | # lower, upper = predictions.confidence_region() 309 | 310 | return y_pred, y_std 311 | 312 | 313 | def computeR2(y_true, y_pred): 314 | """ 315 | Compute R2-values for to torch tensors. 316 | 317 | Args: 318 | y_true (torch.Tensor): true y-values 319 | y_pred (torch.Tensor): predicted y-values 320 | """ 321 | # Ensure the tensors are 1-dimensional 322 | if y_true.dim() != 1 or y_pred.dim() != 1: 323 | raise ValueError("Both y_true and y_pred must be 1-dimensional tensors") 324 | 325 | # Compute the mean of true values 326 | y_mean = torch.mean(y_true) 327 | 328 | # Compute the total sum of squares (SS_tot) 329 | ss_tot = torch.sum((y_true - y_mean) ** 2) 330 | 331 | # Compute the residual sum of squares (SS_res) 332 | ss_res = torch.sum((y_true - y_pred) ** 2) 333 | 334 | # Compute the R² value 335 | r2 = 1 - (ss_res / ss_tot) 336 | 337 | return r2.item() # Convert tensor to float 338 | -------------------------------------------------------------------------------- /src/proteusAI/struc/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for structures. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.struc.struc import * # noqa: F403 12 | -------------------------------------------------------------------------------- /src/proteusAI/visual_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is part of the proteusAI package and is distributed 2 | # under the MIT License. 3 | 4 | """ 5 | A subpackage for sklearn_tools activity_prediction. 6 | """ 7 | 8 | __name__ = "proteusAI" 9 | __author__ = "Jonathan Funk" 10 | 11 | from proteusAI.visual_tools.plots import * # noqa: F403 12 | -------------------------------------------------------------------------------- /src/proteusAI/visual_tools/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from typing import Union, List 5 | import pandas as pd 6 | from sklearn.manifold import TSNE 7 | from sklearn.decomposition import PCA 8 | import numpy as np 9 | import umap 10 | 11 | matplotlib.use("Agg") 12 | representation_dict = { 13 | "One-hot": "ohe", 14 | "BLOSUM50": "blosum50", 15 | "BLOSUM62": "blosum62", 16 | "ESM-2": "esm2", 17 | "ESM-1v": "esm1v", 18 | } 19 | 20 | 21 | def plot_predictions_vs_groundtruth( 22 | y_true: list, 23 | y_pred: list, 24 | title: Union[str, None] = None, 25 | x_label: Union[str, None] = None, 26 | y_label: Union[str, None] = None, 27 | plot_grid: bool = True, 28 | file: Union[str, None] = None, 29 | show_plot: bool = True, 30 | width: Union[float, None] = None, 31 | ): 32 | # Create the plot 33 | fig, ax = plt.subplots(figsize=(10, 5)) 34 | sns.scatterplot(x=y_true, y=y_pred, alpha=0.5, ax=ax) 35 | 36 | # Set plot title and labels 37 | if title is None: 38 | title = "Predicted vs. True y-values" 39 | if y_label is None: 40 | y_label = "predicted y" 41 | if x_label is None: 42 | x_label = "y" 43 | 44 | ax.set_title(title) 45 | ax.set_xlabel(x_label) 46 | ax.set_ylabel(y_label) 47 | 48 | # Add the diagonal line 49 | min_val = min(min(y_true) * 1.05, min(y_pred) * 1.05) 50 | max_val = max(max(y_true) * 1.05, max(y_pred) * 1.05) 51 | ax.plot( 52 | [min_val, max_val], 53 | [min_val, max_val], 54 | color="grey", 55 | linestyle="dotted", 56 | linewidth=2, 57 | ) 58 | 59 | # Add vertical error bars and confidence region if width is specified 60 | if width is not None: 61 | # Add error bars with T-shape 62 | for true, pred in zip(y_true, y_pred): 63 | ax.plot( 64 | [true, true], [pred - width, pred + width], color="darkgrey", alpha=0.8 65 | ) 66 | ax.plot( 67 | [ 68 | true - 0.005 * (max_val - min_val), 69 | true + 0.005 * (max_val - min_val), 70 | ], 71 | [pred - width, pred - width], 72 | color="darkgrey", 73 | alpha=0.8, 74 | ) 75 | ax.plot( 76 | [ 77 | true - 0.005 * (max_val - min_val), 78 | true + 0.005 * (max_val - min_val), 79 | ], 80 | [pred + width, pred + width], 81 | color="darkgrey", 82 | alpha=0.8, 83 | ) 84 | 85 | # Add confidence region 86 | x_range = np.linspace(min_val, max_val, 100) 87 | upper_conf = x_range + width 88 | lower_conf = x_range - width 89 | ax.fill_between( 90 | x_range, 91 | lower_conf, 92 | upper_conf, 93 | color="blue", 94 | alpha=0.08, 95 | label="Confidence Region", 96 | ) 97 | 98 | # Adjust layout to remove extra whitespace 99 | fig.tight_layout() 100 | 101 | # Add grid if specified 102 | ax.set_xlim(min_val, max_val) 103 | ax.grid(plot_grid) 104 | 105 | # Save the plot to a file 106 | if file is not None: 107 | fig.savefig(file) 108 | 109 | # Show the plot 110 | if show_plot: 111 | plt.show() 112 | 113 | # Return the figure and axes 114 | return fig, ax 115 | 116 | 117 | def plot_tsne( 118 | x: List[np.ndarray], 119 | y: Union[List[Union[float, str]], None] = None, 120 | y_upper: Union[float, None] = None, 121 | y_lower: Union[float, None] = None, 122 | names: Union[List[str], None] = None, 123 | y_type: str = "num", 124 | random_state: int = 42, 125 | rep_type: Union[str, None] = None, 126 | highlight_mask: Union[List[Union[int, float]], None] = None, 127 | highlight_label: str = "Highlighted", 128 | df: Union[pd.DataFrame, None] = None, 129 | ): 130 | """ 131 | Create a t-SNE plot and optionally color by y values, with special coloring for points outside given thresholds. 132 | Handles cases where y is None or a list of Nones by not applying hue. 133 | Optionally highlights points based on the highlight_mask. 134 | 135 | Args: 136 | x (List[np.ndarray]): List of sequence representations as numpy arrays. 137 | y (List[Union[float, str]]): List of y values, can be None or contain None. 138 | y_upper (float): Upper threshold for special coloring. 139 | y_lower (float): Lower threshold for special coloring. 140 | names (List[str]): List of names for each point. 141 | y_type (str): 'class' for categorical labels or 'num' for numerical labels. 142 | random_state (int): Random state. 143 | rep_type (str): Representation type used for plotting. 144 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted. 145 | highlight_label (str): Text for the legend entry of highlighted points. 146 | df (pd.DataFrame): DataFrame containing the data to plot. 147 | """ 148 | fig, ax = plt.subplots(figsize=(10, 5)) 149 | 150 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x]) 151 | 152 | if len(x.shape) == 3: # Flatten if necessary 153 | x = x.reshape(x.shape[0], -1) 154 | 155 | if df is None: 156 | tsne = TSNE(n_components=2, verbose=1, random_state=random_state) 157 | z = tsne.fit_transform(x) 158 | 159 | df = pd.DataFrame(z, columns=["z1", "z2"]) 160 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative 161 | if names and len(names) == len(y): 162 | df["names"] = names 163 | else: 164 | df["names"] = [None] * len(y) 165 | 166 | # Handle the palette based on whether y is numerical or categorical 167 | if isinstance(y[0], (int, float)): # If y is numerical 168 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True) 169 | else: # If y is categorical 170 | cmap = sns.color_palette("Set2", as_cmap=False) 171 | 172 | hue = ( 173 | "y" if df["y"].isnull().sum() != len(df["y"]) else None 174 | ) # Use hue only if y is informative 175 | scatter = sns.scatterplot( 176 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df 177 | ) 178 | 179 | # Apply special coloring only if y values are valid and thresholds are provided 180 | if hue and (y_upper is not None or y_lower is not None): 181 | outlier_mask = ( 182 | (df["y"] > y_upper) 183 | if y_upper is not None 184 | else np.zeros(len(df), dtype=bool) 185 | ) 186 | outlier_mask |= ( 187 | (df["y"] < y_lower) 188 | if y_lower is not None 189 | else np.zeros(len(df), dtype=bool) 190 | ) 191 | scatter.scatter( 192 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey" 193 | ) 194 | 195 | # Highlight points based on the highlight_mask 196 | if highlight_mask is not None: 197 | highlight_mask = np.array(highlight_mask) 198 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask 199 | scatter.scatter( 200 | df["z1"][highlight_points], 201 | df["z2"][highlight_points], 202 | color="red", 203 | marker="x", 204 | s=60, 205 | alpha=0.7, 206 | label=highlight_label, 207 | ) 208 | 209 | scatter.set_title(f"t-SNE projection of {rep_type if rep_type else 'data'}") 210 | 211 | # Add the legend, making sure to include highlighted points 212 | handles, labels = scatter.get_legend_handles_labels() 213 | if highlight_label in labels: 214 | ax.legend(handles, labels, title="Legend") 215 | else: 216 | ax.legend(title="Legend") 217 | 218 | return fig, ax, df 219 | 220 | 221 | def plot_umap( 222 | x: List[np.ndarray], 223 | y: Union[List[Union[float, str]], None] = None, 224 | y_upper: Union[float, None] = None, 225 | y_lower: Union[float, None] = None, 226 | names: Union[List[str], None] = None, 227 | y_type: str = "num", 228 | random_state: int = 42, 229 | rep_type: Union[str, None] = None, 230 | highlight_mask: Union[List[Union[int, float]], None] = None, 231 | highlight_label: str = "Highlighted", 232 | df: Union[pd.DataFrame, None] = None, 233 | html: bool = False, 234 | ): 235 | """ 236 | Create a UMAP plot and optionally color by y values, with special coloring for points outside given thresholds. 237 | Handles cases where y is None or a list of Nones by not applying hue. 238 | Optionally highlights points based on the highlight_mask. 239 | 240 | Args: 241 | x (List[np.ndarray]): List of sequence representations as numpy arrays. 242 | y (List[Union[float, str]]): List of y values, can be None or contain None. 243 | y_upper (float): Upper threshold for special coloring. 244 | y_lower (float): Lower threshold for special coloring. 245 | names (List[str]): List of names for each point. 246 | y_type (str): 'class' for categorical labels or 'num' for numerical labels. 247 | random_state (int): Random state. 248 | rep_type (str): Representation type used for plotting. 249 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted. 250 | highlight_label (str): Text for the legend entry of highlighted points. 251 | df (pd.DataFrame): DataFrame containing the data to plot. 252 | """ 253 | fig, ax = plt.subplots(figsize=(10, 5)) 254 | 255 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x]) 256 | 257 | if len(x.shape) == 3: # Flatten if necessary 258 | x = x.reshape(x.shape[0], -1) 259 | 260 | if df is None: 261 | umap_model = umap.UMAP( 262 | n_neighbors=70, min_dist=0.0, n_components=2, random_state=random_state 263 | ) 264 | 265 | z = umap_model.fit_transform(x) 266 | df = pd.DataFrame(z, columns=["z1", "z2"]) 267 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative 268 | if names and len(names) is not None and len(names) == len(y): 269 | df["names"] = names 270 | else: 271 | df["names"] = [None] * len(y) 272 | 273 | # Handle the palette based on whether y is numerical or categorical 274 | if isinstance(y[0], (int, float)): # If y is numerical 275 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True) 276 | else: # If y is categorical 277 | cmap = sns.color_palette("Set2", as_cmap=False) 278 | 279 | hue = ( 280 | "y" if df["y"].isnull().sum() != len(df["y"]) else None 281 | ) # Use hue only if y is informative 282 | scatter = sns.scatterplot( 283 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df 284 | ) 285 | 286 | # Apply special coloring only if y values are valid and thresholds are provided 287 | if hue and (y_upper is not None or y_lower is not None): 288 | outlier_mask = ( 289 | (df["y"] > y_upper) 290 | if y_upper is not None 291 | else np.zeros(len(df), dtype=bool) 292 | ) 293 | outlier_mask |= ( 294 | (df["y"] < y_lower) 295 | if y_lower is not None 296 | else np.zeros(len(df), dtype=bool) 297 | ) 298 | scatter.scatter( 299 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey" 300 | ) 301 | 302 | # Highlight points based on the highlight_mask 303 | if highlight_mask is not None: 304 | highlight_mask = np.array(highlight_mask) 305 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask 306 | scatter.scatter( 307 | df["z1"][highlight_points], 308 | df["z2"][highlight_points], 309 | color="red", 310 | marker="x", 311 | s=60, 312 | alpha=0.7, 313 | label=highlight_label, 314 | ) 315 | 316 | scatter.set_title(f"UMAP projection of {rep_type if rep_type else 'data'}") 317 | 318 | # Add the legend, making sure to include highlighted points 319 | handles, labels = scatter.get_legend_handles_labels() 320 | if highlight_label in labels: 321 | ax.legend(handles, labels, title="Legend") 322 | else: 323 | ax.legend(title="Legend") 324 | 325 | return fig, ax, df 326 | 327 | 328 | def plot_pca( 329 | x: List[np.ndarray], 330 | y: Union[List[Union[float, str]], None] = None, 331 | y_upper: Union[float, None] = None, 332 | y_lower: Union[float, None] = None, 333 | names: Union[List[str], None] = None, 334 | y_type: str = "num", 335 | random_state: int = 42, 336 | rep_type: Union[str, None] = None, 337 | highlight_mask: Union[List[Union[int, float]], None] = None, 338 | highlight_label: str = "Highlighted", 339 | df: Union[pd.DataFrame, None] = None, 340 | ): 341 | """ 342 | Create a PCA plot and optionally color by y values, with special coloring for points outside given thresholds. 343 | Handles cases where y is None or a list of Nones by not applying hue. 344 | Optionally highlights points based on the highlight_mask. 345 | 346 | Args: 347 | x (List[np.ndarray]): List of sequence representations as numpy arrays. 348 | y (List[Union[float, str]]): List of y values, can be None or contain None. 349 | y_upper (float): Upper threshold for special coloring. 350 | y_lower (float): Lower threshold for special coloring. 351 | names (List[str]): List of names for each point. 352 | y_type (str): 'class' for categorical labels or 'num' for numerical labels. 353 | random_state (int): Random state. 354 | rep_type (str): Representation type used for plotting. 355 | highlight_mask (List[Union[int, float]]): List of mask values, non-zero points will be highlighted. 356 | highlight_label (str): Text for the legend entry of highlighted points. 357 | df (pd.DataFrame): DataFrame containing the data to plot. 358 | """ 359 | fig, ax = plt.subplots(figsize=(10, 5)) 360 | 361 | x = np.array([t.numpy() if hasattr(t, "numpy") else t for t in x]) 362 | 363 | if len(x.shape) == 3: # Flatten if necessary 364 | x = x.reshape(x.shape[0], -1) 365 | 366 | if df is None: 367 | pca = PCA(n_components=2, random_state=random_state) 368 | z = pca.fit_transform(x) 369 | df = pd.DataFrame(z, columns=["z1", "z2"]) 370 | df["y"] = y if y is not None and any(y) else None # Use y if it's informative 371 | if names and len(names) == len(y): 372 | df["names"] = names 373 | else: 374 | df["names"] = [None] * len(y) 375 | 376 | # Handle the palette based on whether y is numerical or categorical 377 | if isinstance(y[0], (int, float)): # If y is numerical 378 | cmap = sns.cubehelix_palette(rot=-0.2, as_cmap=True) 379 | else: # If y is categorical 380 | cmap = sns.color_palette("Set2", as_cmap=False) 381 | 382 | hue = ( 383 | "y" if df["y"].isnull().sum() != len(df["y"]) else None 384 | ) # Use hue only if y is informative 385 | scatter = sns.scatterplot( 386 | x="z1", y="z2", hue=hue, palette=cmap if hue else None, data=df 387 | ) 388 | 389 | # Apply special coloring only if y values are valid and thresholds are provided 390 | if hue and (y_upper is not None or y_lower is not None): 391 | outlier_mask = ( 392 | (df["y"] > y_upper) 393 | if y_upper is not None 394 | else np.zeros(len(df), dtype=bool) 395 | ) 396 | outlier_mask |= ( 397 | (df["y"] < y_lower) 398 | if y_lower is not None 399 | else np.zeros(len(df), dtype=bool) 400 | ) 401 | scatter.scatter( 402 | df["z1"][outlier_mask], df["z2"][outlier_mask], color="lightgrey" 403 | ) 404 | 405 | # Highlight points based on the highlight_mask 406 | if highlight_mask is not None: 407 | highlight_mask = np.array(highlight_mask) 408 | highlight_points = highlight_mask != 0 # Non-zero entries in the highlight_mask 409 | scatter.scatter( 410 | df["z1"][highlight_points], 411 | df["z2"][highlight_points], 412 | color="red", 413 | marker="x", 414 | s=60, 415 | alpha=0.7, 416 | label=highlight_label, 417 | ) 418 | 419 | scatter.set_title(f"PCA projection of {rep_type if rep_type else 'data'}") 420 | 421 | # Add the legend, making sure to include highlighted points 422 | handles, labels = scatter.get_legend_handles_labels() 423 | if highlight_label in labels: 424 | ax.legend(handles, labels, title="Legend") 425 | else: 426 | ax.legend(title="Legend") 427 | 428 | return fig, ax, df 429 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | """test python package related functions""" 2 | 3 | 4 | def test_load_package(): 5 | import proteusAI 6 | 7 | proteusAI.__version__ 8 | --------------------------------------------------------------------------------