├── .gitignore
├── LICENSE
├── README.md
├── citation.bib
├── environment.yml
├── notebooks
├── data_usage.ipynb
├── data_usage.py
├── feature_generation.ipynb
└── feature_generation.py
├── project
├── datasets
│ ├── DB5
│ │ ├── README
│ │ ├── db5_dgl_data_module.py
│ │ └── db5_dgl_dataset.py
│ ├── DIPS
│ │ └── filters
│ │ │ ├── buried_surface_over_500.txt
│ │ │ ├── nmr_res_less_3.5.txt
│ │ │ ├── seq_id_less_30.txt
│ │ │ └── size_over_50.0.txt
│ ├── EVCoupling
│ │ ├── evcoupling_heterodimer_list.txt
│ │ ├── final
│ │ │ └── stub
│ │ ├── group_files.py
│ │ └── interim
│ │ │ └── stub
│ ├── analysis
│ │ ├── analyze_experiment_types_and_resolution.py
│ │ ├── analyze_feature_correlation.py
│ │ └── analyze_interface_waters.py
│ └── builder
│ │ ├── add_new_feature.py
│ │ ├── annotate_idr_residues.py
│ │ ├── collect_dataset_statistics.py
│ │ ├── compile_casp_capri_dataset_on_andes.sh
│ │ ├── compile_db5_dataset_on_andes.sh
│ │ ├── compile_dips_dataset_on_andes.sh
│ │ ├── compile_evcoupling_dataset_on_andes.sh
│ │ ├── convert_complexes_to_graphs.py
│ │ ├── create_hdf5_dataset.py
│ │ ├── download_missing_pruned_pair_pdbs.py
│ │ ├── extract_raw_pdb_gz_archives.py
│ │ ├── generate_hhsuite_features.py
│ │ ├── generate_psaia_features.py
│ │ ├── impute_missing_feature_values.py
│ │ ├── launch_parallel_slurm_jobs_on_andes.sh
│ │ ├── log_dataset_statistics.py
│ │ ├── make_dataset.py
│ │ ├── partition_dataset_filenames.py
│ │ ├── postprocess_pruned_pairs.py
│ │ ├── prune_pairs.py
│ │ ├── psaia_chothia.radii
│ │ ├── psaia_config_file_casp_capri.txt
│ │ ├── psaia_config_file_db5.txt
│ │ ├── psaia_config_file_dips.txt
│ │ ├── psaia_config_file_evcoupling.txt
│ │ ├── psaia_hydrophobicity.hpb
│ │ └── psaia_natural_asa.asa
└── utils
│ ├── constants.py
│ ├── modules.py
│ ├── training_constants.py
│ ├── training_utils.py
│ └── utils.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .github
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 | .conda/
94 | miniconda3
95 | venv.tar.gz
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
110 | # IDEs
111 | .idea
112 | .vscode
113 |
114 | # TensorBoard
115 | tb_logs/
116 |
117 | # Feature Processing
118 | *idr_annotation*.txt
119 | *work_filenames*.csv
120 |
121 | # DIPS
122 | project/datasets/DIPS/complexes/**
123 | project/datasets/DIPS/interim/**
124 | project/datasets/DIPS/pairs/**
125 | project/datasets/DIPS/parsed/**
126 | project/datasets/DIPS/raw/**
127 | project/datasets/DIPS/final/raw/**
128 | project/datasets/DIPS/final/load_pair_example.py
129 | project/datasets/DIPS/final/final_raw_dips*.tar.gz*
130 | project/datasets/DIPS/final/processed/**
131 |
132 | # DB5
133 | project/datasets/DB5/processed/**
134 | project/datasets/DB5/raw/**
135 | project/datasets/DB5/interim/**
136 | project/datasets/DB5/final/raw/**
137 | project/datasets/DB5/final/final_raw_db5.tar.gz*
138 | project/datasets/DB5/final/processed/**
139 |
140 | # EVCoupling
141 | project/datasets/EVCoupling/raw/**
142 | project/datasets/EVCoupling/interim/**
143 | project/datasets/EVCoupling/final/raw/**
144 | project/datasets/EVCoupling/final/processed/**
145 |
146 | # CASP-CAPRI
147 | project/datasets/CASP-CAPRI/raw/**
148 | project/datasets/CASP-CAPRI/interim/**
149 | project/datasets/CASP-CAPRI/final/raw/**
150 | project/datasets/CASP-CAPRI/final/processed/**
151 |
152 | # Input
153 | project/datasets/Input/**
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DIPS-Plus
4 |
5 | The Enhanced Database of Interacting Protein Structures for Interface Prediction
6 |
7 | [](https://www.nature.com/articles/s41597-023-02409-3) [![CC BY 4.0][cc-by-shield]][cc-by] [](https://doi.org/10.5281/zenodo.5134732) [](https://doi.org/10.5281/zenodo.8140981)
8 |
9 | [cc-by]: http://creativecommons.org/licenses/by/4.0/
10 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png
11 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg
12 |
13 | [comment]: <> ([](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-35-2021))
14 |
15 | [

](https://pypi.org/project/DIPS-Plus/)
16 |
17 |
18 |
19 | ## Versioning
20 |
21 | * Version 1.0.0: Initial release of DIPS-Plus and DB5-Plus (DOI: 10.5281/zenodo.4815267)
22 | * Version 1.1.0: Minor updates to DIPS-Plus and DB5-Plus' tar archives (DOI: 10.5281/zenodo.5134732)
23 | * DIPS-Plus' final 'raw' tar archive now includes standardized 80%-20% lists of filenames for training and validation, respectively
24 | * DB5-Plus' final 'raw' tar archive now includes (optional) standardized lists of filenames for training and validation, respectively
25 | * DB5-Plus' final 'raw' tar archive now also includes a corrected (i.e. de-duplicated) list of filenames for its 55 test complexes
26 | * Benchmark results included in our paper were run after this issue was resolved
27 | * However, if you ran experiments using DB5-Plus' filename list for its test complexes, please re-run them using the latest list
28 | * Version 1.2.0: Minor additions to DIPS-Plus tar archives, including new residue-level intrinsic disorder region annotations and raw Jackhmmer-small BFD MSAs (Supplementary Data DOI: 10.5281/zenodo.8071136)
29 | * Version 1.3.0: Minor additions to DIPS-Plus tar archives, including new FoldSeek-based structure-focused training and validation splits, residue-level (scalar) disorder propensities, and a Graphein-based featurization pipeline (Supplementary Data DOI: 10.5281/zenodo.8140981)
30 |
31 | ## How to set up
32 |
33 | First, download Mamba (if not already downloaded):
34 | ```bash
35 | wget "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh"
36 | bash Mambaforge-$(uname)-$(uname -m).sh # Accept all terms and install to the default location
37 | rm Mambaforge-$(uname)-$(uname -m).sh # (Optionally) Remove installer after using it
38 | source ~/.bashrc # Alternatively, one can restart their shell session to achieve the same result
39 | ```
40 |
41 | Then, create and configure Mamba environment:
42 |
43 | ```bash
44 | # Clone project:
45 | git clone https://github.com/BioinfoMachineLearning/DIPS-Plus
46 | cd DIPS-Plus
47 |
48 | # Create Conda environment using local 'environment.yml' file:
49 | mamba env create -f environment.yml
50 | conda activate DIPS-Plus # Note: One still needs to use `conda` to (de)activate environments
51 |
52 | # Install local project as package:
53 | pip3 install -e .
54 | ```
55 |
56 | To install PSAIA for feature generation, install GCC 10 for PSAIA:
57 |
58 | ```bash
59 | # Install GCC 10 for Ubuntu 20.04:
60 | sudo apt install software-properties-common
61 | sudo add-apt-repository ppa:ubuntu-toolchain-r/ppa
62 | sudo apt update
63 | sudo apt install gcc-10 g++-10
64 |
65 | # Or install GCC 10 for Arch Linux/Manjaro:
66 | yay -S gcc10
67 | ```
68 |
69 | Then install QT4 for PSAIA:
70 |
71 | ```bash
72 | # Install QT4 for Ubuntu 20.04:
73 | sudo add-apt-repository ppa:rock-core/qt4
74 | sudo apt update
75 | sudo apt install libqt4* libqtcore4 libqtgui4 libqtwebkit4 qt4* libxext-dev
76 |
77 | # Or install QT4 for Arch Linux/Manjaro:
78 | yay -S qt4
79 | ```
80 |
81 | Conclude by compiling PSAIA from source:
82 |
83 | ```bash
84 | # Select the location to install the software:
85 | MY_LOCAL=~/Programs
86 |
87 | # Download and extract PSAIA's source code:
88 | mkdir "$MY_LOCAL"
89 | cd "$MY_LOCAL"
90 | wget http://complex.zesoi.fer.hr/data/PSAIA-1.0-source.tar.gz
91 | tar -xvzf PSAIA-1.0-source.tar.gz
92 |
93 | # Compile PSAIA (i.e., a GUI for PSA):
94 | cd PSAIA_1.0_source/make/linux/psaia/
95 | qmake-qt4 psaia.pro
96 | make
97 |
98 | # Compile PSA (i.e., the protein structure analysis (PSA) program):
99 | cd ../psa/
100 | qmake-qt4 psa.pro
101 | make
102 |
103 | # Compile PIA (i.e., the protein interaction analysis (PIA) program):
104 | cd ../pia/
105 | qmake-qt4 pia.pro
106 | make
107 |
108 | # Test run any of the above-compiled programs:
109 | cd "$MY_LOCAL"/PSAIA_1.0_source/bin/linux
110 | # Test run PSA inside a GUI:
111 | ./psaia/psaia
112 | # Test run PIA through a terminal:
113 | ./pia/pia
114 | # Test run PSA through a terminal:
115 | ./psa/psa
116 | ```
117 |
118 | Lastly, install Docker following the instructions from https://docs.docker.com/engine/install/
119 |
120 | ## How to generate protein feature inputs
121 | In our [feature generation notebook](notebooks/feature_generation.ipynb), we provide examples of how users can generate the protein features described in our [accompanying manuscript](https://arxiv.org/abs/2106.04362) for individual protein inputs.
122 |
123 | ## How to use data
124 | In our [data usage notebook](notebooks/data_usage.ipynb), we provide examples of how users might use DIPS-Plus (or DB5-Plus) for downstream analysis or prediction tasks. For example, to train a new NeiA model with DB5-Plus as its cross-validation dataset, first download DB5-Plus' raw files and process them via the `data_usage` notebook:
125 |
126 | ```bash
127 | mkdir -p project/datasets/DB5/final
128 | wget https://zenodo.org/record/5134732/files/final_raw_db5.tar.gz -O project/datasets/DB5/final/final_raw_db5.tar.gz
129 | tar -xzf project/datasets/DB5/final/final_raw_db5.tar.gz -C project/datasets/DB5/final/
130 |
131 | # To process these raw files for training and subsequently train a model:
132 | python3 notebooks/data_usage.py
133 | ```
134 |
135 | ## How to split data using FoldSeek
136 | We provide users with the [ability](https://github.com/BioinfoMachineLearning/DIPS-Plus/blob/75775d98f0923faf11fb50639eb58cd510a10ffd/project/datasets/builder/partition_dataset_filenames.py#L486) to perform structure-based splits of the complexes in DIPS-Plus using FoldSeek. This script is designed to allow users to customize how stringent one would like FoldSeek's searches to be for structure-based splitting. Moreover, we provide standardized structure-based splits of DIPS-Plus' complexes in the corresponding [supplementary Zenodo data record](https://zenodo.org/record/8140981).
137 |
138 | ## How to featurize DIPS-Plus complexes using Graphein
139 | In the new [graph featurization script](https://github.com/BioinfoMachineLearning/DIPS-Plus/blob/main/project/datasets/builder/add_new_feature.py), we provide an example of how users may install new Expasy protein scale features using the Graphein library. The script is designed to be amenable to simple user customization such that users can use this script to insert arbitrary new Graphein-based features into each DIPS-Plus complex's pair file, for downstream tasks.
140 |
141 | ## Standard DIPS-Plus directory structure
142 |
143 | ```
144 | DIPS-Plus
145 | │
146 | └───project
147 | │
148 | └───datasets
149 | │
150 | └───DB5
151 | │ │
152 | │ └───final
153 | │ │ │
154 | │ │ └───processed # task-ready features for each dataset example
155 | │ │ │
156 | │ │ └───raw # generic features for each dataset example
157 | │ │
158 | │ └───interim
159 | │ │ │
160 | │ │ └───complexes # metadata for each dataset example
161 | │ │ │
162 | │ │ └───external_feats # features curated for each dataset example using external tools
163 | │ │ │
164 | │ │ └───pairs # pair-wise features for each dataset example
165 | │ │
166 | │ └───raw # raw PDB data downloads for each dataset example
167 | │
168 | └───DIPS
169 | │
170 | └───filters # filters to apply to each (un-pruned) dataset example
171 | │
172 | └───final
173 | │ │
174 | │ └───processed # task-ready features for each dataset example
175 | │ │
176 | │ └───raw # generic features for each dataset example
177 | │
178 | └───interim
179 | │ │
180 | │ └───complexes # metadata for each dataset example
181 | │ │
182 | │ └───external_feats # features curated for each dataset example using external tools
183 | │ │
184 | │ └───pairs-pruned # filtered pair-wise features for each dataset example
185 | │ │
186 | │ └───parsed # pair-wise features for each dataset example after initial parsing
187 | │
188 | └───raw
189 | │
190 | └───pdb # raw PDB data downloads for each dataset example
191 | ```
192 |
193 | ## How to compile DIPS-Plus from scratch
194 |
195 | Retrieve protein complexes from the RCSB PDB and build out directory structure:
196 |
197 | ```bash
198 | # Remove all existing training/testing sample lists
199 | rm project/datasets/DIPS/final/raw/pairs-postprocessed.txt project/datasets/DIPS/final/raw/pairs-postprocessed-train.txt project/datasets/DIPS/final/raw/pairs-postprocessed-val.txt project/datasets/DIPS/final/raw/pairs-postprocessed-test.txt
200 |
201 | # Create data directories (if not already created):
202 | mkdir project/datasets/DIPS/raw project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim project/datasets/DIPS/interim/pairs-pruned project/datasets/DIPS/interim/external_feats project/datasets/DIPS/final project/datasets/DIPS/final/raw project/datasets/DIPS/final/processed
203 |
204 | # Download the raw PDB files:
205 | rsync -rlpt -v -z --delete --port=33444 --include='*.gz' --include='*.xz' --include='*/' --exclude '*' \
206 | rsync.rcsb.org::ftp_data/biounit/coordinates/divided/ project/datasets/DIPS/raw/pdb
207 |
208 | # Extract the raw PDB files:
209 | python3 project/datasets/builder/extract_raw_pdb_gz_archives.py project/datasets/DIPS/raw/pdb
210 |
211 | # Process the raw PDB data into associated pair files:
212 | python3 project/datasets/builder/make_dataset.py project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim --num_cpus 28 --source_type rcsb --bound
213 |
214 | # Apply additional filtering criteria:
215 | python3 project/datasets/builder/prune_pairs.py project/datasets/DIPS/interim/pairs project/datasets/DIPS/filters project/datasets/DIPS/interim/pairs-pruned --num_cpus 28
216 |
217 | # Generate externally-sourced features:
218 | python3 project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_dips.txt "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats --source_type rcsb
219 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file # Note: After this, one needs to re-run this command with `--read_file` instead
220 |
221 | # Generate multiple sequence alignments (MSAs) using a smaller sequence database (if not already created using the standard BFD):
222 | DOWNLOAD_DIR="$HHSUITE_DB_DIR" && ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" && SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" && BASENAME=$(basename "${SOURCE_URL}") && mkdir --parents "${ROOT_DIR}" && aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" && pushd "${ROOT_DIR}" && gunzip "${ROOT_DIR}/${BASENAME}" && popd # e.g., Download the small BFD
223 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB_DIR"/small_bfd "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --generate_msa_only --write_file # Note: After this, one needs to re-run this command with `--read_file` instead
224 |
225 | # Identify interfaces within intrinsically disordered regions (IDRs) #
226 | # (1) Pull down the Docker image for `flDPnn`
227 | docker pull docker.io/sinaghadermarzi/fldpnn
228 | # (2) For all sequences in the dataset, predict which interface residues reside within IDRs
229 | python3 project/datasets/builder/annotate_idr_interfaces.py "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 16
230 |
231 | # Add new features to the filtered pairs, ensuring that the pruned pairs' original PDB files are stored locally for DSSP:
232 | python3 project/datasets/builder/download_missing_pruned_pair_pdbs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2"
233 | python3 project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 32
234 |
235 | # Partition dataset filenames, aggregate statistics, and impute missing features
236 | python3 project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DIPS/final/raw --source_type rcsb --filter_by_atom_count True --max_atom_count 17500 --rank "$1" --size "$2"
237 | python3 project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2"
238 | python3 project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2"
239 | python3 project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DIPS/final/raw --impute_atom_features False --advanced_logging False --num_cpus 32 --rank "$1" --size "$2"
240 |
241 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels
242 | python3 project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DIPS/final/raw "$PROJDIR"/project/datasets/DIPS/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2"
243 | ```
244 |
245 | ## How to assemble DB5-Plus
246 |
247 | Fetch prepared protein complexes from Dataverse:
248 |
249 | ```bash
250 | # Download the prepared DB5 files:
251 | wget -O project/datasets/DB5.tar.gz https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/H93ZKK/BXXQCG
252 |
253 | # Extract downloaded DB5 archive:
254 | tar -xzf project/datasets/DB5.tar.gz --directory project/datasets/
255 |
256 | # Remove (now) redundant DB5 archive and other miscellaneous files:
257 | rm project/datasets/DB5.tar.gz project/datasets/DB5/.README.swp
258 | rm -rf project/datasets/DB5/interim project/datasets/DB5/processed
259 |
260 | # Create relevant interim and final data directories:
261 | mkdir project/datasets/DB5/interim project/datasets/DB5/interim/external_feats
262 | mkdir project/datasets/DB5/final project/datasets/DB5/final/raw project/datasets/DB5/final/processed
263 |
264 | # Construct DB5 dataset pairs:
265 | python3 project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim --num_cpus 32 --source_type db5 --unbound
266 |
267 | # Generate externally-sourced features:
268 | python3 project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_db5.txt "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/external_feats --source_type db5
269 | python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/DB5/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type db5 --write_file
270 |
271 | # Add new features to the filtered pairs:
272 | python3 project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/pairs "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/final/raw --num_cpus 32 --source_type db5
273 |
274 | # Partition dataset filenames, aggregate statistics, and impute missing features
275 | python3 project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DB5/final/raw --source_type db5 --rank "$1" --size "$2"
276 | python3 project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2"
277 | python3 project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2"
278 | python3 project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DB5/final/raw --impute_atom_features False --advanced_logging False --num_cpus 32 --rank "$1" --size "$2"
279 |
280 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels
281 | python3 project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DB5/final/raw "$PROJDIR"/project/datasets/DB5/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2"
282 | ```
283 |
284 | ## How to reassemble DIPS-Plus' "interim" external features
285 |
286 | We split the (tar.gz) archive into eight separate parts with
287 | 'split -b 4096M interim_external_feats_dips.tar.gz "interim_external_feats_dips.tar.gz.part"'
288 | to upload it to the dataset's primary Zenodo record, so to recover the original archive:
289 |
290 | ```bash
291 | # Reassemble external features archive with 'cat'
292 | cat interim_external_feats_dips.tar.gz.parta* >interim_external_feats_dips.tar.gz
293 | ```
294 |
295 | ## Python 2 to 3 pickle file solution
296 |
297 | While using Python 3 in this project, you may encounter the following error if you try to postprocess '.dill' pruned
298 | pairs that were created using Python 2.
299 |
300 | ModuleNotFoundError: No module named 'dill.dill'
301 |
302 | 1. To resolve it, ensure that the 'dill' package's version is greater than 0.3.2.
303 | 2. If the problem persists, edit the pickle.py file corresponding to your Conda environment's Python 3 installation (
304 | e.g. ~/DIPS-Plus/venv/lib/python3.8/pickle.py) and add the statement
305 |
306 | ```python
307 | if module == 'dill.dill': module = 'dill._dill'
308 | ```
309 |
310 | to the end of the
311 |
312 | ```python
313 | if self.proto < 3 and self.fix_imports:
314 | ```
315 |
316 | block in the Unpickler class' find_class() function
317 | (e.g. line 1577 of Python 3.8.5's pickle.py).
318 |
319 | ## Citation
320 | If you find DIPS-Plus useful in your research, please cite:
321 |
322 | ```bibtex
323 | @article{morehead2023dips,
324 | title={DIPS-Plus: The enhanced database of interacting protein structures for interface prediction},
325 | author={Morehead, Alex and Chen, Chen and Sedova, Ada and Cheng, Jianlin},
326 | journal={Scientific Data},
327 | volume={10},
328 | number={1},
329 | pages={509},
330 | year={2023},
331 | publisher={Nature Publishing Group UK London}
332 | }
333 | ```
334 |
--------------------------------------------------------------------------------
/citation.bib:
--------------------------------------------------------------------------------
1 | @misc{morehead2021dipsplus,
2 | title={DIPS-Plus: The Enhanced Database of Interacting Protein Structures for Interface Prediction},
3 | author={Alex Morehead and Chen Chen and Ada Sedova and Jianlin Cheng},
4 | year={2021},
5 | eprint={2106.04362},
6 | archivePrefix={arXiv},
7 | primaryClass={q-bio.QM}
8 | }
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: DIPS-Plus
2 | channels:
3 | - pytorch
4 | - salilab
5 | - dglteam/label/cu116
6 | - nvidia
7 | - bioconda
8 | - defaults
9 | - conda-forge
10 | dependencies:
11 | - _libgcc_mutex=0.1=conda_forge
12 | - _openmp_mutex=4.5=2_kmp_llvm
13 | - appdirs=1.4.4=pyhd3eb1b0_0
14 | - aria2=1.23.0=0
15 | - asttokens=2.2.1=pyhd8ed1ab_0
16 | - backcall=0.2.0=pyh9f0ad1d_0
17 | - backports=1.0=pyhd8ed1ab_3
18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
19 | - binutils_impl_linux-64=2.36.1=h193b22a_2
20 | - biopython=1.78=py38h7f8727e_0
21 | - blas=1.0=openblas
22 | - bottleneck=1.3.5=py38h7deecbd_0
23 | - brotlipy=0.7.0=py38h27cfd23_1003
24 | - bzip2=1.0.8=h7b6447c_0
25 | - c-ares=1.19.1=hd590300_0
26 | - ca-certificates=2023.5.7=hbcca054_0
27 | - certifi=2023.5.7=py38h06a4308_0
28 | - cffi=1.15.1=py38h5eee18b_3
29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
30 | - colorama=0.4.6=pyhd8ed1ab_0
31 | - comm=0.1.3=pyhd8ed1ab_0
32 | - cryptography=39.0.0=py38h1724139_0
33 | - cuda=11.6.1=0
34 | - cuda-cccl=11.6.55=hf6102b2_0
35 | - cuda-command-line-tools=11.6.2=0
36 | - cuda-compiler=11.6.2=0
37 | - cuda-cudart=11.6.55=he381448_0
38 | - cuda-cudart-dev=11.6.55=h42ad0f4_0
39 | - cuda-cuobjdump=11.6.124=h2eeebcb_0
40 | - cuda-cupti=11.6.124=h86345e5_0
41 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0
42 | - cuda-driver-dev=11.6.55=0
43 | - cuda-gdb=12.1.105=0
44 | - cuda-libraries=11.6.1=0
45 | - cuda-libraries-dev=11.6.1=0
46 | - cuda-memcheck=11.8.86=0
47 | - cuda-nsight=12.1.105=0
48 | - cuda-nsight-compute=12.1.1=0
49 | - cuda-nvcc=11.6.124=hbba6d2d_0
50 | - cuda-nvdisasm=12.1.105=0
51 | - cuda-nvml-dev=11.6.55=haa9ef22_0
52 | - cuda-nvprof=12.1.105=0
53 | - cuda-nvprune=11.6.124=he22ec0a_0
54 | - cuda-nvrtc=11.6.124=h020bade_0
55 | - cuda-nvrtc-dev=11.6.124=h249d397_0
56 | - cuda-nvtx=11.6.124=h0630a44_0
57 | - cuda-nvvp=12.1.105=0
58 | - cuda-runtime=11.6.1=0
59 | - cuda-samples=11.6.101=h8efea70_0
60 | - cuda-sanitizer-api=12.1.105=0
61 | - cuda-toolkit=11.6.1=0
62 | - cuda-tools=11.6.1=0
63 | - cuda-version=11.7=h67201e3_2
64 | - cuda-visual-tools=11.6.1=0
65 | - cudatoolkit=11.7.0=hd8887f6_10
66 | - cudnn=8.8.0.121=h0800d71_0
67 | - cycler=0.11.0=pyhd8ed1ab_0
68 | - debugpy=1.6.7=py38h8dc9893_0
69 | - decorator=5.1.1=pyhd8ed1ab_0
70 | - dgl=1.1.0.cu116=py38_0
71 | - dssp=3.0.0=h3fd9d12_4
72 | - executing=1.2.0=pyhd8ed1ab_0
73 | - ffmpeg=4.3=hf484d3e_0
74 | - foldseek=7.04e0ec8=pl5321hb365157_0
75 | - freetype=2.12.1=hca18f0e_1
76 | - gawk=5.1.0=h7f98852_0
77 | - gcc=10.3.0=he2824d0_10
78 | - gcc_impl_linux-64=10.3.0=hf2f2afa_16
79 | - gds-tools=1.6.1.9=0
80 | - gettext=0.21.1=h27087fc_0
81 | - gmp=6.2.1=h58526e2_0
82 | - gnutls=3.6.13=h85f3911_1
83 | - gxx=10.3.0=he2824d0_10
84 | - gxx_impl_linux-64=10.3.0=hf2f2afa_16
85 | - hhsuite=3.3.0=py38pl5321hcbe9525_8
86 | - hmmer=3.3.2=hdbdd923_4
87 | - icu=58.2=he6710b0_3
88 | - idna=3.4=py38h06a4308_0
89 | - importlib-metadata=6.6.0=pyha770c72_0
90 | - importlib_metadata=6.6.0=hd8ed1ab_0
91 | - ipykernel=6.23.1=pyh210e3f2_0
92 | - ipython=8.4.0=py38h578d9bd_0
93 | - jedi=0.18.2=pyhd8ed1ab_0
94 | - joblib=1.1.1=py38h06a4308_0
95 | - jpeg=9e=h0b41bf4_3
96 | - jupyter_client=8.2.0=pyhd8ed1ab_0
97 | - jupyter_core=4.12.0=py38h578d9bd_0
98 | - kernel-headers_linux-64=2.6.32=he073ed8_15
99 | - kiwisolver=1.4.4=py38h43d8883_1
100 | - lame=3.100=h166bdaf_1003
101 | - lcms2=2.15=hfd0df8a_0
102 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
103 | - lerc=4.0.0=h27087fc_0
104 | - libblas=3.9.0=16_linux64_openblas
105 | - libboost=1.73.0=h28710b8_12
106 | - libcblas=3.9.0=16_linux64_openblas
107 | - libcublas=11.9.2.110=h5e84587_0
108 | - libcublas-dev=11.9.2.110=h5c901ab_0
109 | - libcufft=10.7.1.112=hf425ae0_0
110 | - libcufft-dev=10.7.1.112=ha5ce4c0_0
111 | - libcufile=1.6.1.9=0
112 | - libcufile-dev=1.6.1.9=0
113 | - libcurand=10.3.2.106=0
114 | - libcurand-dev=10.3.2.106=0
115 | - libcusolver=11.3.4.124=h33c3c4e_0
116 | - libcusparse=11.7.2.124=h7538f96_0
117 | - libcusparse-dev=11.7.2.124=hbbe9722_0
118 | - libdeflate=1.17=h0b41bf4_0
119 | - libffi=3.4.4=h6a678d5_0
120 | - libgcc=7.2.0=h69d50b8_2
121 | - libgcc-devel_linux-64=10.3.0=he6cfe16_16
122 | - libgcc-ng=12.2.0=h65d4601_19
123 | - libgfortran-ng=12.2.0=h69a702a_19
124 | - libgfortran5=12.2.0=h337968e_19
125 | - libgomp=12.2.0=h65d4601_19
126 | - libiconv=1.17=h166bdaf_0
127 | - libidn2=2.3.4=h166bdaf_0
128 | - liblapack=3.9.0=16_linux64_openblas
129 | - libnpp=11.6.3.124=hd2722f0_0
130 | - libnpp-dev=11.6.3.124=h3c42840_0
131 | - libnsl=2.0.0=h5eee18b_0
132 | - libnvjpeg=11.6.2.124=hd473ad6_0
133 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0
134 | - libopenblas=0.3.21=h043d6bf_0
135 | - libpng=1.6.39=h753d276_0
136 | - libprotobuf=3.21.12=h3eb15da_0
137 | - libsanitizer=10.3.0=h26c7422_16
138 | - libsodium=1.0.18=h36c2ea0_1
139 | - libsqlite=3.42.0=h2797004_0
140 | - libssh2=1.10.0=haa6b8db_3
141 | - libstdcxx-devel_linux-64=10.3.0=he6cfe16_16
142 | - libstdcxx-ng=12.2.0=h46fd767_19
143 | - libtiff=4.5.0=h6adf6a1_2
144 | - libunistring=0.9.10=h7f98852_0
145 | - libuuid=2.38.1=h0b41bf4_0
146 | - libwebp-base=1.3.0=h0b41bf4_0
147 | - libxcb=1.13=h7f98852_1004
148 | - libxml2=2.9.9=h13577e0_2
149 | - libzlib=1.2.13=h166bdaf_4
150 | - llvm-openmp=16.0.4=h4dfa4b3_0
151 | - lz4-c=1.9.4=h6a678d5_0
152 | - magma=2.6.2=hc72dce7_0
153 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
154 | - mkl=2022.2.1=h84fe81f_16997
155 | - mpi=1.0=openmpi
156 | - msms=2.6.1=h516909a_0
157 | - nccl=2.15.5.1=h0800d71_0
158 | - ncurses=6.4=h6a678d5_0
159 | - nest-asyncio=1.5.6=pyhd8ed1ab_0
160 | - nettle=3.6=he412f7d_0
161 | - networkx=3.1=pyhd8ed1ab_0
162 | - ninja=1.11.1=h924138e_0
163 | - nsight-compute=2023.1.1.4=0
164 | - numexpr=2.8.4=py38hd2a5715_1
165 | - openh264=2.1.1=h780b84a_0
166 | - openjpeg=2.5.0=hfec8fc6_2
167 | - openmpi=4.1.5=h414af15_101
168 | - openssl=1.1.1u=hd590300_0
169 | - packaging=23.0=py38h06a4308_0
170 | - pandas=1.5.3=py38h417a72b_0
171 | - parso=0.8.3=pyhd8ed1ab_0
172 | - perl=5.32.1=0_h5eee18b_perl5
173 | - pexpect=4.8.0=pyh1a96a4e_2
174 | - pickleshare=0.7.5=py_1003
175 | - pillow=9.4.0=py38hde6dc18_1
176 | - pip=23.1.2=py38h06a4308_0
177 | - pooch=1.4.0=pyhd3eb1b0_0
178 | - prompt-toolkit=3.0.38=pyha770c72_0
179 | - psutil=5.9.5=py38h1de0b5d_0
180 | - pthread-stubs=0.4=h36c2ea0_1001
181 | - ptyprocess=0.7.0=pyhd3deb0d_0
182 | - pure_eval=0.2.2=pyhd8ed1ab_0
183 | - pycparser=2.21=pyhd3eb1b0_0
184 | - pygments=2.15.1=pyhd8ed1ab_0
185 | - pyopenssl=23.1.1=pyhd8ed1ab_0
186 | - pysocks=1.7.1=py38h06a4308_0
187 | - python=3.8.16=he550d4f_1_cpython
188 | - python-dateutil=2.8.2=pyhd3eb1b0_0
189 | - python_abi=3.8=2_cp38
190 | - pytorch=1.13.1=cuda112py38hd94e077_200
191 | - pytorch-cuda=11.6=h867d48c_1
192 | - pytorch-mutex=1.0=cuda
193 | - pytz=2022.7=py38h06a4308_0
194 | - pyzmq=25.0.2=py38he24dcef_0
195 | - readline=8.2=h5eee18b_0
196 | - requests=2.29.0=py38h06a4308_0
197 | - scikit-learn=1.2.2=py38h6a678d5_0
198 | - scipy=1.10.1=py38h32ae08f_1
199 | - setuptools=67.8.0=py38h06a4308_0
200 | - six=1.16.0=pyhd3eb1b0_1
201 | - sleef=3.5.1=h9b69904_2
202 | - sqlite=3.41.2=h5eee18b_0
203 | - stack_data=0.6.2=pyhd8ed1ab_0
204 | - sysroot_linux-64=2.12=he073ed8_15
205 | - tbb=2021.7.0=h924138e_0
206 | - threadpoolctl=2.2.0=pyh0d69192_0
207 | - tk=8.6.12=h1ccaba5_0
208 | - torchaudio=0.13.1=py38_cu116
209 | - torchvision=0.14.1=py38_cu116
210 | - tornado=6.3.2=py38h01eb140_0
211 | - tqdm=4.65.0=py38hb070fc8_0
212 | - traitlets=5.9.0=pyhd8ed1ab_0
213 | - typing_extensions=4.6.0=pyha770c72_0
214 | - urllib3=1.26.15=py38h06a4308_0
215 | - wcwidth=0.2.6=pyhd8ed1ab_0
216 | - wheel=0.38.4=py38h06a4308_0
217 | - xorg-libxau=1.0.11=hd590300_0
218 | - xorg-libxdmcp=1.1.3=h7f98852_0
219 | - xz=5.2.6=h166bdaf_0
220 | - zeromq=4.3.4=h9c3ff4c_1
221 | - zipp=3.15.0=pyhd8ed1ab_0
222 | - zlib=1.2.13=h166bdaf_4
223 | - zstd=1.5.2=h3eb15da_6
224 | - pip:
225 | - absl-py==1.4.0
226 | - aiohttp==3.8.4
227 | - aiosignal==1.3.1
228 | - alabaster==0.7.13
229 | - async-timeout==4.0.2
230 | - git+https://github.com/amorehead/atom3.git@83987404ceed38a1f5a5abd517aa38128d0a4f2c
231 | - attrs==23.1.0
232 | - babel==2.12.1
233 | - beautifulsoup4==4.12.2
234 | - biopandas==0.5.0.dev0
235 | - bioservices==1.11.2
236 | - cachetools==5.3.1
237 | - cattrs==23.1.2
238 | - click==7.0
239 | - colorlog==6.7.0
240 | - configparser==5.3.0
241 | - contourpy==1.1.0
242 | - deepdiff==6.3.1
243 | - dill==0.3.3
244 | - docker-pycreds==0.4.0
245 | - docutils==0.17.1
246 | - easy-parallel-py3==0.1.6.4
247 | - easydev==0.12.1
248 | - exceptiongroup==1.1.2
249 | - fairscale==0.4.0
250 | - fonttools==4.40.0
251 | - frozenlist==1.3.3
252 | - fsspec==2023.5.0
253 | - future==0.18.3
254 | - gevent==22.10.2
255 | - gitdb==4.0.10
256 | - gitpython==3.1.31
257 | - google-auth==2.19.0
258 | - google-auth-oauthlib==1.0.0
259 | - git+https://github.com/a-r-j/graphein.git@371ce9a462b610529488e87a712484328a89de36
260 | - greenlet==2.0.2
261 | - grequests==0.7.0
262 | - grpcio==1.54.2
263 | - h5py==3.8.0
264 | - hickle==5.0.2
265 | - imagesize==1.4.1
266 | - importlib-resources==6.0.0
267 | - install==1.3.5
268 | - jaxtyping==0.2.19
269 | - jinja2==2.11.3
270 | - loguru==0.7.0
271 | - looseversion==1.1.2
272 | - lxml==4.9.3
273 | - markdown==3.4.3
274 | - markdown-it-py==3.0.0
275 | - markupsafe==1.1.1
276 | - matplotlib==3.7.2
277 | - mdurl==0.1.2
278 | - mmtf-python==1.1.3
279 | - mpi4py==3.0.3
280 | - msgpack==1.0.5
281 | - multidict==6.0.4
282 | - multipledispatch==1.0.0
283 | - multiprocess==0.70.11.1
284 | - numpy==1.23.5
285 | - oauthlib==3.2.2
286 | - ordered-set==4.1.0
287 | - pathos==0.2.7
288 | - pathtools==0.1.2
289 | - pdb-tools==2.5.0
290 | - platformdirs==3.8.1
291 | - plotly==5.15.0
292 | - pox==0.3.2
293 | - ppft==1.7.6.6
294 | - promise==2.3
295 | - protobuf==3.20.3
296 | - pyasn1==0.5.0
297 | - pyasn1-modules==0.3.0
298 | - pydantic==1.10.11
299 | - pydeprecate==0.3.1
300 | - pyparsing==3.0.9
301 | - pytorch-lightning==1.4.8
302 | - pyyaml==5.4.1
303 | - requests-cache==1.1.0
304 | - requests-oauthlib==1.3.1
305 | - rich==13.4.2
306 | - rich-click==1.6.1
307 | - rsa==4.9
308 | - seaborn==0.12.2
309 | - sentry-sdk==1.24.0
310 | - shortuuid==1.0.11
311 | - smmap==5.0.0
312 | - snowballstemmer==2.2.0
313 | - soupsieve==2.4.1
314 | - subprocess32==3.5.4
315 | - suds-community==1.1.2
316 | - tenacity==8.2.2
317 | - tensorboard==2.13.0
318 | - tensorboard-data-server==0.7.0
319 | - termcolor==2.3.0
320 | - torchmetrics==0.5.1
321 | - typeguard==4.0.0
322 | - url-normalize==1.4.3
323 | - wandb==0.12.2
324 | - werkzeug==2.3.6
325 | - wget==3.2
326 | - wrapt==1.15.0
327 | - xarray==2023.1.0
328 | - xmltodict==0.13.0
329 | - yarl==1.9.2
330 | - yaspin==2.3.0
331 | - zope-event==5.0
332 | - zope-interface==6.0
333 |
--------------------------------------------------------------------------------
/notebooks/data_usage.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Example of data usage"
9 | ]
10 | },
11 | {
12 | "attachments": {},
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "### Neural network model training"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# -------------------------------------------------------------------------------------------------------------------------------------\n",
26 | "# Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch):\n",
27 | "# -------------------------------------------------------------------------------------------------------------------------------------\n",
28 | "\n",
29 | "import os\n",
30 | "import sys\n",
31 | "from pathlib import Path\n",
32 | "\n",
33 | "import pytorch_lightning as pl\n",
34 | "import torch.nn as nn\n",
35 | "from pytorch_lightning.plugins import DDPPlugin\n",
36 | "\n",
37 | "from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule\n",
38 | "from project.utils.modules import LitNeiA\n",
39 | "from project.utils.training_utils import collect_args, process_args, construct_pl_logger"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "def main(args):\n",
49 | " # -----------\n",
50 | " # Data\n",
51 | " # -----------\n",
52 | " # Load Docking Benchmark 5 (DB5) data module\n",
53 | " db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir,\n",
54 | " batch_size=args.batch_size,\n",
55 | " num_dataloader_workers=args.num_workers,\n",
56 | " knn=args.knn,\n",
57 | " self_loops=args.self_loops,\n",
58 | " percent_to_use=args.db5_percent_to_use,\n",
59 | " process_complexes=args.process_complexes,\n",
60 | " input_indep=args.input_indep)\n",
61 | " db5_data_module.setup()\n",
62 | "\n",
63 | " # ------------\n",
64 | " # Model\n",
65 | " # ------------\n",
66 | " # Assemble a dictionary of model arguments\n",
67 | " dict_args = vars(args)\n",
68 | " use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB\n",
69 | "\n",
70 | " # Pick model and supply it with a dictionary of arguments\n",
71 | " if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA)\n",
72 | " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n",
73 | " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n",
74 | " gnn_activ_fn=nn.Tanh(),\n",
75 | " interact_activ_fn=nn.ReLU(),\n",
76 | " num_classes=db5_data_module.db5_test.num_classes,\n",
77 | " weighted_avg=True, # Use the neighborhood weighted average variant of NeiA\n",
78 | " num_gnn_layers=dict_args['num_gnn_layers'],\n",
79 | " num_interact_layers=dict_args['num_interact_layers'],\n",
80 | " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n",
81 | " num_epochs=dict_args['num_epochs'],\n",
82 | " pn_ratio=dict_args['pn_ratio'],\n",
83 | " knn=dict_args['knn'],\n",
84 | " dropout_rate=dict_args['dropout_rate'],\n",
85 | " metric_to_track=dict_args['metric_to_track'],\n",
86 | " weight_decay=dict_args['weight_decay'],\n",
87 | " batch_size=dict_args['batch_size'],\n",
88 | " lr=dict_args['lr'],\n",
89 | " multi_gpu_backend=dict_args[\"accelerator\"])\n",
90 | " args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n",
91 | " f'-n{db5_data_module.db5_test.num_node_features}' \\\n",
92 | " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n",
93 | " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n",
94 | " if not args.experiment_name \\\n",
95 | " else args.experiment_name\n",
96 | " template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}'\n",
97 | "\n",
98 | " else: # Default Model - Neighborhood Average (NeiA)\n",
99 | " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n",
100 | " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n",
101 | " gnn_activ_fn=nn.Tanh(),\n",
102 | " interact_activ_fn=nn.ReLU(),\n",
103 | " num_classes=db5_data_module.db5_test.num_classes,\n",
104 | " weighted_avg=False,\n",
105 | " num_gnn_layers=dict_args['num_gnn_layers'],\n",
106 | " num_interact_layers=dict_args['num_interact_layers'],\n",
107 | " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n",
108 | " num_epochs=dict_args['num_epochs'],\n",
109 | " pn_ratio=dict_args['pn_ratio'],\n",
110 | " knn=dict_args['knn'],\n",
111 | " dropout_rate=dict_args['dropout_rate'],\n",
112 | " metric_to_track=dict_args['metric_to_track'],\n",
113 | " weight_decay=dict_args['weight_decay'],\n",
114 | " batch_size=dict_args['batch_size'],\n",
115 | " lr=dict_args['lr'],\n",
116 | " multi_gpu_backend=dict_args[\"accelerator\"])\n",
117 | " args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n",
118 | " f'-n{db5_data_module.db5_test.num_node_features}' \\\n",
119 | " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n",
120 | " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n",
121 | " if not args.experiment_name \\\n",
122 | " else args.experiment_name\n",
123 | " template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}'\n",
124 | "\n",
125 | " # ------------\n",
126 | " # Checkpoint\n",
127 | " # ------------\n",
128 | " ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)\n",
129 | " ckpt_path_exists = os.path.exists(ckpt_path)\n",
130 | " ckpt_provided = args.ckpt_name != '' and ckpt_path_exists\n",
131 | " model = model.load_from_checkpoint(ckpt_path,\n",
132 | " use_wandb_logger=use_wandb_logger,\n",
133 | " batch_size=args.batch_size,\n",
134 | " lr=args.lr,\n",
135 | " weight_decay=args.weight_decay,\n",
136 | " dropout_rate=args.dropout_rate) if ckpt_provided else model\n",
137 | "\n",
138 | " # ------------\n",
139 | " # Trainer\n",
140 | " # ------------\n",
141 | " trainer = pl.Trainer.from_argparse_args(args)\n",
142 | "\n",
143 | " # -------------\n",
144 | " # Learning Rate\n",
145 | " # -------------\n",
146 | " if args.find_lr:\n",
147 | " lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder\n",
148 | " fig = lr_finder.plot(suggest=True) # Plot learning rates\n",
149 | " fig.savefig('optimal_lr.pdf')\n",
150 | " fig.show()\n",
151 | " model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate\n",
152 | " print(f'Optimal learning rate found: {model.hparams.lr}')\n",
153 | "\n",
154 | " # ------------\n",
155 | " # Logger\n",
156 | " # ------------\n",
157 | " pl_logger = construct_pl_logger(args) # Log everything to an external logger\n",
158 | " trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance\n",
159 | "\n",
160 | " # -----------\n",
161 | " # Callbacks\n",
162 | " # -----------\n",
163 | " # Create and use callbacks\n",
164 | " mode = 'min' if 'ce' in args.metric_to_track else 'max'\n",
165 | " early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,\n",
166 | " mode=mode,\n",
167 | " min_delta=args.min_delta,\n",
168 | " patience=args.patience)\n",
169 | " ckpt_callback = pl.callbacks.ModelCheckpoint(\n",
170 | " monitor=args.metric_to_track,\n",
171 | " mode=mode,\n",
172 | " verbose=True,\n",
173 | " save_last=True,\n",
174 | " save_top_k=3,\n",
175 | " filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs\n",
176 | " )\n",
177 | " lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True)\n",
178 | " trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback]\n",
179 | "\n",
180 | " # ------------\n",
181 | " # Restore\n",
182 | " # ------------\n",
183 | " # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally\n",
184 | " if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path):\n",
185 | " checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best'\n",
186 | " artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model')\n",
187 | " artifact_dir = artifact.download()\n",
188 | " model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt',\n",
189 | " use_wandb_logger=use_wandb_logger,\n",
190 | " batch_size=args.batch_size,\n",
191 | " lr=args.lr,\n",
192 | " weight_decay=args.weight_decay)\n",
193 | "\n",
194 | " # -------------\n",
195 | " # Training\n",
196 | " # -------------\n",
197 | " # Train with the provided model and DataModule\n",
198 | " trainer.fit(model=model, datamodule=db5_data_module)\n",
199 | "\n",
200 | " # -------------\n",
201 | " # Testing\n",
202 | " # -------------\n",
203 | " trainer.test()\n"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "# -----------\n",
213 | "# Jupyter\n",
214 | "# -----------\n",
215 | "sys.argv = ['']"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "# -----------\n",
225 | "# Arguments\n",
226 | "# -----------\n",
227 | "# Collect all arguments\n",
228 | "parser = collect_args()\n",
229 | "\n",
230 | "# Parse all known and unknown arguments\n",
231 | "args, unparsed_argv = parser.parse_known_args()\n",
232 | "\n",
233 | "# Let the model add what it wants\n",
234 | "parser = LitNeiA.add_model_specific_args(parser)\n",
235 | "\n",
236 | "# Re-parse all known and unknown arguments after adding those that are model specific\n",
237 | "args, unparsed_argv = parser.parse_known_args()\n",
238 | "\n",
239 | "# TODO: Manually set arguments within a Jupyter notebook from here\n",
240 | "args.model_name = \"neia\"\n",
241 | "args.multi_gpu_backend = \"dp\"\n",
242 | "args.db5_data_dir = \"../project/datasets/DB5/final/raw\"\n",
243 | "args.process_complexes = True\n",
244 | "args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation\n",
245 | "\n",
246 | "# Set Lightning-specific parameter values before constructing Trainer instance\n",
247 | "args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes}\n",
248 | "args.max_epochs = args.num_epochs\n",
249 | "args.profiler = args.profiler_method\n",
250 | "args.accelerator = args.multi_gpu_backend\n",
251 | "args.auto_select_gpus = args.auto_choose_gpus\n",
252 | "args.gpus = args.num_gpus\n",
253 | "args.num_nodes = args.num_compute_nodes\n",
254 | "args.precision = args.gpu_precision\n",
255 | "args.accumulate_grad_batches = args.accum_grad_batches\n",
256 | "args.gradient_clip_val = args.grad_clip_val\n",
257 | "args.gradient_clip_algo = args.grad_clip_algo\n",
258 | "args.stochastic_weight_avg = args.stc_weight_avg\n",
259 | "args.deterministic = True # Make LightningModule's training reproducible\n",
260 | "\n",
261 | "# Set plugins for Lightning\n",
262 | "args.plugins = [\n",
263 | " # 'ddp_sharded', # For sharded model training (to reduce GPU requirements)\n",
264 | " # DDPPlugin(find_unused_parameters=False),\n",
265 | "]\n",
266 | "\n",
267 | "# Finalize all arguments as necessary\n",
268 | "args = process_args(args)\n",
269 | "\n",
270 | "# Begin execution of model training with given args above\n",
271 | "main(args)"
272 | ]
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": "DIPS-Plus",
278 | "language": "python",
279 | "name": "python3"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 3
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython3",
291 | "version": "3.8.16"
292 | },
293 | "orig_nbformat": 4
294 | },
295 | "nbformat": 4,
296 | "nbformat_minor": 2
297 | }
298 |
--------------------------------------------------------------------------------
/notebooks/data_usage.py:
--------------------------------------------------------------------------------
1 | # %% [markdown]
2 | # # Example of data usage
3 |
4 | # %% [markdown]
5 | # ### Neural network model training
6 |
7 | # %%
8 | # -------------------------------------------------------------------------------------------------------------------------------------
9 | # Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch):
10 | # -------------------------------------------------------------------------------------------------------------------------------------
11 |
12 | import os
13 | import sys
14 | from pathlib import Path
15 |
16 | import pytorch_lightning as pl
17 | import torch.nn as nn
18 | from pytorch_lightning.plugins import DDPPlugin
19 |
20 | from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule
21 | from project.utils.modules import LitNeiA
22 | from project.utils.training_utils import collect_args, process_args, construct_pl_logger
23 |
24 | # %%
25 | def main(args):
26 | # -----------
27 | # Data
28 | # -----------
29 | # Load Docking Benchmark 5 (DB5) data module
30 | db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir,
31 | batch_size=args.batch_size,
32 | num_dataloader_workers=args.num_workers,
33 | knn=args.knn,
34 | self_loops=args.self_loops,
35 | percent_to_use=args.db5_percent_to_use,
36 | process_complexes=args.process_complexes,
37 | input_indep=args.input_indep)
38 | db5_data_module.setup()
39 |
40 | # ------------
41 | # Model
42 | # ------------
43 | # Assemble a dictionary of model arguments
44 | dict_args = vars(args)
45 | use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB
46 |
47 | # Pick model and supply it with a dictionary of arguments
48 | if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA)
49 | model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,
50 | num_edge_input_feats=db5_data_module.db5_test.num_edge_features,
51 | gnn_activ_fn=nn.Tanh(),
52 | interact_activ_fn=nn.ReLU(),
53 | num_classes=db5_data_module.db5_test.num_classes,
54 | weighted_avg=True, # Use the neighborhood weighted average variant of NeiA
55 | num_gnn_layers=dict_args['num_gnn_layers'],
56 | num_interact_layers=dict_args['num_interact_layers'],
57 | num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],
58 | num_epochs=dict_args['num_epochs'],
59 | pn_ratio=dict_args['pn_ratio'],
60 | knn=dict_args['knn'],
61 | dropout_rate=dict_args['dropout_rate'],
62 | metric_to_track=dict_args['metric_to_track'],
63 | weight_decay=dict_args['weight_decay'],
64 | batch_size=dict_args['batch_size'],
65 | lr=dict_args['lr'],
66 | multi_gpu_backend=dict_args["accelerator"])
67 | args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \
68 | f'-n{db5_data_module.db5_test.num_node_features}' \
69 | f'-e{db5_data_module.db5_test.num_edge_features}' \
70 | f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \
71 | if not args.experiment_name \
72 | else args.experiment_name
73 | template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}'
74 |
75 | else: # Default Model - Neighborhood Average (NeiA)
76 | model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,
77 | num_edge_input_feats=db5_data_module.db5_test.num_edge_features,
78 | gnn_activ_fn=nn.Tanh(),
79 | interact_activ_fn=nn.ReLU(),
80 | num_classes=db5_data_module.db5_test.num_classes,
81 | weighted_avg=False,
82 | num_gnn_layers=dict_args['num_gnn_layers'],
83 | num_interact_layers=dict_args['num_interact_layers'],
84 | num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],
85 | num_epochs=dict_args['num_epochs'],
86 | pn_ratio=dict_args['pn_ratio'],
87 | knn=dict_args['knn'],
88 | dropout_rate=dict_args['dropout_rate'],
89 | metric_to_track=dict_args['metric_to_track'],
90 | weight_decay=dict_args['weight_decay'],
91 | batch_size=dict_args['batch_size'],
92 | lr=dict_args['lr'],
93 | multi_gpu_backend=dict_args["accelerator"])
94 | args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \
95 | f'-n{db5_data_module.db5_test.num_node_features}' \
96 | f'-e{db5_data_module.db5_test.num_edge_features}' \
97 | f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \
98 | if not args.experiment_name \
99 | else args.experiment_name
100 | template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}'
101 |
102 | # ------------
103 | # Checkpoint
104 | # ------------
105 | ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)
106 | ckpt_path_exists = os.path.exists(ckpt_path)
107 | ckpt_provided = args.ckpt_name != '' and ckpt_path_exists
108 | model = model.load_from_checkpoint(ckpt_path,
109 | use_wandb_logger=use_wandb_logger,
110 | batch_size=args.batch_size,
111 | lr=args.lr,
112 | weight_decay=args.weight_decay,
113 | dropout_rate=args.dropout_rate) if ckpt_provided else model
114 |
115 | # ------------
116 | # Trainer
117 | # ------------
118 | trainer = pl.Trainer.from_argparse_args(args)
119 |
120 | # -------------
121 | # Learning Rate
122 | # -------------
123 | if args.find_lr:
124 | lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder
125 | fig = lr_finder.plot(suggest=True) # Plot learning rates
126 | fig.savefig('optimal_lr.pdf')
127 | fig.show()
128 | model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate
129 | print(f'Optimal learning rate found: {model.hparams.lr}')
130 |
131 | # ------------
132 | # Logger
133 | # ------------
134 | pl_logger = construct_pl_logger(args) # Log everything to an external logger
135 | trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance
136 |
137 | # -----------
138 | # Callbacks
139 | # -----------
140 | # Create and use callbacks
141 | mode = 'min' if 'ce' in args.metric_to_track else 'max'
142 | early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,
143 | mode=mode,
144 | min_delta=args.min_delta,
145 | patience=args.patience)
146 | ckpt_callback = pl.callbacks.ModelCheckpoint(
147 | monitor=args.metric_to_track,
148 | mode=mode,
149 | verbose=True,
150 | save_last=True,
151 | save_top_k=3,
152 | filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs
153 | )
154 | lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True)
155 | trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback]
156 |
157 | # ------------
158 | # Restore
159 | # ------------
160 | # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally
161 | if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path):
162 | checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best'
163 | artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model')
164 | artifact_dir = artifact.download()
165 | model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt',
166 | use_wandb_logger=use_wandb_logger,
167 | batch_size=args.batch_size,
168 | lr=args.lr,
169 | weight_decay=args.weight_decay)
170 |
171 | # -------------
172 | # Training
173 | # -------------
174 | # Train with the provided model and DataModule
175 | trainer.fit(model=model, datamodule=db5_data_module)
176 |
177 | # -------------
178 | # Testing
179 | # -------------
180 | trainer.test()
181 |
182 |
183 | # %%
184 | # -----------
185 | # Jupyter
186 | # -----------
187 | # sys.argv = ['']
188 |
189 | # %%
190 | # -----------
191 | # Arguments
192 | # -----------
193 | # Collect all arguments
194 | parser = collect_args()
195 |
196 | # Parse all known and unknown arguments
197 | args, unparsed_argv = parser.parse_known_args()
198 |
199 | # Let the model add what it wants
200 | parser = LitNeiA.add_model_specific_args(parser)
201 |
202 | # Re-parse all known and unknown arguments after adding those that are model specific
203 | args, unparsed_argv = parser.parse_known_args()
204 |
205 | # TODO: Manually set arguments within a Jupyter notebook from here
206 | args.model_name = "neia"
207 | args.multi_gpu_backend = "dp"
208 | args.db5_data_dir = "project/datasets/DB5/final/raw"
209 | args.process_complexes = True
210 | args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation
211 |
212 | # Set Lightning-specific parameter values before constructing Trainer instance
213 | args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes}
214 | args.max_epochs = args.num_epochs
215 | args.profiler = args.profiler_method
216 | args.accelerator = args.multi_gpu_backend
217 | args.auto_select_gpus = args.auto_choose_gpus
218 | args.gpus = args.num_gpus
219 | args.num_nodes = args.num_compute_nodes
220 | args.precision = args.gpu_precision
221 | args.accumulate_grad_batches = args.accum_grad_batches
222 | args.gradient_clip_val = args.grad_clip_val
223 | args.gradient_clip_algo = args.grad_clip_algo
224 | args.stochastic_weight_avg = args.stc_weight_avg
225 | args.deterministic = True # Make LightningModule's training reproducible
226 |
227 | # Set plugins for Lightning
228 | args.plugins = [
229 | # 'ddp_sharded', # For sharded model training (to reduce GPU requirements)
230 | # DDPPlugin(find_unused_parameters=False),
231 | ]
232 |
233 | # Finalize all arguments as necessary
234 | args = process_args(args)
235 |
236 | # Begin execution of model training with given args above
237 | main(args)
238 |
239 |
240 |
--------------------------------------------------------------------------------
/notebooks/feature_generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Feature generation for PDB file inputs"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import os\n",
18 | "\n",
19 | "import atom3.complex as comp\n",
20 | "import atom3.conservation as con\n",
21 | "import atom3.neighbors as nb\n",
22 | "import atom3.pair as pair\n",
23 | "import atom3.parse as parse\n",
24 | "import dill as pickle\n",
25 | "\n",
26 | "from pathlib import Path\n",
27 | "\n",
28 | "from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs"
29 | ]
30 | },
31 | {
32 | "attachments": {},
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "### 1. Parse PDB file input to pair-wise features"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "pdb_filename = \"../project/datasets/Input/raw/pdb/2g/12gs.pdb1\" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip`\n",
46 | "output_pkl = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n",
47 | "complexes_dill = \"../project/datasets/Input/interim/complexes/complexes.dill\"\n",
48 | "pairs_dir = \"../project/datasets/Input/interim/pairs\"\n",
49 | "pkl_filenames = [output_pkl]\n",
50 | "source_type = \"rcsb\" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs)\n",
51 | "neighbor_def = \"non_heavy_res\"\n",
52 | "cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here\n",
53 | "unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False`\n",
54 | "\n",
55 | "for item in [\n",
56 | " Path(pdb_filename).parent,\n",
57 | " Path(output_pkl).parent,\n",
58 | " Path(complexes_dill).parent,\n",
59 | " pairs_dir,\n",
60 | "]:\n",
61 | " os.makedirs(item, exist_ok=True)\n",
62 | "\n",
63 | "# note: the following replicates the logic within `make_dataset.py` for a single PDB file input\n",
64 | "parse.parse(\n",
65 | " # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed\n",
66 | " pdb_filename=pdb_filename,\n",
67 | " output_pkl=output_pkl\n",
68 | ")\n",
69 | "complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type)\n",
70 | "comp.write_complexes(complexes=complexes, output_dill=complexes_dill)\n",
71 | "get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff)\n",
72 | "get_pairs = pair.build_get_pairs(\n",
73 | " neighbor_def=neighbor_def,\n",
74 | " type=source_type,\n",
75 | " unbound=unbound,\n",
76 | " nb_fn=get_neighbors,\n",
77 | " full=False\n",
78 | ")\n",
79 | "complexes = comp.read_complexes(input_dill=complexes_dill)\n",
80 | "pair.complex_to_pairs(\n",
81 | " complex=list(complexes['data'].values())[0],\n",
82 | " source_type=source_type,\n",
83 | " get_pairs=get_pairs,\n",
84 | " output_dir=pairs_dir\n",
85 | ")"
86 | ]
87 | },
88 | {
89 | "attachments": {},
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "### 2. Compute sequence-based features using external tools"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "psaia_dir = \"~/Programs/PSAIA-1.0/bin/linux/psa\" # note: replace this with the path to your local installation of PSAIA\n",
103 | "psaia_config_file = \"../project/datasets/builder/psaia_config_file_dips.txt\" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above\n",
104 | "file_list_file = os.path.join(\"../project/datasets/Input/interim/external_feats/\", 'PSAIA', source_type.upper(), 'pdb_list.fls')\n",
105 | "num_cpus = 8\n",
106 | "pkl_filename = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n",
107 | "output_filename = \"../project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl\"\n",
108 | "hhsuite_db = \"~/Data/Databases/pfamA_35.0/pfam\" # note: substitute the path to your local HHsuite3 database here\n",
109 | "num_iter = 2\n",
110 | "msa_only = False\n",
111 | "\n",
112 | "for item in [\n",
113 | " Path(file_list_file).parent,\n",
114 | " Path(output_filename).parent,\n",
115 | "]:\n",
116 | " os.makedirs(item, exist_ok=True)\n",
117 | "\n",
118 | "# note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input\n",
119 | "with open(file_list_file, 'w') as file:\n",
120 | " file.write(f'{pdb_filename}\\n') # note: references the `pdb_filename` as defined previously\n",
121 | "con.gen_protrusion_index(\n",
122 | " psaia_dir=psaia_dir,\n",
123 | " psaia_config_file=psaia_config_file,\n",
124 | " file_list_file=file_list_file,\n",
125 | ")\n",
126 | "con.map_profile_hmms(\n",
127 | " num_cpus=num_cpus,\n",
128 | " pkl_filename=pkl_filename,\n",
129 | " output_filename=output_filename,\n",
130 | " hhsuite_db=hhsuite_db,\n",
131 | " source_type=source_type,\n",
132 | " num_iter=num_iter,\n",
133 | " msa_only=msa_only,\n",
134 | ")"
135 | ]
136 | },
137 | {
138 | "attachments": {},
139 | "cell_type": "markdown",
140 | "metadata": {},
141 | "source": [
142 | "### 3. Compute structure-based features"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "from project.utils.utils import __should_keep_postprocessed\n",
152 | "\n",
153 | "\n",
154 | "raw_pdb_dir = \"../project/datasets/Input/raw/pdb\"\n",
155 | "pair_filename = \"../project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill\"\n",
156 | "source_type = \"rcsb\"\n",
157 | "external_feats_dir = \"../project/datasets/Input/interim/external_feats/parsed\"\n",
158 | "output_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n",
159 | "\n",
160 | "unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type)\n",
161 | "if should_keep:\n",
162 | " # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired\n",
163 | " postprocessed_pair = postprocess_pruned_pair(\n",
164 | " raw_pdb_filenames=raw_pdb_filenames,\n",
165 | " external_feats_dir=external_feats_dir,\n",
166 | " original_pair=unprocessed_pair,\n",
167 | " source_type=source_type,\n",
168 | " )\n",
169 | " # write into output_filenames if not exist\n",
170 | " os.makedirs(Path(output_filename).parent, exist_ok=True)\n",
171 | " with open(output_filename, 'wb') as f:\n",
172 | " pickle.dump(postprocessed_pair, f)"
173 | ]
174 | },
175 | {
176 | "attachments": {},
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "### 4. Embed deep learning-based IDR features"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "# note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model\n",
190 | "!docker pull docker.io/sinaghadermarzi/fldpnn\n",
191 | "\n",
192 | "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n",
193 | "pickle_filepaths = [input_pair_filename]\n",
194 | "\n",
195 | "annotate_idr_residues(\n",
196 | " pickle_filepaths=pickle_filepaths\n",
197 | ")"
198 | ]
199 | },
200 | {
201 | "attachments": {},
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "### 5. Impute missing feature values (optional)"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n",
215 | "output_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n",
216 | "impute_atom_features = False\n",
217 | "advanced_logging = False\n",
218 | "\n",
219 | "impute_missing_feature_values(\n",
220 | " input_pair_filename=input_pair_filename,\n",
221 | " output_pair_filename=output_pair_filename,\n",
222 | " impute_atom_features=impute_atom_features,\n",
223 | " advanced_logging=advanced_logging,\n",
224 | ")"
225 | ]
226 | },
227 | {
228 | "attachments": {},
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "### 6. Convert pair-wise features into graph inputs (optional)"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "raw_filepath = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n",
242 | "new_graph_dir = \"../project/datasets/Input/final/processed/2g\"\n",
243 | "processed_filepath = \"../project/datasets/Input/final/processed/2g/12gs.pdb1.pt\"\n",
244 | "edge_dist_cutoff = 15.0\n",
245 | "edge_limit = 5000\n",
246 | "self_loops = True\n",
247 | "\n",
248 | "os.makedirs(new_graph_dir, exist_ok=True)\n",
249 | "\n",
250 | "process_raw_file_into_dgl_graphs(\n",
251 | " raw_filepath=raw_filepath,\n",
252 | " new_graph_dir=new_graph_dir,\n",
253 | " processed_filepath=processed_filepath,\n",
254 | " edge_dist_cutoff=edge_dist_cutoff,\n",
255 | " edge_limit=edge_limit,\n",
256 | " self_loops=self_loops,\n",
257 | ")"
258 | ]
259 | }
260 | ],
261 | "metadata": {
262 | "kernelspec": {
263 | "display_name": "DIPS-Plus",
264 | "language": "python",
265 | "name": "python3"
266 | },
267 | "language_info": {
268 | "codemirror_mode": {
269 | "name": "ipython",
270 | "version": 3
271 | },
272 | "file_extension": ".py",
273 | "mimetype": "text/x-python",
274 | "name": "python",
275 | "nbconvert_exporter": "python",
276 | "pygments_lexer": "ipython3",
277 | "version": "3.8.16"
278 | },
279 | "orig_nbformat": 4
280 | },
281 | "nbformat": 4,
282 | "nbformat_minor": 2
283 | }
284 |
--------------------------------------------------------------------------------
/notebooks/feature_generation.py:
--------------------------------------------------------------------------------
1 | # %% [markdown]
2 | # # Feature generation for PDB file inputs
3 |
4 | # %%
5 | import os
6 |
7 | import atom3.complex as comp
8 | import atom3.conservation as con
9 | import atom3.neighbors as nb
10 | import atom3.pair as pair
11 | import atom3.parse as parse
12 | import dill as pickle
13 |
14 | from pathlib import Path
15 |
16 | from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs
17 |
18 | # %% [markdown]
19 | # ### 1. Parse PDB file input to pair-wise features
20 |
21 | # %%
22 | pdb_filename = "project/datasets/Input/raw/pdb/2g/12gs.pdb1" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip`
23 | output_pkl = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl"
24 | complexes_dill = "project/datasets/Input/interim/complexes/complexes.dill"
25 | pairs_dir = "project/datasets/Input/interim/pairs"
26 | pkl_filenames = [output_pkl]
27 | source_type = "rcsb" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs)
28 | neighbor_def = "non_heavy_res"
29 | cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here
30 | unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False`
31 |
32 | for item in [
33 | Path(pdb_filename).parent,
34 | Path(output_pkl).parent,
35 | Path(complexes_dill).parent,
36 | pairs_dir,
37 | ]:
38 | os.makedirs(item, exist_ok=True)
39 |
40 | # note: the following replicates the logic within `make_dataset.py` for a single PDB file input
41 | parse.parse(
42 | # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed
43 | pdb_filename=pdb_filename,
44 | output_pkl=output_pkl
45 | )
46 | complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type)
47 | comp.write_complexes(complexes=complexes, output_dill=complexes_dill)
48 | get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff)
49 | get_pairs = pair.build_get_pairs(
50 | neighbor_def=neighbor_def,
51 | type=source_type,
52 | unbound=unbound,
53 | nb_fn=get_neighbors,
54 | full=False
55 | )
56 | complexes = comp.read_complexes(input_dill=complexes_dill)
57 | pair.complex_to_pairs(
58 | complex=list(complexes['data'].values())[0],
59 | source_type=source_type,
60 | get_pairs=get_pairs,
61 | output_dir=pairs_dir
62 | )
63 |
64 | # %% [markdown]
65 | # ### 2. Compute sequence-based features using external tools
66 |
67 | # %%
68 | psaia_dir = "~/Programs/PSAIA-1.0/bin/linux/psa" # note: replace this with the path to your local installation of PSAIA
69 | psaia_config_file = "project/datasets/builder/psaia_config_file_dips.txt" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above
70 | file_list_file = os.path.join("project/datasets/Input/interim/external_feats/", 'PSAIA', source_type.upper(), 'pdb_list.fls')
71 | num_cpus = 8
72 | pkl_filename = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl"
73 | output_filename = "project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl"
74 | hhsuite_db = "~/Data/Databases/pfamA_35.0/pfam" # note: substitute the path to your local HHsuite3 database here
75 | num_iter = 2
76 | msa_only = False
77 |
78 | for item in [
79 | Path(file_list_file).parent,
80 | Path(output_filename).parent,
81 | ]:
82 | os.makedirs(item, exist_ok=True)
83 |
84 | # note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input
85 | with open(file_list_file, 'w') as file:
86 | file.write(f'{pdb_filename}\n') # note: references the `pdb_filename` as defined previously
87 | con.gen_protrusion_index(
88 | psaia_dir=psaia_dir,
89 | psaia_config_file=psaia_config_file,
90 | file_list_file=file_list_file,
91 | )
92 | con.map_profile_hmms(
93 | num_cpus=num_cpus,
94 | pkl_filename=pkl_filename,
95 | output_filename=output_filename,
96 | hhsuite_db=hhsuite_db,
97 | source_type=source_type,
98 | num_iter=num_iter,
99 | msa_only=msa_only,
100 | )
101 |
102 | # %% [markdown]
103 | # ### 3. Compute structure-based features
104 |
105 | # %%
106 | from project.utils.utils import __should_keep_postprocessed
107 |
108 |
109 | raw_pdb_dir = "project/datasets/Input/raw/pdb"
110 | pair_filename = "project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill"
111 | source_type = "rcsb"
112 | external_feats_dir = "project/datasets/Input/interim/external_feats/parsed"
113 | output_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill"
114 |
115 | unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type)
116 | if should_keep:
117 | # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired
118 | postprocessed_pair = postprocess_pruned_pair(
119 | raw_pdb_filenames=raw_pdb_filenames,
120 | external_feats_dir=external_feats_dir,
121 | original_pair=unprocessed_pair,
122 | source_type=source_type,
123 | )
124 | # write into output_filenames if not exist
125 | os.makedirs(Path(output_filename).parent, exist_ok=True)
126 | with open(output_filename, 'wb') as f:
127 | pickle.dump(postprocessed_pair, f)
128 |
129 | # %% [markdown]
130 | # ### 4. Embed deep learning-based IDR features
131 |
132 | # %%
133 | # note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model
134 | # !docker pull docker.io/sinaghadermarzi/fldpnn
135 |
136 | input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill"
137 | pickle_filepaths = [input_pair_filename]
138 |
139 | annotate_idr_residues(
140 | pickle_filepaths=pickle_filepaths
141 | )
142 |
143 | # %% [markdown]
144 | # ### 5. Impute missing feature values (optional)
145 |
146 | # %%
147 | input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill"
148 | output_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill"
149 | impute_atom_features = False
150 | advanced_logging = False
151 |
152 | impute_missing_feature_values(
153 | input_pair_filename=input_pair_filename,
154 | output_pair_filename=output_pair_filename,
155 | impute_atom_features=impute_atom_features,
156 | advanced_logging=advanced_logging,
157 | )
158 |
159 | # %% [markdown]
160 | # ### 6. Convert pair-wise features into graph inputs (optional)
161 |
162 | # %%
163 | raw_filepath = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill"
164 | new_graph_dir = "project/datasets/Input/final/processed/2g"
165 | processed_filepath = "project/datasets/Input/final/processed/2g/12gs.pdb1.pt"
166 | edge_dist_cutoff = 15.0
167 | edge_limit = 5000
168 | self_loops = True
169 |
170 | os.makedirs(new_graph_dir, exist_ok=True)
171 |
172 | process_raw_file_into_dgl_graphs(
173 | raw_filepath=raw_filepath,
174 | new_graph_dir=new_graph_dir,
175 | processed_filepath=processed_filepath,
176 | edge_dist_cutoff=edge_dist_cutoff,
177 | edge_limit=edge_limit,
178 | self_loops=self_loops,
179 | )
180 |
181 |
182 |
--------------------------------------------------------------------------------
/project/datasets/DB5/README:
--------------------------------------------------------------------------------
1 | Cleaned up version of Docking Benchmark 5 (https://zlab.umassmed.edu/benchmark/).
2 |
3 | Released with "End-to-End Learning on 3D Protein Structure for Interface Prediction."
4 | by Raphael J.L. Townshend, Rishi Bedi, Patricia Suriana, Ron O. Dror
5 | https://arxiv.org/abs/1807.01297
6 |
7 | Specifically, bound chains and residue indexes were aligned across unbound and bound complexes.
8 |
9 | A total of 230 binary protein complexes are included.
10 |
11 | Processing code to regenerate and use the provided tfrecords is located at
12 | https://github.com/drorlab/DIPS
13 |
14 | MANIFEST
15 |
16 | raw/ - All pre-aligned and cleaned DB5 structures, organized into directories
17 | with individual files for ligand-unbound, ligand-bound, receptor-unbound,
18 | receptor-bound.
19 | interim/
20 | parsed/ - All DB5 structures processed to pickled dataframes.
21 | complexes/ - List of all possible pairs in parsed.
22 | pairs/ - Dill files of indivudal pairs listed in complexes.
23 | processed/
24 | tfrecords/ - pairs converted to tfrecords.
25 |
--------------------------------------------------------------------------------
/project/datasets/DB5/db5_dgl_data_module.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader
5 |
6 | from project.datasets.DB5.db5_dgl_dataset import DB5DGLDataset
7 |
8 |
9 | class DB5DGLDataModule(LightningDataModule):
10 | """Unbound protein complex data module for DGL with PyTorch."""
11 |
12 | # Dataset partition instantiation
13 | db5_train = None
14 | db5_val = None
15 | db5_test = None
16 |
17 | def __init__(self, data_dir: str, batch_size: int, num_dataloader_workers: int, knn: int,
18 | self_loops: bool, percent_to_use: float, process_complexes: bool, input_indep: bool):
19 | super().__init__()
20 |
21 | self.data_dir = data_dir
22 | self.batch_size = batch_size
23 | self.num_dataloader_workers = num_dataloader_workers
24 | self.knn = knn
25 | self.self_loops = self_loops
26 | self.percent_to_use = percent_to_use # Fraction of DB5 dataset to use
27 | self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training
28 | self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model
29 |
30 | def setup(self, stage: Optional[str] = None):
31 | # Assign training/validation/testing data set for use in DataLoaders - called on every GPU
32 | self.db5_train = DB5DGLDataset(mode='train', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops,
33 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes,
34 | input_indep=self.input_indep)
35 | self.db5_val = DB5DGLDataset(mode='val', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops,
36 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes,
37 | input_indep=self.input_indep)
38 | self.db5_test = DB5DGLDataset(mode='test', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops,
39 | percent_to_use=self.percent_to_use, process_complexes=self.process_complexes,
40 | input_indep=self.input_indep)
41 |
42 | def train_dataloader(self) -> DataLoader:
43 | return DataLoader(self.db5_train, batch_size=self.batch_size, shuffle=True,
44 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True)
45 |
46 | def val_dataloader(self) -> DataLoader:
47 | return DataLoader(self.db5_val, batch_size=self.batch_size, shuffle=False,
48 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True)
49 |
50 | def test_dataloader(self) -> DataLoader:
51 | return DataLoader(self.db5_test, batch_size=self.batch_size, shuffle=False,
52 | num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True)
53 |
--------------------------------------------------------------------------------
/project/datasets/DB5/db5_dgl_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import pandas as pd
5 | from dgl.data import DGLDataset, download, check_sha1
6 |
7 | from project.utils.training_utils import construct_filenames_frame_txt_filenames, build_filenames_frame_error_message, process_complex_into_dict, zero_out_complex_features
8 |
9 |
10 | class DB5DGLDataset(DGLDataset):
11 | r"""Unbound protein complex dataset for DGL with PyTorch.
12 |
13 | Statistics:
14 |
15 | - Train examples: 140
16 | - Validation examples: 35
17 | - Test examples: 55
18 | - Number of structures per complex: 2
19 | ----------------------
20 | - Total examples: 230
21 | ----------------------
22 |
23 | Parameters
24 | ----------
25 | mode: str, optional
26 | Should be one of ['train', 'val', 'test']. Default: 'train'.
27 | raw_dir: str
28 | Raw file directory to download/contains the input data directory. Default: 'final/raw'.
29 | knn: int
30 | How many nearest neighbors to which to connect a given node. Default: 20.
31 | self_loops: bool
32 | Whether to connect a given node to itself. Default: True.
33 | percent_to_use: float
34 | How much of the dataset to load. Default: 1.0.
35 | process_complexes: bool
36 | Whether to process each unprocessed complex as we load in the dataset. Default: True.
37 | input_indep: bool
38 | Whether to zero-out each input node and edge feature for an input-independent baseline. Default: False.
39 | force_reload: bool
40 | Whether to reload the dataset. Default: False.
41 | verbose: bool
42 | Whether to print out progress information. Default: False.
43 |
44 | Notes
45 | -----
46 | All the samples will be loaded and preprocessed in the memory first.
47 |
48 | Examples
49 | --------
50 | >>> # Get dataset
51 | >>> train_data = DB5DGLDataset()
52 | >>> val_data = DB5DGLDataset(mode='val')
53 | >>> test_data = DB5DGLDataset(mode='test')
54 | >>>
55 | >>> len(test_data)
56 | 55
57 | >>> test_data.num_chains
58 | 2
59 | """
60 |
61 | def __init__(self,
62 | mode='test',
63 | raw_dir=f'final{os.sep}raw',
64 | knn=20,
65 | self_loops=True,
66 | percent_to_use=1.0,
67 | process_complexes=True,
68 | input_indep=False,
69 | force_reload=False,
70 | verbose=False):
71 | assert mode in ['train', 'val', 'test']
72 | assert 0.0 < percent_to_use <= 1.0
73 | self.mode = mode
74 | self.root = raw_dir
75 | self.knn = knn
76 | self.self_loops = self_loops
77 | self.percent_to_use = percent_to_use # How much of the DB5 dataset to use
78 | self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training
79 | self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model
80 | self.final_dir = os.path.join(*self.root.split(os.sep)[:-1])
81 | self.processed_dir = os.path.join(self.final_dir, 'processed')
82 |
83 | self.filename_sampling = 0.0 < self.percent_to_use < 1.0
84 | self.base_txt_filename, self.filenames_frame_txt_filename, self.filenames_frame_txt_filepath = \
85 | construct_filenames_frame_txt_filenames(self.mode, self.percent_to_use, self.filename_sampling, self.root)
86 |
87 | # Try to load the text file containing all DB5 filenames, and alert the user if it is missing or corrupted
88 | filenames_frame_to_be_written = not os.path.exists(self.filenames_frame_txt_filepath)
89 |
90 | # Randomly sample DataFrame of filenames with requested cross validation ratio
91 | if self.filename_sampling:
92 | if filenames_frame_to_be_written:
93 | try:
94 | self.filenames_frame = pd.read_csv(
95 | os.path.join(self.root, self.base_txt_filename + '.txt'), header=None)
96 | except Exception:
97 | raise FileNotFoundError(
98 | build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath))
99 | self.filenames_frame = self.filenames_frame.sample(frac=self.percent_to_use).reset_index()
100 | try:
101 | self.filenames_frame[0].to_csv(self.filenames_frame_txt_filepath, header=None, index=None)
102 | except Exception:
103 | raise Exception(
104 | build_filenames_frame_error_message('DB5-Plus', 'write', self.filenames_frame_txt_filepath))
105 |
106 | # Load in existing DataFrame of filenames as requested (or if a sampled DataFrame .txt has already been written)
107 | if not filenames_frame_to_be_written:
108 | try:
109 | self.filenames_frame = pd.read_csv(self.filenames_frame_txt_filepath, header=None)
110 | except Exception:
111 | raise FileNotFoundError(
112 | build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath))
113 |
114 | # Process any unprocessed examples prior to using the dataset
115 | self.process()
116 |
117 | super(DB5DGLDataset, self).__init__(name='DB5-Plus',
118 | raw_dir=raw_dir,
119 | force_reload=force_reload,
120 | verbose=verbose)
121 | print(f"Loaded DB5-Plus {mode}-set, source: {self.processed_dir}, length: {len(self)}")
122 |
123 | def download(self):
124 | """Download and extract a pre-packaged version of the raw pairs if 'self.raw_dir' is not already populated."""
125 | # Path to store the file
126 | gz_file_path = os.path.join(os.path.join(*self.raw_dir.split(os.sep)[:-1]), 'final_raw_db5.tar.gz')
127 |
128 | # Download file
129 | download(self.url, path=gz_file_path)
130 |
131 | # Check SHA-1
132 | if not check_sha1(gz_file_path, self._sha1_str):
133 | raise UserWarning('File {} is downloaded but the content hash does not match.'
134 | 'The repo may be outdated or download may be incomplete. '
135 | 'Otherwise you can create an issue for it.'.format(gz_file_path))
136 |
137 | # Remove existing raw directory to make way for the new archive to be extracted
138 | if os.path.exists(self.raw_dir):
139 | os.removedirs(self.raw_dir)
140 |
141 | # Extract archive to parent directory of `self.raw_dir`
142 | self._extract_gz(gz_file_path, os.path.join(*self.raw_dir.split(os.sep)[:-1]))
143 |
144 | def process(self):
145 | """Process each protein complex into a testing-ready dictionary representing both structures."""
146 | if self.process_complexes:
147 | # Ensure the directory of processed complexes is already created
148 | os.makedirs(self.processed_dir, exist_ok=True)
149 | # Process each unprocessed protein complex
150 | for (i, raw_path) in self.filenames_frame.iterrows():
151 | raw_filepath = os.path.join(self.root, f'{os.path.splitext(raw_path[0])[0]}.dill')
152 | processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill')
153 | if not os.path.exists(processed_filepath):
154 | processed_parent_dir_to_make = os.path.join(self.processed_dir, os.path.split(raw_path[0])[0])
155 | os.makedirs(processed_parent_dir_to_make, exist_ok=True)
156 | process_complex_into_dict(raw_filepath, processed_filepath, self.knn, self.self_loops, False)
157 |
158 | def has_cache(self):
159 | """Check if each complex is downloaded and available for training, validation, or testing."""
160 | for (i, raw_path) in self.filenames_frame.iterrows():
161 | processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill')
162 | if not os.path.exists(processed_filepath):
163 | print(
164 | f'Unable to load at least one processed DB5 pair. '
165 | f'Please make sure all processed pairs have been successfully downloaded and are not corrupted.')
166 | raise FileNotFoundError
167 | print('DB5 cache found') # Otherwise, a cache was found!
168 |
169 | def __getitem__(self, idx):
170 | r""" Get feature dictionary by index of complex.
171 |
172 | Parameters
173 | ----------
174 | idx : int
175 |
176 | Returns
177 | -------
178 | :class:`dict`
179 |
180 | - ``complex['graph1_node_feats']:`` PyTorch Tensor containing each of the first graph's encoded node features
181 | - ``complex['graph2_node_feats']``: PyTorch Tensor containing each of the second graph's encoded node features
182 | - ``complex['graph1_node_coords']:`` PyTorch Tensor containing each of the first graph's node coordinates
183 | - ``complex['graph2_node_coords']``: PyTorch Tensor containing each of the second graph's node coordinates
184 | - ``complex['graph1_edge_feats']:`` PyTorch Tensor containing each of the first graph's edge features for each node
185 | - ``complex['graph2_edge_feats']:`` PyTorch Tensor containing each of the second graph's edge features for each node
186 | - ``complex['graph1_nbrhd_indices']:`` PyTorch Tensor containing each of the first graph's neighboring node indices
187 | - ``complex['graph2_nbrhd_indices']:`` PyTorch Tensor containing each of the second graph's neighboring node indices
188 | - ``complex['examples']:`` PyTorch Tensor containing the labels for inter-graph node pairs
189 | - ``complex['complex']:`` Python string describing the complex's code and original pdb filename
190 | """
191 | # Assemble filepath of processed protein complex
192 | complex_filepath = f'{os.path.splitext(self.filenames_frame[0][idx])[0]}.dill'
193 | processed_filepath = os.path.join(self.processed_dir, complex_filepath)
194 |
195 | # Load in processed complex
196 | with open(processed_filepath, 'rb') as f:
197 | processed_complex = pickle.load(f)
198 | processed_complex['filepath'] = complex_filepath # Add filepath to each complex dictionary
199 |
200 | # Optionally zero-out input data for an input-independent pipeline (per Karpathy's suggestion)
201 | if self.input_indep:
202 | processed_complex = zero_out_complex_features(processed_complex)
203 |
204 | # Manually filter for desired node and edge features
205 | # n_feat_idx_1, n_feat_idx_2 = 43, 85 # HSAAC
206 | # processed_complex['graph1'].ndata['f'] = processed_complex['graph1'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2]
207 | # processed_complex['graph2'].ndata['f'] = processed_complex['graph2'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2]
208 |
209 | # g1_rsa = processed_complex['graph1'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA
210 | # g1_psaia = processed_complex['graph1'].ndata['f'][:, 37: 43] # PSAIA
211 | # g1_hsaac = processed_complex['graph1'].ndata['f'][:, 43: 85] # HSAAC
212 | # processed_complex['graph1'].ndata['f'] = torch.cat((g1_rsa, g1_psaia, g1_hsaac), dim=1)
213 | #
214 | # g2_rsa = processed_complex['graph2'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA
215 | # g2_psaia = processed_complex['graph2'].ndata['f'][:, 37: 43] # PSAIA
216 | # g2_hsaac = processed_complex['graph2'].ndata['f'][:, 43: 85] # HSAAC
217 | # processed_complex['graph2'].ndata['f'] = torch.cat((g2_rsa, g2_psaia, g2_hsaac), dim=1)
218 |
219 | # processed_complex['graph1'].edata['f'] = processed_complex['graph1'].edata['f'][:, 1].reshape(-1, 1)
220 | # processed_complex['graph2'].edata['f'] = processed_complex['graph2'].edata['f'][:, 1].reshape(-1, 1)
221 |
222 | # Return requested complex to DataLoader
223 | return processed_complex
224 |
225 | def __len__(self) -> int:
226 | r"""Number of graph batches in the dataset."""
227 | return len(self.filenames_frame)
228 |
229 | @property
230 | def num_chains(self) -> int:
231 | """Number of protein chains in each complex."""
232 | return 2
233 |
234 | @property
235 | def num_classes(self) -> int:
236 | """Number of classes for each pair of inter-protein residues."""
237 | return 2
238 |
239 | @property
240 | def num_node_features(self) -> int:
241 | """Number of node feature values after encoding them."""
242 | return 107
243 |
244 | @property
245 | def num_edge_features(self) -> int:
246 | """Number of edge feature values after encoding them."""
247 | return 3
248 |
249 | @property
250 | def raw_path(self) -> str:
251 | """Directory in which to locate raw pairs."""
252 | return self.raw_dir
253 |
254 | @property
255 | def url(self) -> str:
256 | """URL with which to download TAR archive of preprocessed pairs."""
257 | # TODO: Update URL
258 | return 'https://zenodo.org/record/4815267/files/final_raw_db5.tar.gz?download=1'
259 |
--------------------------------------------------------------------------------
/project/datasets/EVCoupling/final/stub:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/EVCoupling/final/stub
--------------------------------------------------------------------------------
/project/datasets/EVCoupling/group_files.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 |
4 | import numpy as np
5 |
6 | # Get all filenames of interest
7 | filename_list = os.listdir(os.getcwd())
8 |
9 | # Use next() + lambda + loop for Initial-Four-Characters Case Categorization
10 | util_func = lambda x, y: x[0] == y[0] and x[1] == y[1] and x[2] == y[2] and x[3] == y[3]
11 | res = []
12 | for sub in filename_list:
13 | ele = next((x for x in res if util_func(sub, x[0])), [])
14 | if ele == []:
15 | res.append(ele)
16 | ele.append(sub)
17 |
18 | # Get names of EVCoupling heterodimers
19 | hd_list = np.loadtxt('../../evcoupling_heterodimer_list.txt', dtype=str).tolist()
20 |
21 | # Find first heterodimer for each complex and move it to a directory categorized by the complex's PDB code
22 | for r in res:
23 | if len(r) >= 2:
24 | r_combos = list(itertools.combinations(r, 2))[:1] # Only select a single heterodimer from each complex
25 | for combo in r_combos:
26 | heterodimer_name = combo[0][:5] + '_' + combo[1][:5]
27 | if heterodimer_name in hd_list:
28 | pdb_code = combo[0][:4]
29 | os.makedirs(pdb_code, exist_ok=True)
30 | l_b_chain_filename = combo[0]
31 | l_b_chain_filename = l_b_chain_filename[:-4] + '_l_b' + '.pdb'
32 | r_b_chain_filename = combo[1]
33 | r_b_chain_filename = r_b_chain_filename[:-4] + '_r_b' + '.pdb'
34 | if not os.path.exists(os.path.join(pdb_code, l_b_chain_filename)):
35 | os.rename(combo[0], os.path.join(pdb_code, l_b_chain_filename))
36 | if not os.path.exists(os.path.join(pdb_code, r_b_chain_filename)):
37 | os.rename(combo[1], os.path.join(pdb_code, r_b_chain_filename))
38 |
--------------------------------------------------------------------------------
/project/datasets/EVCoupling/interim/stub:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/EVCoupling/interim/stub
--------------------------------------------------------------------------------
/project/datasets/analysis/analyze_experiment_types_and_resolution.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import click
3 | import logging
4 | import os
5 |
6 | import atom3.pair as pa
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from graphein.ml.datasets import PDBManager
12 | from pathlib import Path
13 | from tqdm import tqdm
14 |
15 | from project.utils.utils import download_pdb_file, gunzip_file
16 |
17 |
18 | def plot_mean_std_freq(means, stds, freqs, title):
19 | plt.figure(figsize=(10, 6))
20 | num_experiments = len(means)
21 | bar_width = 0.4
22 | index = np.arange(num_experiments)
23 |
24 | # Plot bars for mean values with customized colors
25 | plt.barh(index, means, height=bar_width, color='salmon', label='Mean')
26 |
27 | # Plot error bars representing standard deviation with customized colors
28 | for i, (mean, std) in enumerate(zip(means, stds)):
29 | plt.errorbar(mean, i, xerr=std, color='red')
30 |
31 | # Plot vertical lines for the first and second standard deviation ranges with customized colors
32 | plt.vlines([mean - std, mean + std], i - bar_width/2, i + bar_width/2, color='royalblue', linestyles='dashed', label='1st Std Dev')
33 | plt.vlines([mean - 2*std, mean + 2*std], i - bar_width/2, i + bar_width/2, color='limegreen', linestyles='dotted', label='2nd Std Dev')
34 |
35 | plt.yticks(index, means.index)
36 | plt.xlabel('Resolution')
37 | plt.ylabel('Experiment Type')
38 | plt.title(title)
39 | plt.legend(loc="upper left")
40 |
41 | # Sort means, stds, and freqs based on the experiment types
42 | means = means[means.index.sort_values()]
43 | stds = stds[stds.index.sort_values()]
44 | freqs = freqs[freqs.index.sort_values()]
45 | freqs = (freqs / freqs.sum()) * 100
46 |
47 | # Calculate middle points of each bar
48 | middles = means
49 |
50 | # Calculate the visual center of each bar
51 | visual_center = middles / 2
52 |
53 | # Add frequency (count) of each experiment type at the middle point of each bar
54 | for i, freq in enumerate(freqs):
55 | plt.text(visual_center[i], i, f"{freq:.4f}%", va='center', ha='center', color='black', fontweight='bold')
56 |
57 | plt.tight_layout()
58 | plt.show()
59 | plt.savefig(title.lower().replace(' ', '_') + ".png")
60 |
61 |
62 | @click.command()
63 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
64 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5']))
65 | def main(output_dir: str, source_type: str):
66 | logger = logging.getLogger(__name__)
67 | logger.info("Analyzing experiment types and resolution for each dataset example...")
68 |
69 | if source_type.lower() == "rcsb":
70 | train_metadata_csv_filepath = os.path.join(output_dir, "train_pdb_metadata.csv")
71 | val_metadata_csv_filepath = os.path.join(output_dir, "val_pdb_metadata.csv")
72 | train_val_metadata_csv_filepath = os.path.join(output_dir, "train_val_pdb_metadata.csv")
73 | metadata_csv_filepaths = [train_metadata_csv_filepath, val_metadata_csv_filepath, train_val_metadata_csv_filepath]
74 |
75 | if any(not os.path.exists(fp) for fp in metadata_csv_filepaths):
76 | with tempfile.TemporaryDirectory() as temp_dir:
77 | pdb_manager = PDBManager(root_dir=temp_dir)
78 |
79 | # Collect (and, if necessary, extract) all training PDB files
80 | train_pdb_codes = []
81 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
82 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
83 | with open(pairs_postprocessed_train_txt, "r") as f:
84 | train_filenames = [line.strip() for line in f.readlines()]
85 | for train_filename in tqdm(train_filenames):
86 | try:
87 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename))
88 | except Exception as e:
89 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}")
90 | continue
91 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3]
92 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
93 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0])
94 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0])
95 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique()
96 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique()
97 | assert (
98 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
99 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
100 | assert (
101 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
102 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
103 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
104 | gunzip_file(l_b_pdb_filepath)
105 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
106 | gunzip_file(r_b_pdb_filepath)
107 | if not os.path.exists(l_b_pdb_filepath):
108 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
109 | if not os.path.exists(r_b_pdb_filepath):
110 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
111 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
112 |
113 | l_b_pdb_code = Path(l_b_pdb_filepath).stem + "_" + l_b_df0_chains[0]
114 | r_b_pdb_code = Path(r_b_pdb_filepath).stem + "_" + r_b_df1_chains[0]
115 | train_pdb_codes.extend([l_b_pdb_code, r_b_pdb_code])
116 |
117 | # Collect (and, if necessary, extract) all validation PDB files
118 | val_pdb_codes = []
119 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
120 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
121 | with open(pairs_postprocessed_val_txt, "r") as f:
122 | val_filenames = [line.strip() for line in f.readlines()]
123 | for val_filename in tqdm(val_filenames):
124 | try:
125 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename))
126 | except Exception as e:
127 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}")
128 | continue
129 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3]
130 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
131 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0])
132 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0])
133 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique()
134 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique()
135 | assert (
136 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
137 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
138 | assert (
139 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
140 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
141 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
142 | gunzip_file(l_b_pdb_filepath)
143 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
144 | gunzip_file(r_b_pdb_filepath)
145 | if not os.path.exists(l_b_pdb_filepath):
146 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
147 | if not os.path.exists(r_b_pdb_filepath):
148 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
149 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
150 |
151 | l_b_pdb_code = Path(l_b_pdb_filepath).stem + "_" + l_b_df0_chains[0]
152 | r_b_pdb_code = Path(r_b_pdb_filepath).stem + "_" + r_b_df1_chains[0]
153 | val_pdb_codes.extend([l_b_pdb_code, r_b_pdb_code])
154 |
155 | # Record training and validation PDBs as a metadata CSV file
156 | train_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(train_pdb_codes)]
157 | train_pdbs_df.to_csv(train_metadata_csv_filepath)
158 |
159 | val_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(val_pdb_codes)]
160 | val_pdbs_df.to_csv(val_metadata_csv_filepath)
161 |
162 | train_val_pdbs_df = pdb_manager.df[pdb_manager.df.id.isin(train_pdb_codes + val_pdb_codes)]
163 | train_val_pdbs_df.to_csv(train_val_metadata_csv_filepath)
164 |
165 | assert all(os.path.exists(fp) for fp in metadata_csv_filepaths), "To analyze RCSB complexes, the corresponding metadata must previously have been collected."
166 | train_pdbs_df = pd.read_csv(train_metadata_csv_filepath, index_col=0)
167 | val_pdbs_df = pd.read_csv(val_metadata_csv_filepath, index_col=0)
168 | train_val_pdbs_df = pd.read_csv(train_val_metadata_csv_filepath, index_col=0)
169 |
170 | # Train PDBs
171 | train_pdbs_df = train_pdbs_df[~train_pdbs_df.experiment_type.isin(["other"])]
172 | means_train = train_pdbs_df.groupby('experiment_type')['resolution'].mean()
173 | stds_train = train_pdbs_df.groupby('experiment_type')['resolution'].std()
174 | freqs_train = train_pdbs_df['experiment_type'].value_counts()
175 | plot_mean_std_freq(means_train, stds_train, freqs_train, 'Resolution vs. Experiment Type (Train)')
176 |
177 | # Validation PDBs
178 | val_pdbs_df = val_pdbs_df[~val_pdbs_df.experiment_type.isin(["other"])]
179 | means_val = val_pdbs_df.groupby('experiment_type')['resolution'].mean()
180 | stds_val = val_pdbs_df.groupby('experiment_type')['resolution'].std()
181 | freqs_val = val_pdbs_df['experiment_type'].value_counts()
182 | plot_mean_std_freq(means_val, stds_val, freqs_val, 'Resolution vs. Experiment Type (Validation)')
183 |
184 | # Train + Validation PDBs
185 | train_val_pdbs_df = train_val_pdbs_df[~train_val_pdbs_df.experiment_type.isin(["other"])]
186 | means_train_val = train_val_pdbs_df.groupby('experiment_type')['resolution'].mean()
187 | stds_train_val = train_val_pdbs_df.groupby('experiment_type')['resolution'].std()
188 | freqs_train_val = train_val_pdbs_df['experiment_type'].value_counts()
189 | plot_mean_std_freq(means_train_val, stds_train_val, freqs_train_val, 'Resolution vs. Experiment Type (Train + Validation)')
190 |
191 | logger.info("Finished analyzing experiment types and resolution for all training and validation PDBs")
192 |
193 | else:
194 | raise NotImplementedError(f"Source type {source_type} is currently not supported.")
195 |
196 |
197 | if __name__ == "__main__":
198 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
199 | logging.basicConfig(level=logging.INFO, format=log_fmt)
200 |
201 | main()
202 |
--------------------------------------------------------------------------------
/project/datasets/analysis/analyze_feature_correlation.py:
--------------------------------------------------------------------------------
1 | import click
2 | import logging
3 | import os
4 |
5 | import atom3.pair as pa
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pandas as pd
9 | import seaborn as sns
10 |
11 | from pathlib import Path
12 | from scipy.stats import pearsonr
13 | from tqdm import tqdm
14 |
15 | from project.utils.utils import download_pdb_file, gunzip_file
16 |
17 |
18 | @click.command()
19 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
20 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5']))
21 | @click.option('--feature_types_to_correlate', default='rcsb', type=click.Choice(['rsa_value-rd_value', 'rsa_value-cn_value', 'rd_value-cn_value']))
22 | def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
23 | logger = logging.getLogger(__name__)
24 | logger.info("Analyzing feature correlation for each dataset example...")
25 |
26 | features_to_correlate = feature_types_to_correlate.split("-")
27 | assert len(features_to_correlate) == 2, "Exactly two features may be currently compared for correlation measures."
28 |
29 | if source_type.lower() == "rcsb":
30 | # Collect (and, if necessary, extract) all training PDB files
31 | train_feature_values = []
32 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
33 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
34 | with open(pairs_postprocessed_train_txt, "r") as f:
35 | train_filenames = [line.strip() for line in f.readlines()]
36 | for train_filename in tqdm(train_filenames):
37 | try:
38 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename))
39 | except Exception as e:
40 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}")
41 | continue
42 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3]
43 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
44 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0])
45 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0])
46 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique()
47 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique()
48 | assert (
49 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
50 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
51 | assert (
52 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
53 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
54 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
55 | gunzip_file(l_b_pdb_filepath)
56 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
57 | gunzip_file(r_b_pdb_filepath)
58 | if not os.path.exists(l_b_pdb_filepath):
59 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
60 | if not os.path.exists(r_b_pdb_filepath):
61 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
62 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
63 |
64 | l_b_df0_feature_values = postprocessed_train_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
65 | r_b_df1_feature_values = postprocessed_train_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
66 | train_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values]))
67 |
68 | # Collect (and, if necessary, extract) all validation PDB files
69 | val_feature_values = []
70 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
71 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
72 | with open(pairs_postprocessed_val_txt, "r") as f:
73 | val_filenames = [line.strip() for line in f.readlines()]
74 | for val_filename in tqdm(val_filenames):
75 | try:
76 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename))
77 | except Exception as e:
78 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}")
79 | continue
80 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3]
81 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
82 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0])
83 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0])
84 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique()
85 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique()
86 | assert (
87 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
88 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
89 | assert (
90 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
91 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
92 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
93 | gunzip_file(l_b_pdb_filepath)
94 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
95 | gunzip_file(r_b_pdb_filepath)
96 | if not os.path.exists(l_b_pdb_filepath):
97 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
98 | if not os.path.exists(r_b_pdb_filepath):
99 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
100 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
101 |
102 | l_b_df0_feature_values = postprocessed_val_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
103 | r_b_df1_feature_values = postprocessed_val_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
104 | val_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values]))
105 |
106 | # Train PDBs
107 | train_feature_values_df = pd.concat(train_feature_values)
108 | train_feature_values_correlation, train_feature_values_p_value = pearsonr(train_feature_values_df[features_to_correlate[0]], train_feature_values_df[features_to_correlate[1]])
109 | logger.info(f"With a p-value of {train_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the training dataset is: {train_feature_values_correlation}")
110 | train_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=train_feature_values_df, kind='hex')
111 | # Add correlation value to the jointplot
112 | train_joint_plot.ax_joint.annotate(f"Training correlation: {train_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12)
113 | # Save the jointplot
114 | train_joint_plot.savefig(f"train_{feature_types_to_correlate}_correlation.png")
115 | plt.close()
116 |
117 | # Validation PDBs
118 | val_feature_values_df = pd.concat(val_feature_values)
119 | val_feature_values_correlation, val_feature_values_p_value = pearsonr(val_feature_values_df[features_to_correlate[0]], val_feature_values_df[features_to_correlate[1]])
120 | logger.info(f"With a p-value of {val_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the validation dataset is: {val_feature_values_correlation}")
121 | val_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=val_feature_values_df, kind='hex')
122 | # Add correlation value to the jointplot
123 | val_joint_plot.ax_joint.annotate(f"Validation correlation: {val_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12)
124 | # Save the jointplot
125 | val_joint_plot.savefig(f"val_{feature_types_to_correlate}_correlation.png")
126 | plt.close()
127 |
128 | # Train + Validation PDBs
129 | train_val_feature_values_df = pd.concat([train_feature_values_df, val_feature_values_df])
130 | train_val_feature_values_correlation, train_val_feature_values_p_value = pearsonr(train_val_feature_values_df[features_to_correlate[0]], train_val_feature_values_df[features_to_correlate[1]])
131 | logger.info(f"With a p-value of {train_val_feature_values_p_value}, the Pearson's correlation of feature values `{feature_types_to_correlate}` within the training and validation dataset is: {train_val_feature_values_correlation}")
132 | train_val_joint_plot = sns.jointplot(x=features_to_correlate[0], y=features_to_correlate[1], data=train_val_feature_values_df, kind='hex')
133 | # Add correlation value to the jointplot
134 | train_val_joint_plot.ax_joint.annotate(f"Training + Validation Correlation: {train_val_feature_values_correlation:.2f}", xy=(0.5, 0.95), xycoords='axes fraction', ha='center', fontsize=12)
135 | # Save the jointplot
136 | train_val_joint_plot.savefig(f"train_val_{feature_types_to_correlate}_correlation.png")
137 | plt.close()
138 |
139 | logger.info("Finished analyzing feature correlation for all training and validation PDBs")
140 |
141 | else:
142 | raise NotImplementedError(f"Source type {source_type} is currently not supported.")
143 |
144 |
145 | if __name__ == "__main__":
146 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
147 | logging.basicConfig(level=logging.INFO, format=log_fmt)
148 |
149 | main()
150 |
--------------------------------------------------------------------------------
/project/datasets/analysis/analyze_interface_waters.py:
--------------------------------------------------------------------------------
1 | import click
2 | import logging
3 | import os
4 | import warnings
5 |
6 | import atom3.pair as pa
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from Bio import BiopythonWarning
12 | from Bio.PDB import NeighborSearch
13 | from Bio.PDB import PDBParser
14 | from pathlib import Path
15 | from tqdm import tqdm
16 |
17 | from project.utils.utils import download_pdb_file, gunzip_file
18 |
19 |
20 | @click.command()
21 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
22 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5']))
23 | @click.option('--interfacing_water_distance_cutoff', default=10.0, type=float)
24 | def main(output_dir: str, source_type: str, interfacing_water_distance_cutoff: float):
25 | logger = logging.getLogger(__name__)
26 | logger.info("Analyzing interface waters within each dataset example...")
27 |
28 | if source_type.lower() == "rcsb":
29 | parser = PDBParser()
30 |
31 | # Filter and suppress BioPython warnings
32 | warnings.filterwarnings("ignore", category=BiopythonWarning)
33 |
34 | # Collect (and, if necessary, extract) all training PDB files
35 | train_complex_num_waters = []
36 | pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
37 | assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
38 | with open(pairs_postprocessed_train_txt, "r") as f:
39 | train_filenames = [line.strip() for line in f.readlines()]
40 | for train_filename in tqdm(train_filenames):
41 | complex_num_waters = 0
42 | try:
43 | postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename))
44 | except Exception as e:
45 | logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}")
46 | continue
47 | pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3]
48 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
49 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0])
50 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0])
51 | l_b_df0_chains = postprocessed_train_pair.df0.chain.unique()
52 | r_b_df1_chains = postprocessed_train_pair.df1.chain.unique()
53 | assert (
54 | len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
55 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
56 | assert (
57 | len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
58 | ), "Only a single PDB filename and chain identifier can be associated with a single training example."
59 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
60 | gunzip_file(l_b_pdb_filepath)
61 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
62 | gunzip_file(r_b_pdb_filepath)
63 | if not os.path.exists(l_b_pdb_filepath):
64 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
65 | if not os.path.exists(r_b_pdb_filepath):
66 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
67 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
68 |
69 | l_b_structure = parser.get_structure('protein', l_b_pdb_filepath)
70 | r_b_structure = parser.get_structure('protein', r_b_pdb_filepath)
71 |
72 | l_b_interface_residues = postprocessed_train_pair.df0[postprocessed_train_pair.df0.index.isin(postprocessed_train_pair.pos_idx[:, 0])]
73 | r_b_interface_residues = postprocessed_train_pair.df1[postprocessed_train_pair.df1.index.isin(postprocessed_train_pair.pos_idx[:, 1])]
74 |
75 | try:
76 | l_b_ns = NeighborSearch(list(l_b_structure.get_atoms()))
77 | for index, row in l_b_interface_residues.iterrows():
78 | chain_id = row['chain']
79 | residue = row['residue'].strip()
80 | model = l_b_structure[0]
81 | chain = model[chain_id]
82 | if residue.lstrip("-").isdigit():
83 | residue = int(residue)
84 | else:
85 | residue_index, residue_icode = residue[:-1], residue[-1:]
86 | if residue_icode.strip() == "":
87 | residue = int(residue)
88 | else:
89 | residue = (" ", int(residue_index), residue_icode)
90 | target_residue = chain[residue]
91 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
92 | interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
93 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
94 | complex_num_waters += len(waters_within_threshold)
95 | except Exception as e:
96 | logging.error(f"Unable to locate interface waters for left-bound training structure {l_b_pdb_filepath} due to: {e}. Skipping...")
97 | continue
98 |
99 | try:
100 | r_b_ns = NeighborSearch(list(r_b_structure.get_atoms()))
101 | for index, row in r_b_interface_residues.iterrows():
102 | chain_id = row['chain']
103 | residue = row['residue'].strip()
104 | model = r_b_structure[0]
105 | chain = model[chain_id]
106 | if residue.lstrip("-").isdigit():
107 | residue = int(residue)
108 | else:
109 | residue_index, residue_icode = residue[:-1], residue[-1:]
110 | residue = (" ", int(residue_index), residue_icode)
111 | target_residue = chain[residue]
112 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
113 | interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
114 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
115 | complex_num_waters += len(waters_within_threshold)
116 | except Exception as e:
117 | logging.error(f"Unable to locate interface waters for right-bound training structure {r_b_pdb_filepath} due to: {e}. Skipping...")
118 | continue
119 |
120 | train_complex_num_waters.append(complex_num_waters)
121 |
122 | # Collect (and, if necessary, extract) all validation PDB files
123 | val_complex_num_waters = []
124 | pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
125 | assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
126 | with open(pairs_postprocessed_val_txt, "r") as f:
127 | val_filenames = [line.strip() for line in f.readlines()]
128 | for val_filename in tqdm(val_filenames):
129 | complex_num_waters = 0
130 | try:
131 | postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename))
132 | except Exception as e:
133 | logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}")
134 | continue
135 | pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3]
136 | pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
137 | l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0])
138 | r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0])
139 | l_b_df0_chains = postprocessed_val_pair.df0.chain.unique()
140 | r_b_df1_chains = postprocessed_val_pair.df1.chain.unique()
141 | assert (
142 | len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
143 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
144 | assert (
145 | len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
146 | ), "Only a single PDB filename and chain identifier can be associated with a single validation example."
147 | if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
148 | gunzip_file(l_b_pdb_filepath)
149 | if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
150 | gunzip_file(r_b_pdb_filepath)
151 | if not os.path.exists(l_b_pdb_filepath):
152 | download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
153 | if not os.path.exists(r_b_pdb_filepath):
154 | download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
155 | assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
156 |
157 | l_b_structure = parser.get_structure('protein', l_b_pdb_filepath)
158 | r_b_structure = parser.get_structure('protein', r_b_pdb_filepath)
159 |
160 | l_b_interface_residues = postprocessed_val_pair.df0[postprocessed_val_pair.df0.index.isin(postprocessed_val_pair.pos_idx[:, 0])]
161 | r_b_interface_residues = postprocessed_val_pair.df1[postprocessed_val_pair.df1.index.isin(postprocessed_val_pair.pos_idx[:, 1])]
162 |
163 | try:
164 | l_b_ns = NeighborSearch(list(l_b_structure.get_atoms()))
165 | for index, row in l_b_interface_residues.iterrows():
166 | chain_id = row['chain']
167 | residue = row['residue'].strip()
168 | model = l_b_structure[0]
169 | chain = model[chain_id]
170 | if residue.lstrip("-").isdigit():
171 | residue = int(residue)
172 | else:
173 | residue_index, residue_icode = residue[:-1], residue[-1:]
174 | residue = (" ", int(residue_index), residue_icode)
175 | target_residue = chain[residue]
176 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
177 | interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
178 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
179 | complex_num_waters += len(waters_within_threshold)
180 | except Exception as e:
181 | logging.error(f"Unable to locate interface waters for left-bound validation structure {l_b_pdb_filepath} due to: {e}. Skipping...")
182 | continue
183 |
184 | try:
185 | r_b_ns = NeighborSearch(list(r_b_structure.get_atoms()))
186 | for index, row in r_b_interface_residues.iterrows():
187 | chain_id = row['chain']
188 | residue = row['residue'].strip()
189 | model = r_b_structure[0]
190 | chain = model[chain_id]
191 | if residue.lstrip("-").isdigit():
192 | residue = int(residue)
193 | else:
194 | residue_index, residue_icode = residue[:-1], residue[-1:]
195 | residue = (" ", int(residue_index), residue_icode)
196 | target_residue = chain[residue]
197 | target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
198 | interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
199 | waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
200 | complex_num_waters += len(waters_within_threshold)
201 | except Exception as e:
202 | logging.error(f"Unable to locate interface waters for right-bound validation structure {r_b_pdb_filepath} due to: {e}. Skipping...")
203 | continue
204 |
205 | val_complex_num_waters.append(complex_num_waters)
206 |
207 | train_val_complex_num_waters = train_complex_num_waters + val_complex_num_waters
208 |
209 | # Calculate mean values
210 | training_mean = np.mean(train_complex_num_waters)
211 | validation_mean = np.mean(val_complex_num_waters)
212 | training_validation_mean = np.mean(train_val_complex_num_waters)
213 |
214 | # Plotting the distributions
215 | plt.figure(figsize=(10, 6)) # Set the size of the figure
216 |
217 | # Training data distribution
218 | plt.subplot(131) # 1 row, 3 columns, plot 1 (leftmost)
219 | plt.hist(train_complex_num_waters, bins=10, color='royalblue')
220 | plt.axvline(training_mean, color='limegreen', linestyle='dashed', linewidth=2)
221 | plt.text(training_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {training_mean:.2f}', color='limegreen')
222 | plt.title('Train Interface Waters')
223 | plt.xlabel('Count')
224 | plt.ylabel('Frequency')
225 |
226 | # Validation data distribution
227 | plt.subplot(132) # 1 row, 3 columns, plot 2 (middle)
228 | plt.hist(val_complex_num_waters, bins=10, color='royalblue')
229 | plt.axvline(validation_mean, color='limegreen', linestyle='dashed', linewidth=2)
230 | plt.text(validation_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {validation_mean:.2f}', color='limegreen')
231 | plt.title('Validation Interface Waters')
232 | plt.xlabel('Count')
233 | plt.ylabel('Frequency')
234 |
235 | # Combined data distribution
236 | plt.subplot(133) # 1 row, 3 columns, plot 3 (rightmost)
237 | plt.hist(train_val_complex_num_waters, bins=10, color='royalblue')
238 | plt.axvline(training_validation_mean, color='limegreen', linestyle='dashed', linewidth=2)
239 | plt.text(training_validation_mean + 0.1, plt.ylim()[1] * 0.9, f' Mean: {training_validation_mean:.2f}', color='limegreen')
240 | plt.title('Train+Validation Interface Waters')
241 | plt.xlabel('Count')
242 | plt.ylabel('Frequency')
243 |
244 | plt.tight_layout() # Adjust the spacing between subplots
245 | plt.show() # Display the plots
246 | plt.savefig("train_val_interface_waters_analysis.png")
247 |
248 | logger.info("Finished analyzing interface waters for all training and validation complexes")
249 |
250 | else:
251 | raise NotImplementedError(f"Source type {source_type} is currently not supported.")
252 |
253 |
254 | if __name__ == "__main__":
255 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
256 | logging.basicConfig(level=logging.INFO, format=log_fmt)
257 |
258 | main()
259 |
--------------------------------------------------------------------------------
/project/datasets/builder/add_new_feature.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from pandas.errors import SettingWithCopyWarning
4 |
5 | warnings.simplefilter("ignore", category=FutureWarning)
6 | warnings.simplefilter("ignore", category=SettingWithCopyWarning)
7 |
8 | import click
9 | import graphein
10 | import logging
11 | import loguru
12 | import multiprocessing
13 | import os
14 | import sys
15 |
16 | from functools import partial
17 | from pathlib import Path
18 |
19 | from project.utils.utils import add_new_feature
20 |
21 | GRAPHEIN_FEATURE_NAME_MAPPING = {
22 | # TODO: Fill out remaining mappings for available Graphein residue-level features
23 | "expasy_protein_scale": "expasy",
24 | }
25 |
26 |
27 | @click.command()
28 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True))
29 | @click.option('--num_cpus', '-c', default=1)
30 | @click.option('--modify_pair_data/--dry_run_only', '-m', default=False)
31 | @click.option('--graphein_feature_to_add', default='expasy_protein_scale', type=str)
32 | def main(raw_data_dir: str, num_cpus: int, modify_pair_data: bool, graphein_feature_to_add: str):
33 | # Validate requested feature function
34 | assert (
35 | hasattr(graphein.protein.features.nodes.amino_acid, graphein_feature_to_add)
36 | ), f"Graphein must provide the requested node featurization function {graphein_feature_to_add}"
37 |
38 | # Disable DEBUG messages coming from Graphein
39 | loguru.logger.disable("graphein")
40 | loguru.logger.remove()
41 | loguru.logger.add(lambda message: message["level"].name != "DEBUG")
42 |
43 | # Collect paths of files to modify
44 | raw_data_dir = Path(raw_data_dir)
45 | raw_data_pickle_filepaths = []
46 | for root, dirs, files in os.walk(raw_data_dir):
47 | for dir in dirs:
48 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir):
49 | for file in subfiles:
50 | if file.endswith('.dill'):
51 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file)
52 |
53 | # Add to each file the values corresponding to a new feature, using multiprocessing #
54 | # Define the number of processes to use
55 | num_processes = min(num_cpus, multiprocessing.cpu_count())
56 |
57 | # Split the list of file paths into chunks
58 | chunk_size = len(raw_data_pickle_filepaths) // num_processes
59 | file_path_chunks = [
60 | raw_data_pickle_filepaths[i:i+chunk_size]
61 | for i in range(0, len(raw_data_pickle_filepaths), chunk_size)
62 | ]
63 | assert (
64 | len(raw_data_pickle_filepaths) == len([fp for chunk in file_path_chunks for fp in chunk])
65 | ), "Number of input files must match number of files across all file chunks."
66 |
67 | # Create a pool of worker processes
68 | pool = multiprocessing.Pool(processes=num_processes)
69 |
70 | # Process each chunk of file paths in parallel
71 | parallel_fn = partial(
72 | add_new_feature,
73 | modify_pair_data=modify_pair_data,
74 | graphein_feature_to_add=graphein_feature_to_add,
75 | graphein_feature_name_mapping=GRAPHEIN_FEATURE_NAME_MAPPING,
76 | )
77 | pool.map(parallel_fn, file_path_chunks)
78 |
79 | # Close the pool and wait for all processes to finish
80 | pool.close()
81 | pool.join()
82 |
83 |
84 | if __name__ == "__main__":
85 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
86 | logging.basicConfig(level=logging.INFO, format=log_fmt)
87 |
88 | main()
89 |
--------------------------------------------------------------------------------
/project/datasets/builder/annotate_idr_residues.py:
--------------------------------------------------------------------------------
1 | import click
2 | import logging
3 | import multiprocessing
4 | import os
5 |
6 | from pathlib import Path
7 |
8 | from project.utils.utils import annotate_idr_residues
9 |
10 |
11 | @click.command()
12 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True))
13 | @click.option('--num_cpus', '-c', default=1)
14 | def main(raw_data_dir: str, num_cpus: int):
15 | # Collect paths of files to modify
16 | raw_data_dir = Path(raw_data_dir)
17 | raw_data_pickle_filepaths = []
18 | for root, dirs, files in os.walk(raw_data_dir):
19 | for dir in dirs:
20 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir):
21 | for file in subfiles:
22 | if file.endswith('.dill'):
23 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file)
24 |
25 | # Annotate whether each residue resides in an IDR, using multiprocessing #
26 | # Define the number of processes to use
27 | num_processes = min(num_cpus, multiprocessing.cpu_count())
28 |
29 | # Split the list of file paths into chunks
30 | chunk_size = len(raw_data_pickle_filepaths) // num_processes
31 | file_path_chunks = [
32 | raw_data_pickle_filepaths[i:i+chunk_size]
33 | for i in range(0, len(raw_data_pickle_filepaths), chunk_size)
34 | ]
35 | assert (
36 | len(raw_data_pickle_filepaths) == len([fp for chunk in file_path_chunks for fp in chunk])
37 | ), "Number of input files must match number of files across all file chunks."
38 |
39 | # Create a pool of worker processes
40 | pool = multiprocessing.Pool(processes=num_processes)
41 |
42 | # Process each chunk of file paths in parallel
43 | pool.map(annotate_idr_residues, file_path_chunks)
44 |
45 | # Close the pool and wait for all processes to finish
46 | pool.close()
47 | pool.join()
48 |
49 |
50 | if __name__ == "__main__":
51 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
52 | logging.basicConfig(level=logging.INFO, format=log_fmt)
53 |
54 | main()
55 |
--------------------------------------------------------------------------------
/project/datasets/builder/collect_dataset_statistics.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import click
5 |
6 | from project.utils.utils import get_global_node_rank, collect_dataset_statistics, DEFAULT_DATASET_STATISTICS
7 |
8 |
9 | @click.command()
10 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
11 | @click.option('--rank', '-r', default=0)
12 | @click.option('--size', '-s', default=1)
13 | def main(output_dir: str, rank: int, size: int):
14 | """Collect all dataset statistics."""
15 | # Reestablish global rank
16 | rank = get_global_node_rank(rank, size)
17 |
18 | # Ensure that this task only gets run on a single node to prevent race conditions
19 | if rank == 0:
20 | logger = logging.getLogger(__name__)
21 | logger.info('Aggregating statistics for given dataset')
22 |
23 | # Make sure the output_dir exists
24 | if not os.path.exists(output_dir):
25 | os.mkdir(output_dir)
26 |
27 | # Create dataset statistics CSV if not already existent
28 | dataset_statistics_csv = os.path.join(output_dir, 'dataset_statistics.csv')
29 | if not os.path.exists(dataset_statistics_csv):
30 | # Reset dataset statistics CSV
31 | with open(dataset_statistics_csv, 'w') as f:
32 | for key in DEFAULT_DATASET_STATISTICS.keys():
33 | f.write("%s, %s\n" % (key, DEFAULT_DATASET_STATISTICS[key]))
34 |
35 | # Aggregate dataset statistics in a readable fashion
36 | dataset_statistics = collect_dataset_statistics(output_dir)
37 |
38 | # Write out updated dataset statistics
39 | with open(dataset_statistics_csv, 'w') as f:
40 | for key in dataset_statistics.keys():
41 | f.write("%s, %s\n" % (key, dataset_statistics[key]))
42 |
43 |
44 | if __name__ == '__main__':
45 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
46 | logging.basicConfig(level=logging.INFO, format=log_fmt)
47 |
48 | main()
49 |
--------------------------------------------------------------------------------
/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ####################### Batch Headers #########################
4 | #SBATCH -A BIF135
5 | #SBATCH -p batch
6 | #SBATCH -J make_casp_capri_dataset
7 | #SBATCH -t 0-24:00
8 | #SBATCH --mem 224G
9 | #SBATCH --nodes 4
10 | #SBATCH --ntasks-per-node 1
11 | ###############################################################
12 |
13 | # Remote paths #
14 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/bif135/Repositories/Lab_Repositories/DIPS-Plus
15 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa
16 | export OMP_NUM_THREADS=8
17 |
18 | # Remote Conda environment #
19 | source "$PROJDIR"/miniconda3/bin/activate
20 | conda activate DIPS-Plus
21 |
22 | # Load CUDA module for DGL
23 | module load cuda/10.2.89
24 |
25 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed)
26 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/bif132/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq
27 |
28 | # Run dataset compilation scripts
29 | cd "$PROJDIR"/project || exit
30 |
31 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/CASP-CAPRI/interim/parsed "$PROJDIR"/project/datasets/CASP-CAPRI/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/CASP-CAPRI/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type casp_capri --write_file
32 |
33 | #srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/CASP-CAPRI/raw "$PROJDIR"/project/datasets/CASP-CAPRI/interim/pairs "$PROJDIR"/project/datasets/CASP-CAPRI/interim/external_feats "$PROJDIR"/project/datasets/CASP-CAPRI/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type CASP-CAPRI
34 |
--------------------------------------------------------------------------------
/project/datasets/builder/compile_db5_dataset_on_andes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ####################### Batch Headers #########################
4 | #SBATCH -A BIP198
5 | #SBATCH -p batch
6 | #SBATCH -J make_db5_dataset
7 | #SBATCH -t 0-12:00
8 | #SBATCH --mem 224G
9 | #SBATCH --nodes 4
10 | #SBATCH --ntasks-per-node 1
11 | ###############################################################
12 |
13 | # Remote paths #
14 | export PROJID=bif132
15 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus
16 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa
17 | export OMP_NUM_THREADS=8
18 |
19 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed)
20 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/$PROJID/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq
21 |
22 | # Remote Conda environment #
23 | source "$PROJDIR"/miniconda3/bin/activate
24 |
25 | # Load CUDA module for DGL
26 | module load cuda/10.2.89
27 |
28 | # Run dataset compilation scripts
29 | cd "$PROJDIR"/project || exit
30 | wget -O "$PROJDIR"/project/datasets/DB5.tar.gz https://dataverse.harvard.edu/api/access/datafile/:persistentId?persistentId=doi:10.7910/DVN/H93ZKK/BXXQCG
31 | tar -xzf "$PROJDIR"/project/datasets/DB5.tar.gz --directory "$PROJDIR"/project/datasets/
32 | rm "$PROJDIR"/project/datasets/DB5.tar.gz "$PROJDIR"/project/datasets/DB5/.README.swp
33 | rm -rf "$PROJDIR"/project/datasets/DB5/interim "$PROJDIR"/project/datasets/DB5/processed
34 | mkdir "$PROJDIR"/project/datasets/DB5/interim "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/interim/external_feats/PSAIA "$PROJDIR"/project/datasets/DB5/interim/external_feats/PSAIA/DB5 "$PROJDIR"/project/datasets/DB5/final "$PROJDIR"/project/datasets/DB5/final/raw
35 |
36 | python3 "$PROJDIR"/project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim --num_cpus 32 --rank "$1" --size "$2" --source_type db5
37 |
38 | python3 "$PROJDIR"/project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_db5.txt "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/external_feats --source_type db5 --rank "$1" --size "$2"
39 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DB5/interim/parsed "$PROJDIR"/project/datasets/DB5/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/DB5/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type db5 --write_file
40 |
41 | srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DB5/raw "$PROJDIR"/project/datasets/DB5/interim/pairs "$PROJDIR"/project/datasets/DB5/interim/external_feats "$PROJDIR"/project/datasets/DB5/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type db5
42 |
43 | python3 "$PROJDIR"/project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DB5/final/raw --source_type db5 --rank "$1" --size "$2"
44 | python3 "$PROJDIR"/project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2"
45 | python3 "$PROJDIR"/project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DB5/final/raw --rank "$1" --size "$2"
46 | python3 "$PROJDIR"/project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DB5/final/raw --impute_atom_features False --num_cpus 32 --rank "$1" --size "$2"
47 |
48 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels
49 | python3 "$PROJDIR"/project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DB5/final/raw "$PROJDIR"/project/datasets/DB5/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2"
50 |
--------------------------------------------------------------------------------
/project/datasets/builder/compile_dips_dataset_on_andes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ####################### Batch Headers #########################
4 | #SBATCH -A BIF132
5 | #SBATCH -p batch
6 | #SBATCH -J make_dips_dataset
7 | #SBATCH -t 0-24:00
8 | #SBATCH --mem 224G
9 | #SBATCH --nodes 4
10 | #SBATCH --ntasks-per-node 1
11 | ###############################################################
12 |
13 | # Remote paths #
14 | export PROJID=bif132
15 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus
16 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa
17 |
18 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed)
19 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/$PROJID/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq
20 |
21 | # Remote Conda environment #
22 | source "$PROJDIR"/miniconda3/bin/activate
23 |
24 | # Load CUDA module for DGL
25 | module load cuda/10.2.89
26 |
27 | # Load GCC 10 module for OpenMPI support
28 | module load gcc/10.3.0
29 |
30 | # Run dataset compilation scripts
31 | cd "$PROJDIR"/project || exit
32 | rm "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-train.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-val.txt "$PROJDIR"/project/datasets/DIPS/final/raw/pairs-postprocessed-test.txt
33 | mkdir "$PROJDIR"/project/datasets/DIPS/raw "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/interim/external_feats/PSAIA "$PROJDIR"/project/datasets/DIPS/interim/external_feats/PSAIA/RCSB "$PROJDIR"/project/datasets/DIPS/final "$PROJDIR"/project/datasets/DIPS/final/raw
34 | rsync -rlpt -v -z --delete --port=33444 --include='*.gz' --include='*/' --exclude '*' rsync.rcsb.org::ftp_data/biounit/coordinates/divided/ "$PROJDIR"/project/datasets/DIPS/raw/pdb
35 |
36 | python3 "$PROJDIR"/project/datasets/builder/extract_raw_pdb_gz_archives.py "$PROJDIR"/project/datasets/DIPS/raw/pdb --rank "$1" --size "$2"
37 |
38 | python3 "$PROJDIR"/project/datasets/builder/make_dataset.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim --num_cpus 32 --rank "$1" --size "$2" --source_type rcsb --bound
39 |
40 | python3 "$PROJDIR"/project/datasets/builder/prune_pairs.py "$PROJDIR"/project/datasets/DIPS/interim/pairs "$PROJDIR"/project/datasets/DIPS/filters "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2"
41 |
42 | python3 "$PROJDIR"/project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_dips.txt "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats --source_type rcsb --rank "$1" --size "$2"
43 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file
44 |
45 | # Retroactively download the PDB files corresponding to complexes that made it through DIPS-Plus' RCSB complex pruning to reduce storage requirements
46 | python3 "$PROJDIR"/project/datasets/builder/download_missing_pruned_pair_pdbs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2"
47 | srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats "$PROJDIR"/project/datasets/DIPS/final/raw --num_cpus 32 --rank "$1" --size "$2"
48 |
49 | python3 "$PROJDIR"/project/datasets/builder/partition_dataset_filenames.py "$PROJDIR"/project/datasets/DIPS/final/raw --source_type rcsb --filter_by_atom_count True --max_atom_count 17500 --rank "$1" --size "$2"
50 | python3 "$PROJDIR"/project/datasets/builder/collect_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2"
51 | python3 "$PROJDIR"/project/datasets/builder/log_dataset_statistics.py "$PROJDIR"/project/datasets/DIPS/final/raw --rank "$1" --size "$2"
52 | python3 "$PROJDIR"/project/datasets/builder/impute_missing_feature_values.py "$PROJDIR"/project/datasets/DIPS/final/raw --impute_atom_features False --num_cpus 32 --rank "$1" --size "$2"
53 |
54 | # Optionally convert each postprocessed (final 'raw') complex into a pair of DGL graphs (final 'processed') with labels
55 | python3 "$PROJDIR"/project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/project/datasets/DIPS/final/raw "$PROJDIR"/project/datasets/DIPS/final/processed --num_cpus 32 --edge_dist_cutoff 15.0 --edge_limit 5000 --self_loops True --rank "$1" --size "$2"
56 |
--------------------------------------------------------------------------------
/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ####################### Batch Headers #########################
4 | #SBATCH -A BIF135
5 | #SBATCH -p batch
6 | #SBATCH -J make_evcoupling_dataset
7 | #SBATCH -t 0-24:00
8 | #SBATCH --mem 224G
9 | #SBATCH --nodes 4
10 | #SBATCH --ntasks-per-node 1
11 | ###############################################################
12 |
13 | # Remote paths #
14 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/bif135/Repositories/Lab_Repositories/DIPS-Plus
15 | export PSAIADIR=/ccs/home/"$USER"/Programs/PSAIA_1.0_source/bin/linux/psa
16 | export OMP_NUM_THREADS=8
17 |
18 | # Remote Conda environment #
19 | source "$PROJDIR"/miniconda3/bin/activate
20 | conda activate DIPS-Plus
21 |
22 | # Load CUDA module for DGL
23 | module load cuda/10.2.89
24 |
25 | # Default to using the Big Fantastic Database (BFD) of protein sequences (approx. 270GB compressed)
26 | export HHSUITE_DB=/gpfs/alpine/scratch/$USER/bif132/Data/Databases/bfd_metaclust_clu_complete_id30_c90_final_seq
27 |
28 | # Run dataset compilation scripts
29 | cd "$PROJDIR"/project || exit
30 |
31 | srun python3 "$PROJDIR"/project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/EVCoupling/interim/parsed "$PROJDIR"/project/datasets/EVCoupling/interim/parsed "$HHSUITE_DB" "$PROJDIR"/project/datasets/EVCoupling/interim/external_feats --rank "$1" --size "$2" --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type evcoupling --read_file
32 |
33 | #srun python3 "$PROJDIR"/project/datasets/builder/postprocess_pruned_pairs.py "$PROJDIR"/project/datasets/EVCoupling/raw "$PROJDIR"/project/datasets/EVCoupling/interim/pairs "$PROJDIR"/project/datasets/EVCoupling/interim/external_feats "$PROJDIR"/project/datasets/EVCoupling/final/raw --num_cpus 32 --rank "$1" --size "$2" --source_type EVCoupling
34 |
--------------------------------------------------------------------------------
/project/datasets/builder/convert_complexes_to_graphs.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 |
5 | import click
6 | from parallel import submit_jobs
7 |
8 | from project.utils.utils import make_dgl_graphs, get_global_node_rank
9 |
10 |
11 | @click.command()
12 | @click.argument('input_dir', default='../DIPS/final/raw', type=click.Path(exists=True))
13 | @click.argument('output_dir', default='../DIPS/final/processed', type=click.Path())
14 | @click.option('--ext', '-e', default='pt')
15 | @click.option('--num_cpus', '-c', default=1)
16 | @click.option('--edge_dist_cutoff', '-d', default=15.0)
17 | @click.option('--edge_limit', '-l', default=5000)
18 | @click.option('--self_loops', '-s', default=True)
19 | @click.option('--rank', '-r', default=0)
20 | @click.option('--size', '-s', default=1)
21 | def main(input_dir: str, output_dir: str, ext: str, num_cpus: int, edge_dist_cutoff: float,
22 | edge_limit: int, self_loops: bool, rank: int, size: int):
23 | """Make DGL graphs out of postprocessed pairs."""
24 | # Reestablish global rank
25 | rank = get_global_node_rank(rank, size)
26 |
27 | # Ensure that this task only gets run on a single node to prevent race conditions
28 | if rank == 0:
29 | logger = logging.getLogger(__name__)
30 | logger.info('Creating DGL graphs (final \'processed\' pairs) from postprocessed (final \'raw\') pairs')
31 |
32 | # Make sure the output_dir exists
33 | if not os.path.exists(output_dir):
34 | os.mkdir(output_dir)
35 |
36 | input_files = [file for file in glob.glob(os.path.join(input_dir, '**', '*.dill'), recursive=True)]
37 | inputs = [(output_dir, input_file, ext, edge_dist_cutoff, edge_limit, self_loops) for input_file in input_files]
38 | submit_jobs(make_dgl_graphs, inputs, num_cpus)
39 |
40 |
41 | if __name__ == '__main__':
42 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
43 | logging.basicConfig(level=logging.INFO, format=log_fmt)
44 |
45 | main()
46 |
--------------------------------------------------------------------------------
/project/datasets/builder/create_hdf5_dataset.py:
--------------------------------------------------------------------------------
1 | import click
2 | import logging
3 | import os
4 | import warnings
5 |
6 | from hickle.lookup import SerializedWarning
7 |
8 | warnings.simplefilter("ignore", category=SerializedWarning)
9 |
10 | from pathlib import Path
11 | from parallel import submit_jobs
12 |
13 | from project.utils.utils import convert_pair_pickle_to_hdf5
14 |
15 |
16 | @click.command()
17 | @click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True))
18 | @click.option('--num_cpus', '-c', default=1)
19 | def main(raw_data_dir: str, num_cpus: int):
20 | raw_data_dir = Path(raw_data_dir)
21 | raw_data_pickle_filepaths = []
22 | for root, dirs, files in os.walk(raw_data_dir):
23 | for dir in dirs:
24 | for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir):
25 | for file in subfiles:
26 | if file.endswith('.dill'):
27 | raw_data_pickle_filepaths.append(raw_data_dir / dir / file)
28 | inputs = [(pickle_filepath, Path(pickle_filepath).with_suffix(".hdf5")) for pickle_filepath in raw_data_pickle_filepaths]
29 | submit_jobs(convert_pair_pickle_to_hdf5, inputs, num_cpus)
30 |
31 | # filepath = Path("project/datasets/DIPS/final/raw/0g/10gs.pdb1_0.dill")
32 | # pickle_example = convert_pair_hdf5_to_pickle(
33 | # hdf5_filepath=Path(filepath).with_suffix(".hdf5")
34 | # )
35 | # hdf5_file_example = convert_pair_hdf5_to_hdf5_file(
36 | # hdf5_filepath=Path(filepath).with_suffix(".hdf5")
37 | # )
38 | # print(f"pickle_example: {pickle_example}")
39 | # print(f"hdf5_file_example: {hdf5_file_example}")
40 |
41 |
42 | if __name__ == "__main__":
43 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
44 | logging.basicConfig(level=logging.INFO, format=log_fmt)
45 |
46 | main()
47 |
--------------------------------------------------------------------------------
/project/datasets/builder/download_missing_pruned_pair_pdbs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import click
5 | from atom3.database import get_structures_filenames, get_pdb_name, get_pdb_code
6 | from atom3.utils import slice_list
7 | from mpi4py import MPI
8 | from parallel import submit_jobs
9 |
10 | from project.utils.utils import get_global_node_rank, download_missing_pruned_pair_pdbs
11 |
12 |
13 | @click.command()
14 | @click.argument('output_dir', default='../DIPS/raw/pdb', type=click.Path(exists=True))
15 | @click.argument('pruned_pairs_dir', default='../DIPS/interim/pairs-pruned', type=click.Path(exists=True))
16 | @click.option('--num_cpus', '-c', default=1)
17 | @click.option('--rank', '-r', default=0)
18 | @click.option('--size', '-s', default=1)
19 | def main(output_dir: str, pruned_pairs_dir: str, num_cpus: int, rank: int, size: int):
20 | """Download missing pruned pair PDB files."""
21 | # Reestablish global rank
22 | rank = get_global_node_rank(rank, size)
23 | logger = logging.getLogger(__name__)
24 | logger.info(f'Beginning missing PDB downloads for node {rank + 1} out of a global MPI world of size {size},'
25 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}')
26 |
27 | # Make sure the output_dir exists
28 | if not os.path.exists(output_dir):
29 | os.mkdir(output_dir)
30 |
31 | # Get work filenames
32 | logger.info(f'Looking for all pairs in {pruned_pairs_dir}')
33 | requested_filenames = get_structures_filenames(pruned_pairs_dir, extension='.dill')
34 | requested_filenames = [filename for filename in requested_filenames]
35 | requested_keys = [get_pdb_name(x) for x in requested_filenames]
36 | work_keys = [key for key in requested_keys]
37 | work_filenames = [os.path.join(pruned_pairs_dir, get_pdb_code(work_key)[1:3], work_key + '.dill')
38 | for work_key in work_keys]
39 | logger.info(f'Found {len(work_keys)} work pair(s) in {pruned_pairs_dir}')
40 |
41 | # Reserve an equally-sized portion of the full work load for a given rank in the MPI world
42 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames
43 | work_filename_rank_batches = slice_list(work_filenames, size)
44 | work_filenames = work_filename_rank_batches[rank]
45 |
46 | # Collect thread inputs
47 | inputs = [(logger, output_dir, work_filename) for work_filename in work_filenames]
48 | submit_jobs(download_missing_pruned_pair_pdbs, inputs, num_cpus)
49 |
50 |
51 | if __name__ == '__main__':
52 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
53 | logging.basicConfig(level=logging.INFO, format=log_fmt)
54 |
55 | main()
56 |
--------------------------------------------------------------------------------
/project/datasets/builder/extract_raw_pdb_gz_archives.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import logging
3 | import os
4 | import shutil
5 |
6 | import click
7 | from project.utils.utils import get_global_node_rank
8 |
9 |
10 | @click.command()
11 | @click.argument('gz_data_dir', type=click.Path(exists=True))
12 | @click.option('--rank', '-r', default=0)
13 | @click.option('--size', '-s', default=1)
14 | def main(gz_data_dir: str, rank: int, size: int):
15 | """ Run GZ extraction logic to turn raw data from (raw/) into extracted data ready to be analyzed by DSSP."""
16 | # Reestablish global rank
17 | rank = get_global_node_rank(rank, size)
18 |
19 | # Ensure that this task only gets run on a single node to prevent race conditions
20 | if rank == 0:
21 | logger = logging.getLogger(__name__)
22 | logger.info('Extracting raw GZ archives')
23 |
24 | # Iterate directory structure, extracting all GZ archives along the way
25 | data_dir = os.path.abspath(gz_data_dir) + '/'
26 | raw_pdb_list = os.listdir(data_dir)
27 | for pdb_dir in raw_pdb_list:
28 | for pdb_gz in os.listdir(data_dir + pdb_dir):
29 | if pdb_gz.endswith('.gz'):
30 | _, ext = os.path.splitext(pdb_gz)
31 | gzip_dir = data_dir + pdb_dir + '/' + pdb_gz
32 | extract_dir = data_dir + pdb_dir + '/' + _
33 | if not os.path.exists(extract_dir):
34 | with gzip.open(gzip_dir, 'rb') as f_in:
35 | with open(extract_dir, 'wb') as f_out:
36 | shutil.copyfileobj(f_in, f_out)
37 |
38 |
39 | if __name__ == '__main__':
40 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
41 | logging.basicConfig(level=logging.INFO, format=log_fmt)
42 |
43 | main()
44 |
--------------------------------------------------------------------------------
/project/datasets/builder/generate_hhsuite_features.py:
--------------------------------------------------------------------------------
1 | """
2 | Source code (MIT-Licensed) inspired by Atom3 (https://github.com/drorlab/atom3 & https://github.com/amorehead/atom3)
3 | """
4 |
5 | import glob
6 | import logging
7 | import os
8 | from os import cpu_count
9 |
10 | import click
11 | from atom3.conservation import map_all_profile_hmms
12 | from mpi4py import MPI
13 |
14 | from project.utils.utils import get_global_node_rank
15 |
16 |
17 | @click.command()
18 | @click.argument('pkl_dataset', type=click.Path(exists=True))
19 | @click.argument('pruned_dataset', type=click.Path(exists=True))
20 | @click.argument('hhsuite_db', type=click.Path())
21 | @click.argument('output_dir', type=click.Path())
22 | @click.option('--rank', default=0, type=int)
23 | @click.option('--size', default=1, type=int)
24 | @click.option('--num_cpu_jobs', '-j', default=cpu_count() // 2, type=int)
25 | @click.option('--num_cpus_per_job', '-c', default=2, type=int)
26 | @click.option('--num_iter', '-i', default=2, type=int)
27 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri']))
28 | @click.option('--generate_hmm_profile/--generate_msa_only', '-p', default=True)
29 | @click.option('--write_file/--read_file', '-w', default=True)
30 | def main(pkl_dataset: str, pruned_dataset: str, hhsuite_db: str, output_dir: str, rank: int,
31 | size: int, num_cpu_jobs: int, num_cpus_per_job: int, num_iter: int, source_type: str,
32 | generate_hmm_profile: bool, write_file: bool):
33 | """Run external programs for feature generation to turn raw PDB files from (../raw) into sequence or structure-based residue features (saved in ../interim/external_feats by default)."""
34 | logger = logging.getLogger(__name__)
35 | logger.info(f'Generating external features from PDB files in {pkl_dataset}')
36 |
37 | # Reestablish global rank
38 | rank = get_global_node_rank(rank, size)
39 | logger.info(f"Assigned global rank {rank} of world size {size}")
40 |
41 | # Determine true rank and size for a given node
42 | bfd_copy_ids = ["_1", "_2", "_3", "_4", "_5", "_6", "_7", "_8",
43 | "_9", "_10", "_11", "_12", "_17", "_21", "_25", "_29"]
44 | bfd_copy_id = bfd_copy_ids[rank]
45 |
46 | # Assemble true ID of the BFD copy to use for generating profile HMMs
47 | hhsuite_dbs = glob.glob(os.path.join(hhsuite_db + bfd_copy_id, "*bfd*"))
48 | assert len(hhsuite_dbs) == 1, "Only a single BFD database must be present in the given database directory."
49 | hhsuite_db = hhsuite_dbs[0]
50 | logger.info(f'Starting HH-suite for node {rank + 1} out of a global MPI world of size {size},'
51 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}.'
52 | f' This node\'s copy of the BFD is {hhsuite_db}')
53 |
54 | # Generate profile HMMs #
55 | # Run with --write_file=True using one node
56 | # Then run with --read_file=True using multiple nodes to distribute workload across nodes and their CPU cores
57 | map_all_profile_hmms(pkl_dataset, pruned_dataset, output_dir, hhsuite_db, num_cpu_jobs,
58 | num_cpus_per_job, source_type, num_iter, not generate_hmm_profile, rank, size, write_file)
59 |
60 |
61 | if __name__ == '__main__':
62 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
63 | logging.basicConfig(level=logging.INFO, format=log_fmt)
64 |
65 | main()
66 |
--------------------------------------------------------------------------------
/project/datasets/builder/generate_psaia_features.py:
--------------------------------------------------------------------------------
1 | """
2 | Source code (MIT-Licensed) inspired by Atom3 (https://github.com/drorlab/atom3 & https://github.com/amorehead/atom3)
3 | """
4 |
5 | import logging
6 | import os
7 |
8 | import click
9 | import atom3.conservation as con
10 |
11 | from project.utils.utils import get_global_node_rank
12 |
13 |
14 | @click.command()
15 | @click.argument('psaia_dir', type=click.Path(exists=True))
16 | @click.argument('psaia_config', type=click.Path(exists=True))
17 | @click.argument('pdb_dataset', type=click.Path(exists=True))
18 | @click.argument('pkl_dataset', type=click.Path(exists=True))
19 | @click.argument('pruned_dataset', type=click.Path(exists=True))
20 | @click.argument('output_dir', type=click.Path())
21 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri']))
22 | @click.option('--rank', '-r', default=0)
23 | @click.option('--size', '-s', default=1)
24 | def main(psaia_dir: str, psaia_config: str, pdb_dataset: str, pkl_dataset: str,
25 | pruned_dataset: str, output_dir: str, source_type: str, rank: int, size: int):
26 | """Run external programs for feature generation to turn raw PDB files from (../raw) into sequence or structure-based residue features (saved in ../interim/external_feats by default)."""
27 | # Reestablish global rank
28 | rank = get_global_node_rank(rank, size)
29 |
30 | # Ensure that this task only gets run on a single node to prevent race conditions
31 | if rank == 0:
32 | logger = logging.getLogger(__name__)
33 | logger.info(f'Generating PSAIA features from PDB files in {pkl_dataset}')
34 |
35 | # Ensure PSAIA is in PATH
36 | PSAIA_PATH = psaia_dir
37 | if PSAIA_PATH not in os.environ["PATH"]:
38 | logger.info('Adding ' + PSAIA_PATH + ' to system path')
39 | os.environ["PATH"] += os.path.sep + PSAIA_PATH
40 |
41 | # Generate protrusion indices
42 | con.map_all_protrusion_indices(psaia_dir, psaia_config, pdb_dataset, pkl_dataset,
43 | pruned_dataset, output_dir, source_type)
44 |
45 |
46 | if __name__ == '__main__':
47 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
48 | logging.basicConfig(level=logging.INFO, format=log_fmt)
49 |
50 | main()
51 |
--------------------------------------------------------------------------------
/project/datasets/builder/impute_missing_feature_values.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | import click
6 | from parallel import submit_jobs
7 |
8 | from project.utils.utils import get_global_node_rank, impute_missing_feature_values
9 |
10 |
11 | @click.command()
12 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
13 | @click.option('--impute_atom_features', '-a', default=False)
14 | @click.option('--advanced_logging', '-l', default=False)
15 | @click.option('--num_cpus', '-c', default=1)
16 | @click.option('--rank', '-r', default=0)
17 | @click.option('--size', '-s', default=1)
18 | def main(output_dir: str, impute_atom_features: bool, advanced_logging: bool, num_cpus: int, rank: int, size: int):
19 | """Impute missing feature values."""
20 | # Reestablish global rank
21 | rank = get_global_node_rank(rank, size)
22 |
23 | # Ensure that this task only gets run on a single node to prevent race conditions
24 | if rank == 0:
25 | logger = logging.getLogger(__name__)
26 | logger.info('Imputing missing feature values for given dataset')
27 |
28 | # Make sure the output_dir exists
29 | if not os.path.exists(output_dir):
30 | os.mkdir(output_dir)
31 |
32 | # Collect thread inputs
33 | inputs = [(pair_filename.as_posix(), pair_filename.as_posix(), impute_atom_features, advanced_logging)
34 | for pair_filename in Path(output_dir).rglob('*.dill')]
35 | # Without impute_atom_features set to True, non-CA atoms will be filtered out after writing updated pairs
36 | submit_jobs(impute_missing_feature_values, inputs, num_cpus)
37 |
38 |
39 | if __name__ == '__main__':
40 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
41 | logging.basicConfig(level=logging.INFO, format=log_fmt)
42 |
43 | main()
44 |
--------------------------------------------------------------------------------
/project/datasets/builder/launch_parallel_slurm_jobs_on_andes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Where the project is stored
4 | export PROJID=bif135
5 | export PROJDIR=/gpfs/alpine/scratch/"$USER"/$PROJID/Repositories/Lab_Repositories/DIPS-Plus
6 |
7 | # Which copies of the BFD to use for the available nodes' search batches (i.e. in range ['_1', '_2', ..., '_31', '_32'])
8 | BFD_COPY_IDS=("_1" "_2" "_3" "_4" "_5" "_6" "_7" "_8" "_9" "_10" "_11" "_12" "_17" "_21" "_25" "_29")
9 | NUM_BFD_COPIES=${#BFD_COPY_IDS[@]}
10 |
11 | # Whether to launch multi-node compilation jobs for DIPS or instead for DB5
12 | compile_dips=false
13 | compile_db5=false
14 | compile_evcoupling=false
15 | compile_casp_capri=true
16 |
17 | if [ "$compile_dips" = true ]; then
18 | # Job 1 - DIPS
19 | echo Submitting job 1 for compile_dips_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES"
20 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 0 "$NUM_BFD_COPIES"
21 |
22 | # Job 2 - DIPS
23 | echo Submitting job 2 for compile_dips_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES"
24 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 1 "$NUM_BFD_COPIES"
25 |
26 | # Job 3 - DIPS
27 | echo Submitting job 3 for compile_dips_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES"
28 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 2 "$NUM_BFD_COPIES"
29 |
30 | # Job 4 - DIPS
31 | echo Submitting job 4 for compile_dips_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES"
32 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 3 "$NUM_BFD_COPIES"
33 |
34 | elif [ "$compile_db5" = true ]; then
35 | # Job 1 - DB5
36 | echo Submitting job 1 for compile_db5_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES"
37 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 0 "$NUM_BFD_COPIES"
38 |
39 | # Job 2 - DB5
40 | echo Submitting job 2 for compile_db5_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES"
41 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 1 "$NUM_BFD_COPIES"
42 |
43 | # Job 3 - DB5
44 | echo Submitting job 3 for compile_db5_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES"
45 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 2 "$NUM_BFD_COPIES"
46 |
47 | # Job 4 - DB5
48 | echo Submitting job 4 for compile_db5_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES"
49 | sbatch "$PROJDIR"/project/datasets/builder/compile_db5_dataset_on_andes.sh 3 "$NUM_BFD_COPIES"
50 |
51 | elif [ "$compile_evcoupling" = true ]; then
52 | # Job 1 - EVCoupling
53 | echo Submitting job 1 for compile_evcoupling_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES"
54 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 0 "$NUM_BFD_COPIES"
55 |
56 | # Job 2 - EVCoupling
57 | echo Submitting job 2 for compile_evcoupling_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES"
58 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 1 "$NUM_BFD_COPIES"
59 |
60 | # Job 3 - EVCoupling
61 | echo Submitting job 3 for compile_evcoupling_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES"
62 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 2 "$NUM_BFD_COPIES"
63 |
64 | # Job 4 - EVCoupling
65 | echo Submitting job 4 for compile_evcoupling_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES"
66 | sbatch "$PROJDIR"/project/datasets/builder/compile_evcoupling_dataset_on_andes.sh 3 "$NUM_BFD_COPIES"
67 |
68 | elif [ "$compile_casp_capri" = true ]; then
69 | # Job 1 - CASP-CAPRI
70 | echo Submitting job 1 for compile_casp_capri_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES"
71 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 0 "$NUM_BFD_COPIES"
72 |
73 | # Job 2 - CASP-CAPRI
74 | echo Submitting job 2 for compile_casp_capri_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES"
75 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 1 "$NUM_BFD_COPIES"
76 |
77 | # Job 3 - CASP-CAPRI
78 | echo Submitting job 3 for compile_casp_capri_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES"
79 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 2 "$NUM_BFD_COPIES"
80 |
81 | # Job 4 - CASP-CAPRI
82 | echo Submitting job 4 for compile_casp_capri_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES"
83 | sbatch "$PROJDIR"/project/datasets/builder/compile_casp_capri_dataset_on_andes.sh 3 "$NUM_BFD_COPIES"
84 |
85 | else
86 | # Job 1 - DIPS
87 | echo Submitting job 1 for compile_dips_dataset_on_andes.sh with parameters: 0 "$NUM_BFD_COPIES"
88 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 0 "$NUM_BFD_COPIES"
89 |
90 | # Job 2 - DIPS
91 | echo Submitting job 2 for compile_dips_dataset_on_andes.sh with parameters: 1 "$NUM_BFD_COPIES"
92 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 1 "$NUM_BFD_COPIES"
93 |
94 | # Job 3 - DIPS
95 | echo Submitting job 3 for compile_dips_dataset_on_andes.sh with parameters: 2 "$NUM_BFD_COPIES"
96 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 2 "$NUM_BFD_COPIES"
97 |
98 | # Job 4 - DIPS
99 | echo Submitting job 4 for compile_dips_dataset_on_andes.sh with parameters: 3 "$NUM_BFD_COPIES"
100 | sbatch "$PROJDIR"/project/datasets/builder/compile_dips_dataset_on_andes.sh 3 "$NUM_BFD_COPIES"
101 | fi
102 |
--------------------------------------------------------------------------------
/project/datasets/builder/log_dataset_statistics.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import click
5 |
6 | from project.utils.utils import get_global_node_rank, log_dataset_statistics, DEFAULT_DATASET_STATISTICS
7 |
8 |
9 | @click.command()
10 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
11 | @click.option('--rank', '-r', default=0)
12 | @click.option('--size', '-s', default=1)
13 | def main(output_dir: str, rank: int, size: int):
14 | """Log all collected dataset statistics."""
15 | # Reestablish global rank
16 | rank = get_global_node_rank(rank, size)
17 |
18 | # Ensure that this task only gets run on a single node to prevent race conditions
19 | if rank == 0:
20 | logger = logging.getLogger(__name__)
21 |
22 | # Make sure the output_dir exists
23 | if not os.path.exists(output_dir):
24 | os.mkdir(output_dir)
25 |
26 | # Create dataset statistics CSV if not already existent
27 | dataset_statistics_csv = os.path.join(output_dir, 'dataset_statistics.csv')
28 | if not os.path.exists(dataset_statistics_csv):
29 | # Reset dataset statistics CSV
30 | with open(dataset_statistics_csv, 'w') as f:
31 | for key in DEFAULT_DATASET_STATISTICS.keys():
32 | f.write("%s, %s\n" % (key, DEFAULT_DATASET_STATISTICS[key]))
33 |
34 | with open(dataset_statistics_csv, 'r') as f:
35 | # Read-in existing dataset statistics
36 | dataset_statistics = {}
37 | for line in f.readlines():
38 | dataset_statistics[line.split(',')[0].strip()] = int(line.split(',')[1].strip())
39 |
40 | # Log dataset statistics in a readable fashion
41 | if dataset_statistics is not None:
42 | log_dataset_statistics(logger, dataset_statistics)
43 |
44 |
45 | if __name__ == '__main__':
46 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
47 | logging.basicConfig(level=logging.INFO, format=log_fmt)
48 |
49 | main()
50 |
--------------------------------------------------------------------------------
/project/datasets/builder/make_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Source code (MIT-Licensed) originally from DIPS (https://github.com/drorlab/DIPS)
3 | """
4 |
5 | import logging
6 | import os
7 |
8 | import atom3.complex as comp
9 | import atom3.neighbors as nb
10 | import atom3.pair as pair
11 | import atom3.parse as pa
12 | import click
13 | from project.utils.utils import get_global_node_rank
14 |
15 |
16 | @click.command()
17 | @click.argument('input_dir', type=click.Path(exists=True))
18 | @click.argument('output_dir', type=click.Path())
19 | @click.option('--num_cpus', '-c', default=1)
20 | @click.option('--rank', '-r', default=0)
21 | @click.option('--size', '-s', default=1)
22 | @click.option('--neighbor_def', default='non_heavy_res',
23 | type=click.Choice(['non_heavy_res', 'non_heavy_atom', 'ca_res', 'ca_atom']))
24 | @click.option('--cutoff', default=6)
25 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri']))
26 | @click.option('--unbound/--bound', default=False)
27 | def main(input_dir: str, output_dir: str, num_cpus: int, rank: int, size: int,
28 | neighbor_def: str, cutoff: int, source_type: str, unbound: bool):
29 | """Run data processing scripts to turn raw data from (../raw) into cleaned data ready to be analyzed (saved in ../interim). For reference, pos_idx indicates the IDs of residues in interaction with non-heavy atoms in a cross-protein residue."""
30 | # Reestablish global rank
31 | rank = get_global_node_rank(rank, size)
32 |
33 | # Ensure that this task only gets run on a single node to prevent race conditions
34 | if rank == 0:
35 | logger = logging.getLogger(__name__)
36 | logger.info('Making interim data set from raw data')
37 |
38 | parsed_dir = os.path.join(output_dir, 'parsed')
39 | pa.parse_all(input_dir, parsed_dir, num_cpus)
40 |
41 | complexes_dill = os.path.join(output_dir, 'complexes/complexes.dill')
42 | comp.complexes(parsed_dir, complexes_dill, source_type)
43 | pairs_dir = os.path.join(output_dir, 'pairs')
44 | get_neighbors = nb.build_get_neighbors(neighbor_def, cutoff)
45 | get_pairs = pair.build_get_pairs(neighbor_def, source_type, unbound, get_neighbors, False)
46 | complexes = comp.read_complexes(complexes_dill)
47 | pair.all_complex_to_pairs(complexes, source_type, get_pairs, pairs_dir, num_cpus)
48 |
49 |
50 | if __name__ == '__main__':
51 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
52 | logging.basicConfig(level=logging.INFO, format=log_fmt)
53 |
54 | main()
55 |
--------------------------------------------------------------------------------
/project/datasets/builder/postprocess_pruned_pairs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import click
5 | from atom3.database import get_structures_filenames, get_pdb_name, get_pdb_code
6 | from atom3.utils import slice_list
7 | from mpi4py import MPI
8 | from parallel import submit_jobs
9 |
10 | from project.utils.utils import get_global_node_rank
11 | from project.utils.utils import postprocess_pruned_pairs
12 |
13 |
14 | @click.command()
15 | @click.argument('raw_pdb_dir', default='../DIPS/raw/pdb', type=click.Path(exists=True))
16 | @click.argument('pruned_pairs_dir', default='../DIPS/interim/pairs-pruned', type=click.Path(exists=True))
17 | @click.argument('external_feats_dir', default='../DIPS/interim/external_feats', type=click.Path(exists=True))
18 | @click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
19 | @click.option('--num_cpus', '-c', default=1)
20 | @click.option('--rank', '-r', default=0)
21 | @click.option('--size', '-s', default=1)
22 | @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri']))
23 | def main(raw_pdb_dir: str, pruned_pairs_dir: str, external_feats_dir: str, output_dir: str,
24 | num_cpus: int, rank: int, size: int, source_type: str):
25 | """Run postprocess_pruned_pairs on all provided complexes."""
26 | # Reestablish global rank
27 | rank = get_global_node_rank(rank, size)
28 | logger = logging.getLogger(__name__)
29 | logger.info(f'Starting postprocessing for node {rank + 1} out of a global MPI world of size {size},'
30 | f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}')
31 |
32 | # Make sure the output_dir exists
33 | if not os.path.exists(output_dir):
34 | os.mkdir(output_dir)
35 |
36 | # Get work filenames
37 | logger.info(f'Looking for all pairs in {pruned_pairs_dir}')
38 | requested_filenames = get_structures_filenames(pruned_pairs_dir, extension='.dill')
39 | requested_filenames = [filename for filename in requested_filenames]
40 | requested_keys = [get_pdb_name(x) for x in requested_filenames]
41 | produced_filenames = get_structures_filenames(output_dir, extension='.dill')
42 | produced_keys = [get_pdb_name(x) for x in produced_filenames]
43 | work_keys = [key for key in requested_keys if key not in produced_keys]
44 | rscb_pruned_pair_ext = '.dill' if source_type.lower() in ['rcsb', 'evcoupling', 'casp_capri'] else ''
45 | work_filenames = [os.path.join(pruned_pairs_dir, get_pdb_code(work_key)[1:3], work_key + rscb_pruned_pair_ext)
46 | for work_key in work_keys]
47 | logger.info(f'Found {len(work_keys)} work pair(s) in {pruned_pairs_dir}')
48 |
49 | # Reserve an equally-sized portion of the full work load for a given rank in the MPI world
50 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames
51 | work_filename_rank_batches = slice_list(work_filenames, size)
52 | work_filenames = work_filename_rank_batches[rank]
53 |
54 | # Get filenames in which our threads will store output
55 | output_filenames = []
56 | for pdb_filename in work_filenames:
57 | sub_dir = output_dir + '/' + get_pdb_code(pdb_filename)[1:3]
58 | if not os.path.exists(sub_dir):
59 | os.mkdir(sub_dir)
60 | new_output_filename = sub_dir + '/' + get_pdb_name(pdb_filename) + ".dill" if \
61 | source_type in ['rcsb', 'evcoupling', 'casp_capri'] else \
62 | sub_dir + '/' + get_pdb_name(pdb_filename)
63 | output_filenames.append(new_output_filename)
64 |
65 | # Collect thread inputs
66 | inputs = [(raw_pdb_dir, external_feats_dir, i, o, source_type)
67 | for i, o in zip(work_filenames, output_filenames)]
68 | submit_jobs(postprocess_pruned_pairs, inputs, num_cpus)
69 |
70 |
71 | if __name__ == '__main__':
72 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
73 | logging.basicConfig(level=logging.INFO, format=log_fmt)
74 |
75 | main()
76 |
--------------------------------------------------------------------------------
/project/datasets/builder/prune_pairs.py:
--------------------------------------------------------------------------------
1 | """
2 | Source code (MIT-Licensed) originally from DIPS (https://github.com/drorlab/DIPS)
3 | """
4 |
5 | import logging
6 | import os
7 |
8 | import atom3.database as db
9 | import click
10 | import numpy as np
11 | import parallel as par
12 |
13 | from project.utils.utils import __load_to_keep_files_into_dataframe, process_pairs_to_keep
14 | from project.utils.utils import get_global_node_rank
15 |
16 |
17 | @click.command()
18 | @click.argument('pair_dir', type=click.Path(exists=True))
19 | @click.argument('to_keep_dir', type=click.Path(exists=True))
20 | @click.argument('output_dir', type=click.Path())
21 | @click.option('--num_cpus', '-c', default=1)
22 | @click.option('--rank', '-r', default=0)
23 | @click.option('--size', '-s', default=1)
24 | def main(pair_dir: str, to_keep_dir: str, output_dir: str, num_cpus: int, rank: int, size: int):
25 | """Run write_pairs on all provided complexes."""
26 | # Reestablish global rank
27 | rank = get_global_node_rank(rank, size)
28 |
29 | # Ensure that this task only gets run on a single node to prevent race conditions
30 | if rank == 0:
31 | logger = logging.getLogger(__name__)
32 | to_keep_filenames = \
33 | db.get_structures_filenames(to_keep_dir, extension='.txt')
34 | if len(to_keep_filenames) == 0:
35 | logger.warning(f'There is no to_keep file in {to_keep_dir}.'
36 | f' All pair files from {pair_dir} will be copied into {output_dir}')
37 |
38 | to_keep_df = __load_to_keep_files_into_dataframe(to_keep_filenames)
39 | logger.info(f'There are {to_keep_df.shape} rows, cols in to_keep_df')
40 |
41 | # Get work filenames
42 | logger.info(f'Looking for all pairs in {pair_dir}')
43 | work_filenames = db.get_structures_filenames(pair_dir, extension='.dill')
44 | work_filenames = list(set(work_filenames)) # Remove any duplicate filenames
45 | work_keys = [db.get_pdb_name(x) for x in work_filenames]
46 | logger.info(f'Found {len(work_keys)} pairs in {output_dir}')
47 |
48 | # Get filenames in which our threads will store output
49 | output_filenames = []
50 | for pdb_filename in work_filenames:
51 | sub_dir = output_dir + '/' + db.get_pdb_code(pdb_filename)[1:3]
52 | if not os.path.exists(sub_dir):
53 | os.mkdir(sub_dir)
54 | output_filenames.append(
55 | sub_dir + '/' + db.get_pdb_name(pdb_filename) + ".dill")
56 |
57 | # Collect thread inputs
58 | n_copied = 0
59 | inputs = [(i, o, to_keep_df) for i, o in zip(work_filenames, output_filenames)]
60 | n_copied += np.sum(par.submit_jobs(process_pairs_to_keep, inputs, num_cpus))
61 | logger.info(f'{n_copied} out of {len(work_keys)} pairs was copied')
62 |
63 |
64 | if __name__ == '__main__':
65 | log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
66 | logging.basicConfig(level=logging.INFO, format=log_fmt)
67 |
68 | main()
69 |
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_chothia.radii:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_chothia.radii
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_config_file_casp_capri.txt:
--------------------------------------------------------------------------------
1 | analyze_bound: 1
2 | analyze_unbound: 0
3 | calc_asa: 0
4 | z_slice: 0.5
5 | r_solvent: 1.4
6 | write_asa: 0
7 | calc_rasa: 0
8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa
9 | calc_dpx: 0
10 | calc_cx: 1
11 | cx_threshold: 10
12 | cx_volume: 20.1
13 | calc_hydro: 0
14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb
15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii
16 | write_xml: 0
17 | write_table: 1
18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/CASP-CAPRI/interim/external_feats/PSAIA/CASP_CAPRI
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_config_file_db5.txt:
--------------------------------------------------------------------------------
1 | analyze_bound: 0
2 | analyze_unbound: 1
3 | calc_asa: 0
4 | z_slice: 0.25
5 | r_solvent: 1.4
6 | write_asa: 0
7 | calc_rasa: 0
8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa
9 | calc_dpx: 0
10 | calc_cx: 1
11 | cx_threshold: 10
12 | cx_volume: 20.1
13 | calc_hydro: 0
14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb
15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii
16 | write_xml: 0
17 | write_table: 1
18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/DB5/interim/external_feats/PSAIA/DB5
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_config_file_dips.txt:
--------------------------------------------------------------------------------
1 | analyze_bound: 1
2 | analyze_unbound: 0
3 | calc_asa: 0
4 | z_slice: 0.5
5 | r_solvent: 1.4
6 | write_asa: 0
7 | calc_rasa: 0
8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa
9 | calc_dpx: 0
10 | calc_cx: 1
11 | cx_threshold: 10
12 | cx_volume: 20.1
13 | calc_hydro: 0
14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb
15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii
16 | write_xml: 0
17 | write_table: 1
18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/DIPS/interim/external_feats/PSAIA/RCSB
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_config_file_evcoupling.txt:
--------------------------------------------------------------------------------
1 | analyze_bound: 1
2 | analyze_unbound: 0
3 | calc_asa: 0
4 | z_slice: 0.5
5 | r_solvent: 1.4
6 | write_asa: 0
7 | calc_rasa: 0
8 | standard_asa: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_natural_asa.asa
9 | calc_dpx: 0
10 | calc_cx: 1
11 | cx_threshold: 10
12 | cx_volume: 20.1
13 | calc_hydro: 0
14 | hydro_file: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_hydrophobicity.hpb
15 | radii_filename: /home/acmwhb/data/DIPS-Plus/project/datasets/builder/psaia_chothia.radii
16 | write_xml: 0
17 | write_table: 1
18 | output_dir: /home/acmwhb/data/DIPS-Plus/project/datasets/EVCoupling/interim/external_feats/PSAIA/EVCOUPLING
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_hydrophobicity.hpb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_hydrophobicity.hpb
--------------------------------------------------------------------------------
/project/datasets/builder/psaia_natural_asa.asa:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BioinfoMachineLearning/DIPS-Plus/bed1cd22aa8fcd53b514de094ade21f52e922ddb/project/datasets/builder/psaia_natural_asa.asa
--------------------------------------------------------------------------------
/project/utils/constants.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from Bio.PDB.PDBParser import PDBParser
3 | from Bio.PDB.Polypeptide import CaPPBuilder
4 |
5 | # Cluster-specific limit to the number of compute nodes available to each Slurm job
6 | MAX_NODES_PER_JOB = 4
7 |
8 | # Dataset-global node count limits to restrict computational learning complexity
9 | ATOM_COUNT_LIMIT = 17500 # Default filter for both datasets when encoding complexes at an atom-based level
10 |
11 | # From where we can get bound PDB complexes
12 | RCSB_BASE_URL = 'ftp://ftp.wwpdb.org/pub/pdb/data/biounit/coordinates/divided/'
13 |
14 | # The PDB codes of structures added between DB4 and DB5 (to be used for testing dataset)
15 | DB5_TEST_PDB_CODES = ['3R9A', '4GAM', '3AAA', '4H03', '1EXB',
16 | '2GAF', '2GTP', '3RVW', '3SZK', '4IZ7',
17 | '4GXU', '3BX7', '2YVJ', '3V6Z', '1M27',
18 | '4FQI', '4G6J', '3BIW', '3PC8', '3HI6',
19 | '2X9A', '3HMX', '2W9E', '4G6M', '3LVK',
20 | '1JTD', '3H2V', '4DN4', 'BP57', '3L5W',
21 | '3A4S', 'CP57', '3DAW', '3VLB', '3K75',
22 | '2VXT', '3G6D', '3EO1', '4JCV', '4HX3',
23 | '3F1P', '3AAD', '3EOA', '3MXW', '3L89',
24 | '4M76', 'BAAD', '4FZA', '4LW4', '1RKE',
25 | '3FN1', '3S9D', '3H11', '2A1A', '3P57']
26 |
27 | # Postprocessing logger dictionary
28 | DEFAULT_DATASET_STATISTICS = dict(num_of_processed_complexes=0, num_of_df0_residues=0, num_of_df1_residues=0,
29 | num_of_df0_interface_residues=0, num_of_df1_interface_residues=0,
30 | num_of_pos_res_pairs=0, num_of_neg_res_pairs=0, num_of_res_pairs=0,
31 | num_of_valid_df0_ss_values=0, num_of_valid_df1_ss_values=0,
32 | num_of_valid_df0_rsa_values=0, num_of_valid_df1_rsa_values=0,
33 | num_of_valid_df0_rd_values=0, num_of_valid_df1_rd_values=0,
34 | num_of_valid_df0_protrusion_indices=0, num_of_valid_df1_protrusion_indices=0,
35 | num_of_valid_df0_hsaacs=0, num_of_valid_df1_hsaacs=0,
36 | num_of_valid_df0_cn_values=0, num_of_valid_df1_cn_values=0,
37 | num_of_valid_df0_sequence_feats=0, num_of_valid_df1_sequence_feats=0,
38 | num_of_valid_df0_amide_normal_vecs=0, num_of_valid_df1_amide_normal_vecs=0)
39 |
40 | # Parsing utilities for PDB files (i.e. relevant for sequence and structure analysis)
41 | PDB_PARSER = PDBParser()
42 | CA_PP_BUILDER = CaPPBuilder()
43 |
44 | # Dict for converting three letter codes to one letter codes
45 | D3TO1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
46 | 'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
47 | 'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
48 | 'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
49 | RES_NAMES_LIST = list(D3TO1.keys())
50 |
51 | # PSAIA features to encode as DataFrame columns
52 | PSAIA_COLUMNS = ['avg_cx', 's_avg_cx', 's_ch_avg_cx', 's_ch_s_avg_cx', 'max_cx', 'min_cx']
53 |
54 | # Constants for calculating half sphere exposure statistics
55 | AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY-'
56 | AMINO_ACID_IDX = dict(zip(AMINO_ACIDS, range(len(AMINO_ACIDS))))
57 |
58 | # Default fill values for missing features
59 | HSAAC_DIM = 42 # We have 2 + (2 * 20) HSAAC values from the two instances of the unknown residue symbol '-'
60 | DEFAULT_MISSING_FEAT_VALUE = np.nan
61 | DEFAULT_MISSING_SS = '-'
62 | DEFAULT_MISSING_RSA = DEFAULT_MISSING_FEAT_VALUE
63 | DEFAULT_MISSING_RD = DEFAULT_MISSING_FEAT_VALUE
64 | DEFAULT_MISSING_PROTRUSION_INDEX = [DEFAULT_MISSING_FEAT_VALUE for _ in range(6)]
65 | DEFAULT_MISSING_HSAAC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(HSAAC_DIM)]
66 | DEFAULT_MISSING_CN = DEFAULT_MISSING_FEAT_VALUE
67 | DEFAULT_MISSING_SEQUENCE_FEATS = np.array([DEFAULT_MISSING_FEAT_VALUE for _ in range(27)])
68 | DEFAULT_MISSING_NORM_VEC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(3)]
69 |
70 | # Default number of NaN values allowed in a specific column before imputing missing features of the column with zero
71 | NUM_ALLOWABLE_NANS = 5
72 |
73 | # Features to be one-hot encoded during graph processing and what their values could be
74 | FEAT_COLS = [
75 | 'resname', # By default, leave out one-hot encoding of residues' type to decrease feature redundancy
76 | 'ss_value',
77 | 'rsa_value',
78 | 'rd_value'
79 | ]
80 | FEAT_COLS.extend(
81 | PSAIA_COLUMNS +
82 | ['hsaac',
83 | 'cn_value',
84 | 'sequence_feats',
85 | 'amide_norm_vec',
86 | # 'element' # For atom-level learning only
87 | ])
88 |
89 | ALLOWABLE_FEATS = [
90 | # By default, leave out one-hot encoding of residues' type to decrease feature redundancy
91 | ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR",
92 | "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"],
93 | ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values
94 | [], # Empty list means take scalar value as is
95 | [],
96 | [],
97 | [],
98 | [],
99 | [],
100 | [],
101 | [],
102 | [[]], # Doubly-nested, empty list means take first-level nested list as is
103 | [],
104 | [[]],
105 | [[]],
106 | # ['C', 'O', 'N', 'S'] # For atom-level learning only
107 | ]
108 |
--------------------------------------------------------------------------------
/project/utils/training_constants.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code curated for NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 | import numpy as np
5 |
6 | # Dataset-global node count limits to restrict computational learning complexity
7 | ATOM_COUNT_LIMIT = 2048 # Default atom count filter for DIPS-Plus when encoding complexes at an atom-based level
8 | RESIDUE_COUNT_LIMIT = 256 # Default residue count limit for DIPS-Plus (empirically determined for smoother training)
9 | NODE_COUNT_LIMIT = 2304 # An upper-bound on the node count limit for Geometric Transformers - equal to 9-sized batch
10 | KNN = 20 # Default number of nearest neighbors to query for during graph message passing
11 |
12 | # The PDB codes of structures added between DB4 and DB5 (to be used for testing dataset)
13 | DB5_TEST_PDB_CODES = ['3R9A', '4GAM', '3AAA', '4H03', '1EXB',
14 | '2GAF', '2GTP', '3RVW', '3SZK', '4IZ7',
15 | '4GXU', '3BX7', '2YVJ', '3V6Z', '1M27',
16 | '4FQI', '4G6J', '3BIW', '3PC8', '3HI6',
17 | '2X9A', '3HMX', '2W9E', '4G6M', '3LVK',
18 | '1JTD', '3H2V', '4DN4', 'BP57', '3L5W',
19 | '3A4S', 'CP57', '3DAW', '3VLB', '3K75',
20 | '2VXT', '3G6D', '3EO1', '4JCV', '4HX3',
21 | '3F1P', '3AAD', '3EOA', '3MXW', '3L89',
22 | '4M76', 'BAAD', '4FZA', '4LW4', '1RKE',
23 | '3FN1', '3S9D', '3H11', '2A1A', '3P57']
24 |
25 | # Default fill values for missing features
26 | HSAAC_DIM = 42 # We have 2 + (2 * 20) HSAAC values from the two instances of the unknown residue symbol '-'
27 | DEFAULT_MISSING_FEAT_VALUE = np.nan
28 | DEFAULT_MISSING_SS = '-'
29 | DEFAULT_MISSING_RSA = DEFAULT_MISSING_FEAT_VALUE
30 | DEFAULT_MISSING_RD = DEFAULT_MISSING_FEAT_VALUE
31 | DEFAULT_MISSING_PROTRUSION_INDEX = [DEFAULT_MISSING_FEAT_VALUE for _ in range(6)]
32 | DEFAULT_MISSING_HSAAC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(HSAAC_DIM)]
33 | DEFAULT_MISSING_CN = DEFAULT_MISSING_FEAT_VALUE
34 | DEFAULT_MISSING_SEQUENCE_FEATS = np.array([DEFAULT_MISSING_FEAT_VALUE for _ in range(27)])
35 | DEFAULT_MISSING_NORM_VEC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(3)]
36 |
37 | # Dict for converting three letter codes to one letter codes
38 | D3TO1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
39 | 'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
40 | 'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
41 | 'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
42 |
43 | # PSAIA features to encode as DataFrame columns
44 | PSAIA_COLUMNS = ['avg_cx', 's_avg_cx', 's_ch_avg_cx', 's_ch_s_avg_cx', 'max_cx', 'min_cx']
45 |
46 | # Features to be one-hot encoded during graph processing and what their values could be
47 | FEAT_COLS = [
48 | 'resname', # [7:27]
49 | 'ss_value', # [27:35]
50 | 'rsa_value', # [35:36]
51 | 'rd_value' # [36:37]
52 | ]
53 | FEAT_COLS.extend(
54 | PSAIA_COLUMNS + # [37:43]
55 | ['hsaac', # [43:85]
56 | 'cn_value', # [85:86]
57 | 'sequence_feats', # [86:113]
58 | 'amide_norm_vec', # [Stored separately]
59 | # 'element' # For atom-level learning only
60 | ])
61 |
62 | ALLOWABLE_FEATS = [
63 | ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR",
64 | "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"],
65 | ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values
66 | [], # Empty list means take scalar value as is
67 | [],
68 | [],
69 | [],
70 | [],
71 | [],
72 | [],
73 | [],
74 | [[]], # Doubly-nested, empty list means take first-level nested list as is
75 | [],
76 | [[]],
77 | [[]],
78 | # ['C', 'O', 'N', 'S'] # For atom-level learning only
79 | ]
80 |
81 | # A schematic of which tensor indices correspond to which node and edge features
82 | FEATURE_INDICES = {
83 | # Node feature indices
84 | 'node_pos_enc': 0,
85 | 'node_geo_feats_start': 1,
86 | 'node_geo_feats_end': 7,
87 | 'node_dips_plus_feats_start': 7,
88 | 'node_dips_plus_feats_end': 113,
89 | # Edge feature indices
90 | 'edge_pos_enc': 0,
91 | 'edge_weights': 1,
92 | 'edge_dist_feats_start': 2,
93 | 'edge_dist_feats_end': 20,
94 | 'edge_dir_feats_start': 20,
95 | 'edge_dir_feats_end': 23,
96 | 'edge_orient_feats_start': 23,
97 | 'edge_orient_feats_end': 27,
98 | 'edge_amide_angles': 27
99 | }
100 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | norecursedirs =
3 | .git
4 | dist
5 | build
6 | addopts =
7 | --strict
8 | --doctest-modules
9 | --durations=0
10 |
11 | [coverage:report]
12 | exclude_lines =
13 | pragma: no-cover
14 | pass
15 |
16 | [flake8]
17 | max-line-length = 120
18 | exclude = .tox,*.egg,build,temp
19 | select = E,W,F
20 | doctests = True
21 | verbose = 2
22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
23 | format = pylint
24 | # see: https://www.flake8rules.com/
25 | ignore =
26 | E731 # Do not assign a lambda expression, use a def
27 | W504 # Line break occurred after a binary operator
28 | F401 # Module imported but unused
29 | F841 # Local variable name is assigned to but never used
30 | W605 # Invalid escape sequence 'x'
31 |
32 | # setup.cfg or tox.ini
33 | [check-manifest]
34 | ignore =
35 | *.yml
36 | .github
37 | .github/*
38 |
39 | [metadata]
40 | # license_file = LICENSE
41 | # description-file = README.md
42 | long_description = file:README.md
43 | long_description_content_type = text/markdown
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(
6 | name='DIPS-Plus',
7 | version='1.2.0',
8 | description='The Enhanced Database of Interacting Protein Structures for Interface Prediction',
9 | author='Alex Morehead',
10 | author_email='acmwhb@umsystem.edu',
11 | url='https://github.com/BioinfoMachineLearning/DIPS-Plus',
12 | install_requires=[
13 | 'setuptools==65.5.1',
14 | 'dill==0.3.3',
15 | 'tqdm==4.49.0',
16 | 'Sphinx==4.0.1',
17 | 'easy-parallel-py3==0.1.6.4',
18 | 'click==7.0.0',
19 | ],
20 | packages=find_packages(),
21 | )
22 |
--------------------------------------------------------------------------------