├── .gitignore ├── LICENSE ├── README.md ├── citation.bib ├── environment.yml ├── notebooks ├── data_usage.ipynb ├── data_usage.py ├── feature_generation.ipynb └── feature_generation.py ├── project ├── datasets │ ├── DB5 │ │ ├── README │ │ ├── db5_dgl_data_module.py │ │ └── db5_dgl_dataset.py │ ├── DIPS │ │ └── filters │ │ │ ├── buried_surface_over_500.txt │ │ │ ├── nmr_res_less_3.5.txt │ │ │ ├── seq_id_less_30.txt │ │ │ └── size_over_50.0.txt │ ├── EVCoupling │ │ ├── evcoupling_heterodimer_list.txt │ │ ├── final │ │ │ └── stub │ │ ├── group_files.py │ │ └── interim │ │ │ └── stub │ ├── analysis │ │ ├── analyze_experiment_types_and_resolution.py │ │ ├── analyze_feature_correlation.py │ │ └── analyze_interface_waters.py │ └── builder │ │ ├── add_new_feature.py │ │ ├── annotate_idr_residues.py │ │ ├── collect_dataset_statistics.py │ │ ├── compile_casp_capri_dataset_on_andes.sh │ │ ├── compile_db5_dataset_on_andes.sh │ │ ├── compile_dips_dataset_on_andes.sh │ │ ├── compile_evcoupling_dataset_on_andes.sh │ │ ├── convert_complexes_to_graphs.py │ │ ├── create_hdf5_dataset.py │ │ ├── download_missing_pruned_pair_pdbs.py │ │ ├── extract_raw_pdb_gz_archives.py │ │ ├── generate_hhsuite_features.py │ │ ├── generate_psaia_features.py │ │ ├── impute_missing_feature_values.py │ │ ├── launch_parallel_slurm_jobs_on_andes.sh │ │ ├── log_dataset_statistics.py │ │ ├── make_dataset.py │ │ ├── partition_dataset_filenames.py │ │ ├── postprocess_pruned_pairs.py │ │ ├── prune_pairs.py │ │ ├── psaia_chothia.radii │ │ ├── psaia_config_file_casp_capri.txt │ │ ├── psaia_config_file_db5.txt │ │ ├── psaia_config_file_dips.txt │ │ ├── psaia_config_file_evcoupling.txt │ │ ├── psaia_hydrophobicity.hpb │ │ └── psaia_natural_asa.asa └── utils │ ├── constants.py │ ├── modules.py │ ├── training_constants.py │ ├── training_utils.py │ └── utils.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .github 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | .conda/ 94 | miniconda3 95 | venv.tar.gz 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | # IDEs 111 | .idea 112 | .vscode 113 | 114 | # TensorBoard 115 | tb_logs/ 116 | 117 | # Feature Processing 118 | *idr_annotation*.txt 119 | *work_filenames*.csv 120 | 121 | # DIPS 122 | project/datasets/DIPS/complexes/** 123 | project/datasets/DIPS/interim/** 124 | project/datasets/DIPS/pairs/** 125 | project/datasets/DIPS/parsed/** 126 | project/datasets/DIPS/raw/** 127 | project/datasets/DIPS/final/raw/** 128 | project/datasets/DIPS/final/load_pair_example.py 129 | project/datasets/DIPS/final/final_raw_dips*.tar.gz* 130 | project/datasets/DIPS/final/processed/** 131 | 132 | # DB5 133 | project/datasets/DB5/processed/** 134 | project/datasets/DB5/raw/** 135 | project/datasets/DB5/interim/** 136 | project/datasets/DB5/final/raw/** 137 | project/datasets/DB5/final/final_raw_db5.tar.gz* 138 | project/datasets/DB5/final/processed/** 139 | 140 | # EVCoupling 141 | project/datasets/EVCoupling/raw/** 142 | project/datasets/EVCoupling/interim/** 143 | project/datasets/EVCoupling/final/raw/** 144 | project/datasets/EVCoupling/final/processed/** 145 | 146 | # CASP-CAPRI 147 | project/datasets/CASP-CAPRI/raw/** 148 | project/datasets/CASP-CAPRI/interim/** 149 | project/datasets/CASP-CAPRI/final/raw/** 150 | project/datasets/CASP-CAPRI/final/processed/** 151 | 152 | # Input 153 | project/datasets/Input/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DIPS-Plus 4 | 5 | The Enhanced Database of Interacting Protein Structures for Interface Prediction 6 | 7 | [![Paper](http://img.shields.io/badge/paper-arxiv.2106.04362-B31B1B.svg)](https://www.nature.com/articles/s41597-023-02409-3) [![CC BY 4.0][cc-by-shield]][cc-by] [![Primary Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5134732.svg)](https://doi.org/10.5281/zenodo.5134732) [![Supplementary Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.8140981.svg)](https://doi.org/10.5281/zenodo.8140981) 8 | 9 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 10 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png 11 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 12 | 13 | [comment]: <> ([![Conference](http://img.shields.io/badge/NeurIPS-2021-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-35-2021)) 14 | 15 | [](https://pypi.org/project/DIPS-Plus/) 16 | 17 |
18 | 19 | ## Versioning 20 | 21 | * Version 1.0.0: Initial release of DIPS-Plus and DB5-Plus (DOI: 10.5281/zenodo.4815267) 22 | * Version 1.1.0: Minor updates to DIPS-Plus and DB5-Plus' tar archives (DOI: 10.5281/zenodo.5134732) 23 | * DIPS-Plus' final 'raw' tar archive now includes standardized 80%-20% lists of filenames for training and validation, respectively 24 | * DB5-Plus' final 'raw' tar archive now includes (optional) standardized lists of filenames for training and validation, respectively 25 | * DB5-Plus' final 'raw' tar archive now also includes a corrected (i.e. de-duplicated) list of filenames for its 55 test complexes 26 | * Benchmark results included in our paper were run after this issue was resolved 27 | * However, if you ran experiments using DB5-Plus' filename list for its test complexes, please re-run them using the latest list 28 | * Version 1.2.0: Minor additions to DIPS-Plus tar archives, including new residue-level intrinsic disorder region annotations and raw Jackhmmer-small BFD MSAs (Supplementary Data DOI: 10.5281/zenodo.8071136) 29 | * Version 1.3.0: Minor additions to DIPS-Plus tar archives, including new FoldSeek-based structure-focused training and validation splits, residue-level (scalar) disorder propensities, and a Graphein-based featurization pipeline (Supplementary Data DOI: 10.5281/zenodo.8140981) 30 | 31 | ## How to set up 32 | 33 | First, download Mamba (if not already downloaded): 34 | ```bash 35 | wget "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" 36 | bash Mambaforge-$(uname)-$(uname -m).sh # Accept all terms and install to the default location 37 | rm Mambaforge-$(uname)-$(uname -m).sh # (Optionally) Remove installer after using it 38 | source ~/.bashrc # Alternatively, one can restart their shell session to achieve the same result 39 | ``` 40 | 41 | Then, create and configure Mamba environment: 42 | 43 | ```bash 44 | # Clone project: 45 | git clone https://github.com/BioinfoMachineLearning/DIPS-Plus 46 | cd DIPS-Plus 47 | 48 | # Create Conda environment using local 'environment.yml' file: 49 | mamba env create -f environment.yml 50 | conda activate DIPS-Plus # Note: One still needs to use `conda` to (de)activate environments 51 | 52 | # Install local project as package: 53 | pip3 install -e . 54 | ``` 55 | 56 | To install PSAIA for feature generation, install GCC 10 for PSAIA: 57 | 58 | ```bash 59 | # Install GCC 10 for Ubuntu 20.04: 60 | sudo apt install software-properties-common 61 | sudo add-apt-repository ppa:ubuntu-toolchain-r/ppa 62 | sudo apt update 63 | sudo apt install gcc-10 g++-10 64 | 65 | # Or install GCC 10 for Arch Linux/Manjaro: 66 | yay -S gcc10 67 | ``` 68 | 69 | Then install QT4 for PSAIA: 70 | 71 | ```bash 72 | # Install QT4 for Ubuntu 20.04: 73 | sudo add-apt-repository ppa:rock-core/qt4 74 | sudo apt update 75 | sudo apt install libqt4* libqtcore4 libqtgui4 libqtwebkit4 qt4* libxext-dev 76 | 77 | # Or install QT4 for Arch Linux/Manjaro: 78 | yay -S qt4 79 | ``` 80 | 81 | Conclude by compiling PSAIA from source: 82 | 83 | ```bash 84 | # Select the location to install the software: 85 | MY_LOCAL=~/Programs 86 | 87 | # Download and extract PSAIA's source code: 88 | mkdir "$MY_LOCAL" 89 | cd "$MY_LOCAL" 90 | wget http://complex.zesoi.fer.hr/data/PSAIA-1.0-source.tar.gz 91 | tar -xvzf PSAIA-1.0-source.tar.gz 92 | 93 | # Compile PSAIA (i.e., a GUI for PSA): 94 | cd PSAIA_1.0_source/make/linux/psaia/ 95 | qmake-qt4 psaia.pro 96 | make 97 | 98 | # Compile PSA (i.e., the protein structure analysis (PSA) program): 99 | cd ../psa/ 100 | qmake-qt4 psa.pro 101 | make 102 | 103 | # Compile PIA (i.e., the protein interaction analysis (PIA) program): 104 | cd ../pia/ 105 | qmake-qt4 pia.pro 106 | make 107 | 108 | # Test run any of the above-compiled programs: 109 | cd "$MY_LOCAL"/PSAIA_1.0_source/bin/linux 110 | # Test run PSA inside a GUI: 111 | ./psaia/psaia 112 | # Test run PIA through a terminal: 113 | ./pia/pia 114 | # Test run PSA through a terminal: 115 | ./psa/psa 116 | ``` 117 | 118 | Lastly, install Docker following the instructions from https://docs.docker.com/engine/install/ 119 | 120 | ## How to generate protein feature inputs 121 | In our [feature generation notebook](notebooks/feature_generation.ipynb), we provide examples of how users can generate the protein features described in our [accompanying manuscript](https://arxiv.org/abs/2106.04362) for individual protein inputs. 122 | 123 | ## How to use data 124 | In our [data usage notebook](notebooks/data_usage.ipynb), we provide examples of how users might use DIPS-Plus (or DB5-Plus) for downstream analysis or prediction tasks. For example, to train a new NeiA model with DB5-Plus as its cross-validation dataset, first download DB5-Plus' raw files and process them via the `data_usage` notebook: 125 | 126 | ```bash 127 | mkdir -p project/datasets/DB5/final 128 | wget https://zenodo.org/record/5134732/files/final_raw_db5.tar.gz -O project/datasets/DB5/final/final_raw_db5.tar.gz 129 | tar -xzf project/datasets/DB5/final/final_raw_db5.tar.gz -C project/datasets/DB5/final/ 130 | 131 | # To process these raw files for training and subsequently train a model: 132 | python3 notebooks/data_usage.py 133 | ``` 134 | 135 | ## How to split data using FoldSeek 136 | We provide users with the [ability](https://github.com/BioinfoMachineLearning/DIPS-Plus/blob/75775d98f0923faf11fb50639eb58cd510a10ffd/project/datasets/builder/partition_dataset_filenames.py#L486) to perform structure-based splits of the complexes in DIPS-Plus using FoldSeek. This script is designed to allow users to customize how stringent one would like FoldSeek's searches to be for structure-based splitting. Moreover, we provide standardized structure-based splits of DIPS-Plus' complexes in the corresponding [supplementary Zenodo data record](https://zenodo.org/record/8140981). 137 | 138 | ## How to featurize DIPS-Plus complexes using Graphein 139 | In the new [graph featurization script](https://github.com/BioinfoMachineLearning/DIPS-Plus/blob/main/project/datasets/builder/add_new_feature.py), we provide an example of how users may install new Expasy protein scale features using the Graphein library. The script is designed to be amenable to simple user customization such that users can use this script to insert arbitrary new Graphein-based features into each DIPS-Plus complex's pair file, for downstream tasks. 140 | 141 | ## Standard DIPS-Plus directory structure 142 | 143 | ``` 144 | DIPS-Plus 145 | │ 146 | └───project 147 | │ 148 | └───datasets 149 | │ 150 | └───DB5 151 | │ │ 152 | │ └───final 153 | │ │ │ 154 | │ │ └───processed # task-ready features for each dataset example 155 | │ │ │ 156 | │ │ └───raw # generic features for each dataset example 157 | │ │ 158 | │ └───interim 159 | │ │ │ 160 | │ │ └───complexes # metadata for each dataset example 161 | │ │ │ 162 | │ │ └───external_feats # features curated for each dataset example using external tools 163 | │ │ │ 164 | │ │ └───pairs # pair-wise features for each dataset example 165 | │ │ 166 | │ └───raw # raw PDB data downloads for each dataset example 167 | │ 168 | └───DIPS 169 | │ 170 | └───filters # filters to apply to each (un-pruned) dataset example 171 | │ 172 | └───final 173 | │ │ 174 | │ └───processed # task-ready features for each dataset example 175 | │ │ 176 | │ └───raw # generic features for each dataset example 177 | │ 178 | └───interim 179 | │ │ 180 | │ └───complexes # metadata for each dataset example 181 | │ │ 182 | │ └───external_feats # features curated for each dataset example using external tools 183 | │ │ 184 | │ └───pairs-pruned # filtered pair-wise features for each dataset example 185 | │ │ 186 | │ └───parsed # pair-wise features for each dataset example after initial parsing 187 | │ 188 | └───raw 189 | │ 190 | └───pdb # raw PDB data downloads for each dataset example 191 | ``` 192 | 193 | ## How to compile DIPS-Plus from scratch 194 | 195 | Retrieve protein complexes from the RCSB PDB and build out directory structure: 196 | 197 | ```bash 198 | # Remove all existing training/testing sample lists 199 | rm project/datasets/DIPS/final/raw/pairs-postprocessed.txt project/datasets/DIPS/final/raw/pairs-postprocessed-train.txt project/datasets/DIPS/final/raw/pairs-postprocessed-val.txt project/datasets/DIPS/final/raw/pairs-postprocessed-test.txt 200 | 201 | # Create data directories (if not already created): 202 | mkdir project/datasets/DIPS/raw project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim project/datasets/DIPS/interim/pairs-pruned project/datasets/DIPS/interim/external_feats project/datasets/DIPS/final project/datasets/DIPS/final/raw project/datasets/DIPS/final/processed 203 | 204 | # Download the raw PDB files: 205 | rsync -rlpt -v -z --delete --port=33444 --include='*.gz' --include='*.xz' --include='*/' --exclude '*' \ 206 | rsync.rcsb.org::ftp_data/biounit/coordinates/divided/ project/datasets/DIPS/raw/pdb 207 | 208 | # Extract the raw PDB files: 209 | python3 project/datasets/builder/extract_raw_pdb_gz_archives.py project/datasets/DIPS/raw/pdb 210 | 211 | # Process the raw PDB data into associated pair files: 212 | python3 project/datasets/builder/make_dataset.py project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim --num_cpus 28 --source_type rcsb --bound 213 | 214 | # Apply additional filtering criteria: 215 | python3 project/datasets/builder/prune_pairs.py project/datasets/DIPS/interim/pairs project/datasets/DIPS/filters project/datasets/DIPS/interim/pairs-pruned --num_cpus 28 216 | 217 | # Generate externally-sourced features: 218 | python3 project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_dips.txt "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats --source_type rcsb 219 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file # Note: After this, one needs to re-run this command with `--read_file` instead 220 | 221 | # Generate multiple sequence alignments (MSAs) using a smaller sequence database (if not already created using the standard BFD): 222 | DOWNLOAD_DIR="$HHSUITE_DB_DIR" && ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" && SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" && BASENAME=$(basename "${SOURCE_URL}") && mkdir --parents "${ROOT_DIR}" && aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" && pushd "${ROOT_DIR}" && gunzip "${ROOT_DIR}/${BASENAME}" && popd # e.g., Download the small BFD 223 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB_DIR"/small_bfd "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --generate_msa_only --write_file # Note: After this, one needs to re-run this command with `--read_file` instead 224 | 225 | # Identify interfaces within intrinsically disordered regions (IDRs) # 226 | # (1) Pull down the Docker image for `flDPnn` 227 | docker pull docker.io/sinaghadermarzi/fldpnn 228 | # (2) For all sequences in the dataset, predict which interface residues reside within IDRs 229 | python3 project/datasets/builder/annotate_idr_interfaces.py "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 16 230 | 231 | # Add new features to the filtered pairs, ensuring that the pruned pairs' original PDB files are stored locally for DSSP: 232 | python3 project/datasets/builder/download_missing_pruned_pair_pdbs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2" 233 | python3 project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 32 234 | 235 | # Partition dataset filenames, aggregate statistics, and impute missing features 236 | python3 project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DIPS/final/raw --source_type rcsb --filter_by_atom_count True --max_atom_count 17500 --rank "$1" --size "$2" 237 | python3 project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2" 238 | python3 project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2" 239 | python3 project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DIPS/final/raw --impute_atom_features False --advanced_logging False --num_cpus 32 --rank "$1" --size "$2" 240 | 241 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels 242 | python3 project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DIPS/final/raw "$PROJDIR"/project/datasets/DIPS/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2" 243 | ``` 244 | 245 | ## How to assemble DB5-Plus 246 | 247 | Fetch prepared protein complexes from Dataverse: 248 | 249 | ```bash 250 | # Download the prepared DB5 files: 251 | wget -O project/datasets/DB5.tar.gz https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/H93ZKK/BXXQCG 252 | 253 | # Extract downloaded DB5 archive: 254 | tar -xzf project/datasets/DB5.tar.gz --directory project/datasets/ 255 | 256 | # Remove (now) redundant DB5 archive and other miscellaneous files: 257 | rm project/datasets/DB5.tar.gz project/datasets/DB5/.README.swp 258 | rm -rf project/datasets/DB5/interim project/datasets/DB5/processed 259 | 260 | # Create relevant interim and final data directories: 261 | mkdir project/datasets/DB5/interim project/datasets/DB5/interim/external_feats 262 | mkdir project/datasets/DB5/final project/datasets/DB5/final/raw project/datasets/DB5/final/processed 263 | 264 | # Construct DB5 dataset pairs: 265 | python3 project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim --num_cpus 32 --source_type db5 --unbound 266 | 267 | # Generate externally-sourced features: 268 | python3 project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_db5.txt "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/external_feats --source_type db5 269 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/DB5/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type db5 --write_file 270 | 271 | # Add new features to the filtered pairs: 272 | python3 project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/pairs "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/final/raw --num_cpus 32 --source_type db5 273 | 274 | # Partition dataset filenames, aggregate statistics, and impute missing features 275 | python3 project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DB5/final/raw --source_type db5 --rank "$1" --size "$2" 276 | python3 project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2" 277 | python3 project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2" 278 | python3 project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DB5/final/raw --impute_atom_features False --advanced_logging False --num_cpus 32 --rank "$1" --size "$2" 279 | 280 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels 281 | python3 project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DB5/final/raw "$PROJDIR"/project/datasets/DB5/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2" 282 | ``` 283 | 284 | ## How to reassemble DIPS-Plus' "interim" external features 285 | 286 | We split the (tar.gz) archive into eight separate parts with 287 | 'split -b 4096M interim_external_feats_dips.tar.gz "interim_external_feats_dips.tar.gz.part"' 288 | to upload it to the dataset's primary Zenodo record, so to recover the original archive: 289 | 290 | ```bash 291 | # Reassemble external features archive with 'cat' 292 | cat interim_external_feats_dips.tar.gz.parta* >interim_external_feats_dips.tar.gz 293 | ``` 294 | 295 | ## Python 2 to 3 pickle file solution 296 | 297 | While using Python 3 in this project, you may encounter the following error if you try to postprocess '.dill' pruned 298 | pairs that were created using Python 2. 299 | 300 | ModuleNotFoundError: No module named 'dill.dill' 301 | 302 | 1. To resolve it, ensure that the 'dill' package's version is greater than 0.3.2. 303 | 2. If the problem persists, edit the pickle.py file corresponding to your Conda environment's Python 3 installation ( 304 | e.g. ~/DIPS-Plus/venv/lib/python3.8/pickle.py) and add the statement 305 | 306 | ```python 307 | if module == 'dill.dill': module = 'dill._dill' 308 | ``` 309 | 310 | to the end of the 311 | 312 | ```python 313 | if self.proto < 3 and self.fix_imports: 314 | ``` 315 | 316 | block in the Unpickler class' find_class() function 317 | (e.g. line 1577 of Python 3.8.5's pickle.py). 318 | 319 | ## Citation 320 | If you find DIPS-Plus useful in your research, please cite: 321 | 322 | ```bibtex 323 | @article{morehead2023dips, 324 | title={DIPS-Plus: The enhanced database of interacting protein structures for interface prediction}, 325 | author={Morehead, Alex and Chen, Chen and Sedova, Ada and Cheng, Jianlin}, 326 | journal={Scientific Data}, 327 | volume={10}, 328 | number={1}, 329 | pages={509}, 330 | year={2023}, 331 | publisher={Nature Publishing Group UK London} 332 | } 333 | ``` 334 | -------------------------------------------------------------------------------- /citation.bib: -------------------------------------------------------------------------------- 1 | @misc{morehead2021dipsplus, 2 | title={DIPS-Plus: The Enhanced Database of Interacting Protein Structures for Interface Prediction}, 3 | author={Alex Morehead and Chen Chen and Ada Sedova and Jianlin Cheng}, 4 | year={2021}, 5 | eprint={2106.04362}, 6 | archivePrefix={arXiv}, 7 | primaryClass={q-bio.QM} 8 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DIPS-Plus 2 | channels: 3 | - pytorch 4 | - salilab 5 | - dglteam/label/cu116 6 | - nvidia 7 | - bioconda 8 | - defaults 9 | - conda-forge 10 | dependencies: 11 | - _libgcc_mutex=0.1=conda_forge 12 | - _openmp_mutex=4.5=2_kmp_llvm 13 | - appdirs=1.4.4=pyhd3eb1b0_0 14 | - aria2=1.23.0=0 15 | - asttokens=2.2.1=pyhd8ed1ab_0 16 | - backcall=0.2.0=pyh9f0ad1d_0 17 | - backports=1.0=pyhd8ed1ab_3 18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 19 | - binutils_impl_linux-64=2.36.1=h193b22a_2 20 | - biopython=1.78=py38h7f8727e_0 21 | - blas=1.0=openblas 22 | - bottleneck=1.3.5=py38h7deecbd_0 23 | - brotlipy=0.7.0=py38h27cfd23_1003 24 | - bzip2=1.0.8=h7b6447c_0 25 | - c-ares=1.19.1=hd590300_0 26 | - ca-certificates=2023.5.7=hbcca054_0 27 | - certifi=2023.5.7=py38h06a4308_0 28 | - cffi=1.15.1=py38h5eee18b_3 29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 30 | - colorama=0.4.6=pyhd8ed1ab_0 31 | - comm=0.1.3=pyhd8ed1ab_0 32 | - cryptography=39.0.0=py38h1724139_0 33 | - cuda=11.6.1=0 34 | - cuda-cccl=11.6.55=hf6102b2_0 35 | - cuda-command-line-tools=11.6.2=0 36 | - cuda-compiler=11.6.2=0 37 | - cuda-cudart=11.6.55=he381448_0 38 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 39 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 40 | - cuda-cupti=11.6.124=h86345e5_0 41 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 42 | - cuda-driver-dev=11.6.55=0 43 | - cuda-gdb=12.1.105=0 44 | - cuda-libraries=11.6.1=0 45 | - cuda-libraries-dev=11.6.1=0 46 | - cuda-memcheck=11.8.86=0 47 | - cuda-nsight=12.1.105=0 48 | - cuda-nsight-compute=12.1.1=0 49 | - cuda-nvcc=11.6.124=hbba6d2d_0 50 | - cuda-nvdisasm=12.1.105=0 51 | - cuda-nvml-dev=11.6.55=haa9ef22_0 52 | - cuda-nvprof=12.1.105=0 53 | - cuda-nvprune=11.6.124=he22ec0a_0 54 | - cuda-nvrtc=11.6.124=h020bade_0 55 | - cuda-nvrtc-dev=11.6.124=h249d397_0 56 | - cuda-nvtx=11.6.124=h0630a44_0 57 | - cuda-nvvp=12.1.105=0 58 | - cuda-runtime=11.6.1=0 59 | - cuda-samples=11.6.101=h8efea70_0 60 | - cuda-sanitizer-api=12.1.105=0 61 | - cuda-toolkit=11.6.1=0 62 | - cuda-tools=11.6.1=0 63 | - cuda-version=11.7=h67201e3_2 64 | - cuda-visual-tools=11.6.1=0 65 | - cudatoolkit=11.7.0=hd8887f6_10 66 | - cudnn=8.8.0.121=h0800d71_0 67 | - cycler=0.11.0=pyhd8ed1ab_0 68 | - debugpy=1.6.7=py38h8dc9893_0 69 | - decorator=5.1.1=pyhd8ed1ab_0 70 | - dgl=1.1.0.cu116=py38_0 71 | - dssp=3.0.0=h3fd9d12_4 72 | - executing=1.2.0=pyhd8ed1ab_0 73 | - ffmpeg=4.3=hf484d3e_0 74 | - foldseek=7.04e0ec8=pl5321hb365157_0 75 | - freetype=2.12.1=hca18f0e_1 76 | - gawk=5.1.0=h7f98852_0 77 | - gcc=10.3.0=he2824d0_10 78 | - gcc_impl_linux-64=10.3.0=hf2f2afa_16 79 | - gds-tools=1.6.1.9=0 80 | - gettext=0.21.1=h27087fc_0 81 | - gmp=6.2.1=h58526e2_0 82 | - gnutls=3.6.13=h85f3911_1 83 | - gxx=10.3.0=he2824d0_10 84 | - gxx_impl_linux-64=10.3.0=hf2f2afa_16 85 | - hhsuite=3.3.0=py38pl5321hcbe9525_8 86 | - hmmer=3.3.2=hdbdd923_4 87 | - icu=58.2=he6710b0_3 88 | - idna=3.4=py38h06a4308_0 89 | - importlib-metadata=6.6.0=pyha770c72_0 90 | - importlib_metadata=6.6.0=hd8ed1ab_0 91 | - ipykernel=6.23.1=pyh210e3f2_0 92 | - ipython=8.4.0=py38h578d9bd_0 93 | - jedi=0.18.2=pyhd8ed1ab_0 94 | - joblib=1.1.1=py38h06a4308_0 95 | - jpeg=9e=h0b41bf4_3 96 | - jupyter_client=8.2.0=pyhd8ed1ab_0 97 | - jupyter_core=4.12.0=py38h578d9bd_0 98 | - kernel-headers_linux-64=2.6.32=he073ed8_15 99 | - kiwisolver=1.4.4=py38h43d8883_1 100 | - lame=3.100=h166bdaf_1003 101 | - lcms2=2.15=hfd0df8a_0 102 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 103 | - lerc=4.0.0=h27087fc_0 104 | - libblas=3.9.0=16_linux64_openblas 105 | - libboost=1.73.0=h28710b8_12 106 | - libcblas=3.9.0=16_linux64_openblas 107 | - libcublas=11.9.2.110=h5e84587_0 108 | - libcublas-dev=11.9.2.110=h5c901ab_0 109 | - libcufft=10.7.1.112=hf425ae0_0 110 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 111 | - libcufile=1.6.1.9=0 112 | - libcufile-dev=1.6.1.9=0 113 | - libcurand=10.3.2.106=0 114 | - libcurand-dev=10.3.2.106=0 115 | - libcusolver=11.3.4.124=h33c3c4e_0 116 | - libcusparse=11.7.2.124=h7538f96_0 117 | - libcusparse-dev=11.7.2.124=hbbe9722_0 118 | - libdeflate=1.17=h0b41bf4_0 119 | - libffi=3.4.4=h6a678d5_0 120 | - libgcc=7.2.0=h69d50b8_2 121 | - libgcc-devel_linux-64=10.3.0=he6cfe16_16 122 | - libgcc-ng=12.2.0=h65d4601_19 123 | - libgfortran-ng=12.2.0=h69a702a_19 124 | - libgfortran5=12.2.0=h337968e_19 125 | - libgomp=12.2.0=h65d4601_19 126 | - libiconv=1.17=h166bdaf_0 127 | - libidn2=2.3.4=h166bdaf_0 128 | - liblapack=3.9.0=16_linux64_openblas 129 | - libnpp=11.6.3.124=hd2722f0_0 130 | - libnpp-dev=11.6.3.124=h3c42840_0 131 | - libnsl=2.0.0=h5eee18b_0 132 | - libnvjpeg=11.6.2.124=hd473ad6_0 133 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 134 | - libopenblas=0.3.21=h043d6bf_0 135 | - libpng=1.6.39=h753d276_0 136 | - libprotobuf=3.21.12=h3eb15da_0 137 | - libsanitizer=10.3.0=h26c7422_16 138 | - libsodium=1.0.18=h36c2ea0_1 139 | - libsqlite=3.42.0=h2797004_0 140 | - libssh2=1.10.0=haa6b8db_3 141 | - libstdcxx-devel_linux-64=10.3.0=he6cfe16_16 142 | - libstdcxx-ng=12.2.0=h46fd767_19 143 | - libtiff=4.5.0=h6adf6a1_2 144 | - libunistring=0.9.10=h7f98852_0 145 | - libuuid=2.38.1=h0b41bf4_0 146 | - libwebp-base=1.3.0=h0b41bf4_0 147 | - libxcb=1.13=h7f98852_1004 148 | - libxml2=2.9.9=h13577e0_2 149 | - libzlib=1.2.13=h166bdaf_4 150 | - llvm-openmp=16.0.4=h4dfa4b3_0 151 | - lz4-c=1.9.4=h6a678d5_0 152 | - magma=2.6.2=hc72dce7_0 153 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 154 | - mkl=2022.2.1=h84fe81f_16997 155 | - mpi=1.0=openmpi 156 | - msms=2.6.1=h516909a_0 157 | - nccl=2.15.5.1=h0800d71_0 158 | - ncurses=6.4=h6a678d5_0 159 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 160 | - nettle=3.6=he412f7d_0 161 | - networkx=3.1=pyhd8ed1ab_0 162 | - ninja=1.11.1=h924138e_0 163 | - nsight-compute=2023.1.1.4=0 164 | - numexpr=2.8.4=py38hd2a5715_1 165 | - openh264=2.1.1=h780b84a_0 166 | - openjpeg=2.5.0=hfec8fc6_2 167 | - openmpi=4.1.5=h414af15_101 168 | - openssl=1.1.1u=hd590300_0 169 | - packaging=23.0=py38h06a4308_0 170 | - pandas=1.5.3=py38h417a72b_0 171 | - parso=0.8.3=pyhd8ed1ab_0 172 | - perl=5.32.1=0_h5eee18b_perl5 173 | - pexpect=4.8.0=pyh1a96a4e_2 174 | - pickleshare=0.7.5=py_1003 175 | - pillow=9.4.0=py38hde6dc18_1 176 | - pip=23.1.2=py38h06a4308_0 177 | - pooch=1.4.0=pyhd3eb1b0_0 178 | - prompt-toolkit=3.0.38=pyha770c72_0 179 | - psutil=5.9.5=py38h1de0b5d_0 180 | - pthread-stubs=0.4=h36c2ea0_1001 181 | - ptyprocess=0.7.0=pyhd3deb0d_0 182 | - pure_eval=0.2.2=pyhd8ed1ab_0 183 | - pycparser=2.21=pyhd3eb1b0_0 184 | - pygments=2.15.1=pyhd8ed1ab_0 185 | - pyopenssl=23.1.1=pyhd8ed1ab_0 186 | - pysocks=1.7.1=py38h06a4308_0 187 | - python=3.8.16=he550d4f_1_cpython 188 | - python-dateutil=2.8.2=pyhd3eb1b0_0 189 | - python_abi=3.8=2_cp38 190 | - pytorch=1.13.1=cuda112py38hd94e077_200 191 | - pytorch-cuda=11.6=h867d48c_1 192 | - pytorch-mutex=1.0=cuda 193 | - pytz=2022.7=py38h06a4308_0 194 | - pyzmq=25.0.2=py38he24dcef_0 195 | - readline=8.2=h5eee18b_0 196 | - requests=2.29.0=py38h06a4308_0 197 | - scikit-learn=1.2.2=py38h6a678d5_0 198 | - scipy=1.10.1=py38h32ae08f_1 199 | - setuptools=67.8.0=py38h06a4308_0 200 | - six=1.16.0=pyhd3eb1b0_1 201 | - sleef=3.5.1=h9b69904_2 202 | - sqlite=3.41.2=h5eee18b_0 203 | - stack_data=0.6.2=pyhd8ed1ab_0 204 | - sysroot_linux-64=2.12=he073ed8_15 205 | - tbb=2021.7.0=h924138e_0 206 | - threadpoolctl=2.2.0=pyh0d69192_0 207 | - tk=8.6.12=h1ccaba5_0 208 | - torchaudio=0.13.1=py38_cu116 209 | - torchvision=0.14.1=py38_cu116 210 | - tornado=6.3.2=py38h01eb140_0 211 | - tqdm=4.65.0=py38hb070fc8_0 212 | - traitlets=5.9.0=pyhd8ed1ab_0 213 | - typing_extensions=4.6.0=pyha770c72_0 214 | - urllib3=1.26.15=py38h06a4308_0 215 | - wcwidth=0.2.6=pyhd8ed1ab_0 216 | - wheel=0.38.4=py38h06a4308_0 217 | - xorg-libxau=1.0.11=hd590300_0 218 | - xorg-libxdmcp=1.1.3=h7f98852_0 219 | - xz=5.2.6=h166bdaf_0 220 | - zeromq=4.3.4=h9c3ff4c_1 221 | - zipp=3.15.0=pyhd8ed1ab_0 222 | - zlib=1.2.13=h166bdaf_4 223 | - zstd=1.5.2=h3eb15da_6 224 | - pip: 225 | - absl-py==1.4.0 226 | - aiohttp==3.8.4 227 | - aiosignal==1.3.1 228 | - alabaster==0.7.13 229 | - async-timeout==4.0.2 230 | - git+https://github.com/amorehead/atom3.git@83987404ceed38a1f5a5abd517aa38128d0a4f2c 231 | - attrs==23.1.0 232 | - babel==2.12.1 233 | - beautifulsoup4==4.12.2 234 | - biopandas==0.5.0.dev0 235 | - bioservices==1.11.2 236 | - cachetools==5.3.1 237 | - cattrs==23.1.2 238 | - click==7.0 239 | - colorlog==6.7.0 240 | - configparser==5.3.0 241 | - contourpy==1.1.0 242 | - deepdiff==6.3.1 243 | - dill==0.3.3 244 | - docker-pycreds==0.4.0 245 | - docutils==0.17.1 246 | - easy-parallel-py3==0.1.6.4 247 | - easydev==0.12.1 248 | - exceptiongroup==1.1.2 249 | - fairscale==0.4.0 250 | - fonttools==4.40.0 251 | - frozenlist==1.3.3 252 | - fsspec==2023.5.0 253 | - future==0.18.3 254 | - gevent==22.10.2 255 | - gitdb==4.0.10 256 | - gitpython==3.1.31 257 | - google-auth==2.19.0 258 | - google-auth-oauthlib==1.0.0 259 | - git+https://github.com/a-r-j/graphein.git@371ce9a462b610529488e87a712484328a89de36 260 | - greenlet==2.0.2 261 | - grequests==0.7.0 262 | - grpcio==1.54.2 263 | - h5py==3.8.0 264 | - hickle==5.0.2 265 | - imagesize==1.4.1 266 | - importlib-resources==6.0.0 267 | - install==1.3.5 268 | - jaxtyping==0.2.19 269 | - jinja2==2.11.3 270 | - loguru==0.7.0 271 | - looseversion==1.1.2 272 | - lxml==4.9.3 273 | - markdown==3.4.3 274 | - markdown-it-py==3.0.0 275 | - markupsafe==1.1.1 276 | - matplotlib==3.7.2 277 | - mdurl==0.1.2 278 | - mmtf-python==1.1.3 279 | - mpi4py==3.0.3 280 | - msgpack==1.0.5 281 | - multidict==6.0.4 282 | - multipledispatch==1.0.0 283 | - multiprocess==0.70.11.1 284 | - numpy==1.23.5 285 | - oauthlib==3.2.2 286 | - ordered-set==4.1.0 287 | - pathos==0.2.7 288 | - pathtools==0.1.2 289 | - pdb-tools==2.5.0 290 | - platformdirs==3.8.1 291 | - plotly==5.15.0 292 | - pox==0.3.2 293 | - ppft==1.7.6.6 294 | - promise==2.3 295 | - protobuf==3.20.3 296 | - pyasn1==0.5.0 297 | - pyasn1-modules==0.3.0 298 | - pydantic==1.10.11 299 | - pydeprecate==0.3.1 300 | - pyparsing==3.0.9 301 | - pytorch-lightning==1.4.8 302 | - pyyaml==5.4.1 303 | - requests-cache==1.1.0 304 | - requests-oauthlib==1.3.1 305 | - rich==13.4.2 306 | - rich-click==1.6.1 307 | - rsa==4.9 308 | - seaborn==0.12.2 309 | - sentry-sdk==1.24.0 310 | - shortuuid==1.0.11 311 | - smmap==5.0.0 312 | - snowballstemmer==2.2.0 313 | - soupsieve==2.4.1 314 | - subprocess32==3.5.4 315 | - suds-community==1.1.2 316 | - tenacity==8.2.2 317 | - tensorboard==2.13.0 318 | - tensorboard-data-server==0.7.0 319 | - termcolor==2.3.0 320 | - torchmetrics==0.5.1 321 | - typeguard==4.0.0 322 | - url-normalize==1.4.3 323 | - wandb==0.12.2 324 | - werkzeug==2.3.6 325 | - wget==3.2 326 | - wrapt==1.15.0 327 | - xarray==2023.1.0 328 | - xmltodict==0.13.0 329 | - yarl==1.9.2 330 | - yaspin==2.3.0 331 | - zope-event==5.0 332 | - zope-interface==6.0 333 | -------------------------------------------------------------------------------- /notebooks/data_usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Example of data usage" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### Neural network model training" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# -------------------------------------------------------------------------------------------------------------------------------------\n", 26 | "# Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch):\n", 27 | "# -------------------------------------------------------------------------------------------------------------------------------------\n", 28 | "\n", 29 | "import os\n", 30 | "import sys\n", 31 | "from pathlib import Path\n", 32 | "\n", 33 | "import pytorch_lightning as pl\n", 34 | "import torch.nn as nn\n", 35 | "from pytorch_lightning.plugins import DDPPlugin\n", 36 | "\n", 37 | "from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule\n", 38 | "from project.utils.modules import LitNeiA\n", 39 | "from project.utils.training_utils import collect_args, process_args, construct_pl_logger" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def main(args):\n", 49 | " # -----------\n", 50 | " # Data\n", 51 | " # -----------\n", 52 | " # Load Docking Benchmark 5 (DB5) data module\n", 53 | " db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir,\n", 54 | " batch_size=args.batch_size,\n", 55 | " num_dataloader_workers=args.num_workers,\n", 56 | " knn=args.knn,\n", 57 | " self_loops=args.self_loops,\n", 58 | " percent_to_use=args.db5_percent_to_use,\n", 59 | " process_complexes=args.process_complexes,\n", 60 | " input_indep=args.input_indep)\n", 61 | " db5_data_module.setup()\n", 62 | "\n", 63 | " # ------------\n", 64 | " # Model\n", 65 | " # ------------\n", 66 | " # Assemble a dictionary of model arguments\n", 67 | " dict_args = vars(args)\n", 68 | " use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB\n", 69 | "\n", 70 | " # Pick model and supply it with a dictionary of arguments\n", 71 | " if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA)\n", 72 | " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n", 73 | " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n", 74 | " gnn_activ_fn=nn.Tanh(),\n", 75 | " interact_activ_fn=nn.ReLU(),\n", 76 | " num_classes=db5_data_module.db5_test.num_classes,\n", 77 | " weighted_avg=True, # Use the neighborhood weighted average variant of NeiA\n", 78 | " num_gnn_layers=dict_args['num_gnn_layers'],\n", 79 | " num_interact_layers=dict_args['num_interact_layers'],\n", 80 | " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n", 81 | " num_epochs=dict_args['num_epochs'],\n", 82 | " pn_ratio=dict_args['pn_ratio'],\n", 83 | " knn=dict_args['knn'],\n", 84 | " dropout_rate=dict_args['dropout_rate'],\n", 85 | " metric_to_track=dict_args['metric_to_track'],\n", 86 | " weight_decay=dict_args['weight_decay'],\n", 87 | " batch_size=dict_args['batch_size'],\n", 88 | " lr=dict_args['lr'],\n", 89 | " multi_gpu_backend=dict_args[\"accelerator\"])\n", 90 | " args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n", 91 | " f'-n{db5_data_module.db5_test.num_node_features}' \\\n", 92 | " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n", 93 | " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n", 94 | " if not args.experiment_name \\\n", 95 | " else args.experiment_name\n", 96 | " template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}'\n", 97 | "\n", 98 | " else: # Default Model - Neighborhood Average (NeiA)\n", 99 | " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n", 100 | " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n", 101 | " gnn_activ_fn=nn.Tanh(),\n", 102 | " interact_activ_fn=nn.ReLU(),\n", 103 | " num_classes=db5_data_module.db5_test.num_classes,\n", 104 | " weighted_avg=False,\n", 105 | " num_gnn_layers=dict_args['num_gnn_layers'],\n", 106 | " num_interact_layers=dict_args['num_interact_layers'],\n", 107 | " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n", 108 | " num_epochs=dict_args['num_epochs'],\n", 109 | " pn_ratio=dict_args['pn_ratio'],\n", 110 | " knn=dict_args['knn'],\n", 111 | " dropout_rate=dict_args['dropout_rate'],\n", 112 | " metric_to_track=dict_args['metric_to_track'],\n", 113 | " weight_decay=dict_args['weight_decay'],\n", 114 | " batch_size=dict_args['batch_size'],\n", 115 | " lr=dict_args['lr'],\n", 116 | " multi_gpu_backend=dict_args[\"accelerator\"])\n", 117 | " args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n", 118 | " f'-n{db5_data_module.db5_test.num_node_features}' \\\n", 119 | " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n", 120 | " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n", 121 | " if not args.experiment_name \\\n", 122 | " else args.experiment_name\n", 123 | " template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}'\n", 124 | "\n", 125 | " # ------------\n", 126 | " # Checkpoint\n", 127 | " # ------------\n", 128 | " ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)\n", 129 | " ckpt_path_exists = os.path.exists(ckpt_path)\n", 130 | " ckpt_provided = args.ckpt_name != '' and ckpt_path_exists\n", 131 | " model = model.load_from_checkpoint(ckpt_path,\n", 132 | " use_wandb_logger=use_wandb_logger,\n", 133 | " batch_size=args.batch_size,\n", 134 | " lr=args.lr,\n", 135 | " weight_decay=args.weight_decay,\n", 136 | " dropout_rate=args.dropout_rate) if ckpt_provided else model\n", 137 | "\n", 138 | " # ------------\n", 139 | " # Trainer\n", 140 | " # ------------\n", 141 | " trainer = pl.Trainer.from_argparse_args(args)\n", 142 | "\n", 143 | " # -------------\n", 144 | " # Learning Rate\n", 145 | " # -------------\n", 146 | " if args.find_lr:\n", 147 | " lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder\n", 148 | " fig = lr_finder.plot(suggest=True) # Plot learning rates\n", 149 | " fig.savefig('optimal_lr.pdf')\n", 150 | " fig.show()\n", 151 | " model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate\n", 152 | " print(f'Optimal learning rate found: {model.hparams.lr}')\n", 153 | "\n", 154 | " # ------------\n", 155 | " # Logger\n", 156 | " # ------------\n", 157 | " pl_logger = construct_pl_logger(args) # Log everything to an external logger\n", 158 | " trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance\n", 159 | "\n", 160 | " # -----------\n", 161 | " # Callbacks\n", 162 | " # -----------\n", 163 | " # Create and use callbacks\n", 164 | " mode = 'min' if 'ce' in args.metric_to_track else 'max'\n", 165 | " early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,\n", 166 | " mode=mode,\n", 167 | " min_delta=args.min_delta,\n", 168 | " patience=args.patience)\n", 169 | " ckpt_callback = pl.callbacks.ModelCheckpoint(\n", 170 | " monitor=args.metric_to_track,\n", 171 | " mode=mode,\n", 172 | " verbose=True,\n", 173 | " save_last=True,\n", 174 | " save_top_k=3,\n", 175 | " filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs\n", 176 | " )\n", 177 | " lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True)\n", 178 | " trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback]\n", 179 | "\n", 180 | " # ------------\n", 181 | " # Restore\n", 182 | " # ------------\n", 183 | " # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally\n", 184 | " if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path):\n", 185 | " checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best'\n", 186 | " artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model')\n", 187 | " artifact_dir = artifact.download()\n", 188 | " model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt',\n", 189 | " use_wandb_logger=use_wandb_logger,\n", 190 | " batch_size=args.batch_size,\n", 191 | " lr=args.lr,\n", 192 | " weight_decay=args.weight_decay)\n", 193 | "\n", 194 | " # -------------\n", 195 | " # Training\n", 196 | " # -------------\n", 197 | " # Train with the provided model and DataModule\n", 198 | " trainer.fit(model=model, datamodule=db5_data_module)\n", 199 | "\n", 200 | " # -------------\n", 201 | " # Testing\n", 202 | " # -------------\n", 203 | " trainer.test()\n" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# -----------\n", 213 | "# Jupyter\n", 214 | "# -----------\n", 215 | "sys.argv = ['']" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# -----------\n", 225 | "# Arguments\n", 226 | "# -----------\n", 227 | "# Collect all arguments\n", 228 | "parser = collect_args()\n", 229 | "\n", 230 | "# Parse all known and unknown arguments\n", 231 | "args, unparsed_argv = parser.parse_known_args()\n", 232 | "\n", 233 | "# Let the model add what it wants\n", 234 | "parser = LitNeiA.add_model_specific_args(parser)\n", 235 | "\n", 236 | "# Re-parse all known and unknown arguments after adding those that are model specific\n", 237 | "args, unparsed_argv = parser.parse_known_args()\n", 238 | "\n", 239 | "# TODO: Manually set arguments within a Jupyter notebook from here\n", 240 | "args.model_name = \"neia\"\n", 241 | "args.multi_gpu_backend = \"dp\"\n", 242 | "args.db5_data_dir = \"../project/datasets/DB5/final/raw\"\n", 243 | "args.process_complexes = True\n", 244 | "args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation\n", 245 | "\n", 246 | "# Set Lightning-specific parameter values before constructing Trainer instance\n", 247 | "args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes}\n", 248 | "args.max_epochs = args.num_epochs\n", 249 | "args.profiler = args.profiler_method\n", 250 | "args.accelerator = args.multi_gpu_backend\n", 251 | "args.auto_select_gpus = args.auto_choose_gpus\n", 252 | "args.gpus = args.num_gpus\n", 253 | "args.num_nodes = args.num_compute_nodes\n", 254 | "args.precision = args.gpu_precision\n", 255 | "args.accumulate_grad_batches = args.accum_grad_batches\n", 256 | "args.gradient_clip_val = args.grad_clip_val\n", 257 | "args.gradient_clip_algo = args.grad_clip_algo\n", 258 | "args.stochastic_weight_avg = args.stc_weight_avg\n", 259 | "args.deterministic = True # Make LightningModule's training reproducible\n", 260 | "\n", 261 | "# Set plugins for Lightning\n", 262 | "args.plugins = [\n", 263 | " # 'ddp_sharded', # For sharded model training (to reduce GPU requirements)\n", 264 | " # DDPPlugin(find_unused_parameters=False),\n", 265 | "]\n", 266 | "\n", 267 | "# Finalize all arguments as necessary\n", 268 | "args = process_args(args)\n", 269 | "\n", 270 | "# Begin execution of model training with given args above\n", 271 | "main(args)" 272 | ] 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "DIPS-Plus", 278 | "language": "python", 279 | "name": "python3" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 3 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython3", 291 | "version": "3.8.16" 292 | }, 293 | "orig_nbformat": 4 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | -------------------------------------------------------------------------------- /notebooks/data_usage.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Example of data usage 3 | 4 | # %% [markdown] 5 | # ### Neural network model training 6 | 7 | # %% 8 | # ------------------------------------------------------------------------------------------------------------------------------------- 9 | # Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): 10 | # ------------------------------------------------------------------------------------------------------------------------------------- 11 | 12 | import os 13 | import sys 14 | from pathlib import Path 15 | 16 | import pytorch_lightning as pl 17 | import torch.nn as nn 18 | from pytorch_lightning.plugins import DDPPlugin 19 | 20 | from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule 21 | from project.utils.modules import LitNeiA 22 | from project.utils.training_utils import collect_args, process_args, construct_pl_logger 23 | 24 | # %% 25 | def main(args): 26 | # ----------- 27 | # Data 28 | # ----------- 29 | # Load Docking Benchmark 5 (DB5) data module 30 | db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir, 31 | batch_size=args.batch_size, 32 | num_dataloader_workers=args.num_workers, 33 | knn=args.knn, 34 | self_loops=args.self_loops, 35 | percent_to_use=args.db5_percent_to_use, 36 | process_complexes=args.process_complexes, 37 | input_indep=args.input_indep) 38 | db5_data_module.setup() 39 | 40 | # ------------ 41 | # Model 42 | # ------------ 43 | # Assemble a dictionary of model arguments 44 | dict_args = vars(args) 45 | use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB 46 | 47 | # Pick model and supply it with a dictionary of arguments 48 | if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA) 49 | model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features, 50 | num_edge_input_feats=db5_data_module.db5_test.num_edge_features, 51 | gnn_activ_fn=nn.Tanh(), 52 | interact_activ_fn=nn.ReLU(), 53 | num_classes=db5_data_module.db5_test.num_classes, 54 | weighted_avg=True, # Use the neighborhood weighted average variant of NeiA 55 | num_gnn_layers=dict_args['num_gnn_layers'], 56 | num_interact_layers=dict_args['num_interact_layers'], 57 | num_interact_hidden_channels=dict_args['num_interact_hidden_channels'], 58 | num_epochs=dict_args['num_epochs'], 59 | pn_ratio=dict_args['pn_ratio'], 60 | knn=dict_args['knn'], 61 | dropout_rate=dict_args['dropout_rate'], 62 | metric_to_track=dict_args['metric_to_track'], 63 | weight_decay=dict_args['weight_decay'], 64 | batch_size=dict_args['batch_size'], 65 | lr=dict_args['lr'], 66 | multi_gpu_backend=dict_args["accelerator"]) 67 | args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \ 68 | f'-n{db5_data_module.db5_test.num_node_features}' \ 69 | f'-e{db5_data_module.db5_test.num_edge_features}' \ 70 | f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \ 71 | if not args.experiment_name \ 72 | else args.experiment_name 73 | template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}' 74 | 75 | else: # Default Model - Neighborhood Average (NeiA) 76 | model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features, 77 | num_edge_input_feats=db5_data_module.db5_test.num_edge_features, 78 | gnn_activ_fn=nn.Tanh(), 79 | interact_activ_fn=nn.ReLU(), 80 | num_classes=db5_data_module.db5_test.num_classes, 81 | weighted_avg=False, 82 | num_gnn_layers=dict_args['num_gnn_layers'], 83 | num_interact_layers=dict_args['num_interact_layers'], 84 | num_interact_hidden_channels=dict_args['num_interact_hidden_channels'], 85 | num_epochs=dict_args['num_epochs'], 86 | pn_ratio=dict_args['pn_ratio'], 87 | knn=dict_args['knn'], 88 | dropout_rate=dict_args['dropout_rate'], 89 | metric_to_track=dict_args['metric_to_track'], 90 | weight_decay=dict_args['weight_decay'], 91 | batch_size=dict_args['batch_size'], 92 | lr=dict_args['lr'], 93 | multi_gpu_backend=dict_args["accelerator"]) 94 | args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \ 95 | f'-n{db5_data_module.db5_test.num_node_features}' \ 96 | f'-e{db5_data_module.db5_test.num_edge_features}' \ 97 | f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \ 98 | if not args.experiment_name \ 99 | else args.experiment_name 100 | template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}' 101 | 102 | # ------------ 103 | # Checkpoint 104 | # ------------ 105 | ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name) 106 | ckpt_path_exists = os.path.exists(ckpt_path) 107 | ckpt_provided = args.ckpt_name != '' and ckpt_path_exists 108 | model = model.load_from_checkpoint(ckpt_path, 109 | use_wandb_logger=use_wandb_logger, 110 | batch_size=args.batch_size, 111 | lr=args.lr, 112 | weight_decay=args.weight_decay, 113 | dropout_rate=args.dropout_rate) if ckpt_provided else model 114 | 115 | # ------------ 116 | # Trainer 117 | # ------------ 118 | trainer = pl.Trainer.from_argparse_args(args) 119 | 120 | # ------------- 121 | # Learning Rate 122 | # ------------- 123 | if args.find_lr: 124 | lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder 125 | fig = lr_finder.plot(suggest=True) # Plot learning rates 126 | fig.savefig('optimal_lr.pdf') 127 | fig.show() 128 | model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate 129 | print(f'Optimal learning rate found: {model.hparams.lr}') 130 | 131 | # ------------ 132 | # Logger 133 | # ------------ 134 | pl_logger = construct_pl_logger(args) # Log everything to an external logger 135 | trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance 136 | 137 | # ----------- 138 | # Callbacks 139 | # ----------- 140 | # Create and use callbacks 141 | mode = 'min' if 'ce' in args.metric_to_track else 'max' 142 | early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track, 143 | mode=mode, 144 | min_delta=args.min_delta, 145 | patience=args.patience) 146 | ckpt_callback = pl.callbacks.ModelCheckpoint( 147 | monitor=args.metric_to_track, 148 | mode=mode, 149 | verbose=True, 150 | save_last=True, 151 | save_top_k=3, 152 | filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs 153 | ) 154 | lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True) 155 | trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback] 156 | 157 | # ------------ 158 | # Restore 159 | # ------------ 160 | # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally 161 | if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path): 162 | checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best' 163 | artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model') 164 | artifact_dir = artifact.download() 165 | model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt', 166 | use_wandb_logger=use_wandb_logger, 167 | batch_size=args.batch_size, 168 | lr=args.lr, 169 | weight_decay=args.weight_decay) 170 | 171 | # ------------- 172 | # Training 173 | # ------------- 174 | # Train with the provided model and DataModule 175 | trainer.fit(model=model, datamodule=db5_data_module) 176 | 177 | # ------------- 178 | # Testing 179 | # ------------- 180 | trainer.test() 181 | 182 | 183 | # %% 184 | # ----------- 185 | # Jupyter 186 | # ----------- 187 | # sys.argv = [''] 188 | 189 | # %% 190 | # ----------- 191 | # Arguments 192 | # ----------- 193 | # Collect all arguments 194 | parser = collect_args() 195 | 196 | # Parse all known and unknown arguments 197 | args, unparsed_argv = parser.parse_known_args() 198 | 199 | # Let the model add what it wants 200 | parser = LitNeiA.add_model_specific_args(parser) 201 | 202 | # Re-parse all known and unknown arguments after adding those that are model specific 203 | args, unparsed_argv = parser.parse_known_args() 204 | 205 | # TODO: Manually set arguments within a Jupyter notebook from here 206 | args.model_name = "neia" 207 | args.multi_gpu_backend = "dp" 208 | args.db5_data_dir = "project/datasets/DB5/final/raw" 209 | args.process_complexes = True 210 | args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation 211 | 212 | # Set Lightning-specific parameter values before constructing Trainer instance 213 | args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes} 214 | args.max_epochs = args.num_epochs 215 | args.profiler = args.profiler_method 216 | args.accelerator = args.multi_gpu_backend 217 | args.auto_select_gpus = args.auto_choose_gpus 218 | args.gpus = args.num_gpus 219 | args.num_nodes = args.num_compute_nodes 220 | args.precision = args.gpu_precision 221 | args.accumulate_grad_batches = args.accum_grad_batches 222 | args.gradient_clip_val = args.grad_clip_val 223 | args.gradient_clip_algo = args.grad_clip_algo 224 | args.stochastic_weight_avg = args.stc_weight_avg 225 | args.deterministic = True # Make LightningModule's training reproducible 226 | 227 | # Set plugins for Lightning 228 | args.plugins = [ 229 | # 'ddp_sharded', # For sharded model training (to reduce GPU requirements) 230 | # DDPPlugin(find_unused_parameters=False), 231 | ] 232 | 233 | # Finalize all arguments as necessary 234 | args = process_args(args) 235 | 236 | # Begin execution of model training with given args above 237 | main(args) 238 | 239 | 240 | -------------------------------------------------------------------------------- /notebooks/feature_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Feature generation for PDB file inputs" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import os\n", 18 | "\n", 19 | "import atom3.complex as comp\n", 20 | "import atom3.conservation as con\n", 21 | "import atom3.neighbors as nb\n", 22 | "import atom3.pair as pair\n", 23 | "import atom3.parse as parse\n", 24 | "import dill as pickle\n", 25 | "\n", 26 | "from pathlib import Path\n", 27 | "\n", 28 | "from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs" 29 | ] 30 | }, 31 | { 32 | "attachments": {}, 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "### 1. Parse PDB file input to pair-wise features" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "pdb_filename = \"../project/datasets/Input/raw/pdb/2g/12gs.pdb1\" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip`\n", 46 | "output_pkl = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n", 47 | "complexes_dill = \"../project/datasets/Input/interim/complexes/complexes.dill\"\n", 48 | "pairs_dir = \"../project/datasets/Input/interim/pairs\"\n", 49 | "pkl_filenames = [output_pkl]\n", 50 | "source_type = \"rcsb\" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs)\n", 51 | "neighbor_def = \"non_heavy_res\"\n", 52 | "cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here\n", 53 | "unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False`\n", 54 | "\n", 55 | "for item in [\n", 56 | " Path(pdb_filename).parent,\n", 57 | " Path(output_pkl).parent,\n", 58 | " Path(complexes_dill).parent,\n", 59 | " pairs_dir,\n", 60 | "]:\n", 61 | " os.makedirs(item, exist_ok=True)\n", 62 | "\n", 63 | "# note: the following replicates the logic within `make_dataset.py` for a single PDB file input\n", 64 | "parse.parse(\n", 65 | " # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed\n", 66 | " pdb_filename=pdb_filename,\n", 67 | " output_pkl=output_pkl\n", 68 | ")\n", 69 | "complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type)\n", 70 | "comp.write_complexes(complexes=complexes, output_dill=complexes_dill)\n", 71 | "get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff)\n", 72 | "get_pairs = pair.build_get_pairs(\n", 73 | " neighbor_def=neighbor_def,\n", 74 | " type=source_type,\n", 75 | " unbound=unbound,\n", 76 | " nb_fn=get_neighbors,\n", 77 | " full=False\n", 78 | ")\n", 79 | "complexes = comp.read_complexes(input_dill=complexes_dill)\n", 80 | "pair.complex_to_pairs(\n", 81 | " complex=list(complexes['data'].values())[0],\n", 82 | " source_type=source_type,\n", 83 | " get_pairs=get_pairs,\n", 84 | " output_dir=pairs_dir\n", 85 | ")" 86 | ] 87 | }, 88 | { 89 | "attachments": {}, 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### 2. Compute sequence-based features using external tools" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "psaia_dir = \"~/Programs/PSAIA-1.0/bin/linux/psa\" # note: replace this with the path to your local installation of PSAIA\n", 103 | "psaia_config_file = \"../project/datasets/builder/psaia_config_file_dips.txt\" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above\n", 104 | "file_list_file = os.path.join(\"../project/datasets/Input/interim/external_feats/\", 'PSAIA', source_type.upper(), 'pdb_list.fls')\n", 105 | "num_cpus = 8\n", 106 | "pkl_filename = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n", 107 | "output_filename = \"../project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl\"\n", 108 | "hhsuite_db = \"~/Data/Databases/pfamA_35.0/pfam\" # note: substitute the path to your local HHsuite3 database here\n", 109 | "num_iter = 2\n", 110 | "msa_only = False\n", 111 | "\n", 112 | "for item in [\n", 113 | " Path(file_list_file).parent,\n", 114 | " Path(output_filename).parent,\n", 115 | "]:\n", 116 | " os.makedirs(item, exist_ok=True)\n", 117 | "\n", 118 | "# note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input\n", 119 | "with open(file_list_file, 'w') as file:\n", 120 | " file.write(f'{pdb_filename}\\n') # note: references the `pdb_filename` as defined previously\n", 121 | "con.gen_protrusion_index(\n", 122 | " psaia_dir=psaia_dir,\n", 123 | " psaia_config_file=psaia_config_file,\n", 124 | " file_list_file=file_list_file,\n", 125 | ")\n", 126 | "con.map_profile_hmms(\n", 127 | " num_cpus=num_cpus,\n", 128 | " pkl_filename=pkl_filename,\n", 129 | " output_filename=output_filename,\n", 130 | " hhsuite_db=hhsuite_db,\n", 131 | " source_type=source_type,\n", 132 | " num_iter=num_iter,\n", 133 | " msa_only=msa_only,\n", 134 | ")" 135 | ] 136 | }, 137 | { 138 | "attachments": {}, 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "### 3. Compute structure-based features" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "from project.utils.utils import __should_keep_postprocessed\n", 152 | "\n", 153 | "\n", 154 | "raw_pdb_dir = \"../project/datasets/Input/raw/pdb\"\n", 155 | "pair_filename = \"../project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill\"\n", 156 | "source_type = \"rcsb\"\n", 157 | "external_feats_dir = \"../project/datasets/Input/interim/external_feats/parsed\"\n", 158 | "output_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", 159 | "\n", 160 | "unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type)\n", 161 | "if should_keep:\n", 162 | " # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired\n", 163 | " postprocessed_pair = postprocess_pruned_pair(\n", 164 | " raw_pdb_filenames=raw_pdb_filenames,\n", 165 | " external_feats_dir=external_feats_dir,\n", 166 | " original_pair=unprocessed_pair,\n", 167 | " source_type=source_type,\n", 168 | " )\n", 169 | " # write into output_filenames if not exist\n", 170 | " os.makedirs(Path(output_filename).parent, exist_ok=True)\n", 171 | " with open(output_filename, 'wb') as f:\n", 172 | " pickle.dump(postprocessed_pair, f)" 173 | ] 174 | }, 175 | { 176 | "attachments": {}, 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### 4. Embed deep learning-based IDR features" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model\n", 190 | "!docker pull docker.io/sinaghadermarzi/fldpnn\n", 191 | "\n", 192 | "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", 193 | "pickle_filepaths = [input_pair_filename]\n", 194 | "\n", 195 | "annotate_idr_residues(\n", 196 | " pickle_filepaths=pickle_filepaths\n", 197 | ")" 198 | ] 199 | }, 200 | { 201 | "attachments": {}, 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "### 5. Impute missing feature values (optional)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", 215 | "output_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n", 216 | "impute_atom_features = False\n", 217 | "advanced_logging = False\n", 218 | "\n", 219 | "impute_missing_feature_values(\n", 220 | " input_pair_filename=input_pair_filename,\n", 221 | " output_pair_filename=output_pair_filename,\n", 222 | " impute_atom_features=impute_atom_features,\n", 223 | " advanced_logging=advanced_logging,\n", 224 | ")" 225 | ] 226 | }, 227 | { 228 | "attachments": {}, 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "### 6. Convert pair-wise features into graph inputs (optional)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "raw_filepath = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n", 242 | "new_graph_dir = \"../project/datasets/Input/final/processed/2g\"\n", 243 | "processed_filepath = \"../project/datasets/Input/final/processed/2g/12gs.pdb1.pt\"\n", 244 | "edge_dist_cutoff = 15.0\n", 245 | "edge_limit = 5000\n", 246 | "self_loops = True\n", 247 | "\n", 248 | "os.makedirs(new_graph_dir, exist_ok=True)\n", 249 | "\n", 250 | "process_raw_file_into_dgl_graphs(\n", 251 | " raw_filepath=raw_filepath,\n", 252 | " new_graph_dir=new_graph_dir,\n", 253 | " processed_filepath=processed_filepath,\n", 254 | " edge_dist_cutoff=edge_dist_cutoff,\n", 255 | " edge_limit=edge_limit,\n", 256 | " self_loops=self_loops,\n", 257 | ")" 258 | ] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "DIPS-Plus", 264 | "language": "python", 265 | "name": "python3" 266 | }, 267 | "language_info": { 268 | "codemirror_mode": { 269 | "name": "ipython", 270 | "version": 3 271 | }, 272 | "file_extension": ".py", 273 | "mimetype": "text/x-python", 274 | "name": "python", 275 | "nbconvert_exporter": "python", 276 | "pygments_lexer": "ipython3", 277 | "version": "3.8.16" 278 | }, 279 | "orig_nbformat": 4 280 | }, 281 | "nbformat": 4, 282 | "nbformat_minor": 2 283 | } 284 | -------------------------------------------------------------------------------- /notebooks/feature_generation.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # Feature generation for PDB file inputs 3 | 4 | # %% 5 | import os 6 | 7 | import atom3.complex as comp 8 | import atom3.conservation as con 9 | import atom3.neighbors as nb 10 | import atom3.pair as pair 11 | import atom3.parse as parse 12 | import dill as pickle 13 | 14 | from pathlib import Path 15 | 16 | from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs 17 | 18 | # %% [markdown] 19 | # ### 1. Parse PDB file input to pair-wise features 20 | 21 | # %% 22 | pdb_filename = "project/datasets/Input/raw/pdb/2g/12gs.pdb1" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip` 23 | output_pkl = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl" 24 | complexes_dill = "project/datasets/Input/interim/complexes/complexes.dill" 25 | pairs_dir = "project/datasets/Input/interim/pairs" 26 | pkl_filenames = [output_pkl] 27 | source_type = "rcsb" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs) 28 | neighbor_def = "non_heavy_res" 29 | cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here 30 | unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False` 31 | 32 | for item in [ 33 | Path(pdb_filename).parent, 34 | Path(output_pkl).parent, 35 | Path(complexes_dill).parent, 36 | pairs_dir, 37 | ]: 38 | os.makedirs(item, exist_ok=True) 39 | 40 | # note: the following replicates the logic within `make_dataset.py` for a single PDB file input 41 | parse.parse( 42 | # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed 43 | pdb_filename=pdb_filename, 44 | output_pkl=output_pkl 45 | ) 46 | complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type) 47 | comp.write_complexes(complexes=complexes, output_dill=complexes_dill) 48 | get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff) 49 | get_pairs = pair.build_get_pairs( 50 | neighbor_def=neighbor_def, 51 | type=source_type, 52 | unbound=unbound, 53 | nb_fn=get_neighbors, 54 | full=False 55 | ) 56 | complexes = comp.read_complexes(input_dill=complexes_dill) 57 | pair.complex_to_pairs( 58 | complex=list(complexes['data'].values())[0], 59 | source_type=source_type, 60 | get_pairs=get_pairs, 61 | output_dir=pairs_dir 62 | ) 63 | 64 | # %% [markdown] 65 | # ### 2. Compute sequence-based features using external tools 66 | 67 | # %% 68 | psaia_dir = "~/Programs/PSAIA-1.0/bin/linux/psa" # note: replace this with the path to your local installation of PSAIA 69 | psaia_config_file = "project/datasets/builder/psaia_config_file_dips.txt" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above 70 | file_list_file = os.path.join("project/datasets/Input/interim/external_feats/", 'PSAIA', source_type.upper(), 'pdb_list.fls') 71 | num_cpus = 8 72 | pkl_filename = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl" 73 | output_filename = "project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl" 74 | hhsuite_db = "~/Data/Databases/pfamA_35.0/pfam" # note: substitute the path to your local HHsuite3 database here 75 | num_iter = 2 76 | msa_only = False 77 | 78 | for item in [ 79 | Path(file_list_file).parent, 80 | Path(output_filename).parent, 81 | ]: 82 | os.makedirs(item, exist_ok=True) 83 | 84 | # note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input 85 | with open(file_list_file, 'w') as file: 86 | file.write(f'{pdb_filename}\n') # note: references the `pdb_filename` as defined previously 87 | con.gen_protrusion_index( 88 | psaia_dir=psaia_dir, 89 | psaia_config_file=psaia_config_file, 90 | file_list_file=file_list_file, 91 | ) 92 | con.map_profile_hmms( 93 | num_cpus=num_cpus, 94 | pkl_filename=pkl_filename, 95 | output_filename=output_filename, 96 | hhsuite_db=hhsuite_db, 97 | source_type=source_type, 98 | num_iter=num_iter, 99 | msa_only=msa_only, 100 | ) 101 | 102 | # %% [markdown] 103 | # ### 3. Compute structure-based features 104 | 105 | # %% 106 | from project.utils.utils import __should_keep_postprocessed 107 | 108 | 109 | raw_pdb_dir = "project/datasets/Input/raw/pdb" 110 | pair_filename = "project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill" 111 | source_type = "rcsb" 112 | external_feats_dir = "project/datasets/Input/interim/external_feats/parsed" 113 | output_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" 114 | 115 | unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type) 116 | if should_keep: 117 | # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired 118 | postprocessed_pair = postprocess_pruned_pair( 119 | raw_pdb_filenames=raw_pdb_filenames, 120 | external_feats_dir=external_feats_dir, 121 | original_pair=unprocessed_pair, 122 | source_type=source_type, 123 | ) 124 | # write into output_filenames if not exist 125 | os.makedirs(Path(output_filename).parent, exist_ok=True) 126 | with open(output_filename, 'wb') as f: 127 | pickle.dump(postprocessed_pair, f) 128 | 129 | # %% [markdown] 130 | # ### 4. Embed deep learning-based IDR features 131 | 132 | # %% 133 | # note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model 134 | # !docker pull docker.io/sinaghadermarzi/fldpnn 135 | 136 | input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" 137 | pickle_filepaths = [input_pair_filename] 138 | 139 | annotate_idr_residues( 140 | pickle_filepaths=pickle_filepaths 141 | ) 142 | 143 | # %% [markdown] 144 | # ### 5. Impute missing feature values (optional) 145 | 146 | # %% 147 | input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" 148 | output_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill" 149 | impute_atom_features = False 150 | advanced_logging = False 151 | 152 | impute_missing_feature_values( 153 | input_pair_filename=input_pair_filename, 154 | output_pair_filename=output_pair_filename, 155 | impute_atom_features=impute_atom_features, 156 | advanced_logging=advanced_logging, 157 | ) 158 | 159 | # %% [markdown] 160 | # ### 6. Convert pair-wise features into graph inputs (optional) 161 | 162 | # %% 163 | raw_filepath = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill" 164 | new_graph_dir = "project/datasets/Input/final/processed/2g" 165 | processed_filepath = "project/datasets/Input/final/processed/2g/12gs.pdb1.pt" 166 | edge_dist_cutoff = 15.0 167 | edge_limit = 5000 168 | self_loops = True 169 | 170 | os.makedirs(new_graph_dir, exist_ok=True) 171 | 172 | process_raw_file_into_dgl_graphs( 173 | raw_filepath=raw_filepath, 174 | new_graph_dir=new_graph_dir, 175 | processed_filepath=processed_filepath, 176 | edge_dist_cutoff=edge_dist_cutoff, 177 | edge_limit=edge_limit, 178 | self_loops=self_loops, 179 | ) 180 | 181 | 182 | -------------------------------------------------------------------------------- /project/datasets/DB5/README: -------------------------------------------------------------------------------- 1 | Cleaned up version of Docking Benchmark 5 (https://zlab.umassmed.edu/benchmark/). 2 | 3 | Released with "End-to-End Learning on 3D Protein Structure for Interface Prediction." 4 | by Raphael J.L. Townshend, Rishi Bedi, Patricia Suriana, Ron O. Dror 5 | https://arxiv.org/abs/1807.01297 6 | 7 | Specifically, bound chains and residue indexes were aligned across unbound and bound complexes. 8 | 9 | A total of 230 binary protein complexes are included. 10 | 11 | Processing code to regenerate and use the provided tfrecords is located at 12 | https://github.com/drorlab/DIPS 13 | 14 | MANIFEST 15 | 16 | raw/ - All pre-aligned and cleaned DB5 structures, organized into directories 17 | with individual files for ligand-unbound, ligand-bound, receptor-unbound, 18 | receptor-bound. 19 | interim/ 20 | parsed/ - All DB5 structures processed to pickled dataframes. 21 | complexes/ - List of all possible pairs in parsed. 22 | pairs/ - Dill files of indivudal pairs listed in complexes. 23 | processed/ 24 | tfrecords/ - pairs converted to tfrecords. 25 | -------------------------------------------------------------------------------- /project/datasets/DB5/db5_dgl_data_module.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | 6 | from project.datasets.DB5.db5_dgl_dataset import DB5DGLDataset 7 | 8 | 9 | class DB5DGLDataModule(LightningDataModule): 10 | """Unbound protein complex data module for DGL with PyTorch.""" 11 | 12 | # Dataset partition instantiation 13 | db5_train = None 14 | db5_val = None 15 | db5_test = None 16 | 17 | def __init__(self, data_dir: str, batch_size: int, num_dataloader_workers: int, knn: int, 18 | self_loops: bool, percent_to_use: float, process_complexes: bool, input_indep: bool): 19 | super().__init__() 20 | 21 | self.data_dir = data_dir 22 | self.batch_size = batch_size 23 | self.num_dataloader_workers = num_dataloader_workers 24 | self.knn = knn 25 | self.self_loops = self_loops 26 | self.percent_to_use = percent_to_use # Fraction of DB5 dataset to use 27 | self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training 28 | self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model 29 | 30 | def setup(self, stage: Optional[str] = None): 31 | # Assign training/validation/testing data set for use in DataLoaders - called on every GPU 32 | self.db5_train = DB5DGLDataset(mode='train', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, 33 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, 34 | input_indep=self.input_indep) 35 | self.db5_val = DB5DGLDataset(mode='val', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, 36 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, 37 | input_indep=self.input_indep) 38 | self.db5_test = DB5DGLDataset(mode='test', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, 39 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, 40 | input_indep=self.input_indep) 41 | 42 | def train_dataloader(self) -> DataLoader: 43 | return DataLoader(self.db5_train, batch_size=self.batch_size, shuffle=True, 44 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) 45 | 46 | def val_dataloader(self) -> DataLoader: 47 | return DataLoader(self.db5_val, batch_size=self.batch_size, shuffle=False, 48 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) 49 | 50 | def test_dataloader(self) -> DataLoader: 51 | return DataLoader(self.db5_test, batch_size=self.batch_size, shuffle=False, 52 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) 53 | -------------------------------------------------------------------------------- /project/datasets/DB5/db5_dgl_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import pandas as pd 5 | from dgl.data import DGLDataset, download, check_sha1 6 | 7 | from project.utils.training_utils import construct_filenames_frame_txt_filenames, build_filenames_frame_error_message, process_complex_into_dict, zero_out_complex_features 8 | 9 | 10 | class DB5DGLDataset(DGLDataset): 11 | r"""Unbound protein complex dataset for DGL with PyTorch. 12 | 13 | Statistics: 14 | 15 | - Train examples: 140 16 | - Validation examples: 35 17 | - Test examples: 55 18 | - Number of structures per complex: 2 19 | ---------------------- 20 | - Total examples: 230 21 | ---------------------- 22 | 23 | Parameters 24 | ---------- 25 | mode: str, optional 26 | Should be one of ['train', 'val', 'test']. Default: 'train'. 27 | raw_dir: str 28 | Raw file directory to download/contains the input data directory. Default: 'final/raw'. 29 | knn: int 30 | How many nearest neighbors to which to connect a given node. Default: 20. 31 | self_loops: bool 32 | Whether to connect a given node to itself. Default: True. 33 | percent_to_use: float 34 | How much of the dataset to load. Default: 1.0. 35 | process_complexes: bool 36 | Whether to process each unprocessed complex as we load in the dataset. Default: True. 37 | input_indep: bool 38 | Whether to zero-out each input node and edge feature for an input-independent baseline. Default: False. 39 | force_reload: bool 40 | Whether to reload the dataset. Default: False. 41 | verbose: bool 42 | Whether to print out progress information. Default: False. 43 | 44 | Notes 45 | ----- 46 | All the samples will be loaded and preprocessed in the memory first. 47 | 48 | Examples 49 | -------- 50 | >>> # Get dataset 51 | >>> train_data = DB5DGLDataset() 52 | >>> val_data = DB5DGLDataset(mode='val') 53 | >>> test_data = DB5DGLDataset(mode='test') 54 | >>> 55 | >>> len(test_data) 56 | 55 57 | >>> test_data.num_chains 58 | 2 59 | """ 60 | 61 | def __init__(self, 62 | mode='test', 63 | raw_dir=f'final{os.sep}raw', 64 | knn=20, 65 | self_loops=True, 66 | percent_to_use=1.0, 67 | process_complexes=True, 68 | input_indep=False, 69 | force_reload=False, 70 | verbose=False): 71 | assert mode in ['train', 'val', 'test'] 72 | assert 0.0 < percent_to_use <= 1.0 73 | self.mode = mode 74 | self.root = raw_dir 75 | self.knn = knn 76 | self.self_loops = self_loops 77 | self.percent_to_use = percent_to_use # How much of the DB5 dataset to use 78 | self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training 79 | self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model 80 | self.final_dir = os.path.join(*self.root.split(os.sep)[:-1]) 81 | self.processed_dir = os.path.join(self.final_dir, 'processed') 82 | 83 | self.filename_sampling = 0.0 < self.percent_to_use < 1.0 84 | self.base_txt_filename, self.filenames_frame_txt_filename, self.filenames_frame_txt_filepath = \ 85 | construct_filenames_frame_txt_filenames(self.mode, self.percent_to_use, self.filename_sampling, self.root) 86 | 87 | # Try to load the text file containing all DB5 filenames, and alert the user if it is missing or corrupted 88 | filenames_frame_to_be_written = not os.path.exists(self.filenames_frame_txt_filepath) 89 | 90 | # Randomly sample DataFrame of filenames with requested cross validation ratio 91 | if self.filename_sampling: 92 | if filenames_frame_to_be_written: 93 | try: 94 | self.filenames_frame = pd.read_csv( 95 | os.path.join(self.root, self.base_txt_filename + '.txt'), header=None) 96 | except Exception: 97 | raise FileNotFoundError( 98 | build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath)) 99 | self.filenames_frame = self.filenames_frame.sample(frac=self.percent_to_use).reset_index() 100 | try: 101 | self.filenames_frame[0].to_csv(self.filenames_frame_txt_filepath, header=None, index=None) 102 | except Exception: 103 | raise Exception( 104 | build_filenames_frame_error_message('DB5-Plus', 'write', self.filenames_frame_txt_filepath)) 105 | 106 | # Load in existing DataFrame of filenames as requested (or if a sampled DataFrame .txt has already been written) 107 | if not filenames_frame_to_be_written: 108 | try: 109 | self.filenames_frame = pd.read_csv(self.filenames_frame_txt_filepath, header=None) 110 | except Exception: 111 | raise FileNotFoundError( 112 | build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath)) 113 | 114 | # Process any unprocessed examples prior to using the dataset 115 | self.process() 116 | 117 | super(DB5DGLDataset, self).__init__(name='DB5-Plus', 118 | raw_dir=raw_dir, 119 | force_reload=force_reload, 120 | verbose=verbose) 121 | print(f"Loaded DB5-Plus {mode}-set, source: {self.processed_dir}, length: {len(self)}") 122 | 123 | def download(self): 124 | """Download and extract a pre-packaged version of the raw pairs if 'self.raw_dir' is not already populated.""" 125 | # Path to store the file 126 | gz_file_path = os.path.join(os.path.join(*self.raw_dir.split(os.sep)[:-1]), 'final_raw_db5.tar.gz') 127 | 128 | # Download file 129 | download(self.url, path=gz_file_path) 130 | 131 | # Check SHA-1 132 | if not check_sha1(gz_file_path, self._sha1_str): 133 | raise UserWarning('File {} is downloaded but the content hash does not match.' 134 | 'The repo may be outdated or download may be incomplete. ' 135 | 'Otherwise you can create an issue for it.'.format(gz_file_path)) 136 | 137 | # Remove existing raw directory to make way for the new archive to be extracted 138 | if os.path.exists(self.raw_dir): 139 | os.removedirs(self.raw_dir) 140 | 141 | # Extract archive to parent directory of `self.raw_dir` 142 | self._extract_gz(gz_file_path, os.path.join(*self.raw_dir.split(os.sep)[:-1])) 143 | 144 | def process(self): 145 | """Process each protein complex into a testing-ready dictionary representing both structures.""" 146 | if self.process_complexes: 147 | # Ensure the directory of processed complexes is already created 148 | os.makedirs(self.processed_dir, exist_ok=True) 149 | # Process each unprocessed protein complex 150 | for (i, raw_path) in self.filenames_frame.iterrows(): 151 | raw_filepath = os.path.join(self.root, f'{os.path.splitext(raw_path[0])[0]}.dill') 152 | processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill') 153 | if not os.path.exists(processed_filepath): 154 | processed_parent_dir_to_make = os.path.join(self.processed_dir, os.path.split(raw_path[0])[0]) 155 | os.makedirs(processed_parent_dir_to_make, exist_ok=True) 156 | process_complex_into_dict(raw_filepath, processed_filepath, self.knn, self.self_loops, False) 157 | 158 | def has_cache(self): 159 | """Check if each complex is downloaded and available for training, validation, or testing.""" 160 | for (i, raw_path) in self.filenames_frame.iterrows(): 161 | processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill') 162 | if not os.path.exists(processed_filepath): 163 | print( 164 | f'Unable to load at least one processed DB5 pair. ' 165 | f'Please make sure all processed pairs have been successfully downloaded and are not corrupted.') 166 | raise FileNotFoundError 167 | print('DB5 cache found') # Otherwise, a cache was found! 168 | 169 | def __getitem__(self, idx): 170 | r""" Get feature dictionary by index of complex. 171 | 172 | Parameters 173 | ---------- 174 | idx : int 175 | 176 | Returns 177 | ------- 178 | :class:`dict` 179 | 180 | - ``complex['graph1_node_feats']:`` PyTorch Tensor containing each of the first graph's encoded node features 181 | - ``complex['graph2_node_feats']``: PyTorch Tensor containing each of the second graph's encoded node features 182 | - ``complex['graph1_node_coords']:`` PyTorch Tensor containing each of the first graph's node coordinates 183 | - ``complex['graph2_node_coords']``: PyTorch Tensor containing each of the second graph's node coordinates 184 | - ``complex['graph1_edge_feats']:`` PyTorch Tensor containing each of the first graph's edge features for each node 185 | - ``complex['graph2_edge_feats']:`` PyTorch Tensor containing each of the second graph's edge features for each node 186 | - ``complex['graph1_nbrhd_indices']:`` PyTorch Tensor containing each of the first graph's neighboring node indices 187 | - ``complex['graph2_nbrhd_indices']:`` PyTorch Tensor containing each of the second graph's neighboring node indices 188 | - ``complex['examples']:`` PyTorch Tensor containing the labels for inter-graph node pairs 189 | - ``complex['complex']:`` Python string describing the complex's code and original pdb filename 190 | """ 191 | # Assemble filepath of processed protein complex 192 | complex_filepath = f'{os.path.splitext(self.filenames_frame[0][idx])[0]}.dill' 193 | processed_filepath = os.path.join(self.processed_dir, complex_filepath) 194 | 195 | # Load in processed complex 196 | with open(processed_filepath, 'rb') as f: 197 | processed_complex = pickle.load(f) 198 | processed_complex['filepath'] = complex_filepath # Add filepath to each complex dictionary 199 | 200 | # Optionally zero-out input data for an input-independent pipeline (per Karpathy's suggestion) 201 | if self.input_indep: 202 | processed_complex = zero_out_complex_features(processed_complex) 203 | 204 | # Manually filter for desired node and edge features 205 | # n_feat_idx_1, n_feat_idx_2 = 43, 85 # HSAAC 206 | # processed_complex['graph1'].ndata['f'] = processed_complex['graph1'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2] 207 | # processed_complex['graph2'].ndata['f'] = processed_complex['graph2'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2] 208 | 209 | # g1_rsa = processed_complex['graph1'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA 210 | # g1_psaia = processed_complex['graph1'].ndata['f'][:, 37: 43] # PSAIA 211 | # g1_hsaac = processed_complex['graph1'].ndata['f'][:, 43: 85] # HSAAC 212 | # processed_complex['graph1'].ndata['f'] = torch.cat((g1_rsa, g1_psaia, g1_hsaac), dim=1) 213 | # 214 | # g2_rsa = processed_complex['graph2'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA 215 | # g2_psaia = processed_complex['graph2'].ndata['f'][:, 37: 43] # PSAIA 216 | # g2_hsaac = processed_complex['graph2'].ndata['f'][:, 43: 85] # HSAAC 217 | # processed_complex['graph2'].ndata['f'] = torch.cat((g2_rsa, g2_psaia, g2_hsaac), dim=1) 218 | 219 | # processed_complex['graph1'].edata['f'] = processed_complex['graph1'].edata['f'][:, 1].reshape(-1, 1) 220 | # processed_complex['graph2'].edata['f'] = processed_complex['graph2'].edata['f'][:, 1].reshape(-1, 1) 221 | 222 | # Return requested complex to DataLoader 223 | return processed_complex 224 | 225 | def __len__(self) -> int: 226 | r"""Number of graph batches in the dataset.""" 227 | return len(self.filenames_frame) 228 | 229 | @property 230 | def num_chains(self) -> int: 231 | """Number of protein chains in each complex.""" 232 | return 2 233 | 234 | @property 235 | def num_classes(self) -> int: 236 | """Number of classes for each pair of inter-protein residues.""" 237 | return 2 238 | 239 | @property 240 | def num_node_features(self) -> int: 241 | """Number of node feature values after encoding them.""" 242 | return 107 243 | 244 | @property 245 | def num_edge_features(self) -> int: 246 | """Number of edge feature values after encoding them.""" 247 | return 3 248 | 249 | @property 250 | def raw_path(self) -> str: 251 | """Directory in which to locate raw pairs.""" 252 | return self.raw_dir 253 | 254 | @property 255 | def url(self) -> str: 256 | """URL with which to download TAR archive of preprocessed pairs.""" 257 | # TODO: Update URL 258 | return 'https://zenodo.org/record/4815267/files/final_raw_db5.tar.gz?download=1' 259 | -------------------------------------------------------------------------------- /project/datasets/EVCoupling/final/stub: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/EVCoupling/final/stub -------------------------------------------------------------------------------- /project/datasets/EVCoupling/group_files.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | 4 | import numpy as np 5 | 6 | # Get all filenames of interest 7 | filename_list = os.listdir(os.getcwd()) 8 | 9 | # Use next() + lambda + loop for Initial-Four-Characters Case Categorization 10 | util_func = lambda x, y: x[0] == y[0] and x[1] == y[1] and x[2] == y[2] and x[3] == y[3] 11 | res = [] 12 | for sub in filename_list: 13 | ele = next((x for x in res if util_func(sub, x[0])), []) 14 | if ele == []: 15 | res.append(ele) 16 | ele.append(sub) 17 | 18 | # Get names of EVCoupling heterodimers 19 | hd_list = np.loadtxt('../../evcoupling_heterodimer_list.txt', dtype=str).tolist() 20 | 21 | # Find first heterodimer for each complex and move it to a directory categorized by the complex's PDB code 22 | for r in res: 23 | if len(r) >= 2: 24 | r_combos = list(itertools.combinations(r, 2))[:1] # Only select a single heterodimer from each complex 25 | for combo in r_combos: 26 | heterodimer_name = combo[0][:5] + '_' + combo[1][:5] 27 | if heterodimer_name in hd_list: 28 | pdb_code = combo[0][:4] 29 | os.makedirs(pdb_code, exist_ok=True) 30 | l_b_chain_filename = combo[0] 31 | l_b_chain_filename = l_b_chain_filename[:-4] + '_l_b' + '.pdb' 32 | r_b_chain_filename = combo[1] 33 | r_b_chain_filename = r_b_chain_filename[:-4] + '_r_b' + '.pdb' 34 | if not os.path.exists(os.path.join(pdb_code, l_b_chain_filename)): 35 | os.rename(combo[0], os.path.join(pdb_code, l_b_chain_filename)) 36 | if not os.path.exists(os.path.join(pdb_code, r_b_chain_filename)): 37 | os.rename(combo[1], os.path.join(pdb_code, r_b_chain_filename)) 38 | -------------------------------------------------------------------------------- /project/datasets/EVCoupling/interim/stub: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/EVCoupling/interim/stub -------------------------------------------------------------------------------- /project/datasets/analysis/analyze_experiment_types_and_resolution.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import click 3 | import logging 4 | import os 5 | 6 | import atom3.pair as pa 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from graphein.ml.datasets import PDBManager 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | 15 | from project.utils.utils import download_pdb_file, gunzip_file 16 | 17 | 18 | def plot_mean_std_freq(means, stds, freqs, title): 19 | plt.figure(figsize=(10, 6)) 20 | num_experiments = len(means) 21 | bar_width = 0.4 22 | index = np.arange(num_experiments) 23 | 24 | # Plot bars for mean values with customized colors 25 | plt.barh(index, means, height=bar_width, color='salmon', label='Mean') 26 | 27 | # Plot error bars representing standard deviation with customized colors 28 | for i, (mean, std) in enumerate(zip(means, stds)): 29 | plt.errorbar(mean, i, xerr=std, color='red') 30 | 31 | # Plot vertical lines for the first and second standard deviation ranges with customized colors 32 | plt.vlines([mean - std, mean + std], i - bar_width/2, i + bar_width/2, color='royalblue', linestyles='dashed', label='1st Std Dev') 33 | plt.vlines([mean - 2*std, mean + 2*std], i - bar_width/2, i + bar_width/2, color='limegreen', linestyles='dotted', label='2nd Std Dev') 34 | 35 | plt.yticks(index, means.index) 36 | plt.xlabel('Resolution') 37 | plt.ylabel('Experiment Type') 38 | plt.title(title) 39 | plt.legend(loc="upper left") 40 | 41 | # Sort means, stds, and freqs based on the experiment types 42 | means = means[means.index.sort_values()] 43 | stds = stds[stds.index.sort_values()] 44 | freqs = freqs[freqs.index.sort_values()] 45 | freqs = (freqs / freqs.sum()) * 100 46 | 47 | # Calculate middle points of each bar 48 | middles = means 49 | 50 | # Calculate the visual center of each bar 51 | visual_center = middles / 2 52 | 53 | # Add frequency (count) of each experiment type at the middle point of each bar 54 | for i, freq in enumerate(freqs): 55 | plt.text(visual_center[i], i, f"{freq:.4f}%", va='center', ha='center', color='black', fontweight='bold') 56 | 57 | plt.tight_layout() 58 | plt.show() 59 | plt.savefig(title.lower().replace(' ', '_') + ".png") 60 | 61 | 62 | @click.command() 63 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 64 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5'])) 65 | def main(output_dir: str, source_type: str): 66 | logger = logging.getLogger(__name__) 67 | logger.info("Analyzing experiment types and resolution for each dataset example...") 68 | 69 | if source_type.lower() == "rcsb": 70 | train_metadata_csv_filepath = os.path.join(output_dir, "train_pdb_metadata.csv") 71 | val_metadata_csv_filepath = os.path.join(output_dir, "val_pdb_metadata.csv") 72 | train_val_metadata_csv_filepath = os.path.join(output_dir, "train_val_pdb_metadata.csv") 73 | metadata_csv_filepaths = [train_metadata_csv_filepath, val_metadata_csv_filepath, train_val_metadata_csv_filepath] 74 | 75 | if any(not os.path.exists(fp) for fp in metadata_csv_filepaths): 76 | with tempfile.TemporaryDirectory() as temp_dir: 77 | pdb_manager = PDBManager(root_dir=temp_dir) 78 | 79 | # Collect (and, if necessary, extract) all training PDB files 80 | train_pdb_codes = [] 81 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt') 82 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance." 83 | with open(pairs_postprocessed_train_txt, "r") as f: 84 | train_filenames = [line.strip() for line in f.readlines()] 85 | for train_filename in tqdm(train_filenames): 86 | try: 87 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename)) 88 | except Exception as e: 89 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}") 90 | continue 91 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3] 92 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 93 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0]) 94 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0]) 95 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique() 96 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique() 97 | assert ( 98 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 99 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 100 | assert ( 101 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 102 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 103 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 104 | gunzip_file(l_b_pdb_filepath) 105 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 106 | gunzip_file(r_b_pdb_filepath) 107 | if not os.path.exists(l_b_pdb_filepath): 108 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 109 | if not os.path.exists(r_b_pdb_filepath): 110 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 111 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 112 | 113 | l_b_pdb_code = Path(l_b_pdb_filepath).stem + "_" + l_b_df0_chains[0] 114 | r_b_pdb_code = Path(r_b_pdb_filepath).stem + "_" + r_b_df1_chains[0] 115 | train_pdb_codes.extend([l_b_pdb_code, r_b_pdb_code]) 116 | 117 | # Collect (and, if necessary, extract) all validation PDB files 118 | val_pdb_codes = [] 119 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt') 120 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance." 121 | with open(pairs_postprocessed_val_txt, "r") as f: 122 | val_filenames = [line.strip() for line in f.readlines()] 123 | for val_filename in tqdm(val_filenames): 124 | try: 125 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename)) 126 | except Exception as e: 127 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}") 128 | continue 129 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3] 130 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 131 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0]) 132 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0]) 133 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique() 134 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique() 135 | assert ( 136 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 137 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 138 | assert ( 139 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 140 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 141 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 142 | gunzip_file(l_b_pdb_filepath) 143 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 144 | gunzip_file(r_b_pdb_filepath) 145 | if not os.path.exists(l_b_pdb_filepath): 146 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 147 | if not os.path.exists(r_b_pdb_filepath): 148 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 149 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 150 | 151 | l_b_pdb_code = Path(l_b_pdb_filepath).stem + "_" + l_b_df0_chains[0] 152 | r_b_pdb_code = Path(r_b_pdb_filepath).stem + "_" + r_b_df1_chains[0] 153 | val_pdb_codes.extend([l_b_pdb_code, r_b_pdb_code]) 154 | 155 | # Record training and validation PDBs as a metadata CSV file 156 | train_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(train_pdb_codes)] 157 | train_pdbs_df.to_csv(train_metadata_csv_filepath) 158 | 159 | val_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(val_pdb_codes)] 160 | val_pdbs_df.to_csv(val_metadata_csv_filepath) 161 | 162 | train_val_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(train_pdb_codes + val_pdb_codes)] 163 | train_val_pdbs_df.to_csv(train_val_metadata_csv_filepath) 164 | 165 | assert all(os.path.exists(fp) for fp in metadata_csv_filepaths), "To analyze RCSB complexes, the corresponding metadata must previously have been collected." 166 | train_pdbs_df = pd.read_csv(train_metadata_csv_filepath, index_col=0) 167 | val_pdbs_df = pd.read_csv(val_metadata_csv_filepath, index_col=0) 168 | train_val_pdbs_df = pd.read_csv(train_val_metadata_csv_filepath, index_col=0) 169 | 170 | # Train PDBs 171 | train_pdbs_df = train_pdbs_df[~train_pdbs_df.experiment_type.isin(["other"])] 172 | means_train = train_pdbs_df.groupby('experiment_type')['resolution'].mean() 173 | stds_train = train_pdbs_df.groupby('experiment_type')['resolution'].std() 174 | freqs_train = train_pdbs_df['experiment_type'].value_counts() 175 | plot_mean_std_freq(means_train, stds_train, freqs_train, 'Resolution vs. Experiment Type (Train)') 176 | 177 | # Validation PDBs 178 | val_pdbs_df = val_pdbs_df[~val_pdbs_df.experiment_type.isin(["other"])] 179 | means_val = val_pdbs_df.groupby('experiment_type')['resolution'].mean() 180 | stds_val = val_pdbs_df.groupby('experiment_type')['resolution'].std() 181 | freqs_val = val_pdbs_df['experiment_type'].value_counts() 182 | plot_mean_std_freq(means_val, stds_val, freqs_val, 'Resolution vs. Experiment Type (Validation)') 183 | 184 | # Train + Validation PDBs 185 | train_val_pdbs_df = train_val_pdbs_df[~train_val_pdbs_df.experiment_type.isin(["other"])] 186 | means_train_val = train_val_pdbs_df.groupby('experiment_type')['resolution'].mean() 187 | stds_train_val = train_val_pdbs_df.groupby('experiment_type')['resolution'].std() 188 | freqs_train_val = train_val_pdbs_df['experiment_type'].value_counts() 189 | plot_mean_std_freq(means_train_val, stds_train_val, freqs_train_val, 'Resolution vs. Experiment Type (Train + Validation)') 190 | 191 | logger.info("Finished analyzing experiment types and resolution for all training and validation PDBs") 192 | 193 | else: 194 | raise NotImplementedError(f"Source type {source_type} is currently not supported.") 195 | 196 | 197 | if __name__ == "__main__": 198 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 199 | logging.basicConfig(level=logging.INFO, format=log_fmt) 200 | 201 | main() 202 | -------------------------------------------------------------------------------- /project/datasets/analysis/analyze_feature_correlation.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import os 4 | 5 | import atom3.pair as pa 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | 11 | from pathlib import Path 12 | from scipy.stats import pearsonr 13 | from tqdm import tqdm 14 | 15 | from project.utils.utils import download_pdb_file, gunzip_file 16 | 17 | 18 | @click.command() 19 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 20 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5'])) 21 | @click.option('--feature_types_to_correlate', default='rcsb', type=click.Choice(['rsa_value-rd_value', 'rsa_value-cn_value', 'rd_value-cn_value'])) 22 | def main(output_dir: str, source_type: str, feature_types_to_correlate: str): 23 | logger = logging.getLogger(__name__) 24 | logger.info("Analyzing feature correlation for each dataset example...") 25 | 26 | features_to_correlate = feature_types_to_correlate.split("-") 27 | assert len(features_to_correlate) == 2, "Exactly two features may be currently compared for correlation measures." 28 | 29 | if source_type.lower() == "rcsb": 30 | # Collect (and, if necessary, extract) all training PDB files 31 | train_feature_values = [] 32 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt') 33 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance." 34 | with open(pairs_postprocessed_train_txt, "r") as f: 35 | train_filenames = [line.strip() for line in f.readlines()] 36 | for train_filename in tqdm(train_filenames): 37 | try: 38 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename)) 39 | except Exception as e: 40 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}") 41 | continue 42 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3] 43 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 44 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0]) 45 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0]) 46 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique() 47 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique() 48 | assert ( 49 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 50 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 51 | assert ( 52 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 53 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 54 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 55 | gunzip_file(l_b_pdb_filepath) 56 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 57 | gunzip_file(r_b_pdb_filepath) 58 | if not os.path.exists(l_b_pdb_filepath): 59 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 60 | if not os.path.exists(r_b_pdb_filepath): 61 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 62 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 63 | 64 | l_b_df0_feature_values = postprocessed_train_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric) 65 | r_b_df1_feature_values = postprocessed_train_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric) 66 | train_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values])) 67 | 68 | # Collect (and, if necessary, extract) all validation PDB files 69 | val_feature_values = [] 70 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt') 71 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance." 72 | with open(pairs_postprocessed_val_txt, "r") as f: 73 | val_filenames = [line.strip() for line in f.readlines()] 74 | for val_filename in tqdm(val_filenames): 75 | try: 76 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename)) 77 | except Exception as e: 78 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}") 79 | continue 80 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3] 81 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 82 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0]) 83 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0]) 84 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique() 85 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique() 86 | assert ( 87 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 88 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 89 | assert ( 90 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 91 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 92 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 93 | gunzip_file(l_b_pdb_filepath) 94 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 95 | gunzip_file(r_b_pdb_filepath) 96 | if not os.path.exists(l_b_pdb_filepath): 97 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 98 | if not os.path.exists(r_b_pdb_filepath): 99 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 100 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 101 | 102 | l_b_df0_feature_values = postprocessed_val_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric) 103 | r_b_df1_feature_values = postprocessed_val_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric) 104 | val_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values])) 105 | 106 | # Train PDBs 107 | train_feature_values_df = pd.concat(train_feature_values) 108 | train_feature_values_correlation, train_feature_values_p_value = pearsonr(train_feature_values_df[features_to_correlate[0]], train_feature_values_df[features_to_correlate[1]]) 109 | logger.info(f"With a p-value of {train_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the training dataset is: {train_feature_values_correlation}") 110 | train_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=train_feature_values_df, kind='hex') 111 | # Add correlation value to the jointplot 112 | train_joint_plot.ax_joint.annotate(f"Training correlation: {train_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12) 113 | # Save the jointplot 114 | train_joint_plot.savefig(f"train_{feature_types_to_correlate}_correlation.png") 115 | plt.close() 116 | 117 | # Validation PDBs 118 | val_feature_values_df = pd.concat(val_feature_values) 119 | val_feature_values_correlation, val_feature_values_p_value = pearsonr(val_feature_values_df[features_to_correlate[0]], val_feature_values_df[features_to_correlate[1]]) 120 | logger.info(f"With a p-value of {val_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the validation dataset is: {val_feature_values_correlation}") 121 | val_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=val_feature_values_df, kind='hex') 122 | # Add correlation value to the jointplot 123 | val_joint_plot.ax_joint.annotate(f"Validation correlation: {val_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12) 124 | # Save the jointplot 125 | val_joint_plot.savefig(f"val_{feature_types_to_correlate}_correlation.png") 126 | plt.close() 127 | 128 | # Train + Validation PDBs 129 | train_val_feature_values_df = pd.concat([train_feature_values_df, val_feature_values_df]) 130 | train_val_feature_values_correlation, train_val_feature_values_p_value = pearsonr(train_val_feature_values_df[features_to_correlate[0]], train_val_feature_values_df[features_to_correlate[1]]) 131 | logger.info(f"With a p-value of {train_val_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the training and validation dataset is: {train_val_feature_values_correlation}") 132 | train_val_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=train_val_feature_values_df, kind='hex') 133 | # Add correlation value to the jointplot 134 | train_val_joint_plot.ax_joint.annotate(f"Training + Validation Correlation: {train_val_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12) 135 | # Save the jointplot 136 | train_val_joint_plot.savefig(f"train_val_{feature_types_to_correlate}_correlation.png") 137 | plt.close() 138 | 139 | logger.info("Finished analyzing feature correlation for all training and validation PDBs") 140 | 141 | else: 142 | raise NotImplementedError(f"Source type {source_type} is currently not supported.") 143 | 144 | 145 | if __name__ == "__main__": 146 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 147 | logging.basicConfig(level=logging.INFO, format=log_fmt) 148 | 149 | main() 150 | -------------------------------------------------------------------------------- /project/datasets/analysis/analyze_interface_waters.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import os 4 | import warnings 5 | 6 | import atom3.pair as pa 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from Bio import BiopythonWarning 12 | from Bio.PDB import NeighborSearch 13 | from Bio.PDB import PDBParser 14 | from pathlib import Path 15 | from tqdm import tqdm 16 | 17 | from project.utils.utils import download_pdb_file, gunzip_file 18 | 19 | 20 | @click.command() 21 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 22 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5'])) 23 | @click.option('--interfacing_water_distance_cutoff', default=10.0, type=float) 24 | def main(output_dir: str, source_type: str, interfacing_water_distance_cutoff: float): 25 | logger = logging.getLogger(__name__) 26 | logger.info("Analyzing interface waters within each dataset example...") 27 | 28 | if source_type.lower() == "rcsb": 29 | parser = PDBParser() 30 | 31 | # Filter and suppress BioPython warnings 32 | warnings.filterwarnings("ignore", category=BiopythonWarning) 33 | 34 | # Collect (and, if necessary, extract) all training PDB files 35 | train_complex_num_waters = [] 36 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt') 37 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance." 38 | with open(pairs_postprocessed_train_txt, "r") as f: 39 | train_filenames = [line.strip() for line in f.readlines()] 40 | for train_filename in tqdm(train_filenames): 41 | complex_num_waters = 0 42 | try: 43 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename)) 44 | except Exception as e: 45 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}") 46 | continue 47 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3] 48 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 49 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0]) 50 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0]) 51 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique() 52 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique() 53 | assert ( 54 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 55 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 56 | assert ( 57 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 58 | ), "Only a single PDB filename and chain identifier can be associated with a single training example." 59 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 60 | gunzip_file(l_b_pdb_filepath) 61 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 62 | gunzip_file(r_b_pdb_filepath) 63 | if not os.path.exists(l_b_pdb_filepath): 64 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 65 | if not os.path.exists(r_b_pdb_filepath): 66 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 67 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 68 | 69 | l_b_structure = parser.get_structure('protein', l_b_pdb_filepath) 70 | r_b_structure = parser.get_structure('protein', r_b_pdb_filepath) 71 | 72 | l_b_interface_residues = postprocessed_train_pair.df0[postprocessed_train_pair.df0.index.isin(postprocessed_train_pair.pos_idx[:, 0])] 73 | r_b_interface_residues = postprocessed_train_pair.df1[postprocessed_train_pair.df1.index.isin(postprocessed_train_pair.pos_idx[:, 1])] 74 | 75 | try: 76 | l_b_ns = NeighborSearch(list(l_b_structure.get_atoms())) 77 | for index, row in l_b_interface_residues.iterrows(): 78 | chain_id = row['chain'] 79 | residue = row['residue'].strip() 80 | model = l_b_structure[0] 81 | chain = model[chain_id] 82 | if residue.lstrip("-").isdigit(): 83 | residue = int(residue) 84 | else: 85 | residue_index, residue_icode = residue[:-1], residue[-1:] 86 | if residue_icode.strip() == "": 87 | residue = int(residue) 88 | else: 89 | residue = (" ", int(residue_index), residue_icode) 90 | target_residue = chain[residue] 91 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze() 92 | interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A') 93 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']] 94 | complex_num_waters += len(waters_within_threshold) 95 | except Exception as e: 96 | logging.error(f"Unable to locate interface waters for left-bound training structure {l_b_pdb_filepath} due to: {e}. Skipping...") 97 | continue 98 | 99 | try: 100 | r_b_ns = NeighborSearch(list(r_b_structure.get_atoms())) 101 | for index, row in r_b_interface_residues.iterrows(): 102 | chain_id = row['chain'] 103 | residue = row['residue'].strip() 104 | model = r_b_structure[0] 105 | chain = model[chain_id] 106 | if residue.lstrip("-").isdigit(): 107 | residue = int(residue) 108 | else: 109 | residue_index, residue_icode = residue[:-1], residue[-1:] 110 | residue = (" ", int(residue_index), residue_icode) 111 | target_residue = chain[residue] 112 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze() 113 | interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A') 114 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']] 115 | complex_num_waters += len(waters_within_threshold) 116 | except Exception as e: 117 | logging.error(f"Unable to locate interface waters for right-bound training structure {r_b_pdb_filepath} due to: {e}. Skipping...") 118 | continue 119 | 120 | train_complex_num_waters.append(complex_num_waters) 121 | 122 | # Collect (and, if necessary, extract) all validation PDB files 123 | val_complex_num_waters = [] 124 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt') 125 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance." 126 | with open(pairs_postprocessed_val_txt, "r") as f: 127 | val_filenames = [line.strip() for line in f.readlines()] 128 | for val_filename in tqdm(val_filenames): 129 | complex_num_waters = 0 130 | try: 131 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename)) 132 | except Exception as e: 133 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}") 134 | continue 135 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3] 136 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code) 137 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0]) 138 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0]) 139 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique() 140 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique() 141 | assert ( 142 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1 143 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 144 | assert ( 145 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1 146 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example." 147 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"): 148 | gunzip_file(l_b_pdb_filepath) 149 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"): 150 | gunzip_file(r_b_pdb_filepath) 151 | if not os.path.exists(l_b_pdb_filepath): 152 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath) 153 | if not os.path.exists(r_b_pdb_filepath): 154 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath) 155 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist." 156 | 157 | l_b_structure = parser.get_structure('protein', l_b_pdb_filepath) 158 | r_b_structure = parser.get_structure('protein', r_b_pdb_filepath) 159 | 160 | l_b_interface_residues = postprocessed_val_pair.df0[postprocessed_val_pair.df0.index.isin(postprocessed_val_pair.pos_idx[:, 0])] 161 | r_b_interface_residues = postprocessed_val_pair.df1[postprocessed_val_pair.df1.index.isin(postprocessed_val_pair.pos_idx[:, 1])] 162 | 163 | try: 164 | l_b_ns = NeighborSearch(list(l_b_structure.get_atoms())) 165 | for index, row in l_b_interface_residues.iterrows(): 166 | chain_id = row['chain'] 167 | residue = row['residue'].strip() 168 | model = l_b_structure[0] 169 | chain = model[chain_id] 170 | if residue.lstrip("-").isdigit(): 171 | residue = int(residue) 172 | else: 173 | residue_index, residue_icode = residue[:-1], residue[-1:] 174 | residue = (" ", int(residue_index), residue_icode) 175 | target_residue = chain[residue] 176 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze() 177 | interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A') 178 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']] 179 | complex_num_waters += len(waters_within_threshold) 180 | except Exception as e: 181 | logging.error(f"Unable to locate interface waters for left-bound validation structure {l_b_pdb_filepath} due to: {e}. Skipping...") 182 | continue 183 | 184 | try: 185 | r_b_ns = NeighborSearch(list(r_b_structure.get_atoms())) 186 | for index, row in r_b_interface_residues.iterrows(): 187 | chain_id = row['chain'] 188 | residue = row['residue'].strip() 189 | model = r_b_structure[0] 190 | chain = model[chain_id] 191 | if residue.lstrip("-").isdigit(): 192 | residue = int(residue) 193 | else: 194 | residue_index, residue_icode = residue[:-1], residue[-1:] 195 | residue = (" ", int(residue_index), residue_icode) 196 | target_residue = chain[residue] 197 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze() 198 | interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A') 199 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']] 200 | complex_num_waters += len(waters_within_threshold) 201 | except Exception as e: 202 | logging.error(f"Unable to locate interface waters for right-bound validation structure {r_b_pdb_filepath} due to: {e}. Skipping...") 203 | continue 204 | 205 | val_complex_num_waters.append(complex_num_waters) 206 | 207 | train_val_complex_num_waters = train_complex_num_waters + val_complex_num_waters 208 | 209 | # Calculate mean values 210 | training_mean = np.mean(train_complex_num_waters) 211 | validation_mean = np.mean(val_complex_num_waters) 212 | training_validation_mean = np.mean(train_val_complex_num_waters) 213 | 214 | # Plotting the distributions 215 | plt.figure(figsize=(10, 6)) # Set the size of the figure 216 | 217 | # Training data distribution 218 | plt.subplot(131) # 1 row, 3 columns, plot 1 (leftmost) 219 | plt.hist(train_complex_num_waters, bins=10, color='royalblue') 220 | plt.axvline(training_mean, color='limegreen', linestyle='dashed', linewidth=2) 221 | plt.text(training_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {training_mean:.2f}', color='limegreen') 222 | plt.title('Train Interface Waters') 223 | plt.xlabel('Count') 224 | plt.ylabel('Frequency') 225 | 226 | # Validation data distribution 227 | plt.subplot(132) # 1 row, 3 columns, plot 2 (middle) 228 | plt.hist(val_complex_num_waters, bins=10, color='royalblue') 229 | plt.axvline(validation_mean, color='limegreen', linestyle='dashed', linewidth=2) 230 | plt.text(validation_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {validation_mean:.2f}', color='limegreen') 231 | plt.title('Validation Interface Waters') 232 | plt.xlabel('Count') 233 | plt.ylabel('Frequency') 234 | 235 | # Combined data distribution 236 | plt.subplot(133) # 1 row, 3 columns, plot 3 (rightmost) 237 | plt.hist(train_val_complex_num_waters, bins=10, color='royalblue') 238 | plt.axvline(training_validation_mean, color='limegreen', linestyle='dashed', linewidth=2) 239 | plt.text(training_validation_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {training_validation_mean:.2f}', color='limegreen') 240 | plt.title('Train+Validation Interface Waters') 241 | plt.xlabel('Count') 242 | plt.ylabel('Frequency') 243 | 244 | plt.tight_layout() # Adjust the spacing between subplots 245 | plt.show() # Display the plots 246 | plt.savefig("train_val_interface_waters_analysis.png") 247 | 248 | logger.info("Finished analyzing interface waters for all training and validation complexes") 249 | 250 | else: 251 | raise NotImplementedError(f"Source type {source_type} is currently not supported.") 252 | 253 | 254 | if __name__ == "__main__": 255 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 256 | logging.basicConfig(level=logging.INFO, format=log_fmt) 257 | 258 | main() 259 | -------------------------------------------------------------------------------- /project/datasets/builder/add_new_feature.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from pandas.errors import SettingWithCopyWarning 4 | 5 | warnings.simplefilter("ignore", category=FutureWarning) 6 | warnings.simplefilter("ignore", category=SettingWithCopyWarning) 7 | 8 | import click 9 | import graphein 10 | import logging 11 | import loguru 12 | import multiprocessing 13 | import os 14 | import sys 15 | 16 | from functools import partial 17 | from pathlib import Path 18 | 19 | from project.utils.utils import add_new_feature 20 | 21 | GRAPHEIN_FEATURE_NAME_MAPPING = { 22 | # TODO: Fill out remaining mappings for available Graphein residue-level features 23 | "expasy_protein_scale": "expasy", 24 | } 25 | 26 | 27 | @click.command() 28 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) 29 | @click.option('--num_cpus', '-c', default=1) 30 | @click.option('--modify_pair_data/--dry_run_only', '-m', default=False) 31 | @click.option('--graphein_feature_to_add', default='expasy_protein_scale', type=str) 32 | def main(raw_data_dir: str, num_cpus: int, modify_pair_data: bool, graphein_feature_to_add: str): 33 | # Validate requested feature function 34 | assert ( 35 | hasattr(graphein.protein.features.nodes.amino_acid, graphein_feature_to_add) 36 | ), f"Graphein must provide the requested node featurization function {graphein_feature_to_add}" 37 | 38 | # Disable DEBUG messages coming from Graphein 39 | loguru.logger.disable("graphein") 40 | loguru.logger.remove() 41 | loguru.logger.add(lambda message: message["level"].name != "DEBUG") 42 | 43 | # Collect paths of files to modify 44 | raw_data_dir = Path(raw_data_dir) 45 | raw_data_pickle_filepaths = [] 46 | for root, dirs, files in os.walk(raw_data_dir): 47 | for dir in dirs: 48 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir): 49 | for file in subfiles: 50 | if file.endswith('.dill'): 51 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file) 52 | 53 | # Add to each file the values corresponding to a new feature, using multiprocessing # 54 | # Define the number of processes to use 55 | num_processes = min(num_cpus, multiprocessing.cpu_count()) 56 | 57 | # Split the list of file paths into chunks 58 | chunk_size = len(raw_data_pickle_filepaths) // num_processes 59 | file_path_chunks = [ 60 | raw_data_pickle_filepaths[i:i+chunk_size] 61 | for i in range(0, len(raw_data_pickle_filepaths), chunk_size) 62 | ] 63 | assert ( 64 | len(raw_data_pickle_filepaths) == len([fp for chunk in file_path_chunks for fp in chunk]) 65 | ), "Number of input files must match number of files across all file chunks." 66 | 67 | # Create a pool of worker processes 68 | pool = multiprocessing.Pool(processes=num_processes) 69 | 70 | # Process each chunk of file paths in parallel 71 | parallel_fn = partial( 72 | add_new_feature, 73 | modify_pair_data=modify_pair_data, 74 | graphein_feature_to_add=graphein_feature_to_add, 75 | graphein_feature_name_mapping=GRAPHEIN_FEATURE_NAME_MAPPING, 76 | ) 77 | pool.map(parallel_fn, file_path_chunks) 78 | 79 | # Close the pool and wait for all processes to finish 80 | pool.close() 81 | pool.join() 82 | 83 | 84 | if __name__ == "__main__": 85 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 86 | logging.basicConfig(level=logging.INFO, format=log_fmt) 87 | 88 | main() 89 | -------------------------------------------------------------------------------- /project/datasets/builder/annotate_idr_residues.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import multiprocessing 4 | import os 5 | 6 | from pathlib import Path 7 | 8 | from project.utils.utils import annotate_idr_residues 9 | 10 | 11 | @click.command() 12 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) 13 | @click.option('--num_cpus', '-c', default=1) 14 | def main(raw_data_dir: str, num_cpus: int): 15 | # Collect paths of files to modify 16 | raw_data_dir = Path(raw_data_dir) 17 | raw_data_pickle_filepaths = [] 18 | for root, dirs, files in os.walk(raw_data_dir): 19 | for dir in dirs: 20 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir): 21 | for file in subfiles: 22 | if file.endswith('.dill'): 23 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file) 24 | 25 | # Annotate whether each residue resides in an IDR, using multiprocessing # 26 | # Define the number of processes to use 27 | num_processes = min(num_cpus, multiprocessing.cpu_count()) 28 | 29 | # Split the list of file paths into chunks 30 | chunk_size = len(raw_data_pickle_filepaths) // num_processes 31 | file_path_chunks = [ 32 | raw_data_pickle_filepaths[i:i+chunk_size] 33 | for i in range(0, len(raw_data_pickle_filepaths), chunk_size) 34 | ] 35 | assert ( 36 | len(raw_data_pickle_filepaths) == len([fp for chunk in file_path_chunks for fp in chunk]) 37 | ), "Number of input files must match number of files across all file chunks." 38 | 39 | # Create a pool of worker processes 40 | pool = multiprocessing.Pool(processes=num_processes) 41 | 42 | # Process each chunk of file paths in parallel 43 | pool.map(annotate_idr_residues, file_path_chunks) 44 | 45 | # Close the pool and wait for all processes to finish 46 | pool.close() 47 | pool.join() 48 | 49 | 50 | if __name__ == "__main__": 51 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 52 | logging.basicConfig(level=logging.INFO, format=log_fmt) 53 | 54 | main() 55 | -------------------------------------------------------------------------------- /project/datasets/builder/collect_dataset_statistics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | 6 | from project.utils.utils import get_global_node_rank, collect_dataset_statistics, DEFAULT_DATASET_STATISTICS 7 | 8 | 9 | @click.command() 10 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 11 | @click.option('--rank', '-r', default=0) 12 | @click.option('--size', '-s', default=1) 13 | def main(output_dir: str, rank: int, size: int): 14 | """Collect all dataset statistics.""" 15 | # Reestablish global rank 16 | rank = get_global_node_rank(rank, size) 17 | 18 | # Ensure that this task only gets run on a single node to prevent race conditions 19 | if rank == 0: 20 | logger = logging.getLogger(__name__) 21 | logger.info('Aggregating statistics for given dataset') 22 | 23 | # Make sure the output_dir exists 24 | if not os.path.exists(output_dir): 25 | os.mkdir(output_dir) 26 | 27 | # Create dataset statistics CSV if not already existent 28 | dataset_statistics_csv = os.path.join(output_dir, 'dataset_statistics.csv') 29 | if not os.path.exists(dataset_statistics_csv): 30 | # Reset dataset statistics CSV 31 | with open(dataset_statistics_csv, 'w') as f: 32 | for key in DEFAULT_DATASET_STATISTICS.keys(): 33 | f.write("%s, %s\n" % (key, DEFAULT_DATASET_STATISTICS[key])) 34 | 35 | # Aggregate dataset statistics in a readable fashion 36 | dataset_statistics = collect_dataset_statistics(output_dir) 37 | 38 | # Write out updated dataset statistics 39 | with open(dataset_statistics_csv, 'w') as f: 40 | for key in dataset_statistics.keys(): 41 | f.write("%s, %s\n" % (key, dataset_statistics[key])) 42 | 43 | 44 | if __name__ == '__main__': 45 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 46 | logging.basicConfig(level=logging.INFO, format=log_fmt) 47 | 48 | main() 49 | -------------------------------------------------------------------------------- /project/datasets/builder/compile_casp_capri_dataset_on_andes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ####################### Batch Headers ######################### 4 | #SBATCH -A BIF135 5 | #SBATCH -p batch 6 | #SBATCH -J make_casp_capri_dataset 7 | #SBATCH -t 0-24:00 8 | #SBATCH --mem 224G 9 | #SBATCH --nodes 4 10 | #SBATCH --ntasks-per-node 1 11 | ############################################################### 12 | 13 | # Remote paths # 14 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/bif135/Repositories/Lab_Repositories/DIPS-Plus 15 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa 16 | export OMP_NUM_THREADS=8 17 | 18 | # Remote Conda environment # 19 | source "$PROJDIR"/miniconda3/bin/activate 20 | conda activate DIPS-Plus 21 | 22 | # Load CUDA module for DGL 23 | module load cuda/10.2.89 24 | 25 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed) 26 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/bif132/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq 27 | 28 | # Run dataset compilation scripts 29 | cd "$PROJDIR"/project || exit 30 | 31 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/CASP-CAPRI/interim/parsed "$PROJDIR"/project/datasets/CASP-CAPRI/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/CASP-CAPRI/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type casp_capri --write_file 32 | 33 | #srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/CASP-CAPRI/raw "$PROJDIR"/project/datasets/CASP-CAPRI/interim/pairs "$PROJDIR"/project/datasets/CASP-CAPRI/interim/external_feats "$PROJDIR"/project/datasets/CASP-CAPRI/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type CASP-CAPRI 34 | -------------------------------------------------------------------------------- /project/datasets/builder/compile_db5_dataset_on_andes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ####################### Batch Headers ######################### 4 | #SBATCH -A BIP198 5 | #SBATCH -p batch 6 | #SBATCH -J make_db5_dataset 7 | #SBATCH -t 0-12:00 8 | #SBATCH --mem 224G 9 | #SBATCH --nodes 4 10 | #SBATCH --ntasks-per-node 1 11 | ############################################################### 12 | 13 | # Remote paths # 14 | export PROJID=bif132 15 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus 16 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa 17 | export OMP_NUM_THREADS=8 18 | 19 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed) 20 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/$PROJID/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq 21 | 22 | # Remote Conda environment # 23 | source "$PROJDIR"/miniconda3/bin/activate 24 | 25 | # Load CUDA module for DGL 26 | module load cuda/10.2.89 27 | 28 | # Run dataset compilation scripts 29 | cd "$PROJDIR"/project || exit 30 | wget -O "$PROJDIR"/project/datasets/DB5.tar.gz https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/H93ZKK/BXXQCG 31 | tar -xzf "$PROJDIR"/project/datasets/DB5.tar.gz --directory "$PROJDIR"/project/datasets/ 32 | rm "$PROJDIR"/project/datasets/DB5.tar.gz "$PROJDIR"/project/datasets/DB5/.README.swp 33 | rm -rf "$PROJDIR"/project/datasets/DB5/interim "$PROJDIR"/project/datasets/DB5/processed 34 | mkdir "$PROJDIR"/project/datasets/DB5/interim "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/interim/external_feats/PSAIA "$PROJDIR"/project/datasets/DB5/interim/external_feats/PSAIA/DB5 "$PROJDIR"/project/datasets/DB5/final "$PROJDIR"/project/datasets/DB5/final/raw 35 | 36 | python3 "$PROJDIR"/project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim --num_cpus 32 --rank "$1" --size "$2" --source_type db5 37 | 38 | python3 "$PROJDIR"/project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_db5.txt "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/external_feats --source_type db5 --rank "$1" --size "$2" 39 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/DB5/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type db5 --write_file 40 | 41 | srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/pairs "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type db5 42 | 43 | python3 "$PROJDIR"/project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DB5/final/raw --source_type db5 --rank "$1" --size "$2" 44 | python3 "$PROJDIR"/project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2" 45 | python3 "$PROJDIR"/project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2" 46 | python3 "$PROJDIR"/project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DB5/final/raw --impute_atom_features False --num_cpus 32 --rank "$1" --size "$2" 47 | 48 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels 49 | python3 "$PROJDIR"/project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DB5/final/raw "$PROJDIR"/project/datasets/DB5/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2" 50 | -------------------------------------------------------------------------------- /project/datasets/builder/compile_dips_dataset_on_andes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ####################### Batch Headers ######################### 4 | #SBATCH -A BIF132 5 | #SBATCH -p batch 6 | #SBATCH -J make_dips_dataset 7 | #SBATCH -t 0-24:00 8 | #SBATCH --mem 224G 9 | #SBATCH --nodes 4 10 | #SBATCH --ntasks-per-node 1 11 | ############################################################### 12 | 13 | # Remote paths # 14 | export PROJID=bif132 15 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus 16 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa 17 | 18 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed) 19 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/$PROJID/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq 20 | 21 | # Remote Conda environment # 22 | source "$PROJDIR"/miniconda3/bin/activate 23 | 24 | # Load CUDA module for DGL 25 | module load cuda/10.2.89 26 | 27 | # Load GCC 10 module for OpenMPI support 28 | module load gcc/10.3.0 29 | 30 | # Run dataset compilation scripts 31 | cd "$PROJDIR"/project || exit 32 | rm "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-train.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-val.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-test.txt 33 | mkdir "$PROJDIR"/project/datasets/DIPS/raw "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/interim/external_feats/PSAIA "$PROJDIR"/project/datasets/DIPS/interim/external_feats/PSAIA/RCSB "$PROJDIR"/project/datasets/DIPS/final "$PROJDIR"/project/datasets/DIPS/final/raw 34 | rsync -rlpt -v -z --delete --port=33444 --include='*.gz' --include='*/' --exclude '*' rsync.rcsb.org::ftp_data/biounit/coordinates/divided/ "$PROJDIR"/project/datasets/DIPS/raw/pdb 35 | 36 | python3 "$PROJDIR"/project/datasets/builder/extract_raw_pdb_gz_archives.py "$PROJDIR"/project/datasets/DIPS/raw/pdb --rank "$1" --size "$2" 37 | 38 | python3 "$PROJDIR"/project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim --num_cpus 32 --rank "$1" --size "$2" --source_type rcsb --bound 39 | 40 | python3 "$PROJDIR"/project/datasets/builder/prune_pairs.py "$PROJDIR"/project/datasets/DIPS/interim/pairs "$PROJDIR"/project/datasets/DIPS/filters "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2" 41 | 42 | python3 "$PROJDIR"/project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_dips.txt "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats --source_type rcsb --rank "$1" --size "$2" 43 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file 44 | 45 | # Retroactively download the PDB files corresponding to complexes that made it through DIPS-Plus' RCSB complex pruning to reduce storage requirements 46 | python3 "$PROJDIR"/project/datasets/builder/download_missing_pruned_pair_pdbs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2" 47 | srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 32 --rank "$1" --size "$2" 48 | 49 | python3 "$PROJDIR"/project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DIPS/final/raw --source_type rcsb --filter_by_atom_count True --max_atom_count 17500 --rank "$1" --size "$2" 50 | python3 "$PROJDIR"/project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2" 51 | python3 "$PROJDIR"/project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2" 52 | python3 "$PROJDIR"/project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DIPS/final/raw --impute_atom_features False --num_cpus 32 --rank "$1" --size "$2" 53 | 54 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels 55 | python3 "$PROJDIR"/project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DIPS/final/raw "$PROJDIR"/project/datasets/DIPS/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2" 56 | -------------------------------------------------------------------------------- /project/datasets/builder/compile_evcoupling_dataset_on_andes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ####################### Batch Headers ######################### 4 | #SBATCH -A BIF135 5 | #SBATCH -p batch 6 | #SBATCH -J make_evcoupling_dataset 7 | #SBATCH -t 0-24:00 8 | #SBATCH --mem 224G 9 | #SBATCH --nodes 4 10 | #SBATCH --ntasks-per-node 1 11 | ############################################################### 12 | 13 | # Remote paths # 14 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/bif135/Repositories/Lab_Repositories/DIPS-Plus 15 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa 16 | export OMP_NUM_THREADS=8 17 | 18 | # Remote Conda environment # 19 | source "$PROJDIR"/miniconda3/bin/activate 20 | conda activate DIPS-Plus 21 | 22 | # Load CUDA module for DGL 23 | module load cuda/10.2.89 24 | 25 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed) 26 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/bif132/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq 27 | 28 | # Run dataset compilation scripts 29 | cd "$PROJDIR"/project || exit 30 | 31 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/EVCoupling/interim/parsed "$PROJDIR"/project/datasets/EVCoupling/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/EVCoupling/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type evcoupling --read_file 32 | 33 | #srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/EVCoupling/raw "$PROJDIR"/project/datasets/EVCoupling/interim/pairs "$PROJDIR"/project/datasets/EVCoupling/interim/external_feats "$PROJDIR"/project/datasets/EVCoupling/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type EVCoupling 34 | -------------------------------------------------------------------------------- /project/datasets/builder/convert_complexes_to_graphs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | 5 | import click 6 | from parallel import submit_jobs 7 | 8 | from project.utils.utils import make_dgl_graphs, get_global_node_rank 9 | 10 | 11 | @click.command() 12 | @click.argument('input_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) 13 | @click.argument('output_dir', default='../DIPS/final/processed', type=click.Path()) 14 | @click.option('--ext', '-e', default='pt') 15 | @click.option('--num_cpus', '-c', default=1) 16 | @click.option('--edge_dist_cutoff', '-d', default=15.0) 17 | @click.option('--edge_limit', '-l', default=5000) 18 | @click.option('--self_loops', '-s', default=True) 19 | @click.option('--rank', '-r', default=0) 20 | @click.option('--size', '-s', default=1) 21 | def main(input_dir: str, output_dir: str, ext: str, num_cpus: int, edge_dist_cutoff: float, 22 | edge_limit: int, self_loops: bool, rank: int, size: int): 23 | """Make DGL graphs out of postprocessed pairs.""" 24 | # Reestablish global rank 25 | rank = get_global_node_rank(rank, size) 26 | 27 | # Ensure that this task only gets run on a single node to prevent race conditions 28 | if rank == 0: 29 | logger = logging.getLogger(__name__) 30 | logger.info('Creating DGL graphs (final \'processed\' pairs) from postprocessed (final \'raw\') pairs') 31 | 32 | # Make sure the output_dir exists 33 | if not os.path.exists(output_dir): 34 | os.mkdir(output_dir) 35 | 36 | input_files = [file for file in glob.glob(os.path.join(input_dir, '**', '*.dill'), recursive=True)] 37 | inputs = [(output_dir, input_file, ext, edge_dist_cutoff, edge_limit, self_loops) for input_file in input_files] 38 | submit_jobs(make_dgl_graphs, inputs, num_cpus) 39 | 40 | 41 | if __name__ == '__main__': 42 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 43 | logging.basicConfig(level=logging.INFO, format=log_fmt) 44 | 45 | main() 46 | -------------------------------------------------------------------------------- /project/datasets/builder/create_hdf5_dataset.py: -------------------------------------------------------------------------------- 1 | import click 2 | import logging 3 | import os 4 | import warnings 5 | 6 | from hickle.lookup import SerializedWarning 7 | 8 | warnings.simplefilter("ignore", category=SerializedWarning) 9 | 10 | from pathlib import Path 11 | from parallel import submit_jobs 12 | 13 | from project.utils.utils import convert_pair_pickle_to_hdf5 14 | 15 | 16 | @click.command() 17 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) 18 | @click.option('--num_cpus', '-c', default=1) 19 | def main(raw_data_dir: str, num_cpus: int): 20 | raw_data_dir = Path(raw_data_dir) 21 | raw_data_pickle_filepaths = [] 22 | for root, dirs, files in os.walk(raw_data_dir): 23 | for dir in dirs: 24 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir): 25 | for file in subfiles: 26 | if file.endswith('.dill'): 27 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file) 28 | inputs = [(pickle_filepath, Path(pickle_filepath).with_suffix(".hdf5")) for pickle_filepath in raw_data_pickle_filepaths] 29 | submit_jobs(convert_pair_pickle_to_hdf5, inputs, num_cpus) 30 | 31 | # filepath = Path("project/datasets/DIPS/final/raw/0g/10gs.pdb1_0.dill") 32 | # pickle_example = convert_pair_hdf5_to_pickle( 33 | # hdf5_filepath=Path(filepath).with_suffix(".hdf5") 34 | # ) 35 | # hdf5_file_example = convert_pair_hdf5_to_hdf5_file( 36 | # hdf5_filepath=Path(filepath).with_suffix(".hdf5") 37 | # ) 38 | # print(f"pickle_example: {pickle_example}") 39 | # print(f"hdf5_file_example: {hdf5_file_example}") 40 | 41 | 42 | if __name__ == "__main__": 43 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 44 | logging.basicConfig(level=logging.INFO, format=log_fmt) 45 | 46 | main() 47 | -------------------------------------------------------------------------------- /project/datasets/builder/download_missing_pruned_pair_pdbs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | from atom3.database import get_structures_filenames, get_pdb_name, get_pdb_code 6 | from atom3.utils import slice_list 7 | from mpi4py import MPI 8 | from parallel import submit_jobs 9 | 10 | from project.utils.utils import get_global_node_rank, download_missing_pruned_pair_pdbs 11 | 12 | 13 | @click.command() 14 | @click.argument('output_dir', default='../DIPS/raw/pdb', type=click.Path(exists=True)) 15 | @click.argument('pruned_pairs_dir', default='../DIPS/interim/pairs-pruned', type=click.Path(exists=True)) 16 | @click.option('--num_cpus', '-c', default=1) 17 | @click.option('--rank', '-r', default=0) 18 | @click.option('--size', '-s', default=1) 19 | def main(output_dir: str, pruned_pairs_dir: str, num_cpus: int, rank: int, size: int): 20 | """Download missing pruned pair PDB files.""" 21 | # Reestablish global rank 22 | rank = get_global_node_rank(rank, size) 23 | logger = logging.getLogger(__name__) 24 | logger.info(f'Beginning missing PDB downloads for node {rank + 1} out of a global MPI world of size {size},' 25 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}') 26 | 27 | # Make sure the output_dir exists 28 | if not os.path.exists(output_dir): 29 | os.mkdir(output_dir) 30 | 31 | # Get work filenames 32 | logger.info(f'Looking for all pairs in {pruned_pairs_dir}') 33 | requested_filenames = get_structures_filenames(pruned_pairs_dir, extension='.dill') 34 | requested_filenames = [filename for filename in requested_filenames] 35 | requested_keys = [get_pdb_name(x) for x in requested_filenames] 36 | work_keys = [key for key in requested_keys] 37 | work_filenames = [os.path.join(pruned_pairs_dir, get_pdb_code(work_key)[1:3], work_key + '.dill') 38 | for work_key in work_keys] 39 | logger.info(f'Found {len(work_keys)} work pair(s) in {pruned_pairs_dir}') 40 | 41 | # Reserve an equally-sized portion of the full work load for a given rank in the MPI world 42 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames 43 | work_filename_rank_batches = slice_list(work_filenames, size) 44 | work_filenames = work_filename_rank_batches[rank] 45 | 46 | # Collect thread inputs 47 | inputs = [(logger, output_dir, work_filename) for work_filename in work_filenames] 48 | submit_jobs(download_missing_pruned_pair_pdbs, inputs, num_cpus) 49 | 50 | 51 | if __name__ == '__main__': 52 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 53 | logging.basicConfig(level=logging.INFO, format=log_fmt) 54 | 55 | main() 56 | -------------------------------------------------------------------------------- /project/datasets/builder/extract_raw_pdb_gz_archives.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import click 7 | from project.utils.utils import get_global_node_rank 8 | 9 | 10 | @click.command() 11 | @click.argument('gz_data_dir', type=click.Path(exists=True)) 12 | @click.option('--rank', '-r', default=0) 13 | @click.option('--size', '-s', default=1) 14 | def main(gz_data_dir: str, rank: int, size: int): 15 | """ Run GZ extraction logic to turn raw data from (raw/) into extracted data ready to be analyzed by DSSP.""" 16 | # Reestablish global rank 17 | rank = get_global_node_rank(rank, size) 18 | 19 | # Ensure that this task only gets run on a single node to prevent race conditions 20 | if rank == 0: 21 | logger = logging.getLogger(__name__) 22 | logger.info('Extracting raw GZ archives') 23 | 24 | # Iterate directory structure, extracting all GZ archives along the way 25 | data_dir = os.path.abspath(gz_data_dir) + '/' 26 | raw_pdb_list = os.listdir(data_dir) 27 | for pdb_dir in raw_pdb_list: 28 | for pdb_gz in os.listdir(data_dir + pdb_dir): 29 | if pdb_gz.endswith('.gz'): 30 | _, ext = os.path.splitext(pdb_gz) 31 | gzip_dir = data_dir + pdb_dir + '/' + pdb_gz 32 | extract_dir = data_dir + pdb_dir + '/' + _ 33 | if not os.path.exists(extract_dir): 34 | with gzip.open(gzip_dir, 'rb') as f_in: 35 | with open(extract_dir, 'wb') as f_out: 36 | shutil.copyfileobj(f_in, f_out) 37 | 38 | 39 | if __name__ == '__main__': 40 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 41 | logging.basicConfig(level=logging.INFO, format=log_fmt) 42 | 43 | main() 44 | -------------------------------------------------------------------------------- /project/datasets/builder/generate_hhsuite_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source code (MIT-Licensed) inspired by Atom3 (https://github.com/drorlab/atom3 & https://github.com/amorehead/atom3) 3 | """ 4 | 5 | import glob 6 | import logging 7 | import os 8 | from os import cpu_count 9 | 10 | import click 11 | from atom3.conservation import map_all_profile_hmms 12 | from mpi4py import MPI 13 | 14 | from project.utils.utils import get_global_node_rank 15 | 16 | 17 | @click.command() 18 | @click.argument('pkl_dataset', type=click.Path(exists=True)) 19 | @click.argument('pruned_dataset', type=click.Path(exists=True)) 20 | @click.argument('hhsuite_db', type=click.Path()) 21 | @click.argument('output_dir', type=click.Path()) 22 | @click.option('--rank', default=0, type=int) 23 | @click.option('--size', default=1, type=int) 24 | @click.option('--num_cpu_jobs', '-j', default=cpu_count() // 2, type=int) 25 | @click.option('--num_cpus_per_job', '-c', default=2, type=int) 26 | @click.option('--num_iter', '-i', default=2, type=int) 27 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri'])) 28 | @click.option('--generate_hmm_profile/--generate_msa_only', '-p', default=True) 29 | @click.option('--write_file/--read_file', '-w', default=True) 30 | def main(pkl_dataset: str, pruned_dataset: str, hhsuite_db: str, output_dir: str, rank: int, 31 | size: int, num_cpu_jobs: int, num_cpus_per_job: int, num_iter: int, source_type: str, 32 | generate_hmm_profile: bool, write_file: bool): 33 | """Run external programs for feature generation to turn raw PDB files from (../raw) into sequence or structure-based residue features (saved in ../interim/external_feats by default).""" 34 | logger = logging.getLogger(__name__) 35 | logger.info(f'Generating external features from PDB files in {pkl_dataset}') 36 | 37 | # Reestablish global rank 38 | rank = get_global_node_rank(rank, size) 39 | logger.info(f"Assigned global rank {rank} of world size {size}") 40 | 41 | # Determine true rank and size for a given node 42 | bfd_copy_ids = ["_1", "_2", "_3", "_4", "_5", "_6", "_7", "_8", 43 | "_9", "_10", "_11", "_12", "_17", "_21", "_25", "_29"] 44 | bfd_copy_id = bfd_copy_ids[rank] 45 | 46 | # Assemble true ID of the BFD copy to use for generating profile HMMs 47 | hhsuite_dbs = glob.glob(os.path.join(hhsuite_db + bfd_copy_id, "*bfd*")) 48 | assert len(hhsuite_dbs) == 1, "Only a single BFD database must be present in the given database directory." 49 | hhsuite_db = hhsuite_dbs[0] 50 | logger.info(f'Starting HH-suite for node {rank + 1} out of a global MPI world of size {size},' 51 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}.' 52 | f' This node\'s copy of the BFD is {hhsuite_db}') 53 | 54 | # Generate profile HMMs # 55 | # Run with --write_file=True using one node 56 | # Then run with --read_file=True using multiple nodes to distribute workload across nodes and their CPU cores 57 | map_all_profile_hmms(pkl_dataset, pruned_dataset, output_dir, hhsuite_db, num_cpu_jobs, 58 | num_cpus_per_job, source_type, num_iter, not generate_hmm_profile, rank, size, write_file) 59 | 60 | 61 | if __name__ == '__main__': 62 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 63 | logging.basicConfig(level=logging.INFO, format=log_fmt) 64 | 65 | main() 66 | -------------------------------------------------------------------------------- /project/datasets/builder/generate_psaia_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source code (MIT-Licensed) inspired by Atom3 (https://github.com/drorlab/atom3 & https://github.com/amorehead/atom3) 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | import click 9 | import atom3.conservation as con 10 | 11 | from project.utils.utils import get_global_node_rank 12 | 13 | 14 | @click.command() 15 | @click.argument('psaia_dir', type=click.Path(exists=True)) 16 | @click.argument('psaia_config', type=click.Path(exists=True)) 17 | @click.argument('pdb_dataset', type=click.Path(exists=True)) 18 | @click.argument('pkl_dataset', type=click.Path(exists=True)) 19 | @click.argument('pruned_dataset', type=click.Path(exists=True)) 20 | @click.argument('output_dir', type=click.Path()) 21 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri'])) 22 | @click.option('--rank', '-r', default=0) 23 | @click.option('--size', '-s', default=1) 24 | def main(psaia_dir: str, psaia_config: str, pdb_dataset: str, pkl_dataset: str, 25 | pruned_dataset: str, output_dir: str, source_type: str, rank: int, size: int): 26 | """Run external programs for feature generation to turn raw PDB files from (../raw) into sequence or structure-based residue features (saved in ../interim/external_feats by default).""" 27 | # Reestablish global rank 28 | rank = get_global_node_rank(rank, size) 29 | 30 | # Ensure that this task only gets run on a single node to prevent race conditions 31 | if rank == 0: 32 | logger = logging.getLogger(__name__) 33 | logger.info(f'Generating PSAIA features from PDB files in {pkl_dataset}') 34 | 35 | # Ensure PSAIA is in PATH 36 | PSAIA_PATH = psaia_dir 37 | if PSAIA_PATH not in os.environ["PATH"]: 38 | logger.info('Adding ' + PSAIA_PATH + ' to system path') 39 | os.environ["PATH"] += os.path.sep + PSAIA_PATH 40 | 41 | # Generate protrusion indices 42 | con.map_all_protrusion_indices(psaia_dir, psaia_config, pdb_dataset, pkl_dataset, 43 | pruned_dataset, output_dir, source_type) 44 | 45 | 46 | if __name__ == '__main__': 47 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 48 | logging.basicConfig(level=logging.INFO, format=log_fmt) 49 | 50 | main() 51 | -------------------------------------------------------------------------------- /project/datasets/builder/impute_missing_feature_values.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import click 6 | from parallel import submit_jobs 7 | 8 | from project.utils.utils import get_global_node_rank, impute_missing_feature_values 9 | 10 | 11 | @click.command() 12 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 13 | @click.option('--impute_atom_features', '-a', default=False) 14 | @click.option('--advanced_logging', '-l', default=False) 15 | @click.option('--num_cpus', '-c', default=1) 16 | @click.option('--rank', '-r', default=0) 17 | @click.option('--size', '-s', default=1) 18 | def main(output_dir: str, impute_atom_features: bool, advanced_logging: bool, num_cpus: int, rank: int, size: int): 19 | """Impute missing feature values.""" 20 | # Reestablish global rank 21 | rank = get_global_node_rank(rank, size) 22 | 23 | # Ensure that this task only gets run on a single node to prevent race conditions 24 | if rank == 0: 25 | logger = logging.getLogger(__name__) 26 | logger.info('Imputing missing feature values for given dataset') 27 | 28 | # Make sure the output_dir exists 29 | if not os.path.exists(output_dir): 30 | os.mkdir(output_dir) 31 | 32 | # Collect thread inputs 33 | inputs = [(pair_filename.as_posix(), pair_filename.as_posix(), impute_atom_features, advanced_logging) 34 | for pair_filename in Path(output_dir).rglob('*.dill')] 35 | # Without impute_atom_features set to True, non-CA atoms will be filtered out after writing updated pairs 36 | submit_jobs(impute_missing_feature_values, inputs, num_cpus) 37 | 38 | 39 | if __name__ == '__main__': 40 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 41 | logging.basicConfig(level=logging.INFO, format=log_fmt) 42 | 43 | main() 44 | -------------------------------------------------------------------------------- /project/datasets/builder/launch_parallel_slurm_jobs_on_andes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Where the project is stored 4 | export PROJID=bif135 5 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus 6 | 7 | # Which copies of the BFD to use for the available nodes' search batches (i.e. in range ['_1', '_2', ..., '_31', '_32']) 8 | BFD_COPY_IDS=("_1" "_2" "_3" "_4" "_5" "_6" "_7" "_8" "_9" "_10" "_11" "_12" "_17" "_21" "_25" "_29") 9 | NUM_BFD_COPIES=${#BFD_COPY_IDS[@]} 10 | 11 | # Whether to launch multi-node compilation jobs for DIPS or instead for DB5 12 | compile_dips=false 13 | compile_db5=false 14 | compile_evcoupling=false 15 | compile_casp_capri=true 16 | 17 | if [ "$compile_dips" = true ]; then 18 | # Job 1 - DIPS 19 | echo Submitting job 1 for compile_dips_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES" 20 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 0 "$NUM_BFD_COPIES" 21 | 22 | # Job 2 - DIPS 23 | echo Submitting job 2 for compile_dips_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES" 24 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 1 "$NUM_BFD_COPIES" 25 | 26 | # Job 3 - DIPS 27 | echo Submitting job 3 for compile_dips_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES" 28 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 2 "$NUM_BFD_COPIES" 29 | 30 | # Job 4 - DIPS 31 | echo Submitting job 4 for compile_dips_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES" 32 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 3 "$NUM_BFD_COPIES" 33 | 34 | elif [ "$compile_db5" = true ]; then 35 | # Job 1 - DB5 36 | echo Submitting job 1 for compile_db5_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES" 37 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 0 "$NUM_BFD_COPIES" 38 | 39 | # Job 2 - DB5 40 | echo Submitting job 2 for compile_db5_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES" 41 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 1 "$NUM_BFD_COPIES" 42 | 43 | # Job 3 - DB5 44 | echo Submitting job 3 for compile_db5_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES" 45 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 2 "$NUM_BFD_COPIES" 46 | 47 | # Job 4 - DB5 48 | echo Submitting job 4 for compile_db5_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES" 49 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 3 "$NUM_BFD_COPIES" 50 | 51 | elif [ "$compile_evcoupling" = true ]; then 52 | # Job 1 - EVCoupling 53 | echo Submitting job 1 for compile_evcoupling_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES" 54 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 0 "$NUM_BFD_COPIES" 55 | 56 | # Job 2 - EVCoupling 57 | echo Submitting job 2 for compile_evcoupling_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES" 58 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 1 "$NUM_BFD_COPIES" 59 | 60 | # Job 3 - EVCoupling 61 | echo Submitting job 3 for compile_evcoupling_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES" 62 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 2 "$NUM_BFD_COPIES" 63 | 64 | # Job 4 - EVCoupling 65 | echo Submitting job 4 for compile_evcoupling_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES" 66 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 3 "$NUM_BFD_COPIES" 67 | 68 | elif [ "$compile_casp_capri" = true ]; then 69 | # Job 1 - CASP-CAPRI 70 | echo Submitting job 1 for compile_casp_capri_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES" 71 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 0 "$NUM_BFD_COPIES" 72 | 73 | # Job 2 - CASP-CAPRI 74 | echo Submitting job 2 for compile_casp_capri_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES" 75 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 1 "$NUM_BFD_COPIES" 76 | 77 | # Job 3 - CASP-CAPRI 78 | echo Submitting job 3 for compile_casp_capri_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES" 79 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 2 "$NUM_BFD_COPIES" 80 | 81 | # Job 4 - CASP-CAPRI 82 | echo Submitting job 4 for compile_casp_capri_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES" 83 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 3 "$NUM_BFD_COPIES" 84 | 85 | else 86 | # Job 1 - DIPS 87 | echo Submitting job 1 for compile_dips_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES" 88 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 0 "$NUM_BFD_COPIES" 89 | 90 | # Job 2 - DIPS 91 | echo Submitting job 2 for compile_dips_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES" 92 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 1 "$NUM_BFD_COPIES" 93 | 94 | # Job 3 - DIPS 95 | echo Submitting job 3 for compile_dips_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES" 96 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 2 "$NUM_BFD_COPIES" 97 | 98 | # Job 4 - DIPS 99 | echo Submitting job 4 for compile_dips_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES" 100 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 3 "$NUM_BFD_COPIES" 101 | fi 102 | -------------------------------------------------------------------------------- /project/datasets/builder/log_dataset_statistics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | 6 | from project.utils.utils import get_global_node_rank, log_dataset_statistics, DEFAULT_DATASET_STATISTICS 7 | 8 | 9 | @click.command() 10 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 11 | @click.option('--rank', '-r', default=0) 12 | @click.option('--size', '-s', default=1) 13 | def main(output_dir: str, rank: int, size: int): 14 | """Log all collected dataset statistics.""" 15 | # Reestablish global rank 16 | rank = get_global_node_rank(rank, size) 17 | 18 | # Ensure that this task only gets run on a single node to prevent race conditions 19 | if rank == 0: 20 | logger = logging.getLogger(__name__) 21 | 22 | # Make sure the output_dir exists 23 | if not os.path.exists(output_dir): 24 | os.mkdir(output_dir) 25 | 26 | # Create dataset statistics CSV if not already existent 27 | dataset_statistics_csv = os.path.join(output_dir, 'dataset_statistics.csv') 28 | if not os.path.exists(dataset_statistics_csv): 29 | # Reset dataset statistics CSV 30 | with open(dataset_statistics_csv, 'w') as f: 31 | for key in DEFAULT_DATASET_STATISTICS.keys(): 32 | f.write("%s, %s\n" % (key, DEFAULT_DATASET_STATISTICS[key])) 33 | 34 | with open(dataset_statistics_csv, 'r') as f: 35 | # Read-in existing dataset statistics 36 | dataset_statistics = {} 37 | for line in f.readlines(): 38 | dataset_statistics[line.split(',')[0].strip()] = int(line.split(',')[1].strip()) 39 | 40 | # Log dataset statistics in a readable fashion 41 | if dataset_statistics is not None: 42 | log_dataset_statistics(logger, dataset_statistics) 43 | 44 | 45 | if __name__ == '__main__': 46 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 47 | logging.basicConfig(level=logging.INFO, format=log_fmt) 48 | 49 | main() 50 | -------------------------------------------------------------------------------- /project/datasets/builder/make_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source code (MIT-Licensed) originally from DIPS (https://github.com/drorlab/DIPS) 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | import atom3.complex as comp 9 | import atom3.neighbors as nb 10 | import atom3.pair as pair 11 | import atom3.parse as pa 12 | import click 13 | from project.utils.utils import get_global_node_rank 14 | 15 | 16 | @click.command() 17 | @click.argument('input_dir', type=click.Path(exists=True)) 18 | @click.argument('output_dir', type=click.Path()) 19 | @click.option('--num_cpus', '-c', default=1) 20 | @click.option('--rank', '-r', default=0) 21 | @click.option('--size', '-s', default=1) 22 | @click.option('--neighbor_def', default='non_heavy_res', 23 | type=click.Choice(['non_heavy_res', 'non_heavy_atom', 'ca_res', 'ca_atom'])) 24 | @click.option('--cutoff', default=6) 25 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri'])) 26 | @click.option('--unbound/--bound', default=False) 27 | def main(input_dir: str, output_dir: str, num_cpus: int, rank: int, size: int, 28 | neighbor_def: str, cutoff: int, source_type: str, unbound: bool): 29 | """Run data processing scripts to turn raw data from (../raw) into cleaned data ready to be analyzed (saved in ../interim). For reference, pos_idx indicates the IDs of residues in interaction with non-heavy atoms in a cross-protein residue.""" 30 | # Reestablish global rank 31 | rank = get_global_node_rank(rank, size) 32 | 33 | # Ensure that this task only gets run on a single node to prevent race conditions 34 | if rank == 0: 35 | logger = logging.getLogger(__name__) 36 | logger.info('Making interim data set from raw data') 37 | 38 | parsed_dir = os.path.join(output_dir, 'parsed') 39 | pa.parse_all(input_dir, parsed_dir, num_cpus) 40 | 41 | complexes_dill = os.path.join(output_dir, 'complexes/complexes.dill') 42 | comp.complexes(parsed_dir, complexes_dill, source_type) 43 | pairs_dir = os.path.join(output_dir, 'pairs') 44 | get_neighbors = nb.build_get_neighbors(neighbor_def, cutoff) 45 | get_pairs = pair.build_get_pairs(neighbor_def, source_type, unbound, get_neighbors, False) 46 | complexes = comp.read_complexes(complexes_dill) 47 | pair.all_complex_to_pairs(complexes, source_type, get_pairs, pairs_dir, num_cpus) 48 | 49 | 50 | if __name__ == '__main__': 51 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 52 | logging.basicConfig(level=logging.INFO, format=log_fmt) 53 | 54 | main() 55 | -------------------------------------------------------------------------------- /project/datasets/builder/postprocess_pruned_pairs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import click 5 | from atom3.database import get_structures_filenames, get_pdb_name, get_pdb_code 6 | from atom3.utils import slice_list 7 | from mpi4py import MPI 8 | from parallel import submit_jobs 9 | 10 | from project.utils.utils import get_global_node_rank 11 | from project.utils.utils import postprocess_pruned_pairs 12 | 13 | 14 | @click.command() 15 | @click.argument('raw_pdb_dir', default='../DIPS/raw/pdb', type=click.Path(exists=True)) 16 | @click.argument('pruned_pairs_dir', default='../DIPS/interim/pairs-pruned', type=click.Path(exists=True)) 17 | @click.argument('external_feats_dir', default='../DIPS/interim/external_feats', type=click.Path(exists=True)) 18 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path()) 19 | @click.option('--num_cpus', '-c', default=1) 20 | @click.option('--rank', '-r', default=0) 21 | @click.option('--size', '-s', default=1) 22 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri'])) 23 | def main(raw_pdb_dir: str, pruned_pairs_dir: str, external_feats_dir: str, output_dir: str, 24 | num_cpus: int, rank: int, size: int, source_type: str): 25 | """Run postprocess_pruned_pairs on all provided complexes.""" 26 | # Reestablish global rank 27 | rank = get_global_node_rank(rank, size) 28 | logger = logging.getLogger(__name__) 29 | logger.info(f'Starting postprocessing for node {rank + 1} out of a global MPI world of size {size},' 30 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}') 31 | 32 | # Make sure the output_dir exists 33 | if not os.path.exists(output_dir): 34 | os.mkdir(output_dir) 35 | 36 | # Get work filenames 37 | logger.info(f'Looking for all pairs in {pruned_pairs_dir}') 38 | requested_filenames = get_structures_filenames(pruned_pairs_dir, extension='.dill') 39 | requested_filenames = [filename for filename in requested_filenames] 40 | requested_keys = [get_pdb_name(x) for x in requested_filenames] 41 | produced_filenames = get_structures_filenames(output_dir, extension='.dill') 42 | produced_keys = [get_pdb_name(x) for x in produced_filenames] 43 | work_keys = [key for key in requested_keys if key not in produced_keys] 44 | rscb_pruned_pair_ext = '.dill' if source_type.lower() in ['rcsb', 'evcoupling', 'casp_capri'] else '' 45 | work_filenames = [os.path.join(pruned_pairs_dir, get_pdb_code(work_key)[1:3], work_key + rscb_pruned_pair_ext) 46 | for work_key in work_keys] 47 | logger.info(f'Found {len(work_keys)} work pair(s) in {pruned_pairs_dir}') 48 | 49 | # Reserve an equally-sized portion of the full work load for a given rank in the MPI world 50 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames 51 | work_filename_rank_batches = slice_list(work_filenames, size) 52 | work_filenames = work_filename_rank_batches[rank] 53 | 54 | # Get filenames in which our threads will store output 55 | output_filenames = [] 56 | for pdb_filename in work_filenames: 57 | sub_dir = output_dir + '/' + get_pdb_code(pdb_filename)[1:3] 58 | if not os.path.exists(sub_dir): 59 | os.mkdir(sub_dir) 60 | new_output_filename = sub_dir + '/' + get_pdb_name(pdb_filename) + ".dill" if \ 61 | source_type in ['rcsb', 'evcoupling', 'casp_capri'] else \ 62 | sub_dir + '/' + get_pdb_name(pdb_filename) 63 | output_filenames.append(new_output_filename) 64 | 65 | # Collect thread inputs 66 | inputs = [(raw_pdb_dir, external_feats_dir, i, o, source_type) 67 | for i, o in zip(work_filenames, output_filenames)] 68 | submit_jobs(postprocess_pruned_pairs, inputs, num_cpus) 69 | 70 | 71 | if __name__ == '__main__': 72 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 73 | logging.basicConfig(level=logging.INFO, format=log_fmt) 74 | 75 | main() 76 | -------------------------------------------------------------------------------- /project/datasets/builder/prune_pairs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source code (MIT-Licensed) originally from DIPS (https://github.com/drorlab/DIPS) 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | import atom3.database as db 9 | import click 10 | import numpy as np 11 | import parallel as par 12 | 13 | from project.utils.utils import __load_to_keep_files_into_dataframe, process_pairs_to_keep 14 | from project.utils.utils import get_global_node_rank 15 | 16 | 17 | @click.command() 18 | @click.argument('pair_dir', type=click.Path(exists=True)) 19 | @click.argument('to_keep_dir', type=click.Path(exists=True)) 20 | @click.argument('output_dir', type=click.Path()) 21 | @click.option('--num_cpus', '-c', default=1) 22 | @click.option('--rank', '-r', default=0) 23 | @click.option('--size', '-s', default=1) 24 | def main(pair_dir: str, to_keep_dir: str, output_dir: str, num_cpus: int, rank: int, size: int): 25 | """Run write_pairs on all provided complexes.""" 26 | # Reestablish global rank 27 | rank = get_global_node_rank(rank, size) 28 | 29 | # Ensure that this task only gets run on a single node to prevent race conditions 30 | if rank == 0: 31 | logger = logging.getLogger(__name__) 32 | to_keep_filenames = \ 33 | db.get_structures_filenames(to_keep_dir, extension='.txt') 34 | if len(to_keep_filenames) == 0: 35 | logger.warning(f'There is no to_keep file in {to_keep_dir}.' 36 | f' All pair files from {pair_dir} will be copied into {output_dir}') 37 | 38 | to_keep_df = __load_to_keep_files_into_dataframe(to_keep_filenames) 39 | logger.info(f'There are {to_keep_df.shape} rows, cols in to_keep_df') 40 | 41 | # Get work filenames 42 | logger.info(f'Looking for all pairs in {pair_dir}') 43 | work_filenames = db.get_structures_filenames(pair_dir, extension='.dill') 44 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames 45 | work_keys = [db.get_pdb_name(x) for x in work_filenames] 46 | logger.info(f'Found {len(work_keys)} pairs in {output_dir}') 47 | 48 | # Get filenames in which our threads will store output 49 | output_filenames = [] 50 | for pdb_filename in work_filenames: 51 | sub_dir = output_dir + '/' + db.get_pdb_code(pdb_filename)[1:3] 52 | if not os.path.exists(sub_dir): 53 | os.mkdir(sub_dir) 54 | output_filenames.append( 55 | sub_dir + '/' + db.get_pdb_name(pdb_filename) + ".dill") 56 | 57 | # Collect thread inputs 58 | n_copied = 0 59 | inputs = [(i, o, to_keep_df) for i, o in zip(work_filenames, output_filenames)] 60 | n_copied += np.sum(par.submit_jobs(process_pairs_to_keep, inputs, num_cpus)) 61 | logger.info(f'{n_copied} out of {len(work_keys)} pairs was copied') 62 | 63 | 64 | if __name__ == '__main__': 65 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' 66 | logging.basicConfig(level=logging.INFO, format=log_fmt) 67 | 68 | main() 69 | -------------------------------------------------------------------------------- /project/datasets/builder/psaia_chothia.radii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_chothia.radii -------------------------------------------------------------------------------- /project/datasets/builder/psaia_config_file_casp_capri.txt: -------------------------------------------------------------------------------- 1 | analyze_bound: 1 2 | analyze_unbound: 0 3 | calc_asa: 0 4 | z_slice: 0.5 5 | r_solvent: 1.4 6 | write_asa: 0 7 | calc_rasa: 0 8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa 9 | calc_dpx: 0 10 | calc_cx: 1 11 | cx_threshold: 10 12 | cx_volume: 20.1 13 | calc_hydro: 0 14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb 15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii 16 | write_xml: 0 17 | write_table: 1 18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/CASP-CAPRI/interim/external_feats/PSAIA/CASP_CAPRI -------------------------------------------------------------------------------- /project/datasets/builder/psaia_config_file_db5.txt: -------------------------------------------------------------------------------- 1 | analyze_bound: 0 2 | analyze_unbound: 1 3 | calc_asa: 0 4 | z_slice: 0.25 5 | r_solvent: 1.4 6 | write_asa: 0 7 | calc_rasa: 0 8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa 9 | calc_dpx: 0 10 | calc_cx: 1 11 | cx_threshold: 10 12 | cx_volume: 20.1 13 | calc_hydro: 0 14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb 15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii 16 | write_xml: 0 17 | write_table: 1 18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/DB5/interim/external_feats/PSAIA/DB5 -------------------------------------------------------------------------------- /project/datasets/builder/psaia_config_file_dips.txt: -------------------------------------------------------------------------------- 1 | analyze_bound: 1 2 | analyze_unbound: 0 3 | calc_asa: 0 4 | z_slice: 0.5 5 | r_solvent: 1.4 6 | write_asa: 0 7 | calc_rasa: 0 8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa 9 | calc_dpx: 0 10 | calc_cx: 1 11 | cx_threshold: 10 12 | cx_volume: 20.1 13 | calc_hydro: 0 14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb 15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii 16 | write_xml: 0 17 | write_table: 1 18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/DIPS/interim/external_feats/PSAIA/RCSB -------------------------------------------------------------------------------- /project/datasets/builder/psaia_config_file_evcoupling.txt: -------------------------------------------------------------------------------- 1 | analyze_bound: 1 2 | analyze_unbound: 0 3 | calc_asa: 0 4 | z_slice: 0.5 5 | r_solvent: 1.4 6 | write_asa: 0 7 | calc_rasa: 0 8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa 9 | calc_dpx: 0 10 | calc_cx: 1 11 | cx_threshold: 10 12 | cx_volume: 20.1 13 | calc_hydro: 0 14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb 15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii 16 | write_xml: 0 17 | write_table: 1 18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/EVCoupling/interim/external_feats/PSAIA/EVCOUPLING -------------------------------------------------------------------------------- /project/datasets/builder/psaia_hydrophobicity.hpb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_hydrophobicity.hpb -------------------------------------------------------------------------------- /project/datasets/builder/psaia_natural_asa.asa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_natural_asa.asa -------------------------------------------------------------------------------- /project/utils/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Bio.PDB.PDBParser import PDBParser 3 | from Bio.PDB.Polypeptide import CaPPBuilder 4 | 5 | # Cluster-specific limit to the number of compute nodes available to each Slurm job 6 | MAX_NODES_PER_JOB = 4 7 | 8 | # Dataset-global node count limits to restrict computational learning complexity 9 | ATOM_COUNT_LIMIT = 17500 # Default filter for both datasets when encoding complexes at an atom-based level 10 | 11 | # From where we can get bound PDB complexes 12 | RCSB_BASE_URL = 'ftp://ftp.wwpdb.org/pub/pdb/data/biounit/coordinates/divided/' 13 | 14 | # The PDB codes of structures added between DB4 and DB5 (to be used for testing dataset) 15 | DB5_TEST_PDB_CODES = ['3R9A', '4GAM', '3AAA', '4H03', '1EXB', 16 | '2GAF', '2GTP', '3RVW', '3SZK', '4IZ7', 17 | '4GXU', '3BX7', '2YVJ', '3V6Z', '1M27', 18 | '4FQI', '4G6J', '3BIW', '3PC8', '3HI6', 19 | '2X9A', '3HMX', '2W9E', '4G6M', '3LVK', 20 | '1JTD', '3H2V', '4DN4', 'BP57', '3L5W', 21 | '3A4S', 'CP57', '3DAW', '3VLB', '3K75', 22 | '2VXT', '3G6D', '3EO1', '4JCV', '4HX3', 23 | '3F1P', '3AAD', '3EOA', '3MXW', '3L89', 24 | '4M76', 'BAAD', '4FZA', '4LW4', '1RKE', 25 | '3FN1', '3S9D', '3H11', '2A1A', '3P57'] 26 | 27 | # Postprocessing logger dictionary 28 | DEFAULT_DATASET_STATISTICS = dict(num_of_processed_complexes=0, num_of_df0_residues=0, num_of_df1_residues=0, 29 | num_of_df0_interface_residues=0, num_of_df1_interface_residues=0, 30 | num_of_pos_res_pairs=0, num_of_neg_res_pairs=0, num_of_res_pairs=0, 31 | num_of_valid_df0_ss_values=0, num_of_valid_df1_ss_values=0, 32 | num_of_valid_df0_rsa_values=0, num_of_valid_df1_rsa_values=0, 33 | num_of_valid_df0_rd_values=0, num_of_valid_df1_rd_values=0, 34 | num_of_valid_df0_protrusion_indices=0, num_of_valid_df1_protrusion_indices=0, 35 | num_of_valid_df0_hsaacs=0, num_of_valid_df1_hsaacs=0, 36 | num_of_valid_df0_cn_values=0, num_of_valid_df1_cn_values=0, 37 | num_of_valid_df0_sequence_feats=0, num_of_valid_df1_sequence_feats=0, 38 | num_of_valid_df0_amide_normal_vecs=0, num_of_valid_df1_amide_normal_vecs=0) 39 | 40 | # Parsing utilities for PDB files (i.e. relevant for sequence and structure analysis) 41 | PDB_PARSER = PDBParser() 42 | CA_PP_BUILDER = CaPPBuilder() 43 | 44 | # Dict for converting three letter codes to one letter codes 45 | D3TO1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K', 46 | 'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 47 | 'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 48 | 'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'} 49 | RES_NAMES_LIST = list(D3TO1.keys()) 50 | 51 | # PSAIA features to encode as DataFrame columns 52 | PSAIA_COLUMNS = ['avg_cx', 's_avg_cx', 's_ch_avg_cx', 's_ch_s_avg_cx', 'max_cx', 'min_cx'] 53 | 54 | # Constants for calculating half sphere exposure statistics 55 | AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY-' 56 | AMINO_ACID_IDX = dict(zip(AMINO_ACIDS, range(len(AMINO_ACIDS)))) 57 | 58 | # Default fill values for missing features 59 | HSAAC_DIM = 42 # We have 2 + (2 * 20) HSAAC values from the two instances of the unknown residue symbol '-' 60 | DEFAULT_MISSING_FEAT_VALUE = np.nan 61 | DEFAULT_MISSING_SS = '-' 62 | DEFAULT_MISSING_RSA = DEFAULT_MISSING_FEAT_VALUE 63 | DEFAULT_MISSING_RD = DEFAULT_MISSING_FEAT_VALUE 64 | DEFAULT_MISSING_PROTRUSION_INDEX = [DEFAULT_MISSING_FEAT_VALUE for _ in range(6)] 65 | DEFAULT_MISSING_HSAAC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(HSAAC_DIM)] 66 | DEFAULT_MISSING_CN = DEFAULT_MISSING_FEAT_VALUE 67 | DEFAULT_MISSING_SEQUENCE_FEATS = np.array([DEFAULT_MISSING_FEAT_VALUE for _ in range(27)]) 68 | DEFAULT_MISSING_NORM_VEC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(3)] 69 | 70 | # Default number of NaN values allowed in a specific column before imputing missing features of the column with zero 71 | NUM_ALLOWABLE_NANS = 5 72 | 73 | # Features to be one-hot encoded during graph processing and what their values could be 74 | FEAT_COLS = [ 75 | 'resname', # By default, leave out one-hot encoding of residues' type to decrease feature redundancy 76 | 'ss_value', 77 | 'rsa_value', 78 | 'rd_value' 79 | ] 80 | FEAT_COLS.extend( 81 | PSAIA_COLUMNS + 82 | ['hsaac', 83 | 'cn_value', 84 | 'sequence_feats', 85 | 'amide_norm_vec', 86 | # 'element' # For atom-level learning only 87 | ]) 88 | 89 | ALLOWABLE_FEATS = [ 90 | # By default, leave out one-hot encoding of residues' type to decrease feature redundancy 91 | ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR", 92 | "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"], 93 | ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values 94 | [], # Empty list means take scalar value as is 95 | [], 96 | [], 97 | [], 98 | [], 99 | [], 100 | [], 101 | [], 102 | [[]], # Doubly-nested, empty list means take first-level nested list as is 103 | [], 104 | [[]], 105 | [[]], 106 | # ['C', 'O', 'N', 'S'] # For atom-level learning only 107 | ] 108 | -------------------------------------------------------------------------------- /project/utils/training_constants.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | import numpy as np 5 | 6 | # Dataset-global node count limits to restrict computational learning complexity 7 | ATOM_COUNT_LIMIT = 2048 # Default atom count filter for DIPS-Plus when encoding complexes at an atom-based level 8 | RESIDUE_COUNT_LIMIT = 256 # Default residue count limit for DIPS-Plus (empirically determined for smoother training) 9 | NODE_COUNT_LIMIT = 2304 # An upper-bound on the node count limit for Geometric Transformers - equal to 9-sized batch 10 | KNN = 20 # Default number of nearest neighbors to query for during graph message passing 11 | 12 | # The PDB codes of structures added between DB4 and DB5 (to be used for testing dataset) 13 | DB5_TEST_PDB_CODES = ['3R9A', '4GAM', '3AAA', '4H03', '1EXB', 14 | '2GAF', '2GTP', '3RVW', '3SZK', '4IZ7', 15 | '4GXU', '3BX7', '2YVJ', '3V6Z', '1M27', 16 | '4FQI', '4G6J', '3BIW', '3PC8', '3HI6', 17 | '2X9A', '3HMX', '2W9E', '4G6M', '3LVK', 18 | '1JTD', '3H2V', '4DN4', 'BP57', '3L5W', 19 | '3A4S', 'CP57', '3DAW', '3VLB', '3K75', 20 | '2VXT', '3G6D', '3EO1', '4JCV', '4HX3', 21 | '3F1P', '3AAD', '3EOA', '3MXW', '3L89', 22 | '4M76', 'BAAD', '4FZA', '4LW4', '1RKE', 23 | '3FN1', '3S9D', '3H11', '2A1A', '3P57'] 24 | 25 | # Default fill values for missing features 26 | HSAAC_DIM = 42 # We have 2 + (2 * 20) HSAAC values from the two instances of the unknown residue symbol '-' 27 | DEFAULT_MISSING_FEAT_VALUE = np.nan 28 | DEFAULT_MISSING_SS = '-' 29 | DEFAULT_MISSING_RSA = DEFAULT_MISSING_FEAT_VALUE 30 | DEFAULT_MISSING_RD = DEFAULT_MISSING_FEAT_VALUE 31 | DEFAULT_MISSING_PROTRUSION_INDEX = [DEFAULT_MISSING_FEAT_VALUE for _ in range(6)] 32 | DEFAULT_MISSING_HSAAC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(HSAAC_DIM)] 33 | DEFAULT_MISSING_CN = DEFAULT_MISSING_FEAT_VALUE 34 | DEFAULT_MISSING_SEQUENCE_FEATS = np.array([DEFAULT_MISSING_FEAT_VALUE for _ in range(27)]) 35 | DEFAULT_MISSING_NORM_VEC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(3)] 36 | 37 | # Dict for converting three letter codes to one letter codes 38 | D3TO1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K', 39 | 'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 40 | 'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 41 | 'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'} 42 | 43 | # PSAIA features to encode as DataFrame columns 44 | PSAIA_COLUMNS = ['avg_cx', 's_avg_cx', 's_ch_avg_cx', 's_ch_s_avg_cx', 'max_cx', 'min_cx'] 45 | 46 | # Features to be one-hot encoded during graph processing and what their values could be 47 | FEAT_COLS = [ 48 | 'resname', # [7:27] 49 | 'ss_value', # [27:35] 50 | 'rsa_value', # [35:36] 51 | 'rd_value' # [36:37] 52 | ] 53 | FEAT_COLS.extend( 54 | PSAIA_COLUMNS + # [37:43] 55 | ['hsaac', # [43:85] 56 | 'cn_value', # [85:86] 57 | 'sequence_feats', # [86:113] 58 | 'amide_norm_vec', # [Stored separately] 59 | # 'element' # For atom-level learning only 60 | ]) 61 | 62 | ALLOWABLE_FEATS = [ 63 | ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR", 64 | "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"], 65 | ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values 66 | [], # Empty list means take scalar value as is 67 | [], 68 | [], 69 | [], 70 | [], 71 | [], 72 | [], 73 | [], 74 | [[]], # Doubly-nested, empty list means take first-level nested list as is 75 | [], 76 | [[]], 77 | [[]], 78 | # ['C', 'O', 'N', 'S'] # For atom-level learning only 79 | ] 80 | 81 | # A schematic of which tensor indices correspond to which node and edge features 82 | FEATURE_INDICES = { 83 | # Node feature indices 84 | 'node_pos_enc': 0, 85 | 'node_geo_feats_start': 1, 86 | 'node_geo_feats_end': 7, 87 | 'node_dips_plus_feats_start': 7, 88 | 'node_dips_plus_feats_end': 113, 89 | # Edge feature indices 90 | 'edge_pos_enc': 0, 91 | 'edge_weights': 1, 92 | 'edge_dist_feats_start': 2, 93 | 'edge_dist_feats_end': 20, 94 | 'edge_dir_feats_start': 20, 95 | 'edge_dir_feats_end': 23, 96 | 'edge_orient_feats_start': 23, 97 | 'edge_orient_feats_end': 27, 98 | 'edge_amide_angles': 27 99 | } 100 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | addopts = 7 | --strict 8 | --doctest-modules 9 | --durations=0 10 | 11 | [coverage:report] 12 | exclude_lines = 13 | pragma: no-cover 14 | pass 15 | 16 | [flake8] 17 | max-line-length = 120 18 | exclude = .tox,*.egg,build,temp 19 | select = E,W,F 20 | doctests = True 21 | verbose = 2 22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 23 | format = pylint 24 | # see: https://www.flake8rules.com/ 25 | ignore = 26 | E731 # Do not assign a lambda expression, use a def 27 | W504 # Line break occurred after a binary operator 28 | F401 # Module imported but unused 29 | F841 # Local variable name is assigned to but never used 30 | W605 # Invalid escape sequence 'x' 31 | 32 | # setup.cfg or tox.ini 33 | [check-manifest] 34 | ignore = 35 | *.yml 36 | .github 37 | .github/* 38 | 39 | [metadata] 40 | # license_file = LICENSE 41 | # description-file = README.md 42 | long_description = file:README.md 43 | long_description_content_type = text/markdown 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='DIPS-Plus', 7 | version='1.2.0', 8 | description='The Enhanced Database of Interacting Protein Structures for Interface Prediction', 9 | author='Alex Morehead', 10 | author_email='acmwhb@umsystem.edu', 11 | url='https://github.com/BioinfoMachineLearning/DIPS-Plus', 12 | install_requires=[ 13 | 'setuptools==65.5.1', 14 | 'dill==0.3.3', 15 | 'tqdm==4.49.0', 16 | 'Sphinx==4.0.1', 17 | 'easy-parallel-py3==0.1.6.4', 18 | 'click==7.0.0', 19 | ], 20 | packages=find_packages(), 21 | ) 22 | --------------------------------------------------------------------------------