├── .gitattributes ├── .github └── workflows │ └── CI.yml ├── .gitignore ├── LICENSE ├── NEWS.md ├── README.md ├── bin └── progres ├── data ├── d3oxpa1.pdb ├── d3urra1.pdb ├── filepaths.txt ├── query.pdb └── searchdb.pt ├── docker ├── Dockerfile └── docker_runner.sh ├── progres ├── __init__.py ├── chainsaw │ ├── __init__.py │ ├── get_predictions.py │ └── src │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── domain_assignment │ │ ├── __init__.py │ │ ├── assigners.py │ │ ├── sparse_lowrank.py │ │ └── util.py │ │ ├── domain_chop.py │ │ ├── errors.py │ │ ├── factories.py │ │ ├── featurisers.py │ │ ├── loggers.py │ │ ├── models │ │ ├── __init__.py │ │ ├── results.py │ │ └── rosetta.py │ │ └── utils │ │ ├── __init__.py │ │ ├── cif2pdb.py │ │ ├── common.py │ │ ├── domain_boundary_distance_score.py │ │ ├── ndo_score.py │ │ ├── pdb_reres.py │ │ └── secondary_structure.py ├── databases │ └── README.md ├── progres.py └── trained_models │ └── README.md ├── scripts ├── dataset │ ├── README.md │ ├── overlap_sfam.txt │ ├── overlap_sid.txt │ ├── search_all.txt │ ├── test_seen.txt │ ├── test_unseen_sfam.txt │ ├── test_unseen_sid.txt │ ├── train_sfam.txt │ ├── train_sid.txt │ ├── val_seen.txt │ ├── val_unseen_sfam.txt │ └── val_unseen_sid.txt ├── faiss_index.py ├── other_methods │ ├── 3dsurfer │ │ ├── run_nn_allatom.txt.gz │ │ ├── searching.py │ │ └── searching.txt │ ├── README.md │ ├── astral_40_upper.fa │ ├── contact_order.txt │ ├── dali │ │ ├── domids_imported.txt │ │ ├── domids_pdbids.txt │ │ ├── import.sh │ │ ├── pdbids_imported.txt │ │ ├── run.sh │ │ ├── searching.py │ │ └── searching.txt │ ├── eat │ │ ├── run.sh │ │ ├── scope40.fasta │ │ ├── searching.py │ │ └── searching.txt │ ├── esm │ │ ├── embed.sh │ │ ├── searching.py │ │ └── searching.txt │ ├── extract_model.jl │ ├── foldseek │ │ ├── run.sh │ │ ├── run_tm.sh │ │ ├── searching.py │ │ ├── searching.txt │ │ ├── searching_tm.py │ │ └── searching_tm.txt │ ├── mmseqs2 │ │ ├── run.sh │ │ ├── searching.py │ │ └── searching.txt │ └── tmalign │ │ ├── domids.txt │ │ ├── parse.sh │ │ ├── run.sh │ │ ├── searching.py │ │ └── searching.txt ├── scope_val.py ├── scope_val.txt ├── train.py └── training_coords.tgz └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | Dockerfile text eol=lf 2 | *.sh text eol=lf 3 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | push: 7 | branches: 8 | - main 9 | tags: '*' 10 | schedule: 11 | - cron: '00 04 * * 1' # 4am every Monday 12 | workflow_dispatch: 13 | jobs: 14 | test_repo: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: 19 | - 3.9 20 | defaults: 21 | run: 22 | shell: bash -el {0} 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up conda with Python ${{ matrix.python-version }} 26 | uses: conda-incubator/setup-miniconda@v2 27 | with: 28 | auto-update-conda: true 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install PyTorch and FAISS 31 | run: conda install pytorch==1.11 faiss-cpu -c pytorch 32 | - name: Install PyTorch Scatter and PyTorch Geometric 33 | run: conda install pytorch-scatter pyg -c pyg 34 | - name: Install STRIDE 35 | run: conda install kimlab::stride 36 | - name: Test install 37 | run: pip install -e . 38 | - name: Test help 39 | run: time python bin/progres -h 40 | - name: Test import 41 | run: time python -c "import progres" 42 | - name: Download structures 43 | run: | 44 | wget https://files.rcsb.org/view/1CRN.pdb 45 | wget https://files.rcsb.org/view/1SSU.cif 46 | wget https://alphafold.ebi.ac.uk/files/AF-P31434-F1-model_v4.pdb 47 | - name: Test search 48 | run: time python bin/progres search -q 1CRN.pdb -t scope95 49 | - name: Test domain split 50 | run: time python bin/progres search -q AF-P31434-F1-model_v4.pdb -t cath40 -c 51 | - name: Test score 52 | run: time python bin/progres score 1CRN.pdb 1SSU.cif > score.txt 53 | - name: Check score 54 | run: | 55 | sc=$(cat score.txt) 56 | if [ ${sc:0:7} == "0.72652" ]; then echo "Correct score"; else echo "Wrong score, score is $sc"; exit 1; fi 57 | - name: Test database embedding 58 | run: | 59 | cd data 60 | time python ../bin/progres embed -l filepaths.txt -o out.pt 61 | time python ../bin/progres search -q query.pdb -t out.pt 62 | test_pypi: 63 | runs-on: ubuntu-latest 64 | strategy: 65 | matrix: 66 | python-version: 67 | - 3.9 68 | defaults: 69 | run: 70 | shell: bash -el {0} 71 | steps: 72 | - uses: actions/checkout@v2 73 | - name: Set up conda with Python ${{ matrix.python-version }} 74 | uses: conda-incubator/setup-miniconda@v2 75 | with: 76 | auto-update-conda: true 77 | python-version: ${{ matrix.python-version }} 78 | - name: Install PyTorch and FAISS 79 | run: conda install pytorch==1.11 faiss-cpu -c pytorch 80 | - name: Install PyTorch Scatter and PyTorch Geometric 81 | run: conda install pytorch-scatter pyg -c pyg 82 | - name: Install STRIDE 83 | run: conda install kimlab::stride 84 | - name: Test install 85 | run: pip install progres 86 | - name: Test help 87 | run: time progres -h 88 | - name: Test import 89 | run: time python -c "import progres" 90 | - name: Download structures 91 | run: | 92 | wget https://files.rcsb.org/view/1CRN.pdb 93 | wget https://files.rcsb.org/view/1SSU.cif 94 | wget https://alphafold.ebi.ac.uk/files/AF-P31434-F1-model_v4.pdb 95 | - name: Test search 96 | run: time progres search -q 1CRN.pdb -t scope95 97 | - name: Test domain split 98 | run: time progres search -q AF-P31434-F1-model_v4.pdb -t cath40 -c 99 | - name: Test score 100 | run: time progres score 1CRN.pdb 1SSU.cif > score.txt 101 | - name: Check score 102 | run: | 103 | sc=$(cat score.txt) 104 | if [ ${sc:0:7} == "0.72652" ]; then echo "Correct score"; else echo "Wrong score, score is $sc"; exit 1; fi 105 | - name: Test database embedding 106 | run: | 107 | cd data 108 | time progres embed -l filepaths.txt -o out.pt 109 | time progres search -q query.pdb -t out.pt 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | build 4 | dist 5 | progres/trained_models/v* 6 | progres/databases/v* 7 | progres/chainsaw/model_v3/weights* 8 | .vs 9 | .vscode 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Joe Greener and other contributors (https://github.com/greener-group/progres/graphs/contributors) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Code is included from EGNN PyTorch (https://github.com/lucidrains/egnn-pytorch): 24 | 25 | MIT License 26 | 27 | Copyright (c) 2021 Phil Wang, Eric Alcaide 28 | 29 | Permission is hereby granted, free of charge, to any person obtaining a copy 30 | of this software and associated documentation files (the "Software"), to deal 31 | in the Software without restriction, including without limitation the rights 32 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | copies of the Software, and to permit persons to whom the Software is 34 | furnished to do so, subject to the following conditions: 35 | 36 | The above copyright notice and this permission notice shall be included in all 37 | copies or substantial portions of the Software. 38 | 39 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | SOFTWARE. 46 | 47 | Code is included from SupContrast (https://github.com/HobbitLong/SupContrast): 48 | 49 | BSD 2-Clause License 50 | 51 | Copyright (c) 2020, Yonglong Tian 52 | All rights reserved. 53 | 54 | Redistribution and use in source and binary forms, with or without 55 | modification, are permitted provided that the following conditions are met: 56 | 57 | 1. Redistributions of source code must retain the above copyright notice, this 58 | list of conditions and the following disclaimer. 59 | 60 | 2. Redistributions in binary form must reproduce the above copyright notice, 61 | this list of conditions and the following disclaimer in the documentation 62 | and/or other materials provided with the distribution. 63 | 64 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 65 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 66 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 67 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 68 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 69 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 70 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 71 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 72 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 73 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 74 | 75 | Code is included from Chainsaw (https://github.com/JudeWells/chainsaw): 76 | 77 | MIT License 78 | 79 | Copyright (c) 2024 Jude Wells 80 | 81 | Permission is hereby granted, free of charge, to any person obtaining a copy 82 | of this software and associated documentation files (the "Software"), to deal 83 | in the Software without restriction, including without limitation the rights 84 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 85 | copies of the Software, and to permit persons to whom the Software is 86 | furnished to do so, subject to the following conditions: 87 | 88 | The above copyright notice and this permission notice shall be included in all 89 | copies or substantial portions of the Software. 90 | 91 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 92 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 93 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 94 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 95 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 96 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 97 | SOFTWARE. 98 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # Progres release notes 2 | 3 | ## v1.0.0 - Apr 2025 4 | 5 | Stable version of Progres to accompany the published paper (https://academic.oup.com/bioinformaticsadvances/article/5/1/vbaf042/8107707). 6 | 7 | - A bug in loading the trained model when the data has not been downloaded is fixed. 8 | 9 | ## v0.2.7 - Sep 2024 10 | 11 | - Bugs in domain splitting with a blank chain ID and Mac multithreading are fixed. 12 | 13 | ## v0.2.6 - Aug 2024 14 | 15 | - Python packaging issues are fixed. 16 | 17 | ## v0.2.5 - Aug 2024 18 | 19 | - Structures can now be split into domains with Chainsaw before searching, with each domain searched separately. This makes Progres suitable for use with multi-domain structures. 20 | - The whole PDB split into domains with Chainsaw is made available to search against. 21 | - Hetero atoms are now ignored during file reading. 22 | - Example files are added for searching and database embedding. 23 | 24 | ## v0.2.4 - Jul 2024 25 | 26 | - The `score` mode is added to calculate the Progres score between two structures. 27 | 28 | ## v0.2.3 - May 2024 29 | 30 | - Incomplete downloads are handled during setup. 31 | 32 | ## v0.2.2 - Apr 2024 33 | 34 | - The environmental variable `PROGRES_DATA_DIR` can be used to change where the downloaded data is stored. 35 | - A Docker file is added. 36 | - Searching on GPU is made more memory efficient. 37 | - Bugs when running on Windows are fixed. 38 | 39 | ## v0.2.1 - Apr 2024 40 | 41 | - The AlphaFold database TED domains are made available to search against, with FAISS used for fast searching. 42 | - Pre-embedded databases are stored as Float16 to reduce disk usage. 43 | - Datasets and scripts for benchmarking (including for other methods), FAISS index generation and training are made available. 44 | 45 | ## v0.2.0 - Mar 2023 46 | 47 | - Change model architecture to use 6 EGNN layers and tau torsion angles, making it faster and SE(3)-invariant rather than E(3)-invariant. 48 | - The AlphaFold models for 21 model organisms are made available to search against. 49 | - The trained model and pre-embedded databases are downloaded from Zenodo rather than GitHub when first running the software. 50 | 51 | ## v0.1.3 - Nov 2022 52 | 53 | - Fix data download. 54 | 55 | ## v0.1.2 - Nov 2022 56 | 57 | - Add ECOD database. 58 | - Use versioned model directory. 59 | 60 | ## v0.1.1 - Nov 2022 61 | 62 | - Add einops dependency. 63 | - Add code for ECOD database. 64 | 65 | ## v0.1.0 - Nov 2022 66 | 67 | Initial release of the `progres` Python package for fast protein structure searching using structure graph embeddings. 68 | -------------------------------------------------------------------------------- /bin/progres: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Argument handling 4 | 5 | import argparse 6 | import importlib.metadata 7 | import sys 8 | 9 | parser = argparse.ArgumentParser(description=( 10 | "Fast protein structure searching using structure graph embeddings. " 11 | "See https://github.com/greener-group/progres for documentation and citation information. " 12 | f"This is version {importlib.metadata.version('progres')} of the software." 13 | )) 14 | subparsers = parser.add_subparsers(dest="mode", 15 | help="the mode to run progres in, run \"progres {mode} -h\" to see help for each") 16 | 17 | parser_search = subparsers.add_parser("search", 18 | description=( 19 | "Search one or more protein structures against a pre-embedded database. " 20 | "See https://github.com/greener-group/progres for documentation and citation information. " 21 | f"This is version {importlib.metadata.version('progres')} of the software." 22 | ), 23 | help="search one or more protein structures against a pre-embedded database") 24 | parser_search.add_argument("-q", "--querystructure", 25 | help="query structure file in PDB/mmCIF/MMTF/coordinate format") 26 | parser_search.add_argument("-l", "--querylist", 27 | help="text file with one query file path per line") 28 | parser_search.add_argument("-t", "--targetdb", required=True, 29 | help=("pre-embedded database to search against, either \"scope95\", \"scope40\", " 30 | "\"cath40\", \"ecod70\", \"pdb100\", \"af21org\", \"afted\" or a file path")) 31 | parser_search.add_argument("-f", "--fileformat", 32 | choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess", 33 | help="file format of the query structure(s), by default guessed from the file extension") 34 | parser_search.add_argument("-s", "--minsimilarity", type=float, default=0.8, 35 | help="Progres score (0 -> 1) above which to return hits, default 0.8") 36 | parser_search.add_argument("-m", "--maxhits", type=int, default=100, 37 | help="maximum number of hits per domain to return, default 100") 38 | parser_search.add_argument("-c", "--chainsaw", default=False, action="store_true", 39 | help=("split the query structure(s) into domains with Chainsaw and search with " 40 | "each domain separately")) 41 | parser_search.add_argument("-d", "--device", default="cpu", 42 | help="device to run on, default is \"cpu\"") 43 | 44 | parser_score = subparsers.add_parser("score", 45 | description=( 46 | "Calculate the Progres score between two protein domains. " 47 | "The order of the domains does not affect the score. " 48 | "A score of 0.8 or higher indicates the same fold. " 49 | "See https://github.com/greener-group/progres for documentation and citation information. " 50 | f"This is version {importlib.metadata.version('progres')} of the software." 51 | ), 52 | help="calculate the Progres score between two protein domains") 53 | parser_score.add_argument("structure1", 54 | help="first structure file in PDB/mmCIF/MMTF/coordinate format") 55 | parser_score.add_argument("structure2", 56 | help="second structure file in PDB/mmCIF/MMTF/coordinate format") 57 | parser_score.add_argument("-f", "--fileformat1", 58 | choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess", 59 | help="file format of the first structure, by default guessed from the file extension") 60 | parser_score.add_argument("-g", "--fileformat2", 61 | choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess", 62 | help="file format of the second structure, by default guessed from the file extension") 63 | parser_score.add_argument("-d", "--device", default="cpu", 64 | help="device to run on, default is \"cpu\"") 65 | 66 | parser_embed = subparsers.add_parser("embed", 67 | description=( 68 | "Embed a dataset of structures to allow it to be searched against. " 69 | "See https://github.com/greener-group/progres for documentation and citation information. " 70 | f"This is version {importlib.metadata.version('progres')} of the software." 71 | ), 72 | help="embed a dataset of structures to allow it to be searched against") 73 | parser_embed.add_argument("-l", "--structurelist", required=True, 74 | help="text file with file path, domain name and optional note per line") 75 | parser_embed.add_argument("-o", "--outputfile", required=True, 76 | help="output file path for the PyTorch file containing the embeddings") 77 | parser_embed.add_argument("-f", "--fileformat", 78 | choices=["guess", "pdb", "mmcif", "mmtf", "coords"], default="guess", 79 | help="file format of the structures, by default guessed from the file extension") 80 | parser_embed.add_argument("-d", "--device", default="cpu", 81 | help="device to run on, default is \"cpu\"") 82 | 83 | args = parser.parse_args() 84 | 85 | def main(): 86 | if args.mode == "search": 87 | from progres import progres_search_print 88 | if args.minsimilarity < 0 or args.minsimilarity > 1: 89 | raise argparse.ArgumentTypeError("minsimilarity must be between 0 and 1") 90 | if args.maxhits < 1: 91 | raise argparse.ArgumentTypeError("maxhits must be a positive integer") 92 | if args.querystructure: 93 | progres_search_print(querystructure=args.querystructure, targetdb=args.targetdb, 94 | fileformat=args.fileformat, minsimilarity=args.minsimilarity, 95 | maxhits=args.maxhits, chainsaw=args.chainsaw, device=args.device) 96 | elif args.querylist: 97 | progres_search_print(querylist=args.querylist, targetdb=args.targetdb, 98 | fileformat=args.fileformat, minsimilarity=args.minsimilarity, 99 | maxhits=args.maxhits, chainsaw=args.chainsaw, device=args.device) 100 | else: 101 | print("One of -q and -l must be given for structure searching", file=sys.stderr) 102 | elif args.mode == "score": 103 | from progres import progres_score_print 104 | progres_score_print(structure1=args.structure1, structure2=args.structure2, 105 | fileformat1=args.fileformat1, fileformat2=args.fileformat2, 106 | device=args.device) 107 | elif args.mode == "embed": 108 | from progres import progres_embed 109 | progres_embed(structurelist=args.structurelist, outputfile=args.outputfile, 110 | fileformat=args.fileformat, device=args.device) 111 | else: 112 | print("No mode selected, run \"progres -h\" to see help", file=sys.stderr) 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /data/filepaths.txt: -------------------------------------------------------------------------------- 1 | d3urra1.pdb d3urra1 d.112.1.0 - automated matches {Burkholderia thailandensis [TaxId: 271848]} 2 | d3oxpa1.pdb d3oxpa1 d.112.1.0 - automated matches {Yersinia pestis [TaxId: 214092]} 3 | -------------------------------------------------------------------------------- /data/searchdb.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/data/searchdb.pt -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # can build dockerfile with --build-arg PERSISTENCE_DIR="/persist/progres" and run it with -v /host/path/to/progres/data:/persist/progres 2 | # so that the files can be downloaded once but used many times. In particular useful in an HPC context, for the default dbs that can be pre-downloaded 3 | # to a shared read-only directory 4 | 5 | ARG MINICONDA_VERSION="Miniconda3-py311_24.1.2-0-Linux-x86_64" 6 | ARG CONDA_INSTALL_PATH="/pub/conda" 7 | ARG CONDA_ENV_NAME="progres_env" 8 | ARG PERSISTENCE_DIR="/persist/progres" 9 | 10 | #progres is the official repo on pypi, can override with an alternative for testing 11 | # e.g. git+https://github.com/xeniorn/progres.git@main 12 | ARG PROGRES_REPO='progres' 13 | 14 | FROM debian:12 as build 15 | 16 | ARG CONDA_INSTALL_PATH 17 | ARG MINICONDA_VERSION 18 | 19 | RUN apt-get -y update \ 20 | && apt-get -y install wget git \ 21 | && apt-get -y autoclean 22 | 23 | RUN wget --quiet \ 24 | https://repo.anaconda.com/miniconda/${MINICONDA_VERSION}.sh \ 25 | && bash ${MINICONDA_VERSION}.sh -bfp ${CONDA_INSTALL_PATH} \ 26 | && rm -f ${MINICONDA_VERSION}.sh 27 | 28 | ENV PATH="${CONDA_INSTALL_PATH}/bin:${PATH}" 29 | 30 | ARG CONDA_ENV_NAME 31 | ENV CONDA_ENV_NAME=${CONDA_ENV_NAME} 32 | 33 | RUN . ${CONDA_INSTALL_PATH}/etc/profile.d/conda.sh \ 34 | && conda create --yes -n ${CONDA_ENV_NAME} python=3.9 \ 35 | && conda activate ${CONDA_ENV_NAME} \ 36 | && conda install --yes pytorch=1.11 faiss-cpu -c pytorch \ 37 | && conda install --yes pytorch-scatter pyg -c pyg \ 38 | && conda install --yes kimlab::stride \ 39 | && conda clean --yes --all 40 | 41 | ARG PROGRES_REPO 42 | 43 | RUN . ${CONDA_INSTALL_PATH}/etc/profile.d/conda.sh \ 44 | && conda activate ${CONDA_ENV_NAME} \ 45 | && yes | pip install "${PROGRES_REPO}" \ 46 | && pip cache purge 47 | 48 | ARG PERSISTENCE_DIR 49 | ENV PROGRES_DATA_DIR=${PERSISTENCE_DIR} 50 | 51 | ARG RUN_SCRIPT="/pub/run.sh" 52 | ENV PROGRES_RUN_SCRIPT=${RUN_SCRIPT} 53 | 54 | COPY "docker/docker_runner.sh" ${RUN_SCRIPT} 55 | RUN chmod +rx ${RUN_SCRIPT} 56 | 57 | ENTRYPOINT [ "/pub/run.sh" ] 58 | -------------------------------------------------------------------------------- /docker/docker_runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | source activate ${CONDA_ENV_NAME} 6 | 7 | $@ 8 | -------------------------------------------------------------------------------- /progres/__init__.py: -------------------------------------------------------------------------------- 1 | from .progres import * 2 | from .chainsaw.get_predictions import predict_domains 3 | -------------------------------------------------------------------------------- /progres/chainsaw/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/progres/chainsaw/__init__.py -------------------------------------------------------------------------------- /progres/chainsaw/get_predictions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for running Chainsaw 3 | 4 | Created by: Jude Wells 2023-04-19 5 | 6 | Modified for Progres 7 | """ 8 | 9 | import hashlib 10 | import json 11 | import logging 12 | import os 13 | import sys 14 | import time 15 | from typing import List 16 | 17 | from progres.chainsaw.src import constants, featurisers 18 | from progres.chainsaw.src.domain_assignment.util import convert_domain_dict_strings 19 | from progres.chainsaw.src.factories import pairwise_predictor 20 | from progres.chainsaw.src.models.results import PredictionResult 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | config_str = """ 25 | { 26 | "experiment_name": "C_mul65_do30", 27 | "experiment_group": "cath_new", 28 | "lr": 0.0002, 29 | "weight_decay": 0.001, 30 | "val_freq": 1, 31 | "epochs": 15, 32 | "lr_scheduler": { 33 | "type": "exponential", 34 | "gamma": 0.9 35 | }, 36 | "accumulation_steps": 16, 37 | "data": { 38 | "splits_file": "splits_new_cath_featurized.json", 39 | "validation_splits": [ 40 | "validation", 41 | "test" 42 | ], 43 | "crop_size": null, 44 | "crop_type": null, 45 | "batch_size": 1, 46 | "feature_dir": "../features/new_cath/2d_features", 47 | "label_dir": "../features/new_cath/pairwise", 48 | "chains_csv": null, 49 | "evaluate_test": false, 50 | "eval_casp_10_plus": false, 51 | "remove_missing_residues": false, 52 | "using_alphafold_features": false, 53 | "recycling": false, 54 | "add_padding_mask": false, 55 | "training_exclusion_json": null, 56 | "multi_proportion": 0.65, 57 | "train_ids": "splits_new_cath_featurized.json", 58 | "exclude_test_topology": false, 59 | "cluster_sampling_training": true, 60 | "dist_transform": "unidoc_exponent", 61 | "distance_denominator": 10, 62 | "merizo_train_data": false, 63 | "redundancy_level": "S60_comb" 64 | }, 65 | "learner": { 66 | "uncertainty_model": false, 67 | "save_every_epoch": true, 68 | "model": { 69 | "type": "trrosetta", 70 | "kwargs": { 71 | "filters": 32, 72 | "kernel": 3, 73 | "num_layers": 31, 74 | "in_channels": 5, 75 | "dropout": 0.3, 76 | "symmetrise_output": true 77 | } 78 | }, 79 | "assignment": { 80 | "type": "sparse_lowrank", 81 | "kwargs": { 82 | "N_iters": 3, 83 | "K_init": 4, 84 | "linker_threshold": 30 85 | } 86 | }, 87 | "max_recycles": 0, 88 | "save_val_best": true, 89 | "x_has_padding_mask": false 90 | }, 91 | "num_trainable_params": 577889 92 | } 93 | """ 94 | 95 | feature_config_str = """ 96 | { 97 | "description": "alpha distance only, keep the start and end boundaries in one channel but use -1", 98 | "alpha_distances": true, 99 | "beta_distances": false, 100 | "ss_bounds": true, 101 | "negative_ss_end": true, 102 | "separate_channel_ss_start_end": false, 103 | "same_channel_boundaries_and_ss": false 104 | } 105 | """ 106 | 107 | def setup_logging(loglevel): 108 | # log all messages to stderr so results can be sent to stdout 109 | logging.basicConfig(level=loglevel, 110 | stream=sys.stderr, 111 | format='%(asctime)s | %(levelname)s | %(message)s', 112 | datefmt='%m/%d/%Y %I:%M:%S %p') 113 | 114 | def load_model(*, 115 | model_dir: str, 116 | remove_disordered_domain_threshold: float = 0.35, 117 | min_ss_components: int = 2, 118 | min_domain_length: int = 30, 119 | post_process_domains: bool = True, 120 | device: str = "cpu"): 121 | config = json.loads(config_str) 122 | feature_config = json.loads(feature_config_str) 123 | config["learner"]["remove_disordered_domain_threshold"] = remove_disordered_domain_threshold 124 | config["learner"]["post_process_domains"] = post_process_domains 125 | config["learner"]["min_ss_components"] = min_ss_components 126 | config["learner"]["min_domain_length"] = min_domain_length 127 | config["learner"]["dist_transform_type"] = config["data"].get("dist_transform", 'min_replace_inverse') 128 | config["learner"]["distance_denominator"] = config["data"].get("distance_denominator", None) 129 | learner = pairwise_predictor(config["learner"], output_dir=model_dir, device=device) 130 | learner.feature_config = feature_config 131 | learner.load_checkpoints() 132 | learner.eval() 133 | return learner 134 | 135 | def predict(model, pdb_path, pdbchain=None, fileformat="pdb") -> List[PredictionResult]: 136 | """ 137 | Makes the prediction and returns a list of PredictionResult objects 138 | """ 139 | start = time.time() 140 | 141 | # get model structure metadata 142 | model_structure = featurisers.get_model_structure(pdb_path, fileformat) 143 | 144 | if pdbchain is None: 145 | LOG.warning(f"No chain specified for {pdb_path}, using first chain") 146 | # get all the chain ids from the model structure 147 | all_chain_ids = [c.id for c in model_structure.get_chains()] 148 | # take the first chain id 149 | pdbchain = all_chain_ids[0] 150 | 151 | model_residues = featurisers.get_model_structure_residues(model_structure, chain=pdbchain) 152 | model_res_label_by_index = { int(r.index): str(r.res_label) for r in model_residues} 153 | model_structure_seq = "".join([r.aa for r in model_residues]) 154 | model_structure_md5 = hashlib.md5(model_structure_seq.encode('utf-8')).hexdigest() 155 | 156 | x = featurisers.inference_time_create_features(pdb_path, 157 | feature_config=model.feature_config, 158 | chain=pdbchain, 159 | model_structure=model_structure, 160 | fileformat=fileformat, 161 | ) 162 | 163 | A_hat, domain_dict, confidence = model.predict(x) 164 | # Convert 0-indexed to 1-indexed to match AlphaFold indexing: 165 | domain_dict = [{k: [r + 1 for r in v] for k, v in d.items()} for d in domain_dict] 166 | names_str, bounds_str = convert_domain_dict_strings(domain_dict[0]) 167 | confidence = confidence[0] 168 | 169 | if names_str == "": 170 | names = bounds = () 171 | else: 172 | names = names_str.split('|') 173 | bounds = bounds_str.split('|') 174 | 175 | assert len(names) == len(bounds) 176 | 177 | class Seg: 178 | def __init__(self, domain_id: str, start_index: int, end_index: int): 179 | self.domain_id = domain_id 180 | self.start_index = int(start_index) 181 | self.end_index = int(end_index) 182 | 183 | def res_label_of_index(self, index: int): 184 | if index not in model_res_label_by_index: 185 | raise ValueError(f"Index {index} not in model_res_label_by_index ({model_res_label_by_index})") 186 | return model_res_label_by_index[int(index)] 187 | 188 | @property 189 | def start_label(self): 190 | return self.res_label_of_index(self.start_index) 191 | 192 | @property 193 | def end_label(self): 194 | return self.res_label_of_index(self.end_index) 195 | 196 | class Dom: 197 | def __init__(self, domain_id, segs: List[Seg] = None): 198 | self.domain_id = domain_id 199 | if segs is None: 200 | segs = [] 201 | self.segs = segs 202 | 203 | def add_seg(self, seg: Seg): 204 | self.segs.append(seg) 205 | 206 | # gather choppings into segments in domains 207 | domains_by_domain_id = {} 208 | for domain_id, chopping_by_index in zip(names, bounds): 209 | if domain_id not in domains_by_domain_id: 210 | domains_by_domain_id[domain_id] = Dom(domain_id) 211 | start_index, end_index = chopping_by_index.split('-') 212 | seg = Seg(domain_id, start_index, end_index) 213 | domains_by_domain_id[domain_id].add_seg(seg) 214 | 215 | # sort domain choppings by the start residue in first segment 216 | domains = sorted(domains_by_domain_id.values(), key=lambda dom: dom.segs[0].start_index) 217 | 218 | # collect domain choppings as strings 219 | domain_choppings = [] 220 | for dom in domains: 221 | # convert segments to strings 222 | segs_str = [f"{seg.start_label}-{seg.end_label}" for seg in dom.segs] 223 | segs_index_str = [f"{seg.start_index}-{seg.end_index}" for seg in dom.segs] 224 | LOG.info(f"Segments (index to label): {segs_index_str} -> {segs_str}") 225 | # join discontinuous segs with '_' 226 | domain_choppings.append('_'.join(segs_str)) 227 | 228 | # join domains with ',' 229 | chopping_str = ','.join(domain_choppings) 230 | 231 | num_domains = len(domain_choppings) 232 | if num_domains == 0: 233 | chopping_str = None 234 | runtime = round(time.time() - start, 3) 235 | result = PredictionResult( 236 | pdb_path=pdb_path, 237 | chain_id="A", # Placeholder 238 | sequence_md5=model_structure_md5, 239 | nres=len(model_structure_seq), 240 | ndom=num_domains, 241 | chopping=chopping_str, 242 | confidence=confidence, 243 | time_sec=runtime, 244 | ) 245 | 246 | LOG.info(f"Runtime: {round(runtime, 3)}s") 247 | return result 248 | 249 | def predict_domains(structure_file, fileformat=None, device="cpu", pdbchain=None): 250 | loglevel = os.environ.get("LOGLEVEL", "ERROR").upper() # Change to "INFO" to see more 251 | setup_logging(loglevel) 252 | if fileformat is None: 253 | fileformat = "pdb" 254 | file_ext = os.path.splitext(structure_file)[1].lower() 255 | if file_ext == ".cif" or file_ext == ".mmcif": 256 | fileformat = "mmcif" 257 | elif file_ext == ".mmtf": 258 | fileformat = "mmtf" 259 | model = load_model( 260 | model_dir=os.path.join(constants.REPO_ROOT, "model_v3"), 261 | device=device, 262 | ) 263 | result = predict(model, structure_file, pdbchain=pdbchain, fileformat=fileformat) 264 | return result.chopping 265 | -------------------------------------------------------------------------------- /progres/chainsaw/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/progres/chainsaw/src/__init__.py -------------------------------------------------------------------------------- /progres/chainsaw/src/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | REPO_ROOT = Path(__file__).parent.parent.resolve() 5 | STRIDE_EXE = os.environ.get("STRIDE_EXE", "stride") 6 | 7 | _3to1 = { 8 | 'ALA': 'A', 9 | 'CYS': 'C', 10 | 'ASP': 'D', 11 | 'GLU': 'E', 12 | 'PHE': 'F', 13 | 'GLY': 'G', 14 | 'HIS': 'H', 15 | 'ILE': 'I', 16 | 'LYS': 'K', 17 | 'LEU': 'L', 18 | 'MET': 'M', 19 | 'ASN': 'N', 20 | 'PRO': 'P', 21 | 'GLN': 'Q', 22 | 'ARG': 'R', 23 | 'SER': 'S', 24 | 'THR': 'T', 25 | 'VAL': 'V', 26 | 'TRP': 'W', 27 | 'TYR': 'Y', 28 | } 29 | -------------------------------------------------------------------------------- /progres/chainsaw/src/domain_assignment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/progres/chainsaw/src/domain_assignment/__init__.py -------------------------------------------------------------------------------- /progres/chainsaw/src/domain_assignment/assigners.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from progres.chainsaw.src.domain_assignment.sparse_lowrank import greedy_V 4 | 5 | 6 | class BaseAssigner: 7 | 8 | def __call__(self, y_pred, uncertainty=False): 9 | assert y_pred.ndim == 2 10 | if torch.is_tensor(y_pred): 11 | y_pred = y_pred.detach().cpu().numpy() 12 | return self.assign_domains(y_pred) 13 | 14 | 15 | class SparseLowRank(BaseAssigner): 16 | 17 | """Use Brooks' method to generate a matrix of cluster assignments, 18 | and post-process to remove any excessively small clusters. 19 | """ 20 | 21 | def __init__(self, N_iters=3, K_init=4, linker_threshold=5, cost_type="mse"): 22 | self.N_iters = N_iters 23 | self.K_init = K_init 24 | self.linker_threshold = linker_threshold 25 | self.cost_type = cost_type 26 | 27 | def get_entropy(self, y_pred): 28 | """calculate the entropy of the upper triangle of the matrix 29 | used for uncertainty calculation but does not correlate with NDO as well 30 | as likelihood of assignment under y_pred 31 | """ 32 | return np.triu(-1 * y_pred * np.log(y_pred) - (1 - y_pred) * np.log(1 - y_pred), 1).sum() 33 | 34 | 35 | def assign_domains(self, y_pred): 36 | # N x K, columns are then indicator vectors 37 | epsilon = 1e-6 38 | y_pred = np.clip(y_pred, epsilon, 1-epsilon) 39 | V, loss = greedy_V(y_pred, N_iters=self.N_iters, K_init=self.K_init, cost_type=self.cost_type) 40 | K = V.shape[-1] 41 | A = V@V.T 42 | average_likelihood = np.exp((A * np.log(y_pred) + (1-(A))*np.log(1-y_pred)).mean()) 43 | # throw away small clusters 44 | V = np.where( 45 | V.sum(0, keepdims=True) < self.linker_threshold, # 1, K 46 | np.zeros_like(V), 47 | V, 48 | ) 49 | 50 | assignments = { 51 | "linker": np.argwhere((V == 0).sum(-1) == K).reshape(-1) 52 | } 53 | 54 | domain_ix = 1 55 | for col_ix in range(K): 56 | cluster_inds = np.argwhere(V[:, col_ix] == 1).reshape(-1) 57 | if cluster_inds.size > 0: 58 | assignments[f"domain_{domain_ix}"] = cluster_inds 59 | domain_ix += 1 60 | return assignments, round(average_likelihood, 4) 61 | 62 | 63 | class SpectralClustering(BaseAssigner): 64 | """Challenge is how to determine n_clusters. 65 | 66 | We can look at eigengaps 67 | Or use MSE, running clustering multiple times. 68 | Or use Brooks' method and then refine with SC. 69 | """ 70 | 71 | def __init__(self, n_cluster_method="mse", max_domains=10): 72 | self.n_cluster_method = n_cluster_method 73 | self.max_domains = max_domains # TODO - make this flexible or length dependent. 74 | 75 | def assign_domains(self, y_pred): 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /progres/chainsaw/src/domain_assignment/sparse_lowrank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def mse_loss_at_residue(V_hat, Y, residue_index): 5 | # loss contribution of residue given current assignment 6 | # assumes Y is symmetric 7 | d = residue_index 8 | # avoid double-counting the diagonal 9 | return 2*((V_hat[d]@V_hat.T - Y[d])**2).sum() - (V_hat[d]@V_hat[d] - Y[d,d])**2 10 | 11 | 12 | def null_mse_loss_at_residue(Y, residue_index): 13 | # loss contribution of residue when unassigned 14 | d = residue_index 15 | return 2*(Y[d]**2).sum() - Y[d,d]**2 16 | 17 | 18 | def mse_loss(V_hat, Y): 19 | # only use to compute initial loss 20 | return ((V_hat@V_hat.T - Y)**2).sum() 21 | 22 | 23 | def nll_loss(V_hat, Y): 24 | """We now interpret ys as probabilities. 25 | 26 | We want to minimise the negative log probability. 27 | 28 | NB Y must be symmetric and [0-1] 29 | """ 30 | # only use to compute initial loss 31 | log_pij = (V_hat@V_hat.T) * np.log(Y) + (1-(V_hat@V_hat.T))*np.log(1-Y) 32 | # simple sum double counts off diagonals but single counts on diagonals, so need 33 | # to add diagonal again before dividing by 2 34 | return -0.5*(log_pij.sum() + np.diag(log_pij).sum()) 35 | 36 | 37 | def nll_loss_at_residue(V_hat, Y, residue_index): 38 | # d is residue_index 39 | d = residue_index # V_hat[d]@V_hat.T is just Adj without instantiating A 40 | # n.b. whereas in case of mse we double-count off-diagonals, here we don't, hence the difference 41 | log_pdj = (V_hat[d]@V_hat.T) * np.log(Y[d]) + (1-(V_hat[d]@V_hat.T))*np.log(1-Y[d]) 42 | return -log_pdj.sum() 43 | 44 | 45 | def null_nll_loss_at_residue(Y, residue_index): 46 | return - np.log(1-Y[residue_index]).sum() 47 | 48 | 49 | def greedy_V(Y, N_iters=3, K_init=4, cost_type="mse"): 50 | """ 51 | Learn a binary matrix V, with at most one nonzero entry per row, to minimize 52 | 53 | || VV' - Y ||_2 54 | 55 | This is done by initializing a V of all zeros, and then doing a greedy optimization. 56 | V is initially D x K_init, where D is the number of residues (i.e., where Y is DxD). 57 | K is learned automatically; different values of K_init will not change the result but 58 | might be more or less efficient in terms of memory usage. The code keeps track of the 59 | current number of nonzero columns of V, and when all the columns are full it adds extra 60 | columns of zeros at the end. 61 | 62 | Each iteration sweeps through all residues (i.e. rows of V) once. 63 | 64 | The implementation relies on additivity of loss to reduce computational cost 65 | by only computing required increments to the loss at each iteration. 66 | 67 | 68 | INPUTS: 69 | 70 | Y: matrix of predictions with entries in [0, 1] 71 | N_iters: number of iterations 72 | K_init: initial number of columns of V, adjust this for tweaking performance 73 | """ 74 | Y = (Y + Y.T) / 2 # required for consistency with model_v1 75 | V_hat = np.zeros((Y.shape[0], K_init), dtype=np.uint8) 76 | if cost_type == "mse": 77 | loss = mse_loss(V_hat, Y) # initial loss for the zero matrix V_hat = 0 78 | elif cost_type == "nll": 79 | loss = nll_loss(V_hat, Y) 80 | else: 81 | raise ValueError(cost_type) 82 | K_max = K_init # track number of columns in K_max 83 | 84 | for it in range(N_iters): 85 | for d in range(V_hat.shape[0]): 86 | # d is a residue index 87 | # compute the loss, excluding contribution of this residue 88 | # TODO check minus signs are consistent 89 | loss_minus_d = loss - mse_loss_at_residue(V_hat, Y, d) 90 | 91 | # sweep through all K+1 options and compute what the contribution to the loss would be 92 | V_hat[d] *= 0 93 | 94 | # loss with no assignment 95 | L0 = loss_minus_d + null_mse_loss_at_residue(Y, d) 96 | L_opt = np.zeros(K_max) 97 | for k in range(K_max): 98 | # note this could be vectorized with a bit of work 99 | V_hat[d,k] = 1 100 | L_opt[k] = loss_minus_d + mse_loss_at_residue(V_hat, Y, d) 101 | V_hat[d,k] = 0 102 | 103 | # select the option which minimizes the squared error 104 | z = np.argmin(L_opt) 105 | if L_opt[z] < L0: 106 | V_hat[d,z] = 1 107 | 108 | # update loss 109 | loss = loss_minus_d + mse_loss_at_residue(V_hat, Y, d) 110 | 111 | if z == K_max-1: 112 | # Expand V_hat to make room for extra potential clusters (no nonzero columns remain) 113 | # TODO: note that if a cluster is "removed" on a later iteration we will miss this. 114 | # probably doesn't matter, just means it is occasionally possible that the V 115 | # that is returned could sometimes have a few nonzero columns. 116 | V_hat = np.concatenate((V_hat, np.zeros_like(V_hat)), -1) 117 | K_max = V_hat.shape[1] 118 | 119 | # drop columns that are zeros (i.e. unused clusters) before returning 120 | empty = V_hat.sum(0) == 0 121 | return V_hat[:,~empty], loss 122 | -------------------------------------------------------------------------------- /progres/chainsaw/src/domain_assignment/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilties for converting to/from a dictionary representation of domain assignments. 3 | """ 4 | 5 | import logging 6 | import os 7 | from itertools import product 8 | 9 | import numpy as np 10 | 11 | import warnings 12 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 13 | warnings.simplefilter('ignore', PDBConstructionWarning) 14 | 15 | LOG = logging.getLogger(__name__) 16 | 17 | 18 | def make_pair_labels(n_res, domain_dict, id_string=None, save_dir=None, non_aligned_residues=[]): 19 | """n_res: number of residues in the non-trimmed sequence 20 | 21 | non_aligned_residues: these will be used to trim down from n_res 22 | 23 | domain_dict: eg. {'D1': [0,1,2,3], 'D2': [4,5,6]} 24 | """ 25 | pair_labels = np.zeros([n_res, n_res]) 26 | for domain, res_ix in domain_dict.items(): 27 | if domain == 'linker': 28 | continue 29 | coords_tuples = list(product(res_ix, res_ix)) 30 | x_ix = [i[0] for i in coords_tuples] 31 | y_ix = [i[1] for i in coords_tuples] 32 | pair_labels[x_ix, y_ix] = 1 33 | if len(non_aligned_residues): 34 | aligned_residues = [i for i in range(n_res) if i not in non_aligned_residues] 35 | pair_labels = pair_labels[aligned_residues,:][:,aligned_residues] 36 | if save_dir is not None: 37 | save_path = os.path.join(save_dir, id_string) 38 | np.savez_compressed(save_path, pair_labels) 39 | 40 | return pair_labels 41 | 42 | 43 | def sort_domain_limits(limits, dom_names): 44 | start_positions = [x[0] for x in limits] 45 | end_positions = [x[1] for x in limits] 46 | sorted_index = np.argsort(start_positions) 47 | assert (sorted_index == np.argsort(end_positions)).all() 48 | return np.array(limits)[sorted_index], list(np.array(dom_names)[sorted_index]) 49 | 50 | 51 | def resolve_residue_in_multiple_domain(mapping, shared_res): 52 | """ 53 | This is a stupid slow recursive solution: but I think it only applies to one 54 | case so going to leave it for now 55 | """ 56 | for one_shared in shared_res: 57 | for domain, res in mapping.items(): 58 | if one_shared in res: 59 | mapping[domain].remove(one_shared) 60 | return check_no_residue_in_multiple_domains(mapping) 61 | 62 | 63 | def check_no_residue_in_multiple_domains(mapping, resolve_conflics=True): 64 | # ensures no residue index is associated with more than one domain 65 | for dom, res in mapping.items(): 66 | for dom2, res2 in mapping.items(): 67 | if dom == dom2: 68 | continue 69 | shared_res = set(res).intersection(set(res2)) 70 | if len(shared_res): 71 | print(f'Found {len(shared_res)} shared residues') 72 | if resolve_conflics: 73 | mapping = resolve_residue_in_multiple_domain(mapping, shared_res) 74 | else: 75 | raise ValueError("SAME RESIDUE NUMBER FOUND IN MULTIPLE DOMAINS") 76 | return mapping 77 | 78 | 79 | def make_domain_mapping_dict(row): 80 | dom_limit_list = row.dom_bounds_pdb_ix.split('|') 81 | dom_names = row.dom_names.split('|') 82 | dom_limit_list = convert_limits_to_numbers(dom_limit_list) 83 | dom_limit_array, dom_names = sort_domain_limits(dom_limit_list, dom_names) 84 | mapping = {} 85 | 86 | for i, d_lims in enumerate(dom_limit_array): 87 | dom_name = dom_names[i] 88 | pdb_start, pdb_end = d_lims 89 | if dom_name not in mapping: 90 | mapping[dom_name] = [] 91 | mapping[dom_name] += list(range(pdb_start, pdb_end)) 92 | check_no_residue_in_multiple_domains(mapping) 93 | return mapping 94 | 95 | 96 | def convert_limits_to_numbers(dom_limit_list): 97 | processed_dom_limit_list = [] 98 | for lim in dom_limit_list: 99 | dash_idx = [i for i, char in enumerate(lim) if char == '-'] 100 | if len(dash_idx) == 1: 101 | start_index = int(lim.split('-')[0]) -1 102 | end_index = int(lim.split('-')[1]) 103 | else: 104 | raise ValueError('Invalid format for domain limits', str(dom_limit_list)) 105 | processed_dom_limit_list.append((start_index, end_index)) 106 | return processed_dom_limit_list 107 | 108 | 109 | def convert_domain_dict_strings(domain_dict): 110 | """ 111 | Converts the domain dictionary into domain_name string and domain_bounds string 112 | eg. domain names D1|D2|D1 113 | eg. domain bounds 0-100|100-200|200-300 114 | """ 115 | domain_names = [] 116 | domain_bounds = [] 117 | for k,v in domain_dict.items(): 118 | if k=='linker': 119 | continue 120 | residues = sorted(v) 121 | for i, res in enumerate(residues): 122 | if i==0: 123 | start = res 124 | elif residues[i-1] != res - 1: 125 | domain_bounds.append(f'{start}-{residues[i-1]}') 126 | domain_names.append(k) 127 | start = res 128 | if i == len(residues)-1: 129 | domain_bounds.append(f'{start}-{res}') 130 | domain_names.append(k) 131 | 132 | return '|'.join(domain_names), '|'.join(domain_bounds) 133 | -------------------------------------------------------------------------------- /progres/chainsaw/src/domain_chop.py: -------------------------------------------------------------------------------- 1 | """Domain predictor classes. 2 | """ 3 | import os 4 | import glob 5 | import numpy as np 6 | from pathlib import Path 7 | import torch 8 | from torch import nn 9 | from progres.chainsaw.src.domain_assignment.util import make_pair_labels, make_domain_mapping_dict 10 | 11 | import logging 12 | LOG = logging.getLogger(__name__) 13 | 14 | 15 | def get_checkpoint_epoch(checkpoint_file): 16 | return int(os.path.splitext(checkpoint_file)[0].split(".")[-1]) 17 | 18 | 19 | class PairwiseDomainPredictor(nn.Module): 20 | 21 | """Wrapper for a pairwise domain co-membership predictor, adding in domain prediction post-processing.""" 22 | 23 | def __init__( 24 | self, 25 | model, 26 | domain_caller, 27 | device, 28 | loss="bce", 29 | x_has_padding_mask=True, 30 | mask_padding=True, 31 | n_checkpoints_to_average=1, 32 | checkpoint_dir=None, 33 | load_checkpoint_if_exists=False, 34 | save_val_best=True, 35 | max_recycles=0, 36 | post_process_domains=True, 37 | min_ss_components=2, 38 | min_domain_length=30, 39 | remove_disordered_domain_threshold=0, 40 | trim_each_domain=True, 41 | dist_transform_type="min_replace_inverse", 42 | distance_denominator=10, 43 | ): 44 | super().__init__() 45 | self._train_model = model # we want to keep this hanging around so that optimizer references dont break 46 | self.model = self._train_model 47 | self.domain_caller = domain_caller 48 | self.device = device 49 | self.x_has_padding_mask = x_has_padding_mask 50 | self.mask_padding = mask_padding # if True use padding mask to mask loss 51 | self.n_checkpoints_to_average = n_checkpoints_to_average 52 | self.checkpoint_dir = checkpoint_dir 53 | self._epoch = 0 54 | self.save_val_best = save_val_best 55 | self.best_val_metrics = {} 56 | self.max_recycles = max_recycles 57 | self.post_process_domains = post_process_domains 58 | self.remove_disordered_domain_threshold = remove_disordered_domain_threshold 59 | self.trim_each_domain = trim_each_domain 60 | self.min_domain_length = min_domain_length 61 | self.min_ss_components = min_ss_components 62 | self.dist_transform_type = dist_transform_type 63 | self.distance_denominator = distance_denominator 64 | if load_checkpoint_if_exists: 65 | checkpoint_files = sorted( 66 | glob.glob(os.path.join(self.checkpoint_dir, "weights*")), 67 | key=get_checkpoint_epoch, 68 | reverse=True, 69 | ) 70 | if len(checkpoint_files) > 0: 71 | self._epoch = get_checkpoint_epoch(checkpoint_files[0]) 72 | LOG.info(f"Loading saved checkpoint(s) ending at epoch {self._epoch}") 73 | self.load_checkpoints(average=True) 74 | self.load_checkpoints() 75 | else: 76 | LOG.info("No checkpoints found to load") 77 | 78 | if loss == "bce": 79 | self.loss_function = nn.BCELoss(reduction="none") 80 | elif loss == "mse": 81 | self.loss_function = nn.MSELoss(reduction="none") 82 | 83 | def load_checkpoints(self, average=False, old_style=False): 84 | if self.n_checkpoints_to_average == 1: 85 | data_dir = os.getenv( 86 | "PROGRES_DATA_DIR", 87 | default=Path(__file__).parent.parent.parent.resolve(), 88 | ) 89 | weights_file = os.path.join(data_dir, "chainsaw", "model_v3", "weights.pt") 90 | else: 91 | # for e.g. resetting training weights for next training epoch after testing with avg 92 | LOG.info(f"Loading last checkpoint (epoch {self._epoch})") 93 | weights_file = os.path.join(self.checkpoint_dir, f"weights.{self._epoch}.pt") 94 | LOG.info(f"Loading weights from: {weights_file}") 95 | state_dict = torch.load(weights_file, map_location=self.device) 96 | if old_style: 97 | self.load_state_dict(state_dict, strict=False) 98 | else: 99 | self.model.load_state_dict(state_dict) 100 | 101 | def predict_pairwise(self, x): 102 | x = x.to(self.device) 103 | if np.isnan(x.cpu().numpy()).any(): 104 | raise Exception('NAN values in data') 105 | y_pred = self.model(x).squeeze(1) # b, L, L 106 | assert y_pred.ndim == 3 107 | return y_pred 108 | 109 | def get_mask(self, x): 110 | """Binary mask 1 for observed, 0 for padding.""" 111 | x = x.to(self.device) 112 | if self.x_has_padding_mask: 113 | mask = 1 - x[:, -1] # b, L, L 114 | else: 115 | mask = None 116 | return mask 117 | 118 | def epoch_start(self): 119 | self.model = self._train_model 120 | self.model.train() 121 | self._epoch += 1 122 | 123 | def test_begin(self): 124 | if self.n_checkpoints_to_average > 1: 125 | torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, f'weights.{self._epoch}.pt')) 126 | start_idx = self._epoch - self.n_checkpoints_to_average 127 | if start_idx >= 2: 128 | os.remove(os.path.join(self.checkpoint_dir, f"weights.{start_idx-1}.pt")) 129 | else: 130 | torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, 'weights.pt')) 131 | 132 | if self.n_checkpoints_to_average > 1: 133 | # self.model.to("cpu") # free up gpu memory for average model 134 | self.load_checkpoints(average=True) 135 | 136 | self.model.eval() 137 | 138 | def forward(self, x, y, batch_average=True): 139 | """A training step.""" 140 | x, y = x.to(self.device), y.to(self.device) 141 | y_pred = self.predict_pairwise(x) 142 | mask = self.get_mask(x) 143 | return self.compute_loss(y_pred, y, mask=mask) 144 | 145 | def compute_loss(self, y_pred, y, mask=None, batch_average=True): 146 | y_pred, y = y_pred.to(self.device), y.to(self.device) 147 | if mask is None or not self.mask_padding: 148 | mask = torch.ones_like(y) 149 | # mask is b, L, L. To normalise correctly, we need to divide by number of observations 150 | loss = (self.loss_function(y_pred, y)*mask).sum((-1,-2)) / mask.sum((-1,-2)) 151 | 152 | # metrics characterising inputs: how many residues, how many with domain assignments. 153 | labelled_residues = ((y*mask).sum(-1) > 0).sum(-1) # b 154 | non_padding_residues = (mask.sum(-1) > 0).sum(-1) # b 155 | labelled_frac = labelled_residues / non_padding_residues 156 | metrics = { 157 | "labelled_residues": labelled_residues.detach().cpu().numpy(), 158 | "residues": non_padding_residues.detach().cpu().numpy(), 159 | "labelled_frac": labelled_frac.detach().cpu().numpy(), 160 | "loss": loss.detach().cpu().numpy(), 161 | } 162 | if batch_average: 163 | loss = loss.mean(0) 164 | metrics = {k: np.mean(v) for k, v in metrics.items()} 165 | 166 | return loss, metrics 167 | 168 | def domains_from_pairwise(self, y_pred): 169 | assert y_pred.ndim == 3 170 | domain_preds = [] 171 | confidence_list = [] 172 | for pred_single in y_pred.cpu().numpy(): 173 | single_domains, confidence = self.domain_caller(pred_single) 174 | domain_preds.append(single_domains) 175 | confidence_list.append(confidence) 176 | return domain_preds, confidence_list 177 | 178 | def distance_transform(self, x): 179 | dist_chan = x[0, 0] 180 | # Find the minimum non-zero value in the channel 181 | min_nonzero = dist_chan[dist_chan > 0].min() 182 | # Replace zero values in the channel with the minimum non-zero value 183 | dist_chan[dist_chan == 0] = min_nonzero 184 | if self.dist_transform_type == "min_replace_inverse": 185 | # replace zero values and then invert. 186 | dist_chan = dist_chan ** (-1) 187 | x[0, 0] = dist_chan 188 | return x 189 | 190 | elif self.dist_transform_type == "unidoc_exponent": # replace zero values in pae / distance 191 | spread = self.distance_denominator 192 | dist_chan = (1 + np.exp((dist_chan - 8) / spread)) ** -1 193 | x[0,0] = dist_chan 194 | return x 195 | 196 | @torch.no_grad() 197 | def predict(self, x, return_pairwise=True): 198 | x = self.distance_transform(x) 199 | if self.max_recycles > 0: 200 | for i in range(self.max_recycles): 201 | # add recycling channels 202 | n_res = x.shape[-1] 203 | recycle_channels = torch.zeros(1, 2, n_res, n_res) 204 | # Concatenate the original tensor and the zeros tensor along the second dimension 205 | x = torch.cat((x, recycle_channels), dim=1) 206 | x = self.recycle_predict(x) 207 | y_pred = self.predict_pairwise(x) 208 | domain_dicts, confidence = self.domains_from_pairwise(y_pred) 209 | if self.post_process_domains: 210 | domain_dicts = self.post_process(domain_dicts, x) # todo move this to domains from pairwise function 211 | if return_pairwise: 212 | return y_pred, domain_dicts, confidence 213 | else: 214 | return domain_dicts, confidence 215 | 216 | @torch.no_grad() 217 | def recycle_predict(self, x): 218 | x = x.to(self.device) 219 | y_pred = self.predict_pairwise(x) 220 | domain_dicts, confidence = self.domains_from_pairwise(y_pred) 221 | y_pred_from_domains = np.array( 222 | [make_pair_labels(n_res=x.shape[-1], domain_dict=d_dict) for d_dict in domain_dicts]) 223 | y_pred_from_domains = torch.tensor(y_pred_from_domains).to(self.device) 224 | if self.x_has_padding_mask: 225 | x[:, -2, :, :] = y_pred # assumes that last dimension is padding mask 226 | x[:, -3, :, :] = y_pred_from_domains 227 | else: 228 | x[:, -1, :, :] = y_pred 229 | x[:, -2, :, :] = y_pred_from_domains 230 | return x 231 | 232 | 233 | def post_process(self, domain_dicts, x_batch): 234 | new_domain_dicts = [] 235 | for domain_dict, x in zip(domain_dicts, x_batch): 236 | x = x.cpu().numpy() 237 | domain_dict = {k: list(v) for k, v in domain_dict.items()} 238 | helix, sheet = x[1], x[2] 239 | diag_helix = np.diagonal(helix) 240 | diag_sheet = np.diagonal(sheet) 241 | ss_residues = list(np.where(diag_helix > 0)[0]) + list(np.where(diag_sheet > 0)[0]) 242 | 243 | domain_dict = self.trim_disordered_boundaries(domain_dict, ss_residues) 244 | 245 | if self.remove_disordered_domain_threshold > 0: 246 | domain_dict = self.remove_disordered_domains(domain_dict, ss_residues) 247 | 248 | if self.min_ss_components > 0: 249 | domain_dict = self.remove_domains_with_few_ss_components(domain_dict, x) 250 | 251 | if self.min_domain_length > 0: 252 | domain_dict = self.remove_domains_with_short_length(domain_dict) 253 | new_domain_dicts.append(domain_dict) 254 | return new_domain_dicts 255 | 256 | def trim_disordered_boundaries(self, domain_dict, ss_residues): 257 | if not self.trim_each_domain: 258 | start = min(ss_residues) 259 | end = max(ss_residues) 260 | for dname, res in domain_dict.items(): 261 | if dname == "linker": 262 | continue 263 | if self.trim_each_domain: 264 | domain_specific_ss = set(ss_residues).intersection(set(res)) 265 | if len(domain_specific_ss) == 0: 266 | continue 267 | start = min(domain_specific_ss) 268 | end = max(domain_specific_ss) 269 | domain_dict["linker"] += [r for r in res if r < start or r > end] 270 | domain_dict[dname] = [r for r in res if r >= start and r <= end] 271 | return domain_dict 272 | 273 | def remove_disordered_domains(self, domain_dict, ss_residues): 274 | new_domain_dict = {} 275 | for dname, res in domain_dict.items(): 276 | if dname == "linker": 277 | continue 278 | if len(res) == 0: 279 | continue 280 | if len(set(res).intersection(set(ss_residues))) / len(res) < self.remove_disordered_domain_threshold: 281 | domain_dict["linker"] += res 282 | else: 283 | new_domain_dict[dname] = res 284 | new_domain_dict["linker"] = domain_dict["linker"] 285 | return new_domain_dict 286 | 287 | 288 | def remove_domains_with_few_ss_components(self, domain_dict, x): 289 | """ 290 | Remove domains where number of ss components is less than minimum 291 | eg if self.min_ss_components=2 domains made of only a single helix or sheet are removed 292 | achieve this by counting the number of unique string hashes in domain rows of x 293 | """ 294 | new_domain_dict = {} 295 | for dname, res in domain_dict.items(): 296 | if dname == "linker": 297 | continue 298 | res = sorted(res) 299 | helix = x[1][res, :][:, res] 300 | strand = x[2][res, :][:, res] 301 | helix = helix[np.any(helix, axis=1)] 302 | strand = strand[np.any(strand, axis=1)] 303 | # residues in the same secondary structure component have the same representation in the helix or strand matrix 304 | n_helix = len(set(["".join([str(int(i)) for i in row]) for row in helix])) 305 | n_sheet = len(set(["".join([str(int(i)) for i in row]) for row in strand])) 306 | if len(res) == 0: 307 | continue 308 | if n_helix + n_sheet < self.min_ss_components: 309 | domain_dict["linker"] += res 310 | else: 311 | new_domain_dict[dname] = res 312 | new_domain_dict["linker"] = domain_dict["linker"] 313 | return new_domain_dict 314 | 315 | def remove_domains_with_short_length(self, domain_dict): 316 | """ 317 | Remove domains where length is less than minimum 318 | """ 319 | new_domain_dict = {} 320 | for dname, res in domain_dict.items(): 321 | if dname == "linker": 322 | continue 323 | 324 | 325 | if len(res) < self.min_domain_length: 326 | domain_dict["linker"] += res 327 | else: 328 | new_domain_dict[dname] = res 329 | new_domain_dict["linker"] = domain_dict["linker"] 330 | return new_domain_dict 331 | -------------------------------------------------------------------------------- /progres/chainsaw/src/errors.py: -------------------------------------------------------------------------------- 1 | class BaseError(Exception): 2 | """ 3 | Error class for all local exceptions in this code base 4 | """ 5 | pass 6 | 7 | class PredictionResultExistsError(BaseError): 8 | pass 9 | 10 | class FileExistsError(BaseError): 11 | pass -------------------------------------------------------------------------------- /progres/chainsaw/src/factories.py: -------------------------------------------------------------------------------- 1 | """There are three configurable things: predictors/models, data, training/evaluation. 2 | 3 | Only first two require factories. 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | import logging 10 | 11 | from progres.chainsaw.src.domain_chop import PairwiseDomainPredictor 12 | from progres.chainsaw.src.models.rosetta import trRosettaNetwork 13 | from progres.chainsaw.src.domain_assignment.assigners import SparseLowRank 14 | from progres.chainsaw.src.utils import common as common_utils 15 | 16 | 17 | LOG = logging.getLogger(__name__) 18 | 19 | 20 | def get_assigner(config): 21 | assigner_type = config["type"] 22 | if assigner_type == "sparse_lowrank": 23 | assigner = SparseLowRank(**config["kwargs"]) 24 | else: 25 | return ValueError() 26 | return assigner 27 | 28 | 29 | def get_model(config): 30 | model_type = config["type"] 31 | if model_type == "trrosetta": 32 | model = trRosettaNetwork(**config["kwargs"]) 33 | else: 34 | return ValueError() 35 | return model 36 | 37 | 38 | def pairwise_predictor(learner_config, force_cpu=False, output_dir=None, device="cpu"): 39 | model = get_model(learner_config["model"]) 40 | assigner = get_assigner(learner_config["assignment"]) 41 | device = torch.device(device) 42 | model.to(device) 43 | kwargs = {k: v for k, v in learner_config.items() if k not in ["model", 44 | "assignment", 45 | "save_every_epoch", 46 | "uncertainty_model"]} 47 | LOG.info(f"Learner kwargs: {kwargs}") 48 | return PairwiseDomainPredictor(model, assigner, device, checkpoint_dir=output_dir, **kwargs) 49 | -------------------------------------------------------------------------------- /progres/chainsaw/src/featurisers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List 4 | from tempfile import NamedTemporaryFile 5 | 6 | import Bio.PDB 7 | from Bio.PDB.mmtf import MMTFParser 8 | import numpy as np 9 | import torch 10 | from scipy.spatial import distance_matrix 11 | 12 | from progres.chainsaw.src.constants import _3to1, STRIDE_EXE 13 | from progres.chainsaw.src.utils.cif2pdb import cif2pdb 14 | from progres.chainsaw.src.utils.secondary_structure import calculate_ss, make_ss_matrix 15 | 16 | LOG = logging.getLogger(__name__) 17 | 18 | 19 | def get_model_structure(structure_path, fileformat="pdb") -> Bio.PDB.Structure: 20 | """ 21 | Returns the Bio.PDB.Structure object for a given PDB, mmCIF or MMTF file 22 | """ 23 | structure_id = os.path.split(structure_path)[-1].split('.')[0] 24 | if fileformat == "pdb": 25 | structure = Bio.PDB.PDBParser().get_structure(structure_id, structure_path) 26 | elif fileformat == "mmcif": 27 | structure = Bio.PDB.MMCIFParser().get_structure(structure_id, structure_path) 28 | elif fileformat == "mmtf": 29 | structure = MMTFParser.get_structure(structure_path) 30 | elif fileformat == "coords": 31 | raise ValueError("Coordinate file format not compatible with Chainsaw") 32 | else: 33 | raise ValueError(f"Unrecognized file format: {fileformat}") 34 | model = structure[0] 35 | return model 36 | 37 | 38 | class Residue: 39 | def __init__(self, index: int, res_label: str, aa: str): 40 | self.index = int(index) 41 | self.res_label = str(res_label) 42 | self.aa = str(aa) 43 | 44 | def get_model_structure_residues(structure_model: Bio.PDB.Structure, chain='A') -> List[Residue]: 45 | """ 46 | Returns a list of residues from a given PDB or MMCIF structure 47 | """ 48 | residues = [] 49 | res_index = 1 50 | for biores in structure_model[chain].child_list: 51 | res_num = biores.id[1] 52 | res_ins = biores.id[2] 53 | res_label = str(res_num) 54 | if res_ins != ' ': 55 | res_label += str(res_ins) 56 | 57 | aa3 = biores.get_resname() 58 | if aa3 not in _3to1: 59 | continue 60 | 61 | aa = _3to1[aa3] 62 | res = Residue(res_index, res_label, aa) 63 | residues.append(res) 64 | 65 | # increment the residue index after we have filtered out non-standard amino acids 66 | res_index += 1 67 | 68 | return residues 69 | 70 | 71 | def inference_time_create_features(file_path, feature_config, chain="A", *, 72 | model_structure: Bio.PDB.Structure=None, 73 | stride_path=STRIDE_EXE, fileformat="pdb", 74 | ): 75 | if fileformat == "pdb": 76 | pdb_path = file_path 77 | else: 78 | temp_pdb_file = NamedTemporaryFile() 79 | pdb_path = temp_pdb_file.name 80 | cif2pdb(file_path, pdb_path, fileformat) 81 | 82 | if not model_structure: 83 | model_structure = get_model_structure(pdb_path) 84 | 85 | dist_matrix = get_distance(model_structure, chain=chain) 86 | temp_ss_file = NamedTemporaryFile() 87 | ss_filepath = temp_ss_file.name 88 | calculate_ss(pdb_path, chain, stride_path, ssfile=ss_filepath) 89 | helix, strand = make_ss_matrix(ss_filepath, nres=dist_matrix.shape[-1]) 90 | if feature_config['ss_bounds']: 91 | end_res_val = -1 if feature_config['negative_ss_end'] else 1 92 | helix_boundaries = make_boundary_matrix(helix, end_res_val=end_res_val) 93 | strand_boundaries = make_boundary_matrix(strand, end_res_val=end_res_val) 94 | temp_ss_file.close() 95 | if fileformat != "pdb": 96 | temp_pdb_file.close() 97 | LOG.info(f"Distance matrix shape: {dist_matrix.shape}, SS matrix shape: {helix.shape}") 98 | if feature_config['ss_bounds']: 99 | if feature_config['same_channel_boundaries_and_ss']: 100 | helix_boundaries[helix == 1] = 1 101 | strand_boundaries[strand == 1] = 1 102 | stacked_features = np.stack((dist_matrix, helix_boundaries, strand_boundaries), axis=0) 103 | else: 104 | stacked_features = np.stack((dist_matrix, helix, strand, helix_boundaries, strand_boundaries), axis=0) 105 | else: 106 | stacked_features = np.stack((dist_matrix, helix, strand), axis=0) 107 | stacked_features = stacked_features[None] # add batch dimension 108 | return torch.Tensor(stacked_features) 109 | 110 | 111 | def distance_matrix(x): 112 | """Compute the distance matrix. 113 | 114 | Returns the matrix of all pair-wise distances. 115 | 116 | Parameters 117 | ---------- 118 | x : (M, K) array_like 119 | Matrix of M vectors in K dimensions. 120 | p : float, 1 <= p <= infinity 121 | Which Minkowski p-norm to use. 122 | Returns 123 | ------- 124 | result : (M, M) ndarray 125 | Matrix containing the distance from every vector in `x` to every vector 126 | in `x`. 127 | """ 128 | x1 = x[:, np.newaxis, :] # Expand x to 3D for broadcasting, shape (M, 1, K) 129 | x2 = x[np.newaxis, :, :] # Expand x to 3D for broadcasting, shape (1, M, K) 130 | 131 | distance_matrix = np.sum(np.abs(x2 - x1) ** 2, axis=-1) ** 0.5 132 | 133 | return distance_matrix.astype(np.float16) 134 | 135 | 136 | def get_distance(structure_model: Bio.PDB.Structure, chain='A'): 137 | alpha_coords = np.array([residue['CA'].get_coord() for residue in \ 138 | structure_model[chain].get_residues() if Bio.PDB.is_aa(residue) and \ 139 | 'CA' in residue and residue.get_resname() in _3to1], dtype=np.float16) 140 | x = distance_matrix(alpha_coords) 141 | return x 142 | 143 | 144 | def make_boundary_matrix(ss, end_res_val=1): 145 | """ 146 | makes a matrix where the boundary residues 147 | of the sec struct component are 1 148 | """ 149 | ss_lines = np.zeros_like(ss) 150 | diag = np.diag(ss) 151 | if max(diag) == 0: 152 | return ss_lines 153 | padded_diag = np.zeros(len(diag) + 2) 154 | padded_diag[1:-1] = diag 155 | diff_before = diag - padded_diag[:-2] 156 | diff_after = diag - padded_diag[2:] 157 | start_res = np.where(diff_before == 1)[0] 158 | end_res = np.where(diff_after == 1)[0] 159 | ss_lines[start_res, :] = 1 160 | ss_lines[:, start_res] = 1 161 | ss_lines[end_res, :] = end_res_val 162 | ss_lines[:, end_res] = end_res_val 163 | return ss_lines 164 | -------------------------------------------------------------------------------- /progres/chainsaw/src/loggers.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import logging 5 | LOG = logging.getLogger(__name__) 6 | 7 | def get_versioned_dir(output_dir, version=None, resume=False): 8 | """version gets dir for specific version, resume gets dir for last version.""" 9 | if version is None: 10 | current_versions = glob.glob(os.path.join(output_dir, "version*")) 11 | if current_versions: 12 | last_version = max([int(os.path.basename(v).split("_")[1]) for v in current_versions]) 13 | version = last_version if resume else last_version + 1 14 | else: 15 | assert not resume, f"Passed resume True but no matching directories in {output_dir}" 16 | version = 1 17 | 18 | version_dir = os.path.join(output_dir, f"version_{version}") 19 | return version_dir, version 20 | 21 | 22 | def log_epoch_metrics( 23 | epoch, 24 | metrics, 25 | output_file, 26 | extra_keys=None, 27 | start_epoch=0, 28 | new_file=False 29 | ): 30 | """ 31 | New file gets created if epoch == 1 32 | 33 | We are going for a hierarchical structure /experiment_group/model_name/train_metrics.csv etc 34 | because this works best with tensorboard and avoids file clutter in a single 35 | experiment_group directory 36 | 37 | tensorboard refs: 38 | https://pytorch.org/docs/stable/tensorboard.html 39 | https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html 40 | """ 41 | # output_filename = (model_name + f"_{msa_name}" + f"_vae" + 42 | # ("_posembed{args.pos_embed_dim}" if args.embed_pos else "")) 43 | 44 | metrics.pop("epoch", None) 45 | metric_names = list(metrics.keys()) 46 | extra_keys = extra_keys or [] 47 | assert all([m not in metric_names for m in extra_keys]), f"{metric_names} {extra_keys}" 48 | metric_names += list(extra_keys) 49 | 50 | if new_file: # c.f. training/core epoch 0 is for validation. 51 | with open(output_file, "w") as csvf: 52 | csvf.write(",".join(["epoch"] + metric_names) + "\n") 53 | 54 | with open(output_file, "a") as csvf: 55 | csvf.write(",".join([str(epoch + start_epoch)] + [str(metrics.get(m, "")) for m in metric_names])+"\n") 56 | 57 | 58 | class StdOutLogger: 59 | 60 | def __init__(self, log_freq, start_epoch=0): 61 | self.start_epoch = start_epoch 62 | self.log_freq = log_freq 63 | 64 | def log(self, epoch, metrics, batch=None): 65 | if self.log_freq is not None and epoch % self.log_freq == 0: 66 | if batch is None: 67 | header = f"Epoch {epoch + self.start_epoch}: " 68 | else: 69 | header = f"[{epoch:d}, {batch:5d}]: " 70 | 71 | train_metric_components = [f"{m}: {v:.3f} " for m, v in metrics.items() if not m.startswith("val_")] 72 | if train_metric_components: 73 | LOG.info( 74 | header 75 | + " ".join(train_metric_components), 76 | ) 77 | val_metric_components = [f"{m}: {v:.3f} " for m, v in metrics.items() if m.startswith("val_")] 78 | if val_metric_components: 79 | LOG.info( 80 | " ".join(val_metric_components), 81 | ) 82 | if batch is None: 83 | LOG.info("--------------------------------------\n") 84 | 85 | 86 | class CSVLogger: 87 | def __init__(self, output_dir, start_epoch=0): 88 | self.output_dir = output_dir 89 | self.start_epoch = start_epoch 90 | self.val_keys = None 91 | self.filename = f"train_log.{'' if start_epoch == 0 else (str(start_epoch) + '.')}csv" 92 | self.logged = 0 93 | 94 | @property 95 | def filepath(self): 96 | return str(os.path.join(self.output_dir, self.filename)) 97 | 98 | def log(self, epoch, metrics, batch=None): 99 | metrics["batch"] = batch 100 | 101 | if epoch == 0: 102 | self.val_keys = metrics.keys() 103 | elif self.output_dir is not None and epoch > 0: 104 | # LOG.info([k for k in metrics.keys() if k not in self._prev_keys]) 105 | extra_keys = [k for k in self.val_keys if k not in metrics and k != "epoch"] 106 | os.makedirs(self.output_dir, exist_ok=True) 107 | log_epoch_metrics( 108 | epoch, 109 | metrics, 110 | self.filepath, 111 | extra_keys=extra_keys, 112 | start_epoch=self.start_epoch, 113 | new_file=self.logged == 0 114 | ) 115 | self.logged += 1 116 | # self._prev_keys = metrics.keys() 117 | 118 | 119 | class LoggerContainer: 120 | 121 | def __init__(self, loggers, start_epoch=0): 122 | self.train_log = [] 123 | self.loggers = loggers 124 | self.start_epoch = start_epoch 125 | 126 | def log(self, epoch, metrics, batch=None): 127 | for logger in self.loggers: 128 | logger.log(epoch, metrics, batch=batch) 129 | metrics["epoch"] = epoch + self.start_epoch 130 | self.train_log.append(metrics) 131 | -------------------------------------------------------------------------------- /progres/chainsaw/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/progres/chainsaw/src/models/__init__.py -------------------------------------------------------------------------------- /progres/chainsaw/src/models/results.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | from pydantic import BaseModel, validator 4 | 5 | def chain_id_from_pdb_path(pdb_path: Path): 6 | return pdb_path.stem 7 | 8 | class PredictionResult(BaseModel): 9 | 10 | pdb_path: Path 11 | confidence: float 12 | chain_id: Optional[str] 13 | sequence_md5: str 14 | ndom: int 15 | nres: int 16 | chopping: Optional[str] 17 | time_sec: Optional[float] 18 | 19 | @validator('chain_id', always=True, pre=True, allow_reuse=True) 20 | def set_chain_id(cls, v, values): 21 | if v is None: 22 | return chain_id_from_pdb_path(values['pdb_path']) 23 | return v 24 | -------------------------------------------------------------------------------- /progres/chainsaw/src/models/rosetta.py: -------------------------------------------------------------------------------- 1 | from torch import nn, transpose 2 | 3 | 4 | def elu(): 5 | return nn.ELU(inplace=True) 6 | 7 | 8 | def instance_norm(filters, eps=1e-6, **kwargs): 9 | return nn.InstanceNorm2d(filters, affine=True, eps=eps, **kwargs) 10 | 11 | 12 | def conv2d(in_chan, out_chan, kernel_size, dilation=1, **kwargs): 13 | padding = dilation * (kernel_size - 1) // 2 14 | return nn.Conv2d(in_chan, out_chan, kernel_size, padding=padding, dilation=dilation, **kwargs) 15 | 16 | 17 | class trRosettaNetwork(nn.Module): 18 | def __init__(self, filters=64, kernel=3, num_layers=61, in_channels=3, symmetrise_output=False, dropout=0.15): 19 | super().__init__() 20 | self.filters = filters 21 | self.kernel = kernel 22 | self.num_layers = num_layers 23 | self.in_channels = in_channels 24 | self.symmetrise_output = symmetrise_output 25 | self.first_block = nn.Sequential( 26 | conv2d(self.in_channels, filters, 1), 27 | instance_norm(filters), 28 | elu() 29 | ) 30 | self.output_layer = nn.Sequential( 31 | conv2d(filters, 1, kernel, dilation=1), 32 | nn.Sigmoid()) 33 | 34 | 35 | # stack of residual blocks with dilations 36 | cycle_dilations = [1, 2, 4, 8, 16] 37 | dilations = [cycle_dilations[i % len(cycle_dilations)] for i in range(num_layers)] 38 | if dropout > 0: 39 | self.layers = nn.ModuleList([nn.Sequential( 40 | conv2d(filters, filters, kernel, dilation=dilation), 41 | instance_norm(filters), 42 | elu(), 43 | nn.Dropout(p=dropout), 44 | conv2d(filters, filters, kernel, dilation=dilation), 45 | instance_norm(filters) 46 | ) for dilation in dilations]) 47 | else: 48 | self.layers = nn.ModuleList([nn.Sequential( 49 | conv2d(filters, filters, kernel, dilation=dilation), 50 | instance_norm(filters), 51 | elu(), 52 | conv2d(filters, filters, kernel, dilation=dilation), 53 | instance_norm(filters) 54 | ) for dilation in dilations]) 55 | 56 | self.activate = elu() 57 | 58 | 59 | def forward(self, x): 60 | x = self.first_block(x) 61 | 62 | for layer in self.layers: 63 | x = self.activate(x + layer(x)) 64 | y_hat = self.output_layer(x) 65 | if self.symmetrise_output: 66 | return (y_hat + transpose(y_hat, -1, -2)) * 0.5 67 | else: 68 | return y_hat 69 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/progres/chainsaw/src/utils/__init__.py -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/cif2pdb.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | """ 3 | Script to convert mmCIF files to PDB format. 4 | usage: python cif2pdb.py ciffile [pdbfile] 5 | Requires python BioPython (`pip install biopython`). It should work with recent version of python 2 or 3. 6 | @author Spencer Bliven 7 | """ 8 | import os 9 | import sys 10 | import argparse 11 | import logging 12 | from Bio.PDB.MMCIFParser import MMCIFParser 13 | from Bio.PDB.mmtf import MMTFParser 14 | from Bio.PDB import PDBIO 15 | 16 | LOG = logging.getLogger(__name__) 17 | 18 | 19 | def int_to_chain(i, base=62): 20 | """ 21 | int_to_chain(int,int) -> str 22 | Converts a positive integer to a chain ID. Chain IDs include uppercase 23 | characters, numbers, and optionally lowercase letters. 24 | i = a positive integer to convert 25 | base = the alphabet size to include. Typically 36 or 62. 26 | """ 27 | if i < 0: 28 | raise ValueError("positive integers only") 29 | if base < 0 or 62 < base: 30 | raise ValueError("Invalid base") 31 | 32 | quot = int(i) // base 33 | rem = i % base 34 | if rem < 26: 35 | letter = chr(ord("A") + rem) 36 | elif rem < 36: 37 | letter = str(rem - 26) 38 | else: 39 | letter = chr(ord("a") + rem - 36) 40 | if quot == 0: 41 | return letter 42 | else: 43 | return int_to_chain(quot - 1, base) + letter 44 | 45 | 46 | class OutOfChainsError(Exception): 47 | pass 48 | 49 | 50 | def rename_chains(structure): 51 | """Renames chains to be one-letter chains 52 | 53 | Existing one-letter chains will be kept. Multi-letter chains will be truncated 54 | or renamed to the next available letter of the alphabet. 55 | 56 | If more than 62 chains are present in the structure, raises an OutOfChainsError 57 | 58 | Returns a map between new and old chain IDs, as well as modifying the input structure 59 | """ 60 | next_chain = 0 # 61 | # single-letters stay the same 62 | chainmap = {c.id: c.id for c in structure.get_chains() if len(c.id) == 1} 63 | for o in structure.get_chains(): 64 | if len(o.id) != 1: 65 | if o.id[0] not in chainmap: 66 | chainmap[o.id[0]] = o.id 67 | o.id = o.id[0] 68 | else: 69 | c = int_to_chain(next_chain) 70 | while c in chainmap: 71 | next_chain += 1 72 | c = int_to_chain(next_chain) 73 | if next_chain >= 62: 74 | raise OutOfChainsError() 75 | chainmap[c] = o.id 76 | o.id = c 77 | return chainmap 78 | 79 | def cif2pdb(ciffile, pdbfile, fileformat): 80 | strucid = os.path.split(ciffile)[-1].split(".")[0] 81 | # Read file 82 | if fileformat == "mmcif": 83 | parser = MMCIFParser() 84 | structure = parser.get_structure(strucid, ciffile) 85 | elif fileformat == "mmtf": 86 | structure = MMTFParser.get_structure(ciffile) 87 | else: 88 | raise ValueError(f"Unrecognized file format: {fileformat}") 89 | # rename long chains 90 | try: 91 | chainmap = rename_chains(structure) 92 | except OutOfChainsError: 93 | logging.error("Too many chains to represent in PDB format") 94 | sys.exit(1) 95 | 96 | for new, old in chainmap.items(): 97 | if new != old: 98 | logging.info("Renaming chain {0} to {1}".format(old, new)) 99 | 100 | for atom in structure.get_atoms(): 101 | if atom.get_serial_number() > 99990: 102 | atom.set_serial_number(99990) # Leave space for TER 103 | 104 | # Write PDB 105 | io = PDBIO() 106 | io.set_structure(structure) 107 | io.save(pdbfile, preserve_atom_numbering=True) 108 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import subprocess 4 | import torch 5 | 6 | import logging 7 | LOG = logging.getLogger(__name__) 8 | 9 | 10 | def load_json(jsonfile): 11 | with open(jsonfile, "r") as jf: 12 | res = json.load(jf) 13 | return res 14 | 15 | 16 | def save_config(config, filepath): 17 | with open(filepath, "w") as outfile: 18 | json.dump(config, outfile, indent=4) 19 | 20 | 21 | def apply_diff(base_cfg, diff): 22 | cfg = copy.deepcopy(base_cfg) 23 | for k, v in diff.items(): 24 | if isinstance(v, dict): 25 | cfg[k] = apply_diff(base_cfg[k], v) 26 | else: 27 | cfg[k] = v 28 | return cfg 29 | 30 | 31 | def execute_bash_command(bash_command_string): 32 | bash_return = subprocess.run(bash_command_string.split(), timeout=20) 33 | return bash_return 34 | 35 | 36 | def get_torch_device(force_cpu=False): 37 | try: 38 | if torch.cuda.is_available() and not force_cpu: 39 | device_string = "cuda" 40 | # elif torch.backends.mps.is_available() and not force_cpu: 41 | # device_string = "mps" 42 | else: 43 | device_string = "cpu" 44 | except Exception as exc: 45 | LOG.error(f'Exception: {exc}') 46 | device_string = "cpu" 47 | LOG.info(f'Using device: {device_string}') 48 | device = torch.device(device_string) 49 | return device 50 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/domain_boundary_distance_score.py: -------------------------------------------------------------------------------- 1 | """ 2 | Domain boundary distance score, as defined in CASP 7 paper: 3 | Assessment of predictions submitted for the CASP7 domain prediction category (2007) 4 | https://onlinelibrary.wiley.com/doi/10.1002/prot.21675 5 | 6 | Under our scoring scheme all predictions within 8 residues of the correct boundary 7 | will score, but predictions that are closer to the correct domain boundary would score more. 8 | All distances between the predicted and correct domain boundaries are calculated. If the 9 | domain boundary has a linker, the whole linker is regarded as the domain boundary. 10 | predictions are given one point for being within 1 residue of each correct boundary, 11 | another point if they are within two residues, a further point if they are within three, 12 | and so on up to eight residues. A prediction two residues away from the correct boundary 13 | would therefore have 7 points. 14 | The total score for each domain prediction is then calculated as the sum of all predicted 15 | boundary scores divided by eight and the total number of domain boundaries. The number of 16 | domain boundaries comes from either the target or the prediction, whichever is higher. In 17 | this way over-prediction is penalized. 18 | 19 | """ 20 | import numpy as np 21 | 22 | 23 | def pred_domains_to_bounds(pred_domains, optimize_for_linkers=False): 24 | """ 25 | Converts domain dictionary to list of boundary residues 26 | """ 27 | pred_bounds = np.array([]) 28 | for name, res in pred_domains.items(): 29 | if name=='linker' or len(res) ==0: 30 | continue 31 | res = np.array(sorted(set(res))) 32 | gaps = res[1:] - res[:-1] 33 | if any(gaps > 1): 34 | gaps_start_indexes = np.where(gaps > 1)[0] 35 | pred_bounds = np.append(pred_bounds, res[gaps_start_indexes] + 1) 36 | pred_bounds = np.append(pred_bounds, res[gaps_start_indexes + 1]) 37 | pred_bounds = np.append(pred_bounds, [res[0], res[-1] +1]) 38 | if optimize_for_linkers: 39 | # given that linkers all count as boundaries probs possible 40 | # to score more points by predicting boundaries in the middle of linker 41 | # regions rather than the edges 42 | raise NotImplementedError # todo 43 | return np.array(sorted(list(set(pred_bounds.astype(int))))) 44 | 45 | 46 | def get_true_boundary_res(domain_dict): 47 | """ 48 | In the case where there are multiple non-domain 49 | residues between two domains, all of these NDRs 50 | are counted as domain boundaries. This adjustment 51 | is applied to the true boundaries but not the 52 | predicted boundaries. 53 | """ 54 | bounds = pred_domains_to_bounds(domain_dict) 55 | # c.f. get_boundary_res below 56 | boundaries = { 57 | "boundary_res": list(bounds), 58 | "n_boundaries": len(list(bounds)) 59 | } 60 | boundaries["boundary_res"] += list(domain_dict["linker"]) 61 | return boundaries 62 | 63 | 64 | def boundary_distance_score(domains, boundaries): 65 | pred_bounds = pred_domains_to_bounds(domains) 66 | # distance score as specified in CASP7 paper {distance_to_true_bound:score} 67 | dist_to_score = {0:8, 1:7, 2:6, 3:5, 4:4, 5:3, 6:2, 7:1} 68 | score = 0 69 | # the final score is divided by the number of segments and the maximum score for each boundary 70 | normalizing_term = 8 * max(len(pred_bounds), boundaries['n_boundaries']) 71 | scores = [] 72 | for b in pred_bounds: 73 | distance = min(abs(boundaries['boundary_res'] - b)) 74 | if distance < 8: 75 | scores.append(dist_to_score[distance]) 76 | # JW adjustment: additional boundaries that can't be mapped to a real bound 77 | # should not be added to the un-normalized score 78 | score = sum(sorted(scores)[-boundaries['n_boundaries']:]) 79 | return score / normalizing_term 80 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/ndo_score.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import copy 3 | import numpy as np 4 | 5 | 6 | def check_unique_assignments(assn): 7 | res_counts = defaultdict(int) # check for repeated residues 8 | for res_ids in assn.values(): 9 | for res_id in res_ids: 10 | res_counts[res_id] += 1 11 | assert all([v == 1 for v in res_counts.values()]), f"Non-unique domain assignment {assn}" 12 | 13 | 14 | def make_domain_dict(domain_str, n_res): 15 | """ 16 | Converts string representation of domain boundaries to a dictionary 17 | Delimit separate domains with commas , and discontinuous domains 18 | with underscores _. Residue ranges separated by hyphens -, e.g. 1-100,101-200_300-340 19 | :param domain_str: domain boundaries expressed zero-indexed sequential e.g. 1-100,101-200_300-340 20 | :param n_res: number of residues in the sequence 21 | :return: 22 | """ 23 | domain_dict = {} 24 | bounds = domain_str.split(',') 25 | assigned_res = set() 26 | for i, bound in enumerate(bounds): 27 | if len(bound): 28 | 29 | for segment in bound.split('_'): 30 | if '-' in segment: 31 | start, end = segment.split('-') 32 | else: 33 | start = end = segment 34 | segment_res = set(range(int(start), int(end) + 1)) 35 | assert len(segment_res.intersection(assigned_res)) == 0, f"Overlapping domain assignments {domain_str}" 36 | assigned_res.update(segment_res) 37 | domain_dict[f"D{i + 1}"] = domain_dict.get(f"D{i + 1}", []) + list(segment_res) 38 | domain_dict['linker'] = list(set(range(n_res)).difference(assigned_res)) 39 | return domain_dict 40 | 41 | 42 | def ndo_score(true: dict, pred: dict): 43 | """ 44 | Normalized Domain Overlap Score 45 | Approximately corresponds to the fraction of residues that are assigned to the correct domain 46 | (Tai, C.H., et al. Evaluation of domain prediction in CASP6. Proteins 2005;61:183-192.) 47 | Full description of algorithm at https://ccrod.cancer.gov/confluence/display/CCRLEE/NDO] 48 | https://ccrod.cancer.gov/confluence/display/CCRLEE/NDO 49 | :param true: dict of domain assignments, with keys "linker" and "dX" where X is the domain number 50 | :param pred: dict of domain assignments, with keys "linker" and "dX" where X is the domain number 51 | example domain dictionary: {'linker':[0,1], 'D1':[2,3,4,8,9], 'D2':[5,6,7]} 52 | """ 53 | 54 | # domains definitions must be mutually exclusive. 55 | check_unique_assignments(true) 56 | check_unique_assignments(pred) 57 | # alternative data structure would be a list of res ids for linker, plus a list of lists of res ids for domains. 58 | true = copy.deepcopy(true) 59 | pred = copy.deepcopy(pred) 60 | 61 | n_dom_pred = len([k for k in pred.keys() if k != "linker"]) 62 | n_dom_gt = len([k for k in true.keys() if k != "linker"]) 63 | 64 | # linkers are treated specially, so put them at row/col 0 to make this easy 65 | 66 | pred_linker = pred.pop("linker", []) 67 | gt_linker = true.pop("linker", []) 68 | 69 | gt_res_ids = [gt_linker] + list(true.values()) 70 | pred_res_ids = [pred_linker] + list(pred.values()) 71 | 72 | overlap = np.zeros((n_dom_pred + 1, n_dom_gt + 1)) 73 | assert len(gt_res_ids) == overlap.shape[1] and len(pred_res_ids) == overlap.shape[0] 74 | 75 | for i, p_res in enumerate(pred_res_ids): 76 | for j, gt_res in enumerate(gt_res_ids): 77 | overlap[i, j] = len(set(p_res).intersection(set(gt_res))) 78 | 79 | # modified from v0: max overlap for a domain cannot be with a linker (hence max slices start at 1) 80 | if overlap.shape[0] == 1 or overlap.shape[1] == 1: 81 | # either the predictions or the true labels (or both) are all linker 82 | print('NDO score undefined for case where all linker') 83 | return 0 84 | row_scores = overlap[1:, 1:].max(axis=1) 85 | row_scores -= (overlap[1:].sum(axis=1) - row_scores) 86 | 87 | col_scores = overlap[1:, 1:].max(axis=0) 88 | col_scores -= (overlap[:, 1:].sum(axis=0) - col_scores) 89 | 90 | total_score = (row_scores.sum() + col_scores.sum()) / 2 91 | # count number of domain (non-linker) residues to normalize 92 | max_score = sum([len(gt_res) for gt_res in gt_res_ids[1:]]) 93 | 94 | return total_score / max_score 95 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/pdb_reres.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright 2018 João Pedro Rodrigues 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """ 19 | Renumbers the residues of the PDB file starting from a given number (default 1). 20 | 21 | Usage: 22 | python pdb_reres.py - 23 | 24 | Example: 25 | python pdb_reres.py -10 1CTF.pdb # renumbers from 10 26 | python pdb_reres.py --1 1CTF.pdb # renumbers from -1 27 | 28 | This program is part of the `pdb-tools` suite of utilities and should not be 29 | distributed isolatedly. The `pdb-tools` were created to quickly manipulate PDB 30 | files using the terminal, and can be used sequentially, with one tool streaming 31 | data to another. They are based on old FORTRAN77 code that was taking too much 32 | effort to maintain and compile. RIP. 33 | """ 34 | 35 | import os 36 | import sys 37 | 38 | import logging 39 | LOG = logging.getLogger(__name__) 40 | 41 | 42 | __author__ = "Joao Rodrigues" 43 | __email__ = "j.p.g.l.m.rodrigues@gmail.com" 44 | 45 | 46 | def check_input(args): 47 | """Checks whether to read from stdin/file and validates user input/options. 48 | """ 49 | 50 | # Defaults 51 | option = 1 52 | fh = sys.stdin # file handle 53 | 54 | if not len(args): 55 | # Reading from pipe with default option 56 | if sys.stdin.isatty(): 57 | sys.stderr.write(__doc__) 58 | sys.exit(1) 59 | 60 | elif len(args) == 1: 61 | # One of two options: option & Pipe OR file & default option 62 | if args[0].startswith('-'): 63 | option = args[0][1:] 64 | if sys.stdin.isatty(): # ensure the PDB data is streamed in 65 | emsg = 'ERROR!! No data to process!\n' 66 | sys.stderr.write(emsg) 67 | sys.stderr.write(__doc__) 68 | sys.exit(1) 69 | 70 | else: 71 | if not os.path.isfile(args[0]): 72 | emsg = 'ERROR!! File not found or not readable: \'{}\'\n' 73 | sys.stderr.write(emsg.format(args[0])) 74 | sys.stderr.write(__doc__) 75 | sys.exit(1) 76 | 77 | fh = open(args[0], 'r') 78 | 79 | elif len(args) == 2: 80 | # Two options: option & File 81 | if not args[0].startswith('-'): 82 | emsg = 'ERROR! First argument is not an option: \'{}\'\n' 83 | sys.stderr.write(emsg.format(args[0])) 84 | sys.stderr.write(__doc__) 85 | sys.exit(1) 86 | 87 | if not os.path.isfile(args[1]): 88 | emsg = 'ERROR!! File not found or not readable: \'{}\'\n' 89 | sys.stderr.write(emsg.format(args[1])) 90 | sys.stderr.write(__doc__) 91 | sys.exit(1) 92 | 93 | option = args[0][1:] 94 | fh = open(args[1], 'r') 95 | 96 | else: # Whatever ... 97 | sys.stderr.write(__doc__) 98 | sys.exit(1) 99 | 100 | # Validate option 101 | try: 102 | option = int(option) 103 | except ValueError: 104 | emsg = 'ERROR!! You provided an invalid residue number: \'{}\'' 105 | sys.stderr.write(emsg.format(option)) 106 | sys.exit(1) 107 | 108 | return (fh, option) 109 | 110 | 111 | def pad_line(line): 112 | """Helper function to pad line to 80 characters in case it is shorter""" 113 | size_of_line = len(line) 114 | if size_of_line < 80: 115 | padding = 80 - size_of_line + 1 116 | line = line.strip('\n') + ' ' * padding + '\n' 117 | return line[:81] # 80 + newline character 118 | 119 | 120 | def run(fhandle, starting_resid): 121 | """ 122 | Reset the residue number column to start from a specific number. 123 | 124 | This function is a generator. 125 | 126 | Parameters 127 | ---------- 128 | fhandle : a line-by-line iterator of the original PDB file. 129 | 130 | starting_resid : int 131 | The starting residue number. 132 | 133 | Yields 134 | ------ 135 | str (line-by-line) 136 | The modified (or not) PDB line. 137 | """ 138 | _pad_line = pad_line 139 | prev_resid = None # tracks chain and resid 140 | resid = starting_resid - 1 # account for first residue 141 | records = ('ATOM', 'HETATM', 'TER', 'ANISOU') 142 | for line in fhandle: 143 | line = _pad_line(line) 144 | if line.startswith('MODEL'): 145 | resid = starting_resid - 1 # account for first residue 146 | prev_resid = None # tracks chain and resid 147 | yield line 148 | 149 | elif line.startswith(records): 150 | line_resuid = line[17:27] 151 | if line_resuid != prev_resid: 152 | prev_resid = line_resuid 153 | resid += 1 154 | if resid > 9999: 155 | emsg = 'Cannot set residue number above 9999.\n' 156 | sys.stderr.write(emsg) 157 | sys.exit(1) 158 | 159 | yield line[:22] + str(resid).rjust(4) + line[26:] 160 | 161 | else: 162 | yield line 163 | 164 | 165 | renumber_residues = run 166 | 167 | 168 | def main(): 169 | # Check Input 170 | pdbfh, starting_resid = check_input(sys.argv[1:]) 171 | 172 | # Do the job 173 | new_pdb = run(pdbfh, starting_resid) 174 | 175 | # Output results 176 | try: 177 | _buffer = [] 178 | _buffer_size = 5000 # write N lines at a time 179 | for lineno, line in enumerate(new_pdb): 180 | if not (lineno % _buffer_size): 181 | sys.stdout.write(''.join(_buffer)) 182 | _buffer = [] 183 | _buffer.append(line) 184 | 185 | sys.stdout.write(''.join(_buffer)) 186 | sys.stdout.flush() 187 | except IOError: 188 | # This is here to catch Broken Pipes 189 | # for example to use 'head' or 'tail' without 190 | # the error message showing up 191 | pass 192 | 193 | # last line of the script 194 | # Close file handle even if it is sys.stdin, no problem here. 195 | pdbfh.close() 196 | sys.exit(0) 197 | 198 | 199 | if __name__ == '__main__': 200 | main() 201 | -------------------------------------------------------------------------------- /progres/chainsaw/src/utils/secondary_structure.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created by Jude Wells 2023-04-20 3 | Objective is to create a secondary structure matrix for each protein 4 | 1) renumbers the pdb file (1-indexed) 5 | 2) runs stride to get secondary structure file 6 | 3) parses the secondary structure file to create a matrix of secondary structure 7 | 8 | # should take around 3 minutes for 1000 structures 9 | """ 10 | 11 | import os 12 | import re 13 | import subprocess 14 | import sys 15 | 16 | import numpy as np 17 | 18 | from progres.chainsaw.src.constants import REPO_ROOT 19 | import logging 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | def calculate_ss(pdbfile, chain, stride_path, ssfile='pdb_ss'): 24 | assert os.path.exists(pdbfile) 25 | with open(ssfile, 'w') as ssout_file: 26 | chain_arg = '-r' + chain 27 | args = [stride_path, pdbfile, chain_arg.rstrip()] 28 | LOG.info(f"Running command: {' '.join(args)}") 29 | try: 30 | subprocess.run(args, 31 | stdout=ssout_file, 32 | stderr=subprocess.DEVNULL, 33 | check=True) 34 | except subprocess.CalledProcessError as e: 35 | LOG.warning(f"Stride failed on {pdbfile}, creating empty file") 36 | pass 37 | 38 | 39 | def make_ss_matrix(ss_path, nres): 40 | # create matrices for helix and strad residues where entry ij = 1 if i and j are in the same helix or strand 41 | with open(ss_path) as f: 42 | lines = f.readlines() 43 | type_set = set() 44 | helix = np.zeros([nres, nres], dtype=np.float32) 45 | strand = np.zeros([nres, nres], dtype=np.float32) 46 | for line in lines: 47 | if line.startswith('LOC'): 48 | start = int(re.sub('\D', '',line[22:28].strip())) 49 | end = int(re.sub('\D', '',line[40:46].strip())) 50 | type = line[5:17].strip() 51 | type_set.add(type) 52 | if type in ['AlphaHelix', '310Helix']: 53 | helix[start-1:end, start-1:end] = 1 54 | elif type == 'Strand': 55 | strand[start-1:end, start-1:end] = 1 56 | elif line.startswith('ASG'): 57 | break 58 | return helix, strand 59 | 60 | 61 | def renum_pdb_file(pdb_path, output_pdb_path): 62 | pdb_reres_path = REPO_ROOT / 'src/utils/pdb_reres.py' 63 | with open(output_pdb_path, "w") as output_file: 64 | subprocess.run([sys.executable, str(pdb_reres_path), pdb_path], 65 | stdout=output_file, 66 | check=True, 67 | text=True) 68 | 69 | 70 | def main(chain_ids, pdb_dir, feature_dir, stride_path, reres_path, savedir, job_index=0): 71 | os.makedirs(savedir, exist_ok=True) 72 | os.makedirs(os.path.join(savedir, '2d_features'), exist_ok=True) 73 | for chain_id in chain_ids: 74 | try: 75 | pdb_path = os.path.join(pdb_dir, chain_id + '.pdb') 76 | if os.path.exists(pdb_path): 77 | features = np.load(os.path.join(feature_dir, chain_id + '.npz'))['arr_0'] 78 | nres = features.shape[-1] 79 | LOG.info("Processing", pdb_path) 80 | chain = chain_id[4] 81 | output_pdb_path = os.path.join(savedir, f"{job_index}.pdb") # this gets overwritten to save memory 82 | file_nres = renum_pdb_file(pdb_path, reres_path, output_pdb_path) 83 | if nres != file_nres: 84 | with open(os.path.join(savedir, 'error.txt'), 'a') as f: 85 | msg = f' residue number mismatch (from features) {nres}, (from pdb file) {file_nres}' 86 | f.write(chain_id + msg + '\n') 87 | ss_filepath = os.path.join(savedir, f'pdb_ss{job_index}.txt') # this gets overwritten to save memory 88 | calculate_ss(output_pdb_path, chain, stride_path, ssfile=ss_filepath) 89 | helix, strand = make_ss_matrix(ss_filepath, nres=nres) 90 | np.savez_compressed( 91 | os.path.join(*[savedir, '2d_features', chain_id + '.npz']), 92 | np.stack((features, helix, strand), axis=0)) 93 | except Exception as e: 94 | with open(os.path.join(savedir, 'error.txt'), 'a') as f: 95 | f.write(chain_id + str(e) + '\n') 96 | -------------------------------------------------------------------------------- /progres/databases/README.md: -------------------------------------------------------------------------------- 1 | As of v0.2.0 the embedded databases are downloaded from https://zenodo.org/record/7782088 during installation. 2 | -------------------------------------------------------------------------------- /progres/trained_models/README.md: -------------------------------------------------------------------------------- 1 | As of v0.2.0 the trained model is downloaded from https://zenodo.org/record/7782088 during installation. 2 | -------------------------------------------------------------------------------- /scripts/dataset/README.md: -------------------------------------------------------------------------------- 1 | The training, validation and testing datasets used in the results are `train_sid.txt`, `val_unseen_sid.txt` and `test_unseen_sid.txt` respectively. 2 | The `sfam` variants are for the model with superfamily holdout. 3 | -------------------------------------------------------------------------------- /scripts/dataset/test_seen.txt: -------------------------------------------------------------------------------- 1 | d4u60a_ b.121.6.1 2 | d5y4ta1 c.47.1.0 3 | d3crna1 c.23.1.0 4 | d6t76a_ d.58.5.0 5 | d3u0oa1 d.79.4.0 6 | d1yw4a1 c.56.5.7 7 | d1n26a1 b.1.1.4 8 | d5lypa1 a.118.8.0 9 | d2beca_ a.39.1.0 10 | d2fh1a2 d.109.1.1 11 | d1hc1a3 b.1.18.3 12 | d2f23a1 a.2.1.1 13 | d3bpva_ a.4.5.0 14 | d3i35a_ b.34.2.1 15 | d1c48a_ b.40.2.1 16 | d6atla1 g.3.7.2 17 | d3tc2a_ b.42.4.0 18 | d4weea1 b.7.1.2 19 | d1r5ba2 b.44.1.1 20 | d3rgaa1 d.17.4.0 21 | d5zfga1 c.47.1.0 22 | d2al6a1 a.11.2.1 23 | d1wfqa1 b.40.4.5 24 | d1p28a1 a.39.2.1 25 | d3zzza_ d.58.7.1 26 | d3s5oa_ c.1.10.0 27 | d2wqpa1 c.1.10.6 28 | d4lcta1 a.118.8.9 29 | d3pl2a_ c.72.1.0 30 | d1fltx_ b.1.1.4 31 | d4lzla_ c.23.1.0 32 | d3ffra1 c.67.1.0 33 | d2dixa1 d.50.1.1 34 | d2anua1 c.6.3.1 35 | d5u4na_ c.1.10.0 36 | d1nowa2 d.92.2.1 37 | d1dx5i2 g.3.11.1 38 | d3b0ga1 d.58.36.0 39 | d2vzsa3 b.1.4.1 40 | d3n79a2 d.58.56.0 41 | d1wdda2 d.58.9.1 42 | d5ui3a_ c.1.10.0 43 | d3mr7a1 d.58.29.0 44 | d1t6sa2 a.4.5.60 45 | d2gz6a_ a.102.1.0 46 | d3sk3a1 c.55.1.0 47 | d3f7pc1 b.1.2.1 48 | d1egua2 b.24.1.1 49 | d2basa2 d.110.6.2 50 | d2fona3 a.29.3.0 51 | d2vuaa2 b.42.4.2 52 | d3ojmb2 b.1.1.0 53 | d5kbxb_ c.23.1.0 54 | d1w3ba_ a.118.8.1 55 | d1nr0a1 b.69.4.1 56 | d2b5xa1 c.47.1.10 57 | d3sk3a2 c.55.1.0 58 | d2ltka1 c.47.1.0 59 | d2ps2a2 c.1.11.0 60 | d6c3ma1 d.58.36.0 61 | d2czca1 d.81.1.1 62 | d4kbqa_ a.118.8.0 63 | d1ueba3 b.40.4.5 64 | d1sfla_ c.1.10.1 65 | d2hoea3 c.55.1.10 66 | d2cxha1 c.51.1.2 67 | d1o94a1 c.1.4.1 68 | d6nyta1 b.7.1.0 69 | d1i1ga2 d.58.4.2 70 | d4mf5a1 c.47.1.0 71 | d1biha3 b.1.1.4 72 | d6vk6b_ a.25.1.2 73 | d2gfna1 a.4.1.9 74 | d1jf8a_ c.44.1.1 75 | d3ui4a1 d.26.1.1 76 | d2bmxa1 c.47.1.10 77 | d3cq5a_ c.67.1.0 78 | d1b9wa2 g.3.11.4 79 | d5e4xa_ b.34.13.2 80 | d3q0ha1 b.1.1.0 81 | d3tg7a2 b.121.2.2 82 | d6id0e_ b.69.4.0 83 | d1x5xa1 b.1.2.1 84 | d4f53a_ a.118.8.0 85 | d2fo1e1 d.211.1.1 86 | d3w1ya_ d.15.1.0 87 | d3ugua2 b.1.18.11 88 | d1nowa1 c.1.8.6 89 | d1rq5a1 a.102.1.2 90 | d2h1va_ c.92.1.1 91 | d2r58a2 b.34.9.0 92 | d2w2ne1 g.3.11.0 93 | d3vypa1 b.1.18.0 94 | d7ahsa1 b.1.1.0 95 | d1fs0g_ c.49.2.1 96 | d5le5g_ d.153.1.4 97 | d2fh5a1 d.110.4.4 98 | d1vdza3 a.69.1.2 99 | d1ydya_ c.1.18.3 100 | d2ccya_ a.24.3.2 101 | d1xvsa_ b.1.23.1 102 | d2ch9a1 d.17.1.0 103 | d2edya1 b.1.2.0 104 | d2xmwa_ d.58.17.0 105 | d4e2ba_ c.1.4.0 106 | d4kala_ c.72.1.0 107 | d3t9ja_ a.25.1.1 108 | d2w40a1 c.55.1.0 109 | d1ofcx1 a.4.1.3 110 | d1am9a_ a.38.1.1 111 | d5o2va_ d.58.7.0 112 | d3vzxa_ c.1.4.1 113 | d2m8ua1 a.39.1.0 114 | d1hqz1_ d.109.1.2 115 | d2yuqa1 b.34.2.1 116 | d3lo3a1 d.58.4.0 117 | d1yc61_ b.121.4.5 118 | d3iara1 c.1.9.1 119 | d1edqa1 b.1.18.2 120 | d4a25a_ a.25.1.0 121 | d2mysb_ a.39.1.5 122 | d3m8oh2 b.1.1.0 123 | d3td3a1 d.79.7.0 124 | d6jmia_ a.4.5.0 125 | d1wwca_ b.1.1.4 126 | d1ygya3 d.58.18.1 127 | d1xb4a1 a.4.5.54 128 | d1h3na3 c.26.1.1 129 | d1x52a1 d.79.3.2 130 | d1yvua1 b.34.14.1 131 | d4qmea3 b.1.30.1 132 | d2egca1 b.34.2.0 133 | d2cu7a1 a.4.1.3 134 | d1hc7a1 c.51.1.1 135 | d1xwdc1 c.10.2.7 136 | d5nl9a1 a.4.5.0 137 | d3cu3a1 d.17.4.28 138 | d4ggfa_ a.39.1.2 139 | d4offa_ b.82.3.0 140 | d2fc8a1 d.58.7.0 141 | d1wl8a1 c.23.16.1 142 | d1qzga_ b.40.4.3 143 | d1z1sa1 d.17.4.10 144 | d2hoxa_ c.67.1.1 145 | d3v4da_ d.79.1.0 146 | d2hnba_ c.23.5.0 147 | d2w70a3 b.84.2.1 148 | d2alea1 d.79.3.1 149 | d1zx5a1 b.82.1.3 150 | d2mj5b1 a.5.2.1 151 | d1f20a1 b.43.4.1 152 | d5umsa_ b.55.1.0 153 | d7rjja_ d.79.7.0 154 | d3hzra_ c.26.1.0 155 | d3p0ya1 c.10.2.5 156 | d3cjkb1 d.58.17.1 157 | d1za0a1 a.25.1.2 158 | d1clca2 b.1.18.2 159 | d1x54a1 b.40.4.0 160 | d5hjfa_ a.25.1.0 161 | d1f13a1 b.1.18.9 162 | d7nnfa2 c.78.1.0 163 | d2daka1 a.5.2.0 164 | d6ncra_ c.26.1.0 165 | d5r4oa_ a.29.2.0 166 | d2ktea1 d.129.3.5 167 | d2pa7a1 b.82.1.1 168 | d4g1ma3 b.1.15.0 169 | d2wm1a_ c.1.9.0 170 | d3iq1a_ a.25.1.0 171 | d4utua1 c.1.2.0 172 | d1iknd_ d.211.1.1 173 | d3ckca_ a.118.8.6 174 | d4eo7a1 c.23.2.0 175 | d5tqja_ c.23.1.0 176 | d1wa5c_ a.118.1.1 177 | d3qe2a1 c.23.5.0 178 | d5xega_ b.82.2.0 179 | d1aa6a1 b.52.2.2 180 | d5j90a_ a.118.8.0 181 | d1rwaa2 b.24.1.1 182 | d5k2ia1 c.47.1.0 183 | d6hhea_ a.39.2.0 184 | d2piaa1 b.43.4.2 185 | d2z3ha_ c.97.1.1 186 | d4c5ka1 c.72.1.0 187 | d6kp3a_ a.118.8.0 188 | d3kzoa1 c.78.1.0 189 | d1r1ta_ a.4.5.5 190 | d1nfva_ a.25.1.1 191 | d2fnua_ c.67.1.4 192 | d2pnwa1 b.52.1.4 193 | d2v09a_ b.82.1.2 194 | d2yuma1 a.4.1.0 195 | d1n57a_ c.23.16.2 196 | d2j9ub1 g.41.11.1 197 | d1nbwb_ c.51.3.2 198 | d1mrza1 b.43.5.1 199 | d3giua_ c.56.4.0 200 | d2oz4a2 b.1.1.0 201 | d3dg3a2 c.1.11.0 202 | d3teoa_ c.53.2.0 203 | d1m5q1_ b.38.1.1 204 | d1wema1 g.50.1.2 205 | d6gsca1 a.2.11.0 206 | d1rm6a1 d.41.1.1 207 | d1rqga2 c.26.1.1 208 | d6fgga_ a.29.2.0 209 | d4b8ja_ a.118.1.0 210 | d3mdfa_ d.58.7.1 211 | d2ez6a2 d.50.1.1 212 | d2sqca2 a.102.4.2 213 | d1z1va_ a.60.1.2 214 | d2diba1 b.1.18.10 215 | d2b1xa2 d.129.3.3 216 | d2q9ua2 c.23.5.1 217 | d5khlb1 c.92.2.0 218 | d5fxda2 d.58.32.0 219 | d2ao3a_ b.42.2.1 220 | d3f9ma1 c.55.1.0 221 | d1moxc_ g.3.11.1 222 | d4kb1a_ c.55.3.5 223 | d2cw9a1 d.17.4.13 224 | d2j01s1 c.55.4.1 225 | d3zuda_ b.1.18.0 226 | d4mjea_ c.47.1.0 227 | d1k5nb1 b.1.1.2 228 | d3tvqa_ d.129.3.6 229 | d6lxua_ c.67.1.0 230 | d2nn6h2 b.84.4.2 231 | d2h6fa1 a.118.6.1 232 | d2rg7a_ c.92.2.0 233 | d2p8ba2 c.1.11.0 234 | d2i4aa_ c.47.1.1 235 | d2zcta_ c.47.1.10 236 | d3g9qa1 c.92.2.0 237 | d1suua_ b.68.10.1 238 | d4h8wc1 b.1.1.1 239 | d5nrha2 d.142.1.0 240 | d5gmkn1 b.69.4.1 241 | d4b3la_ c.1.8.0 242 | d1c0da_ c.1.8.3 243 | d2foka2 a.4.5.12 244 | d1saza2 c.55.1.2 245 | d3ovpa_ c.1.2.0 246 | d2vbua_ b.43.5.2 247 | d1rjta_ d.9.1.1 248 | d2cqqa1 a.4.1.3 249 | d4du5a_ c.72.1.0 250 | d1o1za1 c.1.18.3 251 | d4nhja_ a.4.6.0 252 | d2q4ea2 d.81.1.5 253 | d4lxoa2 b.1.2.1 254 | d3cu7ae a.102.4.4 255 | d2yc1c_ g.3.7.1 256 | d3c7ma_ c.47.1.0 257 | d1xi9a1 c.67.1.1 258 | d1s70b_ d.211.1.1 259 | d1wx8a1 d.15.1.1 260 | d1otra_ a.5.2.4 261 | d2ooda2 c.1.9.9 262 | d3up8a1 c.1.7.0 263 | d1wgna1 a.5.2.1 264 | d4wbja_ c.47.1.0 265 | d3efma_ f.4.3.0 266 | d1b0ba_ a.1.1.2 267 | d2knia1 g.3.6.2 268 | d1iuja_ d.58.4.5 269 | d4ur7a_ c.1.10.0 270 | d1wfja1 b.7.1.2 271 | d3hola2 f.4.1.3 272 | d2r85a2 d.142.1.9 273 | d3n9ra_ c.1.10.0 274 | d1jtaa_ b.80.1.1 275 | d6a77a1 b.1.1.0 276 | d1zunb1 b.43.3.1 277 | d4nf7a1 c.1.8.0 278 | d2kgta_ b.34.2.0 279 | d1zvfa1 b.82.1.20 280 | d1a8la1 c.47.1.2 281 | d7d5va1 b.6.1.0 282 | d3drna_ c.47.1.0 283 | d3ndna_ c.67.1.0 284 | d3r5ta1 c.92.2.0 285 | d1q5ya_ d.58.18.4 286 | d4gosa1 b.1.1.0 287 | d1kf6b2 d.15.4.2 288 | d2k5ga1 d.129.3.5 289 | d3cxga1 c.47.1.0 290 | d1a4pa_ a.39.1.2 291 | d2ezia_ a.4.1.2 292 | d2rnrb_ b.55.1.9 293 | d3ef4a_ b.6.1.0 294 | d4etna1 c.44.1.0 295 | d1mija_ a.4.1.1 296 | d3l3ba_ c.23.16.0 297 | d5tfma2 b.1.6.0 298 | d1vema2 c.1.8.1 299 | d1s6ia_ a.39.1.5 300 | d6r2wt1 b.1.2.1 301 | d2i10a1 a.4.1.9 302 | d6swhb_ b.2.3.0 303 | d6zgxa2 d.58.26.7 304 | d1st6a3 a.24.9.1 305 | d2bhua1 b.1.18.2 306 | d2dkxa1 a.60.1.0 307 | d4uqta_ d.58.7.0 308 | d1ipaa2 d.79.3.3 309 | d2v44a2 b.1.1.0 310 | d2zdha2 d.142.1.0 311 | d3iwza1 b.82.3.0 312 | d1ft9a1 a.4.5.4 313 | d2o90a_ d.96.1.0 314 | d7o6ea_ a.25.1.0 315 | d4q9aa_ c.23.10.0 316 | d1g12a_ d.92.1.12 317 | d4rega_ d.58.53.0 318 | d2absa1 c.72.1.1 319 | d3dhxa1 d.58.18.13 320 | d2heoa_ a.4.5.19 321 | d6i57a1 a.118.8.0 322 | d1jb7a2 b.40.4.3 323 | d6d5xa_ a.25.2.0 324 | d1iarb2 b.1.2.1 325 | d1h54a1 a.102.1.4 326 | d1sgma1 a.4.1.9 327 | d6j33a3 b.1.18.0 328 | d1ew4a_ d.82.2.1 329 | d1pqsa_ d.15.2.2 330 | d1skza1 g.3.15.1 331 | d5xlyb_ b.45.2.1 332 | d3qrya_ a.102.1.0 333 | d5jowa1 b.67.2.0 334 | d1pn0a2 c.47.1.10 335 | d1zrra_ b.82.1.6 336 | d3eh8a2 d.95.2.1 337 | d3prna_ f.4.3.1 338 | d1pgl11 b.121.4.2 339 | d1wgra1 d.15.1.5 340 | d3w07a_ c.1.2.3 341 | d6znva1 a.29.2.0 342 | d4r9oa_ c.1.7.1 343 | d1kw4a_ a.60.1.2 344 | d2ovla2 c.1.11.0 345 | d1wlja_ c.55.3.5 346 | d1iqza_ d.58.1.4 347 | d2pkfa1 c.72.1.0 348 | d1xkwa_ f.4.3.0 349 | d4r3na2 d.81.1.1 350 | d2a1ia1 c.52.1.20 351 | d1opca_ a.4.6.1 352 | d4ybra1 c.26.1.0 353 | d2cdqa3 d.58.18.10 354 | d3hcna_ c.92.1.1 355 | d4cu7a2 b.1.4.1 356 | d1vkua1 a.28.1.1 357 | d2fc9a1 d.58.7.0 358 | d4rcab_ c.10.2.0 359 | d1a5ea_ d.211.1.1 360 | d6gg9a1 d.110.3.0 361 | d2nw0a_ c.1.8.0 362 | d3dyda_ c.67.1.0 363 | d3hz4a_ c.47.1.0 364 | d7buga2 d.129.3.3 365 | d1z6om1 a.25.1.1 366 | d1cz9a_ c.55.3.2 367 | d6q00a_ d.15.1.1 368 | d2a73b2 b.1.29.1 369 | d1ilr1_ b.42.1.2 370 | d1x58a1 a.4.1.1 371 | d1bqua1 b.1.2.1 372 | d1ihga1 a.118.8.1 373 | d3cjia_ b.121.4.0 374 | d5buva1 b.82.1.0 375 | d6izha_ d.79.1.0 376 | d3lyea_ c.1.12.0 377 | d2c9aa1 b.1.1.4 378 | d2cq2a1 d.58.7.1 379 | d1iarb1 b.1.2.1 380 | d6pz7a_ c.1.8.0 381 | d3epqa2 c.55.1.10 382 | d3bq7a1 a.60.1.0 383 | d2cuha1 b.1.2.1 384 | d3op1a2 b.43.5.0 385 | d4zm3a_ c.67.1.0 386 | d4hqaa_ d.110.3.6 387 | d2f8ja2 c.67.1.1 388 | d2w59a2 b.1.1.0 389 | d2m80a_ c.47.1.0 390 | d2gm6a_ b.82.1.19 391 | d3lo8a1 b.43.4.0 392 | d5o9zg_ a.118.8.1 393 | d3fdxa1 c.26.2.0 394 | d2cqka1 a.4.5.46 395 | d4g6za1 c.26.1.0 396 | d6m8oa_ c.23.1.0 397 | d1bw5a_ a.4.1.1 398 | d4fkea4 a.118.1.0 399 | d1t3ta4 d.79.4.1 400 | d4hgma_ b.1.1.1 401 | -------------------------------------------------------------------------------- /scripts/dataset/test_unseen_sfam.txt: -------------------------------------------------------------------------------- 1 | d1wgya1 d.15.1.5 2 | d1sqja1 b.69.13.1 3 | d3gd6a2 c.1.11.0 4 | d2fhqa1 b.45.1.1 5 | d1v5da_ a.102.1.2 6 | d1o5la1 b.82.3.2 7 | d3ftba_ c.67.1.0 8 | d1jpdx1 c.1.11.2 9 | d2eo1a1 b.1.1.0 10 | d6w2ga_ a.5.2.1 11 | d1j6oa1 c.1.9.12 12 | d1z96a1 a.5.2.1 13 | d1jgsa_ a.4.5.28 14 | d1t0kb_ d.79.3.1 15 | d3fava_ a.25.3.1 16 | d1wyja1 b.1.6.3 17 | d4l0ma1 c.56.2.0 18 | d7b1sb1 d.58.31.0 19 | d6d4ra_ c.1.5.0 20 | d3edha_ d.92.1.0 21 | d1u4ga_ d.92.1.2 22 | d5i90a_ c.67.1.0 23 | d1jpma1 c.1.11.2 24 | d4q97a1 b.1.1.0 25 | d3k9oa2 a.5.2.1 26 | d6ewja_ c.67.1.0 27 | d1e7ua3 d.15.1.5 28 | d3k7ya_ c.67.1.0 29 | d4esfa1 a.4.5.0 30 | d3fvsa_ c.67.1.1 31 | d4dzbb1 b.1.1.0 32 | d1q3ea_ b.82.3.2 33 | d2ap2a1 b.1.1.1 34 | d3gw2a_ a.4.5.0 35 | d3cafa_ b.1.1.4 36 | d1cf7a_ a.4.5.17 37 | d1w23a_ c.67.1.4 38 | d5ko6a_ c.56.2.0 39 | d4lc3a_ c.67.1.0 40 | d3dgba2 c.1.11.0 41 | d1wm3a_ d.15.1.1 42 | d1tvia_ d.92.1.15 43 | d3cjna_ a.4.5.0 44 | d2pn6a1 a.4.5.0 45 | d4raya1 a.4.5.0 46 | d1s3ja_ a.4.5.28 47 | d2al3a1 d.15.1.2 48 | d2xaua3 a.4.5.0 49 | d1veka1 a.5.2.1 50 | d3ij6a1 c.1.9.0 51 | d7k63u_ a.4.5.0 52 | d1zxqa2 b.1.1.4 53 | d3jw4a_ a.4.5.0 54 | d5dd8a_ a.4.5.28 55 | d2nrac2 a.4.5.10 56 | d1ixca1 a.4.5.37 57 | d4zwva1 c.67.1.0 58 | d1i42a_ d.15.1.2 59 | d1rhfa1 b.1.1.1 60 | d3bn3b2 b.1.1.0 61 | d6mdha1 d.15.1.0 62 | d1r1ua_ a.4.5.5 63 | d1gl4b_ b.1.1.4 64 | d1zkha1 d.15.1.1 65 | d2co5a1 a.4.5.48 66 | d2fpqa1 d.92.1.0 67 | d3keoa1 a.4.5.0 68 | d4wk7a_ d.92.1.0 69 | d5x14a1 a.4.5.0 70 | d2wyqa_ d.15.1.1 71 | d6fyqa_ c.67.1.0 72 | d1hxmb1 b.1.1.1 73 | d6ztbi2 b.1.6.0 74 | d3bpka1 b.45.1.0 75 | d6nnfu_ b.1.1.0 76 | d2a8aa_ d.92.1.7 77 | d3q4da2 c.1.11.0 78 | d4lfhd2 b.1.1.0 79 | d3rita2 c.1.11.0 80 | d4g8ta2 c.1.11.2 81 | d2v44a1 b.1.1.0 82 | d4lfhg1 b.1.1.0 83 | d5n9va1 d.15.1.0 84 | d5nbca_ a.4.5.0 85 | d1cs6a2 b.1.1.4 86 | d1koaa1 b.1.1.4 87 | d1hkfa_ b.1.1.1 88 | d1zoda_ c.67.1.4 89 | d4icva_ d.15.1.1 90 | d1nkoa_ b.1.1.1 91 | d1i43a_ c.67.1.3 92 | d2qs8a2 c.1.9.18 93 | d1yg2a_ a.4.5.61 94 | d5jdda2 b.1.1.0 95 | d1repc1 a.4.5.10 96 | d1ncwl1 b.1.1.1 97 | d4h83a2 c.1.11.0 98 | d1hkqa_ a.4.5.10 99 | d6mfba_ c.67.1.3 100 | d4adba_ c.67.1.0 101 | d1pgya_ a.5.2.1 102 | d5szra4 b.1.6.0 103 | d5ryoa1 d.15.1.0 104 | d4aiha_ a.4.5.28 105 | d2fxaa1 a.4.5.28 106 | d4a5na_ a.4.5.0 107 | d1we6a1 d.15.1.1 108 | d5a0yb1 d.58.31.2 109 | d1wiaa1 d.15.1.1 110 | d3fdba1 c.67.1.0 111 | d2gdqa1 c.1.11.2 112 | d5e16a_ b.82.3.0 113 | d1ibja_ c.67.1.3 114 | d6p2la2 b.69.13.0 115 | d3k2za1 a.4.5.0 116 | d1p1ma2 c.1.9.9 117 | d3hoaa_ d.92.1.0 118 | d6ftfb2 b.82.3.0 119 | d6l0aa_ c.1.9.0 120 | d4el4a_ d.92.1.7 121 | d5szra3 b.1.6.0 122 | d6ijna1 c.1.9.0 123 | d7bxpa_ c.67.1.0 124 | d1j7na1 d.92.1.14 125 | d2p6ra1 a.4.5.43 126 | d1vl7a_ b.45.1.1 127 | d4lfhg2 b.1.1.2 128 | d1uvqb1 b.1.1.2 129 | d4ejoa1 a.4.5.0 130 | d1y0ua1 a.4.5.5 131 | d4a6ra_ c.67.1.0 132 | d2w57a_ a.4.5.0 133 | d3nfia2 a.4.5.86 134 | d4hpna2 c.1.11.0 135 | d1gkpa2 c.1.9.6 136 | d4k8ga2 c.1.11.0 137 | d1ncua1 b.1.1.4 138 | d4jvua1 b.1.1.0 139 | d5gxxa1 a.102.1.0 140 | d4un2b1 a.5.2.1 141 | d1k5na1 b.1.1.2 142 | d1wjua1 d.15.1.1 143 | d3p1ta_ c.67.1.0 144 | d1ogad1 b.1.1.1 145 | d3wkga_ a.102.1.0 146 | d1on2a1 a.4.5.24 147 | d1u5tb1 a.4.5.54 148 | d1p4xa2 a.4.5.28 149 | d2x5da_ c.67.1.0 150 | d1wgpa1 b.82.3.2 151 | d1vg5a1 a.5.2.1 152 | d6l1oa_ c.67.1.0 153 | d6mraa1 b.1.1.0 154 | d1l3wa5 b.1.6.1 155 | d4mcha1 c.56.2.0 156 | d6p2la1 b.69.13.0 157 | d1ufma1 a.4.5.47 158 | d1ybfa_ c.56.2.1 159 | d2va4a2 b.1.1.0 160 | d1repc2 a.4.5.10 161 | d6qp1a1 c.67.1.0 162 | d5m2ta_ c.56.2.1 163 | d4oc9a_ c.67.1.0 164 | d2yz1a_ b.1.1.0 165 | d3nu8a_ c.67.1.0 166 | d3l7wa_ a.4.5.0 167 | d2cqva1 b.1.1.4 168 | d6a2ba2 b.1.1.0 169 | d1tjgh1 b.1.1.1 170 | d1i1ip_ d.92.1.5 171 | d1wxva1 d.15.1.1 172 | d3eufa_ c.56.2.0 173 | d1iama2 b.1.1.4 174 | d1ypzf1 b.1.1.1 175 | d4f5la1 c.67.1.1 176 | d1o7fa2 b.82.3.2 177 | d1cs6a3 b.1.1.4 178 | d2v9va1 a.4.5.35 179 | d2or7a1 b.1.1.0 180 | d4x57b_ d.15.1.0 181 | d1k8rb_ d.15.1.5 182 | d3e6ma_ a.4.5.0 183 | d1ldda_ a.4.5.34 184 | d2cr6a1 b.1.1.0 185 | d1wgla1 a.5.2.4 186 | d2fmya2 a.4.5.0 187 | d2paja2 c.1.9.9 188 | d1wgha1 d.15.1.1 189 | d2di0a1 a.5.2.4 190 | d4jjha_ b.1.1.0 191 | d6isbc_ b.1.1.0 192 | d2f6ka1 c.1.9.15 193 | d2foka3 a.4.5.12 194 | d1xb2b1 a.5.2.2 195 | d1vcaa1 b.1.1.3 196 | d2nw2a2 b.1.1.2 197 | d2m4na1 d.15.1.0 198 | d3ec6a_ b.45.1.0 199 | d2iepa1 b.1.1.0 200 | d1f5wa_ b.1.1.1 201 | -------------------------------------------------------------------------------- /scripts/dataset/test_unseen_sid.txt: -------------------------------------------------------------------------------- 1 | d3bb6a1 b.82.2.13 2 | d1vqta1 c.1.2.3 3 | d3nara_ a.4.1.1 4 | d3euca_ c.67.1.0 5 | d3aowa_ c.67.1.0 6 | d3g46a_ a.1.1.2 7 | d6i9sa2 b.1.1.0 8 | d6rw7a1 b.1.18.0 9 | d3eyxa_ c.53.2.0 10 | d4uypa1 b.2.2.0 11 | d4rz4a1 c.1.10.1 12 | d1knva_ c.52.1.7 13 | d2h6ra_ c.1.1.1 14 | d4ipca1 b.40.4.3 15 | d6r3za_ c.92.2.0 16 | d4n4ee_ d.92.1.2 17 | d3neva_ c.1.10.0 18 | d1l3wa3 b.1.6.1 19 | d2iw0a1 c.6.2.3 20 | d2g7ga1 a.4.1.9 21 | d1p5ub_ b.2.3.2 22 | d3b1qa_ c.72.1.0 23 | d5i8da3 b.1.6.0 24 | d1cb8a2 b.24.1.1 25 | d2v82a1 c.1.10.0 26 | d1q15a1 c.26.2.1 27 | d1wuza_ b.1.6.3 28 | d2hs5a1 a.4.5.6 29 | d3zq7a_ a.4.6.0 30 | d2abwa1 c.23.16.1 31 | d1x1ia3 b.30.5.0 32 | d4ay0a2 b.7.2.0 33 | d4j25a1 b.82.2.0 34 | d6bmaa1 c.1.2.0 35 | d1y4wa2 b.67.2.3 36 | d1lc5a_ c.67.1.1 37 | d3cwna_ c.1.10.1 38 | d3pmea2 b.42.4.0 39 | d1kgsa2 c.23.1.1 40 | d3djca1 c.55.1.0 41 | d5jgya_ c.1.7.0 42 | d2xpwa1 a.4.1.9 43 | d1p0za1 d.110.6.1 44 | d2kiva2 a.60.1.0 45 | d6cw0a_ a.29.2.0 46 | d1ta3a_ c.1.8.5 47 | d2dy8a1 b.34.13.2 48 | d2vnud2 b.40.4.5 49 | d7dhfa1 c.44.1.0 50 | d5vrka_ c.1.9.0 51 | d4cr2r1 a.118.8.9 52 | d2d69a1 c.1.14.1 53 | d2p3ra3 c.55.1.4 54 | d2wy4a_ a.1.1.0 55 | d3dxea_ b.55.1.0 56 | d1wela1 d.58.7.1 57 | d1t9ha1 b.40.4.5 58 | d3grfa2 c.78.1.0 59 | d6n0ka_ b.82.1.12 60 | d1enfa1 b.40.2.2 61 | d1u5ha_ c.1.12.5 62 | d1x3da1 b.1.2.1 63 | d3hhtb_ b.34.4.4 64 | d2ghpa2 d.58.7.1 65 | d1qzya1 a.4.5.25 66 | d5h20a_ a.4.5.0 67 | d4e70a1 a.4.5.0 68 | d4xria_ a.118.1.0 69 | d1k61a_ a.4.1.1 70 | d2csba1 a.60.2.4 71 | d1knla_ b.42.2.1 72 | d5y5qa_ b.85.4.0 73 | d2iiza_ d.58.4.14 74 | d3heba1 c.23.1.0 75 | d2m6sa_ c.23.5.0 76 | d2yrza1 b.1.2.1 77 | d1xyza_ c.1.8.3 78 | d6nqia_ c.55.3.0 79 | d2p39a_ b.42.1.0 80 | d1woca_ b.40.4.3 81 | d2bvya1 b.1.18.2 82 | d7coia_ c.53.2.0 83 | d1fp3a_ a.102.1.3 84 | d4r3va_ d.92.1.0 85 | d3vrna2 a.39.1.0 86 | d1q16a1 b.52.2.2 87 | d2v5ca1 d.92.2.3 88 | d4g9ya_ a.4.5.0 89 | d1i5pa2 b.77.2.1 90 | d2kyra_ c.44.2.0 91 | d4q9ba_ b.1.1.0 92 | d2pn6a2 d.58.4.0 93 | d1d5ra1 b.7.1.1 94 | d5wida_ c.23.5.0 95 | d1ty0a2 d.15.6.1 96 | d2hjsa2 d.81.1.1 97 | d5afwa2 b.34.13.2 98 | d1yjra1 d.58.17.0 99 | d2xfwa1 c.1.10.1 100 | d1yn8a1 b.34.2.0 101 | d2wp4a_ d.41.5.0 102 | d1bvsa3 b.40.4.2 103 | d2dkua1 b.1.1.0 104 | d1sqja2 b.69.13.1 105 | d2nwha_ c.72.1.0 106 | d2p5ka_ a.4.5.3 107 | d3ea6a2 d.15.6.1 108 | d1ga2a2 b.2.2.2 109 | d2jk3a1 a.4.1.9 110 | d1w32a_ c.1.8.3 111 | d1otja_ b.82.2.5 112 | d2znra1 c.97.3.1 113 | d4uqva_ c.67.1.0 114 | d4xmra1 d.110.6.4 115 | d4r82a_ b.45.1.0 116 | d2cm5a_ b.7.1.2 117 | d2f8aa1 c.47.1.10 118 | d1w99a1 f.1.3.1 119 | d2h3na_ b.1.1.0 120 | d1szna2 c.1.8.1 121 | d1rq5a2 b.1.18.2 122 | d1d4ba1 d.15.2.1 123 | d1rlla1 b.1.18.9 124 | d1o12a2 c.1.9.10 125 | d2dlta1 b.1.1.0 126 | d2xvla1 b.30.5.11 127 | d1b9la_ d.96.1.3 128 | d3anua2 c.1.6.1 129 | d3df8a1 a.4.5.0 130 | d2gu3a1 d.17.1.6 131 | d3dbxa2 b.1.1.0 132 | d1w5fa2 d.79.2.1 133 | d6jt6a_ b.1.9.0 134 | d1s05a_ a.24.3.2 135 | d6luja1 a.60.1.0 136 | d1deca_ g.3.15.2 137 | d6evla_ a.118.8.1 138 | d6wsha1 c.23.1.0 139 | d3or5a_ c.47.1.0 140 | d5z0qa_ c.67.1.0 141 | d3keba_ c.47.1.0 142 | d2dc3a_ a.1.1.2 143 | d3n6wa2 b.82.2.5 144 | d2nv0a_ c.23.16.1 145 | d3hd5a1 c.47.1.0 146 | d2gkga_ c.23.1.0 147 | d1mc0a2 d.110.2.1 148 | d1gp6a_ b.82.2.1 149 | d4f0ba1 c.47.1.0 150 | d1t3ga_ c.23.2.0 151 | d4ieua_ b.82.1.19 152 | d6q6na2 c.1.8.3 153 | d2jgqa_ c.1.1.0 154 | d1v86a1 d.15.1.1 155 | d1pkoa1 b.1.1.1 156 | d3tsma_ c.1.2.0 157 | d1vjpa4 d.81.1.3 158 | d1x46a_ a.1.1.0 159 | d1n52a3 a.118.1.14 160 | d1tfia_ g.41.3.1 161 | d3bn3b1 b.1.1.0 162 | d5i4da1 b.40.2.0 163 | d1uz5a1 b.85.6.1 164 | d1x48a1 d.50.1.1 165 | d2dnaa1 a.5.2.1 166 | d1bhea_ b.80.1.3 167 | d7jgwa1 f.1.4.1 168 | d2dy1a5 d.58.11.1 169 | d1o1ya1 c.23.16.1 170 | d3en8a1 d.17.4.20 171 | d1y3ta1 b.82.1.5 172 | d2a1ka_ b.40.4.7 173 | d2bpa2_ b.121.5.1 174 | d3cu7a1 b.1.29.2 175 | d2csba3 a.60.2.4 176 | d3c5ra2 d.211.1.1 177 | d1elwa_ a.118.8.1 178 | d1reqa1 c.1.19.1 179 | d1h97a_ a.1.1.2 180 | d1v7va2 b.30.5.3 181 | d2bgca1 a.4.5.4 182 | d1ebua2 d.81.1.2 183 | d3ma2b_ b.40.3.1 184 | d1qjta_ a.39.1.6 185 | d3eu9a1 d.211.1.0 186 | d5ycza_ b.42.4.0 187 | d2hoea1 a.4.5.63 188 | d1vkra_ c.44.2.1 189 | d4ay7a_ c.1.22.0 190 | d1vlpa1 d.41.2.2 191 | d1j8mf1 a.24.13.1 192 | d1fhoa_ b.55.1.1 193 | d2i61a_ g.3.7.1 194 | d5xopa1 a.39.1.5 195 | d2axwa1 b.2.3.2 196 | d2zaya_ c.23.1.0 197 | d1a8da2 b.42.4.2 198 | d2frha2 a.4.5.28 199 | d4xdaa_ c.72.1.0 200 | d1y1xa_ a.39.1.8 201 | d3ri6a_ c.67.1.0 202 | d2fe3a_ a.4.5.0 203 | d2qcva_ c.72.1.0 204 | d7m3xa_ b.69.4.1 205 | d1zzma1 c.1.9.12 206 | d1dhna_ d.96.1.3 207 | d6r2na2 c.55.1.0 208 | d1khca_ b.34.9.2 209 | d6blga_ c.67.1.0 210 | d1m45a_ a.39.1.5 211 | d1kzla2 b.43.4.3 212 | d1mvha_ b.85.7.1 213 | d3h3ba1 c.10.2.5 214 | d2rnja_ a.4.6.0 215 | d4bhxa1 a.28.3.0 216 | d1ckma1 b.40.4.6 217 | d5svva_ d.110.3.0 218 | d3jb9u3 b.69.4.1 219 | d4yvda_ b.69.4.0 220 | d3krua_ c.1.4.0 221 | d1bdoa_ b.84.1.1 222 | d2kspa1 a.39.1.0 223 | d1ofda3 d.153.1.1 224 | d2cjja1 a.4.1.3 225 | d3ppea2 b.1.6.0 226 | d1kqka_ d.58.17.1 227 | d2wfba1 c.55.5.0 228 | d1xuva1 d.129.3.5 229 | d2coca1 b.55.1.1 230 | d3gt5a1 a.102.1.0 231 | d2p9ra_ b.1.29.3 232 | d3bexa1 c.55.1.13 233 | d3pnza_ c.1.9.0 234 | d1bw0a_ c.67.1.1 235 | d2xfaa_ d.109.1.0 236 | d4ch0s1 d.58.7.1 237 | d2w2db1 d.92.1.0 238 | d4e9sa3 b.6.1.3 239 | d1x4ea1 d.58.7.1 240 | d2q5we1 d.41.5.0 241 | d2d6ya1 a.4.1.9 242 | d1fu3a_ g.3.6.1 243 | d5y9ea_ b.121.6.1 244 | d1y0ga_ b.61.6.1 245 | d3c8ca1 d.110.6.4 246 | d2fexa1 c.23.16.2 247 | d1geqa_ c.1.2.4 248 | d2q9oa1 b.6.1.3 249 | d1smoa_ b.1.1.1 250 | d2naca2 c.23.12.1 251 | d4unua_ b.1.1.0 252 | d3lx3a_ d.96.1.2 253 | d2iyga_ d.58.10.2 254 | d3kcca1 b.82.3.2 255 | d2a28a1 b.34.2.0 256 | d2a1xa1 b.82.2.9 257 | d2ckxa1 a.4.1.3 258 | d1gqia2 d.92.2.2 259 | d5tkma_ c.97.1.6 260 | d3i33a2 c.55.1.1 261 | d1wgga1 d.15.1.1 262 | d3o0ma1 d.13.1.0 263 | d1bvyf_ c.23.5.1 264 | d2zfga_ f.4.3.1 265 | d1t33a1 a.4.1.9 266 | d3gg7a1 c.1.9.0 267 | d3d01a1 d.79.1.0 268 | d1fi6a_ a.39.1.6 269 | d2ih3c_ f.14.1.1 270 | d1w9aa_ b.45.1.1 271 | d2f42a1 a.118.8.0 272 | d3fiua_ c.26.2.0 273 | d3w2wa3 d.58.53.7 274 | d6mwdb1 f.14.1.2 275 | d1ng0a_ b.121.4.7 276 | d1n0za1 g.41.11.1 277 | d3b0fa_ a.5.2.1 278 | d3jtea1 c.23.1.0 279 | d3na8a_ c.1.10.0 280 | d2hxva2 c.97.1.2 281 | d5u1ma_ b.55.1.2 282 | d1dbha2 b.55.1.1 283 | d2bs2b2 d.15.4.2 284 | d1wu7a1 c.51.1.1 285 | d6t40a_ b.121.4.1 286 | d2qqra2 b.34.9.1 287 | d2ra2a1 b.38.1.6 288 | d1ty0a1 b.40.2.2 289 | d3gkxa_ c.47.1.0 290 | d1x1ia1 a.102.3.2 291 | d4l3ma_ c.1.8.0 292 | d1bf2a3 c.1.8.1 293 | d2il5a1 d.129.3.5 294 | d1jvna2 c.23.16.1 295 | d2jbaa_ c.23.1.1 296 | d5wl1a2 b.1.1.0 297 | d1m9sa3 b.34.11.1 298 | d1lvka1 b.34.3.1 299 | d5fjqa_ b.1.18.0 300 | d5cm7a1 d.79.4.0 301 | d2z48a1 b.42.2.1 302 | d3cu7a8 b.1.29.1 303 | d5dl5a1 f.4.7.0 304 | d1x9na2 b.40.4.6 305 | d1kqfb1 d.58.1.5 306 | d2pe8a1 d.58.7.0 307 | d1wjka1 c.47.1.1 308 | d3ceda1 d.58.18.13 309 | d1wisa1 b.1.2.1 310 | d1vyra_ c.1.4.1 311 | d1rmga_ b.80.1.3 312 | d6wk3a_ a.1.1.0 313 | d1pk3a1 a.60.1.2 314 | d6gf1a1 d.15.1.0 315 | d2q3qa2 d.129.3.1 316 | d3n0pb1 b.1.2.1 317 | d1uw4b_ a.118.1.14 318 | d1u5da1 b.55.1.1 319 | d6onna1 c.67.1.0 320 | d2edna1 b.1.1.0 321 | d1gr0a2 d.81.1.3 322 | d1ejea_ b.45.1.2 323 | d6jc0b_ d.41.5.0 324 | d4wnya1 c.26.2.0 325 | d2f3la1 b.80.8.1 326 | d4ohja1 b.40.2.2 327 | d2imla1 b.45.1.4 328 | d6fd1a_ d.58.1.2 329 | d3r1za2 c.1.11.0 330 | d1sixa1 b.85.4.1 331 | d1clca1 a.102.1.2 332 | d2jnfa_ a.39.1.0 333 | d3bb9a1 d.17.4.16 334 | d2hr7a2 g.3.9.0 335 | d2axla1 a.4.5.43 336 | d2ayta1 c.47.1.0 337 | d1ix9a1 a.2.11.1 338 | d3fsoa1 b.1.27.0 339 | d4dpoa1 d.58.4.0 340 | d2fdbm1 b.42.1.1 341 | d7ahsa2 b.1.1.0 342 | d2rsda1 g.50.1.0 343 | d2uy6b1 b.2.3.2 344 | d1rypd_ d.153.1.4 345 | d4mzya2 c.1.17.1 346 | d4kjfa_ c.47.1.0 347 | d2wu2c_ f.21.2.2 348 | d3ibva_ a.118.1.19 349 | d5ysca_ c.92.2.2 350 | d2z61a_ c.67.1.0 351 | d2nw2a1 b.1.1.0 352 | d2nlza1 d.153.1.6 353 | d1tiga_ d.68.1.1 354 | d1kl9a2 b.40.4.5 355 | d2c1ia1 c.6.2.3 356 | d3hcza1 c.47.1.0 357 | d5gxxa2 b.2.2.0 358 | d3hpea_ b.61.6.0 359 | d3md3a1 d.58.7.0 360 | d1jmrb4 d.110.2.2 361 | d3tqva1 d.41.2.0 362 | d1kfwa1 c.1.8.5 363 | d2bz7a_ b.6.1.1 364 | d4cvqa_ c.67.1.0 365 | d1v3wa_ b.81.1.5 366 | d1gdha2 c.23.12.1 367 | d3frxa_ b.69.4.0 368 | d1i7qb_ c.23.16.1 369 | d4u5ia_ c.1.8.0 370 | d2vyoa_ c.6.2.0 371 | d1xxxa1 c.1.10.1 372 | d2fbwc_ f.21.2.2 373 | d6hbsa_ c.67.1.0 374 | d4m5ra1 c.52.1.34 375 | d1pbyb_ b.69.2.2 376 | d5uida_ c.67.1.0 377 | d2nq3a1 b.7.1.1 378 | d1xkua_ c.10.2.7 379 | d4c4ko1 b.1.1.0 380 | d3zdmc_ d.15.1.0 381 | d1vq8b1 b.43.3.2 382 | d2xzga2 a.118.1.0 383 | d1rd5a_ c.1.2.4 384 | d2cuaa_ b.6.1.2 385 | d1usca_ b.45.1.2 386 | d4jw0a_ b.40.15.1 387 | d1ug2a1 a.4.1.3 388 | d1m3ya2 b.121.2.3 389 | d3mgka_ c.23.16.0 390 | d2m87a_ a.4.6.0 391 | d5bnza1 c.26.1.0 392 | d1b9ha_ c.67.1.4 393 | d3czba_ b.52.1.0 394 | d3bbba_ d.58.6.1 395 | d1lw7a1 c.26.1.3 396 | d2fjca1 a.25.1.1 397 | d2mska_ c.23.1.1 398 | d3eura1 c.47.1.0 399 | d1p4ca_ c.1.4.1 400 | d6xr5a2 d.58.26.0 401 | -------------------------------------------------------------------------------- /scripts/dataset/val_seen.txt: -------------------------------------------------------------------------------- 1 | d4e6ua_ b.81.1.0 2 | d3dvwa_ c.47.1.0 3 | d3wctd_ a.1.1.0 4 | d1b24a1 d.95.2.1 5 | d4n68a1 b.1.2.0 6 | d4m1ba_ c.6.2.0 7 | d1us0a_ c.1.7.1 8 | d1r8na_ b.42.4.1 9 | d4e38a1 c.1.10.0 10 | d5suza1 a.4.5.31 11 | d2nn5a1 d.129.3.5 12 | d4cc4b1 b.34.2.0 13 | d3iam9_ d.58.1.5 14 | d2ckra_ c.1.8.0 15 | d2zbla1 a.102.1.3 16 | d4izea_ b.42.1.0 17 | d4bd8a_ f.1.4.0 18 | d3r27a_ d.58.7.0 19 | d1jnrb_ d.58.1.5 20 | d1x4fa1 d.58.7.1 21 | d1xjva1 b.40.4.3 22 | d1szba1 b.23.1.1 23 | d6m4ja_ c.67.1.0 24 | d1u5tb2 a.4.5.54 25 | d2z69a1 b.82.3.0 26 | d4u0qb1 b.1.1.0 27 | d5c71a2 b.1.4.0 28 | d2mq0a1 b.1.1.0 29 | d4rn7a1 c.56.5.0 30 | d5gxub1 b.43.4.0 31 | d1b4ka_ c.1.10.3 32 | d6gmpa_ d.26.1.0 33 | d1cg5b_ a.1.1.2 34 | d1vr6a2 c.1.10.4 35 | d3p1ga_ c.55.3.0 36 | d3wfwa1 a.1.1.0 37 | d1j5va1 c.72.1.1 38 | d2qk9a1 c.55.3.0 39 | d6r2wl2 g.3.11.1 40 | d1u08a_ c.67.1.1 41 | d4hl7a2 c.1.17.0 42 | d1ht6a2 c.1.8.1 43 | d2a3la1 c.1.9.1 44 | d1kf6c_ f.21.2.2 45 | d6rsxa2 b.82.3.0 46 | d3pmsa1 b.121.1.1 47 | d4gnia2 c.55.1.0 48 | d3ldaa1 a.60.4.1 49 | d3p73a2 b.1.1.0 50 | d5le5a_ d.153.1.4 51 | d4yb6a2 d.58.5.0 52 | d6vana_ a.39.1.0 53 | d1hxma1 b.1.1.1 54 | d5i97a_ d.17.4.0 55 | d6ipna1 a.60.6.0 56 | d2hdma_ d.9.1.1 57 | d4ooza1 c.1.8.0 58 | d3loqa2 c.26.2.0 59 | d3b0ga3 d.58.36.0 60 | d1ub0a_ c.72.1.2 61 | d1fvka_ c.47.1.13 62 | d1yw6a1 c.56.5.7 63 | d6tbnb1 b.69.4.0 64 | d1yara_ d.153.1.4 65 | d1bvsa2 a.60.2.1 66 | d2a73a3 b.1.29.4 67 | d2jepa_ c.1.8.0 68 | d3iu5a1 a.29.2.0 69 | d6zcwa_ b.70.1.0 70 | d5d66a2 d.211.1.0 71 | d3mexa_ a.4.5.28 72 | d2d69a2 d.58.9.1 73 | d5k28a1 b.34.2.0 74 | d1zc3b_ b.55.1.1 75 | d2slia2 b.68.1.1 76 | d5dvha_ b.42.4.0 77 | d2j5ca1 a.102.4.0 78 | d1s35a1 a.7.1.1 79 | d1sddb1 b.6.1.3 80 | d4ggca_ b.69.4.0 81 | d2coha3 b.82.3.2 82 | d1ri9a_ b.34.2.1 83 | d2yyda1 d.79.4.1 84 | d1wfma1 b.7.1.2 85 | d4xo9a2 b.2.3.2 86 | d1thfd_ c.1.2.1 87 | d7diea_ a.25.1.0 88 | d5r7xa1 b.82.2.14 89 | d6fzvd1 b.23.1.0 90 | d2zkmx3 b.55.1.1 91 | d1zj9a2 d.58.36.1 92 | d6ag8a_ b.81.1.0 93 | d3uama_ b.1.18.0 94 | d1bf2a1 b.1.18.2 95 | d2osva_ c.92.2.0 96 | d5fmol1 b.121.4.2 97 | d1nkla_ a.64.1.1 98 | d2csba4 a.60.2.4 99 | d1t17a_ d.129.3.6 100 | d5n17a1 a.29.2.0 101 | d2qi2a2 c.55.4.2 102 | d1wtya_ a.24.16.2 103 | d4wu0a_ a.102.1.0 104 | d1moja_ a.25.1.1 105 | d2q7sa1 c.56.5.9 106 | d1v5ua1 b.55.1.1 107 | d2lj0a_ b.34.2.0 108 | d1yeya1 c.1.11.2 109 | d5n0ka2 b.6.1.0 110 | d6vzda_ a.64.1.0 111 | d1x43a1 b.34.2.0 112 | d3ffha1 c.67.1.0 113 | d3pxla3 b.6.1.0 114 | d1gvea_ c.1.7.1 115 | d4ac9a2 b.43.3.1 116 | d1xhna1 b.45.1.1 117 | d4grfa1 c.47.1.0 118 | d1ws8a_ b.6.1.1 119 | d5qu8a_ b.34.2.0 120 | d2d62a3 b.40.6.0 121 | d1dceb_ a.102.4.3 122 | d1tkja_ c.56.5.4 123 | d1mn3a_ a.5.2.4 124 | d1wwha1 d.58.7.1 125 | d3pzsa_ c.72.1.5 126 | d3mkva2 c.1.9.18 127 | d1ff9a2 d.81.1.2 128 | d5k0ua_ b.121.4.0 129 | d1v4ra1 a.4.5.6 130 | d3ea5b_ a.118.1.1 131 | d2hhpa3 d.58.16.1 132 | d1a6ca3 b.121.4.2 133 | d1oh1a_ b.61.2.2 134 | d1vefa1 c.67.1.4 135 | d2zkmx1 a.39.1.7 136 | d1qp8a2 c.23.12.1 137 | d2q62a_ c.23.5.4 138 | d1o97c_ c.26.2.3 139 | d3w2wa1 d.58.53.7 140 | d2gysa1 b.1.2.1 141 | d2vwsa_ c.1.12.0 142 | d1u04a4 c.55.3.15 143 | d3pnua1 c.1.9.0 144 | d3seia2 a.60.1.0 145 | d6yira2 b.40.6.0 146 | d1kwga2 c.1.8.1 147 | d1v5oa1 d.15.1.1 148 | d2qiya_ d.17.4.2 149 | d1fjgk_ c.55.4.1 150 | d2cz9a2 d.58.26.0 151 | d1ucta2 b.1.1.4 152 | d3d85d3 b.1.2.1 153 | d4pmza_ c.1.8.0 154 | d1jqba1 b.35.1.2 155 | d1cr5a1 b.52.2.3 156 | d3rgha1 b.1.18.0 157 | d3bzca4 b.40.4.5 158 | d2a73a2 b.1.29.3 159 | d3kuza1 d.15.1.0 160 | d1i2ha1 b.55.1.4 161 | d5z48a1 c.56.4.0 162 | d1u58a1 b.1.1.2 163 | d1b74a1 c.78.2.1 164 | d1xhba1 b.42.2.1 165 | d5ig6a1 a.29.2.0 166 | d1cvja1 d.58.7.1 167 | d5hqta2 c.78.2.0 168 | d1ml9a_ b.85.7.1 169 | d7a3ha_ c.1.8.3 170 | d1tqga1 a.24.10.3 171 | d2pida_ c.26.1.0 172 | d1nbua_ d.96.1.3 173 | d2v3ga_ c.1.8.3 174 | d1x6oa2 b.40.4.5 175 | d2b1xb1 d.17.4.4 176 | d5w0ha1 d.58.7.1 177 | d3cu7a4 b.1.29.5 178 | d1jmxb_ b.69.2.2 179 | d1wjoa1 a.40.1.1 180 | d1szba2 g.3.11.1 181 | d1tbxa1 a.4.5.48 182 | d3u7qa_ c.92.2.3 183 | d2chca1 d.17.4.25 184 | d4hcha2 c.1.11.2 185 | d1wbha1 c.1.10.1 186 | d1sj1a_ d.58.1.4 187 | d2m0ra_ a.39.1.0 188 | d3v9oa_ d.96.1.0 189 | d4h60a1 c.23.1.0 190 | d3fdwa_ b.7.1.0 191 | d1ue9a1 b.34.2.1 192 | d1vcta1 a.7.12.1 193 | d1gjwa2 c.1.8.1 194 | d1y51a_ d.94.1.1 195 | d3fsta_ c.1.23.1 196 | d2hwva1 a.4.6.0 197 | d1jmca1 b.40.4.3 198 | d2pbdp_ d.110.1.1 199 | d4ggfc_ a.39.1.2 200 | d1x9fd_ a.1.1.2 201 | -------------------------------------------------------------------------------- /scripts/dataset/val_unseen_sfam.txt: -------------------------------------------------------------------------------- 1 | d1hbka_ a.11.1.1 2 | d4dvca1 c.47.1.13 3 | d3m62b_ d.15.1.0 4 | d1ueba1 b.34.5.2 5 | d3pwfa1 a.25.1.1 6 | d5vaaa2 b.1.1.0 7 | d3loia_ b.82.1.0 8 | d6wu7a_ a.39.1.0 9 | d1udla1 b.34.2.1 10 | d1xg0c_ a.1.1.3 11 | d3rhba_ c.47.1.0 12 | d2a2pa1 c.47.1.23 13 | d3l0fa_ a.1.1.3 14 | d3us3a3 c.47.1.3 15 | d3d1ka_ a.1.1.2 16 | d3tcoa_ c.47.1.0 17 | d2rqta_ b.34.2.0 18 | d4hdea1 c.47.1.0 19 | d4jrra1 c.47.1.0 20 | d1oqpa1 a.39.1.5 21 | d2m1ua1 a.39.1.0 22 | d1i1ja_ b.34.2.1 23 | d5b7xa1 a.39.1.0 24 | d3ak8a_ a.25.1.1 25 | d1r26a1 c.47.1.1 26 | d1abaa_ c.47.1.1 27 | d2egea1 b.34.2.0 28 | d2hf5a1 a.39.1.5 29 | d2bjxa_ c.47.1.2 30 | d6g62a1 c.47.1.0 31 | d2lqoa1 c.47.1.0 32 | d2hpsa_ a.39.1.0 33 | d2qamc1 b.34.5.3 34 | d4hswa_ a.1.1.2 35 | d1z6na1 c.47.1.1 36 | d2m0ya_ b.34.2.0 37 | d4cy9a1 a.25.1.1 38 | d2zs0d_ a.1.1.0 39 | d1q1fa_ a.1.1.2 40 | d1zgma1 c.47.1.5 41 | d2ekha1 b.34.2.0 42 | d2ayta2 c.47.1.0 43 | d2b5ea3 c.47.1.2 44 | d1yuza1 a.25.1.1 45 | d1i07a_ b.34.2.1 46 | d3tvta1 b.34.2.0 47 | d1fi2a_ b.82.1.2 48 | d1ng2a2 b.34.2.1 49 | d1yhfa1 b.82.1.9 50 | d3ewla1 c.47.1.0 51 | d2dl5a1 b.34.2.0 52 | d1gcvb_ a.1.1.2 53 | d4o32a1 c.47.1.0 54 | d3fw2a1 c.47.1.0 55 | d4zl8a_ c.47.1.0 56 | d2mioa1 b.34.2.0 57 | d1k94a_ a.39.1.8 58 | d4in0a_ c.47.1.0 59 | d5wsfa_ b.82.1.10 60 | d2ixka_ b.82.1.1 61 | d2cvba1 c.47.1.10 62 | d3ubca_ a.1.1.0 63 | d3zita_ c.47.1.0 64 | d1j3ta1 b.34.2.1 65 | d1n8ja_ c.47.1.10 66 | d3cmia_ c.47.1.0 67 | d4zbda1 c.47.1.0 68 | d3i5ra_ b.34.2.1 69 | d4i2ua1 c.47.1.0 70 | d1j7qa_ a.39.1.5 71 | d3hvva_ c.47.1.10 72 | d3qqqa_ a.1.1.2 73 | d2yuoa1 b.34.2.0 74 | d4ocia1 a.39.1.0 75 | d1vjxa1 a.25.1.1 76 | d5o99a1 b.34.2.0 77 | d3d1kb_ a.1.1.2 78 | d5vf5a2 b.82.1.0 79 | d2kuca1 c.47.1.0 80 | d2scpa_ a.39.1.5 81 | d4j9fa_ b.34.2.1 82 | d2j6ka_ b.34.2.0 83 | d1hlba_ a.1.1.2 84 | d7cfza_ b.34.2.0 85 | d3fiaa_ a.39.1.0 86 | d1hyua4 c.47.1.2 87 | d4iwka_ a.25.1.1 88 | d3rnja1 b.34.2.1 89 | d3b3ha_ a.25.1.0 90 | d2gkma_ a.1.1.1 91 | d1vq8q1 b.34.5.1 92 | d1tq5a1 b.82.1.12 93 | d7bzkb_ c.47.1.0 94 | d3us3a2 c.47.1.3 95 | d2ig3a_ a.1.1.0 96 | d1mbaa_ a.1.1.2 97 | d2myga1 c.47.1.0 98 | d1j9ba_ c.47.1.12 99 | d2fa8a1 c.47.1.23 100 | d1fhga_ b.1.1.4 101 | -------------------------------------------------------------------------------- /scripts/dataset/val_unseen_sid.txt: -------------------------------------------------------------------------------- 1 | d1xiwb_ b.1.1.4 2 | d2ykza_ a.24.3.2 3 | d2y8ya1 d.58.53.0 4 | d1g5ca_ c.53.2.1 5 | d6td7a_ a.1.1.0 6 | d1w1ha_ b.55.1.1 7 | d1v58a1 c.47.1.9 8 | d4i7ga2 c.55.3.0 9 | d2mr7a1 a.28.1.0 10 | d1a0ia1 b.40.4.6 11 | d1js1x1 c.78.1.1 12 | d1sr9a2 c.1.10.5 13 | d2je8a3 b.1.4.1 14 | d2huga2 b.34.13.2 15 | d1oxxk1 b.40.6.3 16 | d2edfa1 b.1.1.0 17 | d3doda_ c.67.1.0 18 | d1zcza2 c.97.1.4 19 | d2cyya2 d.58.4.2 20 | d4h2da_ c.23.5.0 21 | d4gdja_ b.68.1.0 22 | d4fgla_ c.23.5.3 23 | d3fwba_ a.39.1.5 24 | d5llta1 c.26.1.0 25 | d1thxa_ c.47.1.1 26 | d6r2wl3 g.3.11.1 27 | d2rfra1 d.17.4.28 28 | d3kcma1 c.47.1.0 29 | d1pa4a_ d.52.7.1 30 | d5e5ub1 b.1.1.0 31 | d1xoda1 b.55.1.4 32 | d3k6ia1 b.1.6.0 33 | d3bp8a2 c.55.1.10 34 | d3viia_ c.1.8.0 35 | d1qlsa_ a.39.1.2 36 | d2zgya1 c.55.1.1 37 | d5vf5a1 b.82.1.0 38 | d3cfya1 c.23.1.0 39 | d1yixa1 c.1.9.12 40 | d4fc3e_ b.1.28.0 41 | d3iaca_ c.1.9.0 42 | d1rtqa_ c.56.5.4 43 | d5woza1 a.118.9.0 44 | d5ik2g_ c.49.2.0 45 | d3a4ja_ c.1.9.3 46 | d6uxea1 c.67.1.3 47 | d3d22a1 c.47.1.0 48 | d4hjla2 d.129.3.3 49 | d2h88b2 a.1.2.1 50 | d3zq5a3 d.110.2.0 51 | d1guua_ a.4.1.3 52 | d2q4ha2 d.13.1.2 53 | d3eika1 d.129.1.0 54 | d3ef8a1 d.17.4.28 55 | d1oo0b_ d.58.7.1 56 | d4e9sa1 b.6.1.3 57 | d2qr3a1 c.23.1.0 58 | d1fp0a1 g.50.1.2 59 | d6ufva_ b.2.2.0 60 | d3d06a_ b.2.5.2 61 | d2disa1 d.58.7.1 62 | d2y24a2 c.1.8.3 63 | d1c8na_ b.121.4.7 64 | d3ulla_ b.40.4.3 65 | d6jqba1 c.1.8.1 66 | d1oaia_ a.5.2.3 67 | d2bz2a1 d.58.7.1 68 | d1aoaa1 a.40.1.1 69 | d5le5l_ d.153.1.4 70 | d1cb8a3 b.30.5.2 71 | d2cwpa_ b.40.4.0 72 | d6vrra1 b.1.1.0 73 | d1wura_ d.96.1.1 74 | d3ju7a_ c.67.1.0 75 | d1s5ja1 c.55.3.5 76 | d5huoa1 d.41.2.0 77 | d2vgna2 c.55.4.2 78 | d3rnsa1 b.82.1.0 79 | d1ml4a1 c.78.1.1 80 | d1jbea_ c.23.1.1 81 | d2id0a2 b.40.4.5 82 | d3cu7a7 b.1.29.8 83 | d3lmea1 d.79.1.0 84 | d1twfk_ d.74.3.2 85 | d2foka1 a.4.5.12 86 | d3q0wa1 a.4.1.0 87 | d4phja_ a.39.1.8 88 | d4wp9a1 d.58.29.0 89 | d1nzaa_ d.58.5.2 90 | d1ug3a2 a.118.1.14 91 | d5yuoa_ b.40.4.3 92 | d1lqaa_ c.1.7.1 93 | d1vr3a1 b.82.1.6 94 | d1rl2a2 b.40.4.5 95 | d3ul4a_ b.2.2.0 96 | d2adza1 b.55.1.1 97 | d1xppa_ d.74.3.2 98 | d1k3la2 c.47.1.5 99 | d1f5ma_ d.110.2.1 100 | d2a7ba_ b.1.10.2 101 | d1ueba2 b.40.4.5 102 | d2fg9a1 b.45.1.1 103 | d2v0ha2 b.81.1.0 104 | d2plca_ c.1.18.2 105 | d2fq4a1 a.4.1.9 106 | d1fgya_ b.55.1.1 107 | d3cj1a2 c.55.1.15 108 | d6jixa_ c.67.1.0 109 | d1t0ia_ c.23.5.4 110 | d1yqha1 d.58.48.1 111 | d1kkha2 d.58.26.3 112 | d1aoza3 b.6.1.3 113 | d3eska1 a.118.8.1 114 | d4hjlb_ d.17.4.4 115 | d1fyhb2 b.1.2.1 116 | d6zshb_ a.40.1.0 117 | d6tt2a_ b.55.1.1 118 | d1btna_ b.55.1.1 119 | d1v7va1 a.102.1.4 120 | d1m5ta_ c.23.1.1 121 | d5wtpa1 d.79.7.0 122 | d2cxya_ a.4.3.1 123 | d3rtla1 b.1.28.0 124 | d2kp2a1 c.47.1.2 125 | d1d2pa1 b.3.5.1 126 | d3w9sa_ c.23.1.0 127 | d1g6za1 b.34.13.2 128 | d3gdla_ c.1.2.3 129 | d6mgla2 b.69.13.0 130 | d1zuya_ b.34.2.0 131 | d2wsda3 b.6.1.0 132 | d1y55x1 b.61.1.1 133 | d1r3ra_ c.1.22.1 134 | d2illa1 b.1.1.0 135 | d1nq4a_ a.28.1.1 136 | d3ukna_ b.82.3.0 137 | d3pfia2 a.4.5.0 138 | d1uufa1 b.35.1.2 139 | d1hzfa_ a.102.4.4 140 | d2lsta1 c.47.1.0 141 | d2z08a_ c.26.2.4 142 | d1st6a5 a.24.9.1 143 | d2q4za_ c.56.5.7 144 | d2nlia_ c.1.4.0 145 | d3f0ha1 c.67.1.0 146 | d2e0qa_ c.47.1.0 147 | d2fm8a1 d.198.1.1 148 | d1x36a_ b.121.4.7 149 | d1xeaa2 d.81.1.5 150 | d1yoaa1 b.45.1.2 151 | d4f7uf_ b.38.1.0 152 | d2f02a2 c.72.1.0 153 | d1ka9h_ c.23.16.1 154 | d1wp5a1 b.68.10.1 155 | d1z2la1 c.56.5.4 156 | d4jhta_ b.82.2.10 157 | d6d8ha_ g.3.7.2 158 | d4rqra1 c.47.1.1 159 | d5le5h_ d.153.1.4 160 | d4ybna_ b.45.1.0 161 | d1dxya2 c.23.12.1 162 | d1xwva_ b.1.18.7 163 | d2e6qa1 b.1.1.0 164 | d3osga2 a.4.1.0 165 | d5gpga1 d.26.1.1 166 | d6s8za1 b.34.5.0 167 | d2b1ua_ a.39.1.0 168 | d1z3xa1 a.118.1.22 169 | d2kgja1 d.26.1.0 170 | d2da1a1 a.4.1.0 171 | d6xhha_ d.110.2.0 172 | d2zada2 c.1.11.0 173 | d1neua_ b.1.1.1 174 | d1vlia2 c.1.10.6 175 | d1u6za3 c.55.1.8 176 | d1wmha_ d.15.2.2 177 | d1nmra1 a.144.1.1 178 | d3hzha_ c.23.1.0 179 | d4dh2a1 b.2.2.0 180 | d1gxua_ d.58.10.1 181 | d1j5pa3 d.81.1.3 182 | d2id6a1 a.4.1.9 183 | d1p9la2 d.81.1.3 184 | d2dcla_ d.58.5.0 185 | d1vcva1 c.1.10.1 186 | d3vzha1 d.58.53.2 187 | d1tdja3 d.58.18.2 188 | d6nsia_ c.92.2.0 189 | d1kgsa1 a.4.6.1 190 | d3l4na_ c.47.1.0 191 | d2byea1 d.15.1.5 192 | d1qwra_ b.82.1.3 193 | d2apja1 c.23.10.7 194 | d2crga1 a.4.1.3 195 | d6uy1a1 a.29.2.0 196 | d3u7qb_ c.92.2.3 197 | d2iuwa1 b.82.2.10 198 | d5kmya1 c.1.2.0 199 | d3eula_ c.23.1.0 200 | d6hnma_ d.17.4.0 201 | -------------------------------------------------------------------------------- /scripts/faiss_index.py: -------------------------------------------------------------------------------- 1 | # Generate FAISS index from embeddings 2 | 3 | import torch 4 | import faiss 5 | 6 | d = torch.load("afted.pt") 7 | embs = d["embeddings"].numpy() 8 | 9 | index = faiss.IndexFlatIP(128) 10 | index.add(embs) 11 | print(index.ntotal) 12 | 13 | D, I = index.search(embs[:32], 100) # Example of searching 14 | 15 | faiss.write_index(index, "afted.index") 16 | -------------------------------------------------------------------------------- /scripts/other_methods/3dsurfer/run_nn_allatom.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/scripts/other_methods/3dsurfer/run_nn_allatom.txt.gz -------------------------------------------------------------------------------- /scripts/other_methods/3dsurfer/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | # Run `gunzip run_nn_allatom.txt.gz` first 3 | 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | results_file = "run_nn_allatom.txt" 8 | dataset_dir = "../../dataset" 9 | known_problems = ["d6a1ia1", "d1qdma1", "d1jqga2", "d1ayea2", "d1kwma2", "d1ldtl_"] 10 | 11 | domid_to_fam = {} 12 | fams_search = [] 13 | with open(dataset_dir + "/search_all.txt") as f: 14 | for line in f.readlines(): 15 | cols = line.rstrip().split() 16 | domid, fam = cols[0], cols[1] 17 | if domid not in known_problems: 18 | domid_to_fam[domid] = fam 19 | fams_search.append(fam) 20 | 21 | domid_to_nres, domid_to_contact_order = {}, {} 22 | with open("../contact_order.txt") as f: 23 | for line in f.readlines(): 24 | cols = line.rstrip().split() 25 | domid, nres, contact_order = cols[0], cols[1], cols[2] 26 | if domid not in known_problems: 27 | domid_to_nres[domid] = int(nres) 28 | domid_to_contact_order[domid] = float(contact_order) 29 | 30 | domid_to_fam_matches = defaultdict(list) 31 | with open(results_file) as f: 32 | for line in f.readlines(): 33 | cols = line.split() 34 | domid_query = cols[0] 35 | domid_match = cols[1] 36 | if domid_query in domid_to_fam and domid_match in domid_to_fam: 37 | fam_match = domid_to_fam[domid_match] 38 | domid_to_fam_matches[domid_query].append(fam_match) 39 | 40 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 41 | domids_unseen = [l.split()[0] for l in f.readlines()] 42 | 43 | def scop_fold(fam): 44 | return fam.rsplit(".", 2)[0] 45 | 46 | def scop_sfam(fam): 47 | return fam.rsplit(".", 1)[0] 48 | 49 | def searching_accuracy(domid_query, fam_matches): 50 | fam_query = domid_to_fam[domid_query] 51 | sfam_query = scop_sfam(fam_query) 52 | fold_query = scop_fold(fam_query) 53 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 54 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 55 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 56 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 57 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 58 | 59 | for fi, fam_match in enumerate(fam_matches): 60 | sfam_match = scop_sfam(fam_match) 61 | fold_match = scop_fold(fam_match) 62 | # Don't count self match for top 1 accuracy 63 | if fi == 1: 64 | if fam_match == fam_query: 65 | top1_fam = 1 66 | if sfam_match == sfam_query: 67 | top1_sfam = 1 68 | if fold_match == fold_query: 69 | top1_fold = 1 70 | if 0 < fi < 6 and fam_match == fam_query: 71 | top5_fam = 1 72 | if fam_match == fam_query: 73 | count_tp_fam += 1 74 | elif sfam_match == sfam_query: 75 | count_tp_sfam += 1 76 | elif fold_match == fold_query: 77 | count_tp_fold += 1 78 | else: 79 | break 80 | 81 | sens_fam = count_tp_fam / total_fam 82 | sens_sfam = count_tp_sfam / total_sfam 83 | sens_fold = count_tp_fold / total_fold 84 | 85 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 86 | 87 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 88 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 89 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 90 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 91 | 92 | for di, domid in enumerate(domids_unseen): 93 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 94 | 95 | senses_fam.append(sens_fam) 96 | senses_sfam.append(sens_sfam) 97 | senses_fold.append(sens_fold) 98 | top1s_fam.append(top1_fam) 99 | top1s_sfam.append(top1_sfam) 100 | top1s_fold.append(top1_fold) 101 | top5s_fam.append(top5_fam) 102 | 103 | cl = domid_to_fam[domid][0] 104 | senses_fam_class[cl].append(sens_fam) 105 | senses_sfam_class[cl].append(sens_sfam) 106 | senses_fold_class[cl].append(sens_fold) 107 | 108 | nres = domid_to_nres[domid] 109 | if 0 <= nres < 100: 110 | nres_group = "0-99" 111 | elif 100 <= nres < 200: 112 | nres_group = "100-199" 113 | elif 200 <= nres < 300: 114 | nres_group = "200-299" 115 | else: 116 | nres_group = "300+" 117 | 118 | senses_fam_nres[nres_group].append(sens_fam) 119 | senses_sfam_nres[nres_group].append(sens_sfam) 120 | senses_fold_nres[nres_group].append(sens_fold) 121 | 122 | co = domid_to_contact_order[domid] 123 | if 0 <= co < 0.075: 124 | co_group = "0-0.075" 125 | elif 0.075 <= co < 0.125: 126 | co_group = "0.075-0.125" 127 | elif 0.125 <= co < 0.175: 128 | co_group = "0.125-0.175" 129 | else: 130 | co_group = "0.175+" 131 | 132 | senses_fam_co[co_group].append(sens_fam) 133 | senses_sfam_co[co_group].append(sens_sfam) 134 | senses_fold_co[co_group].append(sens_fold) 135 | 136 | print(np.mean(senses_fam)) 137 | print(np.mean(senses_sfam)) 138 | print(np.mean(senses_fold)) 139 | print() 140 | 141 | for cl in sorted(list(senses_fam_class.keys())): 142 | print(cl, "class") 143 | print(np.mean(senses_fam_class[cl])) 144 | print(np.mean(senses_sfam_class[cl])) 145 | print(np.mean(senses_fold_class[cl])) 146 | print() 147 | 148 | for nres_group in sorted(list(senses_fam_nres.keys())): 149 | print(nres_group, "nres") 150 | print(np.mean(senses_fam_nres[nres_group])) 151 | print(np.mean(senses_sfam_nres[nres_group])) 152 | print(np.mean(senses_fold_nres[nres_group])) 153 | print() 154 | 155 | for co_group in sorted(list(senses_fam_co.keys())): 156 | print(co_group, "contact_order") 157 | print(np.mean(senses_fam_co[co_group])) 158 | print(np.mean(senses_sfam_co[co_group])) 159 | print(np.mean(senses_fold_co[co_group])) 160 | print() 161 | -------------------------------------------------------------------------------- /scripts/other_methods/3dsurfer/searching.txt: -------------------------------------------------------------------------------- 1 | 0.3494996223427079 2 | 0.1403989168196148 3 | 0.04583099252270184 4 | 5 | a class 6 | 0.2624496149211868 7 | 0.0847521088566503 8 | 0.05200595789183216 9 | 10 | b class 11 | 0.3805873585445227 12 | 0.13172877670346617 13 | 0.0629701658565192 14 | 15 | c class 16 | 0.3577589269973005 17 | 0.22541060027745288 18 | 0.04415639763666987 19 | 20 | d class 21 | 0.3610731438324318 22 | 0.07736511972986493 23 | 0.011583818548069492 24 | 25 | f class 26 | 0.4386904761904762 27 | 0.11904761904761904 28 | 0.06918505942275043 29 | 30 | g class 31 | 0.21564625850340133 32 | 0.0 33 | 0.0 34 | 35 | 0-99 nres 36 | 0.2577241513964981 37 | 0.0668585691824269 38 | 0.029858329225067967 39 | 40 | 100-199 nres 41 | 0.3322507961171768 42 | 0.11848044659990008 43 | 0.04760817635202358 44 | 45 | 200-299 nres 46 | 0.4343640352942583 47 | 0.18534554416514457 48 | 0.056646748038103184 49 | 50 | 300+ nres 51 | 0.46175419674531315 52 | 0.27743027030319306 53 | 0.055141106954239086 54 | 55 | 0-0.075 contact_order 56 | 0.2486424720759674 57 | 0.05342698882122208 58 | 0.032883335302690136 59 | 60 | 0.075-0.125 contact_order 61 | 0.36025509190940447 62 | 0.19215549745547408 63 | 0.050194155892636226 64 | 65 | 0.125-0.175 contact_order 66 | 0.35498245148418156 67 | 0.09365837587603372 68 | 0.02661522405012336 69 | 70 | 0.175+ contact_order 71 | 0.35801390648292714 72 | 0.14159817218153928 73 | 0.07338759010450258 74 | 75 | -------------------------------------------------------------------------------- /scripts/other_methods/README.md: -------------------------------------------------------------------------------- 1 | Scripts for benchmarking other methods. 2 | 3 | Set up the `pdbstyle-2.08` directory of PDB files with something like: 4 | ```bash 5 | wget https://scop.berkeley.edu/downloads/pdbstyle/pdbstyle-sel-gs-bib-40-2.08.tgz 6 | tar -xvf pdbstyle-sel-gs-bib-40-2.08.tgz 7 | cd pdbstyle-2.08 8 | mv */*.ent . 9 | rmdir ?? 10 | cd .. 11 | mv pdbstyle-2.08 pdbstyle-2.08_models 12 | mkdir pdbstyle-2.08 13 | julia extract_model.jl 14 | rm -r pdbstyle-2.08_models pdbstyle-sel-gs-bib-40-2.08.tgz 15 | ``` 16 | -------------------------------------------------------------------------------- /scripts/other_methods/dali/domids_imported.txt: -------------------------------------------------------------------------------- 1 | d3bb6a1 2 | d1vqta1 3 | d3nara_ 4 | d3euca_ 5 | d3aowa_ 6 | d3g46a_ 7 | d6i9sa2 8 | d6rw7a1 9 | d3eyxa_ 10 | d4uypa1 11 | d4rz4a1 12 | d1knva_ 13 | d2h6ra_ 14 | d4ipca1 15 | d6r3za_ 16 | d4n4ee_ 17 | d3neva_ 18 | d1l3wa3 19 | d2iw0a1 20 | d2g7ga1 21 | d1p5ub_ 22 | d3b1qa_ 23 | d5i8da3 24 | d1cb8a2 25 | d2v82a1 26 | d1q15a1 27 | d1wuza_ 28 | d2hs5a1 29 | d3zq7a_ 30 | d2abwa1 31 | d1x1ia3 32 | d4ay0a2 33 | d4j25a1 34 | d6bmaa1 35 | d1y4wa2 36 | d1lc5a_ 37 | d3cwna_ 38 | d3pmea2 39 | d1kgsa2 40 | d3djca1 41 | d5jgya_ 42 | d2xpwa1 43 | d1p0za1 44 | d2kiva2 45 | d6cw0a_ 46 | d1ta3a_ 47 | d2dy8a1 48 | d2vnud2 49 | d7dhfa1 50 | d5vrka_ 51 | d4cr2r1 52 | d2d69a1 53 | d2p3ra3 54 | d2wy4a_ 55 | d3dxea_ 56 | d1wela1 57 | d1t9ha1 58 | d3grfa2 59 | d6n0ka_ 60 | d1enfa1 61 | d1u5ha_ 62 | d1x3da1 63 | d3hhtb_ 64 | d2ghpa2 65 | d1qzya1 66 | d5h20a_ 67 | d4e70a1 68 | d4xria_ 69 | d1k61a_ 70 | d2csba1 71 | d1knla_ 72 | d5y5qa_ 73 | d2iiza_ 74 | d3heba1 75 | d2m6sa_ 76 | d2yrza1 77 | d1xyza_ 78 | d6nqia_ 79 | d2p39a_ 80 | d1woca_ 81 | d2bvya1 82 | d7coia_ 83 | d1fp3a_ 84 | d4r3va_ 85 | d3vrna2 86 | d1q16a1 87 | d2v5ca1 88 | d4g9ya_ 89 | d1i5pa2 90 | d2kyra_ 91 | d4q9ba_ 92 | d2pn6a2 93 | d1d5ra1 94 | d5wida_ 95 | d1ty0a2 96 | d2hjsa2 97 | d5afwa2 98 | d1yjra1 99 | d2xfwa1 100 | d1yn8a1 101 | d2wp4a_ 102 | d1bvsa3 103 | d2dkua1 104 | d1sqja2 105 | d2nwha_ 106 | d2p5ka_ 107 | d3ea6a2 108 | d1ga2a2 109 | d2jk3a1 110 | d1w32a_ 111 | d1otja_ 112 | d2znra1 113 | d4uqva_ 114 | d4xmra1 115 | d4r82a_ 116 | d2cm5a_ 117 | d2f8aa1 118 | d1w99a1 119 | d2h3na_ 120 | d1szna2 121 | d1rq5a2 122 | d1d4ba1 123 | d1rlla1 124 | d1o12a2 125 | d2dlta1 126 | d2xvla1 127 | d1b9la_ 128 | d3anua2 129 | d3df8a1 130 | d2gu3a1 131 | d3dbxa2 132 | d1w5fa2 133 | d6jt6a_ 134 | d1s05a_ 135 | d6luja1 136 | d1deca_ 137 | d6evla_ 138 | d6wsha1 139 | d3or5a_ 140 | d5z0qa_ 141 | d3keba_ 142 | d2dc3a_ 143 | d3n6wa2 144 | d2nv0a_ 145 | d3hd5a1 146 | d2gkga_ 147 | d1mc0a2 148 | d1gp6a_ 149 | d4f0ba1 150 | d1t3ga_ 151 | d4ieua_ 152 | d6q6na2 153 | d2jgqa_ 154 | d1v86a1 155 | d1pkoa1 156 | d3tsma_ 157 | d1vjpa4 158 | d1x46a_ 159 | d1n52a3 160 | d1tfia_ 161 | d3bn3b1 162 | d5i4da1 163 | d1uz5a1 164 | d1x48a1 165 | d2dnaa1 166 | d1bhea_ 167 | d7jgwa1 168 | d2dy1a5 169 | d1o1ya1 170 | d3en8a1 171 | d1y3ta1 172 | d2a1ka_ 173 | d2bpa2_ 174 | d3cu7a1 175 | d2csba3 176 | d3c5ra2 177 | d1elwa_ 178 | d1reqa1 179 | d1h97a_ 180 | d1v7va2 181 | d2bgca1 182 | d1ebua2 183 | d3ma2b_ 184 | d1qjta_ 185 | d3eu9a1 186 | d5ycza_ 187 | d2hoea1 188 | d1vkra_ 189 | d4ay7a_ 190 | d1vlpa1 191 | d1j8mf1 192 | d1fhoa_ 193 | d2i61a_ 194 | d5xopa1 195 | d2axwa1 196 | d2zaya_ 197 | d1a8da2 198 | d2frha2 199 | d4xdaa_ 200 | d1y1xa_ 201 | d3ri6a_ 202 | d2fe3a_ 203 | d2qcva_ 204 | d7m3xa_ 205 | d1zzma1 206 | d1dhna_ 207 | d6r2na2 208 | d1khca_ 209 | d6blga_ 210 | d1m45a_ 211 | d1kzla2 212 | d1mvha_ 213 | d3h3ba1 214 | d2rnja_ 215 | d4bhxa1 216 | d1ckma1 217 | d5svva_ 218 | d3jb9u3 219 | d4yvda_ 220 | d3krua_ 221 | d1bdoa_ 222 | d2kspa1 223 | d1ofda3 224 | d2cjja1 225 | d3ppea2 226 | d1kqka_ 227 | d2wfba1 228 | d1xuva1 229 | d2coca1 230 | d3gt5a1 231 | d2p9ra_ 232 | d3bexa1 233 | d3pnza_ 234 | d1bw0a_ 235 | d2xfaa_ 236 | d4ch0s1 237 | d2w2db1 238 | d4e9sa3 239 | d1x4ea1 240 | d2q5we1 241 | d2d6ya1 242 | d5y9ea_ 243 | d1y0ga_ 244 | d3c8ca1 245 | d2fexa1 246 | d1geqa_ 247 | d2q9oa1 248 | d1smoa_ 249 | d2naca2 250 | d4unua_ 251 | d3lx3a_ 252 | d2iyga_ 253 | d3kcca1 254 | d2a28a1 255 | d2a1xa1 256 | d2ckxa1 257 | d1gqia2 258 | d5tkma_ 259 | d3i33a2 260 | d1wgga1 261 | d3o0ma1 262 | d1bvyf_ 263 | d2zfga_ 264 | d1t33a1 265 | d3gg7a1 266 | d3d01a1 267 | d1fi6a_ 268 | d2ih3c_ 269 | d1w9aa_ 270 | d2f42a1 271 | d3fiua_ 272 | d3w2wa3 273 | d6mwdb1 274 | d1ng0a_ 275 | d1n0za1 276 | d3b0fa_ 277 | d3jtea1 278 | d3na8a_ 279 | d2hxva2 280 | d5u1ma_ 281 | d1dbha2 282 | d2bs2b2 283 | d1wu7a1 284 | d6t40a_ 285 | d2qqra2 286 | d2ra2a1 287 | d1ty0a1 288 | d3gkxa_ 289 | d1x1ia1 290 | d4l3ma_ 291 | d1bf2a3 292 | d2il5a1 293 | d1jvna2 294 | d2jbaa_ 295 | d5wl1a2 296 | d1m9sa3 297 | d1lvka1 298 | d5fjqa_ 299 | d5cm7a1 300 | d2z48a1 301 | d3cu7a8 302 | d5dl5a1 303 | d1x9na2 304 | d1kqfb1 305 | d2pe8a1 306 | d1wjka1 307 | d3ceda1 308 | d1wisa1 309 | d1vyra_ 310 | d1rmga_ 311 | d6wk3a_ 312 | d1pk3a1 313 | d6gf1a1 314 | d2q3qa2 315 | d3n0pb1 316 | d1uw4b_ 317 | d1u5da1 318 | d6onna1 319 | d2edna1 320 | d1gr0a2 321 | d1ejea_ 322 | d6jc0b_ 323 | d4wnya1 324 | d2f3la1 325 | d4ohja1 326 | d2imla1 327 | d6fd1a_ 328 | d3r1za2 329 | d1sixa1 330 | d1clca1 331 | d2jnfa_ 332 | d3bb9a1 333 | d2hr7a2 334 | d2axla1 335 | d2ayta1 336 | d1ix9a1 337 | d3fsoa1 338 | d4dpoa1 339 | d2fdbm1 340 | d7ahsa2 341 | d2rsda1 342 | d2uy6b1 343 | d1rypd_ 344 | d4mzya2 345 | d4kjfa_ 346 | d2wu2c_ 347 | d3ibva_ 348 | d5ysca_ 349 | d2z61a_ 350 | d2nw2a1 351 | d2nlza1 352 | d1tiga_ 353 | d1kl9a2 354 | d2c1ia1 355 | d3hcza1 356 | d5gxxa2 357 | d3hpea_ 358 | d3md3a1 359 | d1jmrb4 360 | d3tqva1 361 | d1kfwa1 362 | d2bz7a_ 363 | d4cvqa_ 364 | d1v3wa_ 365 | d1gdha2 366 | d3frxa_ 367 | d1i7qb_ 368 | d4u5ia_ 369 | d2vyoa_ 370 | d1xxxa1 371 | d2fbwc_ 372 | d6hbsa_ 373 | d4m5ra1 374 | d1pbyb_ 375 | d5uida_ 376 | d2nq3a1 377 | d1xkua_ 378 | d4c4ko1 379 | d3zdmc_ 380 | d1vq8b1 381 | d1rd5a_ 382 | d2cuaa_ 383 | d1usca_ 384 | d4jw0a_ 385 | d1ug2a1 386 | d1m3ya2 387 | d3mgka_ 388 | d2m87a_ 389 | d5bnza1 390 | d1b9ha_ 391 | d3czba_ 392 | d3bbba_ 393 | d1lw7a1 394 | d2fjca1 395 | d2mska_ 396 | d3eura1 397 | d1p4ca_ 398 | d6xr5a2 399 | -------------------------------------------------------------------------------- /scripts/other_methods/dali/import.sh: -------------------------------------------------------------------------------- 1 | mkdir imported 2 | while read domid_pdbid; do 3 | domid=$(echo $domid_pdbid | cut -f1 -d " ") 4 | pdbid=$(echo $domid_pdbid | cut -f2 -d " ") 5 | echo $domid 6 | perl ~/soft/DaliLite.v5/bin/import.pl --pdbfile ../pdbstyle-2.08/$domid.ent --pdbid mol1 --dat ./ --clean 7 | mv mol1?.dat imported/${pdbid}A.dat 8 | done < domids_pdbids.txt 9 | -------------------------------------------------------------------------------- /scripts/other_methods/dali/run.sh: -------------------------------------------------------------------------------- 1 | mkdir run 2 | while read domid; do 3 | echo $domid 4 | mkdir $domid 5 | cd $domid 6 | perl ~/soft/DaliLite.v5/bin/dali.pl --pdbfile1 ../../pdbstyle-2.08/$domid.ent --db ../pdbids_imported.txt --dat1 ./ --dat2 ../imported --clean 7 | mv mol1?.txt ../run/$domid.txt 8 | cd .. 9 | rm -r $domid 10 | done < domids_imported.txt 11 | -------------------------------------------------------------------------------- /scripts/other_methods/dali/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | import os 5 | from collections import defaultdict 6 | 7 | dataset_dir = "../../dataset" 8 | known_problems = ["d6a1ia1"] 9 | 10 | domid_to_fam = {} 11 | fams_search = [] 12 | with open(dataset_dir + "/search_all.txt") as f: 13 | for line in f.readlines(): 14 | cols = line.rstrip().split() 15 | domid, fam = cols[0], cols[1] 16 | if domid not in known_problems: 17 | domid_to_fam[domid] = fam 18 | fams_search.append(fam) 19 | 20 | domid_to_nres, domid_to_contact_order = {}, {} 21 | with open("../contact_order.txt") as f: 22 | for line in f.readlines(): 23 | cols = line.rstrip().split() 24 | domid, nres, contact_order = cols[0], cols[1], cols[2] 25 | if domid not in known_problems: 26 | domid_to_nres[domid] = int(nres) 27 | domid_to_contact_order[domid] = float(contact_order) 28 | 29 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 30 | domids_unseen = [l.split()[0] for l in f.readlines()] 31 | 32 | pdbid_to_domid = {} 33 | with open("domids_pdbids.txt") as f: 34 | for line in f.readlines(): 35 | domid, pdbid = line.rstrip().split() 36 | pdbid_to_domid[pdbid] = domid 37 | 38 | domid_to_fam_matches = defaultdict(list) 39 | for domid_query in domids_unseen: 40 | fp = f"run/{domid_query}.txt" 41 | if domid_query in domid_to_fam and os.path.isfile(fp): 42 | with open(fp) as f: 43 | for line in f.readlines(): 44 | if len(line.rstrip()) > 0 and not line.startswith("#"): 45 | pdbid = line.strip().split()[1].split("-")[0] 46 | domid_match = pdbid_to_domid[pdbid] 47 | if domid_match in domid_to_fam: 48 | fam_match = domid_to_fam[domid_match] 49 | domid_to_fam_matches[domid_query].append(fam_match) 50 | 51 | def scop_fold(fam): 52 | return fam.rsplit(".", 2)[0] 53 | 54 | def scop_sfam(fam): 55 | return fam.rsplit(".", 1)[0] 56 | 57 | def searching_accuracy(domid_query, fam_matches): 58 | fam_query = domid_to_fam[domid_query] 59 | sfam_query = scop_sfam(fam_query) 60 | fold_query = scop_fold(fam_query) 61 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 62 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 63 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 64 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 65 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 66 | 67 | for fi, fam_match in enumerate(fam_matches): 68 | sfam_match = scop_sfam(fam_match) 69 | fold_match = scop_fold(fam_match) 70 | # Don't count self match for top 1 accuracy 71 | if fi == 1: 72 | if fam_match == fam_query: 73 | top1_fam = 1 74 | if sfam_match == sfam_query: 75 | top1_sfam = 1 76 | if fold_match == fold_query: 77 | top1_fold = 1 78 | if 0 < fi < 6 and fam_match == fam_query: 79 | top5_fam = 1 80 | if fam_match == fam_query: 81 | count_tp_fam += 1 82 | elif sfam_match == sfam_query: 83 | count_tp_sfam += 1 84 | elif fold_match == fold_query: 85 | count_tp_fold += 1 86 | else: 87 | break 88 | 89 | sens_fam = count_tp_fam / total_fam 90 | sens_sfam = count_tp_sfam / total_sfam 91 | sens_fold = count_tp_fold / total_fold 92 | 93 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 94 | 95 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 96 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 97 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 98 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 99 | 100 | for di, domid in enumerate(domids_unseen): 101 | if domid not in domid_to_fam_matches: 102 | continue 103 | 104 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 105 | 106 | senses_fam.append(sens_fam) 107 | senses_sfam.append(sens_sfam) 108 | senses_fold.append(sens_fold) 109 | top1s_fam.append(top1_fam) 110 | top1s_sfam.append(top1_sfam) 111 | top1s_fold.append(top1_fold) 112 | top5s_fam.append(top5_fam) 113 | 114 | cl = domid_to_fam[domid][0] 115 | senses_fam_class[cl].append(sens_fam) 116 | senses_sfam_class[cl].append(sens_sfam) 117 | senses_fold_class[cl].append(sens_fold) 118 | 119 | nres = domid_to_nres[domid] 120 | if 0 <= nres < 100: 121 | nres_group = "0-99" 122 | elif 100 <= nres < 200: 123 | nres_group = "100-199" 124 | elif 200 <= nres < 300: 125 | nres_group = "200-299" 126 | else: 127 | nres_group = "300+" 128 | 129 | senses_fam_nres[nres_group].append(sens_fam) 130 | senses_sfam_nres[nres_group].append(sens_sfam) 131 | senses_fold_nres[nres_group].append(sens_fold) 132 | 133 | co = domid_to_contact_order[domid] 134 | if 0 <= co < 0.075: 135 | co_group = "0-0.075" 136 | elif 0.075 <= co < 0.125: 137 | co_group = "0.075-0.125" 138 | elif 0.125 <= co < 0.175: 139 | co_group = "0.125-0.175" 140 | else: 141 | co_group = "0.175+" 142 | 143 | senses_fam_co[co_group].append(sens_fam) 144 | senses_sfam_co[co_group].append(sens_sfam) 145 | senses_fold_co[co_group].append(sens_fold) 146 | 147 | print(np.mean(senses_fam)) 148 | print(np.mean(senses_sfam)) 149 | print(np.mean(senses_fold)) 150 | print() 151 | 152 | for cl in sorted(list(senses_fam_class.keys())): 153 | print(cl, "class") 154 | print(np.mean(senses_fam_class[cl])) 155 | print(np.mean(senses_sfam_class[cl])) 156 | print(np.mean(senses_fold_class[cl])) 157 | print() 158 | 159 | for nres_group in sorted(list(senses_fam_nres.keys())): 160 | print(nres_group, "nres") 161 | print(np.mean(senses_fam_nres[nres_group])) 162 | print(np.mean(senses_sfam_nres[nres_group])) 163 | print(np.mean(senses_fold_nres[nres_group])) 164 | print() 165 | 166 | for co_group in sorted(list(senses_fam_co.keys())): 167 | print(co_group, "contact_order") 168 | print(np.mean(senses_fam_co[co_group])) 169 | print(np.mean(senses_sfam_co[co_group])) 170 | print(np.mean(senses_fold_co[co_group])) 171 | print() 172 | -------------------------------------------------------------------------------- /scripts/other_methods/dali/searching.txt: -------------------------------------------------------------------------------- 1 | 0.8845702523451638 2 | 0.7090848588018847 3 | 0.16814762755821105 4 | 5 | a class 6 | 0.8041048157760115 7 | 0.4778730511737781 8 | 0.08868434333527841 9 | 10 | b class 11 | 0.8787218609110097 12 | 0.7302754073044503 13 | 0.23786066983373672 14 | 15 | c class 16 | 0.9412446902540419 17 | 0.8028941990581118 18 | 0.17338668114394687 19 | 20 | d class 21 | 0.9120338553775086 22 | 0.7623486187990098 23 | 0.11605653050026907 24 | 25 | f class 26 | 0.8497023809523809 27 | 0.7083333333333334 28 | 0.14388794567062818 29 | 30 | g class 31 | 0.47328042328042325 32 | 0.2303113553113553 33 | 0.0 34 | 35 | 0-99 nres 36 | 0.7661906980962084 37 | 0.4940065916843778 38 | 0.0829641346712276 39 | 40 | 100-199 nres 41 | 0.8988886101524557 42 | 0.7503392028935789 43 | 0.1485664543361721 44 | 45 | 200-299 nres 46 | 0.9297315478336284 47 | 0.7819632804978676 48 | 0.25965694758354485 49 | 50 | 300+ nres 51 | 0.981284938668682 52 | 0.84875396326273 53 | 0.26740263831078676 54 | 55 | 0-0.075 contact_order 56 | 0.7462564069700616 57 | 0.5446318839907001 58 | 0.1613996812050454 59 | 60 | 0.075-0.125 contact_order 61 | 0.932075820395151 62 | 0.7518390111305397 63 | 0.1772013976841998 64 | 65 | 0.125-0.175 contact_order 66 | 0.8913457498717022 67 | 0.7122520692433876 68 | 0.1962162161612364 69 | 70 | 0.175+ contact_order 71 | 0.8266175916877491 72 | 0.6780919452399018 73 | 0.10470910258687956 74 | 75 | -------------------------------------------------------------------------------- /scripts/other_methods/eat/run.sh: -------------------------------------------------------------------------------- 1 | # May need to use fix to eat.py in https://github.com/Rostlab/EAT/issues/7 2 | python ~/soft/EAT/eat.py --lookup scope40.fasta --queries scope40.fasta --output eat_results/ --use_tucker 1 --num_NN 15000 3 | -------------------------------------------------------------------------------- /scripts/other_methods/eat/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | results_dir = "eat_results" 7 | dataset_dir = "../../dataset" 8 | 9 | domid_to_fam = {} 10 | fams_search = [] 11 | with open(dataset_dir + "/search_all.txt") as f: 12 | for line in f.readlines(): 13 | cols = line.rstrip().split() 14 | fam = cols[1] 15 | domid_to_fam[cols[0]] = fam 16 | fams_search.append(fam) 17 | 18 | domid_to_nres, domid_to_contact_order = {}, {} 19 | with open("../contact_order.txt") as f: 20 | for line in f.readlines(): 21 | cols = line.rstrip().split() 22 | domid, nres, contact_order = cols[0], cols[1], cols[2] 23 | domid_to_nres[domid] = int(nres) 24 | domid_to_contact_order[domid] = float(contact_order) 25 | 26 | domid_to_fam_matches = defaultdict(list) 27 | with open(results_dir + "/eat_result.txt") as f: 28 | for l in f.readlines(): 29 | if not l.startswith("Query-ID"): 30 | cols = l.split() 31 | domid_query, domid_match = cols[0], cols[2] 32 | if domid_query in domid_to_fam and domid_match in domid_to_fam: 33 | domid_to_fam_matches[domid_query].append(domid_to_fam[domid_match]) 34 | 35 | with open(dataset_dir + "test_unseen_sid.txt") as f: 36 | domids_unseen = [l.split()[0] for l in f.readlines()] 37 | 38 | def scop_fold(fam): 39 | return fam.rsplit(".", 2)[0] 40 | 41 | def scop_sfam(fam): 42 | return fam.rsplit(".", 1)[0] 43 | 44 | def searching_accuracy(domid_query, fam_matches): 45 | fam_query = domid_to_fam[domid_query] 46 | sfam_query = scop_sfam(fam_query) 47 | fold_query = scop_fold(fam_query) 48 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 49 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 50 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 51 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 52 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 53 | 54 | for fi, fam_match in enumerate(fam_matches): 55 | sfam_match = scop_sfam(fam_match) 56 | fold_match = scop_fold(fam_match) 57 | # Don't count self match for top 1 accuracy 58 | if fi == 1: 59 | if fam_match == fam_query: 60 | top1_fam = 1 61 | if sfam_match == sfam_query: 62 | top1_sfam = 1 63 | if fold_match == fold_query: 64 | top1_fold = 1 65 | if 0 < fi < 6 and fam_match == fam_query: 66 | top5_fam = 1 67 | if fam_match == fam_query: 68 | count_tp_fam += 1 69 | elif sfam_match == sfam_query: 70 | count_tp_sfam += 1 71 | elif fold_match == fold_query: 72 | count_tp_fold += 1 73 | else: 74 | break 75 | 76 | sens_fam = count_tp_fam / total_fam 77 | sens_sfam = count_tp_sfam / total_sfam 78 | sens_fold = count_tp_fold / total_fold 79 | 80 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 81 | 82 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 83 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 84 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 85 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 86 | 87 | for di, domid in enumerate(domids_unseen): 88 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 89 | 90 | senses_fam.append(sens_fam) 91 | senses_sfam.append(sens_sfam) 92 | senses_fold.append(sens_fold) 93 | top1s_fam.append(top1_fam) 94 | top1s_sfam.append(top1_sfam) 95 | top1s_fold.append(top1_fold) 96 | top5s_fam.append(top5_fam) 97 | 98 | cl = domid_to_fam[domid][0] 99 | senses_fam_class[cl].append(sens_fam) 100 | senses_sfam_class[cl].append(sens_sfam) 101 | senses_fold_class[cl].append(sens_fold) 102 | 103 | nres = domid_to_nres[domid] 104 | if 0 <= nres < 100: 105 | nres_group = "0-99" 106 | elif 100 <= nres < 200: 107 | nres_group = "100-199" 108 | elif 200 <= nres < 300: 109 | nres_group = "200-299" 110 | else: 111 | nres_group = "300+" 112 | 113 | senses_fam_nres[nres_group].append(sens_fam) 114 | senses_sfam_nres[nres_group].append(sens_sfam) 115 | senses_fold_nres[nres_group].append(sens_fold) 116 | 117 | co = domid_to_contact_order[domid] 118 | if 0 <= co < 0.075: 119 | co_group = "0-0.075" 120 | elif 0.075 <= co < 0.125: 121 | co_group = "0.075-0.125" 122 | elif 0.125 <= co < 0.175: 123 | co_group = "0.125-0.175" 124 | else: 125 | co_group = "0.175+" 126 | 127 | senses_fam_co[co_group].append(sens_fam) 128 | senses_sfam_co[co_group].append(sens_sfam) 129 | senses_fold_co[co_group].append(sens_fold) 130 | 131 | print(np.mean(senses_fam)) 132 | print(np.mean(senses_sfam)) 133 | print(np.mean(senses_fold)) 134 | print() 135 | 136 | for cl in sorted(list(senses_fam_class.keys())): 137 | print(cl, "class") 138 | print(np.mean(senses_fam_class[cl])) 139 | print(np.mean(senses_sfam_class[cl])) 140 | print(np.mean(senses_fold_class[cl])) 141 | print() 142 | 143 | for nres_group in sorted(list(senses_fam_nres.keys())): 144 | print(nres_group, "nres") 145 | print(np.mean(senses_fam_nres[nres_group])) 146 | print(np.mean(senses_sfam_nres[nres_group])) 147 | print(np.mean(senses_fold_nres[nres_group])) 148 | print() 149 | 150 | for co_group in sorted(list(senses_fam_co.keys())): 151 | print(co_group, "contact_order") 152 | print(np.mean(senses_fam_co[co_group])) 153 | print(np.mean(senses_sfam_co[co_group])) 154 | print(np.mean(senses_fold_co[co_group])) 155 | print() 156 | -------------------------------------------------------------------------------- /scripts/other_methods/eat/searching.txt: -------------------------------------------------------------------------------- 1 | 0.8425490208812232 2 | 0.615026869524143 3 | 0.10136891357128358 4 | 5 | a class 6 | 0.8208160342073386 7 | 0.48863501276769233 8 | 0.0783166709327955 9 | 10 | b class 11 | 0.8420905901021315 12 | 0.6200705585021667 13 | 0.16543843997392027 14 | 15 | c class 16 | 0.8437367054136923 17 | 0.6848032338888271 18 | 0.0793278928234952 19 | 20 | d class 21 | 0.8865006596430745 22 | 0.6087192556610456 23 | 0.036459026011296546 24 | 25 | f class 26 | 0.8705357142857143 27 | 0.7447916666666666 28 | 0.18771222410865873 29 | 30 | g class 31 | 0.5736961451247166 32 | 0.4103610675039246 33 | 0.016978434409627072 34 | 35 | 0-99 nres 36 | 0.7718405287836919 37 | 0.5350137969926312 38 | 0.050655445625553575 39 | 40 | 100-199 nres 41 | 0.8537218802525449 42 | 0.6297209222430927 43 | 0.08112222020349766 44 | 45 | 200-299 nres 46 | 0.8676610724576131 47 | 0.5764137811680339 48 | 0.14631167297036154 49 | 50 | 300+ nres 51 | 0.8967228084401001 52 | 0.7372734741659361 53 | 0.1971451359804853 54 | 55 | 0-0.075 contact_order 56 | 0.7902914123241124 57 | 0.5200137444029397 58 | 0.1094149385513591 59 | 60 | 0.075-0.125 contact_order 61 | 0.8663639991873403 62 | 0.6517059598374546 63 | 0.10434490421352989 64 | 65 | 0.125-0.175 contact_order 66 | 0.8282153231717085 67 | 0.5672163472075178 68 | 0.08546891914629183 69 | 70 | 0.175+ contact_order 71 | 0.8361969956068 72 | 0.6535371722679388 73 | 0.11797714813778958 74 | 75 | -------------------------------------------------------------------------------- /scripts/other_methods/esm/embed.sh: -------------------------------------------------------------------------------- 1 | # Download extract.py from https://github.com/facebookresearch/esm/raw/main/scripts/extract.py 2 | python extract.py esm2_t36_3B_UR50D ../astral_40_upper.fa embed --include mean --truncation_seq_length 10000 3 | -------------------------------------------------------------------------------- /scripts/other_methods/esm/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import torch 4 | from torch.nn.functional import normalize 5 | import numpy as np 6 | import os 7 | from random import sample 8 | 9 | dataset_dir = "../../dataset" 10 | esm_embedding_size = 2560 11 | 12 | def scop_fold(fam): 13 | return fam.rsplit(".", 2)[0] 14 | 15 | def scop_sfam(fam): 16 | return fam.rsplit(".", 1)[0] 17 | 18 | def read_dataset(fp): 19 | domids, fams = [], [] 20 | with open(fp) as f: 21 | for line in f.readlines(): 22 | domid, fam = line.rstrip().split() 23 | domids.append(domid) 24 | fams.append(fam) 25 | return domids, fams 26 | 27 | domids_all, fams_all = read_dataset(dataset_dir + "/search_all.txt") 28 | domids_unseen, fams_unseen = read_dataset(dataset_dir + "/test_unseen_sid.txt") 29 | 30 | inds_unseen = [domids_all.index(d) for d in domids_unseen] 31 | 32 | search_embeddings = torch.zeros(len(domids_all), esm_embedding_size) 33 | for di, domid in enumerate(domids_all): 34 | l = torch.load(os.path.join("embed", domid + ".pt"), map_location="cpu") 35 | emb = l["mean_representations"][36] 36 | search_embeddings[di] = normalize(emb, dim=0) 37 | 38 | def embedding_distance(emb_1, emb_2): 39 | cosine_dist = (emb_1 * emb_2).sum(dim=-1) # Normalised in the model 40 | return (1 - cosine_dist) / 2 # Runs 0 (close) to 1 (far) 41 | 42 | def validate_searching_fam(search_embeddings, search_fams, search_inds, n_samples=None, log=False): 43 | sensitivities, top1_accuracies, top5_accuracies = [], [], [] 44 | if n_samples is None or n_samples >= len(search_inds): 45 | sampled_is = search_inds[:] 46 | else: 47 | sampled_is = sample(search_inds, n_samples) 48 | for i in sampled_is: 49 | fam = search_fams[i] 50 | total_pos = sum(1 if h == fam else 0 for h in search_fams) 51 | dists = embedding_distance(search_embeddings[i:(i + 1)], search_embeddings) 52 | count_tp = 0 53 | top5_fam = 0 54 | for ji, j in enumerate(dists.argsort()): 55 | matched_fam = search_fams[j] 56 | if ji == 1: 57 | top1_accuracies.append(1 if matched_fam == fam else 0) 58 | if 0 < ji < 6 and matched_fam == fam: 59 | top5_fam = 1 60 | if matched_fam == fam: 61 | count_tp += 1 62 | elif scop_fold(matched_fam) != scop_fold(fam): 63 | break 64 | sensitivities.append(count_tp / total_pos) 65 | top5_accuracies.append(top5_fam) 66 | if log: 67 | print(domids_all[i], fam, count_tp / total_pos) 68 | return sensitivities, top1_accuracies, top5_accuracies 69 | 70 | def validate_searching_sfam(search_embeddings, search_fams, search_inds, n_samples=None, log=False): 71 | sensitivities, top1_accuracies = [], [] 72 | if n_samples is None or n_samples >= len(search_inds): 73 | sampled_is = search_inds[:] 74 | else: 75 | sampled_is = sample(search_inds, n_samples) 76 | for i in sampled_is: 77 | fam = search_fams[i] 78 | sfam = scop_sfam(fam) 79 | fold = scop_fold(fam) 80 | total_pos = sum(1 if f != fam and scop_sfam(f) == sfam else 0 for f in search_fams) 81 | dists = embedding_distance(search_embeddings[i:(i + 1)], search_embeddings) 82 | count_tp = 0 83 | for ji, j in enumerate(dists.argsort()): 84 | matched_fam = search_fams[j] 85 | if ji == 1: 86 | top1_accuracies.append(1 if scop_sfam(matched_fam) == sfam else 0) 87 | if matched_fam != fam and scop_sfam(matched_fam) == sfam: 88 | count_tp += 1 89 | elif scop_fold(matched_fam) != fold: 90 | break 91 | sensitivities.append(count_tp / total_pos) 92 | if log: 93 | print(domids_all[i], fam, count_tp / total_pos) 94 | return sensitivities, top1_accuracies 95 | 96 | def validate_searching_fold(search_embeddings, search_fams, search_inds, n_samples=None, log=False): 97 | sensitivities, top1_accuracies = [], [] 98 | if n_samples is None or n_samples >= len(search_inds): 99 | sampled_is = search_inds[:] 100 | else: 101 | sampled_is = sample(search_inds, n_samples) 102 | for i in sampled_is: 103 | fam = search_fams[i] 104 | sfam = scop_sfam(fam) 105 | fold = scop_fold(fam) 106 | total_pos = sum(1 if scop_sfam(f) != sfam and scop_fold(f) == fold else 0 for f in search_fams) 107 | dists = embedding_distance(search_embeddings[i:(i + 1)], search_embeddings) 108 | count_tp = 0 109 | for ji, j in enumerate(dists.argsort()): 110 | matched_fam = search_fams[j] 111 | if ji == 1: 112 | top1_accuracies.append(1 if scop_fold(matched_fam) == fold else 0) 113 | if scop_sfam(matched_fam) != sfam and scop_fold(matched_fam) == fold: 114 | count_tp += 1 115 | elif scop_fold(matched_fam) != fold: 116 | break 117 | sensitivities.append(count_tp / total_pos) 118 | if log: 119 | print(domids_all[i], fam, count_tp / total_pos) 120 | return sensitivities, top1_accuracies 121 | 122 | sens_unseen_fam, top1s_fam, top5s_fam = validate_searching_fam(search_embeddings, fams_all, inds_unseen, None, True) 123 | sens_unseen_sfam, top1s_sfam = validate_searching_sfam(search_embeddings, fams_all, inds_unseen, None, True) 124 | sens_unseen_fold, top1s_fold = validate_searching_fold(search_embeddings, fams_all, inds_unseen, None, True) 125 | 126 | print(np.mean(sens_unseen_fam )) 127 | print(np.mean(sens_unseen_sfam)) 128 | print(np.mean(sens_unseen_fold)) 129 | -------------------------------------------------------------------------------- /scripts/other_methods/esm/searching.txt: -------------------------------------------------------------------------------- 1 | 0.47697878237395175 2 | 0.22080590097369382 3 | 0.014305441667733208 4 | -------------------------------------------------------------------------------- /scripts/other_methods/extract_model.jl: -------------------------------------------------------------------------------- 1 | # Extract first model from ASTRAL PDB files 2 | 3 | using BioStructures 4 | 5 | in_dir = "pdbstyle-2.08_models" 6 | out_dir = "pdbstyle-2.08" 7 | known_problems = ["d6a1ia1.ent"] 8 | 9 | fs = readdir(in_dir) 10 | 11 | for (fi, f) in enumerate(fs) 12 | println(fi, " / ", length(fs)) 13 | f in known_problems && continue 14 | out_fp = joinpath(out_dir, f) 15 | isfile(out_fp) && continue 16 | s = read(joinpath(in_dir, f), PDBFormat) 17 | writepdb(out_fp, s[1]) 18 | end 19 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/run.sh: -------------------------------------------------------------------------------- 1 | foldseek easy-search ../pdbstyle-2.08 ../pdbstyle-2.08 run.out tmp --threads 16 -s 9.5 --max-seqs 2000 -e 10 2 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/run_tm.sh: -------------------------------------------------------------------------------- 1 | foldseek easy-search ../pdbstyle-2.08 ../pdbstyle-2.08 run_tm.out tmp --threads 16 -s 9.5 --max-seqs 2000 -e 10 --alignment-type 1 2 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | run_file = "run.out" 7 | dataset_dir = "../../dataset" 8 | known_problems = ["d6a1ia1"] 9 | 10 | domid_to_fam = {} 11 | fams_search = [] 12 | with open(dataset_dir + "/search_all.txt") as f: 13 | for line in f.readlines(): 14 | cols = line.rstrip().split() 15 | domid, fam = cols[0], cols[1] 16 | if domid not in known_problems: 17 | domid_to_fam[domid] = fam 18 | fams_search.append(fam) 19 | 20 | domid_to_nres, domid_to_contact_order = {}, {} 21 | with open("../contact_order.txt") as f: 22 | for line in f.readlines(): 23 | cols = line.rstrip().split() 24 | domid, nres, contact_order = cols[0], cols[1], cols[2] 25 | if domid not in known_problems: 26 | domid_to_nres[domid] = int(nres) 27 | domid_to_contact_order[domid] = float(contact_order) 28 | 29 | domid_to_fam_matches = defaultdict(list) 30 | with open(run_file) as f: 31 | for line in f.readlines(): 32 | cols = line.split() 33 | domid_query = cols[0].split(".")[0] 34 | domid_match = cols[1].split(".")[0] 35 | if domid_query in domid_to_fam and domid_match in domid_to_fam: 36 | fam_match = domid_to_fam[domid_match] 37 | domid_to_fam_matches[domid_query].append(fam_match) 38 | 39 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 40 | domids_unseen = [l.split()[0] for l in f.readlines()] 41 | 42 | def scop_fold(fam): 43 | return fam.rsplit(".", 2)[0] 44 | 45 | def scop_sfam(fam): 46 | return fam.rsplit(".", 1)[0] 47 | 48 | def searching_accuracy(domid_query, fam_matches): 49 | fam_query = domid_to_fam[domid_query] 50 | sfam_query = scop_sfam(fam_query) 51 | fold_query = scop_fold(fam_query) 52 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 53 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 54 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 55 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 56 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 57 | 58 | for fi, fam_match in enumerate(fam_matches): 59 | sfam_match = scop_sfam(fam_match) 60 | fold_match = scop_fold(fam_match) 61 | # Don't count self match for top 1 accuracy 62 | if fi == 1: 63 | if fam_match == fam_query: 64 | top1_fam = 1 65 | if sfam_match == sfam_query: 66 | top1_sfam = 1 67 | if fold_match == fold_query: 68 | top1_fold = 1 69 | if 0 < fi < 6 and fam_match == fam_query: 70 | top5_fam = 1 71 | if fam_match == fam_query: 72 | count_tp_fam += 1 73 | elif sfam_match == sfam_query: 74 | count_tp_sfam += 1 75 | elif fold_match == fold_query: 76 | count_tp_fold += 1 77 | else: 78 | break 79 | 80 | sens_fam = count_tp_fam / total_fam 81 | sens_sfam = count_tp_sfam / total_sfam 82 | sens_fold = count_tp_fold / total_fold 83 | 84 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 85 | 86 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 87 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 88 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 89 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 90 | 91 | for di, domid in enumerate(domids_unseen): 92 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 93 | 94 | senses_fam.append(sens_fam) 95 | senses_sfam.append(sens_sfam) 96 | senses_fold.append(sens_fold) 97 | top1s_fam.append(top1_fam) 98 | top1s_sfam.append(top1_sfam) 99 | top1s_fold.append(top1_fold) 100 | top5s_fam.append(top5_fam) 101 | 102 | cl = domid_to_fam[domid][0] 103 | senses_fam_class[cl].append(sens_fam) 104 | senses_sfam_class[cl].append(sens_sfam) 105 | senses_fold_class[cl].append(sens_fold) 106 | 107 | nres = domid_to_nres[domid] 108 | if 0 <= nres < 100: 109 | nres_group = "0-99" 110 | elif 100 <= nres < 200: 111 | nres_group = "100-199" 112 | elif 200 <= nres < 300: 113 | nres_group = "200-299" 114 | else: 115 | nres_group = "300+" 116 | 117 | senses_fam_nres[nres_group].append(sens_fam) 118 | senses_sfam_nres[nres_group].append(sens_sfam) 119 | senses_fold_nres[nres_group].append(sens_fold) 120 | 121 | co = domid_to_contact_order[domid] 122 | if 0 <= co < 0.075: 123 | co_group = "0-0.075" 124 | elif 0.075 <= co < 0.125: 125 | co_group = "0.075-0.125" 126 | elif 0.125 <= co < 0.175: 127 | co_group = "0.125-0.175" 128 | else: 129 | co_group = "0.175+" 130 | 131 | senses_fam_co[co_group].append(sens_fam) 132 | senses_sfam_co[co_group].append(sens_sfam) 133 | senses_fold_co[co_group].append(sens_fold) 134 | 135 | print(np.mean(senses_fam)) 136 | print(np.mean(senses_sfam)) 137 | print(np.mean(senses_fold)) 138 | print() 139 | 140 | for cl in sorted(list(senses_fam_class.keys())): 141 | print(cl, "class") 142 | print(np.mean(senses_fam_class[cl])) 143 | print(np.mean(senses_sfam_class[cl])) 144 | print(np.mean(senses_fold_class[cl])) 145 | print() 146 | 147 | for nres_group in sorted(list(senses_fam_nres.keys())): 148 | print(nres_group, "nres") 149 | print(np.mean(senses_fam_nres[nres_group])) 150 | print(np.mean(senses_sfam_nres[nres_group])) 151 | print(np.mean(senses_fold_nres[nres_group])) 152 | print() 153 | 154 | for co_group in sorted(list(senses_fam_co.keys())): 155 | print(co_group, "contact_order") 156 | print(np.mean(senses_fam_co[co_group])) 157 | print(np.mean(senses_sfam_co[co_group])) 158 | print(np.mean(senses_fold_co[co_group])) 159 | print() 160 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/searching.txt: -------------------------------------------------------------------------------- 1 | 0.850267428683668 2 | 0.6442376403285707 3 | 0.11067573671036186 4 | 5 | a class 6 | 0.760447592541238 7 | 0.451518781403812 8 | 0.05739740573289434 9 | 10 | b class 11 | 0.8722939453077839 12 | 0.6712009686959499 13 | 0.14713273171842722 14 | 15 | c class 16 | 0.863048773622922 17 | 0.6907334480012806 18 | 0.10632444709494805 19 | 20 | d class 21 | 0.9027570339876841 22 | 0.7012626777680858 23 | 0.08583708929586616 24 | 25 | f class 26 | 0.8913690476190477 27 | 0.8784722222222222 28 | 0.3098471986417657 29 | 30 | g class 31 | 0.4929705215419501 32 | 0.30659340659340656 33 | 0.006187058861060925 34 | 35 | 0-99 nres 36 | 0.7434338211535859 37 | 0.4869781069348598 38 | 0.0650663151800564 39 | 40 | 100-199 nres 41 | 0.8736602296844398 42 | 0.6810396239647801 43 | 0.08836418156082962 44 | 45 | 200-299 nres 46 | 0.8745493799353734 47 | 0.643963518041649 48 | 0.18091067390794652 49 | 50 | 300+ nres 51 | 0.9264570011906222 52 | 0.7853808724685427 53 | 0.1791703957505634 54 | 55 | 0-0.075 contact_order 56 | 0.7759587689768865 57 | 0.5735574968770739 58 | 0.11117453698391733 59 | 60 | 0.075-0.125 contact_order 61 | 0.8797431036559257 62 | 0.6616243802790973 63 | 0.13168231803998898 64 | 65 | 0.125-0.175 contact_order 66 | 0.8433625227652246 67 | 0.6313306010284506 68 | 0.09677343036240812 69 | 70 | 0.175+ contact_order 71 | 0.828410352477564 72 | 0.6567242162542168 73 | 0.08839059616908168 74 | 75 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/searching_tm.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | run_file = "run_tm.out" 7 | dataset_dir = "../../dataset" 8 | known_problems = ["d6a1ia1"] 9 | 10 | domid_to_fam = {} 11 | fams_search = [] 12 | with open(dataset_dir + "/search_all.txt") as f: 13 | for line in f.readlines(): 14 | cols = line.rstrip().split() 15 | domid, fam = cols[0], cols[1] 16 | if domid not in known_problems: 17 | domid_to_fam[domid] = fam 18 | fams_search.append(fam) 19 | 20 | domid_to_nres, domid_to_contact_order = {}, {} 21 | with open("../contact_order.txt") as f: 22 | for line in f.readlines(): 23 | cols = line.rstrip().split() 24 | domid, nres, contact_order = cols[0], cols[1], cols[2] 25 | if domid not in known_problems: 26 | domid_to_nres[domid] = int(nres) 27 | domid_to_contact_order[domid] = float(contact_order) 28 | 29 | domid_to_fam_matches = defaultdict(list) 30 | with open(run_file) as f: 31 | for line in f.readlines(): 32 | cols = line.split() 33 | domid_query = cols[0].split(".")[0] 34 | domid_match = cols[1].split(".")[0] 35 | if domid_query in domid_to_fam and domid_match in domid_to_fam: 36 | fam_match = domid_to_fam[domid_match] 37 | domid_to_fam_matches[domid_query].append(fam_match) 38 | 39 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 40 | domids_unseen = [l.split()[0] for l in f.readlines()] 41 | 42 | def scop_fold(fam): 43 | return fam.rsplit(".", 2)[0] 44 | 45 | def scop_sfam(fam): 46 | return fam.rsplit(".", 1)[0] 47 | 48 | def searching_accuracy(domid_query, fam_matches): 49 | fam_query = domid_to_fam[domid_query] 50 | sfam_query = scop_sfam(fam_query) 51 | fold_query = scop_fold(fam_query) 52 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 53 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 54 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 55 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 56 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 57 | 58 | for fi, fam_match in enumerate(fam_matches): 59 | sfam_match = scop_sfam(fam_match) 60 | fold_match = scop_fold(fam_match) 61 | # Don't count self match for top 1 accuracy 62 | if fi == 1: 63 | if fam_match == fam_query: 64 | top1_fam = 1 65 | if sfam_match == sfam_query: 66 | top1_sfam = 1 67 | if fold_match == fold_query: 68 | top1_fold = 1 69 | if 0 < fi < 6 and fam_match == fam_query: 70 | top5_fam = 1 71 | if fam_match == fam_query: 72 | count_tp_fam += 1 73 | elif sfam_match == sfam_query: 74 | count_tp_sfam += 1 75 | elif fold_match == fold_query: 76 | count_tp_fold += 1 77 | else: 78 | break 79 | 80 | sens_fam = count_tp_fam / total_fam 81 | sens_sfam = count_tp_sfam / total_sfam 82 | sens_fold = count_tp_fold / total_fold 83 | 84 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 85 | 86 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 87 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 88 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 89 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 90 | 91 | for di, domid in enumerate(domids_unseen): 92 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 93 | 94 | senses_fam.append(sens_fam) 95 | senses_sfam.append(sens_sfam) 96 | senses_fold.append(sens_fold) 97 | top1s_fam.append(top1_fam) 98 | top1s_sfam.append(top1_sfam) 99 | top1s_fold.append(top1_fold) 100 | top5s_fam.append(top5_fam) 101 | 102 | cl = domid_to_fam[domid][0] 103 | senses_fam_class[cl].append(sens_fam) 104 | senses_sfam_class[cl].append(sens_sfam) 105 | senses_fold_class[cl].append(sens_fold) 106 | 107 | nres = domid_to_nres[domid] 108 | if 0 <= nres < 100: 109 | nres_group = "0-99" 110 | elif 100 <= nres < 200: 111 | nres_group = "100-199" 112 | elif 200 <= nres < 300: 113 | nres_group = "200-299" 114 | else: 115 | nres_group = "300+" 116 | 117 | senses_fam_nres[nres_group].append(sens_fam) 118 | senses_sfam_nres[nres_group].append(sens_sfam) 119 | senses_fold_nres[nres_group].append(sens_fold) 120 | 121 | co = domid_to_contact_order[domid] 122 | if 0 <= co < 0.075: 123 | co_group = "0-0.075" 124 | elif 0.075 <= co < 0.125: 125 | co_group = "0.075-0.125" 126 | elif 0.125 <= co < 0.175: 127 | co_group = "0.125-0.175" 128 | else: 129 | co_group = "0.175+" 130 | 131 | senses_fam_co[co_group].append(sens_fam) 132 | senses_sfam_co[co_group].append(sens_sfam) 133 | senses_fold_co[co_group].append(sens_fold) 134 | 135 | print(np.mean(senses_fam)) 136 | print(np.mean(senses_sfam)) 137 | print(np.mean(senses_fold)) 138 | print() 139 | 140 | for cl in sorted(list(senses_fam_class.keys())): 141 | print(cl, "class") 142 | print(np.mean(senses_fam_class[cl])) 143 | print(np.mean(senses_sfam_class[cl])) 144 | print(np.mean(senses_fold_class[cl])) 145 | print() 146 | 147 | for nres_group in sorted(list(senses_fam_nres.keys())): 148 | print(nres_group, "nres") 149 | print(np.mean(senses_fam_nres[nres_group])) 150 | print(np.mean(senses_sfam_nres[nres_group])) 151 | print(np.mean(senses_fold_nres[nres_group])) 152 | print() 153 | 154 | for co_group in sorted(list(senses_fam_co.keys())): 155 | print(co_group, "contact_order") 156 | print(np.mean(senses_fam_co[co_group])) 157 | print(np.mean(senses_sfam_co[co_group])) 158 | print(np.mean(senses_fold_co[co_group])) 159 | print() 160 | -------------------------------------------------------------------------------- /scripts/other_methods/foldseek/searching_tm.txt: -------------------------------------------------------------------------------- 1 | 0.8589513208488821 2 | 0.6663583849239579 3 | 0.15786481539284036 4 | 5 | a class 6 | 0.7339287201093221 7 | 0.43224379227006254 8 | 0.08200502400513911 9 | 10 | b class 11 | 0.8780962197818142 12 | 0.6910882469680515 13 | 0.20452879516173716 14 | 15 | c class 16 | 0.8931928336713214 17 | 0.7394199242869329 18 | 0.16617896541967545 19 | 20 | d class 21 | 0.9106440686936043 22 | 0.7515186248913654 23 | 0.14360726266558602 24 | 25 | f class 26 | 0.9122023809523809 27 | 0.6733630952380952 28 | 0.13582342954159593 29 | 30 | g class 31 | 0.5088435374149659 32 | 0.28618524332810047 33 | 0.0034916410713035123 34 | 35 | 0-99 nres 36 | 0.7408212339112289 37 | 0.47516737759963906 38 | 0.10728699069720012 39 | 40 | 100-199 nres 41 | 0.8917969754686862 42 | 0.7120452665316483 43 | 0.1355858832325859 44 | 45 | 200-299 nres 46 | 0.891334866782085 47 | 0.7066235581546966 48 | 0.25272711844318446 49 | 50 | 300+ nres 51 | 0.9169539215840051 52 | 0.7945555418484148 53 | 0.20953588474385942 54 | 55 | 0-0.075 contact_order 56 | 0.7271857251110455 57 | 0.5197194338034287 58 | 0.13799055847313035 59 | 60 | 0.075-0.125 contact_order 61 | 0.8945635599039466 62 | 0.6923383671530292 63 | 0.17140554798596622 64 | 65 | 0.125-0.175 contact_order 66 | 0.8636030959233312 67 | 0.672242598651514 68 | 0.16180973083509312 69 | 70 | 0.175+ contact_order 71 | 0.8280154583726785 72 | 0.66001832354136 73 | 0.13034373520425782 74 | 75 | -------------------------------------------------------------------------------- /scripts/other_methods/mmseqs2/run.sh: -------------------------------------------------------------------------------- 1 | time mmseqs easy-search ../astral_40_upper.fa ../astral_40_upper.fa out.m8 tmp -a --threads 16 -s 7.5 -e 10000 --max-seqs 2000 2 | -------------------------------------------------------------------------------- /scripts/other_methods/mmseqs2/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | 5 | dataset_dir = "../../dataset" 6 | 7 | with open("out.m8") as f: 8 | mmseqs_lines = [l.rstrip() for l in f.readlines()] 9 | 10 | domid_to_fam = {} 11 | fams_search = [] 12 | with open(dataset_dir + "/search_all.txt") as f: 13 | for line in f.readlines(): 14 | cols = line.rstrip().split() 15 | fam = cols[1] 16 | domid_to_fam[cols[0]] = fam 17 | fams_search.append(fam) 18 | 19 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 20 | domids_unseen = [l.split()[0] for l in f.readlines()] 21 | 22 | def scop_fold(fam): 23 | return fam.rsplit(".", 2)[0] 24 | 25 | def scop_sfam(fam): 26 | return fam.rsplit(".", 1)[0] 27 | 28 | def searching_accuracy(domid_query, mmseqs_lines): 29 | fam_matches = [] 30 | for line in mmseqs_lines: 31 | cols = line.split() 32 | if cols[0] == domid_query: 33 | domid_match = cols[1] 34 | if domid_match in domid_to_fam: 35 | fam_matches.append(domid_to_fam[domid_match]) 36 | 37 | fam_query = domid_to_fam[domid_query] 38 | sfam_query = scop_sfam(fam_query) 39 | fold_query = scop_fold(fam_query) 40 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 41 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 42 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 43 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 44 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 45 | 46 | for fi, fam_match in enumerate(fam_matches): 47 | sfam_match = scop_sfam(fam_match) 48 | fold_match = scop_fold(fam_match) 49 | # Don't count self match for top 1 accuracy 50 | if fi == 1: 51 | if fam_match == fam_query: 52 | top1_fam = 1 53 | if sfam_match == sfam_query: 54 | top1_sfam = 1 55 | if fold_match == fold_query: 56 | top1_fold = 1 57 | if 0 < fi < 6 and fam_match == fam_query: 58 | top5_fam = 1 59 | if fam_match == fam_query: 60 | count_tp_fam += 1 61 | elif sfam_match == sfam_query: 62 | count_tp_sfam += 1 63 | elif fold_match == fold_query: 64 | count_tp_fold += 1 65 | else: 66 | break 67 | 68 | sens_fam = count_tp_fam / total_fam 69 | sens_sfam = count_tp_sfam / total_sfam 70 | sens_fold = count_tp_fold / total_fold 71 | 72 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 73 | 74 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 75 | 76 | for di, domid in enumerate(domids_unseen): 77 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, mmseqs_lines) 78 | senses_fam.append(sens_fam) 79 | senses_sfam.append(sens_sfam) 80 | senses_fold.append(sens_fold) 81 | top1s_fam.append(top1_fam) 82 | top1s_sfam.append(top1_sfam) 83 | top1s_fold.append(top1_fold) 84 | top5s_fam.append(top5_fam) 85 | 86 | print(np.mean(senses_fam)) 87 | print(np.mean(senses_sfam)) 88 | print(np.mean(senses_fold)) 89 | -------------------------------------------------------------------------------- /scripts/other_methods/mmseqs2/searching.txt: -------------------------------------------------------------------------------- 1 | 0.43268166604567254 2 | 0.16539123596729066 3 | 0.00061057902803455 4 | -------------------------------------------------------------------------------- /scripts/other_methods/tmalign/domids.txt: -------------------------------------------------------------------------------- 1 | d3bb6a1 2 | d1vqta1 3 | d3nara_ 4 | d3euca_ 5 | d3aowa_ 6 | d3g46a_ 7 | d6i9sa2 8 | d6rw7a1 9 | d3eyxa_ 10 | d4uypa1 11 | d4rz4a1 12 | d1knva_ 13 | d2h6ra_ 14 | d4ipca1 15 | d6r3za_ 16 | d4n4ee_ 17 | d3neva_ 18 | d1l3wa3 19 | d2iw0a1 20 | d2g7ga1 21 | d1p5ub_ 22 | d3b1qa_ 23 | d5i8da3 24 | d1cb8a2 25 | d2v82a1 26 | d1q15a1 27 | d1wuza_ 28 | d2hs5a1 29 | d3zq7a_ 30 | d2abwa1 31 | d1x1ia3 32 | d4ay0a2 33 | d4j25a1 34 | d6bmaa1 35 | d1y4wa2 36 | d1lc5a_ 37 | d3cwna_ 38 | d3pmea2 39 | d1kgsa2 40 | d3djca1 41 | d5jgya_ 42 | d2xpwa1 43 | d1p0za1 44 | d2kiva2 45 | d6cw0a_ 46 | d1ta3a_ 47 | d2dy8a1 48 | d2vnud2 49 | d7dhfa1 50 | d5vrka_ 51 | d4cr2r1 52 | d2d69a1 53 | d2p3ra3 54 | d2wy4a_ 55 | d3dxea_ 56 | d1wela1 57 | d1t9ha1 58 | d3grfa2 59 | d6n0ka_ 60 | d1enfa1 61 | d1u5ha_ 62 | d1x3da1 63 | d3hhtb_ 64 | d2ghpa2 65 | d1qzya1 66 | d5h20a_ 67 | d4e70a1 68 | d4xria_ 69 | d1k61a_ 70 | d2csba1 71 | d1knla_ 72 | d5y5qa_ 73 | d2iiza_ 74 | d3heba1 75 | d2m6sa_ 76 | d2yrza1 77 | d1xyza_ 78 | d6nqia_ 79 | d2p39a_ 80 | d1woca_ 81 | d2bvya1 82 | d7coia_ 83 | d1fp3a_ 84 | d4r3va_ 85 | d3vrna2 86 | d1q16a1 87 | d2v5ca1 88 | d4g9ya_ 89 | d1i5pa2 90 | d2kyra_ 91 | d4q9ba_ 92 | d2pn6a2 93 | d1d5ra1 94 | d5wida_ 95 | d1ty0a2 96 | d2hjsa2 97 | d5afwa2 98 | d1yjra1 99 | d2xfwa1 100 | d1yn8a1 101 | d2wp4a_ 102 | d1bvsa3 103 | d2dkua1 104 | d1sqja2 105 | d2nwha_ 106 | d2p5ka_ 107 | d3ea6a2 108 | d1ga2a2 109 | d2jk3a1 110 | d1w32a_ 111 | d1otja_ 112 | d2znra1 113 | d4uqva_ 114 | d4xmra1 115 | d4r82a_ 116 | d2cm5a_ 117 | d2f8aa1 118 | d1w99a1 119 | d2h3na_ 120 | d1szna2 121 | d1rq5a2 122 | d1d4ba1 123 | d1rlla1 124 | d1o12a2 125 | d2dlta1 126 | d2xvla1 127 | d1b9la_ 128 | d3anua2 129 | d3df8a1 130 | d2gu3a1 131 | d3dbxa2 132 | d1w5fa2 133 | d6jt6a_ 134 | d1s05a_ 135 | d6luja1 136 | d1deca_ 137 | d6evla_ 138 | d6wsha1 139 | d3or5a_ 140 | d5z0qa_ 141 | d3keba_ 142 | d2dc3a_ 143 | d3n6wa2 144 | d2nv0a_ 145 | d3hd5a1 146 | d2gkga_ 147 | d1mc0a2 148 | d1gp6a_ 149 | d4f0ba1 150 | d1t3ga_ 151 | d4ieua_ 152 | d6q6na2 153 | d2jgqa_ 154 | d1v86a1 155 | d1pkoa1 156 | d3tsma_ 157 | d1vjpa4 158 | d1x46a_ 159 | d1n52a3 160 | d1tfia_ 161 | d3bn3b1 162 | d5i4da1 163 | d1uz5a1 164 | d1x48a1 165 | d2dnaa1 166 | d1bhea_ 167 | d7jgwa1 168 | d2dy1a5 169 | d1o1ya1 170 | d3en8a1 171 | d1y3ta1 172 | d2a1ka_ 173 | d2bpa2_ 174 | d3cu7a1 175 | d2csba3 176 | d3c5ra2 177 | d1elwa_ 178 | d1reqa1 179 | d1h97a_ 180 | d1v7va2 181 | d2bgca1 182 | d1ebua2 183 | d3ma2b_ 184 | d1qjta_ 185 | d3eu9a1 186 | d5ycza_ 187 | d2hoea1 188 | d1vkra_ 189 | d4ay7a_ 190 | d1vlpa1 191 | d1j8mf1 192 | d1fhoa_ 193 | d2i61a_ 194 | d5xopa1 195 | d2axwa1 196 | d2zaya_ 197 | d1a8da2 198 | d2frha2 199 | d4xdaa_ 200 | d1y1xa_ 201 | d3ri6a_ 202 | d2fe3a_ 203 | d2qcva_ 204 | d7m3xa_ 205 | d1zzma1 206 | d1dhna_ 207 | d6r2na2 208 | d1khca_ 209 | d6blga_ 210 | d1m45a_ 211 | d1kzla2 212 | d1mvha_ 213 | d3h3ba1 214 | d2rnja_ 215 | d4bhxa1 216 | d1ckma1 217 | d5svva_ 218 | d3jb9u3 219 | d4yvda_ 220 | d3krua_ 221 | d1bdoa_ 222 | d2kspa1 223 | d1ofda3 224 | d2cjja1 225 | d3ppea2 226 | d1kqka_ 227 | d2wfba1 228 | d1xuva1 229 | d2coca1 230 | d3gt5a1 231 | d2p9ra_ 232 | d3bexa1 233 | d3pnza_ 234 | d1bw0a_ 235 | d2xfaa_ 236 | d4ch0s1 237 | d2w2db1 238 | d4e9sa3 239 | d1x4ea1 240 | d2q5we1 241 | d2d6ya1 242 | d1fu3a_ 243 | d5y9ea_ 244 | d1y0ga_ 245 | d3c8ca1 246 | d2fexa1 247 | d1geqa_ 248 | d2q9oa1 249 | d1smoa_ 250 | d2naca2 251 | d4unua_ 252 | d3lx3a_ 253 | d2iyga_ 254 | d3kcca1 255 | d2a28a1 256 | d2a1xa1 257 | d2ckxa1 258 | d1gqia2 259 | d5tkma_ 260 | d3i33a2 261 | d1wgga1 262 | d3o0ma1 263 | d1bvyf_ 264 | d2zfga_ 265 | d1t33a1 266 | d3gg7a1 267 | d3d01a1 268 | d1fi6a_ 269 | d2ih3c_ 270 | d1w9aa_ 271 | d2f42a1 272 | d3fiua_ 273 | d3w2wa3 274 | d6mwdb1 275 | d1ng0a_ 276 | d1n0za1 277 | d3b0fa_ 278 | d3jtea1 279 | d3na8a_ 280 | d2hxva2 281 | d5u1ma_ 282 | d1dbha2 283 | d2bs2b2 284 | d1wu7a1 285 | d6t40a_ 286 | d2qqra2 287 | d2ra2a1 288 | d1ty0a1 289 | d3gkxa_ 290 | d1x1ia1 291 | d4l3ma_ 292 | d1bf2a3 293 | d2il5a1 294 | d1jvna2 295 | d2jbaa_ 296 | d5wl1a2 297 | d1m9sa3 298 | d1lvka1 299 | d5fjqa_ 300 | d5cm7a1 301 | d2z48a1 302 | d3cu7a8 303 | d5dl5a1 304 | d1x9na2 305 | d1kqfb1 306 | d2pe8a1 307 | d1wjka1 308 | d3ceda1 309 | d1wisa1 310 | d1vyra_ 311 | d1rmga_ 312 | d6wk3a_ 313 | d1pk3a1 314 | d6gf1a1 315 | d2q3qa2 316 | d3n0pb1 317 | d1uw4b_ 318 | d1u5da1 319 | d6onna1 320 | d2edna1 321 | d1gr0a2 322 | d1ejea_ 323 | d6jc0b_ 324 | d4wnya1 325 | d2f3la1 326 | d4ohja1 327 | d2imla1 328 | d6fd1a_ 329 | d3r1za2 330 | d1sixa1 331 | d1clca1 332 | d2jnfa_ 333 | d3bb9a1 334 | d2hr7a2 335 | d2axla1 336 | d2ayta1 337 | d1ix9a1 338 | d3fsoa1 339 | d4dpoa1 340 | d2fdbm1 341 | d7ahsa2 342 | d2rsda1 343 | d2uy6b1 344 | d1rypd_ 345 | d4mzya2 346 | d4kjfa_ 347 | d2wu2c_ 348 | d3ibva_ 349 | d5ysca_ 350 | d2z61a_ 351 | d2nw2a1 352 | d2nlza1 353 | d1tiga_ 354 | d1kl9a2 355 | d2c1ia1 356 | d3hcza1 357 | d5gxxa2 358 | d3hpea_ 359 | d3md3a1 360 | d1jmrb4 361 | d3tqva1 362 | d1kfwa1 363 | d2bz7a_ 364 | d4cvqa_ 365 | d1v3wa_ 366 | d1gdha2 367 | d3frxa_ 368 | d1i7qb_ 369 | d4u5ia_ 370 | d2vyoa_ 371 | d1xxxa1 372 | d2fbwc_ 373 | d6hbsa_ 374 | d4m5ra1 375 | d1pbyb_ 376 | d5uida_ 377 | d2nq3a1 378 | d1xkua_ 379 | d4c4ko1 380 | d3zdmc_ 381 | d1vq8b1 382 | d2xzga2 383 | d1rd5a_ 384 | d2cuaa_ 385 | d1usca_ 386 | d4jw0a_ 387 | d1ug2a1 388 | d1m3ya2 389 | d3mgka_ 390 | d2m87a_ 391 | d5bnza1 392 | d1b9ha_ 393 | d3czba_ 394 | d3bbba_ 395 | d1lw7a1 396 | d2fjca1 397 | d2mska_ 398 | d3eura1 399 | d1p4ca_ 400 | d6xr5a2 401 | -------------------------------------------------------------------------------- /scripts/other_methods/tmalign/parse.sh: -------------------------------------------------------------------------------- 1 | rm run_unsorted.out run.out 2 | for file in out/*.out; do 3 | echo $file 4 | grep -a -n "^Name of Chain_1:\|^Name of Chain_2:\|TM-score=" $file | sed 's|=|:|g' | awk -F ": " '{print $2}' | awk '{print $1}' | sed 's|../pdbstyle-2.08/||g' | xargs -n 4 | awk '{print $1, $2, $3, $4}' >> run_unsorted.out 5 | done 6 | sort -k1b,1 -nrk3,3 run_unsorted.out > run.out 7 | -------------------------------------------------------------------------------- /scripts/other_methods/tmalign/run.sh: -------------------------------------------------------------------------------- 1 | mkdir out 2 | while read query; do 3 | echo $query 4 | for target in ../pdbstyle-2.08/*; do 5 | TMalign ../pdbstyle-2.08/$query.ent $target -fast >> out/$query.out 6 | done 7 | done < domids.txt 8 | -------------------------------------------------------------------------------- /scripts/other_methods/tmalign/searching.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | dataset_dir = "../../dataset" 7 | known_problems = ["d6a1ia1"] 8 | 9 | domid_to_fam = {} 10 | fams_search = [] 11 | with open(dataset_dir + "/search_all.txt") as f: 12 | for line in f.readlines(): 13 | cols = line.rstrip().split() 14 | domid, fam = cols[0], cols[1] 15 | if domid not in known_problems: 16 | domid_to_fam[domid] = fam 17 | fams_search.append(fam) 18 | 19 | domid_to_nres, domid_to_contact_order = {}, {} 20 | with open("../contact_order.txt") as f: 21 | for line in f.readlines(): 22 | cols = line.rstrip().split() 23 | domid, nres, contact_order = cols[0], cols[1], cols[2] 24 | if domid not in known_problems: 25 | domid_to_nres[domid] = int(nres) 26 | domid_to_contact_order[domid] = float(contact_order) 27 | 28 | domid_to_fam_matches = defaultdict(list) 29 | with open("run.out") as f: 30 | for line in f.readlines(): 31 | cols = line.split() 32 | domid_query = cols[0].split(".")[0] 33 | domid_match = cols[1].split(".")[0] 34 | if domid_query in domid_to_fam and domid_match in domid_to_fam: 35 | fam_match = domid_to_fam[domid_match] 36 | domid_to_fam_matches[domid_query].append(fam_match) 37 | 38 | with open(dataset_dir + "/test_unseen_sid.txt") as f: 39 | domids_unseen = [l.split()[0] for l in f.readlines()] 40 | 41 | def scop_fold(fam): 42 | return fam.rsplit(".", 2)[0] 43 | 44 | def scop_sfam(fam): 45 | return fam.rsplit(".", 1)[0] 46 | 47 | def searching_accuracy(domid_query, fam_matches): 48 | fam_query = domid_to_fam[domid_query] 49 | sfam_query = scop_sfam(fam_query) 50 | fold_query = scop_fold(fam_query) 51 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 52 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 53 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 54 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 55 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 56 | 57 | for fi, fam_match in enumerate(fam_matches): 58 | sfam_match = scop_sfam(fam_match) 59 | fold_match = scop_fold(fam_match) 60 | # Don't count self match for top 1 accuracy 61 | if fi == 1: 62 | if fam_match == fam_query: 63 | top1_fam = 1 64 | if sfam_match == sfam_query: 65 | top1_sfam = 1 66 | if fold_match == fold_query: 67 | top1_fold = 1 68 | if 0 < fi < 6 and fam_match == fam_query: 69 | top5_fam = 1 70 | if fam_match == fam_query: 71 | count_tp_fam += 1 72 | elif sfam_match == sfam_query: 73 | count_tp_sfam += 1 74 | elif fold_match == fold_query: 75 | count_tp_fold += 1 76 | else: 77 | break 78 | 79 | sens_fam = count_tp_fam / total_fam 80 | sens_sfam = count_tp_sfam / total_sfam 81 | sens_fold = count_tp_fold / total_fold 82 | 83 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 84 | 85 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 86 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 87 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 88 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 89 | 90 | for di, domid in enumerate(domids_unseen): 91 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(domid, domid_to_fam_matches[domid]) 92 | 93 | senses_fam.append(sens_fam) 94 | senses_sfam.append(sens_sfam) 95 | senses_fold.append(sens_fold) 96 | top1s_fam.append(top1_fam) 97 | top1s_sfam.append(top1_sfam) 98 | top1s_fold.append(top1_fold) 99 | top5s_fam.append(top5_fam) 100 | 101 | cl = domid_to_fam[domid][0] 102 | senses_fam_class[cl].append(sens_fam) 103 | senses_sfam_class[cl].append(sens_sfam) 104 | senses_fold_class[cl].append(sens_fold) 105 | 106 | nres = domid_to_nres[domid] 107 | if 0 <= nres < 100: 108 | nres_group = "0-99" 109 | elif 100 <= nres < 200: 110 | nres_group = "100-199" 111 | elif 200 <= nres < 300: 112 | nres_group = "200-299" 113 | else: 114 | nres_group = "300+" 115 | 116 | senses_fam_nres[nres_group].append(sens_fam) 117 | senses_sfam_nres[nres_group].append(sens_sfam) 118 | senses_fold_nres[nres_group].append(sens_fold) 119 | 120 | co = domid_to_contact_order[domid] 121 | if 0 <= co < 0.075: 122 | co_group = "0-0.075" 123 | elif 0.075 <= co < 0.125: 124 | co_group = "0.075-0.125" 125 | elif 0.125 <= co < 0.175: 126 | co_group = "0.125-0.175" 127 | else: 128 | co_group = "0.175+" 129 | 130 | senses_fam_co[co_group].append(sens_fam) 131 | senses_sfam_co[co_group].append(sens_sfam) 132 | senses_fold_co[co_group].append(sens_fold) 133 | 134 | print(np.mean(senses_fam)) 135 | print(np.mean(senses_sfam)) 136 | print(np.mean(senses_fold)) 137 | print() 138 | 139 | for cl in sorted(list(senses_fam_class.keys())): 140 | print(cl, "class") 141 | print(np.mean(senses_fam_class[cl])) 142 | print(np.mean(senses_sfam_class[cl])) 143 | print(np.mean(senses_fold_class[cl])) 144 | print() 145 | 146 | for nres_group in sorted(list(senses_fam_nres.keys())): 147 | print(nres_group, "nres") 148 | print(np.mean(senses_fam_nres[nres_group])) 149 | print(np.mean(senses_sfam_nres[nres_group])) 150 | print(np.mean(senses_fold_nres[nres_group])) 151 | print() 152 | 153 | for co_group in sorted(list(senses_fam_co.keys())): 154 | print(co_group, "contact_order") 155 | print(np.mean(senses_fam_co[co_group])) 156 | print(np.mean(senses_sfam_co[co_group])) 157 | print(np.mean(senses_fold_co[co_group])) 158 | print() 159 | -------------------------------------------------------------------------------- /scripts/other_methods/tmalign/searching.txt: -------------------------------------------------------------------------------- 1 | 0.8055506710177635 2 | 0.5937810129779497 3 | 0.10010031032505434 4 | 5 | a class 6 | 0.6564888749303465 7 | 0.32697791412822313 8 | 0.06059293844666792 9 | 10 | b class 11 | 0.843119631353269 12 | 0.639022472213631 13 | 0.19114119139777655 14 | 15 | c class 16 | 0.818459536544941 17 | 0.6453969397255832 18 | 0.052793749684319956 19 | 20 | d class 21 | 0.8594113016326638 22 | 0.6627089695227107 23 | 0.051899551856906215 24 | 25 | f class 26 | 0.7946428571428572 27 | 0.8611111111111112 28 | 0.11205432937181664 29 | 30 | g class 31 | 0.7492063492063492 32 | 0.35816326530612247 33 | 0.015727391874180867 34 | 35 | 0-99 nres 36 | 0.688148151271463 37 | 0.4104696187359091 38 | 0.05656127928902577 39 | 40 | 100-199 nres 41 | 0.8338644370588928 42 | 0.6382623882378401 43 | 0.0991826585854389 44 | 45 | 200-299 nres 46 | 0.82257042682675 47 | 0.599159560988082 48 | 0.14601364946564366 49 | 50 | 300+ nres 51 | 0.8912071899947893 52 | 0.747911912827453 53 | 0.1261440363543858 54 | 55 | 0-0.075 contact_order 56 | 0.5481631004631446 57 | 0.373156410221532 58 | 0.07841039293420524 59 | 60 | 0.075-0.125 contact_order 61 | 0.8417835554219008 62 | 0.6116938861823428 63 | 0.10447328845486557 64 | 65 | 0.125-0.175 contact_order 66 | 0.8364214704355866 67 | 0.6209262756101411 68 | 0.1277842827132178 69 | 70 | 0.175+ contact_order 71 | 0.7806155701664256 72 | 0.5994471201524403 73 | 0.0538013694731156 74 | 75 | -------------------------------------------------------------------------------- /scripts/scope_val.py: -------------------------------------------------------------------------------- 1 | # Calculate structure searching performance 2 | # Unpack training_coords first with `tar -xvf training_coords.tgz` 3 | 4 | import progres as pg 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch_geometric.loader import DataLoader 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | dataset_dir = "dataset" 12 | coord_dir = "training_coords" 13 | device = "cpu" 14 | model_batch_search = 8 15 | embedding_size = 128 16 | 17 | domids_search, fams_search = [], [] 18 | with open(dataset_dir + "/search_all.txt") as f: 19 | for line in f.readlines(): 20 | cols = line.rstrip().split() 21 | domid, fam = cols[0], cols[1] 22 | domids_search.append(domid) 23 | fams_search.append(fam) 24 | 25 | domid_to_nres, domid_to_contact_order = {}, {} 26 | with open("other_methods/contact_order.txt") as f: 27 | for line in f.readlines(): 28 | cols = line.rstrip().split() 29 | domid, nres, contact_order = cols[0], cols[1], cols[2] 30 | domid_to_nres[domid] = int(nres) 31 | domid_to_contact_order[domid] = float(contact_order) 32 | 33 | class CoordinateDataset(Dataset): 34 | def __init__(self, domids): 35 | self.domids = domids 36 | self.model = pg.load_trained_model(device) 37 | 38 | def __len__(self): 39 | return len(self.domids) 40 | 41 | def __getitem__(self, idx): 42 | fp = f"{coord_dir}/{self.domids[idx]}" 43 | return pg.embed_structure(fp, fileformat="coords", device=device, model=self.model) 44 | 45 | query_set = CoordinateDataset(domids_search) 46 | num_workers = torch.get_num_threads() if device == torch.device("cpu") else 0 47 | query_loader = DataLoader(query_set, batch_size=model_batch_search, shuffle=False, 48 | num_workers=num_workers) 49 | 50 | with torch.no_grad(): 51 | query_embeddings = torch.zeros(len(query_set), embedding_size, device=device) 52 | for bi, out in enumerate(query_loader): 53 | query_embeddings[(bi * model_batch_search):(bi * model_batch_search + out.size(0))] = out 54 | 55 | query_embeddings = query_embeddings.to(torch.float16) 56 | dists = torch.zeros(len(query_set), len(query_set), device=device) 57 | for bi in range(len(query_set)): 58 | dists[bi:(bi + 1)] = pg.embedding_distance(query_embeddings[bi:(bi + 1)], query_embeddings) 59 | 60 | with open(f"{dataset_dir}/test_unseen_sid.txt") as f: 61 | domids_unseen = [l.split()[0] for l in f.readlines()] 62 | 63 | def scop_fold(fam): 64 | return fam.rsplit(".", 2)[0] 65 | 66 | def scop_sfam(fam): 67 | return fam.rsplit(".", 1)[0] 68 | 69 | def searching_accuracy(di, dists): 70 | fam_query = fams_search[di] 71 | sfam_query = scop_sfam(fam_query) 72 | fold_query = scop_fold(fam_query) 73 | count_tp_fam, count_tp_sfam, count_tp_fold = 0, 0, 0 74 | total_fam = sum(1 if f == fam_query else 0 for f in fams_search) 75 | total_sfam = sum(1 if f != fam_query and scop_sfam(f) == sfam_query else 0 for f in fams_search) 76 | total_fold = sum(1 if scop_sfam(f) != sfam_query and scop_fold(f) == fold_query else 0 for f in fams_search) 77 | top1_fam, top1_sfam, top1_fold, top5_fam = 0, 0, 0, 0 78 | 79 | for ji, j in enumerate(dists[di].argsort()): 80 | fam_match = fams_search[j] 81 | sfam_match = scop_sfam(fam_match) 82 | fold_match = scop_fold(fam_match) 83 | # Don't count self match for top 1 accuracy 84 | if ji == 1: 85 | if fam_match == fam_query: 86 | top1_fam = 1 87 | if sfam_match == sfam_query: 88 | top1_sfam = 1 89 | if fold_match == fold_query: 90 | top1_fold = 1 91 | if 0 < ji < 6 and fam_match == fam_query: 92 | top5_fam = 1 93 | if fam_match == fam_query: 94 | count_tp_fam += 1 95 | elif sfam_match == sfam_query: 96 | count_tp_sfam += 1 97 | elif fold_match == fold_query: 98 | count_tp_fold += 1 99 | else: 100 | break 101 | 102 | sens_fam = count_tp_fam / total_fam 103 | sens_sfam = count_tp_sfam / total_sfam 104 | sens_fold = count_tp_fold / total_fold 105 | 106 | return sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam 107 | 108 | senses_fam, senses_sfam, senses_fold, top1s_fam, top1s_sfam, top1s_fold, top5s_fam = [], [], [], [], [], [], [] 109 | senses_fam_class, senses_sfam_class, senses_fold_class = defaultdict(list), defaultdict(list), defaultdict(list) 110 | senses_fam_nres , senses_sfam_nres , senses_fold_nres = defaultdict(list), defaultdict(list), defaultdict(list) 111 | senses_fam_co , senses_sfam_co , senses_fold_co = defaultdict(list), defaultdict(list), defaultdict(list) 112 | 113 | for dc, domid in enumerate(domids_unseen): 114 | di = domids_search.index(domid) 115 | sens_fam, sens_sfam, sens_fold, top1_fam, top1_sfam, top1_fold, top5_fam = searching_accuracy(di, dists) 116 | 117 | senses_fam.append(sens_fam) 118 | senses_sfam.append(sens_sfam) 119 | senses_fold.append(sens_fold) 120 | top1s_fam.append(top1_fam) 121 | top1s_sfam.append(top1_sfam) 122 | top1s_fold.append(top1_fold) 123 | top5s_fam.append(top5_fam) 124 | 125 | cl = fams_search[di][0] 126 | senses_fam_class[cl].append(sens_fam) 127 | senses_sfam_class[cl].append(sens_sfam) 128 | senses_fold_class[cl].append(sens_fold) 129 | 130 | nres = domid_to_nres[domid] 131 | if 0 <= nres < 100: 132 | nres_group = "0-99" 133 | elif 100 <= nres < 200: 134 | nres_group = "100-199" 135 | elif 200 <= nres < 300: 136 | nres_group = "200-299" 137 | else: 138 | nres_group = "300+" 139 | 140 | senses_fam_nres[nres_group].append(sens_fam) 141 | senses_sfam_nres[nres_group].append(sens_sfam) 142 | senses_fold_nres[nres_group].append(sens_fold) 143 | 144 | co = domid_to_contact_order[domid] 145 | if 0 <= co < 0.075: 146 | co_group = "0-0.075" 147 | elif 0.075 <= co < 0.125: 148 | co_group = "0.075-0.125" 149 | elif 0.125 <= co < 0.175: 150 | co_group = "0.125-0.175" 151 | else: 152 | co_group = "0.175+" 153 | 154 | senses_fam_co[co_group].append(sens_fam) 155 | senses_sfam_co[co_group].append(sens_sfam) 156 | senses_fold_co[co_group].append(sens_fold) 157 | 158 | print(np.mean(senses_fam)) 159 | print(np.mean(senses_sfam)) 160 | print(np.mean(senses_fold)) 161 | print() 162 | 163 | for cl in sorted(list(senses_fam_class.keys())): 164 | print(cl, "class") 165 | print(np.mean(senses_fam_class[cl])) 166 | print(np.mean(senses_sfam_class[cl])) 167 | print(np.mean(senses_fold_class[cl])) 168 | print() 169 | 170 | for nres_group in sorted(list(senses_fam_nres.keys())): 171 | print(nres_group, "nres") 172 | print(np.mean(senses_fam_nres[nres_group])) 173 | print(np.mean(senses_sfam_nres[nres_group])) 174 | print(np.mean(senses_fold_nres[nres_group])) 175 | print() 176 | 177 | for co_group in sorted(list(senses_fam_co.keys())): 178 | print(co_group, "contact_order") 179 | print(np.mean(senses_fam_co[co_group])) 180 | print(np.mean(senses_sfam_co[co_group])) 181 | print(np.mean(senses_fold_co[co_group])) 182 | print() 183 | -------------------------------------------------------------------------------- /scripts/scope_val.txt: -------------------------------------------------------------------------------- 1 | 0.8771379760486869 2 | 0.7060833588310607 3 | 0.1772748013642059 4 | 5 | a class 6 | 0.8316866885127755 7 | 0.6036463789159128 8 | 0.05570900470444102 9 | 10 | b class 11 | 0.8951768493735566 12 | 0.7553443175144247 13 | 0.2834996333823747 14 | 15 | c class 16 | 0.924649641685099 17 | 0.8162678024274738 18 | 0.16435424430908718 19 | 20 | d class 21 | 0.8471587386641566 22 | 0.6220508380206372 23 | 0.14529620105850644 24 | 25 | f class 26 | 0.6116071428571428 27 | 0.14955357142857142 28 | 0.022707979626485568 29 | 30 | g class 31 | 0.7392290249433107 32 | 0.29183673469387755 33 | 0.01179554390563565 34 | 35 | 0-99 nres 36 | 0.8284328665107914 37 | 0.6261483041932864 38 | 0.19579171843737445 39 | 40 | 100-199 nres 41 | 0.8878113130193207 42 | 0.728631928526574 43 | 0.15734643156382713 44 | 45 | 200-299 nres 46 | 0.8954304339903548 47 | 0.7148946410446976 48 | 0.25413009619644594 49 | 50 | 300+ nres 51 | 0.9046245612790047 52 | 0.7574755058793283 53 | 0.13010629305280533 54 | 55 | 0-0.075 contact_order 56 | 0.7314523476719677 57 | 0.5062438653236814 58 | 0.0841403620299361 59 | 60 | 0.075-0.125 contact_order 61 | 0.9079190796417461 62 | 0.7562586325822404 63 | 0.14654076463984148 64 | 65 | 0.125-0.175 contact_order 66 | 0.8470850296671235 67 | 0.6668573458629742 68 | 0.19538490492630925 69 | 70 | 0.175+ contact_order 71 | 0.9193989705588428 72 | 0.743717226061335 73 | 0.25075922378622373 74 | 75 | -------------------------------------------------------------------------------- /scripts/training_coords.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greener-group/progres/39d24c5a431983049f3149f93720508ef97133df/scripts/training_coords.tgz -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name="progres", 8 | version="1.0.0", 9 | author="Joe G Greener", 10 | author_email="jgreener@mrc-lmb.cam.ac.uk", 11 | description="Fast protein structure searching using structure graph embeddings", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/greener-group/progres", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | "Topic :: Scientific/Engineering :: Bio-Informatics", 21 | ], 22 | license="MIT", 23 | keywords="protein structure search graph embedding", 24 | scripts=["bin/progres"], 25 | install_requires=["biopython", "mmtf-python", "einops", "pydantic"], 26 | include_package_data=True, 27 | ) 28 | --------------------------------------------------------------------------------