├── .amltignore ├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── amlt_config.yml ├── amlt_config_ablations.yml ├── bin ├── calculate_supervised_metrics.py ├── create_test_sets.py ├── download_EC_annotations.py ├── download_GO_annotations.py ├── download_and_test_multiple_proteinfer_seeds.sh ├── download_and_test_proteinfer_seeds.py ├── download_swissprot.py ├── export_proteinfer.py ├── generate_label_embeddings.py ├── main.py ├── make_dataset_from_swissprot.py ├── make_proteinfer_dataset.py ├── make_zero_shot_datasets_from_proteinfer.py ├── run_baseline.py ├── run_blast.py ├── test_ablation.sh ├── test_models.py ├── test_proteinfer.py ├── umap_plots.py ├── update_go_annotations.py └── upload_to_zenodo.py ├── configs └── base_config.yaml ├── docs └── .gitkeep ├── environment.yml ├── hyperdrive_seed_replicates.yml ├── img ├── main_fig.jpg └── main_fig.pdf ├── model_card.md ├── notebooks ├── GOAnnotationsPerYear.ipynb ├── Gene Ontology Data Archive.html └── Results.ipynb ├── proteinfer_conda_requirements.yml ├── protnote ├── .gitkeep ├── __init__.py ├── data │ ├── __init__.py │ ├── collators.py │ ├── datasets.py │ └── samplers.py ├── models │ ├── ProtNote.py │ ├── ProtNoteTrainer.py │ ├── __init__.py │ ├── blast.py │ └── protein_encoders.py └── utils │ ├── __init__.py │ ├── configs.py │ ├── data.py │ ├── evaluation.py │ ├── losses.py │ ├── main_utils.py │ ├── models.py │ ├── notebooks.py │ └── proteinfer.py ├── pyproject.toml ├── setup.py └── singularity_image ├── Dockerfile └── build_image.sh /.amltignore: -------------------------------------------------------------------------------- 1 | # Ignore all files and folders 2 | /* 3 | 4 | # But do not ignore the src folder and its contents, the config folder, and main.py 5 | !/src/ 6 | !*.py 7 | !/configs/ 8 | !/main.py -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL Advanced" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '39 1 * * 6' 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze (${{ matrix.language }}) 25 | # Runner size impacts CodeQL analysis time. To learn more, please see: 26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 27 | # - https://gh.io/supported-runners-and-hardware-resources 28 | # - https://gh.io/using-larger-runners (GitHub.com only) 29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 31 | permissions: 32 | # required for all workflows 33 | security-events: write 34 | 35 | # required to fetch internal or private CodeQL packs 36 | packages: read 37 | 38 | # only required for workflows in private repositories 39 | actions: read 40 | contents: read 41 | 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | include: 46 | - language: python 47 | build-mode: none 48 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 49 | # Use `c-cpp` to analyze code written in C, C++ or both 50 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 51 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 52 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 53 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 54 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 55 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 56 | steps: 57 | - name: Checkout repository 58 | uses: actions/checkout@v4 59 | 60 | # Initializes the CodeQL tools for scanning. 61 | - name: Initialize CodeQL 62 | uses: github/codeql-action/init@v3 63 | with: 64 | languages: ${{ matrix.language }} 65 | build-mode: ${{ matrix.build-mode }} 66 | # If you wish to specify custom queries, you can do so here or in a config file. 67 | # By default, queries listed here will override any specified in a config file. 68 | # Prefix the list here with "+" to use these queries and those in the config file. 69 | 70 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 71 | # queries: security-extended,security-and-quality 72 | 73 | # If the analyze step fails for one of the languages you are analyzing with 74 | # "We were unable to automatically build your code", modify the matrix above 75 | # to set the build mode to "manual" for that language. Then modify this step 76 | # to build your code. 77 | # ℹ️ Command-line programs to run using the OS shell. 78 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 79 | - if: matrix.build-mode == 'manual' 80 | shell: bash 81 | run: | 82 | echo 'If you are using a "manual" build mode for one or more of the' \ 83 | 'languages you are analyzing, replace this with the commands to build' \ 84 | 'your code, for example:' 85 | echo ' make bootstrap' 86 | echo ' make release' 87 | exit 1 88 | 89 | - name: Perform CodeQL Analysis 90 | uses: github/codeql-action/analyze@v3 91 | with: 92 | category: "/language:${{matrix.language}}" 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | /models/ 81 | /proteinfer/export/ 82 | /proteinfer/cached_models/ 83 | 84 | # Mac OS-specific storage files 85 | .DS_Store 86 | 87 | # vim 88 | *.swp 89 | *.swo 90 | 91 | # Mypy cache 92 | .mypy_cache/ 93 | 94 | # wandb 95 | /wandb/ 96 | 97 | # results 98 | /outputs/ 99 | 100 | # amulet 101 | /amlt/ 102 | .amltconfig 103 | 104 | # runs 105 | /runs/ 106 | 107 | ``` -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /bin/calculate_supervised_metrics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import obonet 3 | import gc 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import argparse 7 | import matplotlib.pyplot as plt 8 | from protnote.utils.notebooks import * 9 | from protnote.utils.configs import load_config, construct_absolute_paths 10 | 11 | 12 | 13 | def main(): 14 | seeds = [12, 22, 32, 42, 52] 15 | pinf_model_ids = [13703706, 13703742, 13703997, 13704131, 13705631] 16 | 17 | # Load the configuration and project root 18 | config, project_root = load_config() 19 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 20 | 21 | # Argument parser setup 22 | parser = argparse.ArgumentParser( 23 | description=""" 24 | Calculate the metrics for protnote,proteinfer and blast on the same test set given predictions for each. 25 | It's design for multiple predictions per model (except blast), one for each seed. 26 | """ 27 | ) 28 | parser.add_argument( 29 | "--proteinfer-prediction-files", 30 | nargs="+", 31 | default=[f'test_logits_GO_TEST_DATA_PATH_proteinfer{id}.h5' for id in pinf_model_ids], 32 | required=False, 33 | help="List of Proteinfer prediction files, potentially from different seeds", 34 | ) 35 | parser.add_argument( 36 | "--protnote-prediction-files", 37 | nargs="+", 38 | default=[f'test_1_logits_TEST_DATA_PATH_seed_replicates_v9_{seed}_sum_last_epoch.h5' for seed in seeds], 39 | required=False, 40 | help="List of ProtNote prediction files, potentially from different seeds", 41 | ) 42 | 43 | parser.add_argument( 44 | "--blast-predictions-file", 45 | type=str, 46 | default="blast_pivot_parsed_test_GO_train_GO_results_fixed.parquet", 47 | required=False, 48 | help="file with blast predictions", 49 | ) 50 | 51 | parser.add_argument( 52 | "--output-file", 53 | type=str, 54 | default='supervised_metrics_df.parquet', 55 | required=False, 56 | help="the file that will store all the metrics", 57 | ) 58 | 59 | parser.add_argument( 60 | "--go-graph-file", 61 | type=str, 62 | default="go_2019-07-01.obo", 63 | required=False, 64 | help="the file of the appropriate go graph. The version/date of the go graph should match the date of the test set", 65 | ) 66 | 67 | parser.add_argument( 68 | "--test-set-name", 69 | type=str, 70 | default="GO 2019 Supervised", 71 | required=False, 72 | help="The name of the test set used to evaluate the predictions", 73 | ) 74 | 75 | 76 | args = parser.parse_args() 77 | 78 | args.proteinfer_prediction_files = construct_absolute_paths(results_dir,args.proteinfer_prediction_files) 79 | args.protnote_prediction_files = construct_absolute_paths(results_dir,args.protnote_prediction_files) 80 | args.blast_predictions_file = construct_absolute_paths(results_dir,[args.blast_predictions_file]) 81 | args.output_file = results_dir / args.output_file 82 | graph_2019 = obonet.read_obo(project_root / "data" / "annotations" / args.go_graph_file) 83 | threshold = 0.5 84 | device = "cpu" 85 | 86 | ontology2alias = { 87 | "molecular_function": "MF", 88 | "biological_process": "BP", 89 | "cellular_component": "CC", 90 | "All": "All", 91 | } 92 | 93 | models_logits = { 94 | "Proteinfer": args.proteinfer_prediction_files, 95 | "ProtNote": args.protnote_prediction_files, 96 | "baseline_blast": args.blast_predictions_file 97 | } 98 | 99 | models_labels = pd.read_hdf(str(args.protnote_prediction_files[0]).replace('logits','labels'), key="labels_df", mode="r").astype("float32") 100 | 101 | metrics_df = [] 102 | models = list(models_logits.keys()) 103 | for model in tqdm(models): 104 | print(model) 105 | for file in models_logits[model]: 106 | print(file) 107 | if str(file).endswith(".parquet"): 108 | logits_df = pd.read_parquet(file) 109 | elif str(file).endswith(".h5"): 110 | logits_df = pd.read_hdf(file, key="logits_df", mode="r").astype( 111 | "float32" 112 | ) 113 | print("read logits of shape: ", logits_df.shape) 114 | 115 | metrics = metrics_by_go_ontology( 116 | logits_df, models_labels, graph_2019, device, threshold 117 | ) 118 | 119 | metrics = pd.DataFrame(metrics) 120 | metrics["model"] = model 121 | metrics["test_name"] = args.test_set_name 122 | metrics.index.name = "metric" 123 | metrics = metrics.set_index(["model", "test_name"], append=True) 124 | metrics_df.append(metrics) 125 | 126 | del logits_df 127 | gc.collect() 128 | print('-----------------') 129 | 130 | metrics_df = pd.concat(metrics_df) 131 | metrics_df.columns = metrics_df.columns.map(ontology2alias) 132 | 133 | metrics_df.to_parquet(args.output_file) 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /bin/create_test_sets.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from protnote.utils.notebooks import get_data_distributions 3 | from protnote.utils.data import read_fasta, save_to_fasta 4 | import pandas as pd 5 | from protnote.utils.configs import load_config, get_logger 6 | 7 | logger = get_logger() 8 | 9 | 10 | 11 | def create_smaller_test_sets(df,output_dir): 12 | # Create smaller datasets to measure BLAST runtime 13 | for n in [1, 10, 100, 1000, 5000, 10_000, 20_000]: 14 | temp = df.query('split=="test"').sample(n=n, random_state=42) 15 | temp = [ 16 | (row["sequence"], row["id"], row["labels"].split()) 17 | for _, row in temp.iterrows() 18 | ] 19 | save_to_fasta( 20 | temp, 21 | output_file=str(output_dir / f"test_{n}_GO.fasta"), 22 | ) 23 | 24 | 25 | 26 | def create_top_labels_test_set(df, output_path, go_term_distribution, k=3000): 27 | # Create top labels df for embeddings analysis 28 | go_term_distribution_p = go_term_distribution / len(df[df["split"] == "train"]) 29 | top_k_labels = set(go_term_distribution_p.iloc[:k].index) 30 | test_df = df[df["split"] == "test"] 31 | test_df_top_k = test_df[ 32 | test_df["labels"].str.split().apply(lambda x: set(x).issubset(top_k_labels)) 33 | ] 34 | print( 35 | len(test_df_top_k) / len(test_df), 36 | go_term_distribution_p.iloc[:k].max(), 37 | go_term_distribution_p.iloc[:k].min(), 38 | ) 39 | test_df_top_k = [ 40 | (row["sequence"], row["id"], row["labels"].split()) 41 | for _, row in test_df_top_k.drop_duplicates(subset=["sequence"]) 42 | .sample(frac=0.1, random_state=42) 43 | .iterrows() 44 | ] 45 | save_to_fasta(test_df_top_k,str(output_path)) 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | 51 | # Load the configuration and project root 52 | config, project_root = load_config() 53 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 54 | 55 | #Create fake train,val,test sets for hparam tuning 56 | logger.info("Creating fake train, val, test sets for hyperparameter tuning...") 57 | subprocess.run(["python", "bin/make_zero_shot_datasets_from_proteinfer.py"],check=True,shell=False) 58 | logger.info("Done.") 59 | 60 | # Updated Supervised Test Set. pinf test seqs + new labels + new vocab 61 | logger.info("Updated Supervised Test Set. ProteInfer test seqs with new labels & new vocab...") 62 | subprocess.run( 63 | [ 64 | "python", 65 | "bin/make_dataset_from_swissprot.py", 66 | "--latest-swissprot-file", 67 | config["paths"]['data_paths']['LATEST_SWISSPROT_DATA_PATH'], 68 | "--output-file-path", 69 | config["paths"]['data_paths']['TEST_2024_DATA_PATH'], 70 | "--sequence-vocabulary", 71 | "proteinfer_test", 72 | "--label-vocabulary", 73 | "all", 74 | "--parenthood-file", 75 | "parenthood_jul_2024.json" 76 | ], 77 | check=True, 78 | shell=False, 79 | ) 80 | logger.info("Done.") 81 | 82 | # Updated Supervised Test Set. pinf test seqs + new labels + pinf/old vocab 83 | logger.info("Updated Supervised Test Set. ProteInfer test seqs with new labels & ProteInfer/old vocab...") 84 | subprocess.run( 85 | [ 86 | "python", 87 | "bin/make_dataset_from_swissprot.py", 88 | "--latest-swissprot-file", 89 | config["paths"]['data_paths']['LATEST_SWISSPROT_DATA_PATH'], 90 | "--output-file-path", 91 | config["paths"]['data_paths']['TEST_2024_PINF_VOCAB_DATA_PATH'], 92 | "--sequence-vocabulary", 93 | "proteinfer_test", 94 | "--label-vocabulary", 95 | "proteinfer", 96 | "--parenthood-file", 97 | "parenthood_jul_2024.json" 98 | ], 99 | check=True, 100 | shell=False, 101 | ) 102 | logger.info("Done.") 103 | 104 | # GO Zero Shot 2024 Leaf Nodes. new seqs + new labels only leaf nodes + only added vocab terms 105 | logger.info("GO Zero Shot 2024 Leaf Nodes. New seqs with new labels only leaf nodes & only added vocab terms") 106 | subprocess.run( 107 | [ 108 | "python", 109 | "bin/make_dataset_from_swissprot.py", 110 | "--latest-swissprot-file", 111 | config["paths"]['data_paths']['LATEST_SWISSPROT_DATA_PATH'], 112 | "--output-file-path", 113 | config["paths"]['data_paths']['TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES'], 114 | "--sequence-vocabulary", 115 | "new", 116 | "--only-leaf-nodes", 117 | "--label-vocabulary", 118 | "new", 119 | "--parenthood-file", 120 | "parenthood_jul_2024.json" 121 | ], 122 | check=True, 123 | shell=False, 124 | ) 125 | logger.info("Done.") 126 | 127 | # GO Zero Shot 2024 new seqs + new labels + only added vocab terms 128 | logger.info("GO Zero Shot 2024 new seqs with new labels & only added vocab terms") 129 | subprocess.run( 130 | [ 131 | "python", 132 | "bin/make_dataset_from_swissprot.py", 133 | "--latest-swissprot-file", 134 | config["paths"]['data_paths']['LATEST_SWISSPROT_DATA_PATH'], 135 | "--output-file-path", 136 | config["paths"]['data_paths']['TEST_DATA_PATH_ZERO_SHOT'], 137 | "--sequence-vocabulary", 138 | "new", 139 | "--label-vocabulary", 140 | "new", 141 | "--parenthood-file", 142 | "parenthood_jul_2024.json" 143 | ], 144 | check=True, 145 | shell=False, 146 | ) 147 | logger.info("Done.") 148 | 149 | # Updated Supervised Train Set. pinf train seqs + all new & old labels. 150 | logger.info("Updated Supervised Train Set. ProteInfer train seqs with all new & old labels.") 151 | subprocess.run( 152 | [ 153 | "python", 154 | "bin/make_dataset_from_swissprot.py", 155 | "--latest-swissprot-file", 156 | config["paths"]['data_paths']['LATEST_SWISSPROT_DATA_PATH'], 157 | "--output-file-path", 158 | config["paths"]['data_paths']['TRAIN_2024_DATA_PATH'], 159 | "--sequence-vocabulary", 160 | "proteinfer_train", 161 | "--label-vocabulary", 162 | "all", 163 | "--parenthood-file", 164 | "parenthood_jul_2024.json" 165 | ], 166 | check=True, 167 | shell=False, 168 | ) 169 | logger.info("Done.") 170 | 171 | train = read_fasta(config["paths"]['data_paths']['TRAIN_DATA_PATH']) 172 | test = read_fasta(config["paths"]['data_paths']['TEST_DATA_PATH']) 173 | 174 | for dataset in ["train", "test"]: 175 | exec( 176 | f'{dataset} = [(seq,id," ".join(labels),"{dataset}") for seq,id,labels in {dataset}]' 177 | ) 178 | 179 | df = train + test 180 | df = pd.DataFrame(df, columns=["sequence", "id", "labels", "split"]) 181 | 182 | vocab, amino_freq, labels = get_data_distributions(df[df["split"] == "train"]) 183 | go_term_distribution = pd.Series(labels.values(), index=labels.keys()).sort_values( 184 | ascending=False 185 | ) 186 | 187 | logger.info("Create smaller test sets for BLAST runtime calculation...") 188 | create_smaller_test_sets(df = df, 189 | output_dir = project_root / "data" / "swissprot" / "proteinfer_splits" / "random" 190 | ) 191 | logger.info("Done.") 192 | 193 | logger.info("Create top labels test set for embeddings analysis.") 194 | create_top_labels_test_set(df = df, 195 | output_path= config["paths"]['data_paths']['TEST_TOP_LABELS_DATA_PATH'], 196 | go_term_distribution=go_term_distribution 197 | ) 198 | logger.info("Done.") -------------------------------------------------------------------------------- /bin/download_EC_annotations.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import wget 3 | import os 4 | from protnote.utils.data import ( 5 | save_to_pickle, 6 | get_ec_class_descriptions, 7 | get_ec_number_description, 8 | ) 9 | from protnote.utils.configs import get_project_root 10 | 11 | def main(): 12 | output_dir = get_project_root() / 'data' / 'annotations' 13 | ec_classes_data = "https://ftp.expasy.org/databases/enzyme/enzclass.txt" 14 | ec_numbers_data = "https://ftp.expasy.org/databases/enzyme/enzyme.dat" 15 | enzclass_output_path = output_dir / "enzclass.txt" 16 | enzyme_dat_output_path = output_dir / "enzyme.dat" 17 | annotations_output_path = output_dir / "ec_annotations.pkl" 18 | if not os.path.exists(output_dir): 19 | os.makedirs(output_dir) 20 | 21 | wget.download(ec_classes_data, out=str(enzclass_output_path)) 22 | wget.download(ec_numbers_data, out=str(enzyme_dat_output_path)) 23 | 24 | ec_classes = get_ec_class_descriptions(enzclass_output_path) 25 | ec_numbers = get_ec_number_description(enzyme_dat_output_path, ec_classes) 26 | 27 | ec_annotations = pd.concat( 28 | [ 29 | pd.DataFrame.from_records(list(ec_classes.values())), 30 | pd.DataFrame.from_records(ec_numbers), 31 | ], 32 | axis=0, 33 | )[["ec_number", "label"]] 34 | 35 | ec_annotations["ec_number"] = "EC:" + ec_annotations["ec_number"] 36 | ec_annotations.set_index("ec_number", inplace=True) 37 | 38 | ec_annotations.index.name = None 39 | ec_annotations["name"] = ec_annotations["synonym_exact"] = ec_annotations["label"] 40 | ec_annotations["synonym_exact"] = ec_annotations["synonym_exact"].apply( 41 | lambda x: [x] 42 | ) 43 | 44 | save_to_pickle(ec_annotations, annotations_output_path) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /bin/download_GO_annotations.py: -------------------------------------------------------------------------------- 1 | import obonet 2 | import pandas as pd 3 | import argparse 4 | import logging 5 | import re 6 | import os 7 | import numpy as np 8 | from protnote.utils.configs import get_project_root 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | 13 | def calculate_label(row): 14 | """ 15 | Helper function to calculate the label for a given row. 16 | Returns the definition of the row with any text between brackets removed. 17 | """ 18 | definition = row.get("def", None) 19 | 20 | # Remove any text between brackets, e.g., PubMed citations 21 | # Remove leading and trailing quotation marks 22 | if definition is not None: 23 | definition = re.sub(r"\s*\[.*?\]\s*", "", definition) 24 | definition = definition.strip('"') 25 | 26 | return definition 27 | 28 | 29 | def process_synonyms(row) -> dict: 30 | """extracts the synonyms of a GO Annotation 31 | 32 | :param row: Row of GO annotation dataset 33 | :type row: _type_ 34 | :return: dict 35 | :rtype: lists of synonyms for relevant scopes 36 | """ 37 | if row is np.nan or not row: 38 | return { 39 | "synonym_exact": [], 40 | "synonym_narrow": [], 41 | "synonym_related": [], 42 | "synonym_broad": [], 43 | } 44 | 45 | scopes = {"EXACT": [], "NARROW": [], "RELATED": [], "BROAD": []} 46 | for synonym in row: 47 | match = re.search(r"\"(.+?)\"\s+(EXACT|NARROW|RELATED|BROAD)\s+\[", synonym) 48 | if match: 49 | text, scope = match.groups() 50 | scopes[scope].append(text) 51 | 52 | return { 53 | "synonym_exact": scopes["EXACT"], 54 | "synonym_narrow": scopes["NARROW"], 55 | "synonym_related": scopes["RELATED"], 56 | "synonym_broad": scopes["BROAD"], 57 | } 58 | 59 | 60 | def main(url: str, output_file: str): 61 | """ 62 | Download the OBO file from the specified URL and save the GO ID and label to a pickle. 63 | """ 64 | logging.info("Downloading and processing OBO file...") 65 | 66 | output_file = get_project_root() / 'data' / 'annotations' / output_file 67 | os.makedirs(get_project_root() / 'data' / 'annotations', exist_ok=True) 68 | # Load the .obo file directly from the URL into a networkx graph using obonet 69 | 70 | # Download from url if latest, otherwise download file with wget and read file with obonet 71 | if url == "latest": 72 | file = "https://purl.obolibrary.org/obo/go.obo" 73 | else: 74 | # Download the file, extract the date from the url and use date as suffix for temp .obo file 75 | date = re.search(r"\d{4}-\d{2}-\d{2}", url).group() 76 | file = get_project_root() / 'data' / 'annotations' / f"go_{date}.obo" 77 | os.system(f"wget {url} -O {file}") 78 | 79 | graph = obonet.read_obo(file, ignore_obsolete=False) 80 | 81 | # Convert the graph nodes (terms) into a pandas dataframe 82 | df = pd.DataFrame.from_dict(dict(graph.nodes(data=True)), orient="index") 83 | 84 | logging.info("Calculating labels...") 85 | # Create a new column called "label" 86 | df["label"] = df.apply(calculate_label, axis=1) 87 | 88 | # Extract synonyms to augment dataset 89 | df_synonyms = df["synonym"].apply(process_synonyms) 90 | df_synonyms = pd.DataFrame(df_synonyms.tolist(), index=df.index) 91 | 92 | # Merge the new columns back to the original DataFrame with the same index 93 | df = pd.concat([df, df_synonyms], axis=1) 94 | 95 | # Filter the dataframe to retain only 'label', 'name' and 'synonym' columns, with the 'id' column as the index 96 | df_filtered = df[["label", "name"] + list(df_synonyms.columns) + ["is_obsolete"]] 97 | 98 | # Save the filtered dataframe as a pickle 99 | df_filtered.to_pickle(output_file) 100 | 101 | logging.info(f"Saved filtered dataframe as a pickle to {output_file}") 102 | 103 | 104 | if __name__ == "__main__": 105 | """ 106 | 2019 Option: python download_GO_annotations.py http://release.geneontology.org/2019-07-01/ontology/go.obo go_annotations_2019_07_01.pkl 107 | 2023 Option: python download_GO_annotations.py http://release.geneontology.org/2023-07-27/ontology/go.obo go_annotations_2023_07_27.pkl 108 | """ 109 | # TODO: Filename can be figured out from URL 110 | parser = argparse.ArgumentParser( 111 | description="Download OBO file and save GO ID and label to a pickle." 112 | ) 113 | parser.add_argument( 114 | "--url", 115 | type=str, 116 | default="http://release.geneontology.org/2023-07-27/ontology/go.obo", 117 | help="URL to the OBO file.", 118 | ) 119 | parser.add_argument( 120 | "--output-file", type=str 121 | ) 122 | args = parser.parse_args() 123 | 124 | main(args.url, args.output_file) 125 | -------------------------------------------------------------------------------- /bin/download_and_test_multiple_proteinfer_seeds.sh: -------------------------------------------------------------------------------- 1 | ids=(13703706 13703742 13703997 13704131 13705631) # 13705668 13705677 13705689 13705708 13705728 ) 2 | eval "$(conda shell.bash hook)" 3 | conda activate proteinfer 4 | for id in "${ids[@]}"; do 5 | if [ ! -e "data/models/proteinfer/GO_model_weights${id}.pkl" ]; then 6 | wget https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/models/zipped_models/noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-${id}.tar.gz 7 | tar -xvzf noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-${id}.tar.gz 8 | mv -f noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-${id} proteinfer/cached_models/ 9 | python proteinfer/export_proteinfer.py --model-path proteinfer/cached_models/noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-${id} --model-name GO --add-model-id 10 | mv proteinfer/export/GO_model_weights${id}.pkl data/models/proteinfer/GO_model_weights${id}.pkl 11 | else 12 | echo "${id} weights already exist" 13 | fi 14 | done 15 | conda deactivate 16 | conda activate protein_functions_310 17 | 18 | for id in "${ids[@]}"; do 19 | python test_proteinfer.py --test-paths-names TEST_DATA_PATH --only-inference --only-represented-labels --save-prediction-results --name TEST_DATA_PATH_proteinfer --model-weights-id ${id} 20 | done 21 | 22 | conda deactivate -------------------------------------------------------------------------------- /bin/download_and_test_proteinfer_seeds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | from pathlib import Path 5 | 6 | # Define the list of IDs 7 | def run_command(command): 8 | """Utility function to run a shell command""" 9 | result = subprocess.run(command, shell=True, check=True, text=True) 10 | return result 11 | 12 | def download_and_extract_model(id): 13 | """Download, extract, and export the model if it doesn't exist""" 14 | model_dir = Path(__file__).resolve().parent.parent / "data" / "models" / "proteinfer" 15 | model_file = model_dir / f"GO_model_weights{id}.pkl" 16 | if not os.path.exists(model_file): 17 | tar_file = f"noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-{id}.tar.gz" 18 | download_url = f"https://storage.googleapis.com/brain-genomics-public/research/proteins/proteinfer/models/zipped_models/{tar_file}" 19 | 20 | # Download the model 21 | run_command(f"wget {download_url}") 22 | 23 | # Extract the tar file 24 | run_command(f"tar -xvzf {tar_file}") 25 | 26 | # Move extracted model to cached models directory 27 | model_id_dir = f"noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-{id}" 28 | if not os.path.exists(model_dir): 29 | os.makedirs(model_dir) 30 | 31 | run_command(f"mv -f {model_id_dir} {model_dir}") 32 | 33 | # Run the export script 34 | run_command(f"conda run -n proteinfer python bin/export_proteinfer.py --model-path {model_dir / model_id_dir} --output-dir {model_dir} --model-name GO --add-model-id") 35 | 36 | #Clean up. Orginal weights are not needed. 37 | 38 | run_command(f"rm -rf {model_dir / model_id_dir}") 39 | 40 | else: 41 | print(f"{id} weights already exist") 42 | 43 | def run_inference(id): 44 | """Run inference with the specified model weights""" 45 | run_command(f"conda run -n protnote python bin/test_proteinfer.py --test-paths-names TEST_DATA_PATH --only-inference --only-represented-labels --save-prediction-results --name TEST_DATA_PATH_proteinfer --model-weights-id {id}") 46 | 47 | def main(ids:list,get_predictions:bool): 48 | 49 | # Loop over the IDs and handle the models 50 | for id in ids: 51 | download_and_extract_model(id) 52 | 53 | # Run inference on the models 54 | if get_predictions: 55 | for id in ids: 56 | run_inference(id) 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser(description="Run protein inference on multiple models") 61 | parser.add_argument( 62 | "--ids", 63 | nargs="+", 64 | default=[13703706, 13703742, 13703997, 13704131, 13705631], 65 | required=False, 66 | help="List of proteinfer model ids", 67 | ) 68 | parser.add_argument( 69 | "--get-predictions", 70 | action="store_true", 71 | default=False, 72 | help="Whether to run inference and store the predictions of the ProteInfer models", 73 | ) 74 | 75 | args = parser.parse_args() 76 | main(ids = args.ids, get_predictions = args.get_predictions) 77 | 78 | -------------------------------------------------------------------------------- /bin/download_swissprot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from protnote.utils.data import download_and_unzip 3 | from protnote.utils.configs import get_project_root 4 | if __name__ == "__main__": 5 | """ 6 | Example usage: 7 | python download_swissprot.py --url "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz" --output-file "uniprot_sprot_latest.dat" 8 | """ 9 | parser = argparse.ArgumentParser( 10 | description="Download and unzip swissprot file from a given link." 11 | ) 12 | parser.add_argument( 13 | "--url", 14 | type=str, 15 | help="The URL to download the file from." 16 | ) 17 | 18 | parser.add_argument( 19 | "--output-file", 20 | type=str, 21 | help="The relative or absolute path to save the downloaded, unzipped file (not .gz file).", 22 | ) 23 | 24 | args = parser.parse_args() 25 | args.output_file = str(get_project_root() / 'data' / 'swissprot' / args.output_file) 26 | 27 | print("Downloading and unzipping file...") 28 | download_and_unzip(args.url, args.output_file) 29 | -------------------------------------------------------------------------------- /bin/export_proteinfer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | # Construct the path to the proteinfer directory 6 | proteinfer_dir = os.path.join(root_dir, 'proteinfer') 7 | sys.path.append(proteinfer_dir) 8 | import inference 9 | import argparse 10 | import tensorflow as tf 11 | import tensorflow_hub as hub 12 | import pickle 13 | 14 | def export_model_weights( 15 | model_path: str, 16 | model_name: str, 17 | output_dir: str, 18 | add_model_id: bool = False 19 | ): 20 | if not os.path.exists(output_dir): 21 | os.makedirs(output_dir) 22 | 23 | suffix = model_path.split("-")[-1] if add_model_id else "" 24 | output_path = os.path.join( 25 | output_dir, f"{model_name}_model_weights" + suffix + ".pkl" 26 | ) 27 | module_spec = hub.saved_model_module.create_module_spec_from_saved_model(model_path) 28 | 29 | tags = [tf.saved_model.tag_constants.SERVING] 30 | name_scope = "inferrer" 31 | module = hub.Module(module_spec, trainable=True, tags=tags, name=name_scope) 32 | 33 | with tf.Session() as sess: 34 | sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) 35 | 36 | # Fetch the trainable (weights & biases) and non trainable variables (batch norm stats) 37 | all_vars = tf.global_variables() 38 | all_var_values = sess.run(all_vars) 39 | weights_dict = {var.name: value for var, value in zip(all_vars, all_var_values)} 40 | 41 | with open(output_path, "wb") as f: 42 | pickle.dump(weights_dict, f, protocol=pickle.HIGHEST_PROTOCOL) 43 | 44 | 45 | def export_proteinfer_vocab( 46 | model_path: str, 47 | model_name: str, 48 | output_dir: str, 49 | vocab_variable_name: str = "label_vocab:0", 50 | add_model_id: bool = False, 51 | ): 52 | suffix = model_path.split("-")[-1] if add_model_id else "" 53 | inferrer = inference.Inferrer( 54 | savedmodel_dir_path=model_path, 55 | use_tqdm=True, 56 | batch_size=16, 57 | activation_type="pooled_representation", 58 | ) 59 | output_path = os.path.join( 60 | output_dir, f"proteinfer_{model_name}_label_vocab" + suffix + ".json" 61 | ) 62 | label_vocab = inferrer.get_variable(vocab_variable_name).astype(str) 63 | with open(output_path, "w") as output_file: 64 | json.dump(label_vocab.tolist(), output_file) 65 | 66 | 67 | if __name__ == "__main__": 68 | """ 69 | example 70 | 71 | python bin/export_proteinfer.py --model-path 'data/models/proteinfer/noxpd2_cnn_swissprot_go_random_swiss-cnn_for_swissprot_go_random-13703706' --model-name GO 72 | 73 | """ 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument( 76 | "--model-path", 77 | required=True, 78 | help="originally stored in cached_models after running install_models.py", 79 | ) 80 | parser.add_argument( 81 | "--output-dir", 82 | required=True, 83 | help="The directory to store the new ProteInfer weights", 84 | ) 85 | parser.add_argument( 86 | "--model-name", 87 | required=True, 88 | help="GO or EC" 89 | ) 90 | parser.add_argument( 91 | "--add-model-id", 92 | action="store_true", 93 | default=False, 94 | required=False 95 | ) 96 | args = parser.parse_args() 97 | 98 | # if os.path.exists('export') 99 | export_proteinfer_vocab( 100 | model_path=args.model_path, 101 | model_name=args.model_name, 102 | add_model_id=args.add_model_id, 103 | output_dir=args.output_dir 104 | ) 105 | export_model_weights( 106 | model_path=args.model_path, 107 | model_name=args.model_name, 108 | add_model_id=args.add_model_id, 109 | output_dir=args.output_dir 110 | ) 111 | -------------------------------------------------------------------------------- /bin/generate_label_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import logging 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | from transformers import AutoTokenizer, AutoModel 9 | from protnote.utils.models import generate_label_embeddings_from_text 10 | from protnote.utils.configs import generate_label_embedding_path 11 | from protnote.utils.data import ( 12 | read_yaml, 13 | read_pickle, 14 | remove_obsolete_from_string, 15 | ensure_list, 16 | ) 17 | from protnote.utils.configs import get_project_root 18 | 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | ### SETUP ### 22 | torch.cuda.empty_cache() 23 | 24 | 25 | def main(): 26 | # ---------------------- HANDLE ARGUMENTS ----------------------# 27 | parser = argparse.ArgumentParser(description="Train and/or Test the ProtNote model.") 28 | 29 | parser.add_argument( 30 | "--pooling-method", 31 | type=str, 32 | default="mean", 33 | help="How to pool embeddings. mean, last_token, or all", 34 | ) 35 | 36 | parser.add_argument( 37 | "--base-label-embedding-path", 38 | type=str, 39 | default="GO_BASE_LABEL_EMBEDDING_PATH", 40 | help="the base label embedding path name from config. value must exist in config.", 41 | ) 42 | parser.add_argument( 43 | "--annotations-path-name", 44 | type=str, 45 | default="GO_ANNOTATIONS_2019_UPDATED_PATH", 46 | help="Name of the annotation path. Defaults to GO.", 47 | ) 48 | 49 | parser.add_argument( 50 | "--account-for-sos", 51 | action="store_true", 52 | help="Whether to ignore the SOS token. Set to True for BioGPT and False for E5. Doesn't make any difference if pooling method = last_token.", 53 | ) 54 | 55 | parser.add_argument( 56 | "--add-instruction", 57 | action="store_true", 58 | help="Whether to format for instruction tuned model", 59 | ) 60 | 61 | parser.add_argument( 62 | "--label-encoder-checkpoint", 63 | type=str, 64 | default="intfloat/multilingual-e5-large-instruct", 65 | help="The huggingface LLM to use. Others include microsoft/biogpt.", 66 | ) 67 | 68 | def get_detailed_instruct(task_description: str, query: str) -> str: 69 | return f"Instruct: {task_description}\nQuery: {query}" 70 | 71 | args = parser.parse_args() 72 | 73 | ROOT_PATH = get_project_root() 74 | CONFIG = read_yaml(ROOT_PATH / "configs" / "base_config.yaml") 75 | TASK = "Identify the main categories, themes, or topics described in the following Gene Ontology (GO) term, which is used to detail a protein's function" 76 | 77 | # Overwrite config pooling method 78 | CONFIG["params"]["LABEL_EMBEDDING_POOLING_METHOD"] = args.pooling_method 79 | CONFIG["params"]["LABEL_ENCODER_CHECKPOINT"] = args.label_encoder_checkpoint 80 | 81 | DATA_PATH = ROOT_PATH / "data" 82 | OUTPUT_PATH = os.path.join( 83 | DATA_PATH, 84 | generate_label_embedding_path( 85 | params=CONFIG["params"], 86 | base_label_embedding_path=CONFIG["paths"]["data_paths"][ 87 | args.base_label_embedding_path 88 | ], 89 | ), 90 | ) 91 | #Create output path dir if it doesn't exist 92 | os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) 93 | 94 | INDEX_OUTPUT_PATH = OUTPUT_PATH.split(".") 95 | INDEX_OUTPUT_PATH = ( 96 | "_".join([INDEX_OUTPUT_PATH[0], "index"]) + "." + INDEX_OUTPUT_PATH[1] 97 | ) 98 | 99 | ANNOTATIONS_PATH = os.path.join( 100 | DATA_PATH, CONFIG["paths"]["data_paths"][args.annotations_path_name] 101 | ) 102 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 103 | 104 | logging.info( 105 | f"Pooled embeddings will be saved to {OUTPUT_PATH}\n Pooled embeddings index will be saved to {INDEX_OUTPUT_PATH} \n Using pooling method: {args.pooling_method}" 106 | ) 107 | 108 | descriptions_file = read_pickle(ANNOTATIONS_PATH) 109 | 110 | # Initialize label tokenizer 111 | label_tokenizer = AutoTokenizer.from_pretrained( 112 | args.label_encoder_checkpoint, 113 | ) 114 | # Initialize label encoder 115 | label_encoder = AutoModel.from_pretrained( 116 | args.label_encoder_checkpoint, 117 | ).to(DEVICE) 118 | 119 | logging.info( 120 | "Flattening descriptions for batch processing and calculating sequence token lengths..." 121 | ) 122 | print("SOS = ", args.account_for_sos, args.account_for_sos == False) 123 | embeddings_idx = { 124 | "id": [], 125 | "description_type": [], 126 | "description": [], 127 | "token_count": [], 128 | } 129 | for go_term, desriptions in tqdm( 130 | descriptions_file[["name", "label", "synonym_exact"]].iterrows(), 131 | total=len(descriptions_file), 132 | ): 133 | for desription_type, desription_set in desriptions.items(): 134 | for description in ensure_list(desription_set): 135 | description = remove_obsolete_from_string(description) 136 | 137 | if args.add_instruction: 138 | description = get_detailed_instruct(TASK, description) 139 | 140 | embeddings_idx["description"].append(description) 141 | embeddings_idx["id"].append(go_term) 142 | embeddings_idx["description_type"].append(desription_type) 143 | embeddings_idx["token_count"].append( 144 | len(label_tokenizer.tokenize(description)) 145 | ) # We need the token count for embedding normalization (longer descriptions will have more feature-rich embeddings) 146 | # Remove Obsolete/Deprecated texts 147 | logging.info("Extracting embeddings...") 148 | 149 | embeddings = generate_label_embeddings_from_text( 150 | label_annotations=embeddings_idx["description"], 151 | label_tokenizer=label_tokenizer, 152 | label_encoder=label_encoder, 153 | pooling_method=args.pooling_method, 154 | batch_size_limit=CONFIG["params"]["LABEL_BATCH_SIZE_LIMIT_NO_GRAD"], 155 | append_in_cpu=False, 156 | account_for_sos=args.account_for_sos, 157 | ).to("cpu") 158 | 159 | # Convert to indexed pandas df 160 | embeddings_idx = pd.DataFrame(embeddings_idx) 161 | 162 | logging.info("Saving to a torch .pt...") 163 | torch.save(embeddings, OUTPUT_PATH) 164 | torch.save(embeddings_idx, INDEX_OUTPUT_PATH) 165 | logging.info(f"Embeddings saved to {OUTPUT_PATH}") 166 | logging.info(f"Index saved to {INDEX_OUTPUT_PATH}") 167 | 168 | 169 | if __name__ == "__main__": 170 | ''' 171 | Example usage: 172 | 173 | python bin/generate_label_embeddings.py --add-instruction --account-for-sos 174 | python bin/generate_label_embeddings.py --label-encoder-checkpoint microsoft/biogpt --account-for-sos 175 | 176 | python bin/generate_label_embeddings.py --base-label-embedding-path EC_BASE_LABEL_EMBEDDING_PATH --annotations-path-name EC_ANNOTATIONS_PATH --label-encoder-checkpoint microsoft/biogpt --account-for-sos 177 | python bin/generate_label_embeddings.py --base-label-embedding-path EC_BASE_LABEL_EMBEDDING_PATH --annotations-path-name EC_ANNOTATIONS_PATH --add-instruction --account-for-sos 178 | 179 | python bin/generate_label_embeddings.py --base-label-embedding-path GO_2024_BASE_LABEL_EMBEDDING_PATH --annotations-path-name GO_ANNOTATIONS_PATH --add-instruction --account-for-sos 180 | python bin/generate_label_embeddings.py --base-label-embedding-path GO_2024_BASE_LABEL_EMBEDDING_PATH --annotations-path-name GO_ANNOTATIONS_PATH --label-encoder-checkpoint microsoft/biogpt --account-for-sos 181 | 182 | 183 | ''' 184 | main() 185 | -------------------------------------------------------------------------------- /bin/make_dataset_from_swissprot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from Bio import SwissProt 4 | from Bio.Seq import Seq 5 | from Bio.SeqRecord import SeqRecord 6 | from Bio import SeqIO 7 | import os 8 | from tqdm import tqdm 9 | from typing import Literal 10 | import collections 11 | from protnote.utils.configs import get_project_root,construct_absolute_paths, load_config, get_logger 12 | from protnote.utils.data import read_json, read_fasta, generate_vocabularies, COMMON_AMINOACIDS 13 | 14 | logger = get_logger() 15 | 16 | def reverse_map(applicable_label_dict, label_vocab=None): 17 | """Flip parenthood dict to map parents to children. 18 | 19 | Args: 20 | applicable_label_dict: e.g. output of get_applicable_label_dict. 21 | label_vocab: e.g. output of inference_lib.vocab_from_model_base_path 22 | 23 | Returns: 24 | collections.defaultdict of k, v where: 25 | k: originally the values in applicable_label_dict 26 | v: originally the keys in applicable_label_dict. 27 | The defaultdict returns an empty frozenset for keys that are not found. 28 | This behavior is desirable for lifted clan label normalizers, where 29 | keys may not imply themselves. 30 | 31 | DIRECTLY TAKEN FROM PROTEINFER SORUCE CODE 32 | """ 33 | # This is technically the entire transitive closure, so it is safe for DAGs 34 | # (e.g. GO labels). 35 | 36 | children = collections.defaultdict(set) 37 | for child, parents in applicable_label_dict.items(): 38 | # Avoid adding children which don't appear in the vocab. 39 | if label_vocab is None or child in label_vocab: 40 | for parent in parents: 41 | children[parent].add(child) 42 | children = {k: frozenset(v) for k, v in children.items()} 43 | return collections.defaultdict(frozenset, children.items()) 44 | 45 | 46 | def main( 47 | latest_swissprot_file: str, 48 | output_file_path: str, 49 | parenthood_file: str, 50 | label_vocabulary: Literal["proteinfer", "new", "all"], 51 | sequence_vocabulary: Literal["proteinfer_test", "proteinfer_test", "new", "all"], 52 | only_leaf_nodes: bool = False, 53 | cache=True, 54 | ): 55 | # Load the configuration and project root 56 | 57 | # Load the configuration and project root 58 | config, project_root = load_config() 59 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 60 | swissprot_dir = project_root / 'data' / 'swissprot' 61 | annotations_dir = project_root / 'data' / 'annotations' 62 | latest_swissprot_file = swissprot_dir / latest_swissprot_file 63 | 64 | #Create the output directory from output_file_path if it does not exist 65 | os.makedirs(os.path.dirname(output_file_path), exist_ok=True) 66 | 67 | # Extract data from SwissProt records 68 | 69 | # See https://biopython.org/docs/1.75/api/Bio.SwissProt.html and https://web.expasy.org/docs/userman.html 70 | 71 | if not (os.path.exists(swissprot_dir / args.parsed_latest_swissprot_file) & cache): 72 | with open(latest_swissprot_file, "r") as f: 73 | data = [] 74 | 75 | records = SwissProt.parse(f) 76 | 77 | logger.info("Extracting data from SwissProt records... This may take a while...") 78 | for record in tqdm(records, total=571609): 79 | # Extract sequence ID 80 | seq_id = record.accessions[0] 81 | 82 | # Extract sequence 83 | sequence = record.sequence 84 | 85 | # Extract GO ids 86 | go_ids = [ 87 | ref[1] 88 | for ref in record.cross_references 89 | if ref[0] == "GO" and len(ref) > 0 90 | ] 91 | 92 | # Extract free-text description 93 | description = record.description 94 | 95 | # Extract organism and organism classification 96 | organism = record.organism 97 | organism_classification = record.organism_classification 98 | 99 | # Extract organelle 100 | organelle = record.organelle 101 | 102 | # Extract CC line as a dictionary 103 | cc = {} 104 | for comment in record.comments: 105 | key, value = comment.split(": ", 1) 106 | cc[key] = value 107 | 108 | data.append( 109 | [ 110 | seq_id, 111 | sequence, 112 | go_ids, 113 | description, 114 | organism, 115 | organism_classification, 116 | organelle, 117 | cc, 118 | ] 119 | ) 120 | 121 | logger.info("Finished extracting data from SwissProt records.") 122 | 123 | # Convert data into a pandas DataFrame and create a new column with the subcellular location 124 | df_latest = pd.DataFrame( 125 | data, 126 | columns=[ 127 | "seq_id", 128 | "sequence", 129 | "go_ids", 130 | "description", 131 | "organism", 132 | "organism_classification", 133 | "organelle", 134 | "cc", 135 | ], 136 | ) 137 | df_latest["subcellular_location"] = df_latest.cc.apply( 138 | lambda x: x.get("SUBCELLULAR LOCATION") 139 | ) 140 | 141 | # Save df_latest to a file 142 | df_latest.to_pickle(swissprot_dir / args.parsed_latest_swissprot_file) 143 | else: 144 | df_latest = pd.read_pickle(swissprot_dir / args.parsed_latest_swissprot_file) 145 | 146 | # Make a set of the GO labels from the label embeddings 147 | label_ids_2019 = set(pd.read_pickle(annotations_dir / "go_annotations_2019_07_01.pkl").index) 148 | annotations_latest = pd.read_pickle(annotations_dir / args.latest_go_annotations_file) 149 | pinf_train = read_fasta(config['paths']['data_paths']['TRAIN_DATA_PATH']) 150 | pinf_val = read_fasta(config['paths']['data_paths']['VAL_DATA_PATH']) 151 | pinf_test = read_fasta(config['paths']['data_paths']['TEST_DATA_PATH']) 152 | 153 | label_ids_2024 = set(annotations_latest.index) 154 | 155 | # Find added labels 156 | new_go_labels = label_ids_2024 - label_ids_2019 157 | 158 | parenthood = read_json(project_root / "data" / "vocabularies" / parenthood_file) 159 | 160 | reverse_parenthood = reverse_map(parenthood) 161 | leaf_nodes = [] 162 | for parent, children in reverse_parenthood.items(): 163 | leaf_node = list(children)[0] 164 | if ( 165 | "GO" in parent 166 | and len(children) == 1 167 | and leaf_node in annotations_latest.index 168 | ): 169 | if "obsolete" not in annotations_latest.loc[leaf_node, "name"]: 170 | leaf_nodes.append(leaf_node) 171 | leaf_nodes = set(leaf_nodes) 172 | 173 | def add_go_parents(go_terms: list): 174 | all_terms = set() 175 | for term in go_terms: 176 | all_terms.update( 177 | parenthood[term] 178 | ) # Note that parents of term contain term itself 179 | return list(all_terms) 180 | 181 | # Update go terms to include all parents 182 | df_latest["go_ids"] = df_latest["go_ids"].apply(add_go_parents) 183 | 184 | if sequence_vocabulary == "new": 185 | sequence_ids_2019 = set([id for _, id, _ in pinf_train + pinf_val]) 186 | in_proteinfer_train_val = df_latest.seq_id.apply(lambda x: x in sequence_ids_2019) 187 | df_latest = df_latest[(in_proteinfer_train_val == False)] 188 | elif sequence_vocabulary == "proteinfer_test": 189 | proteinfer_test_set_seqs = set([id for _, id, _ in pinf_test]) 190 | in_proteinfer_test = df_latest.seq_id.apply( 191 | lambda x: x in proteinfer_test_set_seqs 192 | ) 193 | df_latest = df_latest[(in_proteinfer_test == True)] 194 | elif sequence_vocabulary == "proteinfer_train": 195 | proteinfer_train_set_seqs = set([id for _, id, _ in pinf_train]) 196 | in_proteinfer_train = df_latest.seq_id.apply( 197 | lambda x: x in proteinfer_train_set_seqs 198 | ) 199 | df_latest = df_latest[(in_proteinfer_train == True)] 200 | elif sequence_vocabulary == "all": 201 | pass 202 | else: 203 | raise ValueError(f"{sequence_vocabulary} not recognized") 204 | 205 | if label_vocabulary == "proteinfer": 206 | vocab = set( 207 | generate_vocabularies( 208 | str(config['paths']['data_paths']['FULL_DATA_PATH']) 209 | )["label_vocab"] 210 | ) 211 | elif label_vocabulary == "new": 212 | vocab = new_go_labels 213 | elif label_vocabulary == "all": 214 | vocab = set([j for i in df_latest.go_ids for j in i]) 215 | 216 | if only_leaf_nodes: 217 | vocab &= leaf_nodes 218 | 219 | logger.info("filtering labels") 220 | # Find protein sequences with added labels 221 | df_latest["go_ids"] = df_latest.go_ids.apply(lambda x: (set(x) & vocab)) 222 | 223 | # Remove sequences with no applicable labels 224 | df_latest = df_latest[(df_latest.go_ids != set())] 225 | 226 | filtered_df = df_latest[["seq_id", "sequence", "go_ids"]] 227 | 228 | # Set of 20 common amino acids 229 | common_amino_acids = set(COMMON_AMINOACIDS) 230 | 231 | # Create a column of the non-common amino acids 232 | filtered_df["non_common_amino_acids"] = filtered_df.sequence.apply( 233 | lambda x: set(x) - common_amino_acids 234 | ) 235 | 236 | # Filter to only contain rows that contain common amino acids and rename 237 | SwissProt_latest = filtered_df[filtered_df.non_common_amino_acids == set()] 238 | 239 | # Rename columns 240 | final_labels = set([j for i in SwissProt_latest["go_ids"] for j in i]) 241 | logger.info( 242 | "Number of sequences in dataframe: " 243 | + str(len(SwissProt_latest)) 244 | + f" Number of labels in dataframe: {str(len(final_labels))}" 245 | ) 246 | 247 | logger.info("Writting to FASTA...") 248 | # Convert dataframe to FASTA format and save to a file 249 | records = [ 250 | SeqRecord( 251 | Seq(row["sequence"]), id=row["seq_id"], description=" ".join(row["go_ids"]) 252 | ) 253 | for _, row in SwissProt_latest.iterrows() 254 | ] 255 | SeqIO.write(records, output_file_path, "fasta") 256 | logger.info("Saved FASTA file to " + output_file_path) 257 | 258 | 259 | if __name__ == "__main__": 260 | """ 261 | Examples usage: 262 | # python make_dataset_from_swissprot.py --latest-swissprot-file uniprot_sprot.dat --output-file-path unseen_swissprot_jul_2024.fasta --only-unseen-seqs --label-vocabulary=new --parenthood-file parenthood_jul_2024.json 263 | # updated test set: python make_dataset_from_swissprot.py --latest-swissprot-file uniprot_sprot.dat --output-file-path test_jul_2024.fasta --only-unseen-seqs --label-vocabulary=new --parenthood-file parenthood_jul_2024.json 264 | """ 265 | parser = argparse.ArgumentParser( 266 | description="Process SwissProt data and generate a FASTA file." 267 | ) 268 | parser.add_argument( 269 | "--latest-swissprot-file", 270 | type=str, 271 | help="Path to the SwissProt data file." 272 | ) 273 | parser.add_argument( 274 | "--output-file-path", 275 | type=str, 276 | help="Path to the output file. Should be a FASTA file. ", 277 | ) 278 | 279 | parser.add_argument( 280 | "--parsed-latest-swissprot-file", 281 | type=str, 282 | default='swissprot_2024_full.pkl', 283 | help="Path to the latest parsed SwissProt file if exists. Otherwise will be created for caching purposes. Date should match latest-go-annotations-file", 284 | ) 285 | 286 | parser.add_argument( 287 | "--latest-go-annotations-file", 288 | type=str, 289 | default="go_annotations_jul_2024.pkl", 290 | help="Path to the latest go annotations file available. Date should match parsed-latest-swissprot-file" 291 | ) 292 | 293 | parser.add_argument( 294 | "--parenthood-file", 295 | type=str, 296 | help="path to the parenthood json containing go term children to parent mapping", 297 | ) 298 | parser.add_argument( 299 | "--sequence-vocabulary", 300 | help="The sequences to use. Can be proteinfer_test, all, or new.", 301 | type=str, 302 | ) 303 | parser.add_argument( 304 | "--only-leaf-nodes", 305 | help="wether to only consider leaf nodes of the hierarchy", 306 | action="store_true", 307 | default=False, 308 | ) 309 | parser.add_argument( 310 | "--label-vocabulary", 311 | help="The label vocabulary to use: proteinfer, all, new. all = all observed terms in dataset. New = all observed and new since 2019", 312 | required=True, 313 | ) 314 | parser.add_argument( 315 | "--no-cache", 316 | action="store_true", 317 | default=False, 318 | help="whether to download data from scratch or read file if exists", 319 | ) 320 | args = parser.parse_args() 321 | 322 | main( 323 | latest_swissprot_file=args.latest_swissprot_file, 324 | output_file_path=args.output_file_path, 325 | parenthood_file=args.parenthood_file, 326 | label_vocabulary=args.label_vocabulary, 327 | sequence_vocabulary=args.sequence_vocabulary, 328 | only_leaf_nodes=args.only_leaf_nodes, 329 | cache=not args.no_cache, 330 | ) 331 | -------------------------------------------------------------------------------- /bin/make_proteinfer_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from torchdata.datapipes.iter import FileLister, FileOpener 4 | import argparse 5 | from Bio.Seq import Seq 6 | from Bio.SeqRecord import SeqRecord 7 | from Bio import SeqIO 8 | from tqdm import tqdm 9 | from protnote.utils.configs import get_project_root, load_config 10 | 11 | def process_sequence_tfrecord(record: dict, annotation_types: list): 12 | sequence = record["sequence"][0].decode() 13 | id = record["id"][0].decode() 14 | 15 | labels = set() 16 | 17 | # Some rows have no lavel column 18 | if "label" not in record: 19 | return None 20 | 21 | # Add all labels from desired annotation types 22 | for l in record["label"]: 23 | label = l.decode() 24 | label_type = label.split(":")[0] 25 | 26 | if label_type in annotation_types: 27 | labels.add(label) 28 | 29 | # Sequence with no annotation from selected types 30 | if not labels: 31 | return None 32 | 33 | return id, (sequence, list(labels)) 34 | 35 | 36 | def process_tfrecords( 37 | input_dir: str, 38 | annotation_types: list, 39 | pattern: str, 40 | pattern_name: str, 41 | ): 42 | # Load all tfrecords from desired data split 43 | datapipe1 = FileLister(str(input_dir), pattern) 44 | datapipe2 = FileOpener(datapipe1, mode="b") 45 | tfrecord_loader_dp = datapipe2.load_from_tfrecord() 46 | 47 | records = [] 48 | # Iterate over records, process and write to a fasta file 49 | for _, record in tqdm(enumerate(tfrecord_loader_dp)): 50 | processed_sequence = process_sequence_tfrecord(record, annotation_types) 51 | 52 | # Skipping sequence with no labels from desired annotations 53 | if processed_sequence is None: 54 | continue 55 | 56 | id, (sequence, labels) = processed_sequence 57 | 58 | description = " ".join(labels) 59 | record = SeqRecord(Seq(sequence), id=f"{id}", description=description) 60 | records.append(record) 61 | 62 | with open( 63 | input_dir / f"{pattern_name}_{'_'.join(annotation_types)}.fasta", 64 | "w", 65 | ) as output_handle: 66 | SeqIO.write(records, output_handle, "fasta") 67 | 68 | 69 | if __name__ == "__main__": 70 | """ 71 | Example usage: python bin/make_proteinfer_dataset.py --dataset-type random --annotation-types GO 72 | """ 73 | logging.basicConfig( 74 | format="%(asctime)s | %(levelname)s: %(message)s", level=logging.NOTSET 75 | ) 76 | parser = argparse.ArgumentParser() 77 | 78 | config, project_root = load_config() 79 | 80 | # TODO: I/O paths could be not required and default to some env var 81 | parser.add_argument( 82 | "--dataset-type", 83 | required=True, 84 | help='The specific type of ProteInfer dataset. Two choices: random, clustered', 85 | default = 'random' 86 | ) 87 | 88 | parser.add_argument( 89 | "--annotation-types", 90 | nargs="+", 91 | required=True 92 | ) 93 | 94 | args = parser.parse_args() 95 | input_dir = project_root / "data" / "swissprot" / "proteinfer_splits" / f"{args.dataset_type}" 96 | 97 | dirname = os.path.dirname(__file__) 98 | 99 | patterns = { 100 | "train": "train*.tfrecord", 101 | "dev": "dev*.tfrecord", 102 | "test": "test*.tfrecord", 103 | "full": "*.tfrecord", 104 | } 105 | 106 | for pattern_name, pattern in patterns.items(): 107 | logging.info(f"Processing {pattern_name}") 108 | process_tfrecords( 109 | input_dir=input_dir, 110 | annotation_types=args.annotation_types, 111 | pattern=pattern, 112 | pattern_name=pattern_name, 113 | ) 114 | -------------------------------------------------------------------------------- /bin/make_zero_shot_datasets_from_proteinfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from protnote.utils.data import read_fasta, save_to_fasta, generate_vocabularies 6 | from datetime import datetime 7 | from protnote.utils.configs import load_config 8 | 9 | 10 | def split_labels(label_vocabulary): 11 | """ 12 | Split the label vocabulary into training, validation, and test sets. 13 | 14 | Parameters: 15 | - label_vocabulary: list, a list of unique labels. 16 | 17 | Returns: 18 | - tuple of lists, containing train_labels, val_labels, and test_labels. 19 | """ 20 | #set seed as 42 for random.shuffle 21 | random.seed(42) 22 | random.shuffle(label_vocabulary) 23 | train_size = len(label_vocabulary) * 80 // 100 24 | val_size = len(label_vocabulary) * 10 // 100 25 | train_labels = label_vocabulary[:train_size] 26 | val_labels = label_vocabulary[train_size : train_size + val_size] 27 | test_labels = label_vocabulary[train_size + val_size :] 28 | return train_labels, val_labels, test_labels 29 | 30 | 31 | def filter_dataset(dataset, labels, desc="Filtering dataset"): 32 | """ 33 | Filter the dataset to include only sequences with specified labels. 34 | 35 | Parameters: 36 | - dataset: list of tuples, where each tuple is (sequence, labels) to be filtered. 37 | - labels: list, labels to retain in the dataset. 38 | 39 | Returns: 40 | - list, filtered dataset with only specified labels. 41 | """ 42 | labels_set = set(labels) # Convert list to set for O(1) lookup 43 | filtered = [ 44 | ( 45 | sequence, 46 | sequence_id, 47 | [label for label in sequence_labels if label in labels_set], 48 | ) 49 | for sequence, sequence_id, sequence_labels in tqdm(dataset, desc=desc) 50 | ] 51 | return filtered 52 | 53 | 54 | def main(): 55 | """ 56 | Main function to filter fasta files based on label vocabulary and save the filtered datasets. 57 | 58 | Parameters: 59 | - train_path: str, path to the training set fasta file. 60 | - val_path: str, path to the validation set fasta file. 61 | - test_path: str, path to the test set fasta file. 62 | - label_vocab_path: str, path to the label vocabulary JSON file. 63 | - output_path: str, directory where the filtered fasta files will be saved. 64 | """ 65 | 66 | config, project_root = load_config() 67 | 68 | parser = ArgumentParser( 69 | description="Filter and save datasets for zero-shot learning." 70 | ) 71 | parser.add_argument( 72 | "--train-path", 73 | required=False, 74 | help="Path to the training set fasta file.", 75 | default=config["paths"]['data_paths']['TRAIN_DATA_PATH'], 76 | ) 77 | parser.add_argument( 78 | "--val-path", 79 | required=False, 80 | help="Path to the validation set fasta file.", 81 | default=config["paths"]['data_paths']['VAL_DATA_PATH'], 82 | ) 83 | parser.add_argument( 84 | "--test-path", 85 | required=False, 86 | help="Path to the test set fasta file.", 87 | default=config["paths"]['data_paths']['TEST_DATA_PATH'], 88 | ) 89 | args = parser.parse_args() 90 | 91 | output_path = project_root / "data" / "swissprot" / "proteinfer_splits" / "random" 92 | train = read_fasta(args.train_path) 93 | val = read_fasta(args.val_path) 94 | test = read_fasta(args.test_path) 95 | 96 | label_vocabulary = generate_vocabularies(file_path=str(config["paths"]['data_paths']['FULL_DATA_PATH']))["label_vocab"] 97 | 98 | train_labels, val_labels, test_labels = split_labels(label_vocabulary) 99 | 100 | filtered_datasets = { 101 | "train_GO_zero_shot.fasta": filter_dataset( 102 | train, train_labels, desc="Filtering training set" 103 | ), 104 | "dev_GO_zero_shot.fasta": filter_dataset( 105 | val, val_labels, desc="Filtering validation set" 106 | ), 107 | "test_GO_zero_shot.fasta": filter_dataset( 108 | test, test_labels, desc="Filtering test set" 109 | ), 110 | } 111 | 112 | for filename, dataset in filtered_datasets.items(): 113 | save_to_fasta(dataset, os.path.join(output_path, 'fake_' + filename)) 114 | 115 | 116 | if __name__ == "__main__": 117 | """ 118 | Example usage: 119 | python bin/make_zero_shot_datasets_from_proteinfer.py 120 | """ 121 | main() 122 | -------------------------------------------------------------------------------- /bin/run_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import os 5 | import sys 6 | import logging 7 | from typing import Literal 8 | from pprint import pprint 9 | from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC 10 | import pandas as pd 11 | import subprocess 12 | from protnote.utils.evaluation import EvalMetrics 13 | from protnote.utils.data import generate_vocabularies, read_yaml 14 | from protnote.utils.configs import load_config, construct_absolute_paths 15 | from test_models import TEST_COMMANDS, MODEL_PATH_TOKEN, MODEL_NAME_TOKEN 16 | 17 | #TODO: This could be more general by eliminating unnecessary dependencies and passing the embedding file names as arguments 18 | def main( 19 | annotation_type: str, 20 | label_embedding_model: Literal["E5", "BioGPT"], 21 | test_name: str, 22 | model_name: str, 23 | cache: bool, 24 | ): 25 | 26 | # Load the configuration and project root 27 | config, project_root = load_config() 28 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 29 | full_data_path = config["paths"]['data_paths']['FULL_DATA_PATH'] 30 | proteinfer_predictions_path = results_dir / f"test_logits_{annotation_type}_{test_name}_proteinfer.h5" 31 | protnote_labels_path = results_dir / f"test_1_labels_{test_name}_{model_name}.h5" 32 | output_file = results_dir / f"test_logits_{annotation_type}_{test_name}_{label_embedding_model}_baseline.h5" 33 | 34 | 35 | logger = logging.getLogger() 36 | logger.setLevel(logging.INFO) 37 | # Create a formatter 38 | formatter = logging.Formatter( 39 | fmt="%(asctime)s %(levelname)-4s %(message)s", datefmt="%Y-%m-%d %H:%M:%S %Z" 40 | ) 41 | stream_handler = logging.StreamHandler(sys.stdout) 42 | stream_handler.setFormatter(formatter) 43 | logger.addHandler(stream_handler) 44 | 45 | assert test_name in TEST_COMMANDS.keys(), f"{test_name} does not exist" 46 | 47 | if label_embedding_model == "E5": 48 | label_embeddings = "2024_E5_multiling_inst_frozen_label_embeddings_mean" 49 | 50 | elif label_embedding_model == "BioGPT": 51 | label_embeddings = "2024_BioGPT_frozen_label_embeddings_mean" 52 | 53 | base_embeddings_path = project_root / "data" / "embeddings" / f"{label_embeddings}.pt" 54 | base_embeddings_idx_path = project_root / "data" / "embeddings" / f"{label_embeddings}_index.pt" 55 | 56 | if (not os.path.exists(proteinfer_predictions_path)) | (not cache): 57 | if annotation_type == "GO": 58 | logger.info("Running inference with Proteinfer...") 59 | subprocess.run( 60 | f"python bin/test_proteinfer.py --test-paths-names {test_name} --only-inference --override EXTRACT_VOCABULARIES_FROM null --save-prediction-results --name {test_name}_proteinfer --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH", 61 | shell=True, 62 | ) 63 | elif annotation_type == "EC": 64 | subprocess.run( 65 | f"python bin/test_proteinfer.py --test-paths-names {test_name} --only-inference --override EXTRACT_VOCABULARIES_FROM null --save-prediction-results --name {test_name}_proteinfer --base-label-embedding-name EC_BASE_LABEL_EMBEDDING_PATH --annotations-path-name EC_ANNOTATIONS_PATH", 66 | shell=True, 67 | ) 68 | else: 69 | logger.info("Found existing proteinfer predictions... caching") 70 | 71 | if ( 72 | not os.path.exists(protnote_labels_path) 73 | ) | (not cache): 74 | logger.info("Getting label dataframe in a hacky way...") 75 | #TODO: This can be avoided by creating the label dataframe and performing the preprocessing 76 | # instead of running the model 77 | subprocess.run( 78 | TEST_COMMANDS[test_name] 79 | .replace(MODEL_PATH_TOKEN, f"{model_name}.pt") 80 | .replace(MODEL_NAME_TOKEN, model_name) + " --save-prediction-results", 81 | shell=True, 82 | ) 83 | 84 | zero_shot_pinf_logits = pd.read_hdf(proteinfer_predictions_path,key="logits_df").astype("float32") 85 | zero_shot_labels = pd.read_hdf(protnote_labels_path, key="labels_df") 86 | vocabularies = generate_vocabularies(file_path=str(full_data_path)) 87 | zero_shot_pinf_logits.columns = vocabularies["label_vocab"] 88 | 89 | embeddings = torch.load(base_embeddings_path) 90 | embeddings_idx = torch.load(base_embeddings_idx_path) 91 | 92 | # Select embeddings based on name / short definition 93 | embedding_mask = embeddings_idx["description_type"] == "name" 94 | embeddings_idx = embeddings_idx[embedding_mask].reset_index(drop=True) 95 | embeddings = embeddings[embedding_mask] 96 | 97 | # Create embeddings matrix of known proteinfer GO Term definitions 98 | train_embeddings_mask = embeddings_idx["id"].isin(vocabularies["label_vocab"]) 99 | train_embeddings_idx = embeddings_idx[train_embeddings_mask].reset_index(drop=True) 100 | train_embeddings = embeddings[train_embeddings_mask] 101 | 102 | if annotation_type == "GO": 103 | # Create embedding matrix of the new/unknown GO Term definitions 104 | zero_shot_embeddings_mask = embeddings_idx["id"].isin(zero_shot_labels.columns) 105 | zero_shot_embeddings_idx = embeddings_idx[ 106 | zero_shot_embeddings_mask 107 | ].reset_index(drop=True) 108 | zero_shot_embeddings = embeddings[zero_shot_embeddings_mask] 109 | 110 | if annotation_type == "EC": 111 | if label_embedding_model == "BioGPT": 112 | label_embeddings_new = "ecv1_BioGPT_frozen_label_embeddings_mean" 113 | elif label_embedding_model == "E5": 114 | label_embeddings_new = "ecv1_E5_multiling_inst_frozen_label_embeddings_mean" 115 | 116 | new_embeddings_path = project_root / "data" / "embeddings" / f"{label_embeddings_new}.pt" 117 | new_embeddings_idx_path = project_root / "data" / "embeddings" / f"{label_embeddings_new}_index.pt" 118 | 119 | embeddings_new = torch.load(new_embeddings_path) 120 | embeddings_idx_new = torch.load(new_embeddings_idx_path) 121 | 122 | # Create embedding matrix of the new/unknown GO Term definitions 123 | embedding_mask_new = embeddings_idx_new["description_type"] == "name" 124 | embeddings_idx_new = embeddings_idx_new[embedding_mask_new].reset_index( 125 | drop=True 126 | ) 127 | embeddings_new = embeddings_new[embedding_mask_new] 128 | 129 | zero_shot_embeddings_mask = embeddings_idx_new["id"].isin( 130 | zero_shot_labels.columns 131 | ) 132 | zero_shot_embeddings_idx = embeddings_idx_new[ 133 | zero_shot_embeddings_mask 134 | ].reset_index(drop=True) 135 | zero_shot_embeddings = embeddings_new[zero_shot_embeddings_mask] 136 | 137 | # Create description mapping from seen to new 138 | label_train_2_zero_shot_similarities = ( 139 | torch.nn.functional.normalize(zero_shot_embeddings) 140 | @ torch.nn.functional.normalize(train_embeddings).T 141 | ) 142 | zero_shot_label_mapping = { 143 | zero_shot_embeddings_idx["id"] 144 | .iloc[zero_shot_label_idx]: train_embeddings_idx["id"] 145 | .iloc[train_label_idx.item()] 146 | for zero_shot_label_idx, train_label_idx in enumerate( 147 | label_train_2_zero_shot_similarities.max(dim=-1).indices 148 | ) 149 | } 150 | 151 | # Create baseline predictions by replicating logits of most similar labels 152 | zero_shot_pinf_baseline_logits = zero_shot_pinf_logits[ 153 | [zero_shot_label_mapping[i] for i in zero_shot_labels.columns] 154 | ] 155 | zero_shot_pinf_baseline_logits.columns = zero_shot_labels.columns 156 | 157 | zero_shot_pinf_baseline_logits.to_hdf(output_file,key='logits_df') 158 | 159 | # Evaluate performance of baseline 160 | eval_metrics = EvalMetrics(device="cuda") 161 | mAP_micro = BinaryAUPRC(device="cpu") 162 | mAP_macro = MultilabelAUPRC(device="cpu", num_labels=zero_shot_labels.shape[-1]) 163 | metrics = eval_metrics.get_metric_collection_with_regex( 164 | pattern="f1_m.*", threshold=0.5, num_labels=zero_shot_labels.shape[-1] 165 | ) 166 | 167 | metrics( 168 | torch.sigmoid( 169 | torch.tensor(zero_shot_pinf_baseline_logits.values, device="cuda") 170 | ), 171 | torch.tensor(zero_shot_labels.values, device="cuda"), 172 | ) 173 | mAP_micro.update( 174 | torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)).flatten(), 175 | torch.tensor(zero_shot_labels.values).flatten(), 176 | ) 177 | mAP_macro.update( 178 | torch.sigmoid(torch.tensor(zero_shot_pinf_baseline_logits.values)), 179 | torch.tensor(zero_shot_labels.values), 180 | ) 181 | 182 | metrics = metrics.compute() 183 | metrics.update({"map_micro": mAP_micro.compute(), "map_macro": mAP_macro.compute()}) 184 | metrics = {k: v.item() for k, v in metrics.items()} 185 | pprint(metrics) 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser(description="Run baseline inference") 190 | parser.add_argument("--annotation-type", type=str, default="GO") 191 | parser.add_argument( 192 | "--test-name", 193 | type=str, 194 | required=True, 195 | help="The name of the test set to run baseline on", 196 | ) 197 | parser.add_argument( 198 | "--model-name", 199 | type=str, 200 | required=True, 201 | help="The name of the latest zero shot model without the .pt. Its predictions wont be used, it is a hack to obtain processed/clean label df", 202 | ) 203 | parser.add_argument( 204 | "--label-embedding-model", 205 | type=str, 206 | required=True, 207 | help="The name of LLM used for text embeddings: E5 or BioGPT", 208 | ) 209 | parser.add_argument( 210 | "--cache", 211 | help="Whether to cache proteinfer predictions if available", 212 | action="store_true", 213 | ) 214 | 215 | args = parser.parse_args() 216 | 217 | main( 218 | annotation_type=args.annotation_type, 219 | test_name=args.test_name, 220 | model_name=args.model_name, 221 | label_embedding_model=args.label_embedding_model, 222 | cache=args.cache, 223 | ) 224 | 225 | 226 | -------------------------------------------------------------------------------- /bin/run_blast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.simplefilter("ignore") 4 | from tqdm import tqdm 5 | import argparse 6 | from protnote.utils.data import read_fasta, generate_vocabularies 7 | from protnote.utils.configs import load_config,construct_absolute_paths, get_project_root 8 | from protnote.models.blast import BlastTopHits 9 | from protnote.utils.data import tqdm_joblib 10 | import pandas as pd 11 | import multiprocessing 12 | from protnote.utils.configs import get_logger 13 | from joblib import Parallel, delayed 14 | import numpy as np 15 | 16 | 17 | # Load the configuration and project root 18 | config, project_root = load_config() 19 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description="Run BLAST") 24 | parser.add_argument( 25 | "--test-data-path", 26 | type=str, 27 | required=True, 28 | help="The test databse of query sequences", 29 | ) 30 | parser.add_argument( 31 | "--train-data-path", 32 | type=str, 33 | required=False, 34 | default=config["paths"]["data_paths"]["TRAIN_DATA_PATH"], 35 | help="The train databse of sequences", 36 | ) 37 | 38 | parser.add_argument( 39 | "--top-k-hits", 40 | type=int, 41 | required=False, 42 | default=1, 43 | help="The number of top hits to return per query in decreasing hit_expect order", 44 | ) 45 | parser.add_argument( 46 | "--max-evalue", 47 | type=float, 48 | required=False, 49 | default=0.05, 50 | help="The evalue threshold. Any this with higher evalue than this threshold is omitted from results", 51 | ) 52 | parser.add_argument( 53 | "--cache", 54 | action="store_true", 55 | default=False, 56 | help="Whether to cache results if available", 57 | ) 58 | 59 | parser.add_argument( 60 | "--save-runtime-info", 61 | action="store_true", 62 | default=False, 63 | help="Whether to save runtime information" 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | # Suppress all Biopython warnings 69 | logger = get_logger() 70 | 71 | test_name = args.test_data_path.split("/")[-1].split(".")[0] 72 | train_name = args.train_data_path.split("/")[-1].split(".")[0] 73 | 74 | raw_results_output_path = results_dir / f"blast_raw_{test_name}_{train_name}_results.tsv" 75 | parsed_results_output_path = results_dir / f"blast_parsed_{test_name}_{train_name}_results.parquet" 76 | pivot_parsed_results_output_path = results_dir / f"blast_pivot_parsed_{test_name}_{train_name}_results.parquet" 77 | runtime_info_output_path = results_dir / f"blast_runtime_info_{test_name}_{train_name}.csv" 78 | 79 | bth = BlastTopHits( 80 | db_fasta_path=args.train_data_path, queries_fasta_path=args.test_data_path 81 | ) 82 | 83 | if not (os.path.exists(raw_results_output_path) & args.cache): 84 | bth.run_blast(output_path=raw_results_output_path, top_k_hits=args.top_k_hits) 85 | 86 | # Parse and save processed results 87 | if not (os.path.exists(parsed_results_output_path) & args.cache): 88 | parsed_results = bth.parse_results( 89 | blast_results_path=raw_results_output_path, 90 | flatten_labels=False, 91 | transfer_labels=True, 92 | ) 93 | parsed_results.to_parquet(parsed_results_output_path, index=False) 94 | else: 95 | parsed_results = pd.read_parquet(parsed_results_output_path) 96 | 97 | # Format as pivoted dataframe 98 | logger.info("Pivoting data") 99 | db_vocab = generate_vocabularies(file_path=args.train_data_path)["label_vocab"] 100 | label2int = {label: idx for idx, label in enumerate(db_vocab)} 101 | 102 | def record_to_pivot(idx_row): 103 | _, row = idx_row 104 | record = [-15.0] * len(db_vocab) 105 | for l in row["transferred_labels"]: 106 | record[label2int[l]] = 15.0 107 | record.insert(0, row["sequence_name"]) 108 | return record 109 | 110 | simplified_results = parsed_results[ 111 | ["sequence_name", "bit_score", "transferred_labels"] 112 | ] 113 | 114 | pivoting_batch_size = 10_000 115 | num_pivoting_baches = int(np.ceil(len(simplified_results) / 10_000)) 116 | simplified_results.iterrows() 117 | 118 | for batch in range(num_pivoting_baches): 119 | logger.info(f"Pivoting batch {batch+1} / {num_pivoting_baches}") 120 | 121 | batch_size = min( 122 | pivoting_batch_size, len(simplified_results) - (batch) * pivoting_batch_size 123 | ) 124 | 125 | with tqdm_joblib(tqdm(total=batch_size)) as pbar: 126 | records = Parallel(n_jobs=multiprocessing.cpu_count())( 127 | delayed(record_to_pivot)(idx_row) 128 | for idx_row in simplified_results[ 129 | batch * pivoting_batch_size : (batch + 1) * pivoting_batch_size 130 | ].iterrows() 131 | ) 132 | 133 | result = pd.DataFrame(records, columns=["sequence_name"] + db_vocab) 134 | result.set_index("sequence_name", inplace=True) 135 | result.index.name = None 136 | result.to_parquet( 137 | str(pivot_parsed_results_output_path).replace('.parquet','') + f"_batch_{batch}.parquet", index=True 138 | ) 139 | 140 | logger.info(f"Merging batched results.") 141 | batch_results = [] 142 | for batch in tqdm(range(num_pivoting_baches)): 143 | batch_results.append( 144 | pd.read_parquet(str(pivot_parsed_results_output_path).replace('.parquet','') + f"_batch_{batch}.parquet") 145 | ) 146 | pd.concat(batch_results).to_parquet(pivot_parsed_results_output_path, index=True) 147 | 148 | logger.info(f"Results saved in {pivot_parsed_results_output_path}") 149 | logger.info(f"Search Duration: {bth.run_duration_seconds}") 150 | logger.info(f"Parse Duration: {bth.parse_results_duration_seconds}") 151 | 152 | # Save the search and parse duration in a csv file, along with the size of query set 153 | 154 | if args.save_runtime_info: 155 | search_parse_duration = pd.DataFrame( 156 | { 157 | "search_duration": [bth.run_duration_seconds], 158 | "parse_duration": [bth.parse_results_duration_seconds], 159 | "query_size": [len(read_fasta(args.test_data_path))] 160 | } 161 | ) 162 | search_parse_duration.to_csv(runtime_info_output_path, index=False) 163 | 164 | 165 | 166 | if __name__ == "__main__": 167 | """ 168 | sample usage: 169 | 170 | python run_blast.py --test-data-path data/swissprot/proteinfer_splits/random/test_GO.fasta --train-data-path data/swissprot/proteinfer_splits/random/train_GO.fasta 171 | 172 | 173 | # List of numbers to iterate over 174 | numbers=(1 10 100 1000 5000 10000 20000) # Modify this list with your actual numbers 175 | 176 | # Loop through each number in the list 177 | for num in "${numbers[@]}"; do 178 | # Run the python script with the current number in the file path 179 | python bin/run_blast.py --test-data-path data/swissprot/proteinfer_splits/random/test_${num}_GO.fasta --train-data-path data/swissprot/proteinfer_splits/random/train_GO.fasta --save-runtime-info; 180 | done 181 | 182 | python bin/run_blast.py --test-data-path data/swissprot/proteinfer_splits/random/test_GO.fasta --train-data-path data/swissprot/proteinfer_splits/random/train_GO.fasta --save-runtime-info; 183 | 184 | 185 | 186 | 187 | 188 | 189 | numbers=(20000) # Modify this list with your actual numbers 190 | 191 | # Loop through each number in the list 192 | for num in "${numbers[@]}"; do 193 | # Run the python script with the current number in the file path 194 | python bin/run_blast.py --test-data-path data/swissprot/proteinfer_splits/random/test_${num}_GO.fasta --train-data-path data/swissprot/proteinfer_splits/random/train_GO.fasta --save-runtime-info; 195 | done 196 | 197 | python bin/run_blast.py --test-data-path data/swissprot/proteinfer_splits/random/test_GO.fasta --train-data-path data/swissprot/proteinfer_splits/random/train_GO.fasta --save-runtime-info; 198 | """ 199 | main() 200 | 201 | #!/bin/bash 202 | 203 | -------------------------------------------------------------------------------- /bin/test_ablation.sh: -------------------------------------------------------------------------------- 1 | names=( 2 | "2024-08-21_06-00-55_seed_42_ablations_OVERRIDE_WEIGHTED_SAMPLING_False_last_epoch.pt" \ 3 | "2024-08-21_06-05-49_seed_42_ablations_OVERRIDE_LABEL_EMBEDDING_NOISING_ALPHA_0_last_epoch.pt" \ 4 | "2024-08-21_06-12-36_seed_42_ablations_OVERRIDE_LOSS_FN_BCE_last_epoch.pt" \ 5 | "2024-08-21_06-33-42_seed_42_ablations_OVERRIDE_LABEL_AUGMENTATION_DESCRIPTIONS_name_last_epoch.pt" \ 6 | "2024-08-21_06-43-50_seed_42_ablations_OVERRIDE_LABEL_ENCODER_CHECKPOINT_biogpt_last_epoch.pt" \ 7 | "2024-08-21_06-00-52_seed_42_ablations_OVERRIDE_AUGMENT_RESIDUE_PROBABILITY_0_last_epoch.pt" \ 8 | "2024-09-09_19-34-44_seed_12_ablations_OVERRIDE_LABEL_AUGMENTATION_DESCRIPTIONS_name_last_epoch.pt" \ 9 | "2024-09-09_19-48-03_seed_12_ablations_OVERRIDE_LABEL_ENCODER_CHECKPOINT_biogpt_last_epoch.pt" \ 10 | "2024-09-09_19-47-36_seed_12_ablations_OVERRIDE_WEIGHTED_SAMPLING_False_last_epoch.pt" \ 11 | "2024-09-09_19-48-10_seed_12_ablations_OVERRIDE_AUGMENT_RESIDUE_PROBABILITY_0_last_epoch.pt" \ 12 | "2024-09-09_19-48-17_seed_12_ablations_OVERRIDE_LABEL_EMBEDDING_NOISING_ALPHA_0_last_epoch.pt" \ 13 | "2024-09-09_19-43-24_seed_12_ablations_OVERRIDE_LOSS_FN_BCE_last_epoch.pt" \ 14 | "2024-09-09_19-48-10_seed_22_ablations_OVERRIDE_LOSS_FN_BCE_last_epoch.pt" \ 15 | "2024-09-09_19-44-01_seed_22_ablations_OVERRIDE_WEIGHTED_SAMPLING_False_last_epoch.pt" \ 16 | "2024-09-09_19-43-36_seed_22_ablations_OVERRIDE_LABEL_ENCODER_CHECKPOINT_biogpt_last_epoch.pt" \ 17 | "2024-09-09_19-48-59_seed_22_ablations_OVERRIDE_AUGMENT_RESIDUE_PROBABILITY_0_last_epoch.pt" \ 18 | "2024-09-09_19-44-49_seed_22_ablations_OVERRIDE_LABEL_AUGMENTATION_DESCRIPTIONS_name_last_epoch.pt" \ 19 | "2024-09-09_19-47-49_seed_22_ablations_OVERRIDE_LABEL_EMBEDDING_NOISING_ALPHA_0_last_epoch.pt" 20 | ) 21 | 22 | output_file='final_model_ablations.json' 23 | for name in "${names[@]}"; do 24 | model_file="${name}" 25 | model_name="${name%.pt}" 26 | 27 | if [[ "$model_name" == *"LABEL_AUGMENTATION_DESCRIPTIONS_name"* ]] 28 | then 29 | description_augmentation="name" 30 | else 31 | description_augmentation="name+label" 32 | fi 33 | 34 | if [[ "$model_name" == *"biogpt"* ]] 35 | then 36 | label_encoder_checkpoint="microsoft/biogpt" 37 | else 38 | label_encoder_checkpoint="intfloat/multilingual-e5-large-instruct" 39 | fi 40 | 41 | python bin/main.py --test-paths-names TEST_DATA_PATH_ZERO_SHOT --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.3 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT "$label_encoder_checkpoint" AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS "$description_augmentation" INFERENCE_GO_DESCRIPTIONS "$description_augmentation" --model-file "$model_file" --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_DATA_PATH_ZERO_SHOT_"$model_name" --save-val-test-metrics --save-val-test-metrics-file "$output_file" 42 | python bin/main.py --test-paths-names TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.3 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT "$label_encoder_checkpoint" AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS "$description_augmentation" INFERENCE_GO_DESCRIPTIONS "$description_augmentation" --model-file "$model_file" --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES_"$model_name" --save-val-test-metrics --save-val-test-metrics-file "$output_file" 43 | python bin/main.py --test-paths-names VAL_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT "$label_encoder_checkpoint" AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS "$description_augmentation" INFERENCE_GO_DESCRIPTIONS "$description_augmentation" --model-file "$model_file" --name VAL_DATA_PATH_"$model_name" --save-val-test-metrics --save-val-test-metrics-file "$output_file" 44 | done -------------------------------------------------------------------------------- /bin/test_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from itertools import product 4 | from protnote.utils.configs import load_config,construct_absolute_paths, get_project_root 5 | 6 | # Load the configuration and project root 7 | config, project_root = load_config() 8 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 9 | 10 | MODEL_PATH_TOKEN = "" 11 | MODEL_NAME_TOKEN = "" 12 | 13 | #TODO: Simplify the overrides to only the needed ones 14 | TEST_COMMANDS = { 15 | "TEST_DATA_PATH_ZERO_SHOT": f"python bin/main.py --test-paths-names TEST_DATA_PATH_ZERO_SHOT --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.3 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_DATA_PATH_ZERO_SHOT_{MODEL_NAME_TOKEN}", 16 | "TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES": f"python bin/main.py --test-paths-names TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.3 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES_{MODEL_NAME_TOKEN}", 17 | "TEST_EC_DATA_PATH_ZERO_SHOT": f"python bin/main.py --test-paths-names TEST_EC_DATA_PATH_ZERO_SHOT --override EXTRACT_VOCABULARIES_FROM null ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct DECISION_TH .3 --model-file {MODEL_PATH_TOKEN} --annotations-path-name EC_ANNOTATIONS_PATH --base-label-embedding-name EC_BASE_LABEL_EMBEDDING_PATH --name TEST_EC_DATA_PATH_ZERO_SHOT_{MODEL_NAME_TOKEN}", 18 | "TEST_2024_DATA_PATH": f"python bin/main.py --test-paths-names TEST_2024_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_2024_DATA_PATH_{MODEL_NAME_TOKEN}", 19 | "TEST_2024_PINF_VOCAB_DATA_PATH": f"python bin/main.py --test-paths-names TEST_2024_PINF_VOCAB_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --base-label-embedding-name GO_2024_BASE_LABEL_EMBEDDING_PATH --name TEST_2024_PINF_VOCAB_DATA_PATH_{MODEL_NAME_TOKEN}", 20 | "TEST_DATA_PATH": f"python bin/main.py --test-paths-names TEST_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --name TEST_DATA_PATH_{MODEL_NAME_TOKEN}", 21 | "VAL_DATA_PATH": f"python bin/main.py --test-paths-names VAL_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --name VAL_DATA_PATH_{MODEL_NAME_TOKEN}", 22 | "TEST_TOP_LABELS_DATA_PATH": f"python bin/main.py --test-paths-names TEST_TOP_LABELS_DATA_PATH --override EXTRACT_VOCABULARIES_FROM null DECISION_TH 0.5 ESTIMATE_MAP False OPTIMIZATION_METRIC_NAME f1_macro LABEL_ENCODER_CHECKPOINT intfloat/multilingual-e5-large-instruct AUGMENT_RESIDUE_PROBABILITY 0.1 LABEL_EMBEDDING_NOISING_ALPHA 20 TEST_BATCH_SIZE 8 LABEL_AUGMENTATION_DESCRIPTIONS name+label INFERENCE_GO_DESCRIPTIONS name+label --model-file {MODEL_PATH_TOKEN} --name TEST_TOP_LABELS_DATA_PATH_{MODEL_NAME_TOKEN}" 23 | } 24 | 25 | #TODO: Change print for logger 26 | def main(): 27 | parser = argparse.ArgumentParser(description="Run specified test paths.") 28 | parser.add_argument( 29 | "--test-paths-names", 30 | nargs="+", 31 | default=[], 32 | required=False, 33 | help="List of test path names to run", 34 | ) 35 | 36 | parser.add_argument( 37 | "--test-type", type=str, default="all", help="One of all, baseline, model" 38 | ) 39 | parser.add_argument( 40 | "--save-prediction-results", action="store_true", default=False, required=False 41 | ) 42 | parser.add_argument( 43 | "--save-embeddings", action="store_true", default=False, required=False 44 | ) 45 | parser.add_argument( 46 | "--save-val-test-metrics", action="store_true", default=False, required=False 47 | ) 48 | parser.add_argument( 49 | "--save-val-test-metrics-file", 50 | help="json file name to append val/test metrics", 51 | type=str, 52 | default="val_test_metrics.json" 53 | ) 54 | 55 | parser.add_argument( 56 | "--model-files", 57 | nargs="+", 58 | required=True, 59 | help="List of .pt models to test.", 60 | ) 61 | 62 | args = parser.parse_args() 63 | 64 | #Calculate baseline outputs if specified. Only runs once. 65 | aux_model_name = args.model_files[0].split(".pt")[0] 66 | 67 | if args.test_type in ["all", "baseline"]: 68 | print("Running baselines") 69 | for label_encoder, test_set in list( 70 | product( 71 | ["BioGPT", "E5"], 72 | ["TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES", "TEST_DATA_PATH_ZERO_SHOT"], 73 | ) 74 | ): 75 | command = f"python bin/run_baseline.py --test-name {test_set} --cache --model-name {aux_model_name} --label-embedding-model {label_encoder} --annotation-type GO" 76 | print(f"Running command: {command}") 77 | subprocess.run(command, shell=True) 78 | 79 | for label_encoder, test_set in list( 80 | product(["BioGPT", "E5"], ["TEST_EC_DATA_PATH_ZERO_SHOT"]) 81 | ): 82 | command = f"python bin/run_baseline.py --test-name {test_set} --cache --model-name {aux_model_name} --label-embedding-model {label_encoder} --annotation-type EC" 83 | print(f"Running command: {command}") 84 | subprocess.run(command, shell=True) 85 | 86 | 87 | for model_file in args.model_files: 88 | 89 | model_name = model_file.split(".pt")[0] 90 | 91 | print(f"Testing model {model_name} in all specified datasets") 92 | 93 | UPDATED_TEST_COMMANDS = { 94 | k: v.replace(MODEL_PATH_TOKEN, model_file).replace( 95 | MODEL_NAME_TOKEN, model_name 96 | ) 97 | + (" --save-prediction-results" if args.save_prediction_results else "") 98 | + (" --save-embeddings" if args.save_embeddings else "") 99 | + (" --save-val-test-metrics" if args.save_val_test_metrics else "") 100 | + " --save-val-test-metrics-file " 101 | + args.save_val_test_metrics_file 102 | for k, v in TEST_COMMANDS.items() 103 | } 104 | 105 | args.test_paths_names = ( 106 | list(UPDATED_TEST_COMMANDS.keys()) 107 | if args.test_paths_names == [] 108 | else args.test_paths_names 109 | ) 110 | 111 | if args.test_type in ["all", "model"]: 112 | print("testing models...") 113 | for test_path in args.test_paths_names: 114 | if test_path in UPDATED_TEST_COMMANDS: 115 | command = UPDATED_TEST_COMMANDS[test_path] 116 | print(f"Running command for {test_path}: {command}") 117 | subprocess.run(command, shell=True) 118 | else: 119 | print( 120 | f"Test path name {test_path} not found in TEST_COMMANDS dictionary." 121 | ) 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | 127 | -------------------------------------------------------------------------------- /bin/test_proteinfer.py: -------------------------------------------------------------------------------- 1 | from protnote.data.datasets import ProteinDataset, create_multiple_loaders 2 | from protnote.utils.configs import get_setup 3 | from protnote.models.protein_encoders import ProteInfer 4 | from protnote.utils.evaluation import EvalMetrics, save_evaluation_results 5 | from protnote.utils.data import read_json 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | import argparse 10 | import os 11 | import re 12 | from collections import defaultdict 13 | from protnote.utils.losses import FocalLoss 14 | from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC 15 | from torch.cuda.amp import autocast 16 | from protnote.utils.data import generate_vocabularies 17 | from protnote.utils.configs import get_project_root 18 | # Load the configuration and project root 19 | project_root = get_project_root() 20 | 21 | """ 22 | sample usage: python test_proteinfer.py --validation-path-name VAL_DATA_PATH --full-path-name FULL_DATA_PATH 23 | """ 24 | 25 | # Set the TOKENIZERS_PARALLELISM environment variable to False 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | 28 | # Argument parser setup 29 | parser = argparse.ArgumentParser(description="Train and/or Test the ProtNote model.") 30 | 31 | 32 | parser.add_argument("--train-path-name", type=str) 33 | 34 | parser.add_argument("--validation-path-name", type=str) 35 | 36 | parser.add_argument( 37 | "--test-paths-names", 38 | nargs="+", 39 | type=str, 40 | help="Specify all the desired test paths names to test the model using names from config file to test. If not provided, model will not be tested.", 41 | ) 42 | 43 | parser.add_argument( 44 | "--name", 45 | type=str, 46 | default="ProteInfer", 47 | help="Name of the W&B run. If not provided, a name will be generated.", 48 | ) 49 | 50 | parser.add_argument( 51 | "--threshold", 52 | type=float, 53 | default=0.5 54 | ) 55 | 56 | parser.add_argument( 57 | "--proteinfer-weights", 58 | type=str, 59 | default="GO", 60 | help="Which model weights to use: GO or EC", 61 | ) 62 | 63 | parser.add_argument( 64 | "--model-weights-id", 65 | type=int, 66 | default=None, 67 | help="The model id if any. If not specified will pick the same one always.", 68 | ) 69 | 70 | parser.add_argument( 71 | "--override", 72 | nargs="*", 73 | help="Override config parameters in key-value pairs." 74 | ) 75 | 76 | parser.add_argument( 77 | "--save-prediction-results", 78 | action="store_true", 79 | default=False, 80 | help="Save predictions and ground truth dataframe for validation and/or test", 81 | ) 82 | 83 | parser.add_argument( 84 | "--only-inference", 85 | action="store_true", 86 | default=False, 87 | help="Whether to only predict without testing and computing metrics", 88 | ) 89 | 90 | parser.add_argument( 91 | "--only-represented-labels", 92 | action="store_true", 93 | default=False, 94 | help="Whether to only predict labels that are represented in the dataset", 95 | ) 96 | 97 | parser.add_argument( 98 | "--annotations-path-name", 99 | type=str, 100 | default="GO_ANNOTATIONS_PATH", 101 | help="Name of the annotation path. Defaults to GO.", 102 | ) 103 | 104 | parser.add_argument( 105 | "--base-label-embedding-name", 106 | type=str, 107 | default="GO_BASE_LABEL_EMBEDDING_PATH", 108 | help="Name of the base label embedding path. Defaults to GO.", 109 | ) 110 | 111 | 112 | # TODO: Add an option to serialize and save config with a name corresponding to the model save path 113 | 114 | args = parser.parse_args() 115 | 116 | 117 | def to_device(device, *args): 118 | return [ 119 | item.to(device) if isinstance(item, torch.Tensor) else None for item in args 120 | ] 121 | 122 | 123 | if args.override: 124 | args.override += ["WEIGHTED_SAMPLING", "False", "TEST_BATCH_SIZE", 4] 125 | else: 126 | args.override = ["WEIGHTED_SAMPLING", "False", "TEST_BATCH_SIZE", 4] 127 | 128 | 129 | task = args.annotations_path_name.split("_")[0] 130 | config = get_setup( 131 | config_path=project_root / 'configs' / 'base_config.yaml', 132 | run_name=args.name, 133 | train_path_name=args.train_path_name, 134 | val_path_name=args.validation_path_name, 135 | test_paths_names=args.test_paths_names, 136 | annotations_path_name=args.annotations_path_name, 137 | base_label_embedding_name=args.base_label_embedding_name, 138 | amlt=False, 139 | is_master=True, 140 | overrides=args.override + ["EXTRACT_VOCABULARIES_FROM", "null"] 141 | if args.only_represented_labels 142 | else args.override, 143 | ) 144 | params, paths, timestamp, logger = ( 145 | config["params"], 146 | config["paths"], 147 | config["timestamp"], 148 | config["logger"], 149 | ) 150 | 151 | 152 | # Create datasets 153 | train_dataset = ( 154 | ProteinDataset( 155 | data_paths=config["dataset_paths"]["train"][0], 156 | config=config, 157 | logger=logger, 158 | require_label_idxs=params["GRID_SAMPLER"], 159 | label_tokenizer=None, 160 | ) 161 | if args.train_path_name is not None 162 | else None 163 | ) 164 | 165 | validation_dataset = ( 166 | ProteinDataset( 167 | data_paths=config["dataset_paths"]["validation"][0], 168 | config=config, 169 | logger=logger, 170 | require_label_idxs=False, # Label indices are not required for validation. 171 | label_tokenizer=None, 172 | ) 173 | if args.validation_path_name is not None 174 | else None 175 | ) 176 | 177 | test_dataset = ( 178 | ProteinDataset( 179 | data_paths=config["dataset_paths"]["test"][0], 180 | config=config, 181 | logger=logger, 182 | require_label_idxs=False, # Label indices are not required for testing 183 | label_tokenizer=None, 184 | ) 185 | if args.test_paths_names is not None 186 | else None 187 | ) 188 | 189 | # Add datasets to a dictionary 190 | # TODO: This does not support multiple datasets. But I think we should remove that support anyway. Too complicated. 191 | datasets = { 192 | "train": [train_dataset], 193 | "validation": [validation_dataset], 194 | "test": [test_dataset], 195 | } 196 | 197 | # Remove empty datasets. May happen in cases like only validating a model. 198 | datasets = {k: v for k, v in datasets.items() if v[0] is not None} 199 | 200 | # Define label sample sizes for train, validation, and test loaders 201 | label_sample_sizes = { 202 | "train": params["TRAIN_LABEL_SAMPLE_SIZE"], 203 | "validation": params["VALIDATION_LABEL_SAMPLE_SIZE"], 204 | "test": None, # No sampling for the test set 205 | } 206 | 207 | # Initialize new run 208 | logger.info(f"################## {timestamp} RUNNING train.py ##################") 209 | 210 | 211 | # Define data loaders 212 | loaders = create_multiple_loaders( 213 | datasets=datasets, params=params, num_workers=params["NUM_WORKERS"], pin_memory=True 214 | ) 215 | 216 | model_weights = paths[f"PROTEINFER_{args.proteinfer_weights}_WEIGHTS_PATH"] 217 | if args.model_weights_id is not None: 218 | model_weights = re.sub(r'(\d+)\.pkl$', str(args.model_weights_id) + '.pkl', model_weights) 219 | 220 | 221 | model = ProteInfer.from_pretrained( 222 | weights_path=model_weights, 223 | num_labels=config["embed_sequences_params"][ 224 | f"PROTEINFER_NUM_{args.proteinfer_weights}_LABELS" 225 | ], 226 | input_channels=config["embed_sequences_params"]["INPUT_CHANNELS"], 227 | output_channels=config["embed_sequences_params"]["OUTPUT_CHANNELS"], 228 | kernel_size=config["embed_sequences_params"]["KERNEL_SIZE"], 229 | activation=torch.nn.ReLU, 230 | dilation_base=config["embed_sequences_params"]["DILATION_BASE"], 231 | num_resnet_blocks=config["embed_sequences_params"]["NUM_RESNET_BLOCKS"], 232 | bottleneck_factor=config["embed_sequences_params"]["BOTTLENECK_FACTOR"], 233 | ) 234 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 235 | model.to(device) 236 | model = model.eval() 237 | 238 | label_normalizer = read_json(paths["PARENTHOOD_LIB_PATH"]) 239 | 240 | 241 | # Initialize EvalMetrics 242 | eval_metrics = EvalMetrics(device=device) 243 | label_sample_sizes = { 244 | k: (v if v is not None else len(datasets[k][0].label_vocabulary)) 245 | for k, v in label_sample_sizes.items() 246 | if k in datasets.keys() 247 | } 248 | 249 | full_data_path = ( 250 | "FULL_DATA_PATH" if args.proteinfer_weights == "GO" else "FULL_EC_DATA_PATH" 251 | ) 252 | PROTEINFER_VOCABULARY = generate_vocabularies( 253 | file_path=config["paths"][full_data_path] 254 | )["label_vocab"] 255 | 256 | for loader_name, loader in loaders.items(): 257 | print(loader_name, len(loader[0].dataset.label_vocabulary)) 258 | 259 | represented_labels = [ 260 | label in loader[0].dataset.label_vocabulary for label in PROTEINFER_VOCABULARY 261 | ] 262 | 263 | test_metrics = eval_metrics.get_metric_collection_with_regex( 264 | pattern="f1_m.*", 265 | threshold=args.threshold, 266 | num_labels=label_sample_sizes["test"] 267 | if (params["IN_BATCH_SAMPLING"] or params["GRID_SAMPLER"]) is False 268 | else None, 269 | ) 270 | 271 | bce_loss = torch.nn.BCEWithLogitsLoss(reduction="mean") 272 | focal_loss = FocalLoss( 273 | gamma=config["params"]["FOCAL_LOSS_GAMMA"], 274 | alpha=config["params"]["FOCAL_LOSS_ALPHA"], 275 | ) 276 | total_bce_loss = 0 277 | total_focal_loss = 0 278 | test_results = defaultdict(list) 279 | 280 | mAP_micro = BinaryAUPRC(device="cpu") 281 | mAP_macro = MultilabelAUPRC(device="cpu", num_labels=label_sample_sizes["test"]) 282 | 283 | with torch.no_grad(), autocast(enabled=True): 284 | for batch_idx, batch in tqdm(enumerate(loader[0]), total=len(loader[0])): 285 | # Unpack the validation or testing batch 286 | ( 287 | sequence_onehots, 288 | sequence_lengths, 289 | sequence_ids, 290 | label_multihots, 291 | label_embeddings, 292 | ) = ( 293 | batch["sequence_onehots"], 294 | batch["sequence_lengths"], 295 | batch["sequence_ids"], 296 | batch["label_multihots"], 297 | batch["label_embeddings"], 298 | ) 299 | sequence_onehots, sequence_lengths, label_multihots = to_device( 300 | device, sequence_onehots, sequence_lengths, label_multihots 301 | ) 302 | 303 | logits = model(sequence_onehots, sequence_lengths) 304 | if args.only_represented_labels: 305 | logits = logits[:, represented_labels] 306 | 307 | probabilities = torch.sigmoid(logits) 308 | 309 | if not args.only_inference: 310 | test_metrics(probabilities, label_multihots) 311 | 312 | if loader_name in ["validation", "test"]: 313 | mAP_micro.update( 314 | probabilities.cpu().flatten(), label_multihots.cpu().flatten() 315 | ) 316 | mAP_macro.update(probabilities.cpu(), label_multihots.cpu()) 317 | 318 | total_bce_loss += bce_loss(logits, label_multihots.float()) 319 | total_focal_loss += focal_loss(logits, label_multihots.float()) 320 | 321 | if args.save_prediction_results: 322 | test_results["sequence_ids"].append(sequence_ids) 323 | test_results["logits"].append(logits.cpu()) 324 | test_results["labels"].append(label_multihots.cpu()) 325 | 326 | if not args.only_inference: 327 | test_metrics = test_metrics.compute() 328 | test_metrics.update({"bce_loss": total_bce_loss / len(loader[0])}) 329 | test_metrics.update({"focal_loss": total_focal_loss / len(loader[0])}) 330 | 331 | if loader_name in ["validation", "test"]: 332 | test_metrics.update( 333 | {"map_micro": mAP_micro.compute(), "map_macro": mAP_macro.compute()} 334 | ) 335 | 336 | print("\n\n", "=" * 20) 337 | print(f"##{loader_name}##") 338 | print(test_metrics) 339 | print("=" * 20, "\n\n") 340 | 341 | if args.save_prediction_results: 342 | for key in test_results.keys(): 343 | if key == "sequence_ids": 344 | test_results[key] = np.array( 345 | [j for i in test_results["sequence_ids"] for j in i] 346 | ) 347 | else: 348 | test_results[key] = torch.cat(test_results[key]).numpy() 349 | print("saving resuts...") 350 | save_evaluation_results( 351 | results=test_results, 352 | label_vocabulary=loader[0].dataset.label_vocabulary, 353 | run_name=f"{task}_{args.name}" + (str(args.model_weights_id) 354 | if args.model_weights_id is not None 355 | else ""), 356 | output_dir=config["paths"]["RESULTS_DIR"], 357 | data_split_name=loader_name, 358 | save_as_h5=True, 359 | ) 360 | print("Done saving resuts...") 361 | torch.cuda.empty_cache() 362 | -------------------------------------------------------------------------------- /bin/umap_plots.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import umap 3 | from sklearn.preprocessing import StandardScaler 4 | from protnote.utils.data import generate_vocabularies 5 | 6 | import os 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import torch 10 | import obonet 11 | import argparse 12 | from tqdm import tqdm 13 | from protnote.utils.configs import load_config 14 | 15 | # Load the configuration and project root 16 | config, project_root = load_config() 17 | results_dir = config["paths"]["output_paths"]["RESULTS_DIR"] 18 | 19 | def save_fig(name): 20 | plt.savefig(f"{name}.pdf", format="pdf", dpi=1200, bbox_inches="tight") 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Run UMAP plot hparam search") 25 | parser.add_argument( 26 | "--n-neighbors-vals", 27 | nargs="+", 28 | type=int, 29 | required=False, 30 | default=[50, 200], 31 | help="n neighbors values to try", 32 | ) 33 | parser.add_argument( 34 | "--min-dist-vals", 35 | nargs="+", 36 | type=float, 37 | required=False, 38 | default=[0.5, 0.3], 39 | help="min dist values to try", 40 | ) 41 | parser.add_argument( 42 | "--paired-hparams", 43 | action="store_true", 44 | default=False 45 | ) 46 | 47 | parser.add_argument( 48 | "--num-seqs", 49 | type=int, 50 | required=False, 51 | default=100, 52 | help="Number of sequences to consider", 53 | ) 54 | 55 | parser.add_argument( 56 | "--embeddings-path", 57 | type=str, 58 | required=False, 59 | default= results_dir / "test_1_embeddings_TEST_TOP_LABELS_DATA_PATH_seed_replicates_v9_42_sum_last_epoch/batches_0_99.pt", 60 | help="The .pt path of the embeddings to perform umap on.", 61 | ) 62 | 63 | parser.add_argument( 64 | "--go-graph-file", 65 | type=str, 66 | default="go_2019-07-01.obo", #TODO: change to the appropriate go graph file 67 | required=False, 68 | help="the file of the appropriate go graph. The version/date of the go graph should match the date of the test set used to generate --embeddings-path", 69 | ) 70 | 71 | parser.add_argument( 72 | "--test-data-path", 73 | type=str, 74 | default='TEST_TOP_LABELS_DATA_PATH', 75 | required=False, 76 | help="The path name used to generate embeddings-path", 77 | ) 78 | 79 | args = parser.parse_args() 80 | figures_dir = results_dir.parents[0] / "figures" 81 | plt.rcParams["font.size"] = 14 82 | 83 | if not os.path.exists(figures_dir): 84 | os.mkdir(figures_dir) 85 | 86 | print("reading data...") 87 | embeddings = torch.load(args.embeddings_path,map_location="cpu") 88 | joint_embedding_dim = embeddings["joint_embeddings"].shape[-1] 89 | num_labels = embeddings["labels"].shape[-1] 90 | vocab = generate_vocabularies(str(config['paths']['data_paths'][args.test_data_path]))["label_vocab"] 91 | graph = obonet.read_obo(project_root / 'data' / 'annotations' / args.go_graph_file) 92 | vocab_parents = [ 93 | (graph.nodes[go_term]["namespace"] if go_term in graph.nodes else "missing") 94 | for go_term in vocab 95 | ] 96 | 97 | print("pre processing...") 98 | X = embeddings["output_layer_embeddings"][: num_labels * args.num_seqs, :] 99 | sc = StandardScaler() 100 | X_s = sc.fit_transform(X) 101 | 102 | hparams = [args.n_neighbors_vals, args.min_dist_vals] 103 | num_combinations = 1 104 | 105 | if args.paired_hparams: 106 | assert len(hparams[0]) == len( 107 | hparams[1] 108 | ), "hparams must be same lenght with paired_hparams = true" 109 | num_combinations = len(args.n_neighbors_vals) 110 | combos = list(zip(*hparams)) 111 | else: 112 | for hparam in hparams: 113 | num_combinations *= len(hparam) 114 | combos = product(*hparams) 115 | print(f"Testing {num_combinations} hparam combinations") 116 | 117 | hue = vocab_parents * (args.num_seqs) 118 | match_binary_mask = embeddings["labels"][: args.num_seqs, :].flatten() 119 | 120 | palette = sns.color_palette("tab10") 121 | 122 | print("running umap plots...") 123 | for n_neighbors, min_dist in tqdm(combos, total=num_combinations): 124 | 125 | X_r = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist).fit(X_s).embedding_ 126 | 127 | fig = plt.figure(figsize=(7, 7)) 128 | title = f"match vs unmatch n_neighbors={n_neighbors}, min_dist={min_dist}, n = {len(X_r)}" 129 | # output layer showing separation between matching and un-matching protein-function pairs 130 | palette_ = palette[7:8] + [(227/255,179/255,51/255)] #palette[6:7] 131 | print(X_r.shape,num_labels) 132 | sns.scatterplot( 133 | x=X_r[:, 0], 134 | y=X_r[:, 1], 135 | marker=".", 136 | s=2, 137 | hue=match_binary_mask, 138 | edgecolor=None, 139 | palette=palette_, 140 | ) 141 | plt.legend( 142 | markerscale=10, 143 | title="Protein-Function Label", 144 | bbox_to_anchor=(0.5, -0.2), 145 | loc="upper center", 146 | ) 147 | sns.despine() 148 | plt.title(title) 149 | save_fig(os.path.join(figures_dir, title)) 150 | plt.show() 151 | 152 | # Output layer colored by GO Top hierarchy 153 | fig = plt.figure(figsize=(7, 7)) 154 | 155 | palette_ = palette[4:5] + palette[8:10] 156 | match_binary_mask = match_binary_mask.astype(bool) 157 | X_r = ( 158 | umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist) 159 | .fit(X_s[match_binary_mask]) 160 | .embedding_ 161 | ) 162 | print(X_r.shape,num_labels) 163 | title = f"top hierarchy n_neighbors={n_neighbors}, min_dist={min_dist}, , n = {len(X_r)}" 164 | sns.scatterplot( 165 | x=X_r[:, 0], 166 | y=X_r[:, 1], 167 | marker=".", 168 | hue=[ 169 | hue_val 170 | for hue_val, binary_mask_val in zip( 171 | hue, match_binary_mask 172 | ) 173 | if binary_mask_val 174 | ], 175 | s=15, 176 | edgecolor=None, 177 | palette=palette_, 178 | ) 179 | plt.legend( 180 | markerscale=1, 181 | title="Ontology", 182 | bbox_to_anchor=(0.5, -0.2), 183 | loc="upper center", 184 | ) 185 | sns.despine() 186 | plt.title(title) 187 | save_fig(os.path.join(figures_dir, title)) 188 | plt.show() 189 | -------------------------------------------------------------------------------- /bin/update_go_annotations.py: -------------------------------------------------------------------------------- 1 | from protnote.utils.data import read_pickle, save_to_pickle 2 | import argparse 3 | import pandas as pd 4 | 5 | 6 | def update_annotations(old_file_path, new_file_path, output_file_path): 7 | """Updates old annotations using the new annotations file and saves the updated file.""" 8 | # Load new and old annotations 9 | new_annotations = read_pickle(new_file_path) 10 | old_annotations = read_pickle(old_file_path) 11 | 12 | # Find added terms (terms in new but not in old) 13 | added_terms = set(new_annotations.index) - set(old_annotations.index) 14 | 15 | # Update old annotations by adding new terms 16 | old_updated = pd.concat([new_annotations.loc[added_terms], old_annotations]) 17 | 18 | #Sort old_updated by index 19 | old_updated = old_updated.sort_index() 20 | 21 | 22 | # Save the updated annotations to the output file path 23 | save_to_pickle(old_updated,output_file_path) 24 | 25 | print(f"Updated annotations saved to {output_file_path}") 26 | 27 | def main(): 28 | # Initialize the argument parser 29 | parser = argparse.ArgumentParser(description="Update old GO annotations file using a new one.") 30 | 31 | # Add arguments for file paths 32 | parser.add_argument('--old-annotations-file-path', required=True, help='Path to the old annotations file') 33 | parser.add_argument('--new-annotations-file-path', required=True, help='Path to the new annotations file') 34 | parser.add_argument('--output-file-path', required=True, help='Path to save the updated annotations file') 35 | 36 | # Parse the arguments 37 | args = parser.parse_args() 38 | 39 | # Call the update_annotations function with the provided arguments 40 | update_annotations(args.old_annotations_file_path, args.new_annotations_file_path, args.output_file_path) 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /bin/upload_to_zenodo.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | 4 | def upload(filename, path, bucket, access_token): 5 | # This is verbatim from the docs, but they forgot the params section 6 | with open(path, "rb") as fp: 7 | r = requests.put( 8 | "%s/%s" % (bucket, filename), 9 | data=fp, 10 | params={'access_token': access_token}, 11 | ) 12 | return r.json() 13 | 14 | 15 | # Main script using argparse 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description="Upload a file to Zenodo") 18 | parser.add_argument( 19 | "--file", 20 | type=str, 21 | required=True, 22 | help="The file to upload", 23 | ) 24 | parser.add_argument( 25 | "--path", 26 | type=str, 27 | required=True, 28 | help="The path to the file", 29 | ) 30 | #make deposition id and access token part of arguments 31 | parser.add_argument( 32 | "--deposition-id", 33 | type=int, 34 | required=True, 35 | help="The deposition id", 36 | ) 37 | parser.add_argument( 38 | "--access-token", 39 | type=str, 40 | required=True, 41 | help="The access token", 42 | ) 43 | 44 | args = parser.parse_args() 45 | 46 | r = requests.get(f'https://zenodo.org/api/deposit/depositions/{args.deposition_id}', 47 | params={'access_token': args.access_token}) 48 | r.status_code 49 | data = r.json() 50 | 51 | response = upload(args.file, args.path, data['links']['bucket'], args.access_token) 52 | print(response) -------------------------------------------------------------------------------- /configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | params: 3 | ############## PARAMS THAT NEED TO BE CHANGED FOR AMULET V100 VS. SANDBOX A100 ############## 4 | 5 | ### V100 ### 6 | # Batch sizes 7 | TRAIN_BATCH_SIZE: 8 # 8 per GPU for V100 (64 if using 8 GPUs) 8 | VALIDATION_BATCH_SIZE: 8 # 32 per GPU for V100 (256 if using 8 GPUs) 9 | TEST_BATCH_SIZE: 8 # 32 per GPU for V100 (256 if using 8 GPUs) 10 | 11 | # Sampling 12 | GRID_SAMPLER: False 13 | IN_BATCH_SAMPLING: False # Use either in batch sampling or lable samples, not both. 14 | WEIGHTED_SAMPLING: True #Sample rare sequences more often as dictaded by inverse frequency of labels 15 | INV_FREQUENCY_POWER: 0.5 #Used for weighted sampling AND relevant losses with weighting 16 | SEQUENCE_WEIGHT_AGG: 'sum' #The aggregation to calculate sequence weight from applicable label weights. 17 | SAMPLING_LOWER_CLAMP_BOUND: null # Lower bound for when clamping sampling weights 18 | SAMPLING_UPPER_CLAMP_BOUND: null # Upper bound for when clamping sampling weights 19 | TRAIN_LABEL_SAMPLE_SIZE: null # 15K for both. IS ALSO USED WITH GRID SAMPLER 20 | VALIDATION_LABEL_SAMPLE_SIZE: null # use all labels in validation 21 | 22 | # Inference 23 | LABEL_BATCH_SIZE_LIMIT_NO_GRAD: 50 # 1K for V100 24 | SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD: 32 # 64 for V100 25 | 26 | # General 27 | LEARNING_RATE: 0.0003 28 | OPTIMIZER: Adam 29 | WEIGHT_DECAY: 0.001 # Only used for ADAM W or SGD 30 | PROTEIN_EMBEDDING_DIM: 1100 31 | LABEL_EMBEDDING_DIM: 1024 32 | LATENT_EMBEDDING_DIM: 1024 33 | OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR: 3 # Scale MLP hidden state with respect to LATENT_EMBEDDING_DIM 34 | OUTPUT_MLP_NUM_LAYERS: 3 35 | OUTPUT_NEURON_PROBABILITY_BIAS: null 36 | OUTPUT_MLP_BATCHNORM: True 37 | RESIDUAL_CONNECTION: False 38 | SYNC_BN: False 39 | OUTPUT_MLP_DROPOUT: 0.0 40 | LABEL_EMBEDDING_DROPOUT: 0.0 41 | SEQUENCE_EMBEDDING_DROPOUT: 0.0 42 | PROJECTION_HEAD_NUM_LAYERS: 4 43 | PROJECTION_HEAD_HIDDEN_DIM_SCALE_FACTOR: 3 44 | FEATURE_FUSION: concatenation # Select from concatenation, concatenation_diff, concatenation_prod, similarity 45 | LABEL_EMBEDDING_POOLING_METHOD: mean # Select from mean, last_token, all 46 | EXTRACT_VOCABULARIES_FROM: FULL_DATA_PATH # Can be any predefined path (e.g., FULL_DATA_PATH) or null. Null means generate vocab from scratch with dataset. Must set to null for zero-shot 47 | OPTIMIZATION_METRIC_NAME: f1_macro # Only micro metrics are supported if sampling labels in validation 48 | DECISION_TH_METRIC_NAME: f1_macro 49 | ESTIMATE_MAP: False 50 | NUM_EPOCHS: 46 51 | 52 | # Memory optimizations 53 | GRADIENT_ACCUMULATION_STEPS: 1 # 1 = no gradient accumulation 54 | GRADIENT_CHECKPOINTING: False # True = gradient checkpointing 55 | LORA: True # True = use LORA 56 | LORA_RANK: 4 57 | LORA_ALPHA: 8 # Typically = 2 * Rank 58 | CLIP_VALUE: 1 # Gradient clipping, set to "null" to turn off 59 | 60 | # Losses. Only the parameters for selected loss will be used 61 | LOSS_FN: FocalLoss # Currently supported: BCE, FocalLoss, BatchWeightedBCE, RGDBCE, WeightedBCE 62 | FOCAL_LOSS_GAMMA: 2 63 | FOCAL_LOSS_ALPHA: -1 64 | BCE_POS_WEIGHT: 1 # 671.7130737304688 65 | SUPCON_TEMP: 0.07 66 | RGDBCE_TEMP: 0.12 # search from [1,1/3,1/5,1/7,1/9] according to paper 67 | LABEL_SMOOTHING: 0.0 # Currently only implemented for FocalLoss 68 | 69 | # Sequence and label encoders 70 | PRETRAINED_SEQUENCE_ENCODER: True 71 | TRAIN_SEQUENCE_ENCODER: False 72 | LABEL_ENCODER_NUM_TRAINABLE_LAYERS: 0 73 | DISTRIBUTE_LABELS: False 74 | TRAIN_PROJECTION_HEAD: True 75 | LABEL_ENCODER_CHECKPOINT: intfloat/multilingual-e5-large-instruct #microsoft/biogpt, intfloat/e5-large-v2, intfloat/multilingual-e5-large-instruct 76 | 77 | # Data processing and metrics 78 | DEDUPLICATE: True 79 | MAX_SEQUENCE_LENGTH: 10000 # Only affects training set. Just 8 seqs with 10K+ length 80 | REMOVE_UNREPRESENTED_LABELS: False # Removes labels that never apply to train sequences 81 | NORMALIZE_PROBABILITIES: False 82 | INFERENCE_GO_DESCRIPTIONS: name+label # Ensembling during inference (validation, test). Only label, name, or name+label are supported. 83 | 84 | # Augmentation 85 | AUGMENT_RESIDUE_PROBABILITY: 0.1 # 0.0 = no augmentation. The probability that any given residue will be augmented. AlphaFold uses 0.15 (15% of residues are augmented) 86 | USE_RESIDUE_MASKING: False # If True, also use masking to augment sequences. Must then also set AUGMENT_SEQUENCE_PROBABILITY > 0, and TRAIN_SEQUENCE_ENCODER = True 87 | LABEL_AUGMENTATION_DESCRIPTIONS: name+label # Specifies what descriptions to use during training. # Options are any combination of label, name, and synonym_exact. Must include at least one. 88 | LABEL_EMBEDDING_NOISING_ALPHA: 20.0 # 0.0 = no noising. Typical ranges are <20. Noising scalar from https://arxiv.org/pdf/2310.05914.pdf 89 | 90 | # Constants 91 | SEED: 42 92 | EPOCHS_PER_VALIDATION: 1 # Must be >= 1 93 | NUM_WORKERS: 3 94 | DECISION_TH: 0.5 # Set to null if you want to use the best threshold from validation 95 | 96 | # Subset fractions (for rapid prototyping; set to 1 for final model)[s] 97 | TRAIN_SUBSET_FRACTION: 1 # Set to 1.0 if you want to use all data 98 | VALIDATION_SUBSET_FRACTION: 1 # Set to 1.0 if you want to use all data 99 | TEST_SUBSET_FRACTION: 1 # Set to 1.0 if you want to use all data 100 | SHUFFLE_LABELS: False # Only has an effect if sample sizes are set and not using grid sampler. If False, simply select same labels in order 101 | 102 | # Constants for protein encoder model (e.g. ProteInfer). Not really params since 103 | # we are not going to change these. 104 | embed_sequences_params: 105 | INPUT_CHANNELS: 20 106 | OUTPUT_CHANNELS: 1100 107 | KERNEL_SIZE: 9 108 | DILATION_BASE: 3 109 | NUM_RESNET_BLOCKS: 5 110 | BOTTLENECK_FACTOR: 0.5 111 | PROTEINFER_NUM_GO_LABELS: 32102 112 | PROTEINFER_NUM_EC_LABELS: 5134 113 | 114 | # Paths to data, vocabularies, embeddings, and models 115 | paths: 116 | # Paths referenced relative to DATA_PATH (will have "data" prepended) 117 | data_paths: 118 | # ProteInfer data paths 119 | TRAIN_DATA_PATH: swissprot/proteinfer_splits/random/train_GO.fasta 120 | TRAIN_2024_DATA_PATH: swissprot/proteinfer_splits/random/train_GO_jul_2024.fasta 121 | VAL_DATA_PATH: swissprot/proteinfer_splits/random/dev_GO.fasta 122 | TEST_DATA_PATH: swissprot/proteinfer_splits/random/test_GO.fasta 123 | TEST_TOP_LABELS_DATA_PATH: swissprot/proteinfer_splits/random/test_top_labels_GO.fasta 124 | 125 | TRAIN_CLUSTERED_DATA_PATH: swissprot/proteinfer_splits/clustered/train_GO.fasta 126 | VAL_CLUSTERED_DATA_PATH: swissprot/proteinfer_splits/clustered/dev_GO.fasta 127 | TEST_CLUSTERED_DATA_PATH: swissprot/proteinfer_splits/clustered/test_GO.fasta 128 | 129 | FULL_DATA_PATH: swissprot/proteinfer_splits/random/full_GO.fasta # used to generate vocabularies 130 | FULL_EC_DATA_PATH: swissprot/proteinfer_splits/random/full_EC.fasta # used to generate vocabularies 131 | 132 | TEST_2024_DATA_PATH: swissprot/proteinfer_splits/random/test_GO_jul_2024.fasta 133 | TEST_2024_PINF_VOCAB_DATA_PATH: swissprot/proteinfer_splits/random/test_GO_jul_2024_pinf_vocab.fasta 134 | 135 | # Zero-shot data paths 136 | TRAIN_DATA_PATH_ZERO_SHOT: swissprot/proteinfer_splits/random/fake_train_GO_zero_shot.fasta 137 | VAL_DATA_PATH_ZERO_SHOT: swissprot/proteinfer_splits/random/fake_dev_GO_zero_shot.fasta 138 | TEST_DATA_PATH_ZERO_SHOT: zero_shot/GO_swissprot_jul_2024.fasta 139 | TEST_DATA_PATH_ZERO_SHOT_LEAF_NODES: zero_shot/GO_swissprot_leaf_nodes_jul_2024.fasta 140 | 141 | #Other non-GO datasets to test zero shot 142 | VAL_EC_DATA_PATH_ZERO_SHOT: zero_shot/dev_EC.fasta 143 | TEST_EC_DATA_PATH_ZERO_SHOT: zero_shot/test_EC.fasta 144 | FULL_EC_DATA_PATH: swissprot/proteinfer_splits/random/full_EC.fasta 145 | 146 | # Vocabulary paths 147 | VOCABULARIES_DIR: vocabularies/proteinfer 148 | GO_ANNOTATIONS_PATH: annotations/go_annotations_jul_2024.pkl 149 | GO_ANNOTATIONS_2019_UPDATED_PATH: annotations/go_annotations_2019_07_01_updated.pkl 150 | EC_ANNOTATIONS_PATH: annotations/ec_annotations.pkl 151 | 152 | # Embeddings paths (if using frozen pre-trained models) 153 | GO_BASE_LABEL_EMBEDDING_PATH: embeddings/frozen_label_embeddings.pt # this is the base path which is modified by POOLING_METHOD 154 | GO_2024_BASE_LABEL_EMBEDDING_PATH: embeddings/2024_frozen_label_embeddings.pt 155 | EC_BASE_LABEL_EMBEDDING_PATH: embeddings/ecv1_frozen_label_embeddings.pt # this is the base path which is modified by POOLING_METHOD 156 | 157 | # ProteInfer paths 158 | PARENTHOOD_LIB_PATH: vocabularies/parenthood_jul_2024.json 159 | PROTEINFER_GO_WEIGHTS_PATH: models/proteinfer/GO_model_weights13703706.pkl 160 | PROTEINFER_EC_WEIGHTS_PATH: models/proteinfer/EC_model_weights13703966.pkl 161 | 162 | #SwissProt 163 | LATEST_SWISSPROT_DATA_PATH: swissprot/uniprot_sprot_jul_2024.dat 164 | 165 | # Paths referenced relative to OUTPUT_PATH (will have "outputs" prepended) 166 | output_paths: 167 | # Where to save the model 168 | OUTPUT_MODEL_DIR: checkpoints/ 169 | 170 | # Where to save results 171 | RESULTS_DIR: results/ 172 | 173 | # Where to log 174 | LOG_DIR: logs/ -------------------------------------------------------------------------------- /docs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protnote/883ec27cc502cbd155c62ba49b3da7b18dd78177/docs/.gitkeep -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: protnote 2 | channels: 3 | - biobuilds 4 | - bioconda 5 | - pytorch 6 | - huggingface 7 | - nvidia 8 | - anaconda 9 | - conda-forge 10 | - defaults 11 | dependencies: 12 | - python=3.10.13 #3.12.4 13 | - pytorch=2.0.1 14 | - torchvision=0.15.2 15 | - pytorch-cuda=11.8 16 | - pandas=1.5.2 17 | - joblib=1.1.1 18 | - transformers=4.32.1 19 | - torchmetrics=1.2.0 20 | - torchdata=0.7.1 21 | - wandb=0.15.11 22 | - sacremoses=0.0.53 23 | - pynvml=11.5.0 24 | - torchmetrics=1.2.0 25 | - protobuf=3.20.3 26 | - scipy=1.13.1 27 | - seaborn=0.13.2 28 | - scikit-learn=1.3.0 29 | - matplotlib=3.9.2 30 | - umap-learn=0.5.4 31 | - blast=2.12.0 32 | - openpyxl=3.1.5 33 | - pip: 34 | - pip 35 | - torcheval==0.0.7 36 | - wget==3.2 37 | - azureml-mlflow==1.53.0 38 | - loralib==0.1.2 39 | - tensorboard==2.15.1 40 | - obonet==1.0.0 41 | - blosum==2.0.2 42 | - biopython==1.84 43 | - ipykernel==6.29.5 44 | prefix: /anaconda/envs/protnote 45 | -------------------------------------------------------------------------------- /hyperdrive_seed_replicates.yml: -------------------------------------------------------------------------------- 1 | description: Protein Functions replicate experiment with 5 different seeds and 3 different weight sampling schemes. 2 | Up to 46 epochs. 3 | 4 | target: 5 | service: sing 6 | name: barlow01 7 | workspace_name: bio0-ext 8 | 9 | environment: 10 | image: amlt-sing/acpt-2.0.1-py3.10-cuda11.8 #azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04:latest 11 | conda_yaml_file: $CONFIG_DIR/environment.yml 12 | skip_conda_packages_on_sing: 13 | - tensorflow 14 | - cudatoolkit 15 | - deepspeed 16 | - pip 17 | - python\b 18 | 19 | code: 20 | local_dir: $CONFIG_DIR/ 21 | 22 | data: 23 | local_dir: $CONFIG_DIR/data 24 | remote_dir: data 25 | 26 | # SKU usage: G1 (single GPU), G4 (quad GPU), G4-V100 (1 machine, 4 V100 gpus), etc... 27 | search: 28 | job_template: 29 | name: "{experiment_name}_SEED_{SEED}_SEQUENCE_WEIGHT_AGG_{SEQUENCE_WEIGHT_AGG}" 30 | sku: G8-V100 31 | priority: high 32 | preemptible: False 33 | command: 34 | - python main.py 35 | --nodes 1 36 | --gpus 8 37 | --train-path-name TRAIN_DATA_PATH 38 | --validation-path-name VAL_DATA_PATH 39 | --name "{experiment_name}_{SEED}_{SEQUENCE_WEIGHT_AGG}" 40 | --override ESTIMATE_MAP True 41 | --amlt 42 | --use-wandb 43 | submit_args: 44 | env: 45 | WANDB_BASE_URL: "https://microsoft-research.wandb.io" 46 | WANDB_API_KEY: "$WANDB_API_KEY" 47 | type: hyperdrive 48 | sampling: grid 49 | max_trials: 10 50 | parallel_trials: 10 51 | params: 52 | - name: SEED 53 | values: [12,22,32,42,52] 54 | - name: SEQUENCE_WEIGHT_AGG 55 | values: ['sum','mean'] 56 | 57 | -------------------------------------------------------------------------------- /img/main_fig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protnote/883ec27cc502cbd155c62ba49b3da7b18dd78177/img/main_fig.jpg -------------------------------------------------------------------------------- /img/main_fig.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protnote/883ec27cc502cbd155c62ba49b3da7b18dd78177/img/main_fig.pdf -------------------------------------------------------------------------------- /model_card.md: -------------------------------------------------------------------------------- 1 | 2 | # Model Card for ProtNote 3 | 4 | 5 | 6 | ProtNote is a multimodal deep learning model that leverages free-form text to enable both supervised and zero-shot protein function prediction 7 | 8 | ## Model Details 9 | 10 | ### Model Description 11 | 12 | Understanding the protein sequence-function relationship is essential for advancing protein biology and engineering. However fewer than 1% of known protein sequences have human-verified functions. While deep learning methods have demonstrated promise for protein function prediction current models are limited to predicting only those functions on which they were trained. 13 | 14 | ProtNote is a multimodal deep learning model that leverages free-form text to enable both supervised and zero-shot protein function prediction. ProtNote not only maintains near state-of-the-art performance for annotations in its train set, but also generalizes to unseen and novel functions in zero-shot test settings. 15 | 16 | - **Developed by:** Samir Char, Nathaniel Corley, Sarah Alamdari, Kevin K. Yang, Ava P. Amini 17 | - **Shared by:** Microsoft Research New England 18 | - **Model type:** ProtNote is two tower Deep Learning model for protein funciton prediction. The first tower is a Convolutional Neural Network that encodes protein sequences, while the second is a Transformer used to encode protein function text descriptions. These two representations are combined through and output Multi Layer Perceptron that computes the final predictions. 19 | - **Language(s) (NLP):** English 20 | - **License:** [MIT License](https://opensource.org/licenses/MIT) 21 | 22 | ### Model Sources [optional] 23 | 24 | 25 | 26 | - **Repository:** https://github.com/microsoft/protnote/tree/main 27 | - **Paper [optional]:** TODO 28 | 29 | ## Uses 30 | 31 | 32 | ### Direct Use 33 | 34 | This model is intended for research use. It can be used to predict the probability that the given protein sequences carry out a function described via free-from text. We provide model weights under five different seeds. Detailed use instructions are available in the model repo. 35 | 36 | 37 | ## Bias, Risks, and Limitations 38 | 39 | - This model is intended for use on protein sequences. It is not meant for other biological sequences, such as DNA sequences. 40 | - Even though the model could take any type of text as input, we recommend providing only functional descriptions. Non-functional descriptions may exhibit reduced performance. 41 | - The model may be biased toward the content, vocabulary, and language style of the functional descriptions in the Gene Ontology, which is the dataset used to train the model. 42 | 43 | 44 | ## How to Get Started with the Model 45 | 46 | Use the code below to install our model: 47 | 48 | ``` 49 | git clone https://github.com/microsoft/protnote.git 50 | cd protnote 51 | conda env create -f environment.yml 52 | conda activate protnote 53 | pip install -e ./ # make sure ./ is the dir including setup.py 54 | ``` 55 | 56 | For detailed instructions on package usage, please refer to the README in model repo: https://github.com/microsoft/protnote.git 57 | 58 | 59 | ## Training Details 60 | 61 | 62 | ### Training Data 63 | 64 | 65 | 66 | ProtNote is trained with protein sequences from the SwissProt section of the [UniProt](http://www.uniprot.org/) database, the world’s largest repository of manually curated, high-quality protein sequence and function information. 80% of the sequences are randomly assigned to train, 10% to validation, and 10% to test; we use the assignments reported by ProteInfer. We remove duplicated sequences and long sequences with more than 10,000 amino acids. 67 | 68 | Further, the model is trained on descriptions from [Gene Ontology](https://geneontology.org/) (GO) annotations. GO annotations capture our knowledge of three aspects of protein biology: molecular function, biological process, and cellular component. Consequently, GO terms are phrases describing the molecular actions of gene products, the biological processes in which those actions occur, and the cellular locations where they are present. We train our model with the GO terms descriptions from the [01 July 2019 GO Annotations release](https://release.geneontology.org/2019-07-01/ontology/go.obo). 69 | 70 | 71 | 72 | ### Training Procedure 73 | 74 | To build a model capable of both supervised and zero-shot 75 | protein-function annotation, we formulate Gene Ontology (GO) term prediction from protein sequence as a binary classification problem: given a protein-GO term pair, the model predicts if the protein has the presented GO term. To combat class imbalance ProtNote leverages Focal Loss and weighted sampling of the sequences according to the inverse frequency of their functional annotations. 76 | 77 | ProtNote uses ProteInfer as a protein sequence encoder and Multilingual E5 Text Embedding model as a text encoder – both encoders are frozen throughout training. To transform the embeddings to a more suitable space and reshape them to a common size d = 1024, they are each passed through independent multi-layer perceptrons (MLP) with three hidden layers of 3d = 3072 units and an output layer of size d = 1024. These embeddings are concatenated and fed through an MLP with three hidden layers of 3d = 3072 units, followed by the output neuron. The MLPs have no bias parameters because we use batch normalization and ReLU activations. 78 | 79 | ProtNote is trained for 46 epochs on 8x 32GB V100 NVIDIA GPUs using an effective batch size of 256 (32 x 8) with dynamic padding. The model trains using Adam optimizer with learning rate of 0.0003 and employs gradient clipping set to 1 and mixed precision [26]. We select the model checkpoint based on best validation performance 80 | 81 | 82 | ## Evaluation 83 | 84 | 85 | 86 | ### Testing Data & Metrics 87 | 88 | #### Testing Data 89 | 90 | We evaluate ProtNote in both the supervised and zero-shot settings, benchmarking against the state-of-the-art deep-learning method ProteInfer and the gold-standard, homology-based method BLAST in the supervised setting and against custom embedding-based baselines in the zero-shot setting. For completeness, we break down the models’ performance across the three GO Ontologies (biological process, cellular component, molecular function) and the seven top-level EC classes (oxidoreductases, transferases hydrolases, lyases, isomerases, ligases, and translocases). 91 | 92 | To asses ProtNote's ability to predict unseen/novel funcitonal text descriptions, we test it on unseen sequences *and* annotations from the July 2024 releases (data 5 years older than the training data) of the Gene Ontology ([Jul 2024 GO Release](https://release.geneontology.org/2024-06-17/ontology/go.obo)) and the Enzyme Commission (a data the model was never trained on) ([EC descriptions](https://ftp.expasy.org/databases/enzyme/enzclass.txt), [EC annotations](https://ftp.expasy.org/databases/enzyme/enzyme.dat)). 93 | 94 | 95 | #### Metrics 96 | 97 | 98 | 99 | To evaluate the performance of different models, we use both macro- and micro-averaged mean Average Precision (mAP), also known as the Area Under the Precision-Recall Curve (AUPRC). The mAP metric summarizes model performance across all possible prediction thresholds, eliminating threshold tuning and providing a more robust assessment of model performance than threshold-dependent metrics such as F1 or Fmax scores. We report both mAP Macro and mAP Micro because of their different virtues, although we use mAP Macro for model selection. 100 | 101 | ### Results 102 | 103 | In this work, we introduce ProtNote, a multimodal deep learning model capable of supervised and zero-shot protein function annotation. We demonstrate how ProtNote leverages unstructured text to predict the likelihood that proteins perform arbitrary functions. Importantly, we showcase that ProtNote generalizes to both unseen sequences and functions by evaluating the model on newly-added GO annotations and on enzyme annotations, which were not used to train the model. We observe that this generality does not come at the cost of supervised performance; our model is also performant at predicting known GO terms, performing on par with the state-of-the-art model ProteInfer. 104 | 105 | ## Environmental Impact 106 | 107 | 108 | 109 | - **Hardware Type:** 8 x 32GB NVIDIA V100 GPUs 110 | - **Hours used:** 960 = 5 days x 24h x 8 GPUs 111 | - **Cloud Provider:** Azure 112 | - **Compute Region:** East US 2 113 | - **Carbon Emitted:** 106.5 kg CO2 114 | 115 | ## Citation [optional] 116 | 117 | TODO: 118 | 119 | **BibTeX:** 120 | 121 | {{ citation_bibtex | default("[More Information Needed]", true)}} 122 | 123 | **APA:** 124 | 125 | {{ citation_apa | default("[More Information Needed]", true)}} 126 | -------------------------------------------------------------------------------- /proteinfer_conda_requirements.yml: -------------------------------------------------------------------------------- 1 | name: proteinfer 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - aiohttp=3.8.3=py37h5eee18b_0 10 | - aiosignal=1.2.0=pyhd3eb1b0_0 11 | - async-timeout=4.0.2=py37h06a4308_0 12 | - asynctest=0.13.0=py_0 13 | - attrs=22.1.0=py37h06a4308_0 14 | - backcall=0.2.0=pyhd3eb1b0_0 15 | - blas=1.0=mkl 16 | - brotlipy=0.7.0=py37h27cfd23_1003 17 | - c-ares=1.19.0=h5eee18b_0 18 | - ca-certificates=2023.7.22=hbcca054_0 19 | - cachetools=4.2.2=pyhd3eb1b0_0 20 | - certifi=2023.7.22=pyhd8ed1ab_0 21 | - cffi=1.15.1=py37h5eee18b_3 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - cryptography=39.0.1=py37h9ce1e76_0 24 | - decorator=5.1.1=pyhd3eb1b0_0 25 | - entrypoints=0.4=py37h06a4308_0 26 | - frozenlist=1.3.3=py37h5eee18b_0 27 | - google-api-core=2.10.1=py37h06a4308_0 28 | - google-api-python-client=2.95.0=pyhd8ed1ab_0 29 | - google-auth=2.6.0=pyhd3eb1b0_0 30 | - google-auth-httplib2=0.1.0=pyhd8ed1ab_1 31 | - google-cloud-core=2.3.2=py37h06a4308_0 32 | - google-cloud-storage=2.10.0=pyh1a96a4e_0 33 | - google-crc32c=1.5.0=py37h5eee18b_0 34 | - google-resumable-media=2.4.0=py37h06a4308_0 35 | - googleapis-common-protos=1.56.4=py37h06a4308_0 36 | - httplib2=0.22.0=pyhd8ed1ab_0 37 | - idna=3.4=py37h06a4308_0 38 | - intel-openmp=2023.1.0=hdb19cb5_46305 39 | - ipython=7.33.0=py37h89c1867_0 40 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 41 | - jedi=0.18.1=py37h06a4308_1 42 | - ld_impl_linux-64=2.38=h1181459_1 43 | - libcrc32c=1.1.2=h6a678d5_0 44 | - libffi=3.4.4=h6a678d5_0 45 | - libgcc-ng=11.2.0=h1234567_1 46 | - libgomp=11.2.0=h1234567_1 47 | - libprotobuf=3.20.3=he621ea3_0 48 | - libsodium=1.0.18=h7b6447c_0 49 | - libstdcxx-ng=11.2.0=h1234567_1 50 | - matplotlib-inline=0.1.6=py37h06a4308_0 51 | - mkl=2023.1.0=h6d00ec8_46342 52 | - multidict=6.0.2=py37h5eee18b_0 53 | - ncurses=6.4=h6a678d5_0 54 | - openssl=1.1.1v=h7f8727e_0 55 | - parso=0.8.3=pyhd3eb1b0_0 56 | - pexpect=4.8.0=pyhd3eb1b0_3 57 | - pickleshare=0.7.5=pyhd3eb1b0_1003 58 | - pip=22.3.1=py37h06a4308_0 59 | - portalocker=2.3.0=py37h06a4308_0 60 | - prompt-toolkit=3.0.36=py37h06a4308_0 61 | - ptyprocess=0.7.0=pyhd3eb1b0_2 62 | - pyasn1=0.4.8=pyhd3eb1b0_0 63 | - pyasn1-modules=0.2.8=py_0 64 | - pycparser=2.21=pyhd3eb1b0_0 65 | - pygments=2.11.2=pyhd3eb1b0_0 66 | - pyopenssl=23.0.0=py37h06a4308_0 67 | - pyparsing=3.0.9=py37h06a4308_0 68 | - pysocks=1.7.1=py37_1 69 | - python=3.7.16=h7a1cb2a_0 70 | - python_abi=3.7=2_cp37m 71 | - pytorch=1.13.1=py3.7_cpu_0 72 | - pytorch-mutex=1.0=cpu 73 | - pyzmq=23.2.0=py37h6a678d5_0 74 | - readline=8.2=h5eee18b_0 75 | - requests=2.28.1=py37h06a4308_0 76 | - rsa=4.7.2=pyhd3eb1b0_1 77 | - setuptools=65.6.3=py37h06a4308_0 78 | - sqlite=3.41.2=h5eee18b_0 79 | - tbb=2021.8.0=hdb19cb5_0 80 | - tk=8.6.12=h1ccaba5_0 81 | - torchdata=0.5.1=py37 82 | - tornado=6.2=py37h5eee18b_0 83 | - traitlets=5.7.1=py37h06a4308_0 84 | - typing-extensions=4.1.1=hd3eb1b0_0 85 | - typing_extensions=4.1.1=pyh06a4308_0 86 | - uritemplate=4.1.1=pyhd8ed1ab_0 87 | - urllib3=1.26.8=pyhd3eb1b0_0 88 | - wcwidth=0.2.5=pyhd3eb1b0_0 89 | - wheel=0.38.4=py37h06a4308_0 90 | - xz=5.4.2=h5eee18b_0 91 | - yarl=1.8.1=py37h5eee18b_0 92 | - zeromq=4.3.4=h2531618_0 93 | - zlib=1.2.13=h5eee18b_0 94 | - pip: 95 | - absl-py==0.7.1 96 | - astor==0.7.1 97 | - backports-weakref==1.0.post1 98 | - biopython==1.78 99 | - cycler==0.11.0 100 | - debugpy==1.6.7 101 | - descartes==1.1.0 102 | - enum34==1.1.6 103 | - funcsigs==1.0.2 104 | - gast==0.2.2 105 | - google-pasta==0.2.0 106 | - grpcio==1.19.0 107 | - h5py==2.9.0 108 | - ipykernel==6.16.2 109 | - joblib==1.3.1 110 | - jupyter-client==7.4.9 111 | - jupyter-core==4.12.0 112 | - keras-applications==1.0.8 113 | - keras-preprocessing==1.0.9 114 | - kiwisolver==1.4.4 115 | - markdown==3.0.1 116 | - matplotlib==3.4.3 117 | - mizani==0.7.3 118 | - mock==2.0.0 119 | - nest-asyncio==1.5.7 120 | - numpy==1.16.5 121 | - opt-einsum==3.3.0 122 | - packaging==23.1 123 | - palettable==3.3.3 124 | - pandas==1.1.2 125 | - patsy==0.5.3 126 | - pbr==5.1.3 127 | - pillow==9.5.0 128 | - plotly==4.11.0 129 | - plotnine==0.7.0 130 | - protobuf==3.20.3 131 | - psutil==5.9.5 132 | - python-dateutil 133 | - pytz==2018.9 134 | - retrying==1.3.4 135 | - scikit-learn==0.23.2 136 | - scipy==1.7.3 137 | - six==1.12.0 138 | - statsmodels==0.12.2 139 | - tensorboard==1.15.0 140 | - tensorflow-estimator==1.15.1 141 | - tensorflow-gpu==1.15.4 142 | - tensorflow-hub==0.9.0 143 | - termcolor==1.1.0 144 | - threadpoolctl==3.1.0 145 | - tqdm==4.62.2 146 | - typing==3.7.4.3 147 | - werkzeug==0.15.1 148 | - wrapt==1.15.0 149 | prefix: /anaconda/envs/proteinfer_env 150 | -------------------------------------------------------------------------------- /protnote/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/protnote/883ec27cc502cbd155c62ba49b3da7b18dd78177/protnote/.gitkeep -------------------------------------------------------------------------------- /protnote/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protnote/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protnote/data/collators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple 3 | 4 | 5 | def collate_variable_sequence_length( 6 | batch: List[Tuple], 7 | label_sample_size=None, 8 | distribute_labels=False, 9 | shuffle_labels=False, 10 | in_batch_sampling=False, 11 | grid_sampler=False, 12 | return_label_multihots=True, 13 | world_size=1, 14 | rank=0, 15 | ): 16 | """ 17 | Collates a batch of data with variable sequence lengths. Pads sequences to the maximum length within the batch to handle the variable 18 | lengths. 19 | 20 | Args: 21 | batch (List[Tuple]): A list of tuples, where each tuple represents one data point. 22 | Each tuple contains a dictionary with keys like 'sequence_onehots', 23 | 'sequence_length', 'label_multihots', etc. 24 | label_sample_size (int, optional): The number of labels to sample for training. 25 | Used with grid_sampler or in_batch_sampling. 26 | distribute_labels (bool, optional): Whether to distribute labels across different GPUs. 27 | return_label_multihots (bool, optional): Whether to batched multihot labels. 28 | shuffle_labels (bool, optional): Whether to shuffle labels during sampling. 29 | in_batch_sampling (bool, optional): If True, samples labels that are present within the batch. 30 | grid_sampler (bool, optional): If True, uses a grid sampling strategy for labels. 31 | world_size (int, optional): The total number of distributed processes or GPUs. 32 | rank (int, optional): The rank of the current process in distributed training. 33 | 34 | Returns: 35 | Dict: A dictionary containing the processed batch data. Keys include: 36 | - 'sequence_onehots': Tensor, padded one-hot encoded sequences. 37 | - 'sequence_ids': List, sequence IDs. 38 | - 'sequence_lengths': Tensor, lengths of sequences. 39 | - 'label_multihots': Tensor, multihot encoded labels (possibly sampled). 40 | - 'label_embeddings': Tensor, label embeddings if provided. Otherwise None. 41 | - 'label_token_counts': Tensor, token counts for each label. 42 | 43 | """ 44 | 45 | # Determine the maximum sequence length in the batch 46 | max_length = max(item["sequence_length"] for item in batch) 47 | 48 | # Initialize lists to store the processed values 49 | processed_sequence_onehots = [] 50 | processed_sequence_ids = [] 51 | processed_sequence_lengths = [] 52 | processed_label_multihots = [] 53 | processed_label_token_counts = [] 54 | processed_label_embeddings = None 55 | 56 | if grid_sampler: 57 | assert ( 58 | label_sample_size is not None 59 | ), "Must provide label_sample_size if using grid sampler" 60 | assert not in_batch_sampling, "Can't use in batch sampling with grid sampler" 61 | 62 | else: 63 | assert not ( 64 | in_batch_sampling and (label_sample_size is not None) 65 | ), "Cant use both in_batch_sampling with lable_sample_size" 66 | 67 | sampled_label_indices = None 68 | num_labels = batch[0]["label_multihots"].shape[0] 69 | 70 | if label_sample_size: 71 | if grid_sampler: 72 | sampled_label_indices = batch[0]["label_idxs"] 73 | else: 74 | if not distribute_labels: 75 | # If not distributing labels, sample from entire dataset 76 | sampled_label_indices = ( 77 | torch.randperm(num_labels)[:label_sample_size] 78 | if shuffle_labels 79 | else torch.arange(label_sample_size) 80 | ) 81 | else: 82 | # Otherwise, sample from the labels on this GPU 83 | labels_per_partition = num_labels // world_size 84 | start_idx = rank * labels_per_partition 85 | end_idx = start_idx + labels_per_partition 86 | partition_indices = torch.arange(start_idx, end_idx) 87 | sampled_label_indices = partition_indices[ 88 | torch.randperm(len(partition_indices))[ 89 | : label_sample_size // world_size 90 | ] 91 | ] 92 | # if rank < 2: 93 | # print("GPU {}. Sampling range: {} to {}. Sampled {} labels".format(rank, start_idx, end_idx, sampled_label_indices[:10])) 94 | 95 | elif in_batch_sampling: 96 | sampled_label_indices = torch.where( 97 | sum(i["label_multihots"] for i in batch) > 0 98 | )[0] 99 | 100 | # Apply the sampled labels to the label embeddings 101 | # We only use the first sequence in the batch to get the label embeddings to minimize complexity 102 | label_embeddings = batch[0]["label_embeddings"] 103 | 104 | # Likewise, we only use the first sequence in the batch to get the label token counts 105 | label_token_counts = batch[0]["label_token_counts"] 106 | 107 | if sampled_label_indices is not None: 108 | # Create a new tensor of embeddings with only the sampled labels 109 | processed_label_embeddings = label_embeddings[sampled_label_indices] 110 | # Otherwise, use the original label embeddings 111 | else: 112 | processed_label_embeddings = label_embeddings 113 | 114 | # Loop through the batch 115 | for row in batch: 116 | # Get the sequence onehots, sequence length, sequence id, and label multihots 117 | sequence_onehots = row["sequence_onehots"] 118 | sequence_id = row["sequence_id"] 119 | sequence_length = row["sequence_length"] 120 | label_multihots = row["label_multihots"] 121 | 122 | # Set padding 123 | padding_length = max_length - sequence_length 124 | 125 | # Get the sequence dimension (e.g., 20 for amino acids) 126 | sequence_dim = sequence_onehots.shape[0] 127 | 128 | # Pad the sequence to the max_length and append to the processed_sequences list 129 | processed_sequence_onehots.append( 130 | torch.cat( 131 | (sequence_onehots, torch.zeros((sequence_dim, padding_length))), dim=1 132 | ) 133 | ) 134 | 135 | # Use the sampled labels for each element in the batch. 136 | if sampled_label_indices is not None: 137 | label_multihots = label_multihots[sampled_label_indices] 138 | 139 | # Append the other values to the processed lists 140 | processed_sequence_ids.append(sequence_id) 141 | processed_sequence_lengths.append(sequence_length) 142 | processed_label_multihots.append(label_multihots) 143 | 144 | processed_batch = { 145 | "sequence_onehots": torch.stack(processed_sequence_onehots), 146 | "sequence_ids": processed_sequence_ids, 147 | "sequence_lengths": torch.stack(processed_sequence_lengths), 148 | "label_embeddings": processed_label_embeddings, 149 | "label_token_counts": label_token_counts, 150 | } 151 | 152 | if return_label_multihots: 153 | processed_batch["label_multihots"] = torch.stack(processed_label_multihots) 154 | 155 | return processed_batch 156 | -------------------------------------------------------------------------------- /protnote/data/samplers.py: -------------------------------------------------------------------------------- 1 | import random 2 | from itertools import product 3 | import numpy as np 4 | from torch.utils.data import BatchSampler 5 | from torch.utils.data import Sampler, WeightedRandomSampler 6 | from typing import Optional 7 | import math 8 | import torch 9 | from torch.utils.data.distributed import DistributedSampler 10 | import torch.distributed as dist 11 | import math 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class GeneralDistributedSampler(DistributedSampler): 16 | 17 | """ 18 | Class to use distributed sampler with any sampler! 19 | """ 20 | 21 | def __init__( 22 | self, 23 | sampler: Sampler, 24 | num_replicas: Optional[int] = None, 25 | rank: Optional[int] = None, 26 | seed: int = 0, 27 | drop_last: bool = False, 28 | ): 29 | # Same as normal DistributedSampler with shuffle = False 30 | super().__init__( 31 | dataset=sampler, 32 | num_replicas=num_replicas, 33 | rank=rank, 34 | shuffle=False, 35 | seed=seed, 36 | drop_last=drop_last, 37 | ) 38 | 39 | assert len(sampler) > num_replicas, "Total samples must be > num replicas" 40 | 41 | def __iter__(self): 42 | # deterministically shuffle based on epoch 43 | torch.manual_seed(self.epoch + self.seed) 44 | indices = list(self.dataset) 45 | 46 | if not self.drop_last: 47 | # add extra samples to make it evenly divisible 48 | padding_size = self.total_size - len(indices) 49 | if padding_size <= len(indices): 50 | indices += indices[:padding_size] 51 | else: 52 | indices += (indices * math.ceil(padding_size / len(indices)))[ 53 | :padding_size 54 | ] 55 | else: 56 | # remove tail of data to make it evenly divisible. 57 | indices = indices[: self.total_size] 58 | assert len(indices) == self.total_size 59 | 60 | # subsample 61 | indices = indices[self.rank : self.total_size : self.num_replicas] 62 | assert len(indices) == self.num_samples 63 | return iter(indices) 64 | 65 | 66 | class DistributedWeightedSampler(Sampler): 67 | def __init__(self, weights, world_size=None, rank=None, replacement=True): 68 | # Get the world size and rank if not provided 69 | if world_size is None: 70 | if not dist.is_available(): 71 | raise RuntimeError("Requires distributed package to be available") 72 | world_size = dist.get_world_size() 73 | if rank is None: 74 | rank = dist.get_rank() 75 | 76 | self.weights = weights 77 | self.world_size = world_size 78 | self.rank = rank 79 | self.epoch = 0 80 | self.replacement = replacement 81 | 82 | # Ensure weights is a tensor 83 | if not isinstance(self.weights, torch.Tensor): 84 | self.weights = torch.tensor(self.weights, dtype=torch.double) 85 | 86 | # Determine the number of samples for each GPU, rounding down to ensure it is evenly divisible 87 | self.num_samples = int(math.floor(len(self.weights) * 1.0 / self.world_size)) 88 | 89 | # Determine the total number of samples 90 | self.total_size = self.num_samples * self.world_size 91 | 92 | def __iter__(self): 93 | # Shuffle based on the epoch 94 | g = torch.Generator() 95 | g.manual_seed(self.epoch) 96 | 97 | # Create a weighted sample for the entire dataset 98 | if self.replacement: 99 | indices = torch.multinomial( 100 | self.weights, self.total_size, replacement=True, generator=g 101 | ) 102 | else: 103 | assert ( 104 | len(self.weights) > self.total_size 105 | ), "When sampling without replacement, number of samples to draw must be less than the number of elements in the dataset" 106 | indices = torch.multinomial( 107 | self.weights, self.total_size, replacement=False, generator=g 108 | ) 109 | 110 | # Subsample for the current process 111 | indices_for_one_gpu = indices[self.rank : self.total_size : self.world_size] 112 | 113 | # Shuffle each epoch 114 | indices_for_one_gpu = indices_for_one_gpu[ 115 | torch.randperm(len(indices_for_one_gpu), generator=g) 116 | ].tolist() 117 | 118 | return iter(indices_for_one_gpu) 119 | 120 | def __len__(self): 121 | return self.num_samples 122 | 123 | def set_epoch(self, epoch): 124 | self.epoch = epoch 125 | 126 | 127 | class GridBatchSampler(BatchSampler): 128 | def __init__( 129 | self, 130 | observation_sampler, 131 | observations_batch_size, 132 | drop_last_observation_batch, 133 | num_labels, 134 | labels_batch_size, 135 | shuffle_grid=True, 136 | ): 137 | self.observation_sampler = observation_sampler 138 | self.observations_batch_size = observations_batch_size 139 | self.drop_last_observation_batch = drop_last_observation_batch 140 | 141 | self.num_labels = num_labels 142 | self.labels_batch_size = labels_batch_size 143 | self.shuffle_grid = shuffle_grid 144 | self.labels_idxs = list(range(num_labels)) 145 | self.calculate_num_batches() 146 | 147 | def __iter__(self): 148 | random.shuffle(self.labels_idxs) 149 | print("Getting label batches...") 150 | observation_batches = self.get_observation_batches() 151 | print("Done...") 152 | 153 | print("Getting observation batches...") 154 | label_batches = self.get_label_batches() 155 | print("Done...") 156 | 157 | print("Getting combinations...") 158 | obs_labels_batch_combinations = list( 159 | product(observation_batches, label_batches) 160 | ) 161 | 162 | print("Done...") 163 | 164 | if self.shuffle_grid: 165 | print("Shuffling...") 166 | random.shuffle(obs_labels_batch_combinations) 167 | print("Done...") 168 | for observation_batch, label_batch in obs_labels_batch_combinations: 169 | yield list( 170 | product(observation_batch, [label_batch]) 171 | ) # [observation_batch,label_batch] 172 | 173 | def calculate_num_batches(self): 174 | num_label_batches = np.ceil(self.num_labels / self.labels_batch_size) 175 | num_observation_batches = ( 176 | np.ceil(len(self.observation_sampler) / self.observations_batch_size) 177 | if not self.drop_last_observation_batch 178 | else len(self.observation_sampler) // self.observations_batch_size 179 | ) 180 | print("Done...") 181 | 182 | self.total_num_batches = int(num_label_batches * num_observation_batches) 183 | print( 184 | f"num label batches = {num_label_batches}, num observation batches = {num_observation_batches}" 185 | ) 186 | print(f"total batches = {self.total_num_batches}") 187 | 188 | def __len__(self): 189 | return self.total_num_batches 190 | 191 | def get_label_batches(self): 192 | # n_chunks = int(np.ceil(self.num_labels/self.labels_batch_size)) 193 | return [ 194 | self.labels_idxs[i : i + self.labels_batch_size] 195 | for i in range(0, self.num_labels, self.labels_batch_size) 196 | ] 197 | 198 | def get_observation_batches(self): 199 | batches = [] 200 | 201 | if self.drop_last_observation_batch: 202 | observation_sampler_iter = iter(self.observation_sampler) 203 | while True: 204 | try: 205 | batch = [ 206 | next(observation_sampler_iter) 207 | for _ in range(self.observations_batch_size) 208 | ] 209 | batches.append(batch) 210 | except StopIteration: 211 | break 212 | else: 213 | batch = [0] * self.observations_batch_size 214 | idx_in_batch = 0 215 | for idx in self.observation_sampler: 216 | batch[idx_in_batch] = idx 217 | idx_in_batch += 1 218 | if idx_in_batch == self.observations_batch_size: 219 | batches.append(batch) 220 | idx_in_batch = 0 221 | batch = [0] * self.observations_batch_size 222 | if idx_in_batch > 0: 223 | batches.append(batch[:idx_in_batch]) 224 | return batches 225 | 226 | 227 | def observation_sampler_factory( 228 | distribute_labels: bool, 229 | weighted_sampling: bool, 230 | shuffle: bool, 231 | dataset: Dataset = None, 232 | world_size: int = 1, 233 | rank: int = 0, 234 | sequence_weights: torch.Tensor = None, 235 | ): 236 | if distribute_labels and not weighted_sampling: 237 | print("WARNING: No Sampler used for distribute labels") 238 | sampler = None 239 | elif not distribute_labels and world_size == 1 and weighted_sampling: 240 | # If NOT distributing labels, and not training on multiple GPU's, create a non-distributed weighted sampler with replacement 241 | assert sequence_weights is not None, "Weighted RandomSampler requires weights" 242 | 243 | sampler = WeightedRandomSampler( 244 | sequence_weights, len(sequence_weights), replacement=True 245 | ) 246 | elif not distribute_labels and world_size > 1 and weighted_sampling: 247 | # If distributing sequences across multiple GPUs with a weighted sampler, create custom DistributedWeightedSampler 248 | sampler = DistributedWeightedSampler( 249 | sequence_weights, 250 | world_size=world_size, 251 | rank=rank, 252 | replacement=True, 253 | ) 254 | elif not distribute_labels and not weighted_sampling: 255 | # If simply distributing sequences across GPU's without weighted sampling, use a distributed sampler 256 | 257 | assert dataset is not None, "DistributeSampler requires dataset" 258 | 259 | sampler = DistributedSampler( 260 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle 261 | ) 262 | else: 263 | # Raise error 264 | raise ValueError( 265 | "Invalid combination of WEIGHTED_SAMPLING, WORLD_SIZE, and DISTRIBUTE_LABELS parameters" 266 | ) 267 | 268 | return sampler 269 | -------------------------------------------------------------------------------- /protnote/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protnote/models/blast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.simplefilter("ignore") 4 | from Bio.Blast.Applications import NcbimakeblastdbCommandline, NcbiblastpCommandline 5 | from tqdm import tqdm 6 | import time 7 | from functools import partial 8 | from protnote.utils.data import read_fasta 9 | from protnote.utils.data import tqdm_joblib 10 | from protnote.utils.configs import get_logger 11 | import pandas as pd 12 | import multiprocessing 13 | from joblib import Parallel, delayed 14 | 15 | 16 | class BlastTopHits: 17 | def __init__(self, db_fasta_path: str, queries_fasta_path: str): 18 | self.db_fasta_path = db_fasta_path 19 | self.db_path = ".".join(db_fasta_path.split(".")[:-1]) 20 | self.queries_fasta_path = queries_fasta_path 21 | 22 | # test_dict = {uid:go_terms for _,uid,go_terms in test} 23 | # self.queries = read_fasta(queries_fasta_path) 24 | self.num_threads = multiprocessing.cpu_count() 25 | self.logger = get_logger() 26 | self.db_seq_2_labels = None 27 | # TODO: Columns of interest should be an argument and a list instead of dict internaly 28 | # using a mapping from friendly name to actual name in the outfmt 29 | self.columns = { 30 | 0: "sequence_name", 31 | 1: "closest_sequence", 32 | 2: "percent_seq_identity", 33 | 3: "e_value", 34 | 4: "bit_score", 35 | } 36 | 37 | def make_db(self): 38 | makeblastdb_cline = NcbimakeblastdbCommandline( 39 | dbtype="prot", input_file=self.db_fasta_path, out=self.db_path 40 | ) 41 | _ = makeblastdb_cline() 42 | 43 | def blast_db_exists(self): 44 | required_extensions = [".pin", ".phr", ".psq"] 45 | for ext in required_extensions: 46 | if not os.path.exists(f"{self.db_path}{ext}"): 47 | return False 48 | return True 49 | 50 | def run_blast(self, output_path: str, top_k_hits: int): 51 | if not self.blast_db_exists(): 52 | self.logger.info("Creating a BLAST database for optimal performance") 53 | self.make_db() 54 | self.logger.info("Finished creating database") 55 | else: 56 | self.logger.info("Found existing database") 57 | 58 | self.logger.info("Initiating BLAST Search") 59 | get_results_start_time = time.time() 60 | blastp_cline = NcbiblastpCommandline( 61 | query=self.queries_fasta_path, 62 | # subject=self.db_fasta_path, 63 | db=self.db_path, 64 | outfmt="6 qseqid sseqid pident evalue bitscore", 65 | out=output_path, 66 | num_threads=self.num_threads, 67 | max_target_seqs=top_k_hits, 68 | ) 69 | stdout, stderr = blastp_cline() 70 | get_results_end_time = time.time() 71 | self.run_duration_seconds = get_results_end_time - get_results_start_time 72 | 73 | results = pd.read_csv( 74 | output_path, sep="\t", header=None, names=list(self.columns.values()) 75 | ) 76 | results = results.loc[results.groupby("sequence_name")["bit_score"].idxmax()] 77 | results.to_csv(output_path, sep="\t", header=None, index=False) 78 | 79 | self.logger.info( 80 | f"BLAST search completed in {self.run_duration_seconds:.2f} seconds." 81 | ) 82 | 83 | def __parse_blast_line(self, line: str) -> dict: 84 | data = line.strip().split("\t") 85 | parsed_line = {colname: data[idx] for idx, colname in self.columns.items()} 86 | return parsed_line 87 | 88 | def __transfer_hit_labels(self, parsed_line: dict): 89 | if self.db_seq_2_labels is None: 90 | self.db_seq_2_labels = { 91 | uid: go_terms for _, uid, go_terms in read_fasta(self.db_fasta_path) 92 | } 93 | parsed_line["transferred_labels"] = self.db_seq_2_labels[ 94 | parsed_line["closest_sequence"] 95 | ] 96 | 97 | def parse_blast_line( 98 | self, line, transfer_labels: bool, flatten_labels: bool = True 99 | ) -> list: 100 | fully_parsed_line = [] 101 | parsed_line = self.__parse_blast_line(line=line) 102 | if transfer_labels: 103 | self.__transfer_hit_labels(parsed_line=parsed_line) 104 | if flatten_labels: 105 | # Replicate each parsed_line result once for each label 106 | transferred_labels = parsed_line.pop("transferred_labels") 107 | fully_parsed_line.extend( 108 | [ 109 | {"transferred_labels": i, **parsed_line} 110 | for i in transferred_labels 111 | ] 112 | ) 113 | else: 114 | fully_parsed_line.append(parsed_line) 115 | else: 116 | fully_parsed_line.append(parsed_line) 117 | 118 | return fully_parsed_line 119 | 120 | def parse_results( 121 | self, 122 | blast_results_path: str, 123 | transfer_labels: bool, 124 | flatten_labels: bool = True, 125 | ) -> pd.DataFrame: 126 | self.logger.info("Parsing BLAST results.") 127 | 128 | parse_results_start_time = time.time() 129 | 130 | with open(blast_results_path, "r") as handle: 131 | lines = handle.readlines() 132 | wrapper = partial( 133 | self.parse_blast_line, 134 | transfer_labels=transfer_labels, 135 | flatten_labels=flatten_labels, 136 | ) 137 | 138 | # parsed_results = process_map(wrapper,handle,max_workers = self.num_threads) 139 | with tqdm_joblib(tqdm(total=len(lines))) as pbar: 140 | parsed_results = Parallel(n_jobs=self.num_threads)( 141 | delayed(wrapper)(line) for line in lines 142 | ) 143 | 144 | # Flatten the list of lists into a single list 145 | flattened_parsed_results = [ 146 | item for sublist in parsed_results for item in sublist 147 | ] 148 | 149 | parse_results_end_time = time.time() 150 | 151 | self.parse_results_duration_seconds = ( 152 | parse_results_end_time - parse_results_start_time 153 | ) 154 | self.logger.info( 155 | f"BLAST parsing completed in {self.parse_results_duration_seconds:.2f} seconds." 156 | ) 157 | 158 | return pd.DataFrame.from_records(flattened_parsed_results) -------------------------------------------------------------------------------- /protnote/models/protein_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Literal, Text, Optional 3 | import numpy as np 4 | from ..data.datasets import set_padding_to_sentinel 5 | from protnote.utils.proteinfer import transfer_tf_weights_to_torch 6 | 7 | 8 | class MaskedConv1D(torch.nn.Conv1d): 9 | def forward(self, x, sequence_lengths): 10 | """ 11 | Correct for padding before and after. Can be redundant 12 | but reduces overhead of setting padding to sentiel in other contexts. 13 | """ 14 | x = set_padding_to_sentinel(x, sequence_lengths, 0) 15 | x = super().forward(x) 16 | x = set_padding_to_sentinel(x, sequence_lengths, 0) 17 | return x 18 | 19 | 20 | # ResNet-V2 https://arxiv.org/pdf/1602.07261v2.pdf 21 | 22 | 23 | class Residual(torch.nn.Module): 24 | def __init__( 25 | self, 26 | input_channels: int, 27 | kernel_size: int, 28 | dilation: int, 29 | bottleneck_factor: float, 30 | activation=torch.nn.ReLU, 31 | ): 32 | super().__init__() 33 | 34 | bottleneck_out_channels = int(np.floor(input_channels * bottleneck_factor)) 35 | self.bn_activation_1 = torch.nn.Sequential( 36 | torch.nn.BatchNorm1d(input_channels, eps=0.001, momentum=0.01), activation() 37 | ) 38 | 39 | self.masked_conv1 = MaskedConv1D( 40 | in_channels=input_channels, 41 | out_channels=bottleneck_out_channels, 42 | padding="same", 43 | kernel_size=kernel_size, 44 | stride=1, 45 | dilation=dilation, 46 | ) 47 | self.bn_activation_2 = torch.nn.Sequential( 48 | torch.nn.BatchNorm1d(bottleneck_out_channels, eps=0.001, momentum=0.01), 49 | activation(), 50 | ) 51 | 52 | self.masked_conv2 = MaskedConv1D( 53 | in_channels=bottleneck_out_channels, 54 | out_channels=input_channels, 55 | padding="same", 56 | kernel_size=1, 57 | stride=1, 58 | dilation=1, 59 | ) 60 | 61 | def forward(self, x, sequence_lengths): 62 | out = self.bn_activation_1(x) 63 | out = self.masked_conv1(out, sequence_lengths) 64 | out = self.bn_activation_2(out) 65 | out = self.masked_conv2(out, sequence_lengths) 66 | out = out + x 67 | return out 68 | 69 | 70 | class ProteInfer(torch.nn.Module): 71 | def __init__( 72 | self, 73 | num_labels: int, 74 | input_channels: int, 75 | output_channels: int, 76 | kernel_size: int, 77 | activation, 78 | dilation_base: int, 79 | num_resnet_blocks: int, 80 | bottleneck_factor: float, 81 | ): 82 | super().__init__() 83 | 84 | self.conv1 = MaskedConv1D( 85 | in_channels=input_channels, 86 | out_channels=output_channels, 87 | padding="same", 88 | kernel_size=kernel_size, 89 | stride=1, 90 | dilation=1, 91 | ) 92 | self.resnet_blocks = torch.nn.ModuleList() 93 | 94 | for i in range(num_resnet_blocks): 95 | self.resnet_blocks.append( 96 | Residual( 97 | input_channels=output_channels, 98 | kernel_size=kernel_size, 99 | dilation=dilation_base**i, 100 | bottleneck_factor=bottleneck_factor, 101 | activation=activation, 102 | ) 103 | ) 104 | 105 | self.output_layer = torch.nn.Linear( 106 | in_features=output_channels, out_features=num_labels 107 | ) 108 | 109 | def get_embeddings(self, x, sequence_lengths): 110 | features = self.conv1(x, sequence_lengths) 111 | # Sequential doesn't work here because of multiple inputs 112 | for idx, resnet_block in enumerate(self.resnet_blocks): 113 | features = resnet_block(features, sequence_lengths) 114 | features = set_padding_to_sentinel(features, sequence_lengths, 0) 115 | features = torch.sum(features, dim=-1) / sequence_lengths.unsqueeze( 116 | -1 117 | ) # Average pooling 118 | return features 119 | 120 | def forward(self, x, sequence_lengths): 121 | features = self.get_embeddings(x, sequence_lengths) 122 | logits = self.output_layer(features) 123 | return logits 124 | 125 | @classmethod 126 | def from_pretrained( 127 | cls, 128 | weights_path: str, 129 | num_labels: int, 130 | input_channels: int, 131 | output_channels: int, 132 | kernel_size: int, 133 | activation, 134 | dilation_base: int, 135 | num_resnet_blocks: int, 136 | bottleneck_factor: float, 137 | ): 138 | """ 139 | Load a pretrained model from a path or url. 140 | """ 141 | model = cls( 142 | num_labels, 143 | input_channels, 144 | output_channels, 145 | kernel_size, 146 | activation, 147 | dilation_base, 148 | num_resnet_blocks, 149 | bottleneck_factor, 150 | ) 151 | transfer_tf_weights_to_torch(model, weights_path) 152 | 153 | return model 154 | -------------------------------------------------------------------------------- /protnote/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protnote/utils/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import logging 5 | from protnote.utils.data import read_yaml 6 | import sys 7 | from ast import literal_eval 8 | from pathlib import Path 9 | 10 | 11 | def get_logger(): 12 | # Create a custom logger 13 | logger = logging.getLogger(__name__) 14 | 15 | # Set the logging level to INFO 16 | logger.setLevel(logging.INFO) 17 | 18 | # Create a console handler and set its level to INFO 19 | console_handler = logging.StreamHandler() 20 | console_handler.setLevel(logging.INFO) 21 | 22 | # Create a formatter that includes the current date and time 23 | formatter = logging.Formatter( 24 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 25 | ) 26 | 27 | # Set the formatter for the console handler 28 | console_handler.setFormatter(formatter) 29 | 30 | # Add the console handler to the logger 31 | logger.addHandler(console_handler) 32 | 33 | # Example usage 34 | logger.info("This is an info message.") 35 | return logger 36 | 37 | 38 | def try_literal_eval(val): 39 | try: 40 | # Attempt to evaluate as a literal 41 | return literal_eval(val) 42 | except (ValueError, SyntaxError): 43 | # If evaluation fails means input is actually a string 44 | if val == "null": 45 | return None 46 | if (val == "false") | (val == "true"): 47 | return literal_eval(val.title()) 48 | return val 49 | 50 | 51 | def override_config(config: dict, overrides: list): 52 | # Process the overrides if provided 53 | if overrides: 54 | if len(overrides) % 2 != 0: 55 | raise ValueError("Overrides must be provided as key-value pairs.") 56 | 57 | # Convert the list to a dictionary 58 | overrides = { 59 | overrides[i]: overrides[i + 1] for i in range(0, len(overrides), 2) 60 | } 61 | 62 | # Update the config with the overrides 63 | for key, value in overrides.items(): 64 | # Convert value to appropriate type if necessary (e.g., float, int) 65 | # Here, we're assuming that the provided keys exist in the 'params' section of the config 66 | if key in config["params"]: 67 | config["params"][key] = try_literal_eval(value) 68 | else: 69 | raise KeyError( 70 | f"Key '{key}' not found in the 'params' section of the config." 71 | ) 72 | 73 | 74 | def generate_label_embedding_path(params: dict, base_label_embedding_path: str): 75 | """ 76 | Generates the name of the file that caches label embeddings. Needed due to different 77 | ways of pooling embeddings, different types of go descriptions and other paramters. 78 | This way we can store different versions/types of label embeddings for caching 79 | """ 80 | assert params["LABEL_ENCODER_CHECKPOINT"] in [ 81 | "microsoft/biogpt", 82 | "intfloat/e5-large-v2", 83 | "intfloat/multilingual-e5-large-instruct", 84 | ], "Model not supported" 85 | 86 | MODEL_NAME_2_NICKNAME = { 87 | "microsoft/biogpt": "BioGPT", 88 | "intfloat/e5-large-v2": "E5", 89 | "intfloat/multilingual-e5-large-instruct": "E5_multiling_inst", 90 | } 91 | 92 | label_embedding_path = base_label_embedding_path.split("/") 93 | temp = label_embedding_path[-1].split(".") 94 | 95 | base_model = temp[0].split("_") 96 | base_model = "_".join( 97 | [base_model[0]] 98 | + [MODEL_NAME_2_NICKNAME[params["LABEL_ENCODER_CHECKPOINT"]]] 99 | + base_model[1:] 100 | ) 101 | 102 | label_embedding_path[-1] = ( 103 | base_model + "_" + params["LABEL_EMBEDDING_POOLING_METHOD"] + "." + temp[1] 104 | ) 105 | 106 | label_embedding_path = "/".join(label_embedding_path) 107 | return label_embedding_path 108 | 109 | 110 | def get_setup( 111 | config_path: str, 112 | run_name: str, 113 | overrides: list, 114 | train_path_name: str = None, 115 | val_path_name: str = None, 116 | test_paths_names: list = None, 117 | annotations_path_name: str = None, 118 | base_label_embedding_name: str = None, 119 | amlt: bool = False, 120 | is_master: bool = True, 121 | ): 122 | # Get the root path from the environment variable; default to current directory if ROOT_PATH is not set 123 | if amlt: 124 | ROOT_PATH = os.getcwd() # Set ROOT_PATH to working directory 125 | DATA_PATH = os.environ["AMLT_DATA_DIR"] 126 | OUTPUT_PATH = os.environ["AMLT_OUTPUT_DIR"] 127 | else: 128 | ROOT_PATH = str(Path(os.path.dirname(__file__)).parents[1]) 129 | print(ROOT_PATH) 130 | DATA_PATH = os.path.join(ROOT_PATH, "data") 131 | OUTPUT_PATH = os.path.join(ROOT_PATH, "outputs") 132 | if not os.path.exists(OUTPUT_PATH) and is_master: 133 | os.makedirs(OUTPUT_PATH) 134 | 135 | # Load the configuration file and override the parameters if provided 136 | config = read_yaml(os.path.join(ROOT_PATH, config_path)) 137 | if overrides: 138 | override_config(config, overrides) 139 | 140 | # Extract the model parameters from the (possibly overidden) config file 141 | params = config["params"] 142 | 143 | # Extract the fixed ProteInfer params from the config file 144 | embed_sequences_params = config["embed_sequences_params"] 145 | 146 | # Prepend the correct path roots 147 | # Define root paths for each section 148 | section_paths = { 149 | "data_paths": DATA_PATH, 150 | "output_paths": OUTPUT_PATH, 151 | } 152 | paths = { 153 | key: os.path.join(section_paths[section], value) 154 | for section, section_values in config["paths"].items() 155 | for key, value in section_values.items() 156 | } 157 | 158 | train_paths_list = ( 159 | [ 160 | { 161 | "data_path": paths[train_path_name], 162 | "dataset_type": "train", 163 | "annotations_path": paths[annotations_path_name], 164 | "vocabularies_dir": paths["VOCABULARIES_DIR"], 165 | } 166 | ] 167 | if train_path_name is not None 168 | else [] 169 | ) 170 | 171 | val_paths_list = ( 172 | [ 173 | { 174 | "data_path": paths[val_path_name], 175 | "dataset_type": "validation", 176 | "annotations_path": paths[annotations_path_name], 177 | "vocabularies_dir": paths["VOCABULARIES_DIR"], 178 | } 179 | ] 180 | if val_path_name is not None 181 | else [] 182 | ) 183 | 184 | test_paths_list = ( 185 | [ 186 | { 187 | "data_path": paths[key], 188 | "dataset_type": "test", 189 | "annotations_path": paths[annotations_path_name], 190 | "vocabularies_dir": paths["VOCABULARIES_DIR"], 191 | } 192 | for key in test_paths_names 193 | ] 194 | if test_paths_names is not None 195 | else [] 196 | ) 197 | 198 | dataset_paths = { 199 | "train": train_paths_list, 200 | "validation": val_paths_list, 201 | "test": test_paths_list, 202 | } 203 | 204 | # Set the timezone for the entire Python environment 205 | os.environ["TZ"] = "US/Pacific" 206 | time.tzset() 207 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S %Z").strip() 208 | 209 | # Initialize logging 210 | log_dir = paths["LOG_DIR"] 211 | if not os.path.exists(log_dir) and is_master: 212 | try: 213 | os.makedirs(log_dir) 214 | except FileExistsError: 215 | print(f"Log directory {log_dir} already exists. is_master={is_master}") 216 | pass 217 | full_log_path = os.path.join(log_dir, f"{timestamp}_{run_name}.log") 218 | 219 | # Initialize the logger for all processes 220 | logger = logging.getLogger() 221 | 222 | # Only log to file and console if this is the main process 223 | if is_master: 224 | # Set the logging level (default for other processes is WARNING) 225 | logger.setLevel(logging.INFO) 226 | 227 | # Create a formatter 228 | formatter = logging.Formatter( 229 | fmt="%(asctime)s %(levelname)-4s %(message)s", 230 | datefmt="%Y-%m-%d %H:%M:%S %Z", 231 | ) 232 | 233 | # Create a file handler and add it to the logger 234 | file_handler = logging.FileHandler(full_log_path, mode="w") 235 | file_handler.setFormatter(formatter) 236 | logger.addHandler(file_handler) 237 | 238 | # Create a stream handler (for stdout) and add it to the logger 239 | stream_handler = logging.StreamHandler(sys.stdout) 240 | stream_handler.setFormatter(formatter) 241 | logger.addHandler(stream_handler) 242 | 243 | logger.info(f"Logging to {full_log_path} and console...") 244 | else: 245 | # Set the logger level to an unreachable level, effectively disabling it 246 | logger.setLevel(logging.CRITICAL + 1) 247 | 248 | # Generate embeddings 249 | label_embedding_path = generate_label_embedding_path( 250 | params=params, base_label_embedding_path=paths[base_label_embedding_name] 251 | ) 252 | 253 | # Return a dictionary 254 | return { 255 | "params": params, 256 | "embed_sequences_params": embed_sequences_params, 257 | "paths": paths, 258 | "dataset_paths": dataset_paths, 259 | "timestamp": timestamp, 260 | "logger": logger, 261 | "ROOT_PATH": ROOT_PATH, 262 | "DATA_PATH": DATA_PATH, 263 | "OUTPUT_PATH": OUTPUT_PATH, 264 | "LABEL_EMBEDDING_PATH": label_embedding_path, 265 | } 266 | 267 | 268 | def get_project_root(): 269 | """Dynamically determine the project root.""" 270 | return Path(__file__).resolve().parent.parent.parent # Adjust based on the folder structure 271 | 272 | def update_config_paths(config, project_root): 273 | # Prepend project_root / 'data' to all paths in the 'data' section 274 | for key, value in config['paths'].get('data_paths', {}).items(): 275 | config['paths']['data_paths'][key] = project_root / 'data' / value 276 | 277 | for key, value in config['paths'].get('output_paths', {}).items(): 278 | config['paths']['output_paths'][key] = project_root / 'outputs' / value 279 | 280 | return config 281 | 282 | def load_config(config_file:str = 'base_config.yaml'): 283 | """Load the environment variables and YAML configuration file.""" 284 | project_root = get_project_root() 285 | 286 | # Load the YAML configuration file from the project root 287 | config_file = project_root / 'configs' / config_file 288 | config = update_config_paths(read_yaml(config_file),project_root) 289 | 290 | return config, project_root 291 | 292 | def construct_absolute_paths(dir:str,files:list)->list: 293 | return [Path(dir) / file for file in files] 294 | -------------------------------------------------------------------------------- /protnote/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | # NOT CURRENTLY USING THIS 7 | class SupCon(torch.nn.Module): 8 | def __init__(self, temperature: float, base_temperature=0.07): 9 | super().__init__() 10 | self.temperature = temperature 11 | self.base_temperature = base_temperature 12 | assert temperature is not None, "temperature must be provided and not None" 13 | 14 | def forward(self, input, target): 15 | """ 16 | Computes the symmetric contrastive loss between protein and label embeddings, given a temperature parameter and a ground truth annotation matrix. 17 | P_e: Batch of projected protein embeddings [batch_size, joint_embedding_dim] 18 | L_e: Batch of projected GO label embeddings [batch_size, joint_embedding_dim] 19 | t: Temperature parameter 20 | target: Batch of "target" GO annotations [batch_size, num_labels] 21 | """ 22 | # Compute loss for each direction (protein to label and label to protein) 23 | overall_loss_p = one_way_supcon(logits=input, labels_multihot=target, dim=1) 24 | """ 25 | overall_loss_l = one_way_supcon(logits=input, 26 | labels_multihot=target, 27 | dim=0)""" 28 | 29 | # Return the average of the two losses 30 | # 31 | # return (overall_loss_p + overall_loss_l) / 2 32 | return overall_loss_p 33 | 34 | 35 | def one_way_supcon(logits, labels_multihot, dim): 36 | """ 37 | dim=1 is traditional entropy of predicting labels 38 | """ 39 | # for numerical stability 40 | logits_max, _ = torch.max(logits, dim=dim, keepdim=True) 41 | logits = logits - logits_max.detach() 42 | 43 | # compute log_prob 44 | exp_logits = torch.exp(logits) 45 | log_prob = logits - torch.log(exp_logits.sum(dim, keepdim=True)) 46 | 47 | # compute mean of log-likelihood over positive 48 | norm = labels_multihot.sum(dim) 49 | mean_log_prob_pos = (labels_multihot * log_prob).sum(dim) / norm 50 | mean_log_prob_pos = torch.nan_to_num( 51 | mean_log_prob_pos, 0 52 | ) # In case dim = 0 and there are labels always negative 53 | # loss 54 | loss = -mean_log_prob_pos.mean() 55 | return loss 56 | 57 | 58 | class RGDBCE(torch.nn.Module): 59 | def __init__(self, temperature: float): 60 | super().__init__() 61 | self.temperature = temperature 62 | assert temperature is not None, "temperature must be provided and not None" 63 | 64 | def forward(self, input, target): 65 | loss = torch.nn.functional.binary_cross_entropy_with_logits( 66 | input, target, reduce="none" 67 | ) 68 | return ( 69 | loss 70 | * torch.exp( 71 | torch.clamp(loss.detach(), max=self.temperature) 72 | / (self.temperature + 1) 73 | ) 74 | ).mean() 75 | 76 | 77 | class CBLoss(torch.nn.Module): 78 | def __init__(self, label_weights: torch.Tensor, beta=0.9999): 79 | super().__init__() 80 | 81 | self.label_weights = label_weights 82 | self.beta = beta 83 | assert label_weights is not None, "label_weights must be provided and not None" 84 | 85 | def forward(self, input, target): 86 | no_of_classes = len(self.label_weights) 87 | effective_num = 1.0 - torch.pow(self.beta, self.label_weights) 88 | 89 | # Replace zeros in effective_num with 'inf' (infinity) to avoid division by zero 90 | effective_num = torch.where( 91 | effective_num == 0, torch.tensor(float("inf")), effective_num 92 | ) 93 | 94 | weights = (1.0 - self.beta) / effective_num 95 | weights = weights / torch.sum(weights) * no_of_classes 96 | 97 | weights = get_batch_weights_v2(weights, target) 98 | cb_loss = F.binary_cross_entropy_with_logits( 99 | input=input, target=target, weight=weights 100 | ) 101 | 102 | return cb_loss 103 | 104 | 105 | class WeightedBCE(torch.nn.Module): 106 | def __init__(self, label_weights: torch.Tensor): 107 | super().__init__() 108 | assert label_weights is not None, "label_weights must be provided and not None" 109 | self.label_weights = label_weights 110 | 111 | def forward(self, input, target): 112 | batch_weights = get_batch_weights_v2( 113 | label_weights=self.label_weights, target=target 114 | ) 115 | return torch.nn.functional.binary_cross_entropy_with_logits( 116 | input, target, weight=batch_weights 117 | ) 118 | 119 | 120 | class BatchWeightedBCE(torch.nn.Module): 121 | def __init__(self, epsilon=1e-10): 122 | super().__init__() 123 | 124 | self.epsilon = epsilon 125 | 126 | def forward(self, input, target): 127 | # Count the number of positives and negatives in the batch 128 | num_positives = target.sum() + self.epsilon 129 | num_negatives = target.numel() - num_positives + self.epsilon 130 | 131 | # Calculate the weights for positives and negatives 132 | total = num_positives + num_negatives 133 | weight_positives = (1.0 / num_positives) * (total / 2.0) 134 | weight_negatives = (1.0 / num_negatives) * (total / 2.0) 135 | 136 | # Create a weight tensor with the same shape as target 137 | weight_tensor = target * weight_positives + (1 - target) * weight_negatives 138 | 139 | # Compute weighted binary cross-entropy loss 140 | return torch.nn.functional.binary_cross_entropy_with_logits( 141 | input, target, weight=weight_tensor 142 | ) 143 | 144 | 145 | class BatchLabelWeightedBCE(torch.nn.Module): 146 | def __init__(self, epsilon=1e-10): 147 | super().__init__() 148 | 149 | self.epsilon = epsilon 150 | 151 | def forward(self, input, target): 152 | total_labels = ( 153 | target.sum() + self.epsilon 154 | ) # epsilon in case all labels are 0 due to high imbalance 155 | label_frequencies = target.sum(axis=0) / total_labels 156 | 157 | label_frequencies = torch.where( 158 | label_frequencies == 0, 159 | torch.ones_like(label_frequencies), 160 | 1 / label_frequencies, 161 | ) 162 | 163 | weights = label_frequencies / label_frequencies.sum() 164 | 165 | # Compute weighted binary cross-entropy loss 166 | return torch.nn.functional.binary_cross_entropy_with_logits( 167 | input, target, weight=weights.unsqueeze(0) 168 | ) 169 | 170 | 171 | class FocalLoss(torch.nn.Module): 172 | def __init__( 173 | self, alpha: float, gamma: float, reduction="mean", label_smoothing=0.0 174 | ): 175 | super().__init__() 176 | self.alpha = alpha 177 | self.gamma = gamma 178 | self.reduction = reduction 179 | self.label_smoothing = label_smoothing 180 | 181 | assert (alpha is not None) & ( 182 | gamma is not None 183 | ), "Both gamma and alpha must be provided and neither should be None" 184 | print( 185 | "Focal Loss with alpha: {}, gamma: {}, reduction: {}, label_smoothing: {}".format( 186 | alpha, gamma, reduction, label_smoothing 187 | ) 188 | ) 189 | 190 | def forward(self, input, target): 191 | # Apply label smoothing, if applicable 192 | if self.label_smoothing > 0: 193 | positive_smoothed_labels = 1.0 - self.label_smoothing 194 | negative_smoothed_labels = self.label_smoothing 195 | target = ( 196 | target * positive_smoothed_labels 197 | + (1 - target) * negative_smoothed_labels 198 | ) 199 | 200 | BCE_loss = torch.nn.BCEWithLogitsLoss(reduction="none")(input, target) 201 | pt = torch.exp(-BCE_loss) 202 | loss = ((1 - pt) ** self.gamma) * BCE_loss 203 | 204 | if self.alpha >= 0: 205 | alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target) 206 | loss = alpha_t * loss 207 | 208 | if self.reduction == "mean": 209 | return loss.mean() 210 | elif self.reduction == "sum": 211 | return loss.sum() 212 | else: 213 | return loss 214 | 215 | 216 | def get_batch_weights_v2(label_weights, target): 217 | """ 218 | Computes the weights for each sample in the batch based on the target labels 219 | using broadcasting. 220 | 221 | Args: 222 | label_weights: torch.tensor of size [no_of_classes] with the weight of each label. 223 | target: torch.tensor of size [batch, no_of_classes]. 224 | 225 | Returns: 226 | weights_for_samples: torch.tensor of size [batch, no_of_classes]. 227 | """ 228 | 229 | # Ensure label_weights is a float tensor for correct broadcasting and computation 230 | label_weights = label_weights.float() 231 | 232 | # Multiply weights with target labels using broadcasting 233 | # This step applies the specific class weights to the corresponding labels in the target. 234 | weighted_targets = label_weights * target 235 | 236 | # Sum the weighted targets along the class dimension to get a single weight per sample 237 | weights_for_samples = weighted_targets.sum(dim=1, keepdim=True) 238 | 239 | # Use broadcasting again for expanding weights across the class dimension 240 | # No need to repeat the tensor explicitly. 241 | weights_for_samples = weights_for_samples.expand_as(target) 242 | 243 | return weights_for_samples 244 | 245 | 246 | def get_batch_weights_v1(label_weights, target=None): 247 | """ 248 | This function applies the given Sample Weighting Scheme and returns the sample weights normalized over a batch 249 | Args: 250 | label_weights: A tensor of size [no_of_classes] with the weight of each label 251 | target: torch.tensor of size [batch, no_of_classes] 252 | beta: float, 253 | 254 | Returns: 255 | weights_for_samples: torch.tensor of size [batch, no_of_classes] 256 | """ 257 | 258 | no_of_classes = len(label_weights) 259 | 260 | weights_for_samples = label_weights.unsqueeze(0) 261 | weights_for_samples = torch.tensor( 262 | np.array(weights_for_samples.repeat(target.shape[0], 1) * target) 263 | ) 264 | weights_for_samples = weights_for_samples.sum(1) 265 | weights_for_samples = weights_for_samples.unsqueeze(1) 266 | weights_for_samples = weights_for_samples.repeat(1, no_of_classes) 267 | return weights_for_samples 268 | 269 | 270 | def get_loss( 271 | config: dict, 272 | label_weights: torch.Tensor = None, 273 | bce_pos_weight: torch.Tensor = None, 274 | ): 275 | if config["params"]["LOSS_FN"] == "BCE": 276 | return torch.nn.BCEWithLogitsLoss(reduction="mean", pos_weight=bce_pos_weight) 277 | elif config["params"]["LOSS_FN"] == "WeightedBCE": 278 | return WeightedBCE(label_weights=label_weights) 279 | elif config["params"]["LOSS_FN"] == "CBLoss": 280 | return CBLoss(label_weights=label_weights) 281 | elif config["params"]["LOSS_FN"] == "BatchWeightedBCE": 282 | return BatchWeightedBCE() 283 | elif config["params"]["LOSS_FN"] == "FocalLoss": 284 | return FocalLoss( 285 | gamma=config["params"]["FOCAL_LOSS_GAMMA"], 286 | alpha=config["params"]["FOCAL_LOSS_ALPHA"], 287 | label_smoothing=config["params"]["LABEL_SMOOTHING"], 288 | ) 289 | elif config["params"]["LOSS_FN"] == "RGDBCE": 290 | return RGDBCE(temperature=config["params"]["RGDBCE_TEMP"]) 291 | elif config["params"]["LOSS_FN"] == "SupCon": 292 | return SupCon(temperature=config["params"]["SUPCON_TEMP"]) 293 | else: 294 | raise ValueError(f"Unknown loss function {config['params']['LOSS_FN']}") 295 | -------------------------------------------------------------------------------- /protnote/utils/main_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from protnote.data.collators import collate_variable_sequence_length 4 | from torch.utils.data import ConcatDataset, DataLoader 5 | import pandas as pd 6 | from functools import partial 7 | import warnings 8 | 9 | 10 | def validate_arguments(args, parser): 11 | # Ensure the full data path is provided, or we are using the zero shot model 12 | if args.full_path_name is None and "zero" not in str(args.train_path_name).lower(): 13 | warnings.warn( 14 | "The full path name is not provided and the train path name does not contain the word 'zero'. Please ensure this is intentional." 15 | ) 16 | 17 | # Raise error if train is provided without val 18 | if args.train_path_name is not None: 19 | if args.validation_path_name is None: 20 | parser.error( 21 | "If providing --train-path-name you must provide --val-path-name." 22 | ) 23 | 24 | # Raise error if no train path is provided and no model is loaded 25 | if (args.train_path_name is None) and (args.model_file is None): 26 | parser.error( 27 | "You must provide --load-model if no --train-path-names is provided" 28 | ) 29 | 30 | # Raise error if none of the paths are provided 31 | 32 | if ( 33 | (args.test_paths_names is None) 34 | & (args.train_path_name is None) 35 | & (args.validation_path_name is None) 36 | ): 37 | parser.error( 38 | "You must provide one of the following options:\n" 39 | "--test-path-names --load-model\n" 40 | "--val-path-names --load-model\n" 41 | "--train-path-name and --validation-path-name (optional load model)\n" 42 | "--train-path-name and --validation-path-name --test-path-names (optional load model)\n" 43 | "All cases with including --full-path-name. Please provide the required option(s) and try again." 44 | ) 45 | 46 | if (args.save_prediction_results) & ( 47 | (args.test_paths_names is None) & (args.validation_path_name is None) 48 | ): 49 | parser.error( 50 | "You must provide --test-path-names and/or --val-path-names to save the results of the validation and/or test sets." 51 | ) 52 | 53 | 54 | def generate_sequence_embeddings(device, sequence_encoder, datasets, params): 55 | """Generate sequence embeddings for the given datasets.""" 56 | sequence_encoder = sequence_encoder.to(device) 57 | sequence_encoder.eval() 58 | all_datasets = [ 59 | dataset for dataset_list in datasets.values() for dataset in dataset_list 60 | ] 61 | combined_dataset = ConcatDataset(all_datasets) 62 | combined_loader = DataLoader( 63 | combined_dataset, 64 | batch_size=params["SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD"], 65 | shuffle=False, 66 | collate_fn=partial( 67 | collate_variable_sequence_length, return_label_multihots=False 68 | ), # have to use return_label_multihots to ignore multihot concat with zero shot 69 | num_workers=params["NUM_WORKERS"], 70 | pin_memory=True, 71 | ) 72 | # Initialize an empty list to store data 73 | data_list = [] 74 | 75 | for batch in tqdm(combined_loader): 76 | sequence_onehots, sequence_ids, sequence_lengths = ( 77 | batch["sequence_onehots"].to(device), 78 | batch["sequence_ids"], 79 | batch["sequence_lengths"].to(device), 80 | ) 81 | with torch.no_grad(): 82 | embeddings = sequence_encoder.get_embeddings( 83 | sequence_onehots, sequence_lengths 84 | ) 85 | for i, original_id in enumerate(sequence_ids): 86 | data_list.append((original_id, embeddings[i].cpu().numpy())) 87 | 88 | sequence_encoder.train() 89 | # Convert the list to a DataFrame 90 | df = pd.DataFrame(data_list, columns=["ID", "Embedding"]).set_index("ID") 91 | return df 92 | -------------------------------------------------------------------------------- /protnote/utils/notebooks.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from protnote.utils.evaluation import metrics_per_label_df 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import torch 6 | from protnote.utils.evaluation import EvalMetrics 7 | from protnote.utils.data import ec_number_to_code 8 | from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC 9 | from collections import Counter 10 | 11 | 12 | def complete_blast_preds(blast_df: pd.DataFrame, labels: list, seqs: list): 13 | blast_cols = set(blast_df.columns) 14 | 15 | # Add labels that blast missed 16 | blast_df[list(set(labels) - blast_cols)] = -15.0 17 | 18 | # Update blast columns 19 | blast_cols = set(blast_df.columns) 20 | blast_cols = [label for label in labels if label in blast_cols] 21 | 22 | # Consider only the sequences in seqs, and add 23 | # sequences that blast missed 24 | blast_df = blast_df[blast_cols].reindex(seqs).fillna(-15.0) 25 | 26 | return blast_df 27 | 28 | 29 | def get_metrics(logits_df, labels_df, device, threshold): 30 | logits = torch.tensor(logits_df.values, device=device) 31 | labels = torch.tensor(labels_df.values, device=device) 32 | probabilities = torch.sigmoid(logits) 33 | 34 | eval_metrics = EvalMetrics(device) 35 | selected_eval_metrics = eval_metrics.get_metric_collection_with_regex( 36 | pattern="f1_m.*", threshold=threshold, num_labels=labels.shape[-1] 37 | ) 38 | mAP_micro = BinaryAUPRC(device=device) 39 | mAP_macro = MultilabelAUPRC(device=device, num_labels=labels.shape[-1]) 40 | 41 | selected_eval_metrics(probabilities, labels) 42 | mAP_micro.update(probabilities.flatten(), labels.flatten()) 43 | mAP_macro.update(probabilities, labels) 44 | 45 | metrics = { 46 | "mAP Macro": mAP_macro.compute(), 47 | "mAP Micro": mAP_micro.compute(), 48 | **selected_eval_metrics.compute(), 49 | } 50 | return {k: v.item() for k, v in metrics.items()} 51 | 52 | 53 | def get_ontology_from_parenthood(go_term, parenthood): 54 | go_term_to_ontology = { 55 | "GO:0008150": "biological_process", 56 | "GO:0003674": "molecular_function", 57 | "GO:0005575": "celular_component", 58 | } 59 | for parent in parenthood[go_term]: 60 | if parent in go_term_to_ontology: 61 | ontology = go_term_to_ontology[parent] 62 | break 63 | return ontology 64 | 65 | 66 | def filter_by_go_ontology(ontology: str, df: pd.DataFrame, graph=None, parenthood=None): 67 | assert (graph is not None) ^ (parenthood is not None) 68 | 69 | if graph is not None: 70 | col_mask = [ 71 | (graph.nodes[go_term]["namespace"] if go_term in graph.nodes else "missing") 72 | for go_term in df.columns 73 | ] 74 | if parenthood is not None: 75 | col_mask = [ 76 | get_ontology_from_parenthood(go_term, parenthood) for go_term in df.columns 77 | ] 78 | 79 | assert ontology in [ 80 | "All", 81 | "biological_process", 82 | "cellular_component", 83 | "molecular_function", 84 | ] 85 | if ontology == "All": 86 | return df 87 | filtered_df = df.iloc[:, [ontology == i for i in col_mask]] 88 | return filtered_df 89 | 90 | 91 | def filter_by_ec_level_1(ontology: str, df: pd.DataFrame, ec_class_descriptions: dict): 92 | ec_number_to_ec_level_1 = lambda ec_number: (ec_number_to_code(ec_number)[0], 0, 0) 93 | col_mask = [ 94 | ec_class_descriptions[ec_number_to_ec_level_1(ec_number)]["label"] 95 | for ec_number in df.columns 96 | ] 97 | if ontology == "All": 98 | return df 99 | filtered_df = df.iloc[:, [i == ontology for i in col_mask]] 100 | return filtered_df 101 | 102 | 103 | def metrics_by_go_ontology(df_logits, df_labels, graph, device, threshold): 104 | results = {} 105 | for ontology in [ 106 | "All", 107 | "biological_process", 108 | "cellular_component", 109 | "molecular_function", 110 | ]: 111 | filtered_df_logits = filter_by_go_ontology(ontology, df_logits, graph) 112 | filtered_df_labels = filter_by_go_ontology(ontology, df_labels, graph) 113 | results[ontology] = get_metrics( 114 | filtered_df_logits, filtered_df_labels, device=device, threshold=threshold 115 | ) 116 | return results 117 | 118 | 119 | def metrics_by_ec_level_1( 120 | df_logits, df_labels, ec_class_descriptions, device, threshold 121 | ): 122 | results = {} 123 | ec_level_1s = [ec_class_descriptions[(i, 0, 0)]["label"] for i in range(1, 8)] 124 | for ec_level_1 in ["All"] + ec_level_1s: 125 | filtered_df_logits = filter_by_ec_level_1( 126 | ontology=ec_level_1, 127 | df=df_logits, 128 | ec_class_descriptions=ec_class_descriptions, 129 | ) 130 | filtered_df_labels = filter_by_ec_level_1( 131 | ontology=ec_level_1, 132 | df=df_labels, 133 | ec_class_descriptions=ec_class_descriptions, 134 | ) 135 | results[ec_level_1] = get_metrics( 136 | filtered_df_logits, filtered_df_labels, device=device, threshold=threshold 137 | ) 138 | return results 139 | 140 | 141 | def save_fig(name): 142 | plt.savefig(f"{name}.pdf", format="pdf", dpi=1200, bbox_inches="tight") 143 | 144 | 145 | def plot_axes_add_letter_index(axes, loc=(0.1, 0.9)): 146 | letters = "abcdefghijklmnopqrstuvwxyz" 147 | for i, ax in enumerate(axes.flat): 148 | ax.text( 149 | *loc, 150 | f"{letters[i]})", 151 | transform=ax.transAxes, 152 | fontsize=12, 153 | fontweight="bold", 154 | ) 155 | 156 | 157 | def plot_category_performance( 158 | metrics_df: pd.DataFrame, 159 | test_name: str, 160 | metric: str, 161 | category_name: str, 162 | ylim: tuple = (0, 0.5), 163 | rotate_x_ticks: bool = False, 164 | pltshow=True, 165 | savefig=True, 166 | name: str = None, 167 | ax=None, 168 | figsize=None, 169 | palette=None, 170 | ): 171 | plot_df = ( 172 | metrics_df.query("metric == @metric and test_name == @test_name") 173 | .melt(ignore_index=False, var_name=[category_name], value_name=metric) 174 | .reset_index() 175 | ) 176 | if figsize is not None: 177 | fig, ax = plt.subplots(figsize=figsize) 178 | 179 | sns.barplot( 180 | data=plot_df, 181 | x=category_name, 182 | y=metric, 183 | alpha=0.8, 184 | errorbar="sd", 185 | capsize=0.1, 186 | # err_kws={'linewidth':1.3}, 187 | hue="model", 188 | ax=ax, 189 | palette=palette, 190 | ) 191 | sns.stripplot( 192 | data=plot_df, 193 | x=category_name, 194 | y=metric, 195 | hue="model", 196 | edgecolor="black", 197 | # size=3, 198 | linewidth=0.5, 199 | alpha=0.8, 200 | palette=palette, 201 | dodge=True, 202 | ) 203 | 204 | sns.despine() 205 | plt.ylim(*ylim) 206 | plt.title(f"{metric} - {test_name}") 207 | plt.legend(title="Model", bbox_to_anchor=(0.5, -0.2), loc="upper center") 208 | if rotate_x_ticks: 209 | plt.xticks(rotation=30) 210 | if savefig: 211 | assert name is not None, "Must add a name if saving the figure" 212 | save_fig(name) 213 | if pltshow: 214 | plt.show() 215 | 216 | 217 | # Define threshold to calculate threshold-dependent metrics otherwise only threshold-agnostic metrics are calculated 218 | def _get_metrics_by_label_and_freq( 219 | logits_df, labels_df, train_go_term_distribution, quantiles, threshold, device 220 | ): 221 | res = metrics_per_label_df(logits_df, labels_df, device=device, threshold=threshold) 222 | res["Frequency"] = res.index.map(train_go_term_distribution) 223 | 224 | # Bin frequencies 225 | freq_bins, freq_bin_edges = pd.qcut( 226 | train_go_term_distribution, 227 | q=quantiles, 228 | duplicates="drop", 229 | precision=0, 230 | retbins=True, 231 | labels=None, 232 | ) 233 | 234 | res["Frequency Bin"] = res.index.map(freq_bins) 235 | # res['Frequency Edges']=res.index.map(freq_bin_edges) 236 | return res, freq_bins, freq_bin_edges 237 | 238 | 239 | def get_metrics_by_label_and_freq( 240 | models, train_go_term_distribution, quantiles, threshold, device 241 | ): 242 | res_dfs = [] 243 | for model, data in models.items(): 244 | print(model) 245 | logits_df = data["logits_df"] 246 | labels_df = data["labels_df"] 247 | 248 | res, freq_bins, freq_bin_edges = _get_metrics_by_label_and_freq( 249 | logits_df=logits_df, 250 | labels_df=labels_df, 251 | train_go_term_distribution=train_go_term_distribution, 252 | quantiles=quantiles, 253 | threshold=threshold, 254 | device=device, 255 | ) 256 | 257 | # Combine both dataframes 258 | res["model"] = model 259 | res_dfs.append(res) 260 | 261 | res_df = pd.concat(res_dfs, axis=0) 262 | 263 | # res_pivot = res_df.pivot(columns=['model'],values=[metric,'Frequency']) 264 | # res_pivot.columns = [i[0]+'_'+i[1] for i in res_pivot.columns] 265 | 266 | return res_df, freq_bins, freq_bin_edges 267 | 268 | 269 | def plot_metric_by_label_freq( 270 | models, train_go_term_distribution, metric, quantiles, threshold, device 271 | ): 272 | res_df, freq_bins, freq_bin_edges = get_metrics_by_label_and_freq( 273 | models=models, 274 | train_go_term_distribution=train_go_term_distribution, 275 | quantiles=quantiles, 276 | threshold=threshold, 277 | device=device, 278 | ) 279 | res_df.dropna(subset=[metric], inplace=True) 280 | freq_bins_pct = freq_bins.value_counts() * 100 / len(train_go_term_distribution) 281 | fig, ax = plt.subplots(figsize=(15, 6)) 282 | 283 | # Annotate bars with the percentage of observations 284 | for index, value in enumerate(freq_bins_pct.sort_index().values): 285 | ax.text( 286 | index, 287 | ax.get_ylim()[1] * 0.01 + max(res_df[metric]) * 0.01, 288 | f"{value:.2f}%", 289 | ha="center", 290 | ) 291 | 292 | sns.barplot( 293 | data=res_df.reset_index(drop=True), 294 | x="Frequency Bin", 295 | y=metric, 296 | alpha=0.8, 297 | errorbar=("ci", 95), 298 | hue="model", 299 | ) 300 | ax.set( 301 | title=f"Individual label performance ({metric}) by label frequency quantiles", 302 | xlabel="Frequency of GO Function Annotation", 303 | ylabel=metric, 304 | ) 305 | sns.despine() 306 | plt.ylim(0, 1) 307 | plt.show() 308 | 309 | 310 | def get_data_distributions(data: pd.DataFrame): 311 | labels = Counter() 312 | vocab = set() 313 | amino_freq = Counter() 314 | for idx, row in data.iterrows(): 315 | sequence = row["sequence"] 316 | row_labels = row["labels"] 317 | aa_list = list(sequence) 318 | if row_labels == "": 319 | print(row["id"], row["labels"]) 320 | vocab.update(aa_list) 321 | amino_freq.update(aa_list) 322 | labels.update(row_labels.split(" ")) 323 | return vocab, amino_freq, labels 324 | -------------------------------------------------------------------------------- /protnote/utils/proteinfer.py: -------------------------------------------------------------------------------- 1 | from protnote.utils.data import read_pickle 2 | import numpy as np 3 | import collections 4 | import torch 5 | 6 | 7 | def transfer_tf_weights_to_torch(torch_model: torch.nn.Module, tf_weights_path: str): 8 | # Load tensorflow variables. Remove global step variable and add it as num_batches variable for each batchnorm 9 | tf_weights = read_pickle(tf_weights_path) 10 | # total training steps from the paper. Used for batch norm running statistics. 11 | num_batches = tf_weights["inferrer/global_step:0"] 12 | tf_weights.pop("inferrer/global_step:0") 13 | temp = {} 14 | for tf_name, tf_param in tf_weights.items(): 15 | temp[tf_name] = tf_param 16 | if ("batch_normalization" in tf_name) & ("moving_variance" in tf_name): 17 | num_batches_name = "/".join( 18 | tf_name.split("/")[:-1] + ["num_batches_tracked:0"] 19 | ) 20 | temp[num_batches_name] = np.array(num_batches) 21 | tf_weights = temp 22 | 23 | # Get pytorch model variables 24 | state_dict = torch_model.state_dict() 25 | state_dict_list = [(k, v) for k, v in state_dict.items()] 26 | 27 | with torch.no_grad(): 28 | for (name, param), (tf_name, tf_param) in zip( 29 | state_dict_list, tf_weights.items() 30 | ): 31 | if tf_param.ndim >= 2: 32 | tf_param = np.transpose( 33 | tf_param, tuple(sorted(range(tf_param.ndim), reverse=True)) 34 | ) 35 | 36 | assert ( 37 | tf_param.shape == param.detach().numpy().shape 38 | ), f"{name} and {tf_name} don't have the same shape" 39 | state_dict[name] = torch.from_numpy(tf_param) 40 | 41 | torch_model.load_state_dict(state_dict) 42 | 43 | 44 | def reverse_map(applicable_label_dict, label_vocab=None): 45 | """Flip parenthood dict to map parents to children. 46 | 47 | Args: 48 | applicable_label_dict: e.g. output of get_applicable_label_dict. 49 | label_vocab: e.g. output of inference_lib.vocab_from_model_base_path 50 | 51 | Returns: 52 | collections.defaultdict of k, v where: 53 | k: originally the values in applicable_label_dict 54 | v: originally the keys in applicable_label_dict. 55 | The defaultdict returns an empty frozenset for keys that are not found. 56 | This behavior is desirable for lifted clan label normalizers, where 57 | keys may not imply themselves. 58 | """ 59 | # This is technically the entire transitive closure, so it is safe for DAGs 60 | # (e.g. GO labels). 61 | 62 | children = collections.defaultdict(set) 63 | for child, parents in applicable_label_dict.items(): 64 | # Avoid adding children which don't appear in the vocab. 65 | if label_vocab is None or child in label_vocab: 66 | for parent in parents: 67 | children[parent].add(child) 68 | children = {k: frozenset(v) for k, v in children.items()} 69 | return collections.defaultdict(frozenset, children.items()) 70 | 71 | 72 | def normalize_confidences(predictions, label_vocab, applicable_label_dict): 73 | """Set confidences of parent labels to the max of their children. 74 | 75 | Args: 76 | predictions: [num_sequences, num_labels] ndarray. 77 | label_vocab: list of vocab strings in an order that corresponds to 78 | `predictions`. 79 | applicable_label_dict: Mapping from labels to their parents (including 80 | indirect parents). 81 | 82 | Returns: 83 | A numpy array [num_sequences, num_labels] with confidences where: 84 | if label_vocab[k] in applicable_label_dict[label_vocab[j]], 85 | then arr[i, j] >= arr[i, k] for all i. 86 | """ 87 | vocab_indices = {v: i for i, v in enumerate(label_vocab)} 88 | children = reverse_map(applicable_label_dict, set(vocab_indices.keys())) 89 | 90 | # Only vectorize this along the sequences dimension as the number of children 91 | # varies between labels. 92 | label_confidences = [] 93 | for label in label_vocab: 94 | child_indices = np.array([vocab_indices[child] for child in children[label]]) 95 | if child_indices.size > 1: 96 | confidences = np.max(predictions[:, child_indices], axis=1) 97 | label_confidences.append(confidences) 98 | else: 99 | label_confidences.append(predictions[:, vocab_indices[label]]) 100 | 101 | return np.stack(label_confidences, axis=1) 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 140 3 | ignore = ["F841", "F401", "E712"] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as source: 4 | long_description = source.read() 5 | 6 | setup( 7 | name="protnote", 8 | version="1.0.0", 9 | author="Samir Char & Nate Corley", 10 | packages=find_packages(), 11 | include_package_data=True, 12 | description="ProtNote: a multimodal method for protein-function annotation", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/microsoft/ProteinFunctions", #TODO: Change repo name 16 | install_requires=[ 17 | "torch==2.2.0", # PyTorch 18 | "torchvision==0.15.2", # TorchVision 19 | # "pytorch-cuda==11.8", #PyTorch CUDA 20 | "pandas==1.5.2", # Pandas 21 | "joblib==1.2.0", # Joblib 22 | "transformers==4.38.0", # Huggingface Transformers 23 | "torchmetrics==1.2.0", # PyTorch metrics 24 | "torchdata==0.7.1", # PyTorch Data 25 | "wandb==0.15.11", # Weights and Biases 26 | "sacremoses==0.0.53", # Tokenizer 27 | "pynvml==11.5.0", # Nvidia management library 28 | "torcheval==0.0.7", # PyTorch evaluation 29 | "wget==3.2", # wget 30 | "azureml-mlflow==1.53.0", # Azure MLflow 31 | "loralib==0.1.2", # LoRA library 32 | "tensorboard==2.15.1", # Tensorboard 33 | "obonet==1.0.0", # OBO ontologies 34 | "blosum==2.0.2", # BLOSUM scoring matrices 35 | "biopython==1.84", # Biopython for bioinformatics 36 | "ipykernel==6.29.5", #Jupyter notebook 37 | "scipy==1.13.1", # Viz 38 | "seaborn==0.13.2", # Viz 39 | "scikit-learn==1.5.0", 40 | "matplotlib==3.9.2", # Viz 41 | "umap-learn==0.5.4" # Viz 42 | ], 43 | classifiers=[ 44 | "Programming Language :: Python :: 3", 45 | "Operating System :: OS Independent", 46 | ], 47 | python_requires=">=3.10", 48 | ) -------------------------------------------------------------------------------- /singularity_image/Dockerfile: -------------------------------------------------------------------------------- 1 | # vim: ft=mako 2 | 3 | FROM singularitybase.azurecr.io/base/job/pytorch/acpt-2.0.1-py3.10-cuda11.8:20240416T115414238 as base 4 | 5 | FROM base 6 | 7 | RUN which mamba || conda install -yq -c conda-forge mamba 8 | RUN mamba install -yq -c biobuilds -c pytorch -c huggingface -c nvidia -c anaconda -c conda-forge -c defaults pytorch=2.0.1 torchvision=0.15.2 pytorch-cuda=11.8 pandas=1.5.2 joblib=1.1.1 transformers=4.32.1 torchmetrics=1.2.0 torchdata=0.7.1 wandb=0.15.11 sacremoses=0.0.53 pynvml=11.5.0 torchmetrics=1.2.0 && conda clean -yqa 9 | RUN python -mpip install -q torcheval==0.0.7 wget==3.2 azureml-mlflow==1.53.0 loralib==0.1.2 tensorboard==2.15.1 obonet==1.0.0 blosum==2.0.2 biopython==1.84 10 | 11 | # set validation arguments for expected use of this image 12 | ENV SINGULARITY_IMAGE_FRAMEWORK=pytorch 13 | ENV SINGULARITY_IMAGE_ACCELERATOR=NVIDIA 14 | 15 | -------------------------------------------------------------------------------- /singularity_image/build_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set your variables here 4 | registry_name="bio0extcr" 5 | repository_name="torch2.0.1cuda11.8" 6 | image_tag="v1.0" # Change this tag as needed 7 | 8 | # Build the Docker image 9 | echo "Building the Docker image..." 10 | docker build . -f Dockerfile --progress=plain 11 | 12 | # Get the Image ID of the most recently built image 13 | image_id=$(docker images -q | head -n 1) 14 | 15 | if [ -z "$image_id" ]; then 16 | echo "Error: Failed to get the image ID. The build might have failed." 17 | exit 1 18 | fi 19 | 20 | echo "Image built with ID: $image_id" 21 | 22 | # Tag the image for pushing to ACR 23 | echo "Tagging the Docker image..." 24 | docker tag $image_id $registry_name.azurecr.io/$repository_name:$image_tag 25 | 26 | # Push the image to ACR 27 | echo "Pushing the image to Azure Container Registry (ACR)..." 28 | docker push $registry_name.azurecr.io/$repository_name:$image_tag 29 | 30 | echo "Image successfully pushed to ACR: $registry_name.azurecr.io/$repository_name:$image_tag" 31 | --------------------------------------------------------------------------------