├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── README.md ├── amberlib │ ├── protein.ff14SB.xml │ └── tip3p_standard.xml ├── create_custom_set.py ├── environment.yml ├── main.py ├── posex │ ├── align.py │ ├── ccd.py │ ├── data.py │ ├── mmcif.py │ ├── preprocess.py │ └── utils.py ├── template │ ├── ccd_query.txt │ ├── cross_dock.txt │ ├── self_dock.txt │ └── vs_query.txt └── utils │ ├── __init__.py │ ├── common_helper.py │ ├── mol_correct_helper.py │ ├── openmm_helper.py │ ├── pdb_helper.py │ ├── pdb_process.py │ └── repair_pdb.py ├── environments ├── base.yaml ├── boltz-1.txt ├── boltz-1x.txt ├── chai-1.txt ├── relax.yaml └── rfaa.yaml ├── figures ├── logo.png ├── posex_cross_dock.png └── posex_self_dock.png ├── scripts ├── calculate_benchmark_result.py ├── calculate_benchmark_result.sh ├── complex_structure_alignment.py ├── complex_structure_alignment.sh ├── convert_to_model_input.py ├── convert_to_model_input.sh ├── extract_model_output.py ├── extract_model_output.sh ├── generate_docking_benchmark.py ├── generate_docking_benchmark.sh ├── relax_model_outputs.py ├── run_alphafold3 │ └── run_alphafold3.sh ├── run_boltz │ └── run_boltz.sh ├── run_boltz1x │ └── run_boltz1x.sh ├── run_chai │ ├── run_chai.py │ └── run_chai.sh ├── run_deepdock │ ├── evaluate.py │ ├── prepare.py │ ├── run_deepdock.py │ └── run_deepdock.sh ├── run_diffdock │ ├── run_diffdock.py │ └── run_diffdock.sh ├── run_diffdock_l │ ├── run_diffdock_l.py │ └── run_diffdock_l.sh ├── run_diffdock_pocket │ └── run_diffdock_pocket.sh ├── run_dynamicbind │ ├── run_dynamicbind.py │ └── run_dynamicbind.sh ├── run_equibind │ ├── run_equibind.py │ └── run_equibind.sh ├── run_fabind │ ├── run_fabind.py │ └── run_fabind.sh ├── run_gnina │ ├── run_gnina.sh │ └── run_gnina_help.sh ├── run_interformer │ ├── run_interformer.py │ └── run_interformer.sh ├── run_neuralplexer │ ├── run_neuralplexer.py │ └── run_neuralplexer.sh ├── run_protenix │ └── run_protenix.sh ├── run_rfaa │ └── run_rfaa.sh ├── run_surfdock │ ├── run_surfdock.sh │ └── run_surfdock_help.sh ├── run_tankbind │ ├── run_tankbind.py │ └── run_tankbind.sh └── run_unimol │ ├── run_unimol.py │ └── run_unimol.sh └── tests ├── s1_download_mmcif.py └── s2_prepare_and_post_relax.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | env/ 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # 虚拟环境 25 | venv/ 26 | ENV/ 27 | 28 | # IDE 29 | .idea/ 30 | .vscode/ 31 | *.swp 32 | *.swo 33 | 34 | # MacOS 35 | .DS_Store 36 | Thumbs.db 37 | 38 | data/ 39 | notebooks/ 40 | /relax_debug 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 CataAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # PoseX 2 | 3 | 4 | ## Setup Environment 5 | 6 | conda env create -f environment.yml 7 | 8 | 9 | ## Generate Benchmark Dataset 10 | **Step 1:** Download “Tabular Report - Entry IDs” that you are interested in from [RCSB PDB](https://www.rcsb.org/search?request=%7B%22query%22%3A%7B%22type%22%3A%22group%22%2C%22logical_operator%22%3A%22and%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22group%22%2C%22logical_operator%22%3A%22and%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22group%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22terminal%22%2C%22service%22%3A%22text%22%2C%22parameters%22%3A%7B%22attribute%22%3A%22rcsb_accession_info.initial_release_date%22%2C%22operator%22%3A%22greater_or_equal%22%2C%22negation%22%3Afalse%2C%22value%22%3A%222022-01-01%22%7D%7D%2C%7B%22type%22%3A%22terminal%22%2C%22service%22%3A%22text%22%2C%22parameters%22%3A%7B%22attribute%22%3A%22rcsb_accession_info.initial_release_date%22%2C%22operator%22%3A%22less_or_equal%22%2C%22negation%22%3Afalse%2C%22value%22%3A%222025-01-01%22%7D%7D%5D%2C%22logical_operator%22%3A%22and%22%7D%2C%7B%22type%22%3A%22group%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22terminal%22%2C%22service%22%3A%22text%22%2C%22parameters%22%3A%7B%22attribute%22%3A%22rcsb_entry_info.selected_polymer_entity_types%22%2C%22operator%22%3A%22exact_match%22%2C%22negation%22%3Afalse%2C%22value%22%3A%22Protein%20(only)%22%7D%7D%5D%2C%22logical_operator%22%3A%22and%22%7D%2C%7B%22type%22%3A%22group%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22terminal%22%2C%22service%22%3A%22text%22%2C%22parameters%22%3A%7B%22attribute%22%3A%22rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id%22%2C%22operator%22%3A%22exists%22%2C%22negation%22%3Afalse%7D%7D%5D%2C%22logical_operator%22%3A%22and%22%7D%2C%7B%22type%22%3A%22group%22%2C%22nodes%22%3A%5B%7B%22type%22%3A%22terminal%22%2C%22service%22%3A%22text%22%2C%22parameters%22%3A%7B%22attribute%22%3A%22rcsb_entry_info.resolution_combined%22%2C%22operator%22%3A%22less_or_equal%22%2C%22negation%22%3Afalse%2C%22value%22%3A2%7D%7D%5D%2C%22logical_operator%22%3A%22and%22%7D%5D%2C%22label%22%3A%22text%22%7D%5D%7D%2C%22return_type%22%3A%22entry%22%2C%22request_options%22%3A%7B%22paginate%22%3A%7B%22start%22%3A0%2C%22rows%22%3A25%7D%2C%22results_content_type%22%3A%5B%22experimental%22%5D%2C%22sort%22%3A%5B%7B%22sort_by%22%3A%22score%22%2C%22direction%22%3A%22desc%22%7D%5D%2C%22scoring_strategy%22%3A%22combined%22%7D%2C%22request_info%22%3A%7B%22query_id%22%3A%2223a56d461e7e7e96f4065e59843158fe%22%7D%7D), you need to merge multiple *.txt files into one if the number of entries is greater than 10000. 11 | 12 | **Step 2:** Generate benchmark dataset 13 | 14 | python main.py 15 | 16 | *Inputs*: 17 | - `--mode`: benchmark mode (self_dock or cross_dock). 18 | - `--pdbid_path`: Path to the downloaded txt file containing Entry IDs in "Step 1". 19 | - `--download_dir`: Folder to save the downloaded files. 20 | - `--mmseqs_exec`: Path to the [MMseqs2](https://github.com/soedinglab/MMseqs2) binary (for protein clustering). 21 | 22 | *Outputs*: 23 | - dataset in a folder named `${mode}` 24 | 25 | 26 | ## Generate Custom Dataset 27 | 28 | python create_custom_set.py 29 | 30 | *Inputs*: 31 | - `--name`: dataset name. 32 | - `--pdb_ccd_path`: Path to the txt file containing items of f"{PDBID}_{CCDID}". 33 | - `--download_dir`: Folder to save the downloaded files. 34 | 35 | *Outputs*: 36 | - dataset in a folder named `${name}` 37 | 38 | 39 | ## Directory Structure 40 | ```plaintext 41 | . 42 | ├── create_custom_set.py # Create custom dataset 43 | ├── main.py # Entry point 44 | ├── posex 45 | │   ├── align.py # Cross alignment module 46 | │   ├── ccd.py # CCD utils module 47 | │   ├── data.py # Dataset generator module 48 | │   ├── mmcif.py # MMCIF parser module 49 | │   ├── preprocess.py # Dataset preprocessor module 50 | │   └── utils.py # Utils module 51 | └── template 52 | ├── ccd_query.txt # CCD query JSON template 53 | ├── cross_dock.txt # Cross-dock table Jinja2 template 54 | ├── self_dock.txt # Self-dock table Jinja2 template 55 | └── vs_query.txt # Validation score query JSON template 56 | 57 | ``` -------------------------------------------------------------------------------- /dataset/create_custom_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from dataclasses import asdict 5 | from collections import defaultdict 6 | from posex.utils import DownloadConfig 7 | from posex.data import DatasetGenerator 8 | from posex.preprocess import DataPreprocessor 9 | 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--name", type=str, required=True, help="dataset name") 15 | parser.add_argument("--pdb_ccd_path", type=str, default="/home/lixinze/github/collab/protein_ligand_docking_benchmark/data/dataset/astex_diverse_set_ids.txt") 16 | parser.add_argument("--download_dir", type=str, default="/data/dataset/posex", help="folder to save the downloaded files") 17 | args = parser.parse_args() 18 | np.random.seed(42) 19 | output_dir = os.path.abspath(args.name) 20 | assert not os.path.exists(output_dir), f"The {args.name} dataset already exists" 21 | download_config = asdict(DownloadConfig(args.download_dir)) 22 | pdb_ccd_dict = defaultdict(set) 23 | with open(args.pdb_ccd_path, "r") as f: 24 | for pdb_ccd in f.readlines(): 25 | pdb, ccd = pdb_ccd.strip().split("_") 26 | pdb_ccd_dict[pdb].add(ccd) 27 | pdbid_list = list(pdb_ccd_dict.keys()) 28 | data_preprocessor = DataPreprocessor(pdbid_list, **download_config) 29 | pdb_ccd_instance_map = data_preprocessor.run() 30 | dataset_generator = DatasetGenerator(mode="self_dock", 31 | pdb_ccd_instance_map=pdb_ccd_instance_map, 32 | output_dir=output_dir, 33 | mmseqs_exec=None, 34 | **download_config) 35 | dataset_generator.select_single_conformation() 36 | dataset_generator.set_pdb_ccd_dict(pdb_ccd_dict) 37 | dataset_generator.save_self_dock_res() 38 | -------------------------------------------------------------------------------- /dataset/environment.yml: -------------------------------------------------------------------------------- 1 | name: posex 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - https://repo.anaconda.com/pkgs/main 6 | - https://repo.anaconda.com/pkgs/r 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - alsa-lib=1.2.13=hb9d3cd8_0 11 | - asttokens=3.0.0=pyhd8ed1ab_1 12 | - attr=2.5.1=h166bdaf_1 13 | - blosc=1.21.6=he440d0b_1 14 | - brotli=1.1.0=hb9d3cd8_2 15 | - brotli-bin=1.1.0=hb9d3cd8_2 16 | - bzip2=1.0.8=h4bc722e_7 17 | - c-ares=1.34.4=hb9d3cd8_0 18 | - ca-certificates=2024.12.14=hbcca054_0 19 | - cairo=1.18.2=h3394656_1 20 | - certifi=2024.12.14=pyhd8ed1ab_0 21 | - chardet=5.2.0=py310hff52083_2 22 | - comm=0.2.2=pyhd8ed1ab_1 23 | - contourpy=1.3.1=py310h3788b33_0 24 | - cycler=0.12.1=pyhd8ed1ab_1 25 | - cyrus-sasl=2.1.27=h54b06d7_7 26 | - dbus=1.13.6=h5008d03_3 27 | - debugpy=1.8.11=py310hf71b8c6_0 28 | - decorator=5.1.1=pyhd8ed1ab_1 29 | - exceptiongroup=1.2.2=pyhd8ed1ab_1 30 | - executing=2.1.0=pyhd8ed1ab_1 31 | - expat=2.6.4=h5888daf_0 32 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 33 | - font-ttf-inconsolata=3.000=h77eed37_0 34 | - font-ttf-source-code-pro=2.038=h77eed37_0 35 | - font-ttf-ubuntu=0.83=h77eed37_3 36 | - fontconfig=2.15.0=h7e30c49_1 37 | - fonts-conda-ecosystem=1=0 38 | - fonts-conda-forge=1=0 39 | - fonttools=4.55.3=py310h89163eb_0 40 | - freetype=2.12.1=h267a509_2 41 | - freetype-py=2.3.0=pyhd8ed1ab_0 42 | - gettext=0.22.5=he02047a_3 43 | - gettext-tools=0.22.5=he02047a_3 44 | - glew=2.1.0=h9c3ff4c_2 45 | - glib=2.82.2=h44428e9_0 46 | - glib-tools=2.82.2=h4833e2c_0 47 | - glm=0.9.9.8=h00ab1b0_0 48 | - graphite2=1.3.13=h59595ed_1003 49 | - greenlet=3.1.1=py310hf71b8c6_1 50 | - gst-plugins-base=1.24.7=h0a52356_0 51 | - gstreamer=1.24.7=hf3bb09a_0 52 | - harfbuzz=10.1.0=h0b3b770_0 53 | - hdf4=4.2.15=h2a13503_7 54 | - hdf5=1.14.4=nompi_h2d575fe_105 55 | - icu=75.1=he02047a_0 56 | - importlib-metadata=8.5.0=pyha770c72_1 57 | - ipykernel=6.29.5=pyh3099207_0 58 | - ipython=8.18.1=pyh707e725_3 59 | - jedi=0.19.2=pyhd8ed1ab_1 60 | - jupyter_client=8.6.3=pyhd8ed1ab_1 61 | - jupyter_core=5.7.2=pyh31011fe_1 62 | - keyutils=1.6.1=h166bdaf_0 63 | - kiwisolver=1.4.7=py310h3788b33_0 64 | - krb5=1.21.3=h659f571_0 65 | - lame=3.100=h166bdaf_1003 66 | - lcms2=2.16=hb7c19ff_0 67 | - ld_impl_linux-64=2.40=h12ee557_0 68 | - lerc=4.0.0=h27087fc_0 69 | - libaec=1.1.3=h59595ed_0 70 | - libasprintf=0.22.5=he8f35ee_3 71 | - libasprintf-devel=0.22.5=he8f35ee_3 72 | - libblas=3.9.0=26_linux64_openblas 73 | - libboost=1.86.0=h6c02f8c_3 74 | - libboost-python=1.86.0=py310ha2bacc8_3 75 | - libbrotlicommon=1.1.0=hb9d3cd8_2 76 | - libbrotlidec=1.1.0=hb9d3cd8_2 77 | - libbrotlienc=1.1.0=hb9d3cd8_2 78 | - libcap=2.71=h39aace5_0 79 | - libcblas=3.9.0=26_linux64_openblas 80 | - libclang-cpp15=15.0.7=default_h127d8a8_5 81 | - libclang-cpp19.1=19.1.6=default_hb5137d0_0 82 | - libclang13=19.1.6=default_h9c6a7e4_0 83 | - libcups=2.3.3=h4637d8d_4 84 | - libcurl=8.11.1=h332b0f4_0 85 | - libdeflate=1.23=h4ddbbb0_0 86 | - libdrm=2.4.124=hb9d3cd8_0 87 | - libedit=3.1.20191231=he28a2e2_2 88 | - libegl=1.7.0=ha4b6fd6_2 89 | - libev=4.33=hd590300_2 90 | - libevent=2.1.12=hf998b51_1 91 | - libexpat=2.6.4=h5888daf_0 92 | - libffi=3.4.2=h7f98852_5 93 | - libflac=1.4.3=h59595ed_0 94 | - libgcc=14.2.0=h77fa898_1 95 | - libgcc-ng=14.2.0=h69a702a_1 96 | - libgcrypt-lib=1.11.0=hb9d3cd8_2 97 | - libgettextpo=0.22.5=he02047a_3 98 | - libgettextpo-devel=0.22.5=he02047a_3 99 | - libgfortran=14.2.0=h69a702a_1 100 | - libgfortran5=14.2.0=hd5240d6_1 101 | - libgl=1.7.0=ha4b6fd6_2 102 | - libglib=2.82.2=h2ff4ddf_0 103 | - libglu=9.0.3=h03adeef_0 104 | - libglvnd=1.7.0=ha4b6fd6_2 105 | - libglx=1.7.0=ha4b6fd6_2 106 | - libgomp=14.2.0=h77fa898_1 107 | - libgpg-error=1.51=hbd13f7d_1 108 | - libiconv=1.17=hd590300_2 109 | - libjpeg-turbo=3.0.0=hd590300_1 110 | - liblapack=3.9.0=26_linux64_openblas 111 | - libllvm15=15.0.7=hb3ce162_4 112 | - libllvm19=19.1.6=ha7bfdaf_0 113 | - liblzma=5.6.3=hb9d3cd8_1 114 | - libnetcdf=4.9.2=nompi_h2564987_115 115 | - libnghttp2=1.64.0=h161d5f1_0 116 | - libnsl=2.0.1=hd590300_0 117 | - libntlm=1.8=hb9d3cd8_0 118 | - libogg=1.3.5=h4ab18f5_0 119 | - libopenblas=0.3.28=pthreads_h94d23a6_1 120 | - libopus=1.3.1=h7f98852_1 121 | - libpciaccess=0.18=hd590300_0 122 | - libpng=1.6.44=hadc24fc_0 123 | - libpq=17.2=h3b95a9b_1 124 | - librdkit=2024.09.3=h84b0b3c_0 125 | - libsndfile=1.2.2=hc60ed4a_1 126 | - libsodium=1.0.20=h4ab18f5_0 127 | - libsqlite=3.47.2=hee588c1_0 128 | - libssh2=1.11.1=hf672d98_0 129 | - libstdcxx=14.2.0=hc0a3c3a_1 130 | - libstdcxx-ng=14.2.0=h4852527_1 131 | - libsystemd0=256.9=h0b6a36f_2 132 | - libtiff=4.7.0=hd9ff511_3 133 | - libuuid=2.38.1=h0b41bf4_0 134 | - libvorbis=1.3.7=h9c3ff4c_0 135 | - libwebp-base=1.5.0=h851e524_0 136 | - libxcb=1.17.0=h8a09558_0 137 | - libxcrypt=4.4.36=hd590300_1 138 | - libxkbcommon=1.7.0=h2c5496b_1 139 | - libxml2=2.13.5=h8d12d68_1 140 | - libzip=1.11.2=h6991a6a_0 141 | - libzlib=1.3.1=hb9d3cd8_2 142 | - lz4-c=1.10.0=h5888daf_1 143 | - matplotlib-base=3.10.0=py310h68603db_0 144 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1 145 | - mpg123=1.32.9=hc50e24c_0 146 | - munkres=1.1.4=pyh9f0ad1d_0 147 | - mysql-common=9.0.1=h266115a_3 148 | - mysql-libs=9.0.1=he0572af_3 149 | - ncurses=6.5=he02047a_1 150 | - nest-asyncio=1.6.0=pyhd8ed1ab_1 151 | - nspr=4.36=h5888daf_0 152 | - nss=3.107=hdf54f9c_0 153 | - openjpeg=2.5.3=h5fbd93e_0 154 | - openldap=2.6.9=he970967_0 155 | - openssl=3.4.0=hb9d3cd8_0 156 | - packaging=24.2=pyhd8ed1ab_2 157 | - pandas=2.2.3=py310h5eaa309_1 158 | - parso=0.8.4=pyhd8ed1ab_1 159 | - pcre2=10.44=hba22ea6_2 160 | - pexpect=4.9.0=pyhd8ed1ab_1 161 | - pickleshare=0.7.5=pyhd8ed1ab_1004 162 | - pillow=11.0.0=py310hfeaa1f3_0 163 | - pip=24.3.1=pyh8b19718_2 164 | - pixman=0.44.2=h29eaf8c_0 165 | - platformdirs=4.3.6=pyhd8ed1ab_1 166 | - ply=3.11=pyhd8ed1ab_3 167 | - pmw=2.0.1=py310hff52083_1008 168 | - prompt-toolkit=3.0.48=pyha770c72_1 169 | - psutil=6.1.0=py310ha75aee5_0 170 | - pthread-stubs=0.4=hb9d3cd8_1002 171 | - ptyprocess=0.7.0=pyhd8ed1ab_1 172 | - pulseaudio-client=17.0=hb77b528_0 173 | - pure_eval=0.2.3=pyhd8ed1ab_1 174 | - pycairo=1.27.0=py310h25ff670_0 175 | - pygments=2.18.0=pyhd8ed1ab_1 176 | - pymol-open-source=3.0.0=py310h200f838_8 177 | - pyparsing=3.2.0=pyhd8ed1ab_2 178 | - pyqt=5.15.9=py310h04931ad_5 179 | - pyqt5-sip=12.12.2=py310hc6cd4ac_5 180 | - python=3.10.0=h543edf9_3_cpython 181 | - python-dateutil=2.9.0.post0=pyhff2d567_1 182 | - python-tzdata=2024.2=pyhd8ed1ab_1 183 | - python_abi=3.10=5_cp310 184 | - pyzmq=26.2.0=py310h71f11fc_3 185 | - qhull=2020.2=h434a139_5 186 | - qt-main=5.15.15=hc3cb62f_2 187 | - readline=8.2=h5eee18b_0 188 | - reportlab=4.2.5=py310ha75aee5_0 189 | - rlpycairo=0.2.0=pyhd8ed1ab_0 190 | - setuptools=75.6.0=pyhff2d567_1 191 | - sip=6.7.12=py310hc6cd4ac_0 192 | - six=1.17.0=pyhd8ed1ab_0 193 | - snappy=1.2.1=h8bd8927_1 194 | - sqlalchemy=2.0.36=py310ha75aee5_0 195 | - sqlite=3.47.2=h9eae976_0 196 | - stack_data=0.6.3=pyhd8ed1ab_1 197 | - tk=8.6.13=noxft_h4845f30_101 198 | - toml=0.10.2=pyhd8ed1ab_1 199 | - tomli=2.2.1=pyhd8ed1ab_1 200 | - tornado=6.4.2=py310ha75aee5_0 201 | - traitlets=5.14.3=pyhd8ed1ab_1 202 | - typing-extensions=4.12.2=hd8ed1ab_1 203 | - typing_extensions=4.12.2=pyha770c72_1 204 | - tzdata=2024b=h04d1e81_0 205 | - unicodedata2=15.1.0=py310ha75aee5_1 206 | - wcwidth=0.2.13=pyhd8ed1ab_1 207 | - wheel=0.45.1=pyhd8ed1ab_1 208 | - xcb-util=0.4.1=hb711507_2 209 | - xcb-util-image=0.4.0=hb711507_2 210 | - xcb-util-keysyms=0.4.1=hb711507_0 211 | - xcb-util-renderutil=0.3.10=hb711507_0 212 | - xcb-util-wm=0.4.2=hb711507_0 213 | - xkeyboard-config=2.42=h4ab18f5_0 214 | - xorg-kbproto=1.0.7=hb9d3cd8_1003 215 | - xorg-libice=1.1.2=hb9d3cd8_0 216 | - xorg-libsm=1.2.5=he73a12e_0 217 | - xorg-libx11=1.8.10=h4f16b4b_1 218 | - xorg-libxau=1.0.12=hb9d3cd8_0 219 | - xorg-libxdamage=1.1.6=hb9d3cd8_0 220 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 221 | - xorg-libxext=1.3.6=hb9d3cd8_0 222 | - xorg-libxfixes=6.0.1=hb9d3cd8_0 223 | - xorg-libxrender=0.9.11=hd590300_0 224 | - xorg-libxxf86vm=1.1.6=hb9d3cd8_0 225 | - xorg-renderproto=0.11.1=hb9d3cd8_1003 226 | - xorg-xextproto=7.3.0=hb9d3cd8_1004 227 | - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1005 228 | - xorg-xproto=7.0.31=hb9d3cd8_1008 229 | - xz=5.4.6=h5eee18b_1 230 | - zeromq=4.3.5=h3b0a872_7 231 | - zipp=3.21.0=pyhd8ed1ab_1 232 | - zlib=1.3.1=hb9d3cd8_2 233 | - zstd=1.5.6=ha6fb4c9_0 234 | - pip: 235 | - bcrypt==4.2.1 236 | - biopandas==0.5.1 237 | - biopython==1.84 238 | - biotite==1.0.1 239 | - biotraj==1.2.2 240 | - cffi==1.17.1 241 | - charset-normalizer==3.4.0 242 | - cryptography==44.0.0 243 | - dill==0.3.9 244 | - future==1.0.0 245 | - gitdb==4.0.11 246 | - gitpython==3.1.43 247 | - idna==3.10 248 | - jinja2==3.1.5 249 | - looseversion==1.1.2 250 | - markupsafe==3.0.2 251 | - mmcif==0.90.0 252 | - mmtf-python==1.1.3 253 | - msgpack==1.1.0 254 | - multiprocess==0.70.17 255 | - networkx==3.4.2 256 | - numpy==1.26.3 257 | - paramiko==3.5.0 258 | - pdbecif==1.5 259 | - posebusters==0.3.1 260 | - pycparser==2.22 261 | - pynacl==1.5.0 262 | - pytz==2024.2 263 | - pyyaml==6.0.2 264 | - rcsb-utils-config==0.41 265 | - rcsb-utils-io==1.49 266 | - rcsb-utils-validation==0.33 267 | - rdkit-pypi==2022.9.5 268 | - requests==2.32.3 269 | - ruamel-yaml==0.18.6 270 | - ruamel-yaml-clib==0.2.12 271 | - scipy==1.14.1 272 | - smmap==5.0.1 273 | - tqdm==4.67.1 274 | - urllib3==2.3.0 275 | -------------------------------------------------------------------------------- /dataset/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from dataclasses import asdict 5 | from posex.utils import DownloadConfig 6 | from posex.data import DatasetGenerator 7 | from posex.preprocess import DataPreprocessor 8 | 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--mode", type=str, default="cross_dock", help="benchmark mode(self_dock or cross_dock)") 14 | parser.add_argument("--pdbid_path", type=str, default="/data/dataset/posex/pdbid_cross_2.0.txt") 15 | parser.add_argument("--download_dir", type=str, default="/data/dataset/posex", help="folder to save the downloaded files") 16 | parser.add_argument("--mmseqs_exec", type=str, default="/data/software/mmseqs2/mmseqs/bin/mmseqs", help="path to mmseqs exec") 17 | args = parser.parse_args() 18 | np.random.seed(42) 19 | output_dir = os.path.abspath(args.mode) 20 | assert not os.path.exists(output_dir), f"The {args.mode} dataset already exists" 21 | download_config = asdict(DownloadConfig(args.download_dir)) 22 | with open(args.pdbid_path, "r") as f: 23 | pdbid_list = f.readline().strip().split(",") 24 | data_preprocessor = DataPreprocessor(pdbid_list, **download_config) 25 | pdb_ccd_instance_map = data_preprocessor.run() 26 | dataset_generator = DatasetGenerator(mode=args.mode, 27 | pdb_ccd_instance_map=pdb_ccd_instance_map, 28 | output_dir=output_dir, 29 | mmseqs_exec=args.mmseqs_exec, 30 | **download_config) 31 | dataset_generator.run() 32 | -------------------------------------------------------------------------------- /dataset/posex/align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from pymol import cmd 4 | from posex.utils import my_tqdm 5 | 6 | class CrossAlignment(): 7 | """ 8 | A class for creating cross groups for cross-docking 9 | 10 | Args: 11 | cif_dir (str) : folder to save pdb entries (cif format) 12 | cluster_map (dict[str, set[str]]): 13 | dict of the representative pdb to other pdbs in the cluster 14 | pdb_ccd_dict (dict[str, set[str]]): 15 | dict of PDBID to CCD set 16 | pdb_ccd_instance_map (dict[str, dict[str, list[str]]]): 17 | dict of pdb to dict of ccd to list of asym_ids 18 | downsample_num (int): maximun num of items in a cross group 19 | 20 | """ 21 | def __init__(self, 22 | cif_dir: str, 23 | cluster_map: dict[str, set[str]], 24 | pdb_ccd_dict: dict[str, set[str]], 25 | pdb_ccd_instance_map: dict[str, dict[str, list[str]]], 26 | downsample_num: int = 8): 27 | self.cif_dir = cif_dir 28 | self.cluster_map = cluster_map 29 | self.pdb_ccd_dict = pdb_ccd_dict 30 | self.pdb_ccd_instance_map = pdb_ccd_instance_map 31 | self.downsample_num = downsample_num 32 | 33 | def _get_func_name(self, idx=1) -> str: 34 | """Return the name of a funciton in the current call stack 35 | """ 36 | return inspect.stack()[idx].function 37 | 38 | def check_lig_from_candidate_to_reference(self, 39 | ref_pdb: str, 40 | ref_ccd: str, 41 | asym_ids: list, 42 | valid_pdb_set: set 43 | ) -> dict[tuple[str, str], str]: 44 | """Check if each instance of the reference ligand is within 4.0 Å of the candidate ligand 45 | 46 | Args: 47 | ref_pdb (str): PDBID 48 | ref_ccd (str): CCD 49 | asym_ids (str): list of asym_ids 50 | valid_pdb_set (set): set of valid PDBIDs 51 | 52 | Returns: 53 | dict[tuple[str, str], str]: candidate_items(dict[tuple[pdb, ccd], ccd_asym_id]) 54 | """ 55 | candidate_items = {} 56 | for cand_pdb in sorted(valid_pdb_set): 57 | for cand_ccd in self.pdb_ccd_dict[cand_pdb]: 58 | if cand_ccd == ref_ccd: continue 59 | cand_ccd_asym_ids = self.pdb_ccd_instance_map[cand_pdb][cand_ccd] 60 | hit_asym_ids = set() 61 | for cand_ccd_asym_id in cand_ccd_asym_ids: 62 | for asym_id in asym_ids: 63 | ref_lig = f"{ref_pdb}_{ref_ccd}_{asym_id}" 64 | res_name = f"{ref_lig}_{cand_ccd}_{cand_ccd_asym_id}" 65 | cmd.select(ref_lig, f"resn {ref_ccd} and segi {asym_id} and {ref_pdb}") 66 | cmd.select(res_name, f"resn {cand_ccd} and segi {cand_ccd_asym_id} and ({cand_pdb} within 4.0 of {ref_lig})") 67 | if cmd.count_atoms(res_name) > 0: 68 | hit_asym_ids.add(asym_id) 69 | if asym_id == asym_ids[0]: 70 | selelcted_cand_ccd_asym_id = cand_ccd_asym_id 71 | cmd.delete(res_name) 72 | cmd.delete(ref_lig) 73 | if len(hit_asym_ids) == len(asym_ids): 74 | candidate_items[(cand_pdb, cand_ccd)] = selelcted_cand_ccd_asym_id 75 | return candidate_items 76 | 77 | def check_lig_from_reference_to_candidate(self, 78 | ref_pdb: str, 79 | ref_ccd: str, 80 | candidate_items: dict[tuple[str, str], str] 81 | ) -> dict[tuple[str, str], str]: 82 | """Check if each instance of the candidate ligand is within 4.0 Å of the reference ligand 83 | 84 | Args: 85 | ref_pdb (str): PDBID 86 | ref_ccd (str): CCD 87 | candidate_items (dict[tuple[str, str], str]): 88 | dict[tuple[pdb, ccd], ccd_asym_id] 89 | 90 | Returns: 91 | dict[tuple[str, str], str]: filtered candidate_items 92 | """ 93 | for cand_item in list(candidate_items.keys()): 94 | cand_pdb, cand_ccd = cand_item 95 | for asym_id in self.pdb_ccd_instance_map[cand_pdb][cand_ccd]: 96 | cand_lig = f"{cand_pdb}_{cand_ccd}_{asym_id}" 97 | res_name = f"{cand_lig}_{ref_ccd}" 98 | cmd.select(cand_lig, f"resn {cand_ccd} and segi {asym_id} and {cand_pdb}") 99 | cmd.select(res_name, f"resn {ref_ccd} and ({ref_pdb} within 4.0 of {cand_lig})") 100 | cmd.delete(cand_lig) 101 | if cmd.count_atoms(res_name) == 0: 102 | del candidate_items[cand_item] 103 | cmd.delete(res_name) 104 | break 105 | else: 106 | cmd.delete(res_name) 107 | return candidate_items 108 | 109 | def select_cross_candidates(self, 110 | ref_pdb: str, 111 | ref_ccd: str, 112 | asym_ids: list[str], 113 | valid_pdb_set: set 114 | ) -> tuple[tuple[str, str, str]] | None: 115 | """Given a reference protien and ligand, return a cross group containing tuple of (pdb, ccd, asym_id) 116 | 117 | Args: 118 | ref_pdb (str): PDBID 119 | ref_ccd (str): CCD 120 | asym_ids (str): list of asym_ids 121 | valid_pdb_set (set): set of valid PDBIDs 122 | 123 | Returns: 124 | tuple[tuple[str, str, str]] | None: 125 | tuple[tuple[pdb, ccd, ccd_asym_id]] 126 | """ 127 | selected_ref_ccd_asym_id = asym_ids[0] 128 | candidate_items = self.check_lig_from_candidate_to_reference(ref_pdb, ref_ccd, asym_ids, valid_pdb_set) 129 | candidate_items = self.check_lig_from_reference_to_candidate(ref_pdb, ref_ccd, candidate_items) 130 | if candidate_items: 131 | # remove repeated ccd in candidate_items 132 | new_candidate_items = {} 133 | cand_ccds = set() 134 | for cand_item, cand_ccd_asym_id in candidate_items.items(): 135 | cand_pdb, cand_ccd = cand_item 136 | if cand_ccd not in cand_ccds: 137 | new_candidate_items[cand_item] = cand_ccd_asym_id 138 | cand_ccds.add(cand_ccd) 139 | cross_group = [(ref_pdb, ref_ccd, selected_ref_ccd_asym_id)] 140 | for cand_item, cand_ccd_asym_id in new_candidate_items.items(): 141 | cross_group.append((*cand_item, cand_ccd_asym_id)) 142 | cross_group = tuple(sorted(cross_group)) 143 | else: 144 | cross_group = None 145 | return cross_group 146 | 147 | def filter_pdb_with_alignment(self, ref_pdb: str, pdb_set: set, max_rmsd: float = 2.0) -> set: 148 | """Remove PDB entry if the RMSD of the protein alignment is greater than 2.0 Å 149 | 150 | Args: 151 | ref_pdb (str): reference PDBID 152 | pdb_set (set): PDBIDs of proteins to be aligned 153 | max_rmsd (str): the maximum acceptable alignment rmsd 154 | valid_pdb_set (set): set of valid PDBIDs 155 | 156 | Returns: 157 | set: valid pdb set 158 | """ 159 | ref_pdb_path = os.path.join(self.cif_dir, f"{ref_pdb}.cif") 160 | cmd.load(ref_pdb_path, ref_pdb) 161 | valid_pdb_set = set() 162 | for pdb in pdb_set: 163 | pdb_path = os.path.join(self.cif_dir, f"{pdb}.cif") 164 | cmd.load(pdb_path, pdb) 165 | align_res = cmd.align(f"{pdb}////CA", f"{ref_pdb}////CA") 166 | rmsd = align_res[0] 167 | if rmsd <= max_rmsd: 168 | valid_pdb_set.add(pdb) 169 | else: 170 | cmd.delete(pdb) 171 | return valid_pdb_set 172 | 173 | def run(self) -> set[tuple[tuple[str, str, str]]]: 174 | """Return a set of cross groups, each cross group contains tuple of (pdb, ccd, asym_id) 175 | """ 176 | cross_groups = set() 177 | for ref_pdb, pdb_set in my_tqdm(self.cluster_map.items(), desc=self._get_func_name(2)): 178 | valid_pdb_set = self.filter_pdb_with_alignment(ref_pdb, pdb_set) 179 | for ref_ccd in self.pdb_ccd_dict[ref_pdb]: 180 | asym_ids = self.pdb_ccd_instance_map[ref_pdb][ref_ccd] 181 | cross_group = self.select_cross_candidates(ref_pdb, ref_ccd, asym_ids, valid_pdb_set) 182 | if cross_group is not None: 183 | # downsample 184 | if len(cross_group) > self.downsample_num: 185 | cross_group = cross_group[:self.downsample_num] 186 | cross_groups.add(cross_group) 187 | cmd.delete("all") 188 | return cross_groups -------------------------------------------------------------------------------- /dataset/posex/ccd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------- 2 | # Following code adapted from (https://github.com/bytedance/Protenix) 3 | # -------------------------------------------------------------------- 4 | 5 | import functools 6 | import logging 7 | from collections import defaultdict 8 | from typing import Any, Optional, Union 9 | 10 | import biotite 11 | import biotite.structure as struc 12 | import biotite.structure.io.pdbx as pdbx 13 | import numpy as np 14 | from biotite.structure import AtomArray 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | COMPONENTS_FILE = None 19 | 20 | 21 | @functools.lru_cache 22 | def biotite_load_ccd_cif() -> pdbx.CIFFile: 23 | """biotite load CCD components file 24 | 25 | Returns: 26 | pdbx.CIFFile: ccd components file 27 | """ 28 | return pdbx.CIFFile.read(COMPONENTS_FILE) 29 | 30 | 31 | def _map_central_to_leaving_groups(component) -> Optional[dict[str, list[list[str]]]]: 32 | """map each central atom (bonded atom) index to leaving atom groups in component (atom_array). 33 | 34 | Returns: 35 | dict[str, list[list[str]]]: central atom name to leaving atom groups (atom names). 36 | """ 37 | comp = component.copy() 38 | # Eg: ions 39 | if comp.bonds is None: 40 | return {} 41 | central_to_leaving_groups = defaultdict(list) 42 | for c_idx in np.flatnonzero(~comp.leaving_atom_flag): 43 | bonds, _ = comp.bonds.get_bonds(c_idx) 44 | for l_idx in bonds: 45 | if comp.leaving_atom_flag[l_idx]: 46 | comp.bonds.remove_bond(c_idx, l_idx) 47 | group_idx = struc.find_connected(comp.bonds, l_idx) 48 | if not np.all(comp.leaving_atom_flag[group_idx]): 49 | return None 50 | central_to_leaving_groups[comp.atom_name[c_idx]].append( 51 | comp.atom_name[group_idx].tolist() 52 | ) 53 | return central_to_leaving_groups 54 | 55 | 56 | @functools.lru_cache 57 | def get_component_atom_array( 58 | ccd_code: str, keep_leaving_atoms: bool = False, keep_hydrogens=False 59 | ) -> AtomArray: 60 | """get component atom array 61 | 62 | Args: 63 | ccd_code (str): ccd code 64 | keep_leaving_atoms (bool, optional): keep leaving atoms. Defaults to False. 65 | keep_hydrogens (bool, optional): keep hydrogens. Defaults to False. 66 | 67 | Returns: 68 | AtomArray: Biotite AtomArray of CCD component 69 | with additional attribute: leaving_atom_flag (bool) 70 | """ 71 | ccd_cif = biotite_load_ccd_cif() 72 | if ccd_code not in ccd_cif: 73 | logger.warning(f"Warning: get_component_atom_array() can not parse {ccd_code}") 74 | return None 75 | try: 76 | comp = pdbx.get_component(ccd_cif, data_block=ccd_code, use_ideal_coord=True) 77 | except biotite.InvalidFileError as e: 78 | # Eg: UNL without atom. 79 | logger.warning( 80 | f"Warning: get_component_atom_array() can not parse {ccd_code} for {e}" 81 | ) 82 | return None 83 | atom_category = ccd_cif[ccd_code]["chem_comp_atom"] 84 | leaving_atom_flag = atom_category["pdbx_leaving_atom_flag"].as_array() 85 | comp.set_annotation("leaving_atom_flag", leaving_atom_flag == "Y") 86 | 87 | for atom_id in ["alt_atom_id", "pdbx_component_atom_id"]: 88 | comp.set_annotation(atom_id, atom_category[atom_id].as_array()) 89 | if not keep_leaving_atoms: 90 | comp = comp[~comp.leaving_atom_flag] 91 | if not keep_hydrogens: 92 | # EG: ND4 93 | comp = comp[~np.isin(comp.element, ["H", "D"])] 94 | 95 | # Map central atom index to leaving group (atom_indices) in component (atom_array). 96 | comp.central_to_leaving_groups = _map_central_to_leaving_groups(comp) 97 | if comp.central_to_leaving_groups is None: 98 | logger.warning( 99 | f"Warning: ccd {ccd_code} has leaving atom group bond to more than one central atom, central_to_leaving_groups is None." 100 | ) 101 | return comp 102 | 103 | 104 | @functools.lru_cache(maxsize=None) 105 | def get_one_letter_code(ccd_code: str) -> Union[str, None]: 106 | """get one_letter_code from CCD components file. 107 | 108 | normal return is one letter: ALA --> A, DT --> T 109 | unknown protein: X 110 | unknown DNA or RNA: N 111 | other unknown: None 112 | some ccd_code will return more than one letter: 113 | eg: XXY --> THG 114 | 115 | Args: 116 | ccd_code (str): _description_ 117 | 118 | Returns: 119 | str: one letter code 120 | """ 121 | ccd_cif = biotite_load_ccd_cif() 122 | if ccd_code not in ccd_cif: 123 | return None 124 | one = ccd_cif[ccd_code]["chem_comp"]["one_letter_code"].as_item() 125 | if one == "?": 126 | return None 127 | else: 128 | return one 129 | 130 | 131 | @functools.lru_cache(maxsize=None) 132 | def get_mol_type(ccd_code: str) -> str: 133 | """get mol_type from CCD components file. 134 | 135 | based on _chem_comp.type 136 | http://mmcif.rcsb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_chem_comp.type.html 137 | 138 | not use _chem_comp.pdbx_type, because it is not consistent with _chem_comp.type 139 | e.g. ccd 000 --> _chem_comp.type="NON-POLYMER" _chem_comp.pdbx_type="ATOMP" 140 | https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v5_next.dic/Items/_struct_asym.pdbx_type.html 141 | 142 | Args: 143 | ccd_code (str): ccd code 144 | 145 | Returns: 146 | str: mol_type, one of {"protein", "rna", "dna", "ligand"} 147 | """ 148 | ccd_cif = biotite_load_ccd_cif() 149 | if ccd_code not in ccd_cif: 150 | return "ligand" 151 | 152 | link_type = ccd_cif[ccd_code]["chem_comp"]["type"].as_item().upper() 153 | 154 | if "PEPTIDE" in link_type and link_type != "PEPTIDE-LIKE": 155 | return "protein" 156 | if "DNA" in link_type: 157 | return "dna" 158 | if "RNA" in link_type: 159 | return "rna" 160 | return "ligand" 161 | 162 | 163 | def get_all_ccd_code() -> list: 164 | """get all ccd code from components file""" 165 | ccd_cif = biotite_load_ccd_cif() 166 | return list(ccd_cif.keys()) 167 | 168 | # Modified from biotite to use consistent ccd components file 169 | def _connect_inter_residue( 170 | atoms: AtomArray, residue_starts: np.ndarray 171 | ) -> struc.BondList: 172 | """ 173 | Create a :class:`BondList` containing the bonds between adjacent 174 | amino acid or nucleotide residues. 175 | 176 | Parameters 177 | ---------- 178 | atoms : AtomArray or AtomArrayStack 179 | The structure to create the :class:`BondList` for. 180 | residue_starts : ndarray, dtype=int 181 | Return value of 182 | ``get_residue_starts(atoms, add_exclusive_stop=True)``. 183 | 184 | Returns 185 | ------- 186 | BondList 187 | A bond list containing all inter residue bonds. 188 | """ 189 | 190 | bonds = [] 191 | 192 | atom_names = atoms.atom_name 193 | res_names = atoms.res_name 194 | res_ids = atoms.res_id 195 | chain_ids = atoms.chain_id 196 | 197 | # Iterate over all starts excluding: 198 | # - the last residue and 199 | # - exclusive end index of 'atoms' 200 | for i in range(len(residue_starts) - 2): 201 | curr_start_i = residue_starts[i] 202 | next_start_i = residue_starts[i + 1] 203 | after_next_start_i = residue_starts[i + 2] 204 | 205 | # Check if the current and next residue is in the same chain 206 | if chain_ids[next_start_i] != chain_ids[curr_start_i]: 207 | continue 208 | # Check if the current and next residue 209 | # have consecutive residue IDs 210 | # (Same residue ID is also possible if insertion code is used) 211 | if res_ids[next_start_i] - res_ids[curr_start_i] > 1: 212 | continue 213 | 214 | # Get link type for this residue from RCSB components.cif 215 | curr_link = get_mol_type(res_names[curr_start_i]) 216 | next_link = get_mol_type(res_names[next_start_i]) 217 | 218 | if curr_link == "protein" and next_link in "protein": 219 | curr_connect_atom_name = "C" 220 | next_connect_atom_name = "N" 221 | elif curr_link in ["dna", "rna"] and next_link in ["dna", "rna"]: 222 | curr_connect_atom_name = "O3'" 223 | next_connect_atom_name = "P" 224 | else: 225 | # Create no bond if the connection types of consecutive 226 | # residues are not compatible 227 | continue 228 | 229 | # Index in atom array for atom name in current residue 230 | # Addition of 'curr_start_i' is necessary, as only a slice of 231 | # 'atom_names' is taken, beginning at 'curr_start_i' 232 | curr_connect_indices = np.where( 233 | atom_names[curr_start_i:next_start_i] == curr_connect_atom_name 234 | )[0] 235 | curr_connect_indices += curr_start_i 236 | 237 | # Index in atom array for atom name in next residue 238 | next_connect_indices = np.where( 239 | atom_names[next_start_i:after_next_start_i] == next_connect_atom_name 240 | )[0] 241 | next_connect_indices += next_start_i 242 | 243 | if len(curr_connect_indices) == 0 or len(next_connect_indices) == 0: 244 | # The connector atoms are not found in the adjacent residues 245 | # -> skip this bond 246 | continue 247 | 248 | bonds.append( 249 | (curr_connect_indices[0], next_connect_indices[0], struc.BondType.SINGLE) 250 | ) 251 | 252 | return struc.BondList(atoms.array_length(), np.array(bonds, dtype=np.uint32)) 253 | 254 | 255 | def add_inter_residue_bonds( 256 | atom_array: AtomArray, 257 | exclude_struct_conn_pairs: bool = False, 258 | remove_far_inter_chain_pairs: bool = False, 259 | ) -> AtomArray: 260 | """ 261 | add polymer bonds (C-N or O3'-P) between adjacent residues based on auth_seq_id. 262 | 263 | exclude_struct_conn_pairs: if True, do not add bond between adjacent residues already has non-standard polymer bonds 264 | on atom C or N or O3' or P. 265 | 266 | remove_far_inter_chain_pairs: if True, remove inter chain (based on label_asym_id) bonds that are far away from each other. 267 | 268 | returns: 269 | AtomArray: Biotite AtomArray merged inter residue bonds into atom_array.bonds 270 | """ 271 | res_starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True) 272 | inter_bonds = _connect_inter_residue(atom_array, res_starts) 273 | 274 | if atom_array.bonds is None: 275 | atom_array.bonds = inter_bonds 276 | return atom_array 277 | 278 | select_mask = np.ones(len(inter_bonds._bonds), dtype=bool) 279 | if exclude_struct_conn_pairs: 280 | for b_idx, (atom_i, atom_j, b_type) in enumerate(inter_bonds._bonds): 281 | atom_k = atom_i if atom_array.atom_name[atom_i] in ("N", "O3'") else atom_j 282 | bonds, types = atom_array.bonds.get_bonds(atom_k) 283 | if len(bonds) == 0: 284 | continue 285 | for b in bonds: 286 | if ( 287 | # adjacent residues 288 | abs((res_starts <= b).sum() - (res_starts <= atom_k).sum()) == 1 289 | and atom_array.chain_id[b] == atom_array.chain_id[atom_k] 290 | and atom_array.atom_name[b] not in ("C", "P") 291 | ): 292 | select_mask[b_idx] = False 293 | break 294 | 295 | if remove_far_inter_chain_pairs: 296 | if not hasattr(atom_array, "label_asym_id"): 297 | logging.warning( 298 | "label_asym_id not found, far inter chain bonds will not be removed" 299 | ) 300 | for b_idx, (atom_i, atom_j, b_type) in enumerate(inter_bonds._bonds): 301 | if atom_array.label_asym_id[atom_i] != atom_array.label_asym_id[atom_j]: 302 | coord_i = atom_array.coord[atom_i] 303 | coord_j = atom_array.coord[atom_j] 304 | if np.linalg.norm(coord_i - coord_j) > 2.5: 305 | select_mask[b_idx] = False 306 | 307 | # filter out removed_inter_bonds from atom_array.bonds 308 | remove_bonds = inter_bonds._bonds[~select_mask] 309 | remove_mask = np.isin(atom_array.bonds._bonds[:, 0], remove_bonds[:, 0]) & np.isin( 310 | atom_array.bonds._bonds[:, 1], remove_bonds[:, 1] 311 | ) 312 | atom_array.bonds._bonds = atom_array.bonds._bonds[~remove_mask] 313 | 314 | # merged normal inter_bonds into atom_array.bonds 315 | inter_bonds._bonds = inter_bonds._bonds[select_mask] 316 | atom_array.bonds = atom_array.bonds.merge(inter_bonds) 317 | return atom_array 318 | 319 | 320 | def res_names_to_sequence(res_names: list[str]) -> str: 321 | """convert res_names to sequences {chain_id: canonical_sequence} based on CCD 322 | 323 | Return 324 | str: canonical_sequence 325 | """ 326 | seq = "" 327 | for res_name in res_names: 328 | one = get_one_letter_code(res_name) 329 | one = "X" if one is None else one 330 | one = "X" if len(one) > 1 else one 331 | seq += one 332 | return seq 333 | -------------------------------------------------------------------------------- /dataset/posex/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import inspect 4 | import requests 5 | import subprocess 6 | import pandas as pd 7 | from pymol import cmd 8 | from typing import Any 9 | from rdkit import Chem 10 | from functools import partial 11 | from multiprocessing import Pool 12 | from collections import defaultdict 13 | from posex.utils import bcif2cif, my_tqdm 14 | from pdbecif.mmcif_io import CifFileReader 15 | from posex.utils import run_in_tmp_dir, get_ccd_instance_map 16 | 17 | 18 | NUM_CPUS = 100 19 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | 21 | 22 | class DataPreprocessor(): 23 | """ 24 | A class for downloading and preprocessing data 25 | 26 | Args: 27 | pdbid_list : list of pdbids 28 | """ 29 | def __init__(self, pdbid_list: list[str], **kwargs: Any): 30 | self.bcif_dir = kwargs.get("bcif_dir") 31 | self.cif_dir = kwargs.get("cif_dir") 32 | self.vs_dir = kwargs.get("vs_dir") 33 | self.ccd_dir = kwargs.get("ccd_dir") 34 | self.ccd_path = kwargs.get("ccd_path") 35 | self.lig_dir = kwargs.get("lig_dir") 36 | self.molecule_dir = kwargs.get("molecule_dir") 37 | self.components_path = kwargs.get("components_path") 38 | self.pdbid_list = pdbid_list 39 | 40 | def _get_func_name(self, idx=1) -> str: 41 | """return the name of a funciton in the current call stack 42 | """ 43 | return inspect.stack()[idx].function 44 | 45 | def download_bcif(self) -> None: 46 | os.makedirs(self.bcif_dir, exist_ok=True) 47 | for pdbid in my_tqdm(self.pdbid_list, desc=self._get_func_name()): 48 | if os.path.exists(f"{self.bcif_dir}/{pdbid}.bcif"): 49 | continue 50 | try: 51 | url = f"https://models.rcsb.org/{pdbid}.bcif" 52 | subprocess.run(["wget", url], cwd=self.bcif_dir, check=True) 53 | except Exception as e: 54 | print(pdbid, e) 55 | 56 | def convert_bif_to_cif(self, num_cpus=NUM_CPUS) -> None: 57 | os.makedirs(self.cif_dir, exist_ok=True) 58 | bcif2cif_ = partial(bcif2cif, bcif_dir=self.bcif_dir, cif_dir=self.cif_dir) 59 | with Pool(processes=num_cpus) as pool: 60 | for bcif_file, success in my_tqdm(pool.imap_unordered(bcif2cif_, os.listdir(self.bcif_dir)), desc=self._get_func_name()): 61 | if not success: 62 | print(f" {bcif_file} failed") 63 | 64 | def download_validation_score(self) -> None: 65 | os.makedirs(self.vs_dir, exist_ok=True) 66 | url = 'https://data.rcsb.org/graphql' 67 | with open(f"{PROJECT_DIR}/template/vs_query.txt", "r") as f: 68 | lines = f.readlines() 69 | query_string = "".join([line.strip() for line in lines]) 70 | for pdbid in my_tqdm(self.pdbid_list, desc=self._get_func_name()): 71 | json_path = f"{self.vs_dir}/{pdbid}.json" 72 | if os.path.exists(json_path): 73 | continue 74 | content = { 75 | "query": query_string, 76 | "variables": {"id": pdbid} 77 | } 78 | response = requests.post(url, json=content) 79 | with open(json_path, 'w') as f: 80 | json.dump(response.json(), f) 81 | 82 | def query_ccd(self, ccd: str) -> tuple[float, str]: 83 | """Query formula_weight and inchi of a ligand using rcsb API 84 | 85 | Args: 86 | ccd (str): chemical descriptions of a ligand 87 | 88 | Returns: 89 | tuple[float, str]: formula_weight and inchi of the ligand 90 | """ 91 | json_path = os.path.join(self.ccd_dir, f"{ccd}.json") 92 | if not os.path.exists(json_path): 93 | url = 'https://data.rcsb.org/graphql' 94 | with open(f"{PROJECT_DIR}/template/ccd_query.txt", "r") as f: 95 | lines = f.readlines() 96 | query_string = "".join([line.strip() for line in lines]) 97 | content = { 98 | "query": query_string, 99 | "variables": {"id": ccd} 100 | } 101 | response = requests.post(url, json=content) 102 | with open(json_path, 'w') as f: 103 | json.dump(response.json(), f) 104 | 105 | with open(json_path, "r") as f: 106 | info = json.load(f) 107 | formula_weight = info["data"]["chem_comp"]["chem_comp"]["formula_weight"] 108 | inchi = None 109 | pdbx_chem_comp_descriptor = info["data"]["chem_comp"]["pdbx_chem_comp_descriptor"] 110 | if pdbx_chem_comp_descriptor is not None: 111 | for item in pdbx_chem_comp_descriptor: 112 | if item["type"] == "InChI": 113 | inchi = item["descriptor"] 114 | break 115 | return formula_weight, inchi 116 | 117 | def _download_ccd_dict(self) -> str: 118 | """Download components.cif 119 | """ 120 | os.makedirs(self.ccd_dir, exist_ok=True) 121 | if not os.path.exists(self.components_path): 122 | print("downloading components.cif...") 123 | url = "https://files.wwpdb.org/pub/pdb/data/monomers/components.cif" 124 | subprocess.run(["wget", url], cwd=self.ccd_dir, check=True) 125 | print("finish") 126 | print("finish") 127 | 128 | def build_ccd_table(self) -> None: 129 | """Create a ccd table and save it to self.ccd_path 130 | """ 131 | if not os.path.exists(self.ccd_path): 132 | self._download_ccd_dict() 133 | cfr = CifFileReader() 134 | cif_obj = cfr.read(self.components_path, ignore = ['_atom_site']) 135 | df_dict = defaultdict(list) 136 | for ccd, values in cif_obj.items(): 137 | if ccd == "UNL": continue 138 | descriptor_types = values["_pdbx_chem_comp_descriptor"]["type"] 139 | inchi_idx = -1 140 | for idx, dtype in enumerate(descriptor_types): 141 | if dtype == "InChI": 142 | inchi_idx = idx 143 | break 144 | if inchi_idx == -1: 145 | print(f"no InChI for {ccd}") 146 | else: 147 | inchi = values["_pdbx_chem_comp_descriptor"]["descriptor"][inchi_idx] 148 | df_dict["CCD"].append(ccd) 149 | df_dict["InChI"].append(inchi) 150 | df_dict["MOLWT"].append(values["_chem_comp"]["formula_weight"]) 151 | else: 152 | df_dict = pd.read_csv(self.ccd_path).to_dict(orient="list") 153 | df_dict["CCD"] = [ccd.strip("''") for ccd in df_dict["CCD"]] 154 | 155 | # get all ccds 156 | ccd_set = set() 157 | for pdbid in self.pdbid_list: 158 | json_path = os.path.join(self.vs_dir, f"{pdbid}.json") 159 | with open(json_path, "r") as f: 160 | info = json.load(f) 161 | try: 162 | for nonpolymer_entity in info["data"]["entry"]["nonpolymer_entities"]: 163 | ccd = nonpolymer_entity["nonpolymer_comp"]["chem_comp"]["id"] 164 | ccd_set.add(ccd) 165 | except: 166 | print(json_path) 167 | raise 168 | 169 | uncover_ccds = ccd_set - set(df_dict["CCD"]) 170 | for ccd in my_tqdm(uncover_ccds, desc=self._get_func_name()): 171 | formula_weight, inchi = self.query_ccd(ccd) 172 | if formula_weight is not None and inchi is not None: 173 | df_dict["CCD"].append(ccd) 174 | df_dict["InChI"].append(inchi) 175 | df_dict["MOLWT"].append(formula_weight) 176 | df = pd.DataFrame(df_dict) 177 | 178 | df.CCD = df.CCD.apply(lambda x: f"\'{x}\'") 179 | df.to_csv(self.ccd_path, na_rep=None, index=False) 180 | 181 | @run_in_tmp_dir 182 | def download_ligand(self, num_cpus=20) -> dict[str, dict[str, list[str]]]: 183 | """Download sdf files from rcsb 184 | 185 | Returns: 186 | dict[str, dict[str, list[str]]]: 187 | dict of pdb to dict of ccd to list of asym_ids 188 | """ 189 | pdb_ccd_instance_map = {} 190 | fetch_ligands_ = partial(fetch_ligands, cif_dir=self.cif_dir, lig_dir=self.lig_dir) 191 | with Pool(processes=num_cpus) as pool: 192 | pool_iter = pool.imap_unordered(fetch_ligands_, self.pdbid_list) 193 | for pdb, ccd_instance_map in my_tqdm(pool_iter, total=len(self.pdbid_list), desc="downloading ligand"): 194 | pdb_ccd_instance_map[pdb] = ccd_instance_map 195 | return pdb_ccd_instance_map 196 | 197 | @run_in_tmp_dir 198 | def extract_molecule(self) -> None: 199 | """Extract sdf files of organic_molecule and metal_ion from cif files 200 | """ 201 | os.makedirs(self.molecule_dir, exist_ok=True) 202 | for pdb in my_tqdm(self.pdbid_list, desc=self._get_func_name()): 203 | cif_path = os.path.join(self.cif_dir, f'{pdb}.cif') 204 | organic_molecule_path = f"{self.molecule_dir}/{pdb}_organic_molecule.cif" 205 | metal_ion_path = f"{self.molecule_dir}/{pdb}_metal_ion.cif" 206 | if os.path.exists(organic_molecule_path) and os.path.exists(metal_ion_path): 207 | continue 208 | cmd.load(cif_path, pdb) 209 | cmd.select("organic_molecule", "organic") 210 | cmd.select("metal_ion", "metal") 211 | cmd.save(organic_molecule_path, "organic_molecule") 212 | cmd.save(metal_ion_path, "metal_ion") 213 | cmd.delete("all") 214 | 215 | def run(self) -> dict[str, dict[str, list[str]]]: 216 | """Preprocess 217 | 218 | Returns: 219 | dict[str, dict[str, list[str]]]: 220 | dict of pdb to dict of ccd to list of asym_ids 221 | """ 222 | self.download_bcif() 223 | self.convert_bif_to_cif() 224 | self.download_validation_score() 225 | self.build_ccd_table() 226 | self.extract_molecule() 227 | pdb_ccd_instance_map = self.download_ligand() 228 | return pdb_ccd_instance_map 229 | 230 | def fetch_ligands(pdb, cif_dir, lig_dir): 231 | _, ccd_instance_map = get_ccd_instance_map(pdb, cif_dir) 232 | for ccd, asym_ids in ccd_instance_map.items(): 233 | mols = [] 234 | ligand_path = f"{lig_dir}/{pdb}_{ccd}.sdf" 235 | if os.path.exists(ligand_path): 236 | continue 237 | for asym_id in asym_ids: 238 | instance_name = f"{pdb}_{ccd}_{asym_id}" 239 | tmp_ligand_path = f"{lig_dir}/{instance_name}.sdf" 240 | url = f"https://models.rcsb.org/v1/{pdb}/ligand?label_comp_id={ccd}&label_asym_id={asym_id}&encoding=sdf" 241 | try: 242 | r = requests.get(url) 243 | open(tmp_ligand_path , 'wb').write(r.content) 244 | except Exception as e: 245 | print(e, pdb, ccd, asym_id) 246 | mol = Chem.SDMolSupplier(tmp_ligand_path)[0] 247 | os.remove(tmp_ligand_path) 248 | if mol is not None: 249 | params = Chem.RemoveHsParameters() 250 | params.removeDegreeZero = True 251 | mol = Chem.RemoveHs(mol, params) 252 | props = list(mol.GetPropNames()) 253 | for prop in props: 254 | mol.ClearProp(prop) 255 | mol.SetProp('_Name', instance_name) 256 | mols.append(mol) 257 | else: 258 | # All ligand SDF files can be loaded with RDKit and pass its sanitization 259 | mols = [] 260 | break 261 | w = Chem.SDWriter(ligand_path) 262 | for mol in mols: 263 | w.write(mol) 264 | w.close() 265 | return pdb, ccd_instance_map -------------------------------------------------------------------------------- /dataset/template/ccd_query.txt: -------------------------------------------------------------------------------- 1 | query molecule ($id: String!) { 2 | chem_comp(comp_id:$id){ 3 | chem_comp { 4 | id, 5 | name, 6 | formula_weight, 7 | type 8 | } 9 | pdbx_chem_comp_descriptor { 10 | type, 11 | descriptor, 12 | program, 13 | program_version 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /dataset/template/cross_dock.txt: -------------------------------------------------------------------------------- 1 | \begin{table}[tb] 2 | \centering 3 | \caption{Selection process of the PDB entries and ligands for the PoseX Benchmark set (cross-dock).} 4 | \vspace{2mm} 5 | \resizebox{\columnwidth}{!}{ 6 | \begin{threeparttable} 7 | \begin{tabular}{p{0.60\columnwidth}cc} 8 | \toprule[1pt] 9 | Selection step & Number of proteins & Number of ligands\\ 10 | \hline 11 | PDB entries released from 1 January 2022 to 1 January 2025 feature a refinement resolution of 2 \angstrom~or better and include at least one protein and one ligand & {{ input.PDB }} & {{ input.CCD }} \\ 12 | Remove unknown ligands (e.g. UNX, UNL) & {{ filter_with_unknown_ccd.PDB }} & {{ filter_with_unknown_ccd.CCD }} \\ 13 | Remove proteins with a sequence length greater than 2000 & {{ filter_with_seq_length.PDB }} & {{ filter_with_seq_length.CCD }} \\ 14 | Ligands weighing from 100 Da to 900 Da & {{ filter_with_molwt.PDB }} & {{ filter_with_molwt.CCD }} \\ 15 | Ligands with at least 3 heavy atoms & {{ filter_with_num_heavy_atom.PDB }} & {{ filter_with_num_heavy_atom.CCD }} \\ 16 | Ligands containing only H, C, O, N, P, S, F, Cl atoms & {{ filter_with_mol_element.PDB }} & {{ filter_with_mol_element.CCD }} \\ 17 | Ligands that are not covalently bound to protein & {{ filter_with_covalent_bond.PDB }} & {{ filter_with_covalent_bond.CCD }} \\ 18 | Structures with no unknown atoms (e.g. element X) & {{ filter_with_unknown_atoms.PDB }} & {{ filter_with_unknown_atoms.CCD }} \\ 19 | Ligand real space R-factor is at most 0.2 & {{ filter_with_RSR.PDB }} & {{ filter_with_RSR.CCD }} \\ 20 | Ligand real space correlation coefficient is at least 0.95 & {{ filter_with_RSCC.PDB }} & {{ filter_with_RSCC.CCD }} \\ 21 | Ligand model completeness is 100\% & {{ filter_with_model_completeness.PDB }} & {{ filter_with_model_completeness.CCD }} \\ 22 | Ligand starting conformation could be generated with ETKDGv3 & {{ filter_with_ETKDG.PDB }} & {{ filter_with_ETKDG.CCD }} \\ 23 | All ligand SDF files can be loaded with RDKit and pass its sanitization & {{ filter_with_rdkit.PDB }} & {{ filter_with_rdkit.CCD }} \\ 24 | PDB ligand report does not list stereochemical errors & {{ filter_with_stereo_outliers.PDB }} & {{ filter_with_stereo_outliers.CCD }} \\ 25 | PDB ligand report does not list any atomic clashes & {{ filter_with_intermolecular_clashes.PDB }} & {{ filter_with_intermolecular_clashes.CCD }} \\ 26 | Select single protein-ligand conformation \tnote{1} & {{ select_single_conformation.PDB }} & {{ select_single_conformation.CCD }} \\ 27 | Intermolecular distance between the ligand(s) and the protein is at least 0.2 \angstrom & {{ filter_with_ligand_protein_distance.PDB }} & {{ filter_with_ligand_protein_distance.CCD }} \\ 28 | Intermolecular distance between the ligand(s) and the other ligands is at least 5.0 \angstrom & {{ filter_with_ligand_ligand_distance.PDB }} & {{ filter_with_ligand_ligand_distance.CCD }} \\ 29 | Remove ligands which are within 5.0 \angstrom~of any protein symmetry mate & {{ filter_with_crystal_contact.PDB }} & {{ filter_with_crystal_contact.CCD }} \\ 30 | Cluster proteins that have at least 90\% sequence identity \tnote{2} & {{ filter_with_clustering.PDB }} & {{ filter_with_clustering.CCD }} \\ 31 | Structures can be successfully alinged to the reference structure in each cluster \tnote{3} & {{ filter_with_cross_alignment.PDB }} & {{ filter_with_cross_alignment.CCD }} \\ 32 | \bottomrule[1pt] 33 | \end{tabular} 34 | \begin{tablenotes} 35 | \item[1] The first conformation was chosen when multiple conformations were available in the PDB entry. 36 | \item[2] Clustering with MMseqs2 is done with an sequence identity threshold of 90\% and a minimum coverage of 80\%. 37 | \item[3] Each candidate protein is structurally aligned to the reference protein via superposition of $C_{\alpha}$ atom of amino acid residues using PyMOL. A candidate PDB entry is removed if the RMSD of the protein alignment is greater than 2.0 \angstrom~and a candidate ligand is removed if it is 4.0 \angstrom~away from the reference ligand. 38 | \end{tablenotes} 39 | \end{threeparttable} 40 | } 41 | \label{table:notation} 42 | \end{table} -------------------------------------------------------------------------------- /dataset/template/self_dock.txt: -------------------------------------------------------------------------------- 1 | \begin{table}[tb] 2 | \centering 3 | \caption{Selection process of the PDB entries and ligands for the PoseX Benchmark set (cross-dock).} 4 | \vspace{2mm} 5 | \resizebox{\columnwidth}{!}{ 6 | \begin{threeparttable} 7 | \begin{tabular}{p{0.60\columnwidth}cc} 8 | \toprule[1pt] 9 | Selection step & Number of proteins & Number of ligands\\ 10 | \hline 11 | PDB entries released from 1 January 2022 to 1 January 2025 feature a refinement resolution of 2 \angstrom~or better and include at least one protein and one ligand & {{ input.PDB }} & {{ input.CCD }} \\ 12 | Remove unknown ligands (e.g. UNX, UNL) & {{ filter_with_unknown_ccd.PDB }} & {{ filter_with_unknown_ccd.CCD }} \\ 13 | Remove proteins with a sequence length greater than 2000 & {{ filter_with_seq_length.PDB }} & {{ filter_with_seq_length.CCD }} \\ 14 | Ligands weighing from 100 Da to 900 Da & {{ filter_with_molwt.PDB }} & {{ filter_with_molwt.CCD }} \\ 15 | Ligands with at least 3 heavy atoms & {{ filter_with_num_heavy_atom.PDB }} & {{ filter_with_num_heavy_atom.CCD }} \\ 16 | Ligands containing only H, C, O, N, P, S, F, Cl atoms & {{ filter_with_mol_element.PDB }} & {{ filter_with_mol_element.CCD }} \\ 17 | Ligands that are not covalently bound to protein & {{ filter_with_covalent_bond.PDB }} & {{ filter_with_covalent_bond.CCD }} \\ 18 | Structures with no unknown atoms (e.g. element X) & {{ filter_with_unknown_atoms.PDB }} & {{ filter_with_unknown_atoms.CCD }} \\ 19 | Ligand real space R-factor is at most 0.2 & {{ filter_with_RSR.PDB }} & {{ filter_with_RSR.CCD }} \\ 20 | Ligand real space correlation coefficient is at least 0.95 & {{ filter_with_RSCC.PDB }} & {{ filter_with_RSCC.CCD }} \\ 21 | Ligand model completeness is 100\% & {{ filter_with_model_completeness.PDB }} & {{ filter_with_model_completeness.CCD }} \\ 22 | Ligand starting conformation could be generated with ETKDGv3 & {{ filter_with_ETKDG.PDB }} & {{ filter_with_ETKDG.CCD }} \\ 23 | All ligand SDF files can be loaded with RDKit and pass its sanitization & {{ filter_with_rdkit.PDB }} & {{ filter_with_rdkit.CCD }} \\ 24 | PDB ligand report does not list stereochemical errors & {{ filter_with_stereo_outliers.PDB }} & {{ filter_with_stereo_outliers.CCD }} \\ 25 | PDB ligand report does not list any atomic clashes & {{ filter_with_intermolecular_clashes.PDB }} & {{ filter_with_intermolecular_clashes.CCD }} \\ 26 | Select single protein-ligand conformation \tnote{1} & {{ select_single_conformation.PDB }} & {{ select_single_conformation.CCD }} \\ 27 | Intermolecular distance between the ligand(s) and the protein is at least 0.2 \angstrom & {{ filter_with_ligand_protein_distance.PDB }} & {{ filter_with_ligand_protein_distance.CCD }} \\ 28 | Intermolecular distance between ligand(s) and other small organic molecules is at least 0.2 \angstrom & {{ filter_with_ligand_organic_molecule_distance.PDB }} & {{ filter_with_ligand_organic_molecule_distance.CCD }} \\ 29 | Intermolecular distance between ligand(s) and ion metals in complex is at least 0.2 \angstrom & {{ filter_with_ligand_metal_ion_distance.PDB }} & {{ filter_with_ligand_metal_ion_distance.CCD }} \\ 30 | Remove ligands which are within 5.0 \angstrom~of any protein symmetry mate & {{ filter_with_crystal_contact.PDB }} & {{ filter_with_crystal_contact.CCD }} \\ 31 | Get a set with unique pdbs and unique ccds by Hopcroft–Karp matching algorithm & {{ filter_with_unique_pdb_ccd.PDB }} & {{ filter_with_unique_pdb_ccd.CCD }} \\ 32 | Select representative PDB entries by clustering protein sequences & {{ filter_with_clustering.PDB }} & {{ filter_with_clustering.CCD }} \\ 33 | \bottomrule[1pt] 34 | \end{tabular} 35 | \begin{tablenotes} 36 | \item[1] The first conformation was chosen when multiple conformations were available in the PDB entry. 37 | \item[2] Clustering with MMseqs2 is done with an sequence identity threshold of 0\% and a minimum coverage of 100\%. 38 | \end{tablenotes} 39 | \end{threeparttable} 40 | } 41 | \label{table:notation} 42 | \end{table} -------------------------------------------------------------------------------- /dataset/template/vs_query.txt: -------------------------------------------------------------------------------- 1 | query ($id: String!) { 2 | entry(entry_id: $id) { 3 | rcsb_entry_info { 4 | experimental_method, 5 | resolution_combined 6 | } 7 | nonpolymer_entities { 8 | nonpolymer_comp { 9 | chem_comp { 10 | name, 11 | id 12 | } 13 | } 14 | rcsb_nonpolymer_entity_annotation { 15 | type 16 | } 17 | nonpolymer_entity_instances { 18 | rcsb_nonpolymer_entity_instance_container_identifiers { 19 | auth_seq_id, 20 | auth_asym_id, 21 | asym_id, 22 | entity_id, 23 | entry_id 24 | } 25 | rcsb_nonpolymer_instance_validation_score { 26 | RSCC, 27 | RSR, 28 | completeness, 29 | intermolecular_clashes, 30 | stereo_outliers, 31 | is_subject_of_investigation 32 | } 33 | } 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /dataset/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib, tempfile 2 | from pathlib import Path 3 | from functools import partial 4 | from subprocess import check_output as _call 5 | run_shell_cmd = partial(_call, shell=True) 6 | 7 | 8 | @contextlib.contextmanager 9 | def temprary_filename(mode='w+b', suffix=None): 10 | tmp_name = None 11 | try: 12 | tmp_file = tempfile.NamedTemporaryFile(mode=mode, suffix=suffix, delete=False) 13 | tmp_name = tmp_file.name 14 | tmp_file.close() 15 | yield tmp_name 16 | finally: 17 | Path(tmp_name).unlink(missing_ok=True) -------------------------------------------------------------------------------- /dataset/utils/common_helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import contextlib, tempfile 4 | from pathlib import Path 5 | from functools import partial 6 | from subprocess import check_output as _call 7 | run_shell_cmd = partial(_call, shell=True) 8 | 9 | DEBUG_MODE = bool(int(os.environ.get("DEBUG_MODE", 0))) 10 | 11 | 12 | def create_logger(filename): 13 | logger_ = logging.getLogger(filename) # type: logging.Logger 14 | 15 | if DEBUG_MODE: 16 | logger_.setLevel(logging.DEBUG) 17 | else: 18 | logger_.setLevel(logging.INFO) 19 | 20 | handler = logging.StreamHandler() 21 | handler.setLevel(logging.DEBUG) 22 | formatter = logging.Formatter( 23 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 24 | ) 25 | handler.setFormatter(formatter) 26 | logger_.addHandler(handler) 27 | return logger_ 28 | 29 | 30 | 31 | 32 | @contextlib.contextmanager 33 | def temprary_filename(mode='w+b', suffix=None): 34 | tmp_name = None 35 | try: 36 | tmp_file = tempfile.NamedTemporaryFile(mode=mode, suffix=suffix, delete=False) 37 | tmp_name = tmp_file.name 38 | tmp_file.close() 39 | yield tmp_name 40 | finally: 41 | Path(tmp_name).unlink(missing_ok=True) -------------------------------------------------------------------------------- /dataset/utils/mol_correct_helper.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | import openmm.app as app 3 | from openmm import unit, Vec3 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | 8 | CURRENTLY_ACCEPTABLE_PROTEIN_RESIDUES = { 9 | "ALA": "A", 10 | "ASN": "N", 11 | "CYS": "C", 12 | "GLU": "E", 13 | "HIS": "H", 14 | "LEU": "L", 15 | "MET": "M", 16 | "PRO": "P", 17 | "THR": "T", 18 | "TYR": "Y", 19 | "ARG": "R", 20 | "ASP": "D", 21 | "GLN": "Q", 22 | "GLY": "G", 23 | "ILE": "I", 24 | "LYS": "K", 25 | "PHE": "F", 26 | "SER": "S", 27 | "TRP": "W", 28 | "VAL": "V", 29 | "ACE": "B", 30 | "NME": "Z", 31 | "HOH": "O", 32 | } 33 | 34 | AA2SMILES = { 35 | "ACE": "CC(=O)", 36 | "NME": "NC", 37 | "ALA": "C[C@H](N)C=O", 38 | "CYS": "N[C@H](C=O)CS", 39 | "ASP": "N[C@H](C=O)CC(=O)[O-]", 40 | "GLU": "N[C@H](C=O)CCC(=O)[O-]", 41 | "PHE": "N[C@H](C=O)Cc1ccccc1", 42 | "GLY": "NCC=O", 43 | "HIS": [ 44 | "N([H])C([H])(C=O)C([H])([H])C1=C([H])N([H])C([H])=N1", 45 | "N(C(C(C1N([H])C([H])=NC=1[H])([H])[H])(C=O)[H])[H]", 46 | "NC(CC1N=CNC=1)C=O", 47 | ], 48 | "ILE": "CC[C@H](C)[C@H](N)C=O", 49 | "LYS": "[NH3+]CCCC[C@H](N)C=O", 50 | "LEU": "CC(C)C[C@H](N)C=O", 51 | "MET": "CSCC[C@H](N)C=O", 52 | "ASN": "NC(=O)C[C@H](N)C=O", 53 | "PRO": "O=C[C@@H]1CCCN1", 54 | "GLN": "NC(=O)CC[C@H](N)C=O", 55 | "ARG": "NC(=[NH2+])NCCC[C@H](N)C=O", 56 | "SER": "N[C@H](C=O)CO", 57 | "THR": "C[C@@H](O)[C@H](N)C=O", 58 | "VAL": "CC(C)[C@H](N)C=O", 59 | "TRP": "N[C@H](C=O)Cc1c[nH]c2ccccc12", 60 | "TYR": "N[C@H](C=O)Cc1ccc(O)cc1", 61 | "HOH": "O", 62 | } 63 | 64 | 65 | def format_4letter(atom_name: str): 66 | output = None 67 | if len(atom_name) == 4: 68 | output = atom_name 69 | elif len(atom_name) == 3: 70 | output = " " + atom_name 71 | elif len(atom_name) == 2: 72 | output = " " + atom_name + " " 73 | elif len(atom_name) == 1: 74 | output = " " + atom_name + " " 75 | else: 76 | raise ValueError() 77 | 78 | return output 79 | 80 | 81 | def assign_bo_with_template_smiles( 82 | mol: Chem.Mol, aa_name: str, slice_ids: list, connect_sites: list, max_match_num=10000, 83 | ): 84 | aa_smiles = AA2SMILES[aa_name] 85 | if isinstance(aa_smiles, str): 86 | aa_smiles_list = [aa_smiles] 87 | else: 88 | aa_smiles_list = aa_smiles 89 | 90 | for idx, aa_smi_ in enumerate(aa_smiles_list): 91 | params = Chem.SmilesParserParams() 92 | if "[H]" in aa_smi_: 93 | params.removeHs = False 94 | 95 | aa_mol = Chem.MolFromSmiles(aa_smi_, params) 96 | Chem.Kekulize(aa_mol, clearAromaticFlags=True) 97 | aa_mol2 = Chem.Mol(aa_mol) 98 | aa_mol2_chg = dict() 99 | 100 | for b in aa_mol2.GetBonds(): 101 | if b.GetBondType() != Chem.BondType.SINGLE: 102 | b.SetBondType(Chem.BondType.SINGLE) 103 | b.SetIsAromatic(False) 104 | 105 | # set atom charges to zero; 106 | for atom in aa_mol2.GetAtoms(): 107 | aa_mol2_chg[atom.GetIdx()] = atom.GetFormalCharge() 108 | atom.SetFormalCharge(0) 109 | 110 | matches = mol.GetSubstructMatches(aa_mol2, maxMatches=max_match_num) 111 | filtered_matches = [] 112 | for match_ in matches: 113 | # if match_[0] in visited_ids: continue 114 | # visited_ids.update(match_) 115 | atom = mol.GetAtomWithIdx(match_[0]) 116 | if atom.GetPDBResidueInfo().GetResidueName() == aa_name: 117 | filtered_matches.append(match_) 118 | 119 | if len(filtered_matches) == 0: 120 | print(f"{aa_smi_}: {idx} not exist!") 121 | 122 | aa_ids = list(range(aa_mol2.GetNumAtoms())) 123 | for slice_atoms in slice_ids: 124 | for match_ in filtered_matches: 125 | if len(set(match_) - set(slice_atoms)) == 0: 126 | mapping = dict() 127 | for idx1, idx2 in zip(match_, aa_ids): 128 | atom = mol.GetAtomWithIdx(idx1) 129 | if aa_mol2_chg[idx2] != 0: 130 | atom.SetFormalCharge(aa_mol2_chg[idx2]) 131 | else: 132 | if ( 133 | atom.GetSymbol() == "N" 134 | and len(list(atom.GetNeighbors())) > 3 135 | ): 136 | connect_sites.append(atom.GetIdx()) 137 | mapping[idx2] = idx1 138 | 139 | for bond in aa_mol.GetBonds(): 140 | at1 = bond.GetBeginAtomIdx() 141 | at2 = bond.GetEndAtomIdx() 142 | new_bond = mol.GetBondBetweenAtoms(mapping[at1], mapping[at2]) 143 | new_bond.SetBondType(bond.GetBondType()) 144 | return mol 145 | 146 | 147 | def rdmol_to_omm(rdmol: Chem.Mol) -> app.Modeller: 148 | # convert RDKit to OpenFF 149 | from openff.toolkit import Molecule 150 | 151 | off_mol = Molecule.from_rdkit(rdmol, hydrogens_are_explicit=True) 152 | 153 | # convert from OpenFF to OpenMM 154 | off_mol_topology = off_mol.to_topology() 155 | mol_topology = off_mol_topology.to_openmm() 156 | mol_positions = off_mol.conformers[0] 157 | # convert units from Ångström to nanometers 158 | mol_positions = mol_positions.to("nanometers") 159 | # combine topology and positions in modeller object 160 | omm_mol = app.Modeller(mol_topology, mol_positions) 161 | return omm_mol 162 | 163 | 164 | def omm_protein_to_rdmol( 165 | topology: app.Topology, positions: unit.Quantity = None 166 | ) -> Chem.Mol: 167 | new_mol = Chem.RWMol() 168 | atom_mapper = dict() 169 | residue_mapper = defaultdict(list) 170 | pos_chg_idxs = [] 171 | middle_xt_idxs = [] 172 | for chain in topology.chains(): 173 | chain: app.Chain 174 | num_res = len(chain._residues) 175 | for res_idx, residue in enumerate(chain.residues()): 176 | for atom in residue.atoms(): 177 | atom: app.Atom 178 | rdatom = Chem.Atom(atom.element.atomic_number) 179 | # rdatom.SetNoImplicit(True) 180 | mi = Chem.AtomPDBResidueInfo() 181 | mi.SetResidueName(atom.residue.name) 182 | mi.SetResidueNumber(int(atom.residue.id)) 183 | mi.SetChainId(atom.residue.chain.id) 184 | mi.SetName(format_4letter(atom.name)) 185 | mi.SetInsertionCode("<>") 186 | 187 | if ( 188 | res_idx == 0 and atom.name == "N" 189 | ): # 如果第一个氨基酸不是capped residue, 则设置为N+ 190 | # 每条链第一个残基的N设置为+1 191 | rdatom.SetFormalCharge(1) 192 | 193 | elif (atom.name == "OXT") and (res_idx == num_res - 1): 194 | # 末端原子设置为N负 195 | rdatom.SetFormalCharge(-1) 196 | else: 197 | rdatom.SetFormalCharge(0) 198 | 199 | rdatom.SetMonomerInfo(mi) 200 | index = new_mol.AddAtom(rdatom) 201 | atom_mapper[atom.index] = index 202 | key = (atom.residue.name, int(atom.residue.id), atom.residue.chain.id) 203 | 204 | if atom.name[-2:] != "XT": # 封端原子 205 | residue_mapper[key].append(index) 206 | else: 207 | if res_idx < num_res - 1: 208 | middle_xt_idxs.append(index) 209 | 210 | if res_idx == 1 and atom.name == "N": 211 | # 每条链第2个残基的主链N 212 | pos_chg_idxs.append(index) 213 | 214 | split_ids = [] 215 | for bond in topology.bonds(): 216 | if bond[0].index in atom_mapper and bond[1].index in atom_mapper: 217 | at1 = atom_mapper[bond[0].index] 218 | at2 = atom_mapper[bond[1].index] 219 | new_mol.AddBond(at1, at2, Chem.BondType.SINGLE) 220 | if bond[0].residue.id != bond[1].residue.id: 221 | rdbond = new_mol.GetBondBetweenAtoms(at1, at2) 222 | split_ids.append(rdbond.GetIdx()) 223 | 224 | residue_byres = defaultdict(list) 225 | for res_info, res_atoms in residue_mapper.items(): 226 | residue_byres[res_info[0]].append(res_atoms) 227 | 228 | connect_sites = [] 229 | # visited_ids = set() 230 | for res_name, res_atoms_list in residue_byres.items(): 231 | # if res_name=="ACE": 232 | # print('ACE') 233 | assign_bo_with_template_smiles(new_mol, res_name, res_atoms_list, connect_sites) 234 | 235 | # 移除第二个氨基酸的N上多连接的H 236 | remove_h_idxs = [] 237 | for atidx in list(set(pos_chg_idxs + connect_sites)): 238 | atom = new_mol.GetAtomWithIdx(atidx) 239 | count = 0 240 | tmp_idxs = [] 241 | for nbr_atom in atom.GetNeighbors(): 242 | count += 1 243 | if nbr_atom.GetSymbol() == "H": 244 | tmp_idxs.append(nbr_atom.GetIdx()) 245 | remove_h_idxs.extend(tmp_idxs[-(count - 3) :]) 246 | 247 | atom_positions = np.array(positions._value) * 10.0 248 | remove_h_idxs += middle_xt_idxs 249 | if len(remove_h_idxs) > 0: 250 | remove_h_idxs = list(sorted(list(set(remove_h_idxs)), reverse=True)) 251 | [new_mol.RemoveAtom(idx) for idx in remove_h_idxs] 252 | atom_positions = np.delete(atom_positions, remove_h_idxs, axis=0) 253 | 254 | new_mol = new_mol.GetMol() 255 | new_mol.UpdatePropertyCache(strict=False) 256 | problems = Chem.DetectChemistryProblems(new_mol) 257 | for problem in problems: 258 | cur_idx = problem.GetAtomIdx() 259 | atom = new_mol.GetAtomWithIdx(cur_idx) 260 | pdb_info = atom.GetPDBResidueInfo() 261 | print( 262 | pdb_info.GetResidueName(), 263 | pdb_info.GetResidueNumber(), 264 | pdb_info.GetChainId(), 265 | ) 266 | 267 | Chem.SanitizeMol(new_mol) 268 | 269 | conf = Chem.Conformer(new_mol.GetNumAtoms()) 270 | for aidx in range(new_mol.GetNumAtoms()): 271 | conf.SetAtomPosition(aidx, atom_positions[aidx]) 272 | 273 | new_mol.RemoveAllConformers() 274 | new_mol.AddConformer(conf, assignId=True) 275 | new_mol_h = Chem.AddHs(new_mol, addCoords=True) 276 | # new_mol_h = new_mol 277 | 278 | h_records = defaultdict(int) 279 | for i in range(new_mol.GetNumAtoms(), new_mol_h.GetNumAtoms()): 280 | atom = new_mol_h.GetAtomWithIdx(i) 281 | nbr_atom = atom.GetNeighbors()[0] 282 | nbr_mi = nbr_atom.GetPDBResidueInfo() 283 | mi = Chem.AtomPDBResidueInfo() 284 | mi.SetResidueName(nbr_mi.GetResidueName()) 285 | mi.SetResidueNumber(int(nbr_mi.GetResidueNumber())) 286 | mi.SetChainId(nbr_mi.GetChainId()) 287 | mi.SetInsertionCode("<>") 288 | nbr_name = nbr_mi.GetName().strip() 289 | key = ( 290 | nbr_mi.GetResidueName(), 291 | int(nbr_mi.GetResidueNumber()), 292 | nbr_mi.GetChainId(), 293 | nbr_name, 294 | ) 295 | h_records[key] += 1 296 | if nbr_name == "N" and nbr_mi.GetResidueName() in ["NME"]: 297 | label = f"H" 298 | elif len(nbr_name) == 2: 299 | label = f"H{nbr_name[1]}{h_records[key]}" 300 | elif len(nbr_name) == 3: 301 | label = f"H{nbr_name[1:]}{h_records[key]}" 302 | else: 303 | label = f"H{h_records[key]}" 304 | mi.SetName(format_4letter(label)) 305 | atom.SetMonomerInfo(mi) 306 | 307 | # 重排 H 308 | residue_mapper = defaultdict(list) 309 | for atom in new_mol_h.GetAtoms(): 310 | pdb_info = atom.GetPDBResidueInfo() 311 | key = ( 312 | pdb_info.GetResidueName(), 313 | pdb_info.GetResidueNumber(), 314 | pdb_info.GetChainId(), 315 | ) 316 | residue_mapper[key].append(atom.GetIdx()) 317 | 318 | resort_list = [] 319 | [resort_list.extend(values) for _, values in residue_mapper.items()] 320 | new_mol_h = Chem.RenumberAtoms(new_mol_h, resort_list) 321 | # Chem.AssignStereochemistryFrom3D(new_mol_h) 322 | # Chem.SetDoubleBondNeighborDirections(new_mol_h) 323 | return new_mol_h 324 | -------------------------------------------------------------------------------- /dataset/utils/openmm_helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | from collections import defaultdict 4 | from functools import partial 5 | from io import StringIO 6 | from pathlib import Path 7 | from subprocess import check_output as _call 8 | from typing import * 9 | 10 | import numpy as np 11 | import openmm as omm 12 | import openmm.app as app 13 | import pdbfixer 14 | from openff.toolkit.topology import Molecule 15 | from openmm import Vec3, unit, MinimizationReporter 16 | from openmm.app import PDBFile, StateDataReporter 17 | from openmmforcefields.generators import ( 18 | GAFFTemplateGenerator, 19 | SMIRNOFFTemplateGenerator, 20 | ) 21 | from rdkit import Chem 22 | 23 | from sys import stdout 24 | 25 | from .common_helper import create_logger, temprary_filename 26 | 27 | logger = create_logger(__name__) 28 | 29 | METALS = ["Na", "K", "Ca", "Mg", "Fe", "Zn", "Cu", "Mn", "Co", "Ni"] 30 | run_shell_cmd = partial(_call, shell=True) 31 | 32 | 33 | class MinReporter(MinimizationReporter): 34 | lastIteration = -1 35 | error = False 36 | 37 | def report(self, iteration, x, grad, args) -> bool: 38 | if iteration != self.lastIteration + 1: 39 | self.error = True 40 | self.lastIteration = iteration 41 | if iteration == 10: 42 | print(f"{iteration=}") 43 | return True 44 | if iteration > 10: 45 | self.error = True 46 | return False 47 | 48 | 49 | def assign_mol_with_pos(new_mol: Chem.Mol, pos: np.array) -> Chem.Mol: 50 | conf = Chem.Conformer(new_mol.GetNumAtoms()) 51 | for idx in range(len(pos)): 52 | conf.SetAtomPosition(idx, pos[idx]) 53 | # conf.SetPositions(pos) 54 | new_mol.RemoveAllConformers() 55 | new_mol.AddConformer(conf, assignId=True) 56 | return new_mol 57 | 58 | 59 | def load_sdf_to_omm( 60 | sdf_fn: str | Path, 61 | ) -> app.Modeller: 62 | # convert RDKit to OpenFF 63 | rdk_mol = Chem.SDMolSupplier(sdf_fn, removeHs=False)[0] 64 | rdk_mol = Chem.AddHs(rdk_mol, addCoords=True) 65 | off_mol = Molecule.from_rdkit( 66 | rdk_mol, hydrogens_are_explicit=True, allow_undefined_stereo=True 67 | ) 68 | props = rdk_mol.GetPropsAsDict() 69 | [rdk_mol.ClearProp(k) for k in props.keys()] 70 | 71 | # convert from OpenFF to OpenMM 72 | off_mol_topology = off_mol.to_topology() 73 | mol_topology = off_mol_topology.to_openmm() 74 | mol_positions = off_mol.conformers[0] 75 | # convert units from Ångström to nanometers 76 | mol_positions = mol_positions.to("nanometers") 77 | lig_pos = [Vec3(*x.m) for x in mol_positions] * unit.nanometers 78 | # combine topology and positions in modeller object 79 | for residue in mol_topology.residues(): 80 | residue.name = "UNL" 81 | omm_mol = app.Modeller(mol_topology, lig_pos) 82 | off_mol.name = "UNL" 83 | return omm_mol, off_mol, rdk_mol 84 | 85 | 86 | def get_am1bcc_charge(ligand_fn: Path): 87 | lig_mol = Chem.SDMolSupplier(ligand_fn, removeHs=False)[0] 88 | total_charge = 0 89 | for atom in lig_mol.GetAtoms(): 90 | total_charge += atom.GetFormalCharge() 91 | with temprary_filename(mode="w", suffix="_ligand.mol2") as tmp_out_fn: 92 | cmd1 = f"antechamber -i {ligand_fn} -fi sdf -o {tmp_out_fn} -fo mol2 -c bcc -nc {total_charge} -at gaff2 -ek \"qm_theory='AM1', scfconv=1.d-10, maxcyc=0, grms_tol=0.0005, ndiis_attempts=700\" -pf y" 93 | run_shell_cmd(cmd1) 94 | with open(tmp_out_fn, "r") as f: 95 | mol_data = f.read() 96 | atom_section = re.search(r"@ATOM([\s\S]+?)@BOND", mol_data) 97 | atom_lines = atom_section.group(1).strip().split("\n") 98 | partial_charges = [float(line.split()[-1]) for line in atom_lines] 99 | 100 | assert ( 101 | len(partial_charges) == lig_mol.GetNumAtoms() 102 | ), f"{ligand_fn} partial charges are not correct" 103 | 104 | from openff.toolkit import Quantity 105 | from openff.toolkit import unit as off_unit 106 | 107 | partial_charges = Quantity(np.asarray(partial_charges), off_unit.elementary_charge) 108 | return partial_charges 109 | 110 | 111 | class ProLigRelax: 112 | def __init__( 113 | self, 114 | protein_mol: Chem.Mol = None, 115 | platform="CPU:16", 116 | ligand_ff="gaff", 117 | receptor_ff="amber", 118 | charge_name="mmff94", 119 | missing_residues: list = None, 120 | is_constrain: tuple[bool, str] = (False, "None"), 121 | is_restrain: tuple[bool, str] = (False, "None"), 122 | max_iteration: int = 0, 123 | ) -> None: 124 | self.receptor_ffname = receptor_ff 125 | self.ligand_ffname = ligand_ff 126 | self.partial_chargename = charge_name 127 | 128 | self.platform_properties = {} 129 | if platform.startswith("CPU"): 130 | device, num_core = platform.split(":") 131 | self.platform = omm.Platform.getPlatform(device) 132 | self.platform.setPropertyDefaultValue("Threads", str(num_core)) 133 | 134 | elif platform.startswith("CUDA"): 135 | device, device_id = platform.split(":") 136 | self.platform = omm.Platform.getPlatformByName(device) 137 | self.platform_properties.update({"DeviceIndex": device_id}) 138 | self.platform_properties.update({"Precision": "double"}) 139 | 140 | else: 141 | raise NotImplementedError() 142 | 143 | self.receptor_rdmol = protein_mol 144 | self.receptor_omm = app.PDBFile( 145 | StringIO(Chem.MolToPDBBlock(protein_mol).replace("< ", " ")) 146 | ) 147 | self.base_forcefield = self._load_rec_forcefield() 148 | self.missing_residues = missing_residues 149 | self.is_constrain = is_constrain 150 | self.is_restrain = is_restrain 151 | self.max_iteration = max_iteration 152 | 153 | self.forcefield_kwargs = { 154 | "nonbondedMethod": app.CutoffNonPeriodic, 155 | "nonbondedCutoff": 1.0 * unit.nanometer, # (default: 1) 156 | # "rigidWater": False, 157 | # "removeCMMotion": True, 158 | "constraints": app.HBonds, 159 | } 160 | 161 | def _load_rec_forcefield(self) -> None: 162 | if self.receptor_ffname == "amber": 163 | receptor_ffs = [ 164 | "amber14/protein.ff14SB.xml", 165 | "amber14/tip3pfb.xml", 166 | "implicit/obc2.xml", 167 | ] 168 | forcefield = app.ForceField(*receptor_ffs) 169 | else: 170 | raise NotImplementedError() 171 | 172 | return forcefield 173 | 174 | def _add_lig_forcefield( 175 | self, ligand_mol: Molecule, ligand_fn: Path = None 176 | ) -> app.ForceField: 177 | if getattr(ligand_mol, "partial_charges") is None: 178 | if self.partial_chargename != "am1bcc": 179 | ligand_mol.assign_partial_charges( 180 | partial_charge_method=self.partial_chargename, 181 | use_conformers=ligand_mol.conformers[0], 182 | ) 183 | else: 184 | # ligand_mol.assign_partial_charges(partial_charge_method='mmff94') 185 | ligand_mol.partial_charges = get_am1bcc_charge(ligand_fn) 186 | 187 | if self.ligand_ffname == "gaff": 188 | ffgen = GAFFTemplateGenerator(forcefield="gaff-2.11") 189 | ffxml_contents = ffgen.generate_residue_template(ligand_mol) 190 | 191 | elif self.ligand_ffname == "openff": 192 | ffgen = SMIRNOFFTemplateGenerator(forcefield="openff-2.1.0") 193 | ffxml_contents = ffgen.generate_residue_template(ligand_mol) 194 | 195 | else: 196 | raise NotImplementedError() 197 | 198 | forcefield: app.ForceField = copy.deepcopy(self.base_forcefield) 199 | forcefield.loadFile(StringIO(ffxml_contents)) 200 | return forcefield 201 | 202 | @staticmethod 203 | def constrain_pocket( 204 | system, top, missing_residues: list = [], level="main" 205 | ) -> None: 206 | assert level in ["none", "main", "heavy", "all"] 207 | for atom in top.topology.atoms(): 208 | if atom.residue.name in ["UNL", "HOH"] + METALS: # water和metal不被限制 209 | continue 210 | 211 | if len(missing_residues) > 0: # missing residues不设置限制 212 | if len(missing_residues[0]) == 3: 213 | tag = (atom.residue.chain.index, atom.residue.id, atom.residue.name) 214 | elif len(missing_residues[0]) == 4: 215 | tag = ( 216 | atom.residue.name, 217 | atom.residue.id, 218 | atom.residue.chain.index, 219 | atom.name, 220 | ) 221 | 222 | if tag in missing_residues: 223 | continue 224 | 225 | if level == "all": 226 | system.setParticleMass(atom.index, 0) 227 | elif level == "main" and atom.name in ["CA", "C", "N"]: 228 | system.setParticleMass(atom.index, 0) 229 | elif level == "heavy" and atom.name[0] != "H": 230 | system.setParticleMass(atom.index, 0) 231 | 232 | @staticmethod 233 | def restrain_pocket( 234 | system, 235 | model: app.Modeller, 236 | missing_residues: list = [], 237 | level="main", 238 | stiffness: float = 10, 239 | ) -> None: 240 | assert level in ["none", "main", "heavy", "all"] 241 | 242 | forces = omm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)") 243 | forces.addGlobalParameter("k", stiffness) 244 | for p in ["x0", "y0", "z0"]: 245 | forces.addPerParticleParameter(p) 246 | 247 | for i, atom in enumerate(model.topology.atoms()): 248 | if atom.residue.name in ["UNL", "HOH"] + METALS: 249 | continue 250 | 251 | if len(missing_residues) > 0: 252 | if len(missing_residues[0]) == 3: 253 | tag = (atom.residue.chain.index, atom.residue.id, atom.residue.name) 254 | elif len(missing_residues[0]) == 4: 255 | tag = ( 256 | atom.residue.name, 257 | atom.residue.id, 258 | atom.residue.chain.id, 259 | atom.name, 260 | ) 261 | 262 | if tag in missing_residues: 263 | continue 264 | 265 | if level == "all": 266 | forces.addParticle(i, model.positions[i]) 267 | elif level == "main" and atom.name in ["CA", "C", "N", "O"]: 268 | forces.addParticle(i, model.positions[i]) 269 | elif level == "heavy" and atom.name[0] != "H": 270 | forces.addParticle(i, model.positions[i]) 271 | 272 | system.addForce(forces) 273 | 274 | def _minimize_energy( 275 | self, cplx_forcefield: app.ForceField, cplx_model: app.Modeller 276 | ) -> np.ndarray: 277 | resname = "FE" 278 | choice = "FE2" 279 | residueTemplates = dict( 280 | (res, choice) 281 | for res in cplx_model.topology.residues() 282 | if res.name == resname 283 | ) 284 | if self.is_restrain[0]: 285 | system = cplx_forcefield.createSystem( 286 | cplx_model.topology, 287 | residueTemplates=residueTemplates, 288 | **self.forcefield_kwargs, 289 | ) 290 | self.restrain_pocket( 291 | system, 292 | cplx_model, 293 | missing_residues=self.missing_residues, 294 | level=self.is_restrain[1], 295 | ) 296 | elif self.is_constrain[0]: 297 | self.forcefield_kwargs["constraints"] = None 298 | system = cplx_forcefield.createSystem( 299 | cplx_model.topology, 300 | residueTemplates=residueTemplates, 301 | **self.forcefield_kwargs, 302 | ) 303 | self.constrain_pocket( 304 | system, 305 | cplx_model, 306 | missing_residues=self.missing_residues, 307 | level=self.is_constrain[1], 308 | ) 309 | else: 310 | system = cplx_forcefield.createSystem( 311 | cplx_model.topology, 312 | residueTemplates=residueTemplates, 313 | **self.forcefield_kwargs, 314 | ) 315 | 316 | integrator = omm.LangevinMiddleIntegrator( 317 | 300, 1, 0.004 318 | ) # from heyi mononor parameters 319 | # only use one cpu per relaxation 320 | simulation = app.Simulation( 321 | cplx_model.topology, 322 | system, 323 | integrator, 324 | self.platform, 325 | self.platform_properties, 326 | ) 327 | simulation.context.setPositions(cplx_model.positions) 328 | logger.info( 329 | f"Minimizing with {self.receptor_ffname} (protein) + {self.ligand_ffname} (ligand) + {self.partial_chargename} charge..." 330 | ) 331 | min_reporter = MinReporter() 332 | simulation.minimizeEnergy( 333 | maxIterations=self.max_iteration, 334 | # reporter=min_reporter, 335 | # tolerance=10.0 * unit.kilocalories_per_mole / unit.angstrom, 336 | # reporter=StateDataReporter( 337 | # stdout, 10, step=True, potentialEnergy=True, temperature=True 338 | # ), 339 | ) 340 | 341 | state = simulation.context.getState(getPositions=True) 342 | minimized_positions = state.getPositions(asNumpy=True) 343 | return minimized_positions 344 | 345 | def prepare_one_cplx_and_relax(self, ligand_fn: Path) -> None: 346 | ligand_omm, ligand_off, ligand_rdk = load_sdf_to_omm(ligand_fn) 347 | receptor_rdmol = Chem.Mol(self.receptor_rdmol) 348 | 349 | cplx_model = app.Modeller( 350 | self.receptor_omm.topology, self.receptor_omm.positions 351 | ) 352 | cplx_model.add(ligand_omm.topology, ligand_omm.positions) 353 | logger.info(f"{len(cplx_model.positions)=}") 354 | 355 | cplx_forcefield = self._add_lig_forcefield(ligand_off, ligand_fn) 356 | relaxed_pos = self._minimize_energy(cplx_forcefield, cplx_model) 357 | 358 | angstroms_pos = relaxed_pos.value_in_unit(unit.angstroms) 359 | nanometers_pos = relaxed_pos.value_in_unit(unit.nanometers) 360 | 361 | rec_num = len(self.receptor_omm.positions) 362 | pos1 = angstroms_pos[:rec_num] 363 | pos2 = angstroms_pos[rec_num:] 364 | 365 | self.receptor_omm.positions = nanometers_pos[:rec_num] * unit.nanometers 366 | 367 | protein_mol = assign_mol_with_pos(receptor_rdmol, pos1) 368 | ligand_mol = assign_mol_with_pos(ligand_rdk, pos2) 369 | 370 | return protein_mol, ligand_mol 371 | 372 | def run_batch(self, batch_fns: List[Path]) -> None: 373 | for ligand_ in batch_fns: 374 | self.prepare_one_cplx_and_relax(ligand_) 375 | 376 | return f"finished: {batch_fns}" 377 | -------------------------------------------------------------------------------- /dataset/utils/pdb_helper.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import * 3 | 4 | import numpy as np 5 | from lxml import etree 6 | from pdbfixer.pdbfixer import PDBFixer 7 | from rdkit import Chem 8 | 9 | PROTEIN_RESIDUES = [ 10 | "ALA", 11 | "ASN", 12 | "CYS", 13 | "GLU", 14 | "HIS", 15 | "LEU", 16 | "MET", 17 | "PRO", 18 | "THR", 19 | "TYR", 20 | "ARG", 21 | "ASP", 22 | "GLN", 23 | "GLY", 24 | "ILE", 25 | "LYS", 26 | "PHE", 27 | "SER", 28 | "TRP", 29 | "VAL", 30 | ] 31 | METALS = ("NA", "K", "CA", "MG", "FE", "ZN", "CU", "MN", "CO", "NI") 32 | WATERS = ("HOH", "DOD") 33 | 34 | 35 | def format_4letter(atom_name: str): 36 | output = None 37 | if len(atom_name) == 4: 38 | output = atom_name 39 | elif len(atom_name) == 3: 40 | output = " " + atom_name 41 | elif len(atom_name) == 2: 42 | output = " " + atom_name + " " 43 | elif len(atom_name) == 1: 44 | output = " " + atom_name + " " 45 | else: 46 | raise ValueError() 47 | 48 | return output 49 | 50 | 51 | def load_amber_xml() -> Dict[str, int]: 52 | charge_table = {} 53 | hydrogen_table = {} 54 | connect_table = {} 55 | protein_xml_fn = Path(__file__).parent.parent / "amberlib/protein.ff14SB.xml" 56 | # 加载XML文件 57 | tree = etree.parse(protein_xml_fn) 58 | root = tree.getroot() 59 | # 遍历XML文件,获取每个元素的属性值 60 | # 遍历XML树 61 | for element in root: 62 | # print(element.tag, element.attrib) 63 | if element.tag == "Residues": 64 | for residue_element in element: 65 | res_name = residue_element.attrib["name"] 66 | charge_table[res_name] = {} 67 | connect_table[res_name] = {} 68 | hydrogen_table[res_name] = {} 69 | for element in residue_element: 70 | if element.tag == "Atom": 71 | atom_name = element.attrib["name"] 72 | if atom_name in ["NH2", "NZ"]: 73 | charge_table[res_name][atom_name] = 1 74 | elif ( 75 | len(res_name) == 4 76 | and res_name[0] == "N" 77 | and atom_name == "N" 78 | ): 79 | # 不加冒的N端 80 | charge_table[res_name][atom_name] = 1 81 | 82 | elif res_name[-3:] == "HIP" and atom_name == "ND1": 83 | charge_table[res_name][atom_name] = 1 84 | 85 | elif atom_name in ["OXT", "OD2", "OE2"]: 86 | charge_table[res_name][atom_name] = -1 87 | else: 88 | charge_table[res_name][atom_name] = 0 89 | 90 | if element.tag == "Bond": 91 | atom1 = element.attrib["atomName1"] 92 | atom2 = element.attrib["atomName2"] 93 | if set([atom1, atom2]) == set(["C", "O"]): 94 | connect_table[res_name][(atom1, atom2)] = 2 95 | connect_table[res_name][(atom2, atom1)] = 2 96 | elif res_name[-3:] == "ARG" and set([atom1, atom2]) == set( 97 | ["CZ", "NH2"] 98 | ): 99 | connect_table[res_name][(atom1, atom2)] = 2 100 | connect_table[res_name][(atom2, atom1)] = 2 101 | 102 | elif res_name[-3:] in ["HIS", "HIE", "HIP"] and set( 103 | [atom1, atom2] 104 | ) in [set(["CG", "CD2"]), set(["CE1", "ND1"])]: 105 | connect_table[res_name][(atom1, atom2)] = 2 106 | connect_table[res_name][(atom2, atom1)] = 2 107 | elif res_name[-3:] == "HID" and set([atom1, atom2]) in [ 108 | set(["CG", "CD2"]), 109 | set(["CE1", "NE2"]), 110 | ]: 111 | connect_table[res_name][(atom1, atom2)] = 2 112 | connect_table[res_name][(atom2, atom1)] = 2 113 | 114 | elif res_name[-3:] in ["ASP", "ASN", "ASH"] and set( 115 | [atom1, atom2] 116 | ) == set(["CG", "OD1"]): 117 | connect_table[res_name][(atom1, atom2)] = 2 118 | connect_table[res_name][(atom2, atom1)] = 2 119 | 120 | elif res_name[-3:] in ["GLU", "GLN", "GLH"] and set( 121 | [atom1, atom2] 122 | ) == set(["CD", "OE1"]): 123 | connect_table[res_name][(atom1, atom2)] = 2 124 | connect_table[res_name][(atom2, atom1)] = 2 125 | 126 | elif res_name[-3:] in ["PHE", "TYR"] and set( 127 | [atom1, atom2] 128 | ) in [ 129 | set(["CG", "CD1"]), 130 | set(["CE1", "CZ"]), 131 | set(["CE2", "CD2"]), 132 | ]: 133 | connect_table[res_name][(atom1, atom2)] = 2 134 | connect_table[res_name][(atom2, atom1)] = 2 135 | 136 | elif res_name[-3:] == "TRP" and set([atom1, atom2]) in [ 137 | set(["CG", "CD1"]), 138 | set(["CE2", "CD2"]), 139 | set(["CH2", "CZ2"]), 140 | set(["CE3", "CZ3"]), 141 | ]: 142 | connect_table[res_name][(atom1, atom2)] = 2 143 | connect_table[res_name][(atom2, atom1)] = 2 144 | 145 | else: 146 | connect_table[res_name][(atom1, atom2)] = 1 147 | connect_table[res_name][(atom2, atom1)] = 1 148 | 149 | # if res_name == 'NME': 150 | # pass 151 | if atom2[0] == "H" and atom1[0] != "H": 152 | if atom1 not in hydrogen_table[res_name]: 153 | hydrogen_table[res_name][atom1] = [] 154 | hydrogen_table[res_name][atom1].append(atom2) 155 | elif atom1[0] == "H" and atom2[0] != "H": 156 | if atom2 not in hydrogen_table[res_name]: 157 | hydrogen_table[res_name][atom2] = [] 158 | hydrogen_table[res_name][atom2].append(atom1) 159 | 160 | hydrogen_table["HOH"] = {} 161 | hydrogen_table["HOH"]["O"] = ["H1", "H2"] 162 | 163 | charge_table["HOH"] = {} 164 | charge_table["HOH"]["O"] = 0 165 | charge_table["HOH"]["H1"] = 0 166 | charge_table["HOH"]["H2"] = 0 167 | for metal_ in METALS: 168 | charge_table[metal_] = {} 169 | 170 | charge_table["NA"]["NA"] = 1 171 | charge_table["K"]["K"] = 1 172 | charge_table["CA"]["CA"] = 2 173 | charge_table["MG"]["MG"] = 2 174 | charge_table["FE"]["FE"] = 2 175 | # charge_table["FE2"]["FE2"] = 2 176 | charge_table["ZN"]["ZN"] = 2 177 | charge_table["CU"]["CU"] = 2 178 | charge_table["MN"]["MN"] = 2 179 | charge_table["CO"]["CO"] = 3 180 | charge_table["NI"]["NI"] = 2 181 | # charge_table["CL"]["CL"] = -1 182 | 183 | connect_table["HOH"] = {} 184 | connect_table["HOH"][("O", "H1")] = 1 185 | connect_table["HOH"][("H1", "O")] = 1 186 | connect_table["HOH"][("O", "H2")] = 1 187 | connect_table["HOH"][("H2", "O")] = 1 188 | 189 | hydrogen_table["HIS"] = hydrogen_table["HID"] 190 | charge_table["HIS"] = charge_table["HID"] 191 | connect_table["HIS"] = connect_table["HID"] 192 | 193 | return connect_table, charge_table, hydrogen_table 194 | 195 | 196 | def add_conformer_to_mol(new_mol: Chem.Mol, pos) -> Chem.Mol: 197 | conf = Chem.Conformer(new_mol.GetNumAtoms()) 198 | for idx in range(len(pos)): 199 | conf.SetAtomPosition(idx, pos[idx]) 200 | # conf.SetPositions(pos) 201 | new_mol.RemoveAllConformers() 202 | new_mol.AddConformer(conf, assignId=True) 203 | return new_mol 204 | 205 | 206 | def add_h_to_receptor_mol(rec_mol: Chem.Mol, hydrogen_table: Dict[str, List[str]]): 207 | new_mol = Chem.AddHs(rec_mol, addCoords=True) 208 | residue_mapper = {} 209 | for atom in new_mol.GetAtoms(): 210 | atom: Chem.Atom 211 | if atom.GetAtomicNum() != 1: 212 | mi = atom.GetPDBResidueInfo() 213 | res_key = [mi.GetResidueName(), int(mi.GetResidueNumber()), mi.GetChainId(), mi.GetInsertionCode()] 214 | atom_name = mi.GetName().strip() 215 | 216 | tmp_mi = Chem.AtomPDBResidueInfo() 217 | tmp_mi.SetResidueName(f"{res_key[0]:>3s}") 218 | tmp_mi.SetResidueNumber(res_key[1]) 219 | tmp_mi.SetChainId(res_key[2]) 220 | tmp_mi.SetInsertionCode(res_key[3]) 221 | 222 | if tuple(res_key) not in residue_mapper: 223 | residue_mapper[tuple(res_key)] = [] 224 | 225 | residue_mapper[tuple(res_key)].append(atom.GetIdx()) 226 | 227 | h_atoms: List[Chem.Atom] = [ 228 | nbr_atom 229 | for nbr_atom in atom.GetNeighbors() 230 | if nbr_atom.GetAtomicNum() == 1 231 | ] 232 | if len(h_atoms) > 0: 233 | std_h_names = hydrogen_table[res_key[0]][atom_name] 234 | assert len(h_atoms) == len(std_h_names) 235 | for h_atom, h_name in zip(h_atoms, list(sorted(std_h_names))): 236 | tmp_mi.SetName(format_4letter(h_name)) 237 | h_atom.SetMonomerInfo(tmp_mi) 238 | residue_mapper[tuple(res_key)].append(h_atom.GetIdx()) 239 | 240 | resort_list = [] 241 | [resort_list.extend(values) for _, values in residue_mapper.items()] 242 | assert len(resort_list) == new_mol.GetNumAtoms() 243 | new_mol_h = Chem.RenumberAtoms(new_mol, resort_list) 244 | 245 | return new_mol_h 246 | -------------------------------------------------------------------------------- /dataset/utils/repair_pdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tracemalloc import start 3 | import openmm as omm 4 | import openmm.app as app 5 | 6 | from pathlib import Path 7 | import pdbfixer 8 | from rdkit import Chem 9 | from openff.toolkit.topology import Molecule 10 | from collections import defaultdict 11 | import numpy as np 12 | import copy 13 | from openmm import unit, Vec3 14 | from openmm.app import PDBFile 15 | from openmmforcefields.generators import GAFFTemplateGenerator 16 | from io import StringIO 17 | from typing import * 18 | 19 | from dataset.utils.mol_correct_helper import omm_protein_to_rdmol 20 | from dataset.utils.common_helper import create_logger 21 | 22 | logger = create_logger(__name__) 23 | 24 | 25 | class PDBFixer(pdbfixer.PDBFixer): 26 | def addMissingHydrogens(self, pH=7.0, forcefield=None) -> None: 27 | extraDefinitions = self._downloadNonstandardDefinitions() 28 | variants = [ 29 | ( 30 | self._describeVariant(res, extraDefinitions) 31 | if res.name not in ["ACE", "NME"] 32 | else None 33 | ) 34 | for res in self.topology.residues() 35 | ] 36 | modeller = app.Modeller(self.topology, self.positions) 37 | modeller.addHydrogens(pH=pH, forcefield=forcefield, variants=variants) 38 | self.topology = modeller.topology 39 | self.positions = modeller.positions 40 | 41 | 42 | def fix_pdb(pdb_fn: Path, cap_n_ter="ACE", cap_c_ter="NME") -> Chem.Mol: 43 | fixer = PDBFixer(filename=str(pdb_fn)) 44 | 45 | fixer.findNonstandardResidues() 46 | logger.info(f"{fixer.nonstandardResidues=}") 47 | fixer.replaceNonstandardResidues() 48 | 49 | fixer.findMissingResidues() 50 | fixer.findMissingAtoms() 51 | logger.info(f"{fixer.missingResidues=}") 52 | logger.info(f"{fixer.missingAtoms=}") 53 | logger.info(f"{fixer.missingTerminals=}") 54 | 55 | for i in range(len(fixer.topology._chains)): 56 | chain = fixer.topology._chains[i] 57 | if chain._residues[0].name == "HOH": 58 | continue 59 | 60 | first_head = (i, 0) 61 | if first_head not in fixer.missingResidues: 62 | fixer.missingResidues[first_head] = [cap_n_ter] 63 | else: 64 | fixer.missingResidues[first_head] = [cap_n_ter] + fixer.missingResidues[ 65 | first_head 66 | ][-1:] 67 | 68 | last_head = (i, len(chain)) 69 | if last_head not in fixer.missingResidues: 70 | fixer.missingResidues[last_head] = [cap_c_ter] 71 | else: 72 | fixer.missingResidues[last_head] = fixer.missingResidues[last_head][-1:] + [ 73 | cap_c_ter 74 | ] 75 | 76 | fixer.missingTerminals = {} 77 | logger.info(f"{fixer.missingResidues=}") 78 | 79 | start_num = 0 80 | missing_residues = [] 81 | for chain in fixer.topology.chains(): 82 | if chain._residues[0].name == "HOH": 83 | continue 84 | for residue in chain.residues(): 85 | tag = (residue.chain.index, residue.index - start_num) 86 | if tag in fixer.missingResidues: 87 | cur_residues = fixer.missingResidues[tag] 88 | cur_len = len(cur_residues) 89 | [ 90 | missing_residues.append( 91 | (residue.chain.index, str(int(residue.id) - (cur_len - i)), res) 92 | ) 93 | for i, res in enumerate(cur_residues) 94 | ] 95 | 96 | if (residue.index + 1 - start_num) == len(chain): 97 | tag = (residue.chain.index, residue.index + 1 - start_num) 98 | cur_residues = fixer.missingResidues[tag] 99 | [ 100 | missing_residues.append( 101 | (residue.chain.index, str(int(residue.id) + (i + 1)), res) 102 | ) 103 | for i, res in enumerate(cur_residues) 104 | ] 105 | start_num += len(chain) 106 | 107 | logger.info(f"{len(list(fixer.topology.atoms()))=}") 108 | fixer.addMissingAtoms(seed=0) 109 | logger.info(f"add {missing_residues=}") 110 | logger.info(f"{len(list(fixer.topology.atoms()))=}") 111 | 112 | protein_mol = omm_protein_to_rdmol(fixer.topology, fixer.positions) 113 | return protein_mol, missing_residues 114 | 115 | 116 | def test_fix_cif(pdbx_fn: Path, cap_n_ter="ACE", cap_c_ter="NME"): 117 | with open(pdbx_fn) as f: 118 | fixer = pdbfixer.PDBFixer(pdbxfile=f) 119 | # fixer = pdbfixer.PDBFixer(pdbxfile=str(pdbx_fn)) 120 | fixer.findNonstandardResidues() 121 | logger.info(f"{fixer.nonstandardResidues=}") 122 | # fixer.replaceNonstandardResidues() 123 | 124 | fixer.findMissingResidues() 125 | fixer.findMissingAtoms() 126 | logger.info(f"{fixer.missingResidues=}") 127 | logger.info(f"{fixer.missingAtoms=}") 128 | logger.info(f"{fixer.missingTerminals=}") 129 | 130 | for i in range(len(fixer.topology._chains)): 131 | chain = fixer.topology._chains[i] 132 | if chain._residues[0].name == "HOH": 133 | continue 134 | 135 | first_head = (i, 0) 136 | if first_head not in fixer.missingResidues: 137 | fixer.missingResidues[first_head] = [cap_n_ter] 138 | else: 139 | fixer.missingResidues[first_head] = [cap_n_ter] + fixer.missingResidues[ 140 | first_head 141 | ][-1:] 142 | 143 | last_head = (i, len(chain)) 144 | if last_head not in fixer.missingResidues: 145 | fixer.missingResidues[last_head] = [cap_c_ter] 146 | else: 147 | fixer.missingResidues[last_head] = fixer.missingResidues[last_head][-1:] + [ 148 | cap_c_ter 149 | ] 150 | 151 | fixer.missingTerminals = {} 152 | logger.info(f"{fixer.missingResidues=}") 153 | 154 | logger.info(f"{len(list(fixer.topology.atoms()))=}") 155 | fixer.addMissingAtoms(seed=0) 156 | # logger.info(f'{missing_res}') 157 | logger.info(f"{len(list(fixer.topology.atoms()))=}") 158 | # fixer.addMissingHydrogens(7.0) 159 | 160 | protein_mol = omm_protein_to_rdmol(fixer.topology, fixer.positions) 161 | 162 | return protein_mol 163 | 164 | 165 | if __name__ == "__main__": 166 | fix_pdb() 167 | test_fix_cif( 168 | "/{your_path}/protein_ligand_docking_benchmark/posex/mmcif_raw/8UCB_X1T.cif" 169 | ) 170 | -------------------------------------------------------------------------------- /environments/base.yaml: -------------------------------------------------------------------------------- 1 | name: posex 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - alsa-lib=1.2.14=hb9d3cd8_0 9 | - attr=2.5.1=h166bdaf_1 10 | - biopandas=0.5.1=pyhd8ed1ab_1 11 | - biopython=1.78=py311h5eee18b_0 12 | - biotite=1.3.0=py311hfdbb021_0 13 | - biotraj=1.2.2=py311hfdbb021_0 14 | - blas=1.0=openblas 15 | - blosc=1.21.6=he440d0b_1 16 | - brotli=1.1.0=hb9d3cd8_2 17 | - brotli-bin=1.1.0=hb9d3cd8_2 18 | - brotli-python=1.1.0=py311hfdbb021_2 19 | - bzip2=1.0.8=h5eee18b_6 20 | - c-ares=1.34.5=hb9d3cd8_0 21 | - ca-certificates=2025.4.26=hbd8a1cb_0 22 | - cairo=1.18.4=h3394656_0 23 | - certifi=2025.4.26=pyhd8ed1ab_0 24 | - cffi=1.17.1=py311hf29c0ef_0 25 | - chardet=5.2.0=pyhd8ed1ab_3 26 | - charset-normalizer=3.4.2=pyhd8ed1ab_0 27 | - click=8.2.1=pyh707e725_0 28 | - colorama=0.4.6=pyhd8ed1ab_1 29 | - contourpy=1.3.2=py311hd18a35c_0 30 | - cycler=0.12.1=pyhd8ed1ab_1 31 | - cyrus-sasl=2.1.27=h54b06d7_7 32 | - dbus=1.13.6=h5008d03_3 33 | - expat=2.7.0=h5888daf_0 34 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 35 | - font-ttf-inconsolata=3.000=h77eed37_0 36 | - font-ttf-source-code-pro=2.038=h77eed37_0 37 | - font-ttf-ubuntu=0.83=h77eed37_3 38 | - fontconfig=2.15.0=h7e30c49_1 39 | - fonts-conda-ecosystem=1=0 40 | - fonts-conda-forge=1=0 41 | - fonttools=4.58.0=py311h2dc5d0c_0 42 | - freetype=2.13.3=ha770c72_1 43 | - freetype-py=2.3.0=pyhd8ed1ab_0 44 | - gettext=0.24.1=h5888daf_0 45 | - gettext-tools=0.24.1=h5888daf_0 46 | - glew=2.1.0=h9c3ff4c_2 47 | - glib=2.84.0=h07242d1_0 48 | - glib-tools=2.84.0=h4833e2c_0 49 | - glm=1.0.1=hdd259ec_0 50 | - graphite2=1.3.13=h59595ed_1003 51 | - greenlet=3.2.2=py311hfdbb021_0 52 | - gst-plugins-base=1.24.7=h0a52356_0 53 | - gstreamer=1.24.7=hf3bb09a_0 54 | - h2=4.2.0=pyhd8ed1ab_0 55 | - harfbuzz=9.0.0=hda332d3_1 56 | - hdf4=4.2.15=h2a13503_7 57 | - hdf5=1.14.6=nompi_h2d575fe_101 58 | - hpack=4.1.0=pyhd8ed1ab_0 59 | - hyperframe=6.1.0=pyhd8ed1ab_0 60 | - icu=75.1=he02047a_0 61 | - idna=3.10=pyhd8ed1ab_1 62 | - jinja2=3.1.6=pyhd8ed1ab_0 63 | - keyutils=1.6.1=h166bdaf_0 64 | - kiwisolver=1.4.7=py311hd18a35c_0 65 | - krb5=1.21.3=h659f571_0 66 | - lame=3.100=h166bdaf_1003 67 | - lcms2=2.17=h717163a_0 68 | - ld_impl_linux-64=2.40=h12ee557_0 69 | - lerc=4.0.0=h0aef613_1 70 | - libaec=1.1.3=h59595ed_0 71 | - libasprintf=0.24.1=h8e693c7_0 72 | - libasprintf-devel=0.24.1=h8e693c7_0 73 | - libblas=3.9.0=1_h86c2bf4_netlib 74 | - libboost=1.86.0=h6c02f8c_3 75 | - libboost-python=1.86.0=py311h5b7b71f_3 76 | - libbrotlicommon=1.1.0=hb9d3cd8_2 77 | - libbrotlidec=1.1.0=hb9d3cd8_2 78 | - libbrotlienc=1.1.0=hb9d3cd8_2 79 | - libcap=2.75=h39aace5_0 80 | - libcblas=3.9.0=12_h832b8c9_netlib 81 | - libclang-cpp19.1=19.1.7=default_hb5137d0_3 82 | - libclang-cpp20.1=20.1.5=default_h1df26ce_1 83 | - libclang13=20.1.5=default_he06ed0a_1 84 | - libcups=2.3.3=h4637d8d_4 85 | - libcurl=8.13.0=h332b0f4_0 86 | - libdeflate=1.24=h86f0d12_0 87 | - libdrm=2.4.124=hb9d3cd8_0 88 | - libedit=3.1.20250104=pl5321h7949ede_0 89 | - libegl=1.7.0=ha4b6fd6_2 90 | - libev=4.33=hd590300_2 91 | - libevent=2.1.12=hf998b51_1 92 | - libexpat=2.7.0=h5888daf_0 93 | - libffi=3.4.4=h6a678d5_1 94 | - libflac=1.4.3=h59595ed_0 95 | - libfreetype=2.13.3=ha770c72_1 96 | - libfreetype6=2.13.3=h48d6fc4_1 97 | - libgcc=15.1.0=h767d61c_2 98 | - libgcc-ng=15.1.0=h69a702a_2 99 | - libgcrypt-lib=1.11.1=hb9d3cd8_0 100 | - libgettextpo=0.24.1=h5888daf_0 101 | - libgettextpo-devel=0.24.1=h5888daf_0 102 | - libgfortran=15.1.0=h69a702a_2 103 | - libgfortran-ng=15.1.0=h69a702a_2 104 | - libgfortran5=15.1.0=hcea5267_2 105 | - libgl=1.7.0=ha4b6fd6_2 106 | - libglib=2.84.0=h2ff4ddf_0 107 | - libglu=9.0.3=h03adeef_0 108 | - libglvnd=1.7.0=ha4b6fd6_2 109 | - libglx=1.7.0=ha4b6fd6_2 110 | - libgomp=15.1.0=h767d61c_2 111 | - libgpg-error=1.55=h3f2d84a_0 112 | - libiconv=1.18=h4ce23a2_1 113 | - libjpeg-turbo=3.1.0=hb9d3cd8_0 114 | - liblapack=3.9.0=12_hd37a5e2_netlib 115 | - libllvm19=19.1.7=ha7bfdaf_1 116 | - libllvm20=20.1.5=he9d0ab4_0 117 | - liblzma=5.8.1=hb9d3cd8_1 118 | - liblzma-devel=5.8.1=hb9d3cd8_1 119 | - libnetcdf=4.9.2=nompi_h0134ee8_117 120 | - libnghttp2=1.64.0=h161d5f1_0 121 | - libnsl=2.0.1=hd590300_0 122 | - libntlm=1.8=hb9d3cd8_0 123 | - libogg=1.3.5=hd0c01bc_1 124 | - libopenblas=0.3.29=pthreads_h94d23a6_0 125 | - libopengl=1.7.0=ha4b6fd6_2 126 | - libopus=1.5.2=hd0c01bc_0 127 | - libpciaccess=0.18=hd590300_0 128 | - libpng=1.6.47=h943b412_0 129 | - libpq=16.9=h87c4ccc_0 130 | - librdkit=2024.09.1=h84b0b3c_3 131 | - libsndfile=1.2.2=hc60ed4a_1 132 | - libsqlite=3.49.2=hee588c1_0 133 | - libssh2=1.11.1=hcf80075_0 134 | - libstdcxx=15.1.0=h8f9b012_2 135 | - libstdcxx-ng=15.1.0=h4852527_2 136 | - libsystemd0=257.4=h4e0b6ca_1 137 | - libtiff=4.7.0=hf01ce69_5 138 | - libuuid=2.38.1=h0b41bf4_0 139 | - libvorbis=1.3.7=h9c3ff4c_0 140 | - libwebp-base=1.5.0=h851e524_0 141 | - libxcb=1.17.0=h8a09558_0 142 | - libxcrypt=4.4.36=hd590300_1 143 | - libxkbcommon=1.10.0=h65c71a3_0 144 | - libxml2=2.13.8=h4bc477f_0 145 | - libzip=1.11.2=h6991a6a_0 146 | - libzlib=1.3.1=hb9d3cd8_2 147 | - loguru=0.7.3=pyh707e725_0 148 | - looseversion=1.3.0=pyhd8ed1ab_0 149 | - lz4-c=1.10.0=h5888daf_1 150 | - markupsafe=3.0.2=py311h2dc5d0c_1 151 | - matplotlib-base=3.10.3=py311h2b939e6_0 152 | - mmtf-python=1.1.3=pyhd8ed1ab_0 153 | - mpg123=1.32.9=hc50e24c_0 154 | - msgpack-python=1.1.0=py311hd18a35c_0 155 | - munkres=1.1.4=pyh9f0ad1d_0 156 | - mysql-common=9.0.1=h266115a_6 157 | - mysql-libs=9.0.1=he0572af_6 158 | - ncurses=6.5=h2d0b736_3 159 | - networkx=3.4.2=pyh267e887_2 160 | - nspr=4.36=h5888daf_0 161 | - nss=3.111=h159eef7_0 162 | - numpy=1.26.4=py311h24aa872_0 163 | - numpy-base=1.26.4=py311hbfb1bba_0 164 | - openjpeg=2.5.3=h5fbd93e_0 165 | - openldap=2.6.10=he970967_0 166 | - openssl=3.5.0=h7b32b05_1 167 | - packaging=25.0=pyh29332c3_1 168 | - pandas=2.2.3=py311h7db5c69_3 169 | - pcre2=10.44=hc749103_2 170 | - pillow=11.2.1=py311h1322bbf_0 171 | - pip=25.1=pyhc872135_2 172 | - pixman=0.46.0=h29eaf8c_0 173 | - ply=3.11=pyhd8ed1ab_3 174 | - pmw=2.0.1=py311h38be061_1008 175 | - posebusters=0.4.4=pyhd8ed1ab_0 176 | - prody=2.4.1=pyh5e1b82b_2 177 | - pthread-stubs=0.4=hb9d3cd8_1002 178 | - pulseaudio-client=17.0=hac146a9_1 179 | - pycairo=1.28.0=py311hd785cd9_0 180 | - pycparser=2.22=pyh29332c3_1 181 | - pymol-open-source=3.0.0=py311hbd307dc_8 182 | - pyparsing=3.1.1=pyhd8ed1ab_0 183 | - pyqt=5.15.11=py311he22028a_0 184 | - pyqt5-sip=12.17.0=py311hfdbb021_0 185 | - pysocks=1.7.1=pyha55dd90_7 186 | - python=3.11.11=h9e4cc4f_2_cpython 187 | - python-dateutil=2.9.0.post0=pyhff2d567_1 188 | - python-tzdata=2025.2=pyhd8ed1ab_0 189 | - python_abi=3.11=7_cp311 190 | - pytz=2025.2=pyhd8ed1ab_0 191 | - pyyaml=6.0.2=py311h2dc5d0c_2 192 | - qhull=2020.2=h434a139_5 193 | - qt-main=5.15.15=h374914d_0 194 | - rdkit=2024.09.1=py311h75c149a_3 195 | - readline=8.2=h5eee18b_0 196 | - reportlab=4.4.1=py311h9ecbd09_0 197 | - requests=2.32.3=pyhd8ed1ab_1 198 | - rlpycairo=0.2.0=pyhd8ed1ab_0 199 | - scipy=1.15.2=py311h8f841c2_0 200 | - setuptools=78.1.1=py311h06a4308_0 201 | - sip=6.10.0=py311hfdbb021_0 202 | - six=1.17.0=pyhd8ed1ab_0 203 | - snappy=1.2.1=h8bd8927_1 204 | - sqlalchemy=2.0.41=py311h9ecbd09_0 205 | - sqlite=3.49.2=h9eae976_0 206 | - tk=8.6.13=noxft_h4845f30_101 207 | - toml=0.10.2=pyhd8ed1ab_1 208 | - tomli=2.2.1=pyhd8ed1ab_1 209 | - tqdm=4.67.1=pyhd8ed1ab_1 210 | - typing-extensions=4.13.2=h0e9735f_0 211 | - typing_extensions=4.13.2=pyh29332c3_0 212 | - tzdata=2025b=h04d1e81_0 213 | - unicodedata2=16.0.0=py311h9ecbd09_0 214 | - urllib3=2.4.0=pyhd8ed1ab_0 215 | - wheel=0.45.1=py311h06a4308_0 216 | - xcb-util=0.4.1=hb711507_2 217 | - xcb-util-image=0.4.0=hb711507_2 218 | - xcb-util-keysyms=0.4.1=hb711507_0 219 | - xcb-util-renderutil=0.3.10=hb711507_0 220 | - xcb-util-wm=0.4.2=hb711507_0 221 | - xkeyboard-config=2.44=hb9d3cd8_0 222 | - xorg-libice=1.1.2=hb9d3cd8_0 223 | - xorg-libsm=1.2.6=he73a12e_0 224 | - xorg-libx11=1.8.12=h4f16b4b_0 225 | - xorg-libxau=1.0.12=hb9d3cd8_0 226 | - xorg-libxcomposite=0.4.6=hb9d3cd8_2 227 | - xorg-libxdamage=1.1.6=hb9d3cd8_0 228 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 229 | - xorg-libxext=1.3.6=hb9d3cd8_0 230 | - xorg-libxfixes=6.0.1=hb9d3cd8_0 231 | - xorg-libxrender=0.9.12=hb9d3cd8_0 232 | - xorg-libxxf86vm=1.1.6=hb9d3cd8_0 233 | - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1005 234 | - xz=5.8.1=hbcc6ac9_1 235 | - xz-gpl-tools=5.8.1=hbcc6ac9_1 236 | - xz-tools=5.8.1=hb9d3cd8_1 237 | - yaml=0.2.5=h7f98852_2 238 | - zlib=1.3.1=hb9d3cd8_2 239 | - zstandard=0.23.0=py311h9ecbd09_2 240 | - zstd=1.5.7=hb8e6e7a_2 241 | - pip: 242 | - bcrypt==4.3.0 243 | - cryptography==45.0.3 244 | - dill==0.4.0 245 | - future==1.0.0 246 | - gitdb==4.0.12 247 | - gitpython==3.1.44 248 | - hatch-cython==0.5.1 249 | - mmcif==0.91.0 250 | - multiprocess==0.70.18 251 | - paramiko==3.5.1 252 | - pdbecif==1.5 253 | - psutil==7.0.0 254 | - pynacl==1.5.0 255 | - rcsb-utils-config==0.41 256 | - rcsb-utils-io==1.49 257 | - rcsb-utils-validation==0.33 258 | - ruamel-yaml==0.18.11 259 | - ruamel-yaml-clib==0.2.12 260 | - smmap==5.0.2 261 | -------------------------------------------------------------------------------- /environments/boltz-1.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.11.7 3 | aiosignal==1.3.1 4 | antlr4-python3-runtime==4.9.3 5 | async-timeout==5.0.1 6 | attrs==24.2.0 7 | biopython==1.84 8 | boltz==0.4.0 9 | certifi==2024.8.30 10 | charset-normalizer==3.4.0 11 | click==8.1.7 12 | dm-tree==0.1.8 13 | docker-pycreds==0.4.0 14 | einops==0.8.0 15 | einx==0.3.0 16 | fairscale==0.4.13 17 | filelock==3.16.1 18 | frozendict==2.4.6 19 | frozenlist==1.5.0 20 | fsspec==2024.10.0 21 | gitdb==4.0.11 22 | GitPython==3.1.43 23 | hydra-core==1.3.2 24 | idna==3.10 25 | ihm==1.7 26 | Jinja2==3.1.4 27 | lightning-utilities==0.11.9 28 | MarkupSafe==3.0.2 29 | mashumaro==3.14 30 | modelcif==1.2 31 | mpmath==1.3.0 32 | msgpack==1.1.0 33 | multidict==6.1.0 34 | networkx==3.4.2 35 | numpy==1.26.3 36 | nvidia-cublas-cu12==12.4.5.8 37 | nvidia-cuda-cupti-cu12==12.4.127 38 | nvidia-cuda-nvrtc-cu12==12.4.127 39 | nvidia-cuda-runtime-cu12==12.4.127 40 | nvidia-cudnn-cu12==9.1.0.70 41 | nvidia-cufft-cu12==11.2.1.3 42 | nvidia-curand-cu12==10.3.5.147 43 | nvidia-cusolver-cu12==11.6.1.9 44 | nvidia-cusparse-cu12==12.3.1.170 45 | nvidia-nccl-cu12==2.21.5 46 | nvidia-nvjitlink-cu12==12.4.127 47 | nvidia-nvtx-cu12==12.4.127 48 | omegaconf==2.3.0 49 | packaging==24.2 50 | pandas==2.2.3 51 | pillow==11.0.0 52 | platformdirs==4.3.6 53 | propcache==0.2.0 54 | protobuf==5.28.3 55 | psutil==6.1.0 56 | python-dateutil==2.9.0.post0 57 | pytorch-lightning==2.4.0 58 | pytz==2024.2 59 | PyYAML==6.0.2 60 | rdkit==2024.3.6 61 | requests==2.32.3 62 | scipy==1.13.1 63 | sentry-sdk==2.19.0 64 | setproctitle==1.3.4 65 | six==1.16.0 66 | smmap==5.0.1 67 | sympy==1.13.1 68 | torch==2.5.1 69 | torchmetrics==1.6.0 70 | tqdm==4.67.1 71 | triton==3.1.0 72 | types-requests==2.32.0.20241016 73 | typing_extensions==4.12.2 74 | tzdata==2024.2 75 | urllib3==2.2.3 76 | wandb==0.18.7 77 | yarl==1.18.0 78 | -------------------------------------------------------------------------------- /environments/boltz-1x.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.6.1 2 | aiohttp==3.11.18 3 | aiosignal==1.3.2 4 | antlr4-python3-runtime==4.9.3 5 | attrs==25.3.0 6 | biopython==1.84 7 | boltz==1.0.0 8 | certifi==2025.4.26 9 | charset-normalizer==3.4.1 10 | click==8.1.7 11 | dm-tree==0.1.8 12 | docker-pycreds==0.4.0 13 | einops==0.8.0 14 | einx==0.3.0 15 | fairscale==0.4.13 16 | filelock==3.18.0 17 | frozendict==2.4.6 18 | frozenlist==1.6.0 19 | fsspec==2025.3.2 20 | gitdb==4.0.12 21 | GitPython==3.1.44 22 | hydra-core==1.3.2 23 | idna==3.10 24 | ihm==2.5 25 | jaxtyping==0.3.2 26 | Jinja2==3.1.6 27 | lightning-utilities==0.14.3 28 | llvmlite==0.44.0 29 | MarkupSafe==3.0.2 30 | mashumaro==3.14 31 | modelcif==1.2 32 | mpmath==1.3.0 33 | msgpack==1.1.0 34 | multidict==6.4.3 35 | networkx==3.4.2 36 | numba==0.61.0 37 | numpy==1.26.3 38 | nvidia-cublas-cu12==12.6.4.1 39 | nvidia-cuda-cupti-cu12==12.6.80 40 | nvidia-cuda-nvrtc-cu12==12.6.77 41 | nvidia-cuda-runtime-cu12==12.6.77 42 | nvidia-cudnn-cu12==9.5.1.17 43 | nvidia-cufft-cu12==11.3.0.4 44 | nvidia-cufile-cu12==1.11.1.6 45 | nvidia-curand-cu12==10.3.7.77 46 | nvidia-cusolver-cu12==11.7.1.2 47 | nvidia-cusparse-cu12==12.5.4.2 48 | nvidia-cusparselt-cu12==0.6.3 49 | nvidia-nccl-cu12==2.26.2 50 | nvidia-nvjitlink-cu12==12.6.85 51 | nvidia-nvtx-cu12==12.6.77 52 | omegaconf==2.3.0 53 | packaging==25.0 54 | pandas==2.2.3 55 | pillow==11.2.1 56 | platformdirs==4.3.7 57 | propcache==0.3.1 58 | protobuf==5.29.4 59 | psutil==7.0.0 60 | python-dateutil==2.9.0.post0 61 | pytorch-lightning==2.4.0 62 | pytz==2025.2 63 | PyYAML==6.0.2 64 | rdkit==2024.9.6 65 | requests==2.32.3 66 | scipy==1.13.1 67 | sentry-sdk==2.27.0 68 | setproctitle==1.3.5 69 | six==1.17.0 70 | smmap==5.0.2 71 | sympy==1.14.0 72 | torch==2.7.0 73 | torchmetrics==1.7.1 74 | tqdm==4.67.1 75 | trifast==0.1.12 76 | triton==3.3.0 77 | types-requests==2.32.0.20250328 78 | typing_extensions==4.13.2 79 | tzdata==2025.2 80 | urllib3==2.4.0 81 | wadler_lindig==0.1.5 82 | wandb==0.18.7 83 | yarl==1.20.0 84 | -------------------------------------------------------------------------------- /environments/chai-1.txt: -------------------------------------------------------------------------------- 1 | aiobotocore==2.15.2 2 | aiohappyeyeballs==2.4.3 3 | aiohttp==3.11.2 4 | aioitertools==0.12.0 5 | aiosignal==1.3.1 6 | annotated-types==0.7.0 7 | antipickle==0.2.0 8 | asttokens==2.4.1 9 | attrs==24.2.0 10 | beartype==0.19.0 11 | biopython==1.83 12 | botocore==1.35.36 13 | cachetools==5.5.0 14 | certifi==2024.8.30 15 | cfgv==3.4.0 16 | chai_lab==0.4.1 17 | charset-normalizer==3.4.0 18 | click==8.1.7 19 | comm==0.2.2 20 | contourpy==1.3.1 21 | cycler==0.12.1 22 | db-dtypes==1.3.1 23 | debugpy==1.8.8 24 | decorator==5.1.1 25 | distlib==0.3.9 26 | einops==0.8.0 27 | executing==2.1.0 28 | filelock==3.16.1 29 | fonttools==4.55.0 30 | frozenlist==1.5.0 31 | fsspec==2024.10.0 32 | gcsfs==2024.10.0 33 | gemmi==0.6.7 34 | google-api-core==2.23.0 35 | google-auth==2.36.0 36 | google-auth-oauthlib==1.2.1 37 | google-cloud-bigquery==3.27.0 38 | google-cloud-core==2.4.1 39 | google-cloud-storage==2.18.2 40 | google-crc32c==1.6.0 41 | google-resumable-media==2.7.2 42 | googleapis-common-protos==1.66.0 43 | grpcio==1.68.0 44 | grpcio-status==1.68.0 45 | huggingface-hub==0.26.2 46 | identify==2.6.2 47 | idna==3.10 48 | ihm==1.7 49 | iniconfig==2.0.0 50 | ipykernel==6.29.5 51 | ipython==8.29.0 52 | jaxtyping==0.2.34 53 | jedi==0.19.2 54 | Jinja2==3.1.4 55 | jmespath==1.0.1 56 | jupyter_client==8.6.3 57 | jupyter_core==5.7.2 58 | kiwisolver==1.4.7 59 | llvmlite==0.43.0 60 | markdown-it-py==3.0.0 61 | MarkupSafe==3.0.2 62 | matplotlib==3.9.2 63 | matplotlib-inline==0.1.7 64 | mdurl==0.1.2 65 | modelcif==1.2 66 | mpmath==1.3.0 67 | msgpack==1.1.0 68 | multidict==6.1.0 69 | multimethod==1.12 70 | mypy==1.13.0 71 | mypy-extensions==1.0.0 72 | nest-asyncio==1.6.0 73 | networkx==3.4.2 74 | nodeenv==1.9.1 75 | numba==0.60.0 76 | numpy==1.26.4 77 | nvidia-cublas-cu12==12.1.3.1 78 | nvidia-cuda-cupti-cu12==12.1.105 79 | nvidia-cuda-nvrtc-cu12==12.1.105 80 | nvidia-cuda-runtime-cu12==12.1.105 81 | nvidia-cudnn-cu12==8.9.2.26 82 | nvidia-cufft-cu12==11.0.2.54 83 | nvidia-curand-cu12==10.3.2.106 84 | nvidia-cusolver-cu12==11.4.5.107 85 | nvidia-cusparse-cu12==12.1.0.106 86 | nvidia-nccl-cu12==2.20.5 87 | nvidia-nvjitlink-cu12==12.6.77 88 | nvidia-nvtx-cu12==12.1.105 89 | oauthlib==3.2.2 90 | packaging==24.2 91 | pandas==2.2.3 92 | pandas-gbq==0.24.0 93 | pandas-stubs==2.2.3.241009 94 | pandera==0.21.0 95 | parso==0.8.4 96 | pexpect==4.9.0 97 | pillow==11.0.0 98 | platformdirs==4.3.6 99 | pluggy==1.5.0 100 | pre_commit==4.0.1 101 | prompt_toolkit==3.0.48 102 | propcache==0.2.0 103 | proto-plus==1.25.0 104 | protobuf==5.28.3 105 | psutil==6.1.0 106 | ptyprocess==0.7.0 107 | pure_eval==0.2.3 108 | pyarrow==18.0.0 109 | pyasn1==0.6.1 110 | pyasn1_modules==0.4.1 111 | pydantic==2.9.2 112 | pydantic_core==2.23.4 113 | pydata-google-auth==1.8.2 114 | Pygments==2.18.0 115 | pyparsing==3.2.0 116 | pytest==8.3.3 117 | python-dateutil==2.9.0.post0 118 | pytz==2024.2 119 | PyYAML==6.0.2 120 | pyzmq==26.2.0 121 | rdkit==2023.9.5 122 | regex==2024.11.6 123 | requests==2.32.3 124 | requests-oauthlib==2.0.0 125 | rich==13.9.4 126 | rsa==4.9 127 | ruff==0.6.3 128 | s3fs==2024.10.0 129 | safetensors==0.4.5 130 | setuptools==75.5.0 131 | shellingham==1.5.4 132 | six==1.16.0 133 | stack-data==0.6.3 134 | sympy==1.13.3 135 | tmtools==0.2.0 136 | tokenizers==0.20.3 137 | torch==2.3.1 138 | tornado==6.4.1 139 | tqdm==4.67.0 140 | traitlets==5.14.3 141 | transformers==4.46.2 142 | typeguard==2.13.3 143 | typer==0.13.0 144 | types-pytz==2024.2.0.20241003 145 | types-PyYAML==6.0.12.20240917 146 | types-requests==2.32.0.20241016 147 | types-tqdm==4.66.0.20240417 148 | typing-inspect==0.9.0 149 | typing_extensions==4.12.2 150 | tzdata==2024.2 151 | urllib3==2.2.3 152 | virtualenv==20.27.1 153 | wcwidth==0.2.13 154 | wheel==0.45.0 155 | wrapt==1.16.0 156 | yarl==1.17.1 157 | -------------------------------------------------------------------------------- /environments/relax.yaml: -------------------------------------------------------------------------------- 1 | name: relax 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - openmmforcefields 6 | - openmm 7 | - openff-toolkit 8 | - biotite 9 | - pdbfixer 10 | - ambertools 11 | - rdkit 12 | - click -------------------------------------------------------------------------------- /environments/rfaa.yaml: -------------------------------------------------------------------------------- 1 | name: rfaa 2 | channels: 3 | - pyg 4 | - biocore 5 | - pytorch 6 | - nvidia 7 | - bioconda 8 | - conda-forge 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - absl-py=2.1.0=pyhd8ed1ab_0 13 | - aiohttp=3.9.3=py310h2372a71_0 14 | - aiosignal=1.3.1=pyhd8ed1ab_0 15 | - alsa-lib=1.2.8=h166bdaf_0 16 | - asttokens=2.4.1=pyhd8ed1ab_0 17 | - astunparse=1.6.3=pyhd8ed1ab_0 18 | - async-timeout=4.0.3=pyhd8ed1ab_0 19 | - attr=2.5.1=h166bdaf_1 20 | - attrs=23.2.0=pyh71513ae_0 21 | - blas=2.121=mkl 22 | - blas-devel=3.9.0=21_linux64_mkl 23 | - blast-legacy=2.2.26=2 24 | - blinker=1.7.0=pyhd8ed1ab_0 25 | - brotli=1.1.0=hd590300_1 26 | - brotli-bin=1.1.0=hd590300_1 27 | - brotli-python=1.1.0=py310hc6cd4ac_1 28 | - bzip2=1.0.8=hd590300_5 29 | - c-ares=1.27.0=hd590300_0 30 | - ca-certificates=2024.2.2=hbcca054_0 31 | - cached-property=1.5.2=hd8ed1ab_1 32 | - cached_property=1.5.2=pyha770c72_1 33 | - cachetools=5.3.3=pyhd8ed1ab_0 34 | - cairo=1.16.0=ha61ee94_1014 35 | - certifi=2024.2.2=pyhd8ed1ab_0 36 | - cffi=1.16.0=py310h2fee648_0 37 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 38 | - click=8.1.7=unix_pyh707e725_0 39 | - colorama=0.4.6=pyhd8ed1ab_0 40 | - contourpy=1.2.0=py310hd41b1e2_0 41 | - cryptography=42.0.2=py310hb8475ec_0 42 | - cuda-cudart=11.8.89=0 43 | - cuda-cupti=11.8.87=0 44 | - cuda-libraries=11.8.0=0 45 | - cuda-nvrtc=11.8.89=0 46 | - cuda-nvtx=11.8.86=0 47 | - cuda-runtime=11.8.0=0 48 | - cuda-version=11.8=h70ddcb2_3 49 | - cudatoolkit=11.8.0=h4ba93d1_13 50 | - cudnn=8.8.0.121=hcdd5f01_4 51 | - cycler=0.12.1=pyhd8ed1ab_0 52 | - dbus=1.13.6=h5008d03_3 53 | - deepdiff=6.7.1=pyhd8ed1ab_0 54 | - dgl=1.1.2=cuda112py310hc641c19_2 55 | - executing=2.0.1=pyhd8ed1ab_0 56 | - expat=2.6.1=h59595ed_0 57 | - ffmpeg=4.3=hf484d3e_0 58 | - fftw=3.3.10=nompi_hc118613_108 59 | - filelock=3.13.1=pyhd8ed1ab_0 60 | - flatbuffers=22.12.06=hcb278e6_2 61 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 62 | - font-ttf-inconsolata=3.000=h77eed37_0 63 | - font-ttf-source-code-pro=2.038=h77eed37_0 64 | - font-ttf-ubuntu=0.83=h77eed37_1 65 | - fontconfig=2.14.2=h14ed4e7_0 66 | - fonts-conda-ecosystem=1=0 67 | - fonts-conda-forge=1=0 68 | - fonttools=4.49.0=py310h2372a71_0 69 | - freetype=2.12.1=h267a509_2 70 | - frozenlist=1.4.1=py310h2372a71_0 71 | - fsspec=2024.2.0=pyhca7485f_0 72 | - gast=0.4.0=pyh9f0ad1d_0 73 | - gettext=0.21.1=h27087fc_0 74 | - giflib=5.2.1=h0b41bf4_3 75 | - glib=2.78.4=hfc55251_4 76 | - glib-tools=2.78.4=hfc55251_4 77 | - gmp=6.3.0=h59595ed_0 78 | - gmpy2=2.1.2=py310h3ec546c_1 79 | - gnutls=3.6.13=h85f3911_1 80 | - google-auth=2.28.2=pyhca7485f_0 81 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 82 | - google-pasta=0.2.0=pyh8c360ce_0 83 | - graphite2=1.3.13=h58526e2_1001 84 | - grpcio=1.51.1=py310h4a5735c_1 85 | - gst-plugins-base=1.22.0=h4243ec0_2 86 | - gstreamer=1.22.0=h25f0c4b_2 87 | - gstreamer-orc=0.4.38=hd590300_0 88 | - gzip=1.13=hd590300_0 89 | - h5py=3.9.0=nompi_py310hcca72df_101 90 | - harfbuzz=6.0.0=h8e241bc_0 91 | - hdf5=1.14.1=nompi_h4f84152_100 92 | - hhsuite=3.3.0=py310pl5321h068649b_10 93 | - icecream=2.1.3=pyhd8ed1ab_0 94 | - icu=70.1=h27087fc_0 95 | - idna=3.6=pyhd8ed1ab_0 96 | - importlib-metadata=7.0.2=pyha770c72_0 97 | - jack=1.9.22=h11f4161_0 98 | - jinja2=3.1.3=pyhd8ed1ab_0 99 | - joblib=1.3.2=pyhd8ed1ab_0 100 | - jpeg=9e=h0b41bf4_3 101 | - keras=2.11.0=pyhd8ed1ab_0 102 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 103 | - keyutils=1.6.1=h166bdaf_0 104 | - kiwisolver=1.4.5=py310hd41b1e2_1 105 | - krb5=1.20.1=h81ceb04_0 106 | - lame=3.100=h166bdaf_1003 107 | - lcms2=2.15=hfd0df8a_0 108 | - ld_impl_linux-64=2.40=h41732ed_0 109 | - lerc=4.0.0=h27087fc_0 110 | - libabseil=20220623.0=cxx17_h05df665_6 111 | - libaec=1.1.2=h59595ed_1 112 | - libblas=3.9.0=21_linux64_mkl 113 | - libbrotlicommon=1.1.0=hd590300_1 114 | - libbrotlidec=1.1.0=hd590300_1 115 | - libbrotlienc=1.1.0=hd590300_1 116 | - libcap=2.67=he9d0100_0 117 | - libcblas=3.9.0=21_linux64_mkl 118 | - libclang=15.0.7=default_hb11cfb5_4 119 | - libclang13=15.0.7=default_ha2b6cf4_4 120 | - libcublas=11.11.3.6=0 121 | - libcufft=10.9.0.58=0 122 | - libcufile=1.9.0.20=0 123 | - libcups=2.3.3=h36d4200_3 124 | - libcurand=10.3.5.119=0 125 | - libcurl=8.1.2=h409715c_0 126 | - libcusolver=11.4.1.48=0 127 | - libcusparse=11.7.5.86=0 128 | - libdb=6.2.32=h9c3ff4c_0 129 | - libdeflate=1.17=h0b41bf4_0 130 | - libedit=3.1.20191231=he28a2e2_2 131 | - libev=4.33=hd590300_2 132 | - libevent=2.1.10=h28343ad_4 133 | - libexpat=2.6.1=h59595ed_0 134 | - libffi=3.4.2=h7f98852_5 135 | - libflac=1.4.3=h59595ed_0 136 | - libgcc-ng=13.2.0=h807b86a_5 137 | - libgcrypt=1.10.3=hd590300_0 138 | - libgfortran-ng=13.2.0=h69a702a_5 139 | - libgfortran5=13.2.0=ha4646dd_5 140 | - libglib=2.78.4=hf2295e7_4 141 | - libgomp=13.2.0=h807b86a_5 142 | - libgpg-error=1.48=h71f35ed_0 143 | - libgrpc=1.51.1=h4fad500_1 144 | - libhwloc=2.9.1=hd6dc26d_0 145 | - libiconv=1.17=hd590300_2 146 | - liblapack=3.9.0=21_linux64_mkl 147 | - liblapacke=3.9.0=21_linux64_mkl 148 | - libllvm15=15.0.7=hadd5161_1 149 | - libnghttp2=1.58.0=h47da74e_0 150 | - libnpp=11.8.0.86=0 151 | - libnsl=2.0.1=hd590300_0 152 | - libnvjpeg=11.9.0.86=0 153 | - libogg=1.3.4=h7f98852_1 154 | - libopus=1.3.1=h7f98852_1 155 | - libpng=1.6.43=h2797004_0 156 | - libpq=15.3=hbcd7760_1 157 | - libprotobuf=3.21.12=hfc55251_2 158 | - libsndfile=1.2.2=hc60ed4a_1 159 | - libsqlite=3.45.1=h2797004_0 160 | - libssh2=1.11.0=h0841786_0 161 | - libstdcxx-ng=13.2.0=h7e041cc_5 162 | - libsystemd0=253=h8c4010b_1 163 | - libtiff=4.5.0=h6adf6a1_2 164 | - libtool=2.4.7=h27087fc_0 165 | - libudev1=253=h0b41bf4_1 166 | - libuuid=2.38.1=h0b41bf4_0 167 | - libuv=1.48.0=hd590300_0 168 | - libvorbis=1.3.7=h9c3ff4c_0 169 | - libwebp-base=1.3.2=hd590300_0 170 | - libxcb=1.13=h7f98852_1004 171 | - libxcrypt=4.4.36=hd590300_1 172 | - libxkbcommon=1.5.0=h79f4944_1 173 | - libxml2=2.10.3=hca2bb57_4 174 | - libzlib=1.2.13=hd590300_5 175 | - llvm-openmp=17.0.6=h4dfa4b3_0 176 | - lz4-c=1.9.4=hcb278e6_0 177 | - markdown=3.5.2=pyhd8ed1ab_0 178 | - markupsafe=2.1.5=py310h2372a71_0 179 | - matplotlib=3.8.3=py310hff52083_0 180 | - matplotlib-base=3.8.3=py310h62c0568_0 181 | - metis=5.1.1=h59595ed_2 182 | - mkl=2024.0.0=ha957f24_49657 183 | - mkl-devel=2024.0.0=ha770c72_49657 184 | - mkl-include=2024.0.0=ha957f24_49657 185 | - mpc=1.3.1=hfe3b2da_0 186 | - mpfr=4.2.1=h9458935_0 187 | - mpg123=1.32.4=h59595ed_0 188 | - mpmath=1.3.0=pyhd8ed1ab_0 189 | - multidict=6.0.5=py310h2372a71_0 190 | - munkres=1.1.4=pyh9f0ad1d_0 191 | - mysql-common=8.0.33=hf1915f5_6 192 | - mysql-libs=8.0.33=hca2cd23_6 193 | - nccl=2.20.5.1=h6103f9b_0 194 | - ncurses=6.4=h59595ed_2 195 | - nettle=3.6=he412f7d_0 196 | - networkx=3.2.1=pyhd8ed1ab_0 197 | - nspr=4.35=h27087fc_0 198 | - nss=3.98=h1d7d5a4_0 199 | - numpy=1.26.4=py310hb13e2d6_0 200 | - oauthlib=3.2.2=pyhd8ed1ab_0 201 | - openbabel=3.1.1=py310heaf86c6_5 202 | - openh264=2.1.1=h780b84a_0 203 | - openjpeg=2.5.0=hfec8fc6_2 204 | - openssl=3.1.5=hd590300_0 205 | - opt_einsum=3.3.0=pyhc1e730c_2 206 | - ordered-set=4.1.0=pyhd8ed1ab_0 207 | - orjson=3.9.15=py310hcb5633a_0 208 | - packaging=23.2=pyhd8ed1ab_0 209 | - pandas=2.2.1=py310hcc13569_0 210 | - pcre2=10.43=hcad00b1_0 211 | - perl=5.32.1=7_hd590300_perl5 212 | - pillow=9.4.0=py310h023d228_1 213 | - pip=24.0=pyhd8ed1ab_0 214 | - pixman=0.43.2=h59595ed_0 215 | - ply=3.11=py_1 216 | - protobuf=4.21.12=py310heca2aa9_0 217 | - psipred=4.01=1 218 | - psutil=5.9.8=py310h2372a71_0 219 | - pthread-stubs=0.4=h36c2ea0_1001 220 | - pulseaudio=16.1=hcb278e6_3 221 | - pulseaudio-client=16.1=h5195f5e_3 222 | - pulseaudio-daemon=16.1=ha8d29e2_3 223 | - pyasn1=0.5.1=pyhd8ed1ab_0 224 | - pyasn1-modules=0.3.0=pyhd8ed1ab_0 225 | - pycparser=2.21=pyhd8ed1ab_0 226 | - pyg=2.5.0=py310_torch_2.0.0_cu118 227 | - pygments=2.17.2=pyhd8ed1ab_0 228 | - pyjwt=2.8.0=pyhd8ed1ab_1 229 | - pyopenssl=24.0.0=pyhd8ed1ab_0 230 | - pyparsing=3.1.2=pyhd8ed1ab_0 231 | - pyqt=5.15.9=py310h04931ad_5 232 | - pyqt5-sip=12.12.2=py310hc6cd4ac_5 233 | - pysocks=1.7.1=pyha2e5f31_6 234 | - python=3.10.13=hd12c33a_0_cpython 235 | - python-dateutil=2.9.0=pyhd8ed1ab_0 236 | - python-flatbuffers=24.3.6=pyh59ac667_0 237 | - python-tzdata=2024.1=pyhd8ed1ab_0 238 | - python_abi=3.10=4_cp310 239 | - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0 240 | - pytorch-cuda=11.8=h7e8668a_5 241 | - pytorch-mutex=1.0=cuda 242 | - pytz=2024.1=pyhd8ed1ab_0 243 | - pyu2f=0.1.5=pyhd8ed1ab_0 244 | - qt-main=5.15.8=h5d23da1_6 245 | - re2=2023.02.01=hcb278e6_0 246 | - readline=8.2=h8228510_1 247 | - requests=2.31.0=pyhd8ed1ab_0 248 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 249 | - rsa=4.9=pyhd8ed1ab_0 250 | - scikit-learn=1.4.1.post1=py310h1fdf081_0 251 | - scipy=1.12.0=py310hb13e2d6_2 252 | - setuptools=69.1.1=pyhd8ed1ab_0 253 | - sip=6.7.12=py310hc6cd4ac_0 254 | - six=1.16.0=pyh6c4a22f_0 255 | - snappy=1.1.10=h9fff704_0 256 | - sympy=1.12=pypyh9d50eac_103 257 | - tbb=2021.9.0=hf52228f_0 258 | - tensorboard=2.11.2=pyhd8ed1ab_0 259 | - tensorboard-data-server=0.6.1=py310h600f1e7_4 260 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 261 | - tensorflow=2.11.0=cuda112py310he87a039_0 262 | - tensorflow-base=2.11.0=cuda112py310h52da4a5_0 263 | - tensorflow-estimator=2.11.0=cuda112py310h37add04_0 264 | - termcolor=2.4.0=pyhd8ed1ab_0 265 | - threadpoolctl=3.3.0=pyhc1e730c_0 266 | - tk=8.6.13=noxft_h4845f30_101 267 | - toml=0.10.2=pyhd8ed1ab_0 268 | - tomli=2.0.1=pyhd8ed1ab_0 269 | - torchaudio=2.0.2=py310_cu118 270 | - torchtriton=2.0.0=py310 271 | - torchvision=0.15.2=py310_cu118 272 | - tornado=6.4=py310h2372a71_0 273 | - tqdm=4.66.2=pyhd8ed1ab_0 274 | - typing-extensions=4.10.0=hd8ed1ab_0 275 | - typing_extensions=4.10.0=pyha770c72_0 276 | - tzdata=2024a=h0c530f3_0 277 | - unicodedata2=15.1.0=py310h2372a71_0 278 | - unzip=6.0=h7f98852_3 279 | - urllib3=2.2.1=pyhd8ed1ab_0 280 | - werkzeug=3.0.1=pyhd8ed1ab_0 281 | - wheel=0.42.0=pyhd8ed1ab_0 282 | - wrapt=1.16.0=py310h2372a71_0 283 | - xcb-util=0.4.0=h516909a_0 284 | - xcb-util-image=0.4.0=h166bdaf_0 285 | - xcb-util-keysyms=0.4.0=h516909a_0 286 | - xcb-util-renderutil=0.3.9=h166bdaf_0 287 | - xcb-util-wm=0.4.1=h516909a_0 288 | - xkeyboard-config=2.38=h0b41bf4_0 289 | - xorg-kbproto=1.0.7=h7f98852_1002 290 | - xorg-libice=1.1.1=hd590300_0 291 | - xorg-libsm=1.2.4=h7391055_0 292 | - xorg-libx11=1.8.4=h0b41bf4_0 293 | - xorg-libxau=1.0.11=hd590300_0 294 | - xorg-libxdmcp=1.1.3=h7f98852_0 295 | - xorg-libxext=1.3.4=h0b41bf4_2 296 | - xorg-libxrender=0.9.10=h7f98852_1003 297 | - xorg-renderproto=0.11.1=h7f98852_1002 298 | - xorg-xextproto=7.3.0=h0b41bf4_1003 299 | - xorg-xproto=7.0.31=h7f98852_1007 300 | - xz=5.2.6=h166bdaf_0 301 | - yarl=1.9.4=py310h2372a71_0 302 | - zip=3.0=hd590300_3 303 | - zipp=3.17.0=pyhd8ed1ab_0 304 | - zlib=1.2.13=hd590300_5 305 | - zstd=1.5.5=hfc55251_0 306 | - pip: 307 | - antlr4-python3-runtime==4.9.3 308 | - assertpy==1.1 309 | - configparser==6.0.1 310 | - dllogger==1.0.0 311 | - docker-pycreds==0.4.0 312 | - e3nn==0.3.3 313 | - gitdb==4.0.11 314 | - gitpython==3.1.42 315 | - hydra-core==1.3.2 316 | - omegaconf==2.3.0 317 | - opt-einsum-fx==0.1.4 318 | - pathtools==0.1.2 319 | - promise==2.3 320 | - pynvml==11.0.0 321 | - pyrsistent==0.20.0 322 | - pyyaml==6.0.1 323 | - sentry-sdk==1.41.0 324 | - shortuuid==1.0.12 325 | - signalp6==6.0+h 326 | - smmap==5.0.1 327 | - subprocess32==3.5.4 328 | - wandb==0.12.0 329 | 330 | -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CataAI/PoseX/cce9ad3854828744b6bddda32aaa9eece95919a1/figures/logo.png -------------------------------------------------------------------------------- /figures/posex_cross_dock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CataAI/PoseX/cce9ad3854828744b6bddda32aaa9eece95919a1/figures/posex_cross_dock.png -------------------------------------------------------------------------------- /figures/posex_self_dock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CataAI/PoseX/cce9ad3854828744b6bddda32aaa9eece95919a1/figures/posex_self_dock.png -------------------------------------------------------------------------------- /scripts/calculate_benchmark_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | 5 | import pandas as pd 6 | from posebusters import PoseBusters 7 | from tqdm import tqdm 8 | 9 | POSEBUSTER_TEST_COLUMNS = [ 10 | # accuracy # 11 | "rmsd_≤_2å", 12 | # chemical validity and consistency # 13 | "mol_pred_loaded", 14 | "mol_true_loaded", 15 | "mol_cond_loaded", 16 | "sanitization", 17 | "molecular_formula", 18 | "molecular_bonds", 19 | "tetrahedral_chirality", 20 | "double_bond_stereochemistry", 21 | # intramolecular validity # 22 | "bond_lengths", 23 | "bond_angles", 24 | "internal_steric_clash", 25 | "aromatic_ring_flatness", 26 | "double_bond_flatness", 27 | "internal_energy", 28 | # intermolecular validity # 29 | "minimum_distance_to_protein", 30 | "minimum_distance_to_organic_cofactors", 31 | "minimum_distance_to_inorganic_cofactors", 32 | "volume_overlap_with_protein", 33 | "volume_overlap_with_organic_cofactors", 34 | "volume_overlap_with_inorganic_cofactors", 35 | ] 36 | 37 | 38 | def get_group_info(dataset: str, dataset_folder: str) -> pd.DataFrame: 39 | group_dict = defaultdict(list) 40 | for item_name in os.listdir(dataset_folder): 41 | item_dir = os.path.join(dataset_folder, item_name) 42 | if not os.path.isdir(item_dir): 43 | continue 44 | group_dict["PDB_CCD_ID"].append(item_name) 45 | if dataset == "posex_cross_dock": 46 | group_path = os.path.join(dataset_folder, item_name, "group_id.txt") 47 | with open(group_path, "r") as f: 48 | lines = f.readlines() 49 | lines = [line.strip() for line in lines] 50 | group_dict["PDB_GROUP"].append(lines[0]) 51 | group_dict["GROUP"].append(lines[1]) 52 | elif dataset in ["posex_self_dock", "posex_supp"]: 53 | group_dict["PDB_GROUP"].append(item_name) 54 | group_dict["GROUP"].append(item_name) 55 | else: 56 | raise RuntimeError() 57 | df_group = pd.DataFrame(group_dict) 58 | return df_group 59 | 60 | def main(args: argparse.Namespace): 61 | docking_data = pd.read_csv(args.input_file) 62 | bust_dict = defaultdict(list) 63 | 64 | total_samples = len(docking_data["PDB_CCD_ID"]) 65 | for pdb_ccd_id in tqdm(docking_data["PDB_CCD_ID"]): 66 | mol_true = os.path.join(args.dataset_folder, f"{pdb_ccd_id}/{pdb_ccd_id}_ligands.sdf") 67 | if args.model_type == "alphafold3": 68 | pdb_ccd_id = pdb_ccd_id.lower() 69 | mol_cond = os.path.join(args.model_output_folder, f"{pdb_ccd_id}/{pdb_ccd_id}_model_protein_aligned.pdb") 70 | mol_pred = os.path.join(args.model_output_folder, f"{pdb_ccd_id}/{pdb_ccd_id}_model_ligand_aligned.sdf") 71 | if not os.path.exists(mol_pred): 72 | print(f"File {mol_pred} does not exist") 73 | continue 74 | bust_dict["PDB_CCD_ID"].append(pdb_ccd_id.upper()) 75 | bust_dict["mol_pred"].append(mol_pred) 76 | bust_dict["mol_true"].append(mol_true) 77 | bust_dict["mol_cond"].append(mol_cond) 78 | 79 | bust_data = pd.DataFrame(bust_dict) 80 | print("Number of Benchmark Data: ", total_samples) 81 | print("Number of Posebusters Data: ", len(bust_data)) 82 | save_folder = os.path.dirname(args.input_file) 83 | 84 | # Calculate posebusters result 85 | buster = PoseBusters(config="redock", top_n=None, max_workers=None) 86 | buster.config["loading"]["mol_true"]["load_all"] = True 87 | bust_results = buster.bust_table(bust_data, full_report=True) 88 | bust_results["PDB_CCD_ID"] = bust_dict["PDB_CCD_ID"] 89 | if args.dataset in ["posex_self_dock", "posex_cross_dock", "posex_supp"]: 90 | df_group = get_group_info(args.dataset, args.dataset_folder) 91 | df_group_sim = pd.read_csv(os.path.join(args.dataset_folder, "qtm.csv")) 92 | bust_results = pd.merge(bust_results, df_group, on="PDB_CCD_ID") 93 | bust_results = pd.merge(bust_results, df_group_sim, on="GROUP") 94 | if args.dataset == "posex_cross_dock": 95 | total_samples = df_group.GROUP.unique().shape[0] 96 | if args.relax == "true": 97 | res_path = os.path.join(save_folder, f"{args.dataset}_benchmark_result_{args.model_type}_relax.csv") 98 | else: 99 | res_path = os.path.join(save_folder, f"{args.dataset}_benchmark_result_{args.model_type}.csv") 100 | bust_results.to_csv(res_path, index=False) 101 | if args.dataset in ["posex_self_dock", "posex_cross_dock", "posex_supp"]: 102 | test_data = bust_results[POSEBUSTER_TEST_COLUMNS].copy() 103 | bust_results.loc[:, "pb_valid"] = test_data.iloc[:, 1:].all(axis=1) 104 | bust_results = bust_results.groupby("PDB_GROUP").agg({"rmsd": "mean", "pb_valid": "mean", "GROUP": "first"}) 105 | bust_results = bust_results.groupby("GROUP").agg({"rmsd": "mean", "pb_valid": "mean"}) 106 | accuracy = len(bust_results[bust_results["rmsd"] <= 2.0]) / total_samples 107 | print(f"RMSD ≤ 2 Å: {accuracy * 100:.2f}%") 108 | valid_data = bust_results[(bust_results["rmsd"] <= 2) & (bust_results["pb_valid"] >= 0.5)] 109 | print(f"RMSD ≤ 2 Å and PB Valid: {len(valid_data) / total_samples * 100:.2f}%") 110 | else: 111 | # Calculate accuracy 112 | accuracy = len(bust_results[bust_results["rmsd_≤_2å"] == True]) / total_samples 113 | print(f"RMSD ≤ 2 Å: {accuracy * 100:.2f}%") 114 | 115 | # Calculate posebusters test result 116 | test_data = bust_results[POSEBUSTER_TEST_COLUMNS].copy() 117 | test_data.loc[:, "pb_valid"] = test_data.iloc[:, 1:].all(axis=1) 118 | valid_data = test_data[test_data["rmsd_≤_2å"] & test_data["pb_valid"]] 119 | print(f"RMSD ≤ 2 Å and PB Valid: {len(valid_data) / total_samples * 100:.2f}%") 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--input_file", type=str, required=True, help="Path to the benchmark input file") 125 | parser.add_argument("--dataset_folder", type=str, required=True, help="Path to the dataset folder") 126 | parser.add_argument("--model_output_folder", type=str, required=True, help="Path to the model output folder") 127 | parser.add_argument("--dataset", type=str, required=True, help="Dataset name") 128 | parser.add_argument("--model_type", type=str, required=True, help="Model type") 129 | parser.add_argument("--relax", type=str, required=True, help="relax mode (true or false)") 130 | args = parser.parse_args() 131 | 132 | main(args) 133 | -------------------------------------------------------------------------------- /scripts/calculate_benchmark_result.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Check if both arguments are provided 4 | if [ $# -ne 3 ]; then 5 | echo "Error: Missing arguments" 6 | echo "Usage: $0 " 7 | echo "Example: $0 posex_self_dock alphafold3 false" 8 | exit 1 9 | fi 10 | 11 | # Get the dataset and model_type from command-line arguments 12 | DATASET=$1 13 | MODEL_TYPE=$2 14 | RELAX_MODE=$3 15 | 16 | # Set dataset folder based on DATASET 17 | if [ "$DATASET" = "posex_self_dock" ]; then 18 | DATASET_FOLDER="data/dataset/posex/posex_self_docking_set" 19 | elif [ "$DATASET" = "posex_cross_dock" ]; then 20 | DATASET_FOLDER="data/dataset/posex/posex_cross_docking_set" 21 | elif [ "$DATASET" = "posex_supp" ]; then 22 | DATASET_FOLDER="data/dataset/posex/posex_supp_set" 23 | elif [ "$DATASET" = "astex" ]; then 24 | DATASET_FOLDER="data/dataset/posex/astex_diverse_set" 25 | else 26 | echo "Error: Unknown dataset ${DATASET}" 27 | exit 1 28 | fi 29 | 30 | if [ "$RELAX_MODE" = "true" ]; then 31 | MODEL_OUTPUT_FOLDER="data/benchmark/${DATASET}/${MODEL_TYPE}/processed" 32 | elif [ "$RELAX_MODE" = "false" ]; then 33 | MODEL_OUTPUT_FOLDER="data/benchmark/${DATASET}/${MODEL_TYPE}/output" 34 | else 35 | echo "Error: Unknown relax_mode ${RELAX_MODE}" 36 | exit 1 37 | fi 38 | 39 | python scripts/calculate_benchmark_result.py \ 40 | --input_file data/benchmark/${DATASET}/${DATASET}_benchmark.csv \ 41 | --dataset_folder ${DATASET_FOLDER} \ 42 | --model_output_folder ${MODEL_OUTPUT_FOLDER} \ 43 | --model_type ${MODEL_TYPE} \ 44 | --dataset ${DATASET} \ 45 | --relax ${RELAX_MODE} 46 | -------------------------------------------------------------------------------- /scripts/complex_structure_alignment.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Check if both arguments are provided 4 | if [ $# -ne 3 ]; then 5 | echo "Error: Missing arguments" 6 | echo "Usage: $0 " 7 | echo "Example: $0 posex_self_dock alphafold3 false" 8 | exit 1 9 | fi 10 | 11 | # Get the dataset and model_type from command-line arguments 12 | DATASET=$1 13 | MODEL_TYPE=$2 14 | RELAX_MODE=$3 15 | 16 | # Set dataset folder based on DATASET 17 | if [ "$DATASET" = "posex_self_dock" ]; then 18 | DATASET_FOLDER="data/dataset/posex/posex_self_docking_set" 19 | elif [ "$DATASET" = "posex_cross_dock" ]; then 20 | DATASET_FOLDER="data/dataset/posex/posex_cross_docking_set" 21 | elif [ "$DATASET" = "posex_supp" ]; then 22 | DATASET_FOLDER="data/dataset/posex/posex_supp_set" 23 | elif [ "$DATASET" = "astex" ]; then 24 | DATASET_FOLDER="data/dataset/posex/astex_diverse_set" 25 | else 26 | echo "Error: Unknown dataset ${DATASET}" 27 | exit 1 28 | fi 29 | 30 | if [ "$RELAX_MODE" = "true" ]; then 31 | MODEL_OUTPUT_FOLDER="data/benchmark/${DATASET}/${MODEL_TYPE}/processed" 32 | elif [ "$RELAX_MODE" = "false" ]; then 33 | MODEL_OUTPUT_FOLDER="data/benchmark/${DATASET}/${MODEL_TYPE}/output" 34 | else 35 | echo "Error: Unknown relax_mode ${RELAX_MODE}" 36 | exit 1 37 | fi 38 | 39 | python scripts/complex_structure_alignment.py \ 40 | --input_file data/benchmark/${DATASET}/${DATASET}_benchmark.csv \ 41 | --dataset_folder ${DATASET_FOLDER} \ 42 | --model_output_folder ${MODEL_OUTPUT_FOLDER} \ 43 | --model_type ${MODEL_TYPE} 44 | -------------------------------------------------------------------------------- /scripts/convert_to_model_input.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if both arguments are provided 5 | if [ $# -ne 2 ]; then 6 | echo "Error: Missing arguments" 7 | echo "Usage: $0 " 8 | echo "Example: $0 astex alphafold3" 9 | exit 1 10 | fi 11 | 12 | # Get the dataset and model_type from command-line arguments 13 | DATASET=$1 14 | MODEL_TYPE=$2 15 | 16 | python scripts/convert_to_model_input.py \ 17 | --input_file data/benchmark/${DATASET}/${DATASET}_benchmark.csv \ 18 | --output_folder data/benchmark/${DATASET}/${MODEL_TYPE}/input \ 19 | --model_type ${MODEL_TYPE} 20 | -------------------------------------------------------------------------------- /scripts/extract_model_output.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Check if both arguments are provided 4 | if [ $# -ne 2 ]; then 5 | echo "Error: Missing arguments" 6 | echo "Usage: $0 " 7 | echo "Example: $0 posebusters alphafold3" 8 | exit 1 9 | fi 10 | 11 | # Get the dataset and model_type from command-line arguments 12 | DATASET=$1 13 | MODEL_TYPE=$2 14 | 15 | python scripts/extract_model_output.py \ 16 | --input_file data/benchmark/${DATASET}/${DATASET}_benchmark.csv \ 17 | --output_folder data/benchmark/${DATASET}/${MODEL_TYPE}/output \ 18 | --model_type ${MODEL_TYPE} -------------------------------------------------------------------------------- /scripts/generate_docking_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import pandas as pd 6 | from loguru import logger 7 | 8 | 9 | def generate_posex_benchmark(args: argparse.Namespace): 10 | """Generate the docking benchmark for the PoseX dataset 11 | 12 | Args: 13 | args (argparse.Namespace): The input arguments 14 | """ 15 | # Get the docking dataset input folder 16 | if args.dataset == "astex": 17 | input_folder = os.path.join(args.input_folder, "astex_diverse_set") 18 | elif args.dataset == "posex_self_dock": 19 | input_folder = os.path.join(args.input_folder, "posex_self_docking_set") 20 | elif args.dataset == "posex_cross_dock": 21 | input_folder = os.path.join(args.input_folder, "posex_cross_docking_set") 22 | else: 23 | input_folder = os.path.join(args.input_folder, "posex_supp_set") 24 | 25 | # Get the PDB_CCD_IDs 26 | pdb_ccd_ids = [] 27 | for pdb_ccd_id in os.listdir(input_folder): 28 | if os.path.isdir(os.path.join(input_folder, pdb_ccd_id)): 29 | pdb_ccd_ids.append(pdb_ccd_id) 30 | 31 | logger.info(f"Number of PoseX {args.dataset} Data: {len(pdb_ccd_ids)}") 32 | 33 | molecule_smiles_list, protein_sequence_list, pdb_path_list, sdf_path_list = [], [], [], [] 34 | for pdb_ccd_id in pdb_ccd_ids: 35 | with open(os.path.join(input_folder, pdb_ccd_id, f"{pdb_ccd_id}.json"), "r") as f: 36 | input_dict = json.load(f) 37 | 38 | # Get the protein sequences and the SMILES of the ligand 39 | chain_list = [] 40 | for entity in input_dict["sequences"]: 41 | if "protein" in entity: 42 | chain_list.append(entity["protein"]["sequence"]) 43 | elif "ligand" in entity: 44 | molecule_smiles_list.append(entity["ligand"]["smiles"]) 45 | if len(chain_list) > 20: 46 | logger.warning(f"Warning: {pdb_ccd_id} has more than 20 chains.") 47 | if sum(len(chain) for chain in chain_list) > 2000: 48 | logger.warning(f"Warning: {pdb_ccd_id} has more than 2000 amino acids.") 49 | protein_sequence_list.append("|".join(chain_list)) 50 | 51 | # Get the PDB file path and the SDF file path 52 | pdb_path_list.append(os.path.join(input_folder, pdb_ccd_id, f"{pdb_ccd_id}_protein.pdb")) 53 | sdf_path_list.append(os.path.join(input_folder, pdb_ccd_id, f"{pdb_ccd_id}_ligand_start_conf.sdf")) 54 | 55 | # Create the output data 56 | output_data = pd.DataFrame({ 57 | "PDB_CCD_ID": pdb_ccd_ids, 58 | "LIGAND_SMILES": molecule_smiles_list, 59 | "PROTEIN_SEQUENCE": protein_sequence_list, 60 | "PROTEIN_PDB_PATH": pdb_path_list, 61 | "LIGAND_SDF_PATH": sdf_path_list 62 | }) 63 | 64 | # Save the output data to a CSV file 65 | output_path = os.path.join(args.output_folder, f"{args.dataset}_benchmark.csv") 66 | output_data.to_csv(output_path, index=False) 67 | logger.info(f"Saved PoseX {args.dataset} Benchmark to {output_path}") 68 | 69 | 70 | def main(args: argparse.Namespace): 71 | # Check if the dataset is valid 72 | if args.dataset not in ["astex", "posex_self_dock", "posex_cross_dock", "posex_supp"]: 73 | raise ValueError(f"Unknown dataset: {args.dataset}") 74 | 75 | # Check if the output folder exists, if not create it 76 | if not os.path.exists(args.output_folder): 77 | logger.info(f"Output folder {args.output_folder} does not exist, creating it.") 78 | os.makedirs(args.output_folder) 79 | 80 | # Generate the docking benchmark 81 | generate_posex_benchmark(args) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--input_folder", type=str, required=True, help="Path to the input folder containing the PoseX dataset") 87 | parser.add_argument("--output_folder", type=str, required=True, help="Path to the output folder") 88 | parser.add_argument("--dataset", type=str, required=True, help="Dataset name (astex or posex_self_dock or posex_cross_dock)") 89 | args = parser.parse_args() 90 | 91 | main(args) 92 | -------------------------------------------------------------------------------- /scripts/generate_docking_benchmark.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if model type is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | 13 | python scripts/generate_docking_benchmark.py \ 14 | --input_folder data/dataset/posex \ 15 | --output_folder data/benchmark/${DATASET} \ 16 | --dataset ${DATASET} 17 | -------------------------------------------------------------------------------- /scripts/relax_model_outputs.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import traceback 3 | from functools import partial 4 | from pathlib import Path 5 | from typing import * 6 | import click 7 | from tqdm import tqdm 8 | import os 9 | 10 | from dataset.utils.pdb_helper import load_amber_xml 11 | from dataset.utils.pdb_process import ( 12 | fixer_into_protein_mol, 13 | pdb_to_fixer, 14 | protein_mol_to_file, 15 | ligand_rdmol_to_file, 16 | ) 17 | 18 | from dataset.utils.openmm_helper import ProLigRelax 19 | from dataset.utils.common_helper import create_logger 20 | 21 | 22 | logger = create_logger(__name__) 23 | NUM_CPUS = mp.cpu_count() 24 | 25 | 26 | def run_once( 27 | pdb_fn: Path, 28 | cif_dir: Path, 29 | output_dir: Path, 30 | residues_tables: tuple[Dict, Dict, Dict], 31 | num_threads: int = 1, 32 | ): 33 | print(pdb_fn) 34 | tmp_items = pdb_fn.stem.split("_") 35 | tmp_name = "_".join(tmp_items[:2]) 36 | sdf_fn = pdb_fn.parent / f"{tmp_name}_model_ligand.sdf" 37 | cif_fn = cif_dir / f"{tmp_items[0].upper()}.cif" 38 | work_dir = output_dir / f"{tmp_name}" 39 | work_dir.mkdir(parents=True, exist_ok=True) 40 | out_pdb_fn = work_dir / f"{tmp_name}_protein_step1.pdb" 41 | 42 | out_pdb_fn2 = work_dir / f"{tmp_name}_model_protein.pdb" 43 | out_ligand_fn2 = work_dir / f"{tmp_name}_model_ligand.sdf" 44 | if out_ligand_fn2.exists() and out_pdb_fn2.exists(): 45 | return True 46 | 47 | processed_dir = output_dir / pdb_fn.parent.name 48 | processed_dir.mkdir(parents=True, exist_ok=True) 49 | try: 50 | 51 | fixer_noh = pdb_to_fixer(pdb_fn, cif_fn) 52 | prot_mol_h = fixer_into_protein_mol(fixer_noh, residues_tables) 53 | protein_mol_to_file(prot_mol_h, out_pdb_fn) 54 | print(f"{out_pdb_fn=}") 55 | # relax 56 | logger.info(f"Run relaxing...") 57 | relax_tool = ProLigRelax( 58 | prot_mol_h, 59 | missing_residues=[], 60 | # platform=f"CPU:{num_threads}", 61 | platform="CUDA:0", 62 | ligand_ff="openff", 63 | charge_name="mmff94", 64 | is_restrain=(True, "main"), 65 | ) 66 | receptor_rdmol_relaxed, ligand_rdmol_relaxed = ( 67 | relax_tool.prepare_one_cplx_and_relax(sdf_fn) 68 | ) 69 | 70 | logger.info(f"Relax success: {cif_fn=}") 71 | 72 | ligand_rdmol_to_file(ligand_rdmol_relaxed, out_ligand_fn2) 73 | logger.info(f"Writing ligand to {out_ligand_fn2}") 74 | 75 | protein_mol_to_file(receptor_rdmol_relaxed, out_pdb_fn2) 76 | logger.info(f"Writing protein to {out_pdb_fn2}") 77 | logger.info(f'{"#"*10} Succeed: {pdb_fn=}') 78 | 79 | del relax_tool 80 | 81 | except: 82 | tb = traceback.format_exc() 83 | tmp_error_log = output_dir.parent / "error.log" 84 | with open(tmp_error_log, "a") as f: 85 | f.write(f"#PDBNAME#: {tmp_name}\n{tb}") 86 | logger.info(f"{tmp_error_log=}") 87 | 88 | return True 89 | 90 | 91 | def run_batch(pdb_fns, cif_dir, output_dir, num_proc: int = 6): 92 | residues_tables = load_amber_xml() 93 | num_threads = int(NUM_CPUS / num_proc) 94 | with mp.Pool(num_proc) as pool: 95 | func = partial( 96 | run_once, 97 | output_dir=output_dir, 98 | cif_dir=cif_dir, 99 | residues_tables=residues_tables, 100 | num_threads=num_threads, 101 | ) 102 | results = list( 103 | tqdm( 104 | pool.imap(func, pdb_fns, chunksize=1), 105 | total=len(pdb_fns), 106 | desc="Processing", 107 | ) 108 | ) 109 | 110 | print(f"{sum(results)=}") 111 | 112 | 113 | @click.command() 114 | @click.option( 115 | "--input_dir", 116 | type=str, 117 | default="/Users/josephxu/PycharmProjects/debug_iip_code/protein_ligand_docking_benchmark/relax_debug/output", 118 | ) 119 | @click.option("--cif_dir", type=str) 120 | @click.option("--output_dir", type=str, default=None) 121 | @click.option("--num_proc", type=int, default=1) 122 | def main(**kwargs): 123 | input_dir = Path(kwargs["input_dir"]) 124 | if kwargs["output_dir"] is None: 125 | output_dir = input_dir.parent / "processed" 126 | error_fn = input_dir.parent / "error.log" 127 | if error_fn.exists(): 128 | error_fn.unlink(missing_ok=True) 129 | output_dir.mkdir(parents=True, exist_ok=True) 130 | else: 131 | output_dir = Path(kwargs["output_dir"]) 132 | cif_dir = Path(kwargs["cif_dir"]) 133 | pdb_fns = list(input_dir.glob("*/*model_protein.pdb")) 134 | run_batch(pdb_fns, cif_dir, output_dir, int(kwargs["num_proc"])) 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /scripts/run_alphafold3/run_alphafold3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if model type is provided as argument 4 | if [ $# -eq 0 ]; then 5 | echo "Error: Please provide dataset as argument" 6 | echo "Usage: $0 " 7 | exit 1 8 | fi 9 | 10 | DATASET="$1" 11 | 12 | # Run AlphaFold3 13 | docker run -it --rm --gpus all --shm-size=32g -e CUDA_VISIBLE_DEVICES=3 \ 14 | -v ./data/benchmark/${DATASET}/alphafold3/input:/root/af_input \ 15 | -v ./data/benchmark/${DATASET}/alphafold3/output:/root/af_output \ 16 | -v /data/dataset/alphafold3/models:/root/models \ 17 | -v /data/dataset/alphafold3/databases:/root/public_databases \ 18 | brandonsoubasis/alphafold3 \ 19 | python run_alphafold.py \ 20 | --input_dir=/root/af_input \ 21 | --model_dir=/root/models \ 22 | --output_dir=/root/af_output 23 | -------------------------------------------------------------------------------- /scripts/run_boltz/run_boltz.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if model type is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | 13 | BOLTZ_INPUT_FOLDER="data/benchmark/${DATASET}/boltz/input" 14 | BOLTZ_OUTPUT_FOLDER="data/benchmark/${DATASET}/boltz/output" 15 | GPU_ID=2 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate boltz 20 | 21 | for yaml_file in ${BOLTZ_INPUT_FOLDER}/*.yaml; do 22 | filename=$(basename "${yaml_file}" .yaml) 23 | output_folder="${BOLTZ_OUTPUT_FOLDER}/${filename}" 24 | 25 | echo "Predicting ${yaml_file} to ${output_folder} ..." 26 | CUDA_VISIBLE_DEVICES=${GPU_ID} boltz predict ${yaml_file} --out_dir ${output_folder} --cache /data/models/boltz --use_msa_server --diffusion_samples 5 27 | done 28 | 29 | conda activate posex -------------------------------------------------------------------------------- /scripts/run_boltz1x/run_boltz1x.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if model type is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | 13 | BOLTZ_INPUT_FOLDER="data/benchmark/${DATASET}/boltz1x/input" 14 | BOLTZ_OUTPUT_FOLDER="data/benchmark/${DATASET}/boltz1x/output" 15 | GPU_ID=2 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate boltz-1x 20 | 21 | for yaml_file in ${BOLTZ_INPUT_FOLDER}/*.yaml; do 22 | filename=$(basename "${yaml_file}" .yaml) 23 | output_folder="${BOLTZ_OUTPUT_FOLDER}/${filename}" 24 | 25 | echo "Predicting ${yaml_file} to ${output_folder} ..." 26 | CUDA_VISIBLE_DEVICES=${GPU_ID} boltz predict ${yaml_file} --out_dir ${output_folder} --cache /data/models/boltz --use_msa_server --diffusion_samples 5 27 | done 28 | 29 | conda activate posex -------------------------------------------------------------------------------- /scripts/run_chai/run_chai.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | from chai_lab.chai1 import run_inference 6 | 7 | 8 | def main(args: argparse.Namespace): 9 | run_inference( 10 | fasta_file=Path(args.fasta_file), 11 | output_dir=Path(args.output_dir), 12 | num_trunk_recycles=3, 13 | num_diffn_timesteps=200, 14 | seed=42, 15 | device=torch.device(f"cuda:{args.gpu_id}"), 16 | use_esm_embeddings=True, 17 | ) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--fasta_file", type=str, required=True, help="Path to the fasta file") 23 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 24 | parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID") 25 | args = parser.parse_args() 26 | 27 | main(args) 28 | -------------------------------------------------------------------------------- /scripts/run_chai/run_chai.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if model type is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | 13 | CHAI_INPUT_FOLDER="data/benchmark/${DATASET}/chai/input" 14 | CHAI_OUTPUT_FOLDER="data/benchmark/${DATASET}/chai/output" 15 | GPU_ID=1 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate chai 20 | 21 | for fasta_file in ${CHAI_INPUT_FOLDER}/*.fasta; do 22 | filename=$(basename "${fasta_file}" .fasta) 23 | output_folder="${CHAI_OUTPUT_FOLDER}/${filename}" 24 | 25 | echo "Predicting ${fasta_file} to ${output_folder} ..." 26 | CUDA_VISIBLE_DEVICES=${GPU_ID} chai fold --use-msa-server ${fasta_file} ${output_folder} 27 | done 28 | 29 | conda activate posex 30 | -------------------------------------------------------------------------------- /scripts/run_deepdock/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from deepdock.prepare_target.computeTargetMesh import compute_inp_surface 4 | from rdkit import Chem 5 | from deepdock.models import * 6 | from deepdock.DockingFunction import dock_compound, get_random_conformation 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def main(args: argparse.Namespace): 13 | np.random.seed(123) 14 | torch.cuda.manual_seed_all(123) 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10) 17 | target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10) 18 | model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to( 19 | device) 20 | checkpoint = torch.load('/DeepDock/Trained_models/DeepDock_pdbbindv2019_13K_minTestLoss.chk', 21 | map_location=torch.device(device)) 22 | model.load_state_dict(checkpoint['model_state_dict']) 23 | 24 | ligand_filename = f'{args.pdb_ccd_id}_ligand.mol2' 25 | sdf_filename = f'{args.pdb_ccd_id}_ligand_start_conf.sdf' 26 | target_filename = f'{args.pdb_ccd_id}_protein.pdb' 27 | target_ply = f'{args.pdb_ccd_id}_protein.ply' 28 | output_filename = f'{args.pdb_ccd_id}_optimal.sdf' 29 | 30 | compute_inp_surface(target_filename, ligand_filename, dist_threshold=10) 31 | 32 | real_mol = Chem.MolFromMolFile(sdf_filename) 33 | opt_mol, init_mol, result = dock_compound(real_mol, target_ply, model, dist_threshold=3., popsize=150, seed=123, 34 | device=device) 35 | 36 | writer = Chem.SDWriter(output_filename) 37 | writer.write(opt_mol, confId=0) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--pdb_ccd_id", type=str, required=True, help="The PDB-CCD id") 43 | args = parser.parse_args() 44 | 45 | main(args) 46 | -------------------------------------------------------------------------------- /scripts/run_deepdock/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def main(args: argparse.Namespace): 6 | with open(args.evaluation_template_path) as f: 7 | data = f.read() 8 | 9 | for pdb_ccd in os.listdir(args.input_folder): 10 | if os.path.isdir(pdb_ccd): 11 | new_data = data.replace("LIGAND", pdb_ccd) 12 | evaluation_output = os.path.join(args.input_dir_path, pdb_ccd, "evaluation.py") 13 | with open(evaluation_output, "w") as fw: 14 | fw.write(new_data) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--input_folder", type=str, required=True, help="Path to the input folder") 20 | parser.add_argument("--evaluation_template_path", type=str, required=True, help="Path to evaluation template") 21 | args = parser.parse_args() 22 | 23 | main(args) 24 | -------------------------------------------------------------------------------- /scripts/run_deepdock/run_deepdock.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def main(args: argparse.Namespace): 6 | for pdb_ccd in os.listdir(args.input_folder): 7 | os.chdir(os.path.join(args.input_folder, pdb_ccd)) 8 | os.system(f"python evaluate.py --pdb_ccd {pdb_ccd}") 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input_folder", type=str, required=True, help="Path to the input folder") 14 | args = parser.parse_args() 15 | 16 | main(args) 17 | -------------------------------------------------------------------------------- /scripts/run_deepdock/run_deepdock.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | if [ $# -eq 0 ]; then 4 | echo "Error: Please provide dataset as argument" 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | DATASET="$1" 10 | DEEPDOCK_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/deepdock/input" 11 | DEEPDOCK_RUNNING_FOLDER="${PWD}/scripts/run_deepdock" 12 | 13 | EVALUATE_FILE="scripts/run_deepdock/evaluate.py" 14 | for subdir in $DEEPDOCK_INPUT_FOLDER/*; do 15 | output_dir=$subdir 16 | cp $EVALUATE_FILE $output_dir 17 | done 18 | 19 | # DeepDock only accept mol2 input for ligand 20 | # Install openbabel to convert sdf file into mol2 21 | echo "convert sdf file to mol2 using obabel" 22 | eval "$(conda shell.bash hook)" 23 | conda activate openbabel 24 | 25 | for subdir in $DEEPDOCK_INPUT_FOLDER/*; do 26 | subdir_name=$(basename $subdir) 27 | sdf_file="$subdir/"$subdir_name"_ligand.sdf" 28 | output_file="$subdir/"$subdir_name"_ligand.mol2" 29 | obabel $sdf_file -O $output_file 30 | done 31 | 32 | start_time=$(date +%s) 33 | docker run -it \ 34 | -v $DEEPDOCK_INPUT_FOLDER:/DeepDock/eval \ 35 | -v $DEEPDOCK_RUNNING_FOLDER:/DeepDock/run \ 36 | omendezlucio/deepdock \ 37 | python DeepDock/run/run_deepdock.py --input_folder /DeepDock/eval 38 | end_time=$(date +%s) 39 | cost_time=$[ $end_time-$start_time ] 40 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_diffdock/run_diffdock.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | 5 | def main(args: argparse.Namespace): 6 | 7 | subprocess.run( 8 | [ 9 | "python", 10 | "inference.py", 11 | "--protein_ligand_csv", 12 | args.input_csv_path, 13 | "--out_dir", 14 | args.output_dir, 15 | "--cuda_device_index", 16 | str(args.gpu_id), 17 | "--model_dir", 18 | args.model_dir, 19 | "--confidence_model_dir", 20 | args.confidence_model_dir, 21 | "--inference_steps", 22 | "20", 23 | "--samples_per_complex", 24 | "40", 25 | "--actual_steps", 26 | "18", 27 | "--no_final_step_noise", 28 | ], 29 | cwd=args.diffdock_exec_dir, 30 | check=True 31 | ) # nosec 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the protein_ligand_csv file") 38 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 39 | parser.add_argument("--diffdock_exec_dir", type=str, required=True, help="Path to the DiffDock codebase") 40 | parser.add_argument("--model_dir", type=str, required=True, help="Path to the model_dir") 41 | parser.add_argument("--confidence_model_dir", type=str, required=True, help="Path to the confidence_model_dir") 42 | parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID") 43 | args = parser.parse_args() 44 | 45 | main(args) -------------------------------------------------------------------------------- /scripts/run_diffdock/run_diffdock.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | DIFFDOCK_EXEC_FOLDER="path/to/DiffDock" 13 | MODEL_DIR="${DIFFDOCK_EXEC_FOLDER}/workdir/paper_score_model" 14 | CONFIDENCE_MODEL_DIR="${DIFFDOCK_EXEC_FOLDER}/workdir/paper_confidence_model" 15 | DIFFDOCK_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock/input") 16 | DIFFDOCK_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock/output") 17 | GPU_ID=1 18 | 19 | # init conda 20 | eval "$(conda shell.bash hook)" 21 | conda activate diffdock 22 | 23 | input_csv_path="${DIFFDOCK_INPUT_FOLDER}/data.csv" 24 | python scripts/run_diffdock/run_diffdock.py \ 25 | --input_csv_path ${input_csv_path} \ 26 | --output_dir ${DIFFDOCK_OUTPUT_FOLDER} \ 27 | --gpu_id ${GPU_ID} \ 28 | --diffdock_exec_dir ${DIFFDOCK_EXEC_FOLDER} \ 29 | --model_dir ${MODEL_DIR} \ 30 | --confidence_model_dir ${CONFIDENCE_MODEL_DIR} 31 | -------------------------------------------------------------------------------- /scripts/run_diffdock_l/run_diffdock_l.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | 5 | def main(args: argparse.Namespace): 6 | subprocess.run( 7 | [ 8 | "python", 9 | "inference.py", 10 | "--protein_ligand_csv", 11 | args.input_csv_path, 12 | "--out_dir", 13 | args.output_dir, 14 | "--config", 15 | args.config_path, 16 | ], 17 | cwd=args.diffdock_exec_dir, 18 | check=True 19 | ) 20 | 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the protein_ligand_csv file") 26 | parser.add_argument("--diffdock_exec_dir", type=str, required=True, help="Path to the DiffDock_L codebase") 27 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 28 | parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") 29 | args = parser.parse_args() 30 | 31 | main(args) -------------------------------------------------------------------------------- /scripts/run_diffdock_l/run_diffdock_l.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | DIFFDOCK_EXEC_FOLDER="path/to/DiffDock_L" 13 | DIFFDOCK_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock_l/input") 14 | DIFFDOCK_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock_l/output") 15 | 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate diffdock 20 | 21 | input_csv_path="${DIFFDOCK_INPUT_FOLDER}/data.csv" 22 | config_path="${DIFFDOCK_EXEC_FOLDER}/default_inference_args.yaml" 23 | python scripts/run_diffdock_l/run_diffdock_l.py \ 24 | --input_csv_path ${input_csv_path} \ 25 | --config_path ${config_path} \ 26 | --output_dir ${DIFFDOCK_OUTPUT_FOLDER} \ 27 | --diffdock_exec_dir ${DIFFDOCK_EXEC_FOLDER} 28 | -------------------------------------------------------------------------------- /scripts/run_diffdock_pocket/run_diffdock_pocket.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | DIFFDOCK_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock_pocket/input" 13 | DIFFDOCK_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/diffdock_pocket/output" 14 | DIFFDOCK_EXEC_PATH="path/to/Diffdock_pocket" 15 | 16 | cd $DIFFDOCK_EXEC_PATH 17 | 18 | start_time=$(date +%s) 19 | python inference.py --protein_ligand_csv "${DIFFDOCK_INPUT_FOLDER}/example.csv" --out_dir "${DIFFDOCK_OUTPUT_FOLDER}/results" --batch_size 12 --samples_per_complex 40 --keep_local_structures 20 | end_time=$(date +%s) 21 | cost_time=$[ $end_time-$start_time ] 22 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_dynamicbind/run_dynamicbind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | 7 | def main(args: argparse.Namespace): 8 | os.environ["MKL_THREADING_LAYER"] = "GNU" 9 | header = args.itemname 10 | result_root = os.path.join(args.dynamicbind_exec_dir, "inference", "outputs", "results") 11 | result_dir = os.path.join(result_root, header) 12 | subprocess.run( 13 | [ 14 | "python", 15 | "run_single_protein_inference.py", 16 | args.protein_filepath, 17 | args.ligand_filepath, 18 | "--samples_per_complex", 19 | "40", 20 | "--header", 21 | header, 22 | "--device", 23 | str(args.gpu_id), 24 | "--python", 25 | "python", 26 | "--relax_python", 27 | "python", 28 | "--results", 29 | result_root, 30 | "--no_relax", 31 | "--paper", 32 | ], 33 | cwd=args.dynamicbind_exec_dir, 34 | check=True 35 | ) 36 | 37 | shutil.move(result_dir, args.output_dir) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--itemname", type=str, required=True, help="PDB_CCD_ID") 43 | parser.add_argument("--protein_filepath", type=str, required=True, help="Path to the protein pdb file") 44 | parser.add_argument("--ligand_filepath", type=str, required=True, help="Path to the ligand csv file") 45 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 46 | parser.add_argument("--dynamicbind_exec_dir", type=str, required=True, help="Path to the DynamicBind project") 47 | parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID") 48 | args = parser.parse_args() 49 | 50 | main(args) -------------------------------------------------------------------------------- /scripts/run_dynamicbind/run_dynamicbind.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | DYNAMICBIND_EXEC_FOLDER="path/to/DynamicBind" 13 | DYNAMICBIND_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/dynamicbind/input") 14 | DYNAMICBIND_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/dynamicbind/output") 15 | GPU_ID=1 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate dynamicbind 20 | 21 | for protein_filepath in ${DYNAMICBIND_INPUT_FOLDER}/*.pdb; do 22 | itemname=$(basename "${protein_filepath}" .pdb) 23 | output_folder="${DYNAMICBIND_OUTPUT_FOLDER}/${itemname}" 24 | ligand_filepath="${DYNAMICBIND_INPUT_FOLDER}/${itemname}.csv" 25 | echo "Predicting ${itemname}..." 26 | python scripts/run_dynamicbind/run_dynamicbind.py \ 27 | --itemname ${itemname} \ 28 | --protein_filepath ${protein_filepath} \ 29 | --ligand_filepath ${ligand_filepath} \ 30 | --dynamicbind_exec_dir ${DYNAMICBIND_EXEC_FOLDER} \ 31 | --output_dir ${output_folder} \ 32 | --gpu_id ${GPU_ID} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/run_equibind/run_equibind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | 5 | 6 | def main(args: argparse.Namespace): 7 | os.chdir(os.path.join(args.equibind_exec_dir)) 8 | with open(args.yml_path, 'r', encoding='utf-8') as f: 9 | file_content = f.read() 10 | data = yaml.load(file_content, yaml.FullLoader) 11 | data["inference_path"] = args.input_dir 12 | data["output_directory"] = args.output_dir 13 | 14 | modified_yml = os.path.join(os.path.dirname(args.yml_path), "modified_inference.yml") 15 | with open(modified_yml, 'w', encoding='utf-8') as f: 16 | yaml.dump(data, stream=f, allow_unicode=True, encoding='utf-8') 17 | cmd = f"python inference.py --config={modified_yml}" 18 | os.system(cmd) 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--input_dir", type=str, required=True, help="Path to the input files") 24 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 25 | parser.add_argument("--equibind_exec_dir", type=str, required=True, help="Path to the Equibind codebase") 26 | parser.add_argument("--yml_path", type=str, required=True, help="Path to the yml config file") 27 | args = parser.parse_args() 28 | 29 | main(args) 30 | -------------------------------------------------------------------------------- /scripts/run_equibind/run_equibind.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | EQUIBIND_EXEC_FOLDER="path/to/equibind" 13 | EQUIBIND_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/equibind/input" 14 | EQUIBIND_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/equibind/output" 15 | 16 | # init conda 17 | eval "$(conda shell.bash hook)" 18 | conda activate equibind 19 | 20 | start_time=$(date +%s) 21 | python scripts/run_equibind/run_equibind.py \ 22 | --input_dir ${EQUIBIND_INPUT_FOLDER} \ 23 | --output_dir ${EQUIBIND_OUTPUT_FOLDER} \ 24 | --equibind_exec_dir ${EQUIBIND_EXEC_FOLDER} \ 25 | --yml_path "configs_clean/inference.yml" 26 | end_time=$(date +%s) 27 | cost_time=$[ $end_time-$start_time ] 28 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_fabind/run_fabind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | 5 | 6 | def main(args: argparse.Namespace): 7 | save_pt_dir = os.path.join(args.output_dir, "temp_files") 8 | save_mols_dir = os.path.join(save_pt_dir, "mol") 9 | # preprocess ligand 10 | subprocess.run( 11 | [ 12 | "python", 13 | "inference_preprocess_mol_confs.py", 14 | "--index_csv", 15 | args.input_csv_path, 16 | "--save_mols_dir", 17 | save_mols_dir, 18 | "--num_threads", 19 | "1", 20 | ], 21 | cwd=args.fabind_exec_dir, 22 | check=True 23 | ) 24 | # preprocess protein 25 | subprocess.run( 26 | [ 27 | "python", 28 | "inference_preprocess_protein.py", 29 | "--pdb_file_dir", 30 | args.input_data_dir, 31 | "--save_pt_dir", 32 | save_pt_dir, 33 | "--cuda_device_index", 34 | str(args.gpu_id), 35 | ], 36 | cwd=args.fabind_exec_dir, 37 | check=True 38 | ) 39 | # inference 40 | subprocess.run( 41 | [ 42 | "python", 43 | "fabind_inference.py", 44 | "--ckpt", 45 | args.ckpt_path, 46 | "--batch_size", 47 | "4", 48 | "--seed", 49 | "42", 50 | "--test-gumbel-soft", 51 | "--redocking", 52 | "--post-optim", 53 | "--write-mol-to-file", 54 | "--sdf-output-path-post-optim", 55 | args.output_dir, 56 | "--index-csv", 57 | args.input_csv_path, 58 | "--preprocess-dir", 59 | save_pt_dir, 60 | "--cuda_device_index", 61 | str(args.gpu_id), 62 | ], 63 | cwd=args.fabind_exec_dir, 64 | check=True 65 | ) 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the ligand_csv file") 70 | parser.add_argument("--input_data_dir", type=str, required=True, help="Path to the protein pdb dir") 71 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 72 | parser.add_argument("--fabind_exec_dir", type=str, required=True, help="Path to the FABind codebase") 73 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the model chekpoint") 74 | parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID") 75 | args = parser.parse_args() 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /scripts/run_fabind/run_fabind.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | FABIND_EXEC_FOLDER="path/to/fabind" 13 | CKPT_PATH="${FABIND_EXEC_FOLDER}/ckpt/best_model.bin" 14 | FABIND_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/fabind/input") 15 | FABIND_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/fabind/output") 16 | GPU_ID=0 17 | 18 | # init conda 19 | eval "$(conda shell.bash hook)" 20 | conda activate fabind 21 | 22 | input_csv_path="${FABIND_INPUT_FOLDER}/ligand.csv" 23 | input_data_dir="${FABIND_INPUT_FOLDER}/protein" 24 | 25 | python scripts/run_fabind/run_fabind.py \ 26 | --input_csv_path ${input_csv_path} \ 27 | --input_data_dir ${input_data_dir} \ 28 | --output_dir ${FABIND_OUTPUT_FOLDER} \ 29 | --fabind_exec_dir ${FABIND_EXEC_FOLDER} \ 30 | --ckpt_path ${CKPT_PATH} \ 31 | --gpu_id ${GPU_ID} 32 | -------------------------------------------------------------------------------- /scripts/run_gnina/run_gnina.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | if [ $# -eq 0 ]; then 4 | echo "Error: Please provide dataset as argument" 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | DATASET="$1" 10 | GNINA_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/gnina/input" 11 | GNINA_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/gnina/output" 12 | RUNNING_SCRIPTS="${PWD}/scripts/run_gnina/run_gnina_help.sh" 13 | 14 | 15 | start_time=$(date +%s) 16 | docker run -it \ 17 | --privileged=true \ 18 | --env CUDA_VISIBLE_DEVICES="1" \ 19 | --gpus "device=0" \ 20 | -v $GNINA_INPUT_FOLDER:/input \ 21 | -v $GNINA_OUTPUT_FOLDER:/output \ 22 | -v $RUNNING_SCRIPTS:/run_gnina_help.sh \ 23 | gnina/gnina:latest \ 24 | bash -c "/run_gnina_help.sh /input /output" 25 | end_time=$(date +%s) 26 | cost_time=$[ $end_time-$start_time ] 27 | echo "Running time for ${DATASET}: ${cost_time}" 28 | -------------------------------------------------------------------------------- /scripts/run_gnina/run_gnina_help.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | INPUT_FOLDER="$1" 4 | OUTPUT_FOLDER="$2" 5 | 6 | for pdb_ccd in "$INPUT_FOLDER"/*; do 7 | if [ -d "$pdb_ccd" ]; then 8 | protein_path="$pdb_ccd/${pdb_ccd##*/}_protein.pdb" 9 | ref_path="$pdb_ccd/${pdb_ccd##*/}_ligand.sdf" 10 | ligand_path="$pdb_ccd/${pdb_ccd##*/}_ligand_start_conf.sdf" 11 | 12 | output_dir="$OUTPUT_FOLDER/$(basename "$pdb_ccd")" 13 | output_path="$output_dir/$(basename "$pdb_ccd")_ligand.sdf" 14 | 15 | mkdir -p "$output_dir" 16 | 17 | cmd="gnina -r $protein_path -l $ligand_path --autobox_ligand $ref_path -o $output_path" 18 | 19 | eval "$cmd" 20 | fi 21 | done 22 | -------------------------------------------------------------------------------- /scripts/run_interformer/run_interformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from collections import defaultdict 4 | 5 | from rdkit import Chem 6 | import sys 7 | import pandas as pd 8 | import argparse 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = "1" 11 | 12 | 13 | def gen_demo_dock_csv(sdf_f, target, csv_path, isuff=True): 14 | data = [] 15 | suppl = Chem.SDMolSupplier(sdf_f, sanitize=False) 16 | for i, mol in enumerate(suppl): 17 | m_id = mol.GetProp('_Name') 18 | if mol is not None: 19 | if isuff: 20 | data.append([target, 0, i, m_id]) 21 | else: 22 | data.append([target, i, 0, m_id]) 23 | 24 | df = pd.DataFrame(data, columns=['Target', 'pose_rank', 'uff_pose_rank', 'Molecule ID']) 25 | df.to_csv(csv_path, index=False) 26 | 27 | 28 | def main(args: argparse.Namespace): 29 | os.chdir(os.path.join(args.interformer_exec_dir)) 30 | for pdb_ccd in os.listdir(args.input_dir): 31 | pdb = pdb_ccd.split("_")[0] 32 | task_dir = os.path.join(args.input_dir, pdb_ccd) 33 | raw_folder = os.path.join(task_dir, "raw") 34 | ligand_folder = os.path.join(task_dir, "ligand") 35 | crystal_ligand_folder = os.path.join(task_dir, "crystal_ligand") 36 | uff_folder = os.path.join(task_dir, "uff") 37 | pocket_folder = os.path.join(task_dir, "pocket") 38 | raw_pocket_folder = os.path.join(raw_folder, "pocket") 39 | os.makedirs(ligand_folder, exist_ok=True) 40 | os.makedirs(crystal_ligand_folder, exist_ok=True) 41 | os.makedirs(uff_folder, exist_ok=True) 42 | os.makedirs(pocket_folder, exist_ok=True) 43 | input_protein = os.path.join(raw_folder, f"{pdb_ccd}_protein.pdb") 44 | input_ligand = os.path.join(raw_folder, f"{pdb_ccd}_ligand.sdf") 45 | start_conf = os.path.join(raw_folder, f"{pdb_ccd}_ligand_start_conf.sdf") 46 | 47 | crystal_ligand_output_path = os.path.join(crystal_ligand_folder, f"{pdb}_docked.sdf") 48 | os.system(f"obabel {start_conf} -p 7.4 -O {crystal_ligand_output_path}") 49 | os.system(f"python tools/rdkit_ETKDG_3d_gen.py {crystal_ligand_folder} {uff_folder}") 50 | 51 | ligand_with_hydrogen_path = os.path.join(ligand_folder, f"{pdb}_docked.sdf") 52 | os.system(f"obabel {input_ligand} -p 7.4 -O {ligand_with_hydrogen_path}") 53 | # os.system(f"python tools/rdkit_ETKDG_3d_gen.py {ligand_folder} {uff_folder}") 54 | 55 | reduced_protein_path = os.path.join(raw_pocket_folder, f"{pdb}_reduce.pdb") 56 | os.system(f"mkdir -p {raw_pocket_folder} && reduce {input_protein} > {reduced_protein_path}") 57 | 58 | pocket_output_path = os.path.join(raw_pocket_folder, "output", f"{pdb}_pocket.pdb") 59 | os.system( 60 | f"python tools/extract_pocket_by_ligand.py {raw_pocket_folder} {ligand_folder} 0 " 61 | f"&& mv {pocket_output_path} {pocket_folder}") 62 | 63 | energy_output = os.path.join(task_dir, "energy_output") 64 | csv_path = os.path.join(task_dir, "demo_dock.csv") 65 | 66 | gen_demo_dock_csv(input_ligand, pdb, csv_path) 67 | 68 | predict_energy_cmd = f"PYTHONPATH=interformer/ python inference.py -test_csv {csv_path} \ 69 | -work_path {task_dir} \ 70 | -ensemble checkpoints/v0.2_energy_model \ 71 | -batch_size 1 \ 72 | -posfix *val_loss* \ 73 | -energy_output_folder {energy_output} \ 74 | -reload \ 75 | -debug" 76 | os.system(predict_energy_cmd) 77 | # os.system(f"cp {start_conf} {energy_output}/uff/{pdb}_uff.sdf") 78 | os.system( 79 | f'OMP_NUM_THREADS="64,64" python docking/reconstruct_ligands.py -y --cwd {energy_output} -y --find_all --uff_folder uff find') 80 | os.system(f'python docking/reconstruct_ligands.py --cwd {energy_output} --find_all stat') 81 | os.system( 82 | f'python docking/merge_summary_input.py {os.path.join(energy_output, "ligand_reconstructing/stat_concated.csv")} {csv_path}') 83 | infer_dir = os.path.join(task_dir, "infer") 84 | os.makedirs(infer_dir, exist_ok=True) 85 | os.system(f'cp -r {os.path.join(energy_output, "ligand_reconstructing")} {infer_dir}') 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--input_dir", type=str, required=True, help="Path to the input files") 91 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 92 | parser.add_argument("--interformer_exec_dir", type=str, required=True, help="Path to the Unimol codebase") 93 | args = parser.parse_args() 94 | 95 | main(args) 96 | -------------------------------------------------------------------------------- /scripts/run_interformer/run_interformer.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | INTERFORMER_EXEC_FOLDER="path/to/interformer" 13 | INTERFORMER_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/interformer/input" 14 | INTERFORMER_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/interformer/output" 15 | 16 | # init conda 17 | eval "$(conda shell.bash hook)" 18 | conda activate interformer 19 | 20 | export CUDA_VISIBLE_DEVICES="1" 21 | 22 | start_time=$(date +%s) 23 | python scripts/run_interformer/run_interformer.py \ 24 | --input_dir ${INTERFORMER_INPUT_FOLDER} \ 25 | --output_dir ${INTERFORMER_OUTPUT_FOLDER} \ 26 | --interformer_exec_dir ${INTERFORMER_EXEC_FOLDER} 27 | end_time=$(date +%s) 28 | cost_time=$[ $end_time-$start_time ] 29 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_neuralplexer/run_neuralplexer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def main(args: argparse.Namespace): 6 | for pdb_ccd in os.listdir(args.input_folder): 7 | pdb_ccd_dir = os.path.join(args.input_folder, pdb_ccd) 8 | protein_path = os.path.join(pdb_ccd_dir, f"{pdb_ccd}_protein.pdb") 9 | ligand_path = os.path.join(pdb_ccd_dir, f"{pdb_ccd}_ligand_start_conf.sdf") 10 | output_dir = os.path.join(args.output_folder, pdb_ccd) 11 | os.makedirs(output_dir, exist_ok=True) 12 | if os.path.isdir(pdb_ccd_dir): 13 | cmd = f"neuralplexer-inference --task=batched_structure_sampling \ 14 | --input-receptor {protein_path} \ 15 | --input-ligand {ligand_path} \ 16 | --use-template --input-template {protein_path} \ 17 | --out-path {output_dir}\ 18 | --model-checkpoint {args.model_checkpoint} \ 19 | --n-samples {args.n_samples} \ 20 | --chunk-size {args.chunk_size} \ 21 | --num-steps={args.num_steps} \ 22 | --cuda \ 23 | --sampler=langevin_simulated_annealing" 24 | os.system(cmd) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--input_folder", type=str, required=True, help="Path to the input folder") 30 | parser.add_argument("--output_folder", type=str, required=True, help="Path to the output folder") 31 | parser.add_argument("--model_checkpoint", type=str, required=True, help="Path to the model checkpoint") 32 | parser.add_argument("--n_samples", type=int, default=16, help="The number of conformations to generate in total") 33 | parser.add_argument("--chunk_size", type=int, default=4, help="The number of conformation to generate in parallel") 34 | parser.add_argument("--num_steps", type=int, default=40, 35 | help="The number of steps for the diffusion part of the sampling process") 36 | args = parser.parse_args() 37 | 38 | main(args) 39 | -------------------------------------------------------------------------------- /scripts/run_neuralplexer/run_neuralplexer.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | if [ $# -eq 0 ]; then 4 | echo "Error: Please provide dataset as argument" 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | DATASET="$1" 10 | NEURALPLEXER_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/neuralplexer/input" 11 | NEURALPLEXER_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/neuralplexer/output" 12 | RUNNING_SCRIPTS="${PWD}/scripts/run_neuralplexer/run_neuralplexer.py" 13 | MODEL_CHECKPOINT="${PWD}/data/benchmark/astex/neuralplexer/input/complex_structure_prediction.ckpt" 14 | 15 | if [ ! -e $MODEL_CHECKPOINT ]; then 16 | echo "File path ${MODEL_CHECKPOINT} does not exist" 17 | echo "Please download the model checkpoint" 18 | exit 1 19 | fi 20 | 21 | start_time=$(date +%s) 22 | docker run -it \ 23 | --privileged=true \ 24 | --env CUDA_VISIBLE_DEVICES="1" \ 25 | --gpus "device=0" \ 26 | -v $NEURALPLEXER_INPUT_FOLDER:/input \ 27 | -v $MODEL_CHECKPOINT:/input/complex_structure_prediction.ckpt \ 28 | -v $NEURALPLEXER_OUTPUT_FOLDER:/output \ 29 | -v $RUNNING_SCRIPTS:/run_neuralplexer.py \ 30 | neuralplexer:latest \ 31 | bash -c "source /opt/conda/etc/profile.d/conda.sh && conda activate NeuralPLexer && python /run_neuralplexer.py \ 32 | --input_folder /input \ 33 | --output_folder /output \ 34 | --model_checkpoint /input/complex_structure_prediction.ckpt" 35 | end_time=$(date +%s) 36 | cost_time=$[ $end_time-$start_time ] 37 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_protenix/run_protenix.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | PROTENIX_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/protenix/input" 13 | PROTENIX_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/protenix/output" 14 | 15 | eval "$(conda shell.bash hook)" 16 | conda activate protenix 17 | 18 | export CUDA_VISIBLE_DEVICES="1" 19 | 20 | start_time=$(date +%s) 21 | protenix predict --input $PROTENIX_INPUT_FOLDER --out_dir $PROTENIX_OUTPUT_FOLDER --seeds 101 --use_msa_server 22 | end_time=$(date +%s) 23 | cost_time=$[ $end_time-$start_time ] 24 | echo "Running time for ${DATASET}: ${cost_time}" 25 | 26 | -------------------------------------------------------------------------------- /scripts/run_rfaa/run_rfaa.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | 4 | # Check if model type is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | 13 | GPU_ID=2 14 | RFAA_INPUT_FOLDER="data/benchmark/${DATASET}/rfaa/input" 15 | RFAA_OUTPUT_FOLDER="data/benchmark/${DATASET}/rfaa/output" 16 | 17 | # Create output folder 18 | mkdir -p ${RFAA_OUTPUT_FOLDER} 19 | 20 | # init conda 21 | eval "$(conda shell.bash hook)" 22 | conda activate rfaa 23 | 24 | RFAA_REPO_FOLDER="path/to/RFAA" 25 | 26 | for yaml_file in ${RFAA_INPUT_FOLDER}/*.yaml; do 27 | filename=$(basename "${yaml_file}" .yaml) 28 | cp ${yaml_file} ${RFAA_REPO_FOLDER}/rf2aa/config/inference/ 29 | 30 | pushd ${RFAA_REPO_FOLDER} 31 | echo "Predicting ${filename} ..." 32 | CUDA_VISIBLE_DEVICES=${GPU_ID} python -m rf2aa.run_inference --config-name ${filename} 33 | popd 34 | 35 | echo "Copying predictions of ${filename} ..." 36 | mv ${RFAA_REPO_FOLDER}/predictions/${filename} ${RFAA_OUTPUT_FOLDER}/ 37 | rm ${RFAA_REPO_FOLDER}/rf2aa/config/inference/${filename}.yaml 38 | done 39 | 40 | conda activate posex -------------------------------------------------------------------------------- /scripts/run_surfdock/run_surfdock.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | SURFDOCK_EXEC_PATH="path/to/surfdock" 13 | SURFDOCK_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/surfdock/input" 14 | SURFDOCK_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/surfdock/output" 15 | RUNNING_SCRIPTS="${PWD}/scripts/run_surfdock/run_surfdock_help.sh" 16 | 17 | start_time=$(date +%s) 18 | docker run -it \ 19 | --privileged=true \ 20 | --gpus all \ 21 | -e CUDA_VISIBLE_DEVICES="1" \ 22 | -v $SURFDOCK_EXEC_PATH:/SurfDock \ 23 | -v $SURFDOCK_INPUT_FOLDER:/input \ 24 | -v $SURFDOCK_OUTPUT_FOLDER:/output \ 25 | -v $RUNNING_SCRIPTS:/run_surfdock_help.sh \ 26 | surfdock:v1 \ 27 | bash -c "/run_surfdock_help.sh /input /output" 28 | end_time=$(date +%s) 29 | cost_time=$[ $end_time-$start_time ] 30 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_surfdock/run_surfdock_help.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start_time=$(date +%s) 4 | # This script is used to run SurfDock on test samples 5 | path=$(readlink -f "$0") 6 | INPUT_DIR="$1" 7 | OUTPUT_DIR="$2" 8 | 9 | 10 | SurfDockdir="/SurfDock" 11 | 12 | temp="/" 13 | model_temp="/SurfDock" 14 | 15 | 16 | #------------------------------------------------------------------------------------------------# 17 | #------------------------------------ Step0 : Setup Params --------------------------------------# 18 | #------------------------------------------------------------------------------------------------# 19 | echo '------------------------------------ Step1 : Setup Params --------------------------------------' 20 | export precomputed_arrays="${temp}/precomputed/precomputed_arrays" 21 | gpu_string="1" 22 | echo "Using GPU devices: ${gpu_string}" 23 | IFS=',' read -ra gpu_array <<< "$gpu_string" 24 | NUM_GPUS=${#gpu_array[@]} 25 | export CUDA_VISIBLE_DEVICES=${gpu_string} 26 | 27 | main_process_port=2951${gpu_array[-1]} 28 | project_name='SurfDock_eval_samples/repeat_astex' 29 | surface_out_dir=${SurfDockdir}/data/eval_sample_dirs/${project_name}/test_samples_8A_surface 30 | data_dir=${INPUT_DIR} 31 | out_csv_file=${SurfDockdir}/data/eval_sample_dirs/${project_name}/input_csv_files/test_samples.csv 32 | esmbedding_dir=${SurfDockdir}/data/eval_sample_dirs/${project_name}/test_samples_esmbedding 33 | # project_name='SurfDock_Screen_samples/repeat5' 34 | 35 | #------------------------------------------------------------------------------------------------# 36 | #----------------------------- Step1 : Compute Target Surface -----------------------------------# 37 | #------------------------------------------------------------------------------------------------# 38 | echo '----------------------------- Step1 : Compute Target Surface -----------------------------------' 39 | mkdir -p $surface_out_dir 40 | cd $surface_out_dir 41 | command=` 42 | python ${SurfDockdir}/comp_surface/prepare_target/computeTargetMesh_test_samples.py \ 43 | --data_dir ${data_dir} \ 44 | --out_dir ${surface_out_dir} \ 45 | ` 46 | state=$command 47 | 48 | #------------------------------------------------------------------------------------------------# 49 | #-------------------------------- Step2 : Get Input CSV File -----------------------------------# 50 | #------------------------------------------------------------------------------------------------# 51 | echo '-------------------------------- Step2 : Get Input CSV File -----------------------------------' 52 | command=` python \ 53 | ${SurfDockdir}/inference_utils/construct_csv_input.py \ 54 | --data_dir ${data_dir} \ 55 | --surface_out_dir ${surface_out_dir} \ 56 | --output_csv_file ${out_csv_file} \ 57 | ` 58 | state=$command 59 | 60 | #------------------------------------------------------------------------------------------------# 61 | #-------------------------------- Step3 : Get Pocket ESM Embedding ----------------------------# 62 | #------------------------------------------------------------------------------------------------# 63 | echo '-------------------------------- Step3 : Get Pocket ESM Embedding ----------------------------' 64 | 65 | esm_dir=${SurfDockdir}/esm 66 | sequence_out_file="${esmbedding_dir}/test_samples.fasta" 67 | protein_pocket_csv=${out_csv_file} 68 | full_protein_esm_embedding_dir="${esmbedding_dir}/esm_embedding_output" 69 | pocket_emb_save_dir="${esmbedding_dir}/esm_embedding_pocket_output" 70 | pocket_emb_save_to_single_file="${esmbedding_dir}/esm_embedding_pocket_output_for_train/esm2_3billion_pdbbind_embeddings.pt" 71 | # get faste sequence 72 | command=`python ${SurfDockdir}/datasets/esm_embedding_preparation.py \ 73 | --out_file ${sequence_out_file} \ 74 | --protein_ligand_csv ${protein_pocket_csv}` 75 | state=$command 76 | # esm embedding preprateion 77 | 78 | command=`python ${esm_dir}/scripts/extract.py \ 79 | "esm2_t33_650M_UR50D" \ 80 | ${sequence_out_file} \ 81 | ${full_protein_esm_embedding_dir} \ 82 | --repr_layers 33 \ 83 | --include "per_tok" \ 84 | --truncation_seq_length 4096` 85 | state=$command 86 | 87 | 88 | # map pocket esm embedding 89 | command=`python ${SurfDockdir}/datasets/get_pocket_embedding.py \ 90 | --protein_pocket_csv ${protein_pocket_csv} \ 91 | --embeddings_dir ${full_protein_esm_embedding_dir} \ 92 | --pocket_emb_save_dir ${pocket_emb_save_dir}` 93 | state=$command 94 | 95 | # save pocket esm embedding to single file 96 | command=`python ${SurfDockdir}/datasets/esm_pocket_embeddings_to_pt.py \ 97 | --esm_embeddings_path ${pocket_emb_save_dir} \ 98 | --output_path ${pocket_emb_save_to_single_file}` 99 | state=$command 100 | 101 | 102 | #------------------------------------------------------------------------------------------------# 103 | #------------------------ Step4 : Start Sampling Ligand Confromers ----------------------------# 104 | #------------------------------------------------------------------------------------------------# 105 | echo '------------------------ Step4 : Start Sampling Ligand Confromers ----------------------------' 106 | 107 | diffusion_model_dir=${model_temp}/model_weights/docking 108 | confidence_model_base_dir=${model_temp}/model_weights/posepredict 109 | protein_embedding=${pocket_emb_save_to_single_file} 110 | test_data_csv=${out_csv_file} 111 | 112 | mdn_dist_threshold_test=3.0 113 | version=6 114 | dist_arrays=(3) 115 | for i in ${dist_arrays[@]} 116 | do 117 | mdn_dist_threshold_test=${i} 118 | command=`accelerate launch \ 119 | --multi_gpu \ 120 | --main_process_port ${main_process_port} \ 121 | --num_processes ${NUM_GPUS} \ 122 | ${SurfDockdir}/inference_accelerate.py \ 123 | --data_csv ${test_data_csv} \ 124 | --model_dir ${diffusion_model_dir} \ 125 | --ckpt best_ema_inference_epoch_model.pt \ 126 | --confidence_model_dir ${confidence_model_base_dir} \ 127 | --confidence_ckpt best_model.pt \ 128 | --save_docking_result \ 129 | --mdn_dist_threshold_test ${mdn_dist_threshold_test} \ 130 | --esm_embeddings_path ${protein_embedding} \ 131 | --run_name ${confidence_model_base_dir}_test_dist_${mdn_dist_threshold_test} \ 132 | --project ${project_name} \ 133 | --out_dir $OUTPUT_DIR \ 134 | --batch_size 40 \ 135 | --batch_size_molecule 1 \ 136 | --samples_per_complex 40 \ 137 | --save_docking_result_number 40 \ 138 | --head_index 0 \ 139 | --tail_index 10000 \ 140 | --inference_mode evaluate \ 141 | --wandb_dir ${temp}/docking_result/test_workdir` 142 | state=$command 143 | done 144 | end_time=$(date +%s) 145 | cost_time=$[ $end_time-$start_time ] 146 | echo "Running time : ${cost_time}" -------------------------------------------------------------------------------- /scripts/run_tankbind/run_tankbind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import shutil 5 | import logging 6 | import argparse 7 | import subprocess 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from rdkit import Chem 12 | from Bio.PDB import PDBParser 13 | from torch_geometric.loader import DataLoader 14 | 15 | torch.set_num_threads(1) 16 | 17 | class TankBindRunner(): 18 | def __init__(self, args) -> None: 19 | self.input_dir = args.input_dir 20 | self.output_dir = args.output_dir 21 | self.tankbind_exec_dir = args.tankbind_exec_dir 22 | self.p2rank_exec_path = args.p2rank_exec_path 23 | input_data_path = os.path.join(self.input_dir, "data.csv") 24 | self.input_data = pd.read_csv(input_data_path) 25 | self.pdb_list = self.input_data.PDB_CCD_ID.values 26 | self.rdkit_folder = f"{self.output_dir}/rdkit" 27 | os.makedirs(self.rdkit_folder, exist_ok=True) 28 | 29 | def _create_data(self, pockets_dict, protein_dict): 30 | info = [] 31 | for pdb in self.pdb_list: 32 | protein_name = pdb 33 | compound_name = pdb 34 | pocket = pockets_dict[pdb].head(10) 35 | pocket.columns = pocket.columns.str.strip() 36 | pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values 37 | # native block. 38 | info.append([protein_name, compound_name, pdb, None, True, False]) 39 | # protein center as a block. 40 | protein_com = protein_dict[protein_name][0].numpy().mean(axis=0).astype(float).reshape(1, 3) 41 | info.append([protein_name, compound_name, pdb+"_c", protein_com, 0, False, False]) 42 | for idx, pocket_line in pocket.iterrows(): 43 | pdb_idx = f"{pdb}_{idx}" 44 | info.append([protein_name, compound_name, pdb_idx, pocket_coms[idx].reshape(1, 3), False, False]) 45 | info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pdb', 'pocket_com', 'affinity', 46 | 'use_compound_com', 'use_whole_protein']) 47 | return info 48 | 49 | def _predict_protein_feature(self): 50 | protein_dict = {} 51 | for pdb in self.pdb_list: 52 | proteinFile = f"{self.input_dir}/protein_remove_extra_chains_10A/{pdb}_protein.pdb" 53 | parser = PDBParser(QUIET=True) 54 | s = parser.get_structure(pdb, proteinFile) 55 | res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True) 56 | protein_dict[pdb] = get_protein_feature(res_list) 57 | return protein_dict 58 | 59 | def _predict_ligand_feature(self): 60 | compound_dict = {} 61 | for _, row in self.input_data.iterrows(): 62 | pdb = row["PDB_CCD_ID"] 63 | sdf_path = row["LIGAND_SDF_PATH"] 64 | mol = Chem.SDMolSupplier(sdf_path)[0] 65 | smiles = Chem.MolToSmiles(mol) 66 | rdkit_mol_path = f"{self.rdkit_folder}/{pdb}_ligand.sdf" 67 | generate_sdf_from_smiles_using_rdkit(smiles, rdkit_mol_path, shift_dis=0) 68 | mol = Chem.SDMolSupplier(rdkit_mol_path)[0] 69 | compound_dict[pdb] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True) # self-dock set has_LAS_mask to true 70 | return compound_dict 71 | 72 | def _predict_pockets(self): 73 | # predict pockets by p2rank 74 | p2rank_prediction_folder = f"{self.input_dir}/p2rank_protein_remove_extra_chains_10A" 75 | os.system(f"mkdir -p {p2rank_prediction_folder}") 76 | ds = f"{p2rank_prediction_folder}/protein_list.ds" 77 | with open(ds, "w") as out: 78 | for pdb in self.pdb_list: 79 | out.write(f"../protein_remove_extra_chains_10A/{pdb}_protein.pdb\n") 80 | cmd = ["bash", self.p2rank_exec_path, "predict", ds, "-o", f"{p2rank_prediction_folder}/p2rank", "-threads", "8"] 81 | subprocess.run(cmd, check=True) 82 | 83 | # handle predictions 84 | d_list = [] 85 | for name in self.pdb_list: 86 | p2rankFile = f"{self.input_dir}/p2rank_protein_remove_extra_chains_10A/p2rank/{name}_protein.pdb_predictions.csv" 87 | d = pd.read_csv(p2rankFile) 88 | d.columns = d.columns.str.strip() 89 | d_list.append(d.assign(name=name)) 90 | d = pd.concat(d_list).reset_index(drop=True) 91 | d.reset_index(drop=True).to_feather(f"{self.input_dir}/p2rank_result.feather") 92 | d = pd.read_feather(f"{self.input_dir}/p2rank_result.feather") 93 | 94 | pockets_dict = {} 95 | for name in self.pdb_list: 96 | pockets_dict[name] = d[d.name == name].reset_index(drop=True) 97 | return pockets_dict 98 | 99 | def _save_conformation(self, dataset, chosen, y_pred_list): 100 | device = "cpu" 101 | for _, line in tqdm(chosen.iterrows(), total=chosen.shape[0]): 102 | name = line['compound_name'] 103 | dataset_index = line['dataset_index'] 104 | coords = dataset[dataset_index].coords.to(device) 105 | protein_nodes_xyz = dataset[dataset_index].node_xyz.to(device) 106 | n_compound = coords.shape[0] 107 | n_protein = protein_nodes_xyz.shape[0] 108 | y_pred = y_pred_list[dataset_index].reshape(n_protein, n_compound).to(device) 109 | y = dataset[dataset_index].dis_map.reshape(n_protein, n_compound).to(device) 110 | compound_pair_dis_constraint = torch.cdist(coords, coords) 111 | rdkit_mol_path = f"{self.rdkit_folder}/{name}_ligand.sdf" 112 | mol = Chem.SDMolSupplier(rdkit_mol_path)[0] 113 | LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool() 114 | pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint, 115 | LAS_distance_constraint_mask=LAS_distance_constraint_mask, 116 | n_repeat=1, show_progress=False) 117 | toFile = f'{self.output_dir}/{name}_tankbind_chosen.sdf' 118 | new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double) 119 | write_with_new_coords(mol, new_coords, toFile) 120 | 121 | def process_testset(self): 122 | toFolder = f"{self.input_dir}/protein_remove_extra_chains_10A/" 123 | os.makedirs(toFolder, exist_ok=True) 124 | for _, row in self.input_data.iterrows(): 125 | cutoff = 10 126 | itemname = row["PDB_CCD_ID"] 127 | toFile = f"{toFolder}/{itemname}_protein.pdb" 128 | shutil.copy(row["PROTEIN_PDB_PATH"], toFile) 129 | # x = (row["PROTEIN_PDB_PATH"], row["LIGAND_SDF_PATH"], cutoff, toFile) 130 | # select_chain_within_cutoff_to_ligand_v2(x) 131 | pockets_dict = self._predict_pockets() 132 | protein_dict = self._predict_protein_feature() 133 | compound_dict = self._predict_ligand_feature() 134 | data = self._create_data(pockets_dict, protein_dict) 135 | dataset_dir = f"{self.input_dir}/dataset" 136 | if os.path.exists(dataset_dir): 137 | shutil.rmtree(dataset_dir) 138 | os.makedirs(dataset_dir) 139 | testset = TankBindDataSet(dataset_dir, data=data, protein_dict=protein_dict, compound_dict=compound_dict) 140 | testset = TankBindDataSet(dataset_dir, proteinMode=0, compoundMode=1, pocket_radius=20, predDis=True) 141 | return testset 142 | 143 | def predict(self, dataset, device="cpu"): 144 | data_loader = DataLoader(dataset, batch_size=1, 145 | follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=8, pin_memory=True) 146 | logging.basicConfig(level=logging.INFO) 147 | model = get_model(0, logging, device) 148 | model.eval() 149 | model.load_state_dict(torch.load(f"{self.tankbind_exec_dir}/saved_models/self_dock.pt", map_location=device)) 150 | y_pred_list, affinity_pred_list = [], [] 151 | for data in tqdm(data_loader): 152 | data = data.to(device) 153 | with torch.no_grad(): 154 | y_pred, affinity_pred = model(data) 155 | affinity_pred_list.append(affinity_pred.detach().cpu()) 156 | for i in range(data.y_batch.max() + 1): 157 | y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu()) 158 | affinity_pred_list = torch.cat(affinity_pred_list) 159 | output_info_chosen = dataset.data 160 | output_info_chosen['affinity'] = affinity_pred_list 161 | output_info_chosen['dataset_index'] = range(len(output_info_chosen)) 162 | output_info_chosen = output_info_chosen.query("not use_compound_com").reset_index(drop=True) 163 | chosen = output_info_chosen.loc[output_info_chosen.groupby(['protein_name', 'compound_name'], sort=False)['affinity'].agg('idxmax')].reset_index() 164 | self._save_conformation(dataset, chosen, y_pred_list) 165 | 166 | def run(self): 167 | testset = self.process_testset() 168 | self.predict(testset) 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--tankbind_exec_dir", type=str, required=True, help="Path to the TankBind codebase") 173 | parser.add_argument("--p2rank_exec_path", type=str, required=True, help="Path to the p2rank_exec_path") 174 | parser.add_argument("--input_dir", type=str, required=True, help="Path to the input dir") 175 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 176 | parser.add_argument("--gpu_id", type=int, required=True, help="GPU ID") 177 | args = parser.parse_args() 178 | 179 | tankbind_module_dir = os.path.join(args.tankbind_exec_dir, "tankbind") 180 | sys.path.insert(0, tankbind_module_dir) 181 | from data import TankBindDataSet 182 | from model import get_model 183 | from feature_utils import select_chain_within_cutoff_to_ligand_v2, get_protein_feature, \ 184 | get_clean_res_list, extract_torchdrug_feature_from_mol, generate_sdf_from_smiles_using_rdkit 185 | from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords 186 | 187 | tankbind_runner = TankBindRunner(args) 188 | tankbind_runner.run() 189 | -------------------------------------------------------------------------------- /scripts/run_tankbind/run_tankbind.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | TANKBIND_EXEC_FOLDER="path/to/TankBind" 13 | TANKBIND_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/tankbind/input" 14 | TANKBIND_OUTPUT_FOLDER="${PWD}data/benchmark/${DATASET}/tankbind/output" 15 | GPU_ID=0 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate tankbind_py38 20 | 21 | p2rank_exec_path="${TANKBIND_EXEC_FOLDER}/package/p2rank/prank" 22 | python scripts/run_tankbind/run_tankbind.py \ 23 | --tankbind_exec_dir ${TANKBIND_EXEC_FOLDER} \ 24 | --p2rank_exec_path ${p2rank_exec_path} \ 25 | --input_dir ${TANKBIND_INPUT_FOLDER} \ 26 | --output_dir ${TANKBIND_OUTPUT_FOLDER} \ 27 | --gpu_id ${GPU_ID} 28 | -------------------------------------------------------------------------------- /scripts/run_unimol/run_unimol.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | 6 | def main(args: argparse.Namespace): 7 | os.chdir(os.path.join(args.unimol_exec_dir, "interface")) 8 | for pdb_ccd in os.listdir(args.input_dir): 9 | input_protein = os.path.join(args.input_dir, pdb_ccd, f"{pdb_ccd}_protein.pdb") 10 | input_docking_grid = os.path.join(args.input_dir, pdb_ccd, f"{pdb_ccd}.json") 11 | input_ligand = os.path.join(args.input_dir, pdb_ccd, f"{pdb_ccd}_ligand_start_conf.sdf") 12 | output_dir = os.path.join(args.output_dir, pdb_ccd) 13 | os.makedirs(output_dir, exist_ok=True) 14 | cmd = f"python demo.py --mode single --conf-size 10 --cluster \ 15 | --input-protein {input_protein} \ 16 | --input-ligand {input_ligand} \ 17 | --input-docking-grid {input_docking_grid} \ 18 | --output-ligand-name ligand_predict \ 19 | --output-ligand-dir {output_dir} \ 20 | --steric-clash-fix \ 21 | --model-dir {args.ckpt_path}" 22 | os.system(cmd) 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--input_dir", type=str, required=True, help="Path to the input files") 28 | parser.add_argument("--output_dir", type=str, required=True, help="Path to save the output") 29 | parser.add_argument("--unimol_exec_dir", type=str, required=True, help="Path to the Unimol codebase") 30 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the model chekpoint") 31 | args = parser.parse_args() 32 | 33 | main(args) 34 | -------------------------------------------------------------------------------- /scripts/run_unimol/run_unimol.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # Check if dataset is provided as argument 5 | if [ $# -eq 0 ]; then 6 | echo "Error: Please provide dataset as argument" 7 | echo "Usage: $0 " 8 | exit 1 9 | fi 10 | 11 | DATASET="$1" 12 | UNIMOL_EXEC_FOLDER="path/to/unimol" 13 | CKPT_PATH="${UNIMOL_EXEC_FOLDER}/ckpt/unimol_docking_v2_240517.pt" 14 | UNIMOL_INPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/unimol/input" 15 | UNIMOL_OUTPUT_FOLDER="${PWD}/data/benchmark/${DATASET}/unimol/output" 16 | 17 | # init conda 18 | eval "$(conda shell.bash hook)" 19 | conda activate unicore 20 | 21 | export MKL_SERVICE_FORCE_INTEL=1 22 | export CUDA_VISIBLE_DEVICES=1 23 | 24 | 25 | start_time=$(date +%s) 26 | python scripts/run_unimol/run_unimol.py \ 27 | --input_dir ${UNIMOL_INPUT_FOLDER} \ 28 | --output_dir ${UNIMOL_OUTPUT_FOLDER} \ 29 | --unimol_exec_dir ${UNIMOL_EXEC_FOLDER} \ 30 | --ckpt_path ${CKPT_PATH} 31 | end_time=$(date +%s) 32 | cost_time=$[ $end_time-$start_time ] 33 | echo "Running time for ${DATASET}: ${cost_time}" -------------------------------------------------------------------------------- /tests/s1_download_mmcif.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import wget 3 | 4 | in_dir = "posex/posex_self_docking_set" 5 | out_dir = "posex/mmcif_raw" 6 | 7 | 8 | def download_one(pdb_id: str, save_fn: Path): 9 | url = f"https://files.rcsb.org/download/{pdb_id}.cif" # 请替换为实际的URL 10 | # 下载文件 11 | wget.download(url, str(save_fn)) 12 | 13 | 14 | def run(): 15 | input_fns = list(Path(in_dir).glob("*/*.json")) 16 | cur_dir = Path(out_dir) 17 | cur_dir.mkdir(exist_ok=True, parents=True) 18 | for input_fn in input_fns[:10]: 19 | item_name = input_fn.parent.stem 20 | pdb_id = item_name.split("_")[0].lower() 21 | save_fn = cur_dir / f"{item_name}.cif" 22 | download_one(pdb_id, Path(save_fn)) 23 | print(input_fn) 24 | 25 | 26 | run() 27 | --------------------------------------------------------------------------------