├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Explore_Propedia.ipynb ├── Explore_ProtCID-updated.ipynb ├── Filter_out_protcid_large_complex.ipynb ├── LICENSE ├── Preprocess_PDBBind_full_complexes.ipynb ├── README.md ├── THIRD-PARTY-NOTICE ├── ablation_pdbbind.sh ├── ablation_protcid.sh ├── e2e_hgvp_experiments.sh ├── eval_casf.sh ├── evaluate.py ├── evaluate_casf2016.py ├── figs ├── GoGs_of_complexes.png └── GoGs_of_molecules.png ├── h_gvp_experiments.sh ├── h_gvp_experiments_bin_clf.sh ├── integration_test.sh ├── multistage_experiments.sh ├── ppi ├── __init__.py ├── data.py ├── data_utils │ ├── __init__.py │ ├── camp │ │ ├── crawl.py │ │ ├── query-mapping.py │ │ ├── step1_pdb_process.py │ │ ├── step2_pepBDB_pep_bindingsites.py │ │ ├── step3_iupred2a.py │ │ ├── step3_pssm.py │ │ ├── step3_ss.py │ │ ├── step4_agg_files.py │ │ └── step5_PreProcessFeatures.py │ ├── contact_map_utils.py │ ├── pignet_featurizers.py │ ├── polypeptide_featurizers.py │ ├── residue_featurizers.py │ └── xpdb.py ├── gvp.py ├── model.py ├── modules.py └── transfer.py ├── preprocess_diffdock_output.py ├── requirements.txt ├── setup_env.sh ├── test.txt ├── test_run.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EGGNet: Equivariant Graph-of-Graphs Neural Network 2 | 3 | Source code for "[EGGNet, a generalizable geometric deep learning framework for protein complex pose scoring](https://www.biorxiv.org/content/10.1101/2023.03.22.533800v1)" 4 | 5 | 6 | 7 | ## Dependencies 8 | 9 | All experiments were performed in Python 3.8 with Pytorch (v1.10). 10 | 11 | To install all dependencies run: 12 | ``` 13 | $ pip install dgl-cu111 dglgo -f https://data.dgl.ai/wheels/repo.html 14 | $ pip install -r requirements.txt 15 | ``` 16 | 17 | 18 | ## Data preparation 19 | 20 | PDBbind/CASF-2016 data can be downloaded using [the script](https://github.com/ACE-KAIST/PIGNet/blob/main/data/download_train_data.sh) from the PIGNet repository. 21 | The included python notebooks can be used as a guide for data prep in order to reproduce results or train on new datasets. This command was used to download DC and MANY data from DeepRank: `rsync -av rsync://data.sbgrid.org/10.15785/SBGRID/843`. Note that the whole download is 500GB. A script for ProtCid-like data can be used for classification tasks (`prep_eggnet_data_protcid_model.py`). This script will be run twice. Once for each label. The input is a directory of PDB structures. 22 | ``` 23 | prep_eggnet_data_protcid_model.py --dir many_xtal --label 0 --threshold 12 --skip_filter 24 | prep_eggnet_data_protcid_model.py --dir many_bio --label 1 --threshold 12 --skip_filter --datafile processed/train_full.csv 25 | ``` 26 | 27 | Note that datasets are hard-coded, so if you have a new dataset that is very different than what EggNet was trained on, you will need to modify the code to add a new dataset. 28 | 29 | 30 | ## Training 31 | 32 | Training of EGGNet and competing models for protein complex scoring tasks can be done in `train.py`, which utilizes the [PyTorch Lightning Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#). All of the [trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags) in PyTorch Lightning are supported. To see the usage, run: 33 | 34 | ``` 35 | $ python train.py -h 36 | usage: train.py [-h] [--logger [LOGGER]] [--enable_checkpointing [ENABLE_CHECKPOINTING]] [--default_root_dir DEFAULT_ROOT_DIR] [--gradient_clip_val GRADIENT_CLIP_VAL] 37 | [--gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM] [--num_nodes NUM_NODES] [--num_processes NUM_PROCESSES] [--devices DEVICES] [--gpus GPUS] [--auto_select_gpus [AUTO_SELECT_GPUS]] 38 | [--tpu_cores TPU_CORES] [--ipus IPUS] [--enable_progress_bar [ENABLE_PROGRESS_BAR]] [--overfit_batches OVERFIT_BATCHES] [--track_grad_norm TRACK_GRAD_NORM] 39 | [--check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH] [--fast_dev_run [FAST_DEV_RUN]] [--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES] [--max_epochs MAX_EPOCHS] 40 | [--min_epochs MIN_EPOCHS] [--max_steps MAX_STEPS] [--min_steps MIN_STEPS] [--max_time MAX_TIME] [--limit_train_batches LIMIT_TRAIN_BATCHES] [--limit_val_batches LIMIT_VAL_BATCHES] 41 | [--limit_test_batches LIMIT_TEST_BATCHES] [--limit_predict_batches LIMIT_PREDICT_BATCHES] [--val_check_interval VAL_CHECK_INTERVAL] [--log_every_n_steps LOG_EVERY_N_STEPS] 42 | [--accelerator ACCELERATOR] [--strategy STRATEGY] [--sync_batchnorm [SYNC_BATCHNORM]] [--precision PRECISION] [--enable_model_summary [ENABLE_MODEL_SUMMARY]] 43 | [--weights_save_path WEIGHTS_SAVE_PATH] [--num_sanity_val_steps NUM_SANITY_VAL_STEPS] [--resume_from_checkpoint RESUME_FROM_CHECKPOINT] [--profiler PROFILER] [--benchmark [BENCHMARK]] 44 | [--deterministic [DETERMINISTIC]] [--reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS] [--auto_lr_find [AUTO_LR_FIND]] [--replace_sampler_ddp [REPLACE_SAMPLER_DDP]] 45 | [--detect_anomaly [DETECT_ANOMALY]] [--auto_scale_batch_size [AUTO_SCALE_BATCH_SIZE]] [--plugins PLUGINS] [--amp_backend AMP_BACKEND] [--amp_level AMP_LEVEL] 46 | [--move_metrics_to_cpu [MOVE_METRICS_TO_CPU]] [--multiple_trainloader_mode MULTIPLE_TRAINLOADER_MODE] [--model_name MODEL_NAME] 47 | 48 | optional arguments: 49 | -h, --help show this help message and exit 50 | --model_name MODEL_NAME 51 | Choose from gvp, hgvp, multistage-gvp, multistage-hgvp 52 | 53 | pl.Trainer: 54 | --logger [LOGGER] Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers 55 | # other pl.Trainer flags... 56 | ``` 57 | 58 | Training scripts for ProtCid and pdbbind can be found in `ablation_protcid.sh` and `ablation_pdbbind.sh` 59 | Example is below for ProtCid-like data joint training with GIN featurizer. 60 | 61 | ``` 62 | n_gpus=4 63 | num_workers=8 64 | 65 | suffix=full 66 | 67 | residue_featurizer_name=gin-supervised-contextpred-mean # to change this to pretrained GNN residue featurizer 68 | dataset_name=ProtCID 69 | bs=16 70 | lr=1e-4 71 | max_epochs=1000 72 | early_stopping_patience=50 73 | seed=42 74 | 75 | node_h_dim=200\ 32 76 | edge_h_dim=64\ 2 77 | num_layers=3 78 | crop=12 79 | 80 | data_dir=/home/ec2-user/SageMaker/eggnet-equivariant-graph-of-graph-neural-network/crop_${crop}_no_filter 81 | root_dir=/home/ec2-user/SageMaker/eggnet_training_results/crop${crop} 82 | 83 | # 3: pretrained GNN joint training GVP None 84 | python train.py --accelerator gpu \ 85 | --model_name hgvp \ 86 | --devices $n_gpus \ 87 | --num_workers 16 \ 88 | --precision 32 \ 89 | --dataset_name $dataset_name \ 90 | --input_type complex \ 91 | --residue_featurizer_name $residue_featurizer_name-grad \ 92 | --data_dir $data_dir \ 93 | --data_suffix $suffix \ 94 | --bs $bs \ 95 | --lr $lr \ 96 | --max_epochs $max_epochs \ 97 | --early_stopping_patience $early_stopping_patience \ 98 | --residual \ 99 | --node_h_dim $node_h_dim \ 100 | --edge_h_dim $edge_h_dim \ 101 | --num_layers $num_layers \ 102 | --default_root_dir ${root_dir}/3_ProtCID_t6_small_HGVP_GIN \ 103 | --random_seed $seed 104 | ``` 105 | 106 | ## Evaluation 107 | EggNet is dataset-centric, so all inputs will need to be prepped either through the notebook or script. Once prepped, an example evaluation command is below: 108 | ``` 109 | python evaluate.py --checkpoint_path ../eggnet_training_results/crop12/6_ProtCID_Molt5-small/lightning_logs/version_0 --evaluate_type classification --dataset_name ProtCID --input_type complex --data_suffix full --data_dir /home/ec2-user/SageMaker/eggnet-equivariant-graph-of-graph-neural-network/crop_12_no_filter --residue_featurizer_name MolT5-small-grad --model_name hgvp --num_workers 8 --bs 4 --dataset_alias protcid_test 110 | ``` 111 | 112 | ## Citation 113 | 114 | Please cite the following preprint: 115 | ``` 116 | @article {Wang2023.03.22.533800, 117 | author = {Wang, Zichen and Brand, Ryan and Adolf-Bryfogle, Jared and Grewal, Jasleen and Qi, Yanjun and Combs, Steven A. and Golovach, Nataliya and Alford, Rebecca and Rangwala, Huzefa and Clark, Peter M.}, 118 | title = {EGGNet, a generalizable geometric deep learning framework for protein complex pose scoring}, 119 | elocation-id = {2023.03.22.533800}, 120 | year = {2023}, 121 | doi = {10.1101/2023.03.22.533800}, 122 | publisher = {Cold Spring Harbor Laboratory}, 123 | abstract = {Computational prediction of molecule-protein interactions has been key for developing new molecules to interact with a target protein for therapeutics development. Literature includes two independent streams of approaches: (1) predicting protein-protein interactions between naturally occurring proteins and (2) predicting the binding affinities between proteins and small molecule ligands (aka drug target interaction, or DTI). Studying the two problems in isolation has limited computational models{\textquoteright} ability to generalize across tasks, both of which ultimately involve non-covalent interactions with a protein target. In this work, we developed Equivariant Graph of Graphs neural Network (EGGNet), a geometric deep learning framework for molecule-protein binding predictions that can handle three types of molecules for interacting with a target protein: (1) small molecules, (2) synthetic peptides and (3) natural proteins. EGGNet leverages a graph of graphs (GoGs) representation constructed from the molecule structures at atomic-resolution and utilizes a multiresolution equivariant graph neural network (GNN) to learn from such representations. In addition, EGGNet gets inspired by biophysics and makes use of both atom- and residue-level interactions, which greatly improve EGGNet{\textquoteright}s ability to rank candidate poses from blind docking. EGGNet achieves competitive performance on both a public proteinsmall molecule binding affinity prediction task (80.2\% top-1 success rate on CASF-2016) and an synthetic protein interface prediction task (88.4\% AUPR). We envision that the proposed geometric deep learning framework can generalize to many other protein interaction prediction problems, such as binding site prediction and molecular docking, helping to accelerate protein engineering and structure-based drug development.Competing Interest StatementThe authors have declared no competing interest.}, 124 | URL = {https://www.biorxiv.org/content/early/2023/03/22/2023.03.22.533800}, 125 | eprint = {https://www.biorxiv.org/content/early/2023/03/22/2023.03.22.533800.full.pdf}, 126 | journal = {bioRxiv} 127 | } 128 | ``` 129 | 130 | ## Security 131 | 132 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 133 | 134 | ## License 135 | 136 | This library is licensed under the MIT-0 License. See the LICENSE file. 137 | -------------------------------------------------------------------------------- /ablation_pdbbind.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | # ablation studies for the GVP models on PDBBind regression task 6 | # global variables shared across all runs: 7 | n_gpus=4 8 | num_workers=8 9 | pdbbind_data=/home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring 10 | residue_featurizer_name=gin-supervised-contextpred-mean # to change this to pretrained GNN residue featurizer 11 | dataset_name=PDBBind 12 | bs=16 13 | lr=1e-4 14 | max_epochs=1000 15 | early_stopping_patience=50 16 | seed=42 17 | 18 | node_h_dim=200\ 32 19 | edge_h_dim=64\ 2 20 | num_layers=3 21 | 22 | # row2: pretrained GNN GVP None 23 | for seed in 43 44; do 24 | python train.py --accelerator gpu \ 25 | --model_name gvp \ 26 | --devices $n_gpus \ 27 | --num_workers $num_workers \ 28 | --persistent_workers True \ 29 | --precision 16 \ 30 | --dataset_name $dataset_name \ 31 | --input_type complex \ 32 | --residue_featurizer_name $residue_featurizer_name \ 33 | --data_dir $pdbbind_data \ 34 | --bs $bs \ 35 | --lr $lr \ 36 | --max_epochs $max_epochs \ 37 | --early_stopping_patience $early_stopping_patience \ 38 | --residual \ 39 | --node_h_dim $node_h_dim \ 40 | --edge_h_dim $edge_h_dim \ 41 | --num_layers $num_layers \ 42 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_GIN \ 43 | --random_seed $seed 44 | done 45 | 46 | # row4: pretrained GNN joint training GVP None 47 | for seed in 43 44; do 48 | CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py --accelerator gpu \ 49 | --model_name hgvp \ 50 | --devices $n_gpus \ 51 | --num_workers $num_workers \ 52 | --precision 32 \ 53 | --dataset_name $dataset_name \ 54 | --input_type complex \ 55 | --residue_featurizer_name $residue_featurizer_name-grad \ 56 | --data_dir $pdbbind_data \ 57 | --bs $bs \ 58 | --lr $lr \ 59 | --max_epochs $max_epochs \ 60 | --early_stopping_patience $early_stopping_patience \ 61 | --residual \ 62 | --node_h_dim $node_h_dim \ 63 | --edge_h_dim $edge_h_dim \ 64 | --num_layers $num_layers \ 65 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_HGVP_GIN \ 66 | --random_seed $seed 67 | done 68 | 69 | # row6: pretrained GNN GVP E_int 70 | for seed in 43 44; do 71 | python train.py --accelerator gpu \ 72 | --model_name gvp \ 73 | --devices $n_gpus \ 74 | --num_workers $num_workers \ 75 | --persistent_workers True \ 76 | --precision 16 \ 77 | --dataset_name $dataset_name \ 78 | --input_type complex \ 79 | --residue_featurizer_name $residue_featurizer_name \ 80 | --use_energy_decoder \ 81 | --is_hetero \ 82 | --data_dir $pdbbind_data \ 83 | --bs $bs \ 84 | --lr $lr \ 85 | --max_epochs $max_epochs \ 86 | --early_stopping_patience $early_stopping_patience \ 87 | --residual \ 88 | --node_h_dim $node_h_dim \ 89 | --edge_h_dim $edge_h_dim \ 90 | --num_layers $num_layers \ 91 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_GIN_energy \ 92 | --random_seed $seed 93 | done 94 | 95 | seed=42 96 | num_layers=6 97 | python train.py --accelerator gpu \ 98 | --model_name gvp \ 99 | --devices $n_gpus \ 100 | --num_workers $num_workers \ 101 | --persistent_workers True \ 102 | --precision 16 \ 103 | --dataset_name $dataset_name \ 104 | --input_type complex \ 105 | --residue_featurizer_name $residue_featurizer_name \ 106 | --use_energy_decoder \ 107 | --is_hetero \ 108 | --data_dir $pdbbind_data \ 109 | --bs $bs \ 110 | --lr $lr \ 111 | --max_epochs $max_epochs \ 112 | --early_stopping_patience $early_stopping_patience \ 113 | --residual \ 114 | --node_h_dim $node_h_dim \ 115 | --edge_h_dim $edge_h_dim \ 116 | --num_layers $num_layers \ 117 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_GIN_energy \ 118 | --random_seed $seed 119 | 120 | # row8: pretrained GNN joint training GVP E_int 121 | n_gpus=8 122 | bs=8 123 | num_workers=8 124 | for seed in 43 44; do 125 | python train.py --accelerator gpu \ 126 | --model_name hgvp \ 127 | --devices $n_gpus \ 128 | --num_workers $num_workers \ 129 | --precision 32 \ 130 | --dataset_name $dataset_name \ 131 | --input_type complex \ 132 | --residue_featurizer_name $residue_featurizer_name-grad \ 133 | --use_energy_decoder \ 134 | --is_hetero \ 135 | --data_dir $pdbbind_data \ 136 | --bs $bs \ 137 | --lr $lr \ 138 | --max_epochs $max_epochs \ 139 | --early_stopping_patience $early_stopping_patience \ 140 | --residual \ 141 | --node_h_dim $node_h_dim \ 142 | --edge_h_dim $edge_h_dim \ 143 | --num_layers $num_layers \ 144 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_HGVP_GIN_energy \ 145 | --random_seed $seed 146 | done 147 | 148 | ## Evaluation 149 | eval_data_dir=/home/ec2-user/SageMaker/efs/data/PIGNet/data/casf2016_processed 150 | 151 | # row2: pretrained GNN GVP None 152 | python evaluate_casf2016.py --model_name gvp \ 153 | --num_workers 8 \ 154 | --data_dir $eval_data_dir \ 155 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_GIN/lightning_logs/version_2 \ 156 | --residue_featurizer_name $residue_featurizer_name 157 | 158 | # row4: pretrained GNN joint training GVP None 159 | python evaluate_casf2016.py --model_name hgvp \ 160 | --num_workers 8 \ 161 | --data_dir $eval_data_dir \ 162 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_HGVP_GIN/lightning_logs/version_3 \ 163 | --residue_featurizer_name $residue_featurizer_name-grad 164 | 165 | 166 | # row6: pretrained GNN GVP E_int 167 | python evaluate_casf2016.py --model_name gvp \ 168 | --num_workers 8 \ 169 | --data_dir $eval_data_dir \ 170 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_GIN_energy/lightning_logs/version_4 \ 171 | --residue_featurizer_name $residue_featurizer_name \ 172 | --use_energy_decoder \ 173 | --is_hetero \ 174 | --bs 16 175 | 176 | # row8: pretrained GNN joint training GVP E_int 177 | python evaluate_casf2016.py --model_name hgvp \ 178 | --num_workers 8 \ 179 | --data_dir $eval_data_dir \ 180 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_HGVP_GIN_energy/lightning_logs/version_5 \ 181 | --residue_featurizer_name $residue_featurizer_name-grad \ 182 | --use_energy_decoder \ 183 | --is_hetero \ 184 | --bs 16 -------------------------------------------------------------------------------- /ablation_protcid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | n_gpus=4 6 | num_workers=8 7 | # data_dir=/home/ec2-user/SageMaker/efs/data/ProtCID/JaredJanssen_Benchmark_thres_10 8 | data_dir=/home/ec2-user/SageMaker/efs/data/ProtCID/JaredJanssen_Benchmark_thres_6 9 | residue_featurizer_name=gin-supervised-contextpred-mean # to change this to pretrained GNN residue featurizer 10 | dataset_name=ProtCID 11 | bs=16 12 | lr=1e-4 13 | max_epochs=1000 14 | early_stopping_patience=50 15 | seed=42 16 | 17 | node_h_dim=200\ 32 18 | edge_h_dim=64\ 2 19 | num_layers=3 20 | 21 | # row2: pretrained GNN GVP None 22 | python train.py --accelerator gpu \ 23 | --model_name gvp \ 24 | --devices $n_gpus \ 25 | --num_workers $num_workers \ 26 | --persistent_workers True \ 27 | --precision 16 \ 28 | --dataset_name $dataset_name \ 29 | --input_type complex \ 30 | --residue_featurizer_name $residue_featurizer_name \ 31 | --data_dir $data_dir \ 32 | --data_suffix small_filt1e5 \ 33 | --bs $bs \ 34 | --lr $lr \ 35 | --max_epochs $max_epochs \ 36 | --early_stopping_patience $early_stopping_patience \ 37 | --residual \ 38 | --node_h_dim $node_h_dim \ 39 | --edge_h_dim $edge_h_dim \ 40 | --num_layers $num_layers \ 41 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t6_small_GVP_GIN \ 42 | --random_seed $seed 43 | 44 | # row3: pretrained GNN MS-GVP None 45 | python train.py --accelerator gpu \ 46 | --model_name multistage-gvp \ 47 | --devices $n_gpus \ 48 | --num_workers $num_workers \ 49 | --persistent_workers True \ 50 | --precision 16 \ 51 | --dataset_name $dataset_name \ 52 | --input_type multistage-complex \ 53 | --residue_featurizer_name $residue_featurizer_name \ 54 | --data_dir $data_dir \ 55 | --data_suffix small_filt1e5 \ 56 | --bs $bs \ 57 | --lr $lr \ 58 | --max_epochs $max_epochs \ 59 | --early_stopping_patience $early_stopping_patience \ 60 | --residual \ 61 | --stage1_node_h_dim $node_h_dim \ 62 | --stage1_edge_h_dim $edge_h_dim \ 63 | --stage1_num_layers $num_layers \ 64 | --stage2_node_h_dim $node_h_dim \ 65 | --stage2_edge_h_dim $edge_h_dim \ 66 | --stage2_num_layers $num_layers \ 67 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/brandry/ProtCID_t6_small_MS-GVP_GIN \ 68 | --random_seed $seed 69 | 70 | # row4: pretrained GNN joint training GVP None 71 | python train.py --accelerator gpu \ 72 | --model_name hgvp \ 73 | --devices $n_gpus \ 74 | --num_workers 16 \ 75 | --precision 32 \ 76 | --dataset_name $dataset_name \ 77 | --input_type complex \ 78 | --residue_featurizer_name $residue_featurizer_name-grad \ 79 | --data_dir $data_dir \ 80 | --data_suffix small_filt1e5 \ 81 | --bs $bs \ 82 | --lr $lr \ 83 | --max_epochs $max_epochs \ 84 | --early_stopping_patience $early_stopping_patience \ 85 | --residual \ 86 | --node_h_dim $node_h_dim \ 87 | --edge_h_dim $edge_h_dim \ 88 | --num_layers $num_layers \ 89 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t6_small_HGVP_GIN \ 90 | --random_seed $seed 91 | 92 | python train.py --accelerator gpu \ 93 | --model_name hgvp \ 94 | --devices $n_gpus \ 95 | --num_workers 16 \ 96 | --precision 32 \ 97 | --dataset_name $dataset_name \ 98 | --input_type complex \ 99 | --residue_featurizer_name $residue_featurizer_name-grad \ 100 | --data_dir $data_dir \ 101 | --data_suffix full_filt1e5 \ 102 | --bs $bs \ 103 | --lr $lr \ 104 | --max_epochs $max_epochs \ 105 | --early_stopping_patience $early_stopping_patience \ 106 | --residual \ 107 | --node_h_dim $node_h_dim \ 108 | --edge_h_dim $edge_h_dim \ 109 | --num_layers $num_layers \ 110 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t6_full_HGVP_GIN \ 111 | --random_seed $seed 112 | 113 | # row5: pretrained GNN joint training MS-GVP None 114 | python train.py --accelerator gpu \ 115 | --model_name multistage-hgvp \ 116 | --devices $n_gpus \ 117 | --num_workers $num_workers \ 118 | --precision 32 \ 119 | --dataset_name $dataset_name \ 120 | --input_type multistage-complex \ 121 | --residue_featurizer_name $residue_featurizer_name-grad \ 122 | --data_dir $data_dir \ 123 | --data_suffix small_filt1e5 \ 124 | --bs $bs \ 125 | --lr $lr \ 126 | --max_epochs $max_epochs \ 127 | --early_stopping_patience $early_stopping_patience \ 128 | --residual \ 129 | --stage1_node_h_dim $node_h_dim \ 130 | --stage1_edge_h_dim $edge_h_dim \ 131 | --stage1_num_layers $num_layers \ 132 | --stage2_node_h_dim $node_h_dim \ 133 | --stage2_edge_h_dim $edge_h_dim \ 134 | --stage2_num_layers $num_layers \ 135 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/brandry/ProtCID_t6_small_MS-HGVP_GIN \ 136 | --random_seed $seed 137 | 138 | # row6: pretrained GNN GVP E_int 139 | bs=4 140 | lr=1e-4 141 | n_gpus=4 142 | python train.py --accelerator gpu \ 143 | --model_name gvp \ 144 | --devices $n_gpus \ 145 | --num_workers $num_workers \ 146 | --persistent_workers True \ 147 | --precision 16 \ 148 | --dataset_name $dataset_name \ 149 | --input_type complex \ 150 | --residue_featurizer_name $residue_featurizer_name \ 151 | --use_energy_decoder \ 152 | --is_hetero \ 153 | --data_dir $data_dir \ 154 | --data_suffix small_filt1e5 \ 155 | --bs $bs \ 156 | --lr $lr \ 157 | --max_epochs $max_epochs \ 158 | --early_stopping_patience $early_stopping_patience \ 159 | --residual \ 160 | --node_h_dim $node_h_dim \ 161 | --edge_h_dim $edge_h_dim \ 162 | --num_layers $num_layers \ 163 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t6_small_GVP_GIN_energy \ 164 | --random_seed $seed \ 165 | --loss_der1_ratio 0 \ 166 | --loss_der2_ratio 0 167 | 168 | # row6: with final energy bias 169 | python train.py --accelerator gpu \ 170 | --model_name gvp \ 171 | --devices $n_gpus \ 172 | --num_workers $num_workers \ 173 | --persistent_workers True \ 174 | --precision 16 \ 175 | --dataset_name $dataset_name \ 176 | --input_type complex \ 177 | --residue_featurizer_name $residue_featurizer_name \ 178 | --use_energy_decoder \ 179 | --is_hetero \ 180 | --data_dir $data_dir \ 181 | --data_suffix small_filt1e5 \ 182 | --bs $bs \ 183 | --lr $lr \ 184 | --max_epochs $max_epochs \ 185 | --early_stopping_patience $early_stopping_patience \ 186 | --residual \ 187 | --node_h_dim $node_h_dim \ 188 | --edge_h_dim $edge_h_dim \ 189 | --num_layers $num_layers \ 190 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t6_small_GVP_GIN_energy \ 191 | --random_seed $seed \ 192 | --loss_der1_ratio 0 \ 193 | --loss_der2_ratio 0 \ 194 | --final_energy_bias 195 | 196 | # row6: smaller network 197 | bs=4 198 | lr=1e-3 199 | n_gpus=8 200 | python train.py --accelerator gpu \ 201 | --model_name gvp \ 202 | --devices $n_gpus \ 203 | --num_workers $num_workers \ 204 | --persistent_workers True \ 205 | --precision 16 \ 206 | --dataset_name $dataset_name \ 207 | --input_type complex \ 208 | --residue_featurizer_name $residue_featurizer_name \ 209 | --use_energy_decoder \ 210 | --is_hetero \ 211 | --data_dir $data_dir \ 212 | --data_suffix small_filt1e6 \ 213 | --bs $bs \ 214 | --lr $lr \ 215 | --max_epochs $max_epochs \ 216 | --early_stopping_patience $early_stopping_patience \ 217 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_t10_small_GVP_GIN_energy \ 218 | --random_seed $seed \ 219 | --loss_der1_ratio 0 \ 220 | --loss_der2_ratio 0 221 | 222 | -------------------------------------------------------------------------------- /e2e_hgvp_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | python train.py --model_name hgvp \ 6 | --accelerator gpu \ 7 | --devices 1 \ 8 | --max_epochs 500 \ 9 | --precision 32 \ 10 | --num_layers 3 \ 11 | --node_h_dim 200 32 \ 12 | --edge_h_dim 64 2 \ 13 | --dataset_name PDBBind \ 14 | --input_type complex \ 15 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 16 | --residual \ 17 | --num_workers 8 \ 18 | --lr 1e-4 \ 19 | --bs 8 \ 20 | --early_stopping_patience 10 \ 21 | --residue_featurizer_name MolT5-small-grad \ 22 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_MolT5_grad 23 | 24 | python train.py --model_name hgvp \ 25 | --accelerator gpu \ 26 | --devices 4 \ 27 | --max_epochs 500 \ 28 | --precision 32 \ 29 | --num_layers 3 \ 30 | --node_h_dim 200 32 \ 31 | --edge_h_dim 64 2 \ 32 | --dataset_name PDBBind \ 33 | --input_type complex \ 34 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 35 | --residual \ 36 | --num_workers 8 \ 37 | --lr 1e-4 \ 38 | --bs 4 \ 39 | --early_stopping_patience 200 \ 40 | --residue_featurizer_name MolT5-small-grad \ 41 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_MolT5_grad 42 | 43 | python evaluate_casf2016.py --model_name hgvp \ 44 | --num_workers 8 \ 45 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/casf2016_processed \ 46 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_GVP_MolT5_grad/lightning_logs/version_23 \ 47 | --residue_featurizer_name MolT5-small-grad 48 | 49 | 50 | ## PDBBind bin-clf 51 | python train.py --model_name hgvp \ 52 | --accelerator gpu \ 53 | --devices 4 \ 54 | --max_epochs 500 \ 55 | --precision 32 \ 56 | --num_layers 3 \ 57 | --node_h_dim 200 32 \ 58 | --edge_h_dim 64 2 \ 59 | --dataset_name PDBBind \ 60 | --input_type complex \ 61 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 62 | --residual \ 63 | --num_workers 8 \ 64 | --lr 1e-5 \ 65 | --bs 4 \ 66 | --early_stopping_patience 10 \ 67 | --residue_featurizer_name MolT5-small-grad \ 68 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_bin_GVP_MolT5_grad \ 69 | --binary_cutoff 6.7 70 | -------------------------------------------------------------------------------- /eval_casf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | python evaluate_casf2016.py \ 6 | --model_name multistage-hgvp \ 7 | --input_type multistage-hetero \ 8 | --residue_featurizer_name MolT5-small-grad \ 9 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/brandry/PDBBind_MS-HGVP_hetero_energy/lightning_logs/version_9 \ 10 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/casf2016_processed/ \ 11 | --num_workers 8 \ 12 | --bs 16 \ 13 | --is_hetero \ 14 | --use_energy_decoder -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | Evaluate a trained pytorch-lightning model on a given dataset. 5 | """ 6 | from train import ( 7 | evaluate_graph_classification, 8 | get_datasets, 9 | evaluate_graph_regression, 10 | MODEL_CONSTRUCTORS, 11 | ) 12 | 13 | import pytorch_lightning as pl 14 | from torch.utils.data import DataLoader 15 | 16 | import argparse 17 | import os 18 | import json 19 | from pprint import pprint 20 | 21 | 22 | def load_model_from_checkpoint( 23 | checkpoint_path: str, model_name: str, classify=False 24 | ) -> pl.LightningModule: 25 | """Load a ptl model from checkpoint path. 26 | Args: 27 | checkpoint_path: the path to `lightning_logs/version_x` or 28 | the .ckpt file itself. 29 | model_name: should be a key in `MODEL_CONSTRUCTORS` 30 | """ 31 | if not checkpoint_path.endswith(".ckpt"): 32 | # find the .ckpt file 33 | ckpt_file = os.listdir(os.path.join(checkpoint_path, "checkpoints"))[0] 34 | ckpt_file_path = os.path.join( 35 | checkpoint_path, "checkpoints", ckpt_file 36 | ) 37 | else: 38 | ckpt_file_path = checkpoint_path 39 | # load the model from checkpoint 40 | ModelConstructor = MODEL_CONSTRUCTORS[model_name] 41 | model = ModelConstructor.load_from_checkpoint( 42 | ckpt_file_path, strict=False, classify=classify 43 | ) 44 | return model 45 | 46 | 47 | def main(args): 48 | pl.seed_everything(42, workers=True) 49 | # 1. Load data 50 | test_dataset = get_datasets( 51 | name=args.dataset_name, 52 | input_type=args.input_type, 53 | data_dir=args.data_dir, 54 | residue_featurizer_name=args.residue_featurizer_name, 55 | use_energy_decoder=args.use_energy_decoder, 56 | data_suffix=args.data_suffix, 57 | binary_cutoff=args.binary_cutoff, 58 | test_only=True, 59 | ) 60 | print( 61 | "Data loaded:", 62 | len(test_dataset), 63 | ) 64 | # 2. Prepare data loaders 65 | test_loader = DataLoader( 66 | test_dataset, 67 | batch_size=args.bs, 68 | shuffle=False, 69 | num_workers=args.num_workers, 70 | collate_fn=test_dataset.collate_fn, 71 | ) 72 | # 3. Prepare model 73 | classify = args.evaluate_type == "classification" 74 | model = load_model_from_checkpoint( 75 | args.checkpoint_path, 76 | args.model_name, 77 | classify=classify, 78 | ) 79 | # 4. Evaluate 80 | if not classify: 81 | eval_func = evaluate_graph_regression 82 | else: 83 | eval_func = evaluate_graph_classification 84 | 85 | scores = eval_func( 86 | model, 87 | test_loader, 88 | model_name=args.model_name, 89 | use_energy_decoder=args.use_energy_decoder, 90 | is_hetero=args.is_hetero, 91 | ) 92 | pprint(scores) 93 | # save scores to file 94 | json.dump( 95 | scores, 96 | open( 97 | os.path.join( 98 | args.checkpoint_path, f"{args.dataset_alias}_scores.json" 99 | ), 100 | "w", 101 | ), 102 | ) 103 | return 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument( 109 | "--model_name", 110 | type=str, 111 | default="gvp", 112 | help="Choose from %s" % ", ".join(list(MODEL_CONSTRUCTORS.keys())), 113 | ) 114 | parser.add_argument( 115 | "--checkpoint_path", 116 | type=str, 117 | help="ptl checkpoint path like `lightning_logs/version_x`", 118 | required=True, 119 | ) 120 | parser.add_argument( 121 | "--evaluate_type", 122 | type=str, 123 | help="regression or classification", 124 | default="regression", 125 | ) 126 | 127 | # dataset params 128 | parser.add_argument( 129 | "--dataset_name", 130 | help="dataset name", 131 | type=str, 132 | default="PepBDB", 133 | ) 134 | parser.add_argument( 135 | "--input_type", 136 | help="data input type", 137 | type=str, 138 | default="complex", 139 | ) 140 | parser.add_argument( 141 | "--data_dir", 142 | help="directory to dataset", 143 | type=str, 144 | default="", 145 | ) 146 | parser.add_argument( 147 | "--dataset_alias", 148 | help="Short name for the test dataset", 149 | type=str, 150 | required=True, 151 | ) 152 | parser.add_argument( 153 | "--data_suffix", 154 | help="used to distinguish different verions of the same dataset", 155 | type=str, 156 | default="full", 157 | ) 158 | parser.add_argument( 159 | "--binary_cutoff", 160 | help="used to convert PDBBind to a binary classification problem", 161 | type=float, 162 | default=None, 163 | ) 164 | parser.add_argument( 165 | "--bs", type=int, default=64, help="batch size for test data" 166 | ) 167 | parser.add_argument( 168 | "--num_workers", 169 | type=int, 170 | default=0, 171 | help="num_workers used in DataLoader", 172 | ) 173 | # featurizer params 174 | parser.add_argument( 175 | "--residue_featurizer_name", 176 | help="name of the residue featurizer", 177 | type=str, 178 | default="MACCS", 179 | ) 180 | parser.add_argument("--use_energy_decoder", action="store_true") 181 | parser.add_argument("--is_hetero", action="store_true") 182 | parser.set_defaults( 183 | use_energy_decoder=False, 184 | is_hetero=False, 185 | ) 186 | args = parser.parse_args() 187 | 188 | print("args:", args) 189 | # evaluate 190 | main(args) 191 | -------------------------------------------------------------------------------- /evaluate_casf2016.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | Evaluate a trained pytorch-lightning model on the three tasks on CASF2016: 5 | - Scoring => Spearman rho, R2 6 | - Docking => top1, 2, 3 success rates 7 | - Screening => Average EF, success rates 8 | """ 9 | 10 | from train import ( 11 | get_datasets, 12 | evaluate_graph_regression, 13 | MODEL_CONSTRUCTORS, 14 | predict_step, 15 | ) 16 | from evaluate import load_model_from_checkpoint 17 | 18 | import pytorch_lightning as pl 19 | from torch.utils.data import DataLoader 20 | 21 | from typing import Dict, List 22 | import argparse 23 | import os 24 | import json 25 | import glob 26 | from pprint import pprint 27 | from tqdm import tqdm 28 | import torch 29 | import numpy as np 30 | 31 | 32 | def choose_best_pose(id_to_pred: Dict[str, float]) -> Dict[str, float]: 33 | pairs = ["_".join(k.split("_")[:-1]) for k in id_to_pred.keys()] 34 | pairs = sorted(list(set(pairs))) 35 | retval = {p: [] for p in pairs} 36 | for key in id_to_pred.keys(): 37 | pair = "_".join(key.split("_")[:-1]) 38 | retval[pair].append(id_to_pred[key]) 39 | for key in retval.keys(): 40 | retval[key] = min(retval[key]) 41 | return retval 42 | 43 | 44 | def predict( 45 | model, 46 | data_loader, 47 | model_name="gvp", 48 | use_energy_decoder=False, 49 | is_hetero=False, 50 | ): 51 | """Make predictions on data from the data_loader""" 52 | # make predictions on test set 53 | device = torch.device("cuda:0") 54 | model = model.to(device) 55 | model.eval() 56 | 57 | all_preds = [] 58 | with torch.no_grad(): 59 | for batch in tqdm(data_loader): 60 | preds = predict_step( 61 | model, 62 | batch, 63 | device, 64 | model_name=model_name, 65 | use_energy_decoder=use_energy_decoder, 66 | is_hetero=is_hetero, 67 | ) 68 | preds = preds.to("cpu") 69 | preds = list(preds.numpy().reshape(-1)) 70 | all_preds.extend(preds) 71 | return all_preds 72 | 73 | 74 | def load_rmsd(rmsd_dir): 75 | """Load decoys docking RMSD from files""" 76 | rmsd_dir = os.path.join(rmsd_dir, "*_rmsd.dat") 77 | rmsd_filenames = glob.glob(rmsd_dir) 78 | id_to_rmsd = dict() 79 | for file in rmsd_filenames: 80 | with open(file, "r") as f: 81 | lines = f.readlines()[1:] 82 | lines = [line.split() for line in lines] 83 | lines = [[line[0], float(line[1])] for line in lines] 84 | dic = dict(lines) 85 | id_to_rmsd.update(dic) 86 | 87 | return id_to_rmsd 88 | 89 | 90 | def load_screening_target_file(target_file): 91 | # Load target file 92 | target_file = "/home/ec2-user/SageMaker/efs/data/PIGNet/casf2016_benchmark/TargetInfo.dat" 93 | 94 | true_binder_list = [] 95 | with open(target_file, "r") as f: 96 | lines = f.readlines()[9:] 97 | for line in lines: 98 | line = line.split() 99 | true_binder_list += [(line[0], elem) for elem in line[1:6]] 100 | return true_binder_list 101 | 102 | 103 | def evaluate_docking(id_to_pred, id_to_rmsd): 104 | # modified from PIGNet/casf2016_benchmark/docking_power.py 105 | # calculate topn success 106 | pdbs = sorted( 107 | list(set(key.split()[0].split("_")[0] for key in id_to_pred)) 108 | ) 109 | topn_successed_pdbs = [] 110 | for pdb in pdbs: 111 | selected_keys = [key for key in id_to_pred if pdb in key] 112 | pred = [id_to_pred[key] for key in selected_keys] 113 | pred, sorted_keys = zip(*sorted(zip(pred, selected_keys))) 114 | rmsd = [id_to_rmsd[key] for key in sorted_keys] 115 | topn_successed = [] 116 | for topn in [1, 2, 3]: 117 | if min(rmsd[:topn]) < 2.0: 118 | topn_successed.append(1) 119 | else: 120 | topn_successed.append(0) 121 | topn_successed_pdbs.append(topn_successed) 122 | 123 | scores = {} 124 | for topn in [1, 2, 3]: 125 | successed = [success[topn - 1] for success in topn_successed_pdbs] 126 | success_rate = np.mean(successed) 127 | scores["success_rate_top%d" % topn] = success_rate 128 | print(round(success_rate, 3), end="\t") 129 | 130 | return scores 131 | 132 | 133 | def evaluate_screening(id_to_pred, true_binder_list): 134 | ntb_top = [] 135 | ntb_total = [] 136 | high_affinity_success = [] 137 | pdbs = sorted(list(set([key.split("_")[0] for key in id_to_pred.keys()]))) 138 | for pdb in pdbs: 139 | selected_keys = [ 140 | key for key in id_to_pred.keys() if key.split("_")[0] == pdb 141 | ] 142 | preds = [id_to_pred[key] for key in selected_keys] 143 | preds, selected_keys = zip(*sorted(zip(preds, selected_keys))) 144 | true_binders = [ 145 | key 146 | for key in selected_keys 147 | if (key.split("_")[0], key.split("_")[1]) in true_binder_list 148 | ] 149 | ntb_top_pdb, ntb_total_pdb, high_affinity_success_pdb = [], [], [] 150 | for topn in [0.01, 0.05, 0.1]: 151 | n = int(topn * len(selected_keys)) 152 | top_keys = selected_keys[:n] 153 | n_top_true_binder = len(list(set(top_keys) & set(true_binders))) 154 | ntb_top_pdb.append(n_top_true_binder) 155 | ntb_total_pdb.append(len(true_binders) * topn) 156 | if f"{pdb}_{pdb}" in top_keys: 157 | high_affinity_success_pdb.append(1) 158 | else: 159 | high_affinity_success_pdb.append(0) 160 | ntb_top.append(ntb_top_pdb) 161 | ntb_total.append(ntb_total_pdb) 162 | high_affinity_success.append(high_affinity_success_pdb) 163 | 164 | scores = {} 165 | for i in range(3): 166 | ef = [] 167 | for j in range(len(ntb_total)): 168 | if ntb_total[j][i] == 0: 169 | continue 170 | ef.append(ntb_top[j][i] / ntb_total[j][i]) 171 | 172 | avg_ef = np.mean(ef) 173 | scores["avgEF_top_%d_pct" % (i + 1)] = avg_ef 174 | print(round(avg_ef, 3), end="\t") 175 | 176 | for i in range(3): 177 | success = [] 178 | for j in range(len(ntb_total)): 179 | if high_affinity_success[j][i] > 0: 180 | success.append(1) 181 | else: 182 | success.append(0) 183 | 184 | success_rate = np.mean(success) 185 | scores["success_rate_top%d" % (i + 1)] = success_rate 186 | print(round(success_rate, 3), end="\t") 187 | return scores 188 | 189 | 190 | def main(args): 191 | pl.seed_everything(42, workers=True) 192 | # 0. Prepare model 193 | model = load_model_from_checkpoint(args.checkpoint_path, args.model_name) 194 | if args.checkpoint_path.endswith(".ckpt"): 195 | checkpoint_path = os.path.dirname( 196 | os.path.dirname(args.checkpoint_path) 197 | ) 198 | else: 199 | checkpoint_path = args.checkpoint_path 200 | # 1. Scoring data 201 | print("Performing scoring task...") 202 | scoring_dataset = get_datasets( 203 | name="PDBBind", 204 | input_type=args.input_type, 205 | data_dir=os.path.join(args.data_dir, "scoring"), 206 | test_only=True, 207 | residue_featurizer_name=args.residue_featurizer_name, 208 | use_energy_decoder=args.use_energy_decoder, 209 | intra_mol_energy=args.intra_mol_energy, 210 | ) 211 | print( 212 | "Data loaded:", 213 | len(scoring_dataset), 214 | ) 215 | scoring_data_loader = DataLoader( 216 | scoring_dataset, 217 | batch_size=args.bs, 218 | shuffle=False, 219 | num_workers=args.num_workers, 220 | collate_fn=scoring_dataset.collate_fn, 221 | ) 222 | scores = evaluate_graph_regression( 223 | model, 224 | scoring_data_loader, 225 | model_name=args.model_name, 226 | use_energy_decoder=args.use_energy_decoder, 227 | is_hetero=args.is_hetero, 228 | ) 229 | pprint(scores) 230 | # save scores to file 231 | json.dump( 232 | scores, 233 | open( 234 | os.path.join(checkpoint_path, "casf2016_scoring_scores.json"), 235 | "w", 236 | ), 237 | ) 238 | 239 | # 2. Docking data 240 | print("Performing docking task...") 241 | id_to_rmsd = load_rmsd( 242 | os.path.join( 243 | args.data_dir, "../../casf2016_benchmark/decoys_docking_rmsd" 244 | ) 245 | ) 246 | 247 | docking_dataset = get_datasets( 248 | name="PDBBind", 249 | input_type=args.input_type, 250 | data_dir=os.path.join(args.data_dir, "docking"), 251 | test_only=True, 252 | residue_featurizer_name=args.residue_featurizer_name, 253 | use_energy_decoder=args.use_energy_decoder, 254 | intra_mol_energy=args.intra_mol_energy, 255 | ) 256 | print( 257 | "Data loaded:", 258 | len(docking_dataset), 259 | ) 260 | docking_data_loader = DataLoader( 261 | docking_dataset, 262 | batch_size=args.bs, 263 | shuffle=False, 264 | num_workers=args.num_workers, 265 | collate_fn=docking_dataset.collate_fn, 266 | ) 267 | all_preds = predict( 268 | model, 269 | docking_data_loader, 270 | model_name=args.model_name, 271 | use_energy_decoder=args.use_energy_decoder, 272 | is_hetero=args.is_hetero, 273 | ) 274 | id_to_pred = dict(zip(docking_dataset.keys, all_preds)) 275 | 276 | docking_scores = evaluate_docking(id_to_pred, id_to_rmsd) 277 | # save scores to file 278 | json.dump( 279 | docking_scores, 280 | open( 281 | os.path.join(checkpoint_path, "casf2016_docking_scores.json"), 282 | "w", 283 | ), 284 | ) 285 | # 3. Screening data 286 | print("Performing screening task...") 287 | true_binder_list = load_screening_target_file( 288 | os.path.join(args.data_dir, "../../casf2016_benchmark/TargetInfo.dat") 289 | ) 290 | 291 | screening_dataset = get_datasets( 292 | name="PDBBind", 293 | input_type=args.input_type, 294 | data_dir=os.path.join(args.data_dir, "screening"), 295 | test_only=True, 296 | residue_featurizer_name=args.residue_featurizer_name, 297 | use_energy_decoder=args.use_energy_decoder, 298 | intra_mol_energy=args.intra_mol_energy, 299 | ) 300 | print( 301 | "Data loaded:", 302 | len(screening_dataset), 303 | ) 304 | screening_data_loader = DataLoader( 305 | screening_dataset, 306 | batch_size=args.bs, 307 | shuffle=False, 308 | num_workers=args.num_workers, 309 | collate_fn=screening_dataset.collate_fn, 310 | ) 311 | all_preds = predict( 312 | model, 313 | screening_data_loader, 314 | model_name=args.model_name, 315 | use_energy_decoder=args.use_energy_decoder, 316 | is_hetero=args.is_hetero, 317 | ) 318 | id_to_pred = dict(zip(screening_dataset.keys, all_preds)) 319 | screening_scores = evaluate_screening(id_to_pred, true_binder_list) 320 | # save scores to file 321 | json.dump( 322 | screening_scores, 323 | open( 324 | os.path.join(checkpoint_path, "casf2016_screening_scores.json"), 325 | "w", 326 | ), 327 | ) 328 | return 329 | 330 | 331 | if __name__ == "__main__": 332 | parser = argparse.ArgumentParser() 333 | parser.add_argument( 334 | "--model_name", 335 | type=str, 336 | default="gvp", 337 | help="Choose from %s" % ", ".join(list(MODEL_CONSTRUCTORS.keys())), 338 | ) 339 | parser.add_argument( 340 | "--input_type", 341 | help="data input type", 342 | type=str, 343 | default="complex", 344 | ) 345 | parser.add_argument( 346 | "--checkpoint_path", 347 | type=str, 348 | help="ptl checkpoint path like `lightning_logs/version_x`", 349 | required=True, 350 | ) 351 | 352 | # dataset params 353 | parser.add_argument( 354 | "--data_dir", 355 | help="directory to dataset", 356 | type=str, 357 | default="", 358 | ) 359 | parser.add_argument( 360 | "--bs", type=int, default=64, help="batch size for test data" 361 | ) 362 | parser.add_argument( 363 | "--num_workers", 364 | type=int, 365 | default=0, 366 | help="num_workers used in DataLoader", 367 | ) 368 | # featurizer params 369 | parser.add_argument( 370 | "--residue_featurizer_name", 371 | help="name of the residue featurizer", 372 | type=str, 373 | default="MACCS", 374 | ) 375 | 376 | parser.add_argument("--use_energy_decoder", action="store_true") 377 | parser.add_argument("--is_hetero", action="store_true") 378 | parser.add_argument("--intra_mol_energy", action="store_true") 379 | parser.set_defaults( 380 | use_energy_decoder=False, is_hetero=False, intra_mol_energy=False 381 | ) 382 | 383 | args = parser.parse_args() 384 | 385 | print("args:", args) 386 | # evaluate 387 | main(args) 388 | -------------------------------------------------------------------------------- /figs/GoGs_of_complexes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/eggnet-equivariant-graph-of-graph-neural-network/87ee428c8a79171f2d5331e1cae7c6ac82d84dd8/figs/GoGs_of_complexes.png -------------------------------------------------------------------------------- /figs/GoGs_of_molecules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/eggnet-equivariant-graph-of-graph-neural-network/87ee428c8a79171f2d5331e1cae7c6ac82d84dd8/figs/GoGs_of_molecules.png -------------------------------------------------------------------------------- /h_gvp_experiments_bin_clf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | ########### 6 | ## Propedia 7 | ########### 8 | python train.py --accelerator gpu \ 9 | --max_epochs 500 \ 10 | --precision 16 \ 11 | --num_layers 3 \ 12 | --node_h_dim 200 32 \ 13 | --edge_h_dim 64 2 \ 14 | --dataset_name Propedia \ 15 | --input_type complex \ 16 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 17 | --residual \ 18 | --num_workers 8 \ 19 | --bs 32 \ 20 | --lr 1e-3 \ 21 | --early_stopping_patience 10 \ 22 | --residue_featurizer_name MACCS \ 23 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_GVP_MACCS 24 | 25 | 26 | python train.py --accelerator gpu \ 27 | --devices 4 \ 28 | --max_epochs 2 \ 29 | --precision 16 \ 30 | --dataset_name Propedia \ 31 | --input_type complex \ 32 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 33 | --residual \ 34 | --num_workers 16 \ 35 | --bs 16 \ 36 | --lr 1e-3 \ 37 | --early_stopping_patience 10 \ 38 | --residue_featurizer_name MACCS \ 39 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_GVP_MACCS 40 | 41 | 42 | python train.py --accelerator gpu \ 43 | --devices 4 \ 44 | --max_epochs 500 \ 45 | --precision 16 \ 46 | --dataset_name Propedia \ 47 | --input_type complex \ 48 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 49 | --data_suffix small \ 50 | --residual \ 51 | --num_workers 4 \ 52 | --bs 16 \ 53 | --lr 1e-3 \ 54 | --early_stopping_patience 10 \ 55 | --residue_featurizer_name MACCS \ 56 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_small_GVP_MACCS \ 57 | --persistent_workers True 58 | 59 | 60 | # selfdock: refine the positive complexes 61 | python train.py --accelerator gpu \ 62 | --devices 4 \ 63 | --max_epochs 500 \ 64 | --precision 16 \ 65 | --dataset_name Propedia \ 66 | --input_type complex \ 67 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 68 | --data_suffix small_selfdock \ 69 | --residual \ 70 | --num_workers 4 \ 71 | --bs 16 \ 72 | --lr 1e-3 \ 73 | --early_stopping_patience 10 \ 74 | --residue_featurizer_name MACCS \ 75 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_small_GVP_MACCS \ 76 | --persistent_workers True 77 | 78 | # selfdock + noise 79 | python train.py --accelerator gpu \ 80 | --devices 4 \ 81 | --max_epochs 500 \ 82 | --precision 16 \ 83 | --dataset_name Propedia \ 84 | --input_type complex \ 85 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 86 | --data_suffix small_selfdock \ 87 | --residual \ 88 | --num_workers 4 \ 89 | --bs 16 \ 90 | --lr 1e-3 \ 91 | --early_stopping_patience 10 \ 92 | --residue_featurizer_name MACCS \ 93 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_small_GVP_MACCS \ 94 | --add_noise 0.02 95 | 96 | # crystal + noise 97 | python train.py --accelerator gpu \ 98 | --devices 4 \ 99 | --max_epochs 500 \ 100 | --precision 16 \ 101 | --dataset_name Propedia \ 102 | --input_type complex \ 103 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 104 | --data_suffix small \ 105 | --residual \ 106 | --num_workers 4 \ 107 | --bs 16 \ 108 | --lr 1e-3 \ 109 | --early_stopping_patience 10 \ 110 | --residue_featurizer_name MACCS \ 111 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_small_GVP_MACCS \ 112 | --add_noise 0.02 113 | 114 | # less persistent workers to save CPU RAM 115 | # -> still may run out of CPU RAM 116 | python train.py --accelerator gpu \ 117 | --devices 4 \ 118 | --max_epochs 500 \ 119 | --precision 16 \ 120 | --num_layers 3 \ 121 | --node_h_dim 200 32 \ 122 | --edge_h_dim 64 2 \ 123 | --dataset_name Propedia \ 124 | --input_type complex \ 125 | --data_dir /home/ec2-user/SageMaker/efs/data/Propedia \ 126 | --residual \ 127 | --num_workers 8 \ 128 | --bs 16 \ 129 | --lr 1e-3 \ 130 | --early_stopping_patience 10 \ 131 | --residue_featurizer_name MACCS \ 132 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/Propedia_GVP_MACCS \ 133 | --persistent_workers True 134 | 135 | ########### 136 | ## PDBBind binary classification 137 | ########### 138 | python train.py --accelerator gpu \ 139 | --devices 1 \ 140 | --max_epochs 500 \ 141 | --precision 16 \ 142 | --num_layers 3 \ 143 | --node_h_dim 200 32 \ 144 | --edge_h_dim 64 2 \ 145 | --dataset_name PDBBind \ 146 | --input_type complex \ 147 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 148 | --residual \ 149 | --num_workers 4 \ 150 | --bs 32 \ 151 | --lr 1e-4 \ 152 | --early_stopping_patience 10 \ 153 | --residue_featurizer_name MACCS \ 154 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_bin_GVP_MACCS \ 155 | --persistent_workers True \ 156 | --binary_cutoff 6.7 157 | 158 | python train.py --accelerator gpu \ 159 | --devices 1 \ 160 | --max_epochs 500 \ 161 | --precision 16 \ 162 | --num_layers 3 \ 163 | --node_h_dim 200 32 \ 164 | --edge_h_dim 64 2 \ 165 | --dataset_name PDBBind \ 166 | --input_type complex \ 167 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 168 | --residual \ 169 | --num_workers 4 \ 170 | --bs 32 \ 171 | --lr 1e-4 \ 172 | --early_stopping_patience 10 \ 173 | --residue_featurizer_name MolT5-small \ 174 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_bin_GVP_MolT5_small \ 175 | --persistent_workers True \ 176 | --binary_cutoff 6.7 177 | 178 | ########### 179 | ## ProtCID 180 | ########### 181 | python train.py --accelerator gpu \ 182 | --devices 1 \ 183 | --max_epochs 500 \ 184 | --precision 16 \ 185 | --dataset_name Propedia \ 186 | --input_type complex \ 187 | --data_dir /home/ec2-user/SageMaker/efs/data/ProtCID/JaredJanssen_Benchmark \ 188 | --data_suffix small \ 189 | --residual \ 190 | --num_workers 8 \ 191 | --bs 16 \ 192 | --lr 1e-3 \ 193 | --early_stopping_patience 10 \ 194 | --residue_featurizer_name MACCS \ 195 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_small_GVP_MACCS \ 196 | --persistent_workers True 197 | 198 | python train.py --accelerator gpu \ 199 | --devices 1 \ 200 | --max_epochs 500 \ 201 | --precision 16 \ 202 | --dataset_name Propedia \ 203 | --input_type complex \ 204 | --data_dir /home/ec2-user/SageMaker/efs/data/ProtCID/JaredJanssen_Benchmark \ 205 | --data_suffix small \ 206 | --residual \ 207 | --num_workers 4 \ 208 | --bs 16 \ 209 | --lr 1e-3 \ 210 | --early_stopping_patience 10 \ 211 | --residue_featurizer_name MolT5-small \ 212 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/ProtCID_small_GVP_MolT5_small \ 213 | --persistent_workers True -------------------------------------------------------------------------------- /integration_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # SPDX-License-Identifier: MIT-0 4 | 5 | # Run this on Amazon DLAMI or SageMaker 6 | source ~/anaconda3/etc/profile.d/conda.sh 7 | conda activate pytorch_p38 8 | 9 | # global variables shared across all tests: 10 | n_gpus=4 11 | pdbbind_data=/home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring 12 | residue_featurizer_name=gin-supervised-contextpred-mean 13 | # residue_featurizer_name=MACCS 14 | 15 | # row2: pretrained GNN GVP None 16 | python train.py --accelerator gpu \ 17 | --model_name gvp \ 18 | --devices $n_gpus \ 19 | --fast_dev_run $n_gpus \ 20 | --precision 16 \ 21 | --dataset_name PDBBind \ 22 | --input_type complex \ 23 | --residual \ 24 | --residue_featurizer_name $residue_featurizer_name \ 25 | --data_dir $pdbbind_data 26 | 27 | # row3: pretrained GNN MS-GVP None 28 | python train.py --accelerator gpu \ 29 | --model_name multistage-gvp \ 30 | --devices $n_gpus \ 31 | --fast_dev_run $n_gpus \ 32 | --precision 16 \ 33 | --dataset_name PDBBind \ 34 | --input_type multistage-hetero \ 35 | --residual \ 36 | --residue_featurizer_name $residue_featurizer_name \ 37 | --data_dir $pdbbind_data 38 | 39 | # row4: pretrained GNN joint training GVP None 40 | python train.py --accelerator gpu \ 41 | --model_name hgvp \ 42 | --devices $n_gpus \ 43 | --fast_dev_run $n_gpus \ 44 | --precision 32 \ 45 | --dataset_name PDBBind \ 46 | --input_type complex \ 47 | --residual \ 48 | --residue_featurizer_name $residue_featurizer_name-grad \ 49 | --data_dir $pdbbind_data 50 | 51 | # row5: pretrained GNN joint training MS-GVP None 52 | python train.py --accelerator gpu \ 53 | --model_name multistage-hgvp \ 54 | --devices $n_gpus \ 55 | --fast_dev_run $n_gpus \ 56 | --precision 32 \ 57 | --dataset_name PDBBind \ 58 | --input_type multistage-hetero \ 59 | --is_hetero \ 60 | --residual \ 61 | --residue_featurizer_name $residue_featurizer_name-grad \ 62 | --data_dir $pdbbind_data 63 | 64 | # row6: pretrained GNN GVP E_int 65 | python train.py --accelerator gpu \ 66 | --model_name gvp \ 67 | --devices $n_gpus \ 68 | --fast_dev_run $n_gpus \ 69 | --precision 16 \ 70 | --dataset_name PDBBind \ 71 | --input_type complex \ 72 | --residual \ 73 | --residue_featurizer_name $residue_featurizer_name \ 74 | --use_energy_decoder \ 75 | --is_hetero \ 76 | --data_dir $pdbbind_data 77 | 78 | # row7: pretrained GNN MS-GVP E_int 79 | python train.py --accelerator gpu \ 80 | --model_name multistage-gvp \ 81 | --devices $n_gpus \ 82 | --fast_dev_run $n_gpus \ 83 | --precision 16 \ 84 | --dataset_name PDBBind \ 85 | --input_type multistage-hetero \ 86 | --residual \ 87 | --residue_featurizer_name $residue_featurizer_name \ 88 | --use_energy_decoder \ 89 | --is_hetero \ 90 | --data_dir $pdbbind_data 91 | 92 | # row8: pretrained GNN joint training GVP E_int 93 | python train.py --accelerator gpu \ 94 | --model_name hgvp \ 95 | --devices $n_gpus \ 96 | --fast_dev_run $n_gpus \ 97 | --precision 32 \ 98 | --dataset_name PDBBind \ 99 | --input_type complex \ 100 | --residual \ 101 | --residue_featurizer_name $residue_featurizer_name-grad \ 102 | --use_energy_decoder \ 103 | --is_hetero \ 104 | --data_dir $pdbbind_data 105 | 106 | # row9: pretrained GNN joint training MS-GVP E_int 107 | python train.py --accelerator gpu \ 108 | --model_name multistage-hgvp \ 109 | --devices $n_gpus \ 110 | --fast_dev_run $n_gpus \ 111 | --precision 32 \ 112 | --dataset_name PDBBind \ 113 | --input_type multistage-hetero \ 114 | --residual \ 115 | --residue_featurizer_name $residue_featurizer_name-grad \ 116 | --use_energy_decoder \ 117 | --is_hetero \ 118 | --data_dir $pdbbind_data \ 119 | --bs 8 120 | -------------------------------------------------------------------------------- /multistage_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | # multistage-physical 6 | for input_type in physical hetero geometric; do 7 | CUDA_VISIBLE_DEVICES=2 python train.py --accelerator gpu \ 8 | --devices 1 \ 9 | --max_epochs 500 \ 10 | --precision 16 \ 11 | --protein_num_layers 3 \ 12 | --ligand_num_layers 3 \ 13 | --complex_num_layers 3 \ 14 | --protein_node_h_dim 200 32 \ 15 | --protein_edge_h_dim 64 2 \ 16 | --ligand_node_h_dim 200 32 \ 17 | --ligand_edge_h_dim 64 2 \ 18 | --complex_node_h_dim 200 32 \ 19 | --complex_edge_h_dim 64 2 \ 20 | --dataset_name PDBBind \ 21 | --input_type multistage-$input_type \ 22 | --model_name gvp-multistage \ 23 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring \ 24 | --residual \ 25 | --num_workers 8 \ 26 | --lr 1e-4 \ 27 | --bs 128 \ 28 | --early_stopping_patience 10 \ 29 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_MSGVP_$input_type 30 | done 31 | 32 | python evaluate_casf2016.py --model_name gvp-multistage \ 33 | --input_type multistage-$input_type \ 34 | --num_workers 8 \ 35 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/casf2016 \ 36 | --checkpoint_path /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_MSGVP_$input_type/lightning_logs/version_2 37 | 38 | 39 | CUDA_VISIBLE_DEVICES=1,2 python train.py --accelerator gpu \ 40 | --devices 1 \ 41 | --max_epochs 500 \ 42 | --precision 16 \ 43 | --protein_num_layers 3 \ 44 | --ligand_num_layers 3 \ 45 | --complex_num_layers 3 \ 46 | --protein_node_h_dim 200 32 \ 47 | --protein_edge_h_dim 64 2 \ 48 | --ligand_node_h_dim 200 32 \ 49 | --ligand_edge_h_dim 64 2 \ 50 | --complex_node_h_dim 200 32 \ 51 | --complex_edge_h_dim 64 2 \ 52 | --dataset_name PDBBind \ 53 | --input_type multistage-physical \ 54 | --model_name gvp-multistage \ 55 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring \ 56 | --residual \ 57 | --num_workers 8 \ 58 | --lr 1e-4 \ 59 | --bs 128 \ 60 | --early_stopping_patience 10 \ 61 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_MSGVP_physical -------------------------------------------------------------------------------- /ppi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/eggnet-equivariant-graph-of-graph-neural-network/87ee428c8a79171f2d5331e1cae7c6ac82d84dd8/ppi/__init__.py -------------------------------------------------------------------------------- /ppi/data.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | Pytorch dataset classes from PPI prediction. 5 | """ 6 | from ppi.data_utils.pignet_featurizers import mol_to_feature 7 | from rdkit import Chem 8 | import os 9 | import pickle 10 | from typing import Any, Dict, List 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from sklearn.metrics import pairwise_distances 15 | import torch 16 | import torch.utils.data as data 17 | import dgl 18 | from tqdm import tqdm 19 | 20 | import numpy as np 21 | import pickle 22 | from Bio.PDB import PDBParser, MMCIFParser 23 | 24 | # custom modules 25 | from ppi.data_utils import ( 26 | remove_nan_residues, 27 | mol_to_pdb_structure, 28 | residue_to_mol, 29 | parse_structure, 30 | ) 31 | 32 | 33 | def check_dimension(tensors: List[Any]) -> Any: 34 | size = [] 35 | for tensor in tensors: 36 | if isinstance(tensor, np.ndarray): 37 | size.append(tensor.shape) 38 | else: 39 | size.append(0) 40 | size = np.asarray(size) 41 | 42 | return np.max(size, 0) 43 | 44 | 45 | def collate_tensor(tensor: Any, max_tensor: Any, batch_idx: int) -> Any: 46 | if isinstance(tensor, np.ndarray): 47 | dims = tensor.shape 48 | slice_list = tuple([slice(0, dim) for dim in dims]) 49 | slice_list = [slice(batch_idx, batch_idx + 1), *slice_list] 50 | max_tensor[tuple(slice_list)] = tensor 51 | elif isinstance(tensor, str): 52 | max_tensor[batch_idx] = tensor 53 | else: 54 | max_tensor[batch_idx] = tensor 55 | 56 | return max_tensor 57 | 58 | 59 | def tensor_collate_fn(batch: List[Any]) -> Dict[str, Any]: 60 | batch_items = [it for e in batch for it in e.items()] 61 | dim_dict = dict() 62 | total_key, total_value = list(zip(*batch_items)) 63 | batch_size = len(batch) 64 | n_element = int(len(batch_items) / batch_size) 65 | total_key = total_key[0:n_element] 66 | for i, k in enumerate(total_key): 67 | value_list = [ 68 | v for j, v in enumerate(total_value) if j % n_element == i 69 | ] 70 | if isinstance(value_list[0], np.ndarray): 71 | dim_dict[k] = np.zeros( 72 | np.array([batch_size, *check_dimension(value_list)]) 73 | ) 74 | elif isinstance(value_list[0], str): 75 | dim_dict[k] = ["" for _ in range(batch_size)] 76 | else: 77 | dim_dict[k] = np.zeros((batch_size,)) 78 | 79 | ret_dict = {} 80 | for j in range(batch_size): 81 | if batch[j] is None: 82 | continue 83 | for key, value in dim_dict.items(): 84 | value = collate_tensor(batch[j][key], value, j) 85 | if not isinstance(value, list): 86 | value = torch.from_numpy(value).float() 87 | ret_dict[key] = value 88 | 89 | return ret_dict 90 | 91 | 92 | class BasePPIDataset(data.Dataset): 93 | """Dataset for the Base Protein Graph.""" 94 | 95 | def __init__(self, preprocess=False): 96 | self.processed_data = pd.Series([None] * len(self)) 97 | if preprocess: 98 | print("Preprocessing data...") 99 | self._preprocess_all() 100 | 101 | def __getitem__(self, i): 102 | if self.processed_data[i] is None: 103 | # if not processed, process this instance and update 104 | self.processed_data[i] = self._preprocess(i) 105 | return self.processed_data[i] 106 | 107 | def _preprocess(self, complex): 108 | raise NotImplementedError 109 | 110 | def _preprocess_all(self): 111 | """Preprocess all the records in `data_list` with `_preprocess""" 112 | for i in tqdm(range(len(self.processed_data))): 113 | self.processed_data[i] = self._preprocess(i) 114 | 115 | 116 | class PDBComplexDataset(BasePPIDataset): 117 | """ 118 | To work with Propedia and ProtCID data, where each individual sample is a 119 | PDB complex file. 120 | """ 121 | 122 | def __init__( 123 | self, 124 | meta_df: pd.DataFrame, 125 | path_to_data_files: str, 126 | featurizer: object, 127 | compute_energy=False, 128 | intra_mol_energy=False, 129 | **kwargs 130 | ): 131 | self.meta_df = meta_df 132 | self.path = path_to_data_files 133 | self.pdb_parser = PDBParser( 134 | QUIET=True, 135 | PERMISSIVE=True, 136 | ) 137 | self.cif_parser = MMCIFParser(QUIET=True) 138 | self.featurizer = featurizer 139 | self.compute_energy = compute_energy 140 | self.intra_mol_energy = intra_mol_energy 141 | super(PDBComplexDataset, self).__init__(**kwargs) 142 | 143 | def __len__(self) -> int: 144 | return self.meta_df.shape[0] 145 | 146 | def _preprocess(self, idx: int) -> Dict[str, Any]: 147 | row = self.meta_df.iloc[idx] 148 | structure = parse_structure( 149 | self.pdb_parser, 150 | self.cif_parser, 151 | name=str(idx), 152 | file_path=os.path.join(self.path, row["pdb_file"]), 153 | ) 154 | for chain in structure.get_chains(): 155 | if chain.id == row["receptor_chain_id"]: 156 | protein = chain 157 | elif chain.id == row["ligand_chain_id"]: 158 | ligand = chain 159 | sample = self.featurizer.featurize( 160 | {"ligand": ligand, "protein": protein} 161 | ) 162 | sample["target"] = row["label"] 163 | if self.compute_energy: 164 | ligand_mol = residue_to_mol(ligand, sanitize=False) 165 | protein_mol = residue_to_mol(protein, sanitize=False) 166 | physics = mol_to_feature( 167 | ligand_mol, protein_mol, compute_full=self.intra_mol_energy 168 | ) 169 | sample["physics"] = physics 170 | return sample 171 | 172 | @property 173 | def pos_weight(self) -> torch.Tensor: 174 | """To compute the weight of the positive class, assuming binary 175 | classification""" 176 | class_sizes = self.meta_df["label"].value_counts() 177 | pos_weights = np.mean(class_sizes) / class_sizes 178 | pos_weights = torch.from_numpy(pos_weights.values.astype(np.float32)) 179 | return pos_weights[1] / pos_weights[0] 180 | 181 | def collate_fn(self, samples): 182 | """Collating protein complex graphs and graph-level targets.""" 183 | graphs = [] 184 | smiles_strings = [] 185 | g_targets = [] 186 | physics = [] 187 | for rec in samples: 188 | graphs.append(rec["graph"]) 189 | g_targets.append(rec["target"]) 190 | if "smiles_strings" in rec: 191 | smiles_strings.extend(rec["smiles_strings"]) 192 | if self.compute_energy: 193 | physics.append(rec["physics"]) 194 | res = { 195 | "graph": dgl.batch(graphs), 196 | "g_targets": torch.tensor(g_targets) 197 | .to(torch.float32) 198 | .unsqueeze(-1), 199 | "smiles_strings": smiles_strings, 200 | } 201 | if self.compute_energy: 202 | res["sample"] = tensor_collate_fn(physics) 203 | return res 204 | 205 | 206 | class PIGNetComplexDataset(data.Dataset): 207 | """ 208 | To work with preprocessed pickles sourced from PDBBind dataset by the 209 | PIGNet paper. 210 | Modified from https://github.com/ACE-KAIST/PIGNet/blob/main/dataset.py 211 | """ 212 | 213 | def __init__( 214 | self, 215 | keys: List[str], 216 | data_dir: str, 217 | id_to_y: Dict[str, float], 218 | featurizer: object, 219 | compute_energy=False, 220 | intra_mol_energy=False, 221 | binary_cutoff=None, 222 | ): 223 | self.keys = np.array(keys).astype(np.unicode_) 224 | self.data_dir = data_dir 225 | self.id_to_y = pd.Series(id_to_y) 226 | self.featurizer = featurizer 227 | self.processed_data = pd.Series([None] * len(self)) 228 | self.compute_energy = compute_energy 229 | self.intra_mol_energy = intra_mol_energy 230 | self.binary_cutoff = binary_cutoff 231 | 232 | def __len__(self) -> int: 233 | return len(self.keys) 234 | 235 | def __getitem__(self, idx: int) -> Dict[str, Any]: 236 | if self.processed_data[idx] is None: 237 | self.processed_data[idx] = self._preprocess(idx) 238 | return self.processed_data[idx] 239 | 240 | def _preprocess_all(self): 241 | """Preprocess all the records in `data_list` with `_preprocess""" 242 | for i in tqdm(range(len(self))): 243 | self.processed_data[i] = self._preprocess(i) 244 | 245 | def _preprocess(self, idx: int) -> Dict[str, Any]: 246 | key = self.keys[idx] 247 | with open(os.path.join(self.data_dir, "data", key), "rb") as f: 248 | m1, _, m2, _ = pickle.load(f) 249 | 250 | if type(m2) is Chem.rdchem.Mol: 251 | protein_mol = m2 252 | protein_pdb = mol_to_pdb_structure(m2) 253 | else: 254 | protein_pdb = m2 255 | protein_mol = None 256 | 257 | sample = self.featurizer.featurize( 258 | { 259 | "ligand": m1, 260 | "protein": protein_pdb, 261 | } 262 | ) 263 | if self.binary_cutoff is None: 264 | sample["affinity"] = self.id_to_y[key] * -1.36 265 | else: 266 | # convert to a binary classification problem: 267 | sample["affinity"] = self.id_to_y[key] >= self.binary_cutoff 268 | sample["key"] = key 269 | if self.compute_energy: 270 | if protein_mol is None: 271 | protein_mol = residue_to_mol(protein_pdb, sanitize=False) 272 | physics = mol_to_feature( 273 | m1, protein_mol, compute_full=self.intra_mol_energy 274 | ) 275 | sample["physics"] = physics 276 | return sample 277 | 278 | @property 279 | def pos_weight(self) -> torch.Tensor: 280 | """To compute the weight of the positive class, assuming binary 281 | classification""" 282 | if self.binary_cutoff is None: 283 | return None 284 | else: 285 | affinities = self.id_to_y.loc[self.keys] > self.binary_cutoff 286 | class_sizes = affinities.astype(int).value_counts() 287 | pos_weights = np.mean(class_sizes) / class_sizes 288 | pos_weights = torch.from_numpy( 289 | pos_weights.values.astype(np.float32) 290 | ) 291 | return pos_weights[1] / pos_weights[0] 292 | 293 | def collate_fn(self, samples): 294 | """Collating protein complex graphs and graph-level targets.""" 295 | graphs = [] 296 | smiles_strings = [] 297 | g_targets = [] 298 | physics = [] 299 | for rec in samples: 300 | graphs.append(rec["graph"]) 301 | g_targets.append(rec["affinity"]) 302 | if "smiles_strings" in rec: 303 | smiles_strings.extend(rec["smiles_strings"]) 304 | if self.compute_energy: 305 | physics.append(rec["physics"]) 306 | res = { 307 | "graph": dgl.batch(graphs), 308 | "g_targets": torch.tensor(g_targets) 309 | .to(torch.float32) 310 | .unsqueeze(-1), 311 | "smiles_strings": smiles_strings, 312 | } 313 | if self.compute_energy: 314 | res["sample"] = tensor_collate_fn(physics) 315 | return res 316 | 317 | 318 | class PIGNetHeteroBigraphComplexDataset(data.Dataset): 319 | """ 320 | To work with preprocessed pickles sourced from PDBBind dataset by the 321 | PIGNet paper. 322 | Modified from https://github.com/ACE-KAIST/PIGNet/blob/main/dataset.py 323 | """ 324 | 325 | def __init__( 326 | self, 327 | keys: List[str], 328 | data_dir: str, 329 | id_to_y: Dict[str, float], 330 | featurizer: object, 331 | ): 332 | self.keys = keys 333 | self.data_dir = data_dir 334 | self.id_to_y = id_to_y 335 | self.featurizer = featurizer 336 | 337 | def __len__(self) -> int: 338 | return len(self.keys) 339 | 340 | def __getitem__(self, idx: int) -> Dict[str, Any]: 341 | key = self.keys[idx] 342 | with open(os.path.join(self.data_dir, "data", key), "rb") as f: 343 | m1, _, m2, _ = pickle.load(f) 344 | if type(m2) is Chem.rdchem.Mol: 345 | m2 = mol_to_pdb_structure(m2) 346 | 347 | if self.featurizer.residue_featurizer: 348 | ( 349 | protein_graph, 350 | ligand_graph, 351 | complex_graph, 352 | ) = self.featurizer.featurize( 353 | { 354 | "ligand": m1, 355 | "protein": m2, 356 | } 357 | ) 358 | sample = { 359 | "protein_graph": protein_graph, 360 | "ligand_graph": ligand_graph, 361 | "complex_graph": complex_graph, 362 | } 363 | else: 364 | ( 365 | protein_graph, 366 | ligand_graph, 367 | complex_graph, 368 | protein_smiles_strings, 369 | ligand_smiles, 370 | ) = self.featurizer.featurize( 371 | { 372 | "ligand": m1, 373 | "protein": m2, 374 | } 375 | ) 376 | sample = { 377 | "protein_graph": protein_graph, 378 | "ligand_graph": ligand_graph, 379 | "complex_graph": complex_graph, 380 | "protein_smiles_strings": protein_smiles_strings, 381 | "ligand_smiles_strings": None, 382 | "ligand_smiles": ligand_smiles, 383 | } 384 | sample["affinity"] = self.id_to_y[key] * -1.36 385 | sample["key"] = key 386 | return sample 387 | 388 | def collate_fn(self, samples): 389 | """Collating protein complex graphs and graph-level targets.""" 390 | ( 391 | protein_graphs, 392 | ligand_graphs, 393 | complex_graphs, 394 | protein_smiles_strings, 395 | ligand_smiles, 396 | ) = ([], [], [], [], []) 397 | g_targets = [] 398 | for rec in samples: 399 | protein_graphs.append(rec["protein_graph"]) 400 | ligand_graphs.append(rec["ligand_graph"]) 401 | complex_graphs.append(rec["complex_graph"]) 402 | g_targets.append(rec["affinity"]) 403 | if "protein_smiles_strings" in rec: 404 | protein_smiles_strings.extend(rec["protein_smiles_strings"]) 405 | if "ligand_smiles" in rec: 406 | ligand_smiles.append(rec["ligand_smiles"]) 407 | return { 408 | "protein_graph": dgl.batch(protein_graphs), 409 | "ligand_graph": dgl.batch(ligand_graphs), 410 | "complex_graph": dgl.batch(complex_graphs), 411 | "g_targets": torch.tensor(g_targets).unsqueeze(-1), 412 | "protein_smiles_strings": protein_smiles_strings, 413 | "ligand_smiles_strings": None, 414 | "ligand_smiles": ligand_smiles, 415 | } 416 | 417 | 418 | class PIGNetHeteroBigraphComplexDatasetForEnergyModel(data.Dataset): 419 | """ 420 | To work with preprocessed pickles sourced from PDBBind dataset by the 421 | PIGNet paper. 422 | Modified from https://github.com/ACE-KAIST/PIGNet/blob/main/dataset.py 423 | """ 424 | 425 | def __init__( 426 | self, 427 | keys: List[str], 428 | data_dir: str, 429 | id_to_y: Dict[str, float], 430 | featurizer: object, 431 | ): 432 | self.keys = keys 433 | self.data_dir = data_dir 434 | self.id_to_y = id_to_y 435 | self.featurizer = featurizer 436 | 437 | def __len__(self) -> int: 438 | return len(self.keys) 439 | 440 | def __getitem__(self, idx: int) -> Dict[str, Any]: 441 | key = self.keys[idx] 442 | with open(os.path.join(self.data_dir, "data", key), "rb") as f: 443 | m1, _, m2, _ = pickle.load(f) 444 | 445 | if type(m2) is Chem.rdchem.Mol: 446 | protein_atoms = m2 447 | protein_residues = mol_to_pdb_structure(m2) 448 | else: 449 | protein_residues = m2 450 | protein_atoms = residue_to_mol(m2) 451 | 452 | if self.featurizer.residue_featurizer: 453 | ( 454 | protein_graph, 455 | ligand_graph, 456 | complex_graph, 457 | physics, 458 | atom_to_residue, 459 | ) = self.featurizer.featurize( 460 | { 461 | "ligand": m1, 462 | "protein_atoms": protein_atoms, 463 | "protein_residues": protein_residues, 464 | } 465 | ) 466 | sample = { 467 | "protein_graph": protein_graph, 468 | "ligand_graph": ligand_graph, 469 | "complex_graph": complex_graph, 470 | "sample": physics, 471 | "atom_to_residue": atom_to_residue, 472 | } 473 | else: 474 | ( 475 | protein_graph, 476 | ligand_graph, 477 | complex_graph, 478 | physics, 479 | atom_to_residue, 480 | smiles_strings, 481 | ligand_smiles, 482 | ) = self.featurizer.featurize( 483 | { 484 | "ligand": m1, 485 | "protein_atoms": protein_atoms, 486 | "protein_residues": protein_residues, 487 | } 488 | ) 489 | sample = { 490 | "protein_graph": protein_graph, 491 | "ligand_graph": ligand_graph, 492 | "complex_graph": complex_graph, 493 | "sample": physics, 494 | "atom_to_residue": atom_to_residue, 495 | "protein_smiles_strings": smiles_strings, 496 | "ligand_smiles_strings": None, 497 | "ligand_smiles": ligand_smiles, 498 | } 499 | sample["affinity"] = self.id_to_y[key] * -1.36 500 | sample["key"] = key 501 | return sample 502 | 503 | def collate_fn(self, samples): 504 | """Collating protein complex graphs and graph-level targets.""" 505 | ( 506 | protein_graphs, 507 | ligand_graphs, 508 | complex_graphs, 509 | physics, 510 | atom_to_residues, 511 | protein_smiles_strings, 512 | ligand_smiles, 513 | ) = ([], [], [], [], [], [], []) 514 | g_targets = [] 515 | for rec in samples: 516 | protein_graphs.append(rec["protein_graph"]) 517 | ligand_graphs.append(rec["ligand_graph"]) 518 | complex_graphs.append(rec["complex_graph"]) 519 | physics.append(rec["sample"]) 520 | atom_to_residues.append(rec["atom_to_residue"]) 521 | g_targets.append(rec["affinity"]) 522 | if "protein_smiles_strings" in rec: 523 | protein_smiles_strings.extend(rec["protein_smiles_strings"]) 524 | if "ligand_smiles" in rec: 525 | ligand_smiles.append(rec["ligand_smiles"]) 526 | return { 527 | "protein_graph": dgl.batch(protein_graphs), 528 | "ligand_graph": dgl.batch(ligand_graphs), 529 | "complex_graph": dgl.batch(complex_graphs), 530 | "sample": tensor_collate_fn(physics), 531 | "atom_to_residue": atom_to_residues, 532 | "g_targets": torch.tensor(g_targets).unsqueeze(-1), 533 | "protein_smiles_strings": protein_smiles_strings, 534 | "ligand_smiles_strings": None, 535 | "ligand_smiles": ligand_smiles, 536 | } 537 | 538 | 539 | class PDBBigraphComplexDataset(BasePPIDataset): 540 | """ 541 | To work with Propedia and ProtCID data, where each individual sample is a 542 | PDB complex file. 543 | """ 544 | 545 | def __init__( 546 | self, 547 | meta_df: pd.DataFrame, 548 | path_to_data_files: str, 549 | featurizer: object, 550 | **kwargs 551 | ): 552 | self.meta_df = meta_df 553 | self.path = path_to_data_files 554 | self.pdb_parser = PDBParser( 555 | QUIET=True, 556 | PERMISSIVE=True, 557 | ) 558 | self.cif_parser = MMCIFParser(QUIET=True) 559 | self.featurizer = featurizer 560 | super(PDBBigraphComplexDataset, self).__init__(**kwargs) 561 | 562 | def __len__(self) -> int: 563 | return self.meta_df.shape[0] 564 | 565 | def _preprocess(self, idx: int) -> Dict[str, Any]: 566 | row = self.meta_df.iloc[idx] 567 | structure = parse_structure( 568 | self.pdb_parser, 569 | self.cif_parser, 570 | name=str(idx), 571 | file_path=os.path.join(self.path, row["pdb_file"]), 572 | ) 573 | for chain in structure.get_chains(): 574 | if chain.id == row["receptor_chain_id"]: 575 | protein = chain 576 | elif chain.id == row["ligand_chain_id"]: 577 | ligand = chain 578 | sample = self.featurizer.featurize( 579 | {"ligand": ligand, "protein": protein} 580 | ) 581 | sample["target"] = row["label"] 582 | return sample 583 | 584 | @property 585 | def pos_weight(self) -> torch.Tensor: 586 | """To compute the weight of the positive class, assuming binary 587 | classification""" 588 | class_sizes = self.meta_df["label"].value_counts() 589 | pos_weights = np.mean(class_sizes) / class_sizes 590 | pos_weights = torch.from_numpy(pos_weights.values.astype(np.float32)) 591 | return pos_weights[1] / pos_weights[0] 592 | 593 | def collate_fn(self, samples): 594 | """Collating protein complex graphs and graph-level targets.""" 595 | protein_graphs = [] 596 | protein_smiles_strings = [] 597 | ligand_graphs = [] 598 | ligand_smiles_strings = [] 599 | complex_graphs = [] 600 | g_targets = [] 601 | for rec in samples: 602 | protein_graphs.append(rec["protein_graph"]) 603 | ligand_graphs.append(rec["ligand_graph"]) 604 | complex_graphs.append(rec["complex_graph"]) 605 | g_targets.append(rec["target"]) 606 | if "protein_smiles_strings" in rec: 607 | protein_smiles_strings.extend(rec["protein_smiles_strings"]) 608 | if "ligand_smiles_strings" in rec: 609 | ligand_smiles_strings.extend(rec["ligand_smiles_strings"]) 610 | return { 611 | "protein_graph": dgl.batch(protein_graphs), 612 | "ligand_graph": dgl.batch(ligand_graphs), 613 | "complex_graph": dgl.batch(complex_graphs), 614 | "g_targets": torch.tensor(g_targets) 615 | .to(torch.float32) 616 | .unsqueeze(-1), 617 | "protein_smiles_strings": protein_smiles_strings, 618 | "ligand_smiles_strings": ligand_smiles_strings, 619 | "ligand_smiles": None, 620 | } 621 | -------------------------------------------------------------------------------- /ppi/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .contact_map_utils import * 2 | from .xpdb import SloppyStructureBuilder 3 | from .polypeptide_featurizers import * 4 | from .residue_featurizers import * 5 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/crawl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jul 15 15:26:57 2020 4 | 5 | @author: lenovo 6 | 4650""" 7 | import requests 8 | from lxml import etree 9 | import csv 10 | url="http://huanglab.phys.hust.edu.cn/pepbdb/browse.php" 11 | pep_ids=[] 12 | for i in range(1,266): 13 | surl=url+"?pagenum="+str(i) 14 | html=requests.get(surl).text 15 | selector=etree.HTML(html) 16 | for j in range(2,52): 17 | pdbid=selector.xpath("/html/body/div[2]/table/tr["+str(j)+"]/td[1]/a/text()")[0] 18 | pdbid=str.lower(pdbid) 19 | peptideid=selector.xpath("/html/body/div[2]/table/tr["+str(j)+"]/td[2]/text()")[0] 20 | pep_id=pdbid+'_'+peptideid[1] 21 | pep_ids.append(pep_id) 22 | surl=url+"?pagenum=266" 23 | html=requests.get(surl).text 24 | selector=etree.HTML(html) 25 | for j in range(2,51): 26 | pdbid=selector.xpath("/html/body/div[2]/table/tr["+str(j)+"]/td[1]/a/text()")[0] 27 | pdbid=str.lower(pdbid) 28 | peptideid=selector.xpath("/html/body/div[2]/table/tr["+str(j)+"]/td[2]/text()")[0] 29 | pep_id=pdbid+'_'+peptideid[1] 30 | pep_ids.append(pep_id) 31 | print(len(pep_ids)) 32 | with open('crawl_results.csv','w',newline="",encoding="utf-8-sig") as f: 33 | writer=csv.writer(f) 34 | writer.writerow(["Peptide ID","Interacting peptide residues","Peptide sequence","Interacting receptor residues","Receptor sequence(s)"]) 35 | for pep_id in pep_ids[5000:]: 36 | if pep_id in ['6mk1_Z']: 37 | continue 38 | print(pep_id) 39 | row=[] 40 | url= "http://huanglab.phys.hust.edu.cn/pepbdb/db/"+pep_id+"/" 41 | html=requests.get(url).text 42 | selector=etree.HTML(html) 43 | ipr=selector.xpath("/html/body/div[2]/table[1]/tr/td/text()")[0] 44 | ps=selector.xpath("/html/body/div[2]/table[3]/tr/td/text()") 45 | irr=selector.xpath("/html/body/div[2]/table[2]/tr/td/text()") 46 | rs=selector.xpath("/html/body/div[2]/table[4]/tr/td/text()") 47 | irrdict={} 48 | for item in irr: 49 | item=item.split(': ') 50 | irrdict[item[0]]=item[1] 51 | rsdict={} 52 | string='' 53 | for item in rs: 54 | string+=item 55 | seq_list=string.split('>') 56 | seq_list.remove('') 57 | 58 | rsdict={} 59 | for seq in seq_list: 60 | rsdict[seq[0]]=seq[1:] 61 | 62 | for pid in irrdict: 63 | irr=pid+': '+irrdict[pid] 64 | rs=pid+': '+rsdict[pid] 65 | pepid=pep_id+'_'+pid 66 | row.append(pepid) 67 | row.append(ipr) 68 | psfinal=ps[0].strip('>')+': '+ps[1] 69 | row.append(psfinal) 70 | row.append(irr) 71 | row.append(rs) 72 | writer.writerow(row) 73 | row.clear() 74 | print("down") 75 | 76 | 77 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/query-mapping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | pepbdb_inputdir = '/home/ec2-user/SageMaker/efs/data/CAMP/paper/' #tar -xvzf'd version of pepbdb-20200318 3 | 4 | querys=[] 5 | import csv 6 | import os 7 | def check_abnormal_aa(peptide_seq): 8 | len_seq = len(peptide_seq) 9 | cnt = 0 10 | standard_aa = ['G','A','P','V','L','I','M','F','Y','W','S','T','C','N','Q','K','H','R','D','E'] 11 | for i in peptide_seq: 12 | if i in standard_aa : 13 | cnt = cnt+1 14 | score = float(cnt)/len_seq 15 | return score 16 | def delete_duplicate(seq): 17 | seqsort=[] 18 | for i in seq: 19 | if i not in seqsort: 20 | seqsort.append(i) 21 | seqstr={} 22 | for s in seqsort: 23 | seqstr[s[:-1]]=s[-1:] 24 | return seqstr 25 | aa_dict = {'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', \ 26 | 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', \ 27 | 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V', 'SEC': 'U', 'PLY': 'O'} 28 | csvfile=open('./crawl_results.csv','r') 29 | reader=csv.reader(csvfile) 30 | residue_dict={} 31 | seq_dict={} 32 | 33 | for item in reader: 34 | if reader.line_num==1: 35 | continue 36 | qid=item[0] 37 | querys.append(qid) 38 | pep_index=item[1].split(': ') #prot_index=item[3].split(': ') 39 | residue_dict[item[0]]=pep_index[1] 40 | seq_dict[item[0]]=item[2].split(': ')[1]#seq_dict[item[0]]=item[4].split(': ')[1] 41 | whole_dict={} 42 | for pid in querys: 43 | sequence=[] 44 | 45 | pdbid=pid[:6] 46 | chain=pid[5] #chain=pid[7] 47 | address=os.path.join(pepbdb_inputdir, 'pepbdb/', pdbid, 'peptide.pdb')#receptor.pdb 48 | if os.path.isfile(address)==False: 49 | continue 50 | with open(address,'r') as f: 51 | for line in f: 52 | line=line.split() 53 | if 'HETATM' in line[0] and len(line[0])>6: 54 | if line[3]==chain: 55 | index=line[4] 56 | elif line[3][0]==chain: 57 | index=line[3][1:] 58 | else: 59 | continue 60 | amino = line[2] 61 | if amino in aa_dict: 62 | amino=aa_dict[amino] 63 | else: 64 | amino='X' 65 | sequence.append(index+amino) 66 | else: 67 | if line[0]=='TER': 68 | continue 69 | amino=line[3] 70 | if line[4]==chain: 71 | index=line[5] 72 | elif line[4][0]==chain: 73 | index=line[4][1:] 74 | else: 75 | continue 76 | if amino in aa_dict: 77 | amino=aa_dict[amino] 78 | else: 79 | amino='X' 80 | sequence.append(index+amino) 81 | stringdict=delete_duplicate(sequence) 82 | whole_dict[pid]=stringdict 83 | starts={} 84 | sorted_lists={} 85 | pass_lists=[] 86 | print("Total queryes {0}".format(len(querys))) 87 | querys = [q for q in querys if q in whole_dict.keys()] 88 | print("Queries with pdb data {0}, filtered to those with pdb data locally {1}".format(len(whole_dict.keys()), len(querys))) 89 | 90 | with open('step2/peptide-mapping.txt','w') as f2:#prot-mapping.txt 91 | for pdbid in querys: 92 | if pdbid not in residue_dict: 93 | print(pdbid) 94 | continue 95 | residue=residue_dict[pdbid].split(', ') 96 | output='' 97 | outputseq='' 98 | real_sequ=whole_dict[pdbid] 99 | query_sequ=seq_dict[pdbid] 100 | for i in real_sequ: 101 | output+=real_sequ[i] 102 | index=output.find(query_sequ) 103 | if index==-1: 104 | continue 105 | #print(pdbid) 106 | flag=0 107 | for i in real_sequ: 108 | if flag==index: 109 | start=i 110 | #print(i) 111 | break 112 | else: 113 | flag+=1 114 | new_dict={} 115 | sorted_list=list(real_sequ.keys()) 116 | sorted_list=sorted_list[flag:flag+len(query_sequ)] 117 | sorted_lists[pdbid]=sorted_list 118 | for i in sorted_list: 119 | new_dict[i]=real_sequ[i] 120 | for i in new_dict: 121 | if i in residue: 122 | outputseq+='1' 123 | else: 124 | outputseq+='0' 125 | f2.write(pdbid+'\t'+query_sequ+'\t'+outputseq+'\n') 126 | with open('step2/query_peptide_sequence_index.txt', 'w') as f3: 127 | for pdbid in sorted_lists: 128 | sequ_index='' 129 | sorted_list=sorted_lists[pdbid] 130 | for i in sorted_list: 131 | sequ_index+=i+',' 132 | f3.write(pdbid+'\t'+sequ_index+'\n') 133 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step1_pdb_process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import subprocess 4 | 5 | def check_abnormal_aa(peptide_seq): 6 | len_seq = len(peptide_seq) 7 | cnt = 0 8 | standard_aa = ['G','A','P','V','L','I','M','F','Y','W','S','T','C','N','Q','K','H','R','D','E'] 9 | for i in peptide_seq: 10 | if i in standard_aa : 11 | cnt = cnt+1 12 | score = float(cnt)/len_seq 13 | return score 14 | 15 | def lower_chain(input_str): 16 | chain_list = list(input_str) 17 | output_list = [] 18 | 19 | for item in chain_list: 20 | if item.isalpha() : 21 | a=item.lower() 22 | else : 23 | a=item 24 | output_list.append(a) 25 | output_str = ''.join(output_list) 26 | return output_str 27 | 28 | 29 | # Step 0: parse the fasta file downloaded from the RCSB PDB 30 | # INPUT : pdb_seqres.txt 31 | # OUTPUT: pdb_pep_chain, pdbid_all_fasta 32 | raw_str='' 33 | inputdir='/home/ec2-user/SageMaker/efs/data/CAMP/' 34 | with open(inputdir + 'pdb_seqres.txt','r') as f: 35 | for line in f.readlines(): 36 | raw_str = raw_str+line.replace('\n','###').replace('->', 'to').replace('<1>', '-1-').replace('<2>', '-2-') 37 | raw_list = raw_str.split('>') 38 | del raw_list[0] 39 | 40 | PDB_id_lst = [x.split('_')[0] for x in raw_list] 41 | PDB_chain_lst = [x.split('_')[1].split(' ')[0].lower() for x in raw_list] 42 | PDB_type_lst = [x.split('mol:')[1].split(' ')[0] for x in raw_list] 43 | PDB_seq_lst = [x.split('###')[1] for x in raw_list] 44 | PDB_seq_len_lst = [len(x) for x in PDB_seq_lst] 45 | df_fasta_raw =pd.DataFrame(list(zip(PDB_type_lst, PDB_seq_len_lst,PDB_seq_lst,PDB_id_lst,PDB_chain_lst)),\ 46 | columns=['PDB_type','PDB_seq_len','PDB_seq','PDB_id','chain']) 47 | df_fasta = df_fasta_raw[(df_fasta_raw.PDB_seq_len<=50)&(df_fasta_raw.PDB_type=='protein')] 48 | df_fasta_raw.to_csv('step1/pdbid_all_fasta', encoding='utf-8', index=False, sep = '\t') 49 | df_fasta.to_csv('step1/pdb_pep_chain', encoding='utf-8', index=False, sep = '\t') 50 | 51 | print('Step 0 is finished by generating two files : pdb_pep_chain & pdbid_all_fasta!') 52 | # Step1 : Load all PDB ids that might contain peptide interaction and plip prediction results 53 | # INPUT : pdb_pep_chain from Step 0 & analyzed file generated by PLIP (placed under ./peptide_result/). There is an example of PLIP result file called example_PLIP_result.txt) 54 | # OUTPUT: plip_predict_result 55 | 56 | def load(pdb_pep_dataset,plip_result_filename): #pdb_pep_chain #plip_predict_result 57 | df_fasta_pep = pd.read_csv(pdb_pep_dataset,sep='\t',header = 0) 58 | df_fasta_pep=df_fasta_pep.reset_index(drop=True) 59 | df_predict = pd.DataFrame(columns=['pdb_id','pep_chain','predicted_chain']) 60 | import random 61 | list_indices = list(range(df_fasta_pep.shape[0])) 62 | #random.shuffle(list_indices) 63 | outdir = '/home/ec2-user/SageMaker/efs/data/CAMP/processed/step1-peptide_result/' 64 | for i in list_indices[0:100]: 65 | pdb_id = df_fasta_pep['PDB_id'][i] 66 | chain = df_fasta_pep['chain'][i] 67 | #result_file_name = './peptide_result/'+pdb_id + '_'+chain+'_result.txt' 68 | result_file_name = outdir +pdb_id + '_'+chain+'_result.txt' 69 | try: 70 | for line in open(result_file_name): 71 | if line.startswith('Interacting chain(s):'): 72 | df_predict.loc[i] = [pdb_id,chain, str(line).replace('\n','')\ 73 | .replace('\r','')\ 74 | .replace('Interacting chain(s):','') 75 | .lower()] 76 | if i % 5000 == 0: 77 | print('already finished files',i) 78 | except : 79 | pass # COMMENT THIS IF YOU WANT TO GET MORE PREDICTIONS 80 | print('Found no file for',pdb_id); print("retrieving from PLIP tool") 81 | # Add --nohydro to output more determinative predictions 82 | dockercmd = '/usr/bin/docker run --rm -v ${PWD}/peptide_result/:/peptide_result/ -w /peptide_result/ -u $(id -u ${USER}):$(id -g ${USER}) pharmai/plip:latest -i ' + pdb_id + ' -yv -t --name ' + pdb_id + '_'+chain+'_result' 83 | p = subprocess.run(dockercmd, shell=True) 84 | p = subprocess.run('rm ${PWD}/peptide_result/*.pdb', shell=True) 85 | p = subprocess.run('rm ${PWD}/peptide_result/*.pse', shell=True) 86 | #print(i,pdb_id,line) 87 | print('finish loading!') 88 | print('-----------------------------------------------------') 89 | print(df_predict.info()) 90 | print('Left with {0}'.format(df_predict.shape)) 91 | df_predict['predicted_chain_num'] = df_predict.predicted_chain.apply(lambda x :len(x.replace(' ','')) ) 92 | df_predict = df_predict.loc[df_predict.predicted_chain_num>0] 93 | df_predict = df_predict.drop('predicted_chain_num',axis = 1) 94 | df_predict['predicted_chain'] = df_predict.predicted_chain.apply(lambda x :\ 95 | x.replace(' ','')) 96 | df_predict['pep_chain'] = df_predict.pep_chain.apply(lambda x :\ 97 | x.replace(' ','')) 98 | df_predict = df_predict.reset_index(drop=True) 99 | print('finish removing PDB ids without any interaction') 100 | print('Left with {0}'.format(df_predict.shape)) 101 | print('-----------------------------------------------------') 102 | df_predict.predicted_chain = df_predict.predicted_chain.apply(lambda x: x.split(',')) 103 | lst_col = 'predicted_chain' 104 | df1 = pd.DataFrame({col:np.repeat(df_predict[col].values, df_predict[lst_col].str.len()) 105 | for col in df_predict.columns.difference([lst_col]) 106 | }).assign(**{lst_col:np.concatenate(df_predict[lst_col].values)})[df_predict.columns.tolist()] 107 | df_predict = df1 108 | # save organized data formatted like (pdb,pep_chain,predicted_prot_chain) 109 | file_name = plip_result_filename 110 | df_predict.to_csv(file_name, encoding='utf-8', index=False, sep='\t') 111 | print('finish exploding comma-seperated predicted chain, successfully saved records:',df_predict.shape[0]) 112 | 113 | print('Step 1 is finished by generating the PLIP prediction file : plip_predict_result. ') 114 | 115 | return df_predict 116 | 117 | # Step 2: Get fasta sequence of the predicted interacting chains 118 | # INPUT: pdbid_all_fasta from Step 0 119 | # OUTPUT: - 120 | def load_all_fasta(all_fasta_file,input_dataset): # pdbid_all_fasta # df_predict 121 | df_fasta = pd.read_csv(all_fasta_file,sep = '\t', header = 0) 122 | df_fasta_protein = df_fasta.loc[df_fasta.PDB_type=='protein'] 123 | 124 | #df_fasta_protein['PDB_id'] = df_fasta_protein.PDB_id_chain.apply(lambda x: x.split('_')[0]) 125 | #df_fasta_protein['chain'] = df_fasta_protein.PDB_id_chain.apply(lambda x: x.split('_')[1].lower()) 126 | df_fasta_vocabulary = df_fasta_protein[['PDB_id','chain','PDB_seq']] 127 | 128 | df_predict_det = pd.merge(input_dataset,df_fasta_vocabulary,how='left',\ 129 | left_on = ['pdb_id','pep_chain'],right_on = ['PDB_id','chain']) 130 | 131 | df_predict_det1 = pd.merge(df_predict_det,df_fasta_vocabulary,how='left',\ 132 | left_on = ['pdb_id','predicted_chain'],right_on = ['PDB_id','chain']) 133 | df_predict_det1 =df_predict_det1.drop(['PDB_id_x','chain_x','PDB_id_y','chain_y'],axis =1) 134 | df_predict_det1.columns = ['pdb_id','pep_chain','predicted_chain','pep_seq','prot_seq'] 135 | df_predict_det1['pep_seq_len'] = df_predict_det1.pep_seq.apply(lambda x: len(x)) 136 | df_predict_det1['prot_seq_len'] = df_predict_det1.prot_seq.apply(lambda x: len(x)) 137 | 138 | 139 | # check sequence length(peptide<=50 & protein >50) 140 | df_predict_det1 = df_predict_det1.loc[(df_predict_det1.pep_seq_len <= 50) & (df_predict_det1.prot_seq_len > 50)] 141 | 142 | 143 | # remove records with more than 20% AA is abnormal 144 | df_predict_det1['peptide_seq_score'] = df_predict_det1.pep_seq.apply(lambda x: check_abnormal_aa(x)) 145 | df_predict_det1 = df_predict_det1[df_predict_det1.peptide_seq_score >= 0.8] 146 | 147 | print('finish removing sequences without too many non-standard residues') 148 | print('-----------------------------------------------------') 149 | 150 | 151 | return df_predict_det1 152 | 153 | # Step 3: Map Uniprot ID for PDB complex by protein-chain & PDB id 154 | # INPUT: data from Step 2 & pdb_chain_uniprot.tsv from SIFT 155 | # OUTPUT: UniProt_ID_list ( all IDs are the searching query on https://www.uniprot.org/uploadlists/ for unified sequence) 156 | def map_uniprot_chain(input_dataset,pdb_chain_uniprot_file): #df_predict_det1 #pdb_chain_uniprot.tsv 157 | df_sifts = pd.read_csv(pdb_chain_uniprot_file, sep = '\t', header = 0, comment='#') 158 | df_sifts = df_sifts[['PDB','CHAIN','SP_PRIMARY']] 159 | df_sifts_keep = df_sifts[df_sifts['CHAIN'] != df_sifts['CHAIN']] 160 | df_sifts = df_sifts[df_sifts['CHAIN'] == df_sifts['CHAIN']] 161 | df_sifts['CHAIN'] = df_sifts.CHAIN.apply(lambda x: lower_chain(x)) 162 | 163 | df_predict_det2 = pd.merge(input_dataset,df_sifts, how = 'left', \ 164 | left_on = ['pdb_id','predicted_chain'],right_on = ['PDB','CHAIN']) 165 | df_predict_det2 = df_predict_det2.drop(['PDB','CHAIN'],axis = 1) 166 | 167 | # subset records that don't have a matched protein chain Uniprot 168 | df_predict_det2_no_uni = df_predict_det2[df_predict_det2.SP_PRIMARY != df_predict_det2.SP_PRIMARY] 169 | df_predict_det2_no_uni = df_predict_det2_no_uni.reset_index(drop = True) 170 | 171 | df_predict_det2_no_uni = df_predict_det2_no_uni.drop(['prot_seq_len','peptide_seq_score'],axis = 1) 172 | df_predict_det2_no_uni = df_predict_det2_no_uni[['pdb_id','pep_chain','predicted_chain','pep_seq',\ 173 | 'pep_seq_len','SP_PRIMARY','prot_seq']] 174 | df_predict_det2_no_uni.rename(columns = {'prot_seq': 'Sequence'}, inplace=True) 175 | 176 | # focus on records with Uniprot Ids 177 | df_predict_det3 = df_predict_det2[df_predict_det2.SP_PRIMARY == df_predict_det2.SP_PRIMARY] 178 | 179 | # save matched uniport ID for retrieving information from Uniprot Website 180 | df_uni_id = df_predict_det3[['SP_PRIMARY']] 181 | df_uni_id.drop_duplicates(inplace = True) 182 | file_name = "step1/pdb_chain_uniprot-filtIDs.tsv" 183 | df_uni_id.to_csv(file_name, encoding = 'utf-8', index = False, sep = '\t') 184 | 185 | return df_predict_det2_no_uni,df_predict_det3 186 | 187 | 188 | # Step 4: Load Uniport sequences and family information & filter out MHC families 189 | # INPUT: the data from Step 3 & uniprot2seq from UniProt Website (a tab separated file with fields including Uniprot_id,Uniprot Sequence,Protein_name,Protein_families) 190 | # OUTPUT: interacted peptide-protein pairs from PDB (a '#' separated file with fields including pdb_id,pep_chain,prot_chain,pep_seq,Uniprot_id,prot_seq,protein_families) 191 | 192 | def load_uni_seq(input_dataset,uniprot2seq_file=None): 193 | df_uni2seq = pd.read_csv(uniprot2seq_file,sep = '\t',header = 0) 194 | df_uni2seq = df_uni2seq.drop('uniprot',axis = 1) 195 | df_uni2seq = df_uni2seq.drop_duplicates(['Uniprot_id','Sequence'],keep = 'first') 196 | df_uni2seq = df_uni2seq.fillna('Unknown_from_uniprot') 197 | 198 | # join by uniprot id 199 | df_predict_det4 = pd.merge(input_dataset,df_uni2seq,how = 'left',left_on = ['SP_PRIMARY'],right_on = ['Uniprot_id']) 200 | df_predict_det4 = df_predict_det4.drop(['Uniprot_id','Protein_name','prot_seq','prot_seq_len','peptide_seq_score'],axis = 1) 201 | df_predict_det4 = df_predict_det4.drop_duplicates(['pdb_id','pep_seq','SP_PRIMARY','Sequence'],keep = 'first') 202 | 203 | # filter out MHC 204 | df_predict_det4["MHC_flag"] = df_predict_det4.Protein_families.apply(lambda x: x.lower().find('mhc')) 205 | df_mhc = df_predict_det4.loc[df_predict_det4.MHC_flag!=-1][["pdb_id","Protein_families"]] 206 | df_mhc.columns = ['pdb_id_mhc','prot_family_mhc'] 207 | 208 | # join by PDB id only(if a pdb contains mhc proteins,remove all records of the PDB id) 209 | df_predict_det5 = pd.merge(df_predict_det4, df_mhc,left_on = ['pdb_id'], right_on = ['pdb_id_mhc'], how='left') 210 | df_predict_det5 = df_predict_det5.loc[df_predict_det5.pdb_id_mhc!=df_predict_det5.pdb_id_mhc] 211 | df_predict_det5 = df_predict_det5.drop(['pdb_id_mhc','prot_family_mhc','MHC_flag'],axis =1) 212 | df_predict_det5.drop_duplicates(inplace=True) 213 | 214 | df_predict_det2_no_uni = df_predict_det2_no_uni.drop(['prot_seq_len','peptide_seq_score'],axis = 1) 215 | df_predict_det2_no_uni = df_predict_det2_no_uni[['pdb_id','pep_chain','predicted_chain','pep_seq','pep_seq_len','SP_PRIMARY','prot_seq']] 216 | df_predict_det2_no_uni.rename(columns={'prot_seq': 'Sequence'}, inplace=True) 217 | df_predict_det2_no_uni['Protein_families'] = pd.Series(['Unknown Uniprot_ids' for x in range(df_predict_det2_no_uni.shape[0])]) 218 | df_predict_det6 = pd.concat([df_predict_det2_no_uni,df_predict_det5],ignore_index=True) 219 | df_predict_det6['plip_prot_chain'] = df_predict_det6.predicted_chain.apply(lambda x:\ 220 | x.upper()) 221 | df_predict_det6 = df_predict_det6.drop_duplicates(['pep_seq','Sequence'],keep='first') 222 | df_predict_det6 = df_predict_det6.reset_index(drop=True) #8184 223 | df_predict_det6['prot_seq_len'] = df_predict_det6.Sequence.apply(lambda x: len(str(x))) 224 | df_predict_det6 = df_predict_det6[df_predict_det6.prot_seq_len<=5000] 225 | 226 | df_pdb_pairs = df_predict_det6[['pdb_id','pep_chain','plip_prot_chain','pep_seq','SP_PRIMARY','Sequence','Protein_families']] 227 | df_pdb_pairs.columns = ['pdb_id','pep_chain','prot_chain','pep_seq','Uniprot_id','prot_seq','protein_families'] 228 | file_name = 'train_pairs_pdb' 229 | df_pdb_pairs.to_csv(file_name, encoding = 'utf-8', index = False, sep = '#') 230 | 231 | 232 | return df_pdb_pairs 233 | 234 | 235 | df_predict = load('step1/pdb_pep_chain','step1/plip_predict_result') 236 | print(df_predict.head()) 237 | df_predict_det1 = load_all_fasta('step1/pdbid_all_fasta',df_predict) 238 | df_predict_det2_no_uni,df_predict_det3 = map_uniprot_chain(df_predict_det1,'step1/pdb_chain_uniprot.tsv') 239 | 240 | print(df_predict_det3.head()) 241 | print(df_predict_det2_no_uni.shape, df_predict_det3.shape) 242 | file_name = "step1/pdb_chain_uniprot-processed.tsv" 243 | df_predict_det3.to_csv(file_name, encoding = 'utf-8', index = False, sep = '\t') 244 | 245 | #df_pdb_pairs = load_uni_seq(df_predict_det3,'step1/uniprot2seq') 246 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step2_pepBDB_pep_bindingsites.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pandas as pd 3 | 4 | """ 5 | BEFORE THIS STEP: (Jasleen added): 6 | 7 | 1a. Prepare pepbdb-2020/pepbdb directory from .tgz download at http://huanglab.phys.hust.edu.cn/pepbdb/db/download/ [DONE] - /home/ec2-user/SageMaker/efs/data/CAMP/paper/pepbdb 8 | 9 | 1b. Run crawl.py to generate crawl_results.csv [ONGOING] 10 | 11 | 2. Next, run query-mapping.py that uses inputs crawl_results.csv, and the peptide.pdb files in pepbdb-20200318/pepbdb/{pdbid} 12 | - This script outputs peptide-mapping.txt, query_peptide_sequence_index.txt 13 | 14 | """ 15 | # Step 1: According to the "PDB ID-Peptide Chain-Protein Chain" obtained in "step1_pdb_process.py" , retrieve the interacting information with following fields: 16 | # ("Peptide ID","Interacting peptide residues","Peptide sequence","Interacting receptor residues","Receptor sequence(s)") a 17 | # nd downloading the corresponding "peptide.pdb" files (please put under ./pepbdb-2020/pepbdb/$pdb_id$/peptide.pdb) 18 | 19 | # Step 2: To map the peptide sequences from PepBDB to the peptide sequences from the peptide sequences from the RCSB PDB() generated in "step1_pdb_process.py"). 20 | 21 | # Generate query (PepBDB version) sequence file called "query_peptide.fasta" & target (RSCB PDB) fasta sequence files called "target_peptide.fasta" for peptides 22 | # We use scripts under ./smith-waterman-src/ to align two versions of peptide sequences. The output is "alignment_result.txt" 23 | #python query_mapping.py #to get peptide sequence vectors (the output is "peptide-mapping.txt ") 24 | #python target_mapping.py #to get target sequence vector 25 | 26 | # Step 3: Loading and mapping labels of binding residues for peptide sequences 27 | # load peptide-protein pairs & pepBDB files (target : PDB fasta, query : pepBDB) 28 | df_train = pd.read_csv('step1/pdb_chain_uniprot-processed.tsv', header=0, sep='\t', comment='#') # The output of "step1_pdb_process.py" 29 | 30 | df_zy_pep = pd.read_csv('step2/peptide-mapping.txt',header=None,sep='\t') 31 | df_zy_pep.columns= ['bdb_id','bdb_pep_seq','pep_binding_vec'] 32 | df_zy_pep['pdb_id'] = df_zy_pep.bdb_id.apply(lambda x: x.split('_')[0]) 33 | df_zy_pep['pep_chain'] = df_zy_pep.bdb_id.apply(lambda x: x.split('_')[1].lower()) 34 | df_zy_pep['prot_chain'] = df_zy_pep.bdb_id.apply(lambda x: x.split('_')[2].upper()) 35 | df_zy_pep.drop_duplicates(['bdb_id'],inplace=True) 36 | 37 | # Since we did not run uniprot based MHC filtering step in Step 1, 38 | # We don't have the protein chain column defined. Instead, we do that 39 | # modification here (same as 'plip_prot_chain' in skipped step). 40 | df_train['prot_chain'] = df_train.predicted_chain.apply(lambda x:\ 41 | x.upper()) 42 | df_join = pd.merge(df_train, df_zy_pep, how='left', left_on=['pdb_id','pep_chain','prot_chain'],right_on=['pdb_id','pep_chain','prot_chain']) 43 | #df_v1 = df_join[['pdb_id','pep_chain','prot_chain','pep_seq','SP_PRIMARY','prot_seq','Protein_families','pep_binding_vec']] 44 | df_v1 = df_join[['pdb_id','pep_chain','prot_chain','pep_seq','SP_PRIMARY','prot_seq','pep_binding_vec']] 45 | print(df_v1.shape) 46 | 47 | # impute records that don't have bs information with -99999 48 | def extract_inter_idx(pep_seq,binding_vec): 49 | if binding_vec==binding_vec: 50 | if len(binding_vec) != len(pep_seq): 51 | print('Error length') 52 | return '-99999' 53 | else: 54 | binding_lst = [] 55 | for idx in range(len(binding_vec)): 56 | if binding_vec[idx]=='1': 57 | binding_lst.append(idx) 58 | binding_str = ','.join(str(e) for e in binding_lst) 59 | return binding_str 60 | else: 61 | return '-99999' 62 | 63 | df_v1['binding_idx'] = df_v1.apply(lambda x: extract_inter_idx(x.pep_seq,x.pep_binding_vec),axis=1) 64 | #df_part_pair = df_part_all[['pep_seq','prot_seq','binding_idx']] 65 | df_part_pair = df_v1[['pep_seq','prot_seq','binding_idx']] 66 | df_pos_bs = pd.merge(df_v1,df_part_pair,how='left',left_on=['pep_seq','prot_seq'],right_on=['pep_seq','prot_seq']).drop_duplicates().reset_index() 67 | df_pos_bs.to_csv('step2/pdb_pairs_bindingsites', encoding = 'utf-8', index = False, sep = ',') 68 | 69 | ofile = open('step2/peptide_sequences.fasta', 'w') 70 | for i in range(df_pos_bs.shape[0]): 71 | ofile.write('>' + df_pos_bs.loc[i, 'pdb_id'] + '_' + df_pos_bs.loc[i, 'pep_chain'] + '\n' + df_pos_bs.loc[i, 'pep_seq'] + '\n') 72 | ofile.close() 73 | 74 | ofile = open('step2/prot_sequences.fasta', 'w') 75 | for i in range(df_pos_bs.shape[0]): 76 | ofile.write('>' + df_pos_bs.loc[i, 'pdb_id'] + '_' + df_pos_bs.loc[i, 'prot_chain'] + '\n' + df_pos_bs.loc[i, 'prot_seq'] + '\n') 77 | ofile.close() 78 | 79 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step3_iupred2a.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes in the sequence fasta files generated from Step2 and processes each sequence in each fasta file through the iupred tool. Note that the input files have to be of naming convention step2/[prot,peptide]_sequences.fasta, and that the results will be stored in the subfolder step3, in the current working directory. 3 | """ 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import os 8 | import subprocess 9 | import pickle 10 | 11 | iupredtool="python3 /home/ec2-user/SageMaker/efs/data/CAMP/tools/iupred2a/iupred2a.py " 12 | querytype = 'prot' # prot or peptide 13 | 14 | fasta_file = './step2/' + querytype + '_sequences.fasta' 15 | 16 | outdir = os.path.abspath('./step3/') 17 | output_intrinsic_dict = outdir + '/intrinsic_dict_' + querytype + '.pkl' 18 | 19 | if not os.path.exists(os.path.abspath(outdir + '/tmp/')): 20 | os.makedirs(os.path.abspath(outdir + '/tmp')) 21 | 22 | def get_iupred_rawscores(fasta_file, outdir, ind, iupredtool=iupredtool): 23 | ## This command runs iupred2a for each sequence 24 | ## (each sequence must be a separate fasta file, so we create temp subset files from main fasta) 25 | ## This is because iupred2a parses all sequences in a fasta as one long sequence -_- 26 | # First we get the input ids and sequences 27 | with open(fasta_file, 'r') as f: 28 | id_list = [] 29 | seq_list = [] 30 | for line in f.readlines(): 31 | line = line.strip() 32 | if line.startswith('>'): 33 | id_list.append(line) 34 | else: 35 | if(len(line) > 0): 36 | seq_list.append(line) 37 | # Then run iupred for each sequence 38 | longseq_pred = {} 39 | for ix in range(len(id_list)): 40 | # Write out each sequence as a fasta file 41 | outfile_seq = os.path.abspath(outdir + '/tmp/' + id_list[ix].strip('>') + '.fasta') 42 | with open(outfile_seq, 'w') as outfile: 43 | for elem in [id_list[ix], seq_list[ix]]: 44 | outfile.write(elem + '\n') 45 | # Run that sequence through Intrinsic Disorder prediction (long, short) 46 | process = subprocess.Popen([iupredtool + outfile_seq + ' ' + ind], stdout=subprocess.PIPE,stderr=subprocess.PIPE, shell=True) 47 | process.wait() 48 | # Store the results 49 | tmpscores = [] 50 | for line in process.stdout.readlines(): 51 | line_list = line.decode("utf-8").strip() 52 | if (len(line_list)>0 and line_list[0]!='#'): 53 | tmpscores.append(line_list) 54 | longseq_pred[id_list[ix]] = tmpscores 55 | return longseq_pred 56 | 57 | # Load Intrinsic disorder 58 | # dict : {sequence: Intrinsic Disorder Matrix} 59 | # Intrinsic Disorder Matrix : (sequence length ,3) , last dimension :(long , short, ANCHOR score) 60 | def load_fasta(fasta_filename): 61 | raw_fasta_list = [] 62 | with open(fasta_filename,'r') as f: 63 | for line in f.readlines(): 64 | line_list = line.strip() 65 | raw_fasta_list.append(line_list) 66 | fasta_id_list = [x for x in raw_fasta_list if x[0]=='>'] 67 | fasta_sequence_list = [x for x in raw_fasta_list if x[0]!='>'] 68 | fasta_seq_len_list = [len(x) for x in fasta_sequence_list] 69 | print(len(fasta_id_list),len(fasta_sequence_list),len(fasta_seq_len_list)) 70 | fasta_dict = {} 71 | for i in range(len(fasta_id_list)): 72 | fasta_dict[fasta_id_list[i]]=(fasta_sequence_list[i],fasta_seq_len_list[i]) 73 | return fasta_dict 74 | 75 | def extract_intrinsic_disorder(fasta_filename, ind, outdir=outdir, iupredtool=iupredtool): 76 | fasta_dict = load_fasta(fasta_filename) 77 | raw_result_dict = get_iupred_rawscores(fasta_filename, outdir, ind, iupredtool=iupredtool) 78 | intrinsic_id_list = list(raw_result_dict.keys()) 79 | raw_score_dict = {} 80 | for idx in range(len(intrinsic_id_list)): 81 | prot_id = intrinsic_id_list[idx] 82 | seq_len = fasta_dict[prot_id][1] 83 | individual_score_list = [x.split('\t') for x in raw_result_dict[prot_id]] 84 | individual_score_list = [x[2:] for x in individual_score_list] 85 | individual_score_array = np.array(individual_score_list,dtype='float') 86 | raw_score_dict[prot_id] = individual_score_array 87 | print(len(fasta_dict.keys()),len(raw_score_dict.keys())) 88 | return fasta_dict, raw_score_dict 89 | 90 | # long & short 91 | # the input fasta file used in IUPred2A 92 | fasta_dict_long, raw_score_dict_long = extract_intrinsic_disorder(fasta_filename=fasta_file, ind='long') 93 | fasta_dict_short, raw_score_dict_short = extract_intrinsic_disorder(fasta_filename=fasta_file, ind='short') 94 | 95 | Intrinsic_score_long = {} 96 | for key in fasta_dict_long.keys(): 97 | sequence = fasta_dict_long[key][0] 98 | seq_len = fasta_dict_long[key][1] 99 | Intrinsic = raw_score_dict_long[key] 100 | if Intrinsic.shape[0]!= seq_len: 101 | print('Error!') 102 | Intrinsic_score_long[sequence] = Intrinsic 103 | 104 | 105 | Intrinsic_score_short = {} 106 | for key in fasta_dict_short.keys(): 107 | sequence = fasta_dict_short[key][0] 108 | seq_len = fasta_dict_short[key][1] 109 | Intrinsic = raw_score_dict_short[key] 110 | if Intrinsic.shape[0]!= seq_len: 111 | print('Error!') 112 | Intrinsic_score_short[sequence] = Intrinsic 113 | 114 | Intrinsic_score = {} 115 | for seq in Intrinsic_score_short.keys(): 116 | long_Intrinsic = Intrinsic_score_long[seq][:,0] 117 | short_Intrinsic = Intrinsic_score_short[seq] 118 | concat_Intrinsic = np.column_stack((long_Intrinsic,short_Intrinsic)) 119 | Intrinsic_score[seq] = np.column_stack((long_Intrinsic,short_Intrinsic)) 120 | 121 | 122 | with open(output_intrisic_dict,'wb') as f: # 'output_intrisic_dict' is the name of the output dict you like 123 | pickle.dump(Intrinsic_score,f) 124 | 125 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step3_pssm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will 3 | a) generate .fasta and .pssm files for each individual protein sequence. 4 | b) process the pssms into a dict with the sequence IDs. 5 | 6 | # May first need to run perl /home/ec2-user/SageMaker/efs/data/CAMP/tools/ncbi-blast-2.13.0+/bin/update_blastdb.pl swissprot.fa 7 | 8 | This script will run the following psiblast command for each sequence, to generate the PSSM: 9 | $psiblast -db swissprot -query $outdir/tmp/1a61_i.fasta -num_iterations 3 -evalue 0.001 -out_ascii_pssm test.pssm 10 | """ 11 | 12 | import pandas as pd 13 | import numpy as np 14 | import os 15 | import subprocess 16 | import pickle 17 | 18 | ### Generate Protein PSSM Files 19 | psiblast="/home/ec2-user/SageMaker/efs/data/CAMP/tools/ncbi-blast-2.13.0+/bin/psiblast " 20 | blastdb='export BLASTDB=/home/ec2-user/SageMaker/efs/data/CAMP/tools/ncbi-blast-2.13.0+/bin/; ' 21 | psiblast_opts = '-db ' + swissprotdb + ' -num_iterations 3 -evalue 0.001 ' 22 | swissprotdb='swissprot' 23 | 24 | querytype = 'prot' # prot or peptide 25 | 26 | fasta_file = './step2/' + querytype + '_sequences.fasta' 27 | 28 | outdir = os.path.abspath('./step3/') 29 | output_pssm_dict = outdir + 'pssm_' + querytype + '.pkl' 30 | 31 | if not os.path.exists(os.path.abspath(outdir + '/tmp/')): 32 | os.makedirs(os.path.abspath(outdir + '/tmp')) 33 | 34 | def get_pssm(fasta_file, outdir, psiblast=psiblast, psiblast_opts=psiblast_opts): 35 | ## This command runs psiblast for each sequence 36 | ## (each sequence must be a separate fasta file, so we create temp subset files from main fasta) 37 | # First we get the input ids and sequences 38 | with open(fasta_file, 'r') as f: 39 | id_list = [] 40 | seq_list = [] 41 | for line in f.readlines(): 42 | line = line.strip() 43 | if line.startswith('>'): 44 | id_list.append(line) 45 | else: 46 | if(len(line) > 0): 47 | seq_list.append(line) 48 | # Then run psiblast for each sequence 49 | longseq_pred = {} 50 | for ix in range(len(id_list)): 51 | # Write out each sequence as a fasta file 52 | outfile_seq = os.path.abspath(outdir + '/tmp/' + id_list[ix].strip('>') + '.fasta') 53 | outfile_pssm = os.path.abspath(outdir + '/tmp/' + id_list[ix].strip('>') + '.pssm') 54 | if not os.path.exists(outfile_pssm): 55 | with open(outfile_seq, 'w') as outfile: 56 | for elem in [id_list[ix], seq_list[ix]]: 57 | outfile.write(elem + '\n') 58 | # Run that sequence through psiblast prediction (long, short) 59 | process = subprocess.Popen([ 60 | blastdb + 61 | psiblast + psiblast_opts + '-query ' + outfile_seq + ' -out_ascii_pssm ' + outfile_pssm], stdout=subprocess.PIPE,stderr=subprocess.PIPE, shell=True) 62 | stdout, stderr = (process.communicate()) 63 | return id_list 64 | 65 | proteins_list = get_pssm(fasta_file = fasta_file, outdir = outdir) 66 | 67 | ### Load Protein PSSM Files (first change the value of protein_number) 68 | # prot_pssm_dict : key is protein sequence, value is protein PSSM Matrix 69 | prot_pssm_dict_all={} 70 | prot_pssm_dict={} 71 | protein_num = len(proteins_list) ### NEED TO BE CHANGED TO the total number of protein sequences 72 | inputs_dir = os.path.abspath(outdir + "tmp/") 73 | for protid in proteins_list: 74 | filename_pssm = protid.strip('>') + '.pssm' # need to name each individual fasta and pssm file with the same prefix 75 | filename_fasta = protid.strip('>') + '.fasta' 76 | prot_key = protid.strip('>') # 'new_prot_'+str(i) 77 | pssm_line_list= [] 78 | 79 | with open(inputs_dir+'/'+filename_fasta,'r') as f: # directory to store fasta files (single file of each protein) 80 | for line in f.readlines(): 81 | prot_seq = line.strip() 82 | 83 | with open(inputs_dir+'/'+filename_pssm,'r') as f: # directory to store pssm files (single file of each protein) 84 | for line in f.readlines()[3:-6]: 85 | line_list = line.strip().split(' ') 86 | line_list = [x for x in line_list if x!=''][2:22] 87 | line_list = [int(x) for x in line_list] 88 | if len(line_list)!=20: 89 | print('Error line:') 90 | print(line_list) 91 | pssm_line_list.append(line_list) 92 | pssm_array = np.array(pssm_line_list) 93 | if pssm_array.shape[1]!=20: 94 | print('Error!') 95 | print(filename_pssm) 96 | else: 97 | prot_pssm_dict_all[prot_key] = (prot_seq,pssm_array) 98 | prot_pssm_dict[prot_seq]=pssm_array 99 | 100 | with open(output_pssm_dict,'wb') as f: # 'output_pssm_dict' is the name of the output dict you like 101 | pickle.dump(prot_pssm_dict,f) 102 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step3_ss.py: -------------------------------------------------------------------------------- 1 | """ 2 | fasta='step2/prot_sequences.fasta' # peptide_sequences.fasta 3 | outpath='step2/ss/prot_ssp' 4 | scratchtool="/home/ec2-user/SageMaker/efs/data/CAMP/tools/SCRATCH-1D_2.0/bin/run_scratch1d_predictors.sh --input_fasta $fasta --output_prefix $outpath " 5 | 6 | fasta='step2/peptide_sequences.fasta' # peptide_sequences.fasta 7 | outpath='step2/ss/peptide_ssp' 8 | scratchtool="/home/ec2-user/SageMaker/efs/data/CAMP/tools/SCRATCH-1D_2.0/bin/run_scratch1d_predictors.sh --input_fasta $fasta --output_prefix $outpath " 9 | """ 10 | import pandas as pd 11 | 12 | #Secondary Structure 13 | seqtype = 'prot' # prot or peptide 14 | output_ss_filename = 'step3/' + seqtype + '_seq_ss.txt' 15 | 16 | # Generate secondary structure predictions first 17 | scratchtool="/home/ec2-user/SageMaker/efs/data/CAMP/tools/SCRATCH-1D_2.0/bin/run_scratch1d_predictors.sh " 18 | 19 | querytype = 'prot' #prot or peptide 20 | input_fasta = './step2/' + querytype + '_sequences.fasta' 21 | 22 | outdir = os.path.abspath('./step2/ss/') 23 | outpath = outdir + querytype + '_ssp' 24 | process = subprocess.Popen([scratchtool + ' --input_fasta ' + input_fasta + ' --output_prefix ' + outpath]) 25 | stdout, stderr = (process.communicate()) 26 | 27 | # load predicted ss features for sequences in the dataset 28 | def aa_ss_concat(aa,ss): 29 | if len(aa)!= len(ss): 30 | return 'string length error!' 31 | else: 32 | new_str = '' 33 | for i in range(len(aa)): 34 | concat_str = aa[i]+ss[i]+',' 35 | new_str = new_str+concat_str 36 | final_str = new_str[:-1] 37 | return final_str 38 | 39 | 40 | #df_org = pd.read_csv('./ss/seq_data.out.ss',sep='#',header = None) #the generated file by SCRATCH1D SSPro 41 | df_org = pd.read_csv('./step2/ss/' + seqtype + '_ssp.ss3',sep='#',header = None) #the generated file by SCRATCH1D SSPro 42 | df_org.columns = ['col_1'] 43 | 44 | # subset sequence dataframe and sse dataframe 45 | df_seqid = df_org.iloc[::4, ] # .iloc[seq_idx] 46 | df_seqid.columns = ['seq_id'] 47 | df_seqid.loc[:, 'seq_id'] = df_seqid['seq_id'].str.replace('>', '') 48 | df_seq = df_org.iloc[1::4, ] # .iloc[seq_idx] 49 | df_seq.columns = ['seq'] 50 | df_ss = df_org.iloc[2::4, ] 51 | df_ss.columns = ['seq_ss'] 52 | 53 | df_seqid = df_seqid.reset_index(drop=True) 54 | df_seq = df_seq.reset_index(drop=True) 55 | df_ss = df_ss.reset_index(drop=True) 56 | 57 | # join sequence & sse together 58 | df_seq_ss = pd.merge(df_seqid, df_ss,left_index=True, right_index=True) 59 | 60 | df_output_ss = pd.merge(df_seq_ss, df_seq,left_index=True, right_index=True) 61 | df_output_ss['concat_seq'] = df_output_ss.apply(lambda x: aa_ss_concat(x['seq'],x['seq_ss']),axis=1) 62 | df_output_ss.to_csv(output_ss_filename, encoding = 'utf-8', index = False, sep = '\t') # 'output_ss_filename' is the name of the output tsv you like 63 | 64 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step4_agg_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script combines all the necessary files and organizes them as needed by preprocess_features.py 3 | """ 4 | import pandas as pd 5 | ss_prot = './step3/prot_seq_ss.txt' 6 | ss_peptide = './step3/peptide_seq_ss.txt' 7 | 8 | # Format the peptide-protein data like 9 | # seq, pep, label, pep_ss, seq_ss 10 | prot_df = pd.read_csv(ss_prot, sep='\t') 11 | pep_df = pd.read_csv(ss_peptide, sep='\t') 12 | prot_df.columns = [x + '_prot' for x in prot_df.columns.tolist()] 13 | pep_df.columns = [x + '_pep' for x in pep_df.columns.tolist()] 14 | 15 | pairs_df = pd.read_csv("step2/pdb_pairs_bindingsites", sep=",") 16 | mappings = pairs_df[['pdb_id', 'pep_chain', 'prot_chain']] 17 | mappings['seq_id_pep'] = mappings['pdb_id'] + '_' + mappings['pep_chain'] 18 | mappings['seq_id_prot'] = mappings['pdb_id'] + '_' + mappings['prot_chain'] 19 | 20 | merged_df = pd.concat([pd.concat([prot_df, mappings], axis=1, join='inner'), pep_df], axis=1, join='inner') 21 | merged_df['label'] = merged_df['pdb_id'] + "_" + merged_df['pep_chain'] + "_" + merged_df['prot_chain'] 22 | #out_df = merged_df[['seq_prot', 'seq_pep', 'label', 'seq_ss_pep', 'seq_ss_prot']] 23 | out_df = merged_df[['seq_prot', 'seq_pep', 'label', 'concat_seq_pep', 'concat_seq_prot']] 24 | out_df.to_csv('test_filename', encoding = 'utf-8', index = False, sep = '\t') 25 | 26 | out_df = merged_df[['seq_prot', 'seq_pep', 'concat_seq_pep', 'concat_seq_prot']] 27 | out_df.to_csv('test_data.tsv', encoding = 'utf-8', index = False, sep = '\t') 28 | 29 | ## Also copy over the pssm dict (protein) and the intrinsic disorder dicst (protein and peptide) 30 | import shutil 31 | shutil.copy2('./step3/pssm_prot.pkl', './dense_feature_dict/Protein_pssm_dict') 32 | shutil.copy2('./step3/intrinsic_dict_prot.pkl', './dense_feature_dict/Protein_Intrinsic_dict') 33 | shutil.copy2('./step3/intrinsic_dict_peptide.pkl', './dense_feature_dict/Peptide_Intrinsic_dict_v3') 34 | -------------------------------------------------------------------------------- /ppi/data_utils/camp/step5_PreProcessFeatures.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import sys 4 | import pickle 5 | import math 6 | 7 | amino_acid_set = { "A": 1, "C": 2, "E": 3, "D": 4, "G": 5, "F": 6, "I": 7, "H": 8, "K": 9, "M": 10, "L": 11, 8 | "N": 12, "Q": 13, "P": 14, "S": 15, "R": 16, "T": 17, "W": 18, "V": 19, "Y": 20, "X": 21 } 9 | 10 | amino_acid_num = 21 11 | 12 | ss_set = {"C": 1, "H": 2, "E": 3} 13 | ss_number = 3 14 | 15 | physicochemical_set={'A': 1, 'C': 3, 'B': 7, 'E': 5, 'D': 5, 'G': 2, 'F': 1, 16 | 'I': 1, 'H': 6, 'K': 6, 'M': 1, 'L': 1, 'O': 7, 'N': 4, 17 | 'Q': 4, 'P': 1, 'S': 4, 'R': 6, 'U': 7, 'T': 4, 'W': 2, 18 | 'V': 1, 'Y': 4, 'X': 7, 'Z': 7} 19 | 20 | residue_list = list(amino_acid_set.keys()) 21 | ss_list = list(ss_set.keys()) 22 | 23 | 24 | new_key_list = [] 25 | for i in residue_list: 26 | for j in ss_list: 27 | str_1 = str(i)+str(j) 28 | new_key_list.append(str_1) 29 | 30 | new_value_list = [x+1 for x in list(range(amino_acid_num*ss_number))] 31 | 32 | seq_ss_set = dict(zip(new_key_list,new_value_list)) 33 | seq_ss_number = amino_acid_num*ss_number #75 34 | 35 | def label_sequence(line, pad_prot_len, res_ind): 36 | X = np.zeros(pad_prot_len) 37 | 38 | for i, res in enumerate(line[:pad_prot_len]): 39 | X[i] = res_ind[res] 40 | 41 | return X 42 | 43 | def label_seq_ss(line, pad_prot_len, res_ind): 44 | line = line.strip().split(',') 45 | X = np.zeros(pad_prot_len) 46 | for i ,res in enumerate(line[:pad_prot_len]): 47 | X[i] = res_ind[res] 48 | return X 49 | 50 | 51 | def sigmoid(x): 52 | return 1 / (1 + math.exp(-x)) 53 | 54 | sigmoid_array=np.vectorize(sigmoid) 55 | 56 | def padding_sigmoid_pssm(x,N): 57 | x = sigmoid_array(x) 58 | padding_array = np.zeros([N,x.shape[1]]) 59 | if x.shape[0]>=N: # sequence is longer than N 60 | padding_array[:N,:x.shape[1]] = x[:N,:] 61 | else: 62 | padding_array[:x.shape[0],:x.shape[1]] = x 63 | return padding_array 64 | 65 | def padding_intrinsic_disorder(x,N): 66 | padding_array = np.zeros([N,x.shape[1]]) 67 | if x.shape[0]>=N: # sequence is longer than N 68 | padding_array[:N,:x.shape[1]] = x[:N,:] 69 | else: 70 | padding_array[:x.shape[0],:x.shape[1]] = x 71 | return padding_array 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | input_file = sys.argv[1] 77 | f = open(input_file) 78 | pep_set = set() 79 | prot_set = set() 80 | pep_ss_set = set() 81 | prot_ss_set = set() 82 | 83 | for line in f.readlines()[1:]: # if the file has headers and pay attention to the columns (whether have peptide binding site labels) 84 | prot, pep, label, pep_ss, prot_ss = line.strip().split('\t') 85 | pep_set.add(pep) 86 | prot_set.add(prot) 87 | pep_ss_set.add(pep_ss) 88 | prot_ss_set.add(prot_ss) 89 | 90 | f.close() 91 | pep_len = [len(pep) for pep in pep_set] 92 | prot_len = [len(seq) for seq in prot_set] 93 | pep_ss_len = [len(pep_ss) for pep_ss in pep_ss_set] 94 | prot_ss_len = [len(seq_ss) for seq_ss in prot_ss_set] 95 | 96 | pep_len.sort() 97 | prot_len.sort() 98 | pep_ss_len.sort() 99 | prot_ss_len.sort() 100 | pad_pep_len = 50 101 | pad_prot_len = prot_len[int(0.8*len(prot_len))-1] 102 | print('num of peptides', len(pep_len), 'pad_pep_len', pad_pep_len) 103 | print('prot_set', len(prot_len), 'pad_prot_len', pad_prot_len) 104 | print('num of peptide ss', len(pep_ss_len), 'pad_pep_len', pad_pep_len) 105 | print('prot_ss_set', len(prot_ss_len), 'pad_prot_len', pad_prot_len) 106 | np.save('./preprocessing/pad_pep_len',pad_pep_len) 107 | np.save('./preprocessing/pad_prot_len',pad_prot_len) 108 | np.save('./preprocessing/pad_pep_len',pad_pep_len) 109 | np.save('./preprocessing/_pad_prot_len',pad_prot_len) 110 | 111 | 112 | # load raw dense features, the directory dense_feature_dict and proprocessing need to be created first. 113 | with open('./dense_feature_dict/Protein_pssm_dict', 'rb') as f: # value: (sequence_length, 20) without sigmoid 114 | protein_pssm_dict = pickle.load(f) 115 | 116 | with open('./dense_feature_dict/Protein_Intrinsic_dict', 'rb') as f: # value: (sequence_length, 3): long, short, anchor 117 | protein_intrinsic_dict = pickle.load(f) 118 | 119 | with open('./dense_feature_dict/Peptide_Intrinsic_dict_v3', 'rb') as f: # value: (sequence_length, 3): long, short, anchor 120 | peptide_intrinsic_dict = pickle.load(f) 121 | 122 | peptide_feature_dict = {} 123 | protein_feature_dict = {} 124 | 125 | peptide_ss_feature_dict = {} 126 | protein_ss_feature_dict = {} 127 | 128 | peptide_2_feature_dict = {} 129 | protein_2_feature_dict = {} 130 | 131 | peptide_dense_feature_dict = {} 132 | protein_dense_feature_dict = {} 133 | 134 | protein_intrinsic_feature_dict = {} 135 | f = open(input_file) 136 | for line in f.readlines()[1:]: 137 | prot, pep, label, pep_ss, prot_ss = line.strip().split('\t') 138 | if pep not in peptide_feature_dict: 139 | feature = label_sequence(pep, pad_pep_len, amino_acid_set) 140 | peptide_feature_dict[pep] = feature 141 | if prot not in protein_feature_dict: 142 | feature = label_sequence(prot, pad_prot_len, amino_acid_set) 143 | protein_feature_dict[prot] = feature 144 | if pep_ss not in peptide_ss_feature_dict: 145 | feature = label_seq_ss(pep_ss, pad_pep_len, seq_ss_set) 146 | peptide_ss_feature_dict[pep_ss] = feature 147 | if prot_ss not in protein_ss_feature_dict: 148 | feature = label_seq_ss(prot_ss, pad_prot_len, seq_ss_set) 149 | protein_ss_feature_dict[prot_ss] = feature 150 | if pep not in peptide_2_feature_dict: 151 | feature = label_sequence(pep, pad_pep_len, physicochemical_set) 152 | peptide_2_feature_dict[pep] = feature 153 | if prot not in protein_2_feature_dict: 154 | feature = label_sequence(prot, pad_prot_len, physicochemical_set) 155 | protein_2_feature_dict[prot] = feature 156 | if pep not in peptide_dense_feature_dict: 157 | feature = padding_intrinsic_disorder(peptide_intrinsic_dict[pep], pad_pep_len) 158 | peptide_dense_feature_dict[pep] = feature 159 | if prot not in protein_dense_feature_dict: 160 | feature_pssm = padding_sigmoid_pssm(protein_pssm_dict[prot], pad_prot_len) 161 | feature_intrinsic = padding_intrinsic_disorder(protein_intrinsic_dict[prot], pad_prot_len) 162 | feature_dense = np.concatenate((feature_pssm, feature_intrinsic), axis=1) 163 | protein_dense_feature_dict[prot] = feature_dense 164 | if prot not in protein_intrinsic_feature_dict: 165 | feature_intrinsic = padding_intrinsic_disorder(protein_intrinsic_dict[prot], pad_prot_len) 166 | protein_intrinsic_feature_dict[prot] = feature_intrinsic 167 | 168 | f.close() 169 | 170 | with open('./preprocessing/peptide_feature_dict','wb') as f: 171 | pickle.dump(peptide_feature_dict,f) 172 | with open('./preprocessing/protein_feature_dict','wb') as f: 173 | pickle.dump(protein_feature_dict,f) 174 | with open('./preprocessing/peptide_ss_feature_dict','wb') as f: 175 | pickle.dump(peptide_ss_feature_dict,f) 176 | with open('./preprocessing/protein_ss_feature_dict','wb') as f: 177 | pickle.dump(protein_ss_feature_dict,f) 178 | with open('./preprocessing/peptide_2_feature_dict','wb') as f: 179 | pickle.dump(peptide_2_feature_dict,f) 180 | with open('./preprocessing/protein_2_feature_dict','wb') as f: 181 | pickle.dump(protein_2_feature_dict,f) 182 | with open('./preprocessing/peptide_dense_feature_dict','wb') as f: 183 | pickle.dump(peptide_dense_feature_dict,f) 184 | with open('./preprocessing/protein_dense_feature_dict','wb') as f: 185 | pickle.dump(protein_dense_feature_dict,f) 186 | 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /ppi/data_utils/contact_map_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | """ 5 | Helpers for parsing protein structure files and generating contact maps. 6 | """ 7 | 8 | import gzip 9 | import boto3 10 | import numpy as np 11 | import pandas as pd 12 | from io import StringIO 13 | from Bio.PDB.Polypeptide import three_to_one, is_aa 14 | from Bio.PDB import PDBParser 15 | from Bio.PDB.PDBIO import PDBIO 16 | from Bio.PDB.Entity import Entity as PDBEntity 17 | from rdkit import Chem 18 | from tqdm import tqdm 19 | from .xpdb import SloppyStructureBuilder 20 | 21 | 22 | def gunzip_to_ram(gzip_file_path): 23 | """ 24 | gunzip a gzip file and decode it to a io.StringIO object. 25 | 26 | Args: 27 | gzip_file_path: String. Gunzip filepath. 28 | 29 | Returns: 30 | io.StringIO object. 31 | """ 32 | content = [] 33 | with gzip.open(gzip_file_path, "rb") as f: 34 | for line in f: 35 | content.append(line.decode("utf-8")) 36 | 37 | temp_fp = StringIO("".join(content)) 38 | return temp_fp 39 | 40 | 41 | def _parse_structure(parser, name, file_path): 42 | """Parse a .pdb or .cif file into a structure object. 43 | The file can be gzipped. 44 | 45 | Args: 46 | parser: a Bio.PDB.PDBParser or Bio.PDB.MMCIFParser instance. 47 | name: String. name of protein 48 | file_path: String. Filpath of the pdb or cif file to be read. 49 | 50 | Retruns: 51 | a Bio.PDB.Structure object representing the protein structure. 52 | 53 | """ 54 | if pd.isnull(file_path): 55 | return None 56 | if file_path.endswith(".gz"): 57 | structure = parser.get_structure(name, gunzip_to_ram(file_path)) 58 | else: # not gzipped 59 | structure = parser.get_structure(name, file_path) 60 | return structure 61 | 62 | 63 | parse_pdb_structure = _parse_structure # for backward compatiblity 64 | 65 | 66 | def parse_structure(pdb_parser, cif_parser, name, file_path): 67 | """Parse a .pdb file or .cif file into a structure object. 68 | The file can be gzipped. 69 | 70 | Args: 71 | pdb_parser: a Bio.PDB.PDBParser instance 72 | cif_parser: Bio.PDB.MMCIFParser instance 73 | name: String. name of protein 74 | file_path: String. Filpath of the pdb or cif file to be read. 75 | 76 | Return: 77 | a Bio.PDB.Structure object representing the protein structure. 78 | """ 79 | if file_path.rstrip(".gz").endswith("pdb"): 80 | return _parse_structure(pdb_parser, name, file_path) 81 | else: 82 | return _parse_structure(cif_parser, name, file_path) 83 | 84 | 85 | def three_to_one_standard(res): 86 | """Encode non-standard AA to X. 87 | 88 | Args: 89 | res: a Bio.PDB.Residue object representing the residue. 90 | 91 | Return: 92 | String. One letter code of the residue. 93 | """ 94 | if not is_aa(res, standard=True): 95 | return "X" 96 | return three_to_one(res) 97 | 98 | 99 | def get_atom_coords(residue, target_atoms=["N", "CA", "C", "O"]): 100 | """Extract the coordinates of the target_atoms from an AA residue. 101 | Handles exception where residue doesn't contain certain atoms 102 | by setting coordinates to np.nan 103 | 104 | Args: 105 | residue: a Bio.PDB.Residue object. 106 | target_atoms: Target atoms which residues will be resturned. 107 | 108 | Returns: 109 | np arrays with target atoms 3D coordinates in the order of target atoms. 110 | """ 111 | atom_coords = [] 112 | for atom in target_atoms: 113 | try: 114 | coord = residue[atom].coord 115 | except KeyError: 116 | coord = [np.nan] * 3 117 | atom_coords.append(coord) 118 | return np.asarray(atom_coords) 119 | 120 | 121 | def chain_to_coords( 122 | chain, target_atoms=["N", "CA", "C", "O"], name="", residue_smiles=False 123 | ): 124 | """Convert a PDB chain in to coordinates of target atoms from all 125 | AAs 126 | 127 | Args: 128 | chain: a Bio.PDB.Chain object 129 | target_atoms: Target atoms which residues will be resturned. 130 | name: String. Name of the protein. 131 | residue_smiles: bool. Whether to get a list of smiles strings for the residues 132 | Returns: 133 | Dictonary containing protein sequence `seq`, 3D coordinates `coord` and name `name`. 134 | 135 | """ 136 | output = {} 137 | # get AA sequence in the pdb structure 138 | pdb_seq = "".join( 139 | [ 140 | three_to_one_standard(res.get_resname()) 141 | for res in chain.get_residues() 142 | if is_aa(res) 143 | ] 144 | ) 145 | if len(pdb_seq) <= 1: 146 | # has no or only 1 AA in the chain 147 | return None 148 | output["seq"] = pdb_seq 149 | if residue_smiles: 150 | residues = [] 151 | for res in chain.get_residues(): 152 | if is_aa(res): 153 | mol = residue_to_mol(res) 154 | residues.append(Chem.MolToSmiles(mol)) 155 | output["residues"] = residues 156 | # get the atom coords 157 | coords = np.asarray( 158 | [ 159 | get_atom_coords(res, target_atoms=target_atoms) 160 | for res in chain.get_residues() 161 | if is_aa(res) 162 | ] 163 | ) 164 | output["coords"] = coords.tolist() 165 | output["name"] = "{}-{}".format(name, chain.id) 166 | return output 167 | 168 | 169 | def read_file_from_s3(bucket: str, prefix: str): 170 | s3 = boto3.resource("s3") 171 | obj = s3.Object(bucket, prefix) 172 | return obj.get()["Body"] 173 | 174 | 175 | def extract_coords( 176 | structure, target_atoms=["N", "CA", "C", "O"], residue_smiles=False 177 | ): 178 | """ 179 | Extract the atomic coordinates for all the chains. 180 | """ 181 | records = {} 182 | for chain in structure.get_chains(): 183 | record = chain_to_coords( 184 | chain, 185 | name=structure.id, 186 | target_atoms=target_atoms, 187 | residue_smiles=residue_smiles, 188 | ) 189 | if record is not None: 190 | records[chain.id] = record 191 | return records 192 | 193 | 194 | def parse_pdb_ids(pdb_ids: list, residue_smiles=False) -> dict: 195 | """ 196 | Parse a list of PDB ids to structures by first retrieving 197 | PDB files from AWS OpenData Registry, then parse to structure objects. 198 | """ 199 | PDB_BUCKET_NAME = "pdbsnapshots" 200 | pdb_parser = PDBParser( 201 | QUIET=True, 202 | PERMISSIVE=True, 203 | structure_builder=SloppyStructureBuilder(), 204 | ) 205 | parsed_structures = {} 206 | for pdb_id in tqdm(pdb_ids): 207 | try: 208 | pdb_file = read_file_from_s3( 209 | PDB_BUCKET_NAME, 210 | f"20220103/pub/pdb/data/structures/all/pdb/pdb{pdb_id.lower()}.ent.gz", 211 | ) 212 | except Exception as e: 213 | print(pdb_id, "caused the following error:") 214 | print(e) 215 | else: 216 | structure = pdb_parser.get_structure( 217 | pdb_id, gunzip_to_ram(pdb_file) 218 | ) 219 | rec = extract_coords(structure, residue_smiles=residue_smiles) 220 | parsed_structures[pdb_id] = rec 221 | return parsed_structures 222 | 223 | 224 | def remove_nan_residues(rec: dict) -> dict: 225 | """ 226 | Remove the residues from a parsed protein chain where coordinates contains nan's 227 | """ 228 | if len(rec["coords"]) == 0: 229 | return None 230 | coords = np.asarray(rec["coords"]) # shape: (n_residues, 4, 3) 231 | mask = np.isfinite(coords.sum(axis=(1, 2))) 232 | if mask.sum() == 0: 233 | # all residues coordinates are nan's 234 | return None 235 | if mask.sum() < coords.shape[0]: 236 | rec["seq"] = "".join(np.asarray(list(rec["seq"]))[mask]) 237 | rec["coords"] = coords[mask].tolist() 238 | return rec 239 | 240 | 241 | def residue_to_mol(residue: PDBEntity, **kwargs) -> Chem.rdchem.Mol: 242 | """Convert a parsed Biopython PDB object (Residue, Chain, Structure) to a 243 | rdkit Mol object""" 244 | # Write the PDB object into PDB string 245 | stream = StringIO() 246 | pdbio = PDBIO() 247 | pdbio.set_structure(residue) 248 | pdbio.save(stream) 249 | # Parse the PDB string with rdkit 250 | mol = Chem.MolFromPDBBlock(stream.getvalue(), **kwargs) 251 | return mol 252 | 253 | 254 | def mol_to_pdb_structure( 255 | mol: Chem.rdchem.Mol, pdb_parser=None, protein_id="" 256 | ) -> PDBEntity: 257 | """ 258 | Convert a rdkit Mol object to a Biopython PDB Structure object 259 | """ 260 | # Write the Mol object into PDB string 261 | stream = StringIO() 262 | stream.write(Chem.MolToPDBBlock(mol)) 263 | stream.seek(0) 264 | # parse the stream into a PDB Structure object 265 | if not pdb_parser: 266 | pdb_parser = PDBParser( 267 | QUIET=True, 268 | PERMISSIVE=True, 269 | structure_builder=SloppyStructureBuilder(), 270 | ) 271 | structure = pdb_parser.get_structure(protein_id, stream) 272 | return structure 273 | -------------------------------------------------------------------------------- /ppi/data_utils/pignet_featurizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for computing physically relevant properties of molecules 3 | Source: https://github.com/ACE-KAIST/PIGNet/blob/main/dataset.py 4 | """ 5 | 6 | import random 7 | from typing import Any, Dict, List 8 | 9 | import numpy as np 10 | from rdkit import Chem, RDLogger 11 | from rdkit.Chem import Atom, Mol 12 | from rdkit.Chem.rdMolDescriptors import CalcNumRotatableBonds 13 | 14 | RDLogger.DisableLog("rdApp.*") 15 | random.seed(0) 16 | 17 | INTERACTION_TYPES = [ 18 | # "saltbridge", 19 | "hbonds", 20 | # "pication", 21 | # "pistack", 22 | # "halogen", 23 | # "waterbridge", 24 | "hydrophobic", 25 | "metal_complexes", 26 | ] 27 | pt = """ 28 | H,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,HE 29 | LI,BE,1,1,1,1,1,1,1,1,1,1,B,C,N,O,F,NE 30 | NA,MG,1,1,1,1,1,1,1,1,1,1,AL,SI,P,S,CL,AR 31 | K,CA,SC,TI,V,CR,MN,FE,CO,NI,CU,ZN,GA,GE,AS,SE,BR,KR 32 | RB,SR,Y,ZR,NB,MO,TC,RU,RH,PD,AG,CD,IN,SN,SB,TE,I,XE 33 | CS,BA,LU,HF,TA,W,RE,OS,IR,PT,AU,HG,TL,PB,BI,PO,AT,RN 34 | """ 35 | PERIODIC_TABLE = dict() 36 | for i, per in enumerate(pt.split()): 37 | for j, ele in enumerate(per.split(",")): 38 | PERIODIC_TABLE[ele] = (i, j) 39 | PERIODS = [0, 1, 2, 3, 4, 5] 40 | GROUPS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 41 | SYMBOLS = ["C", "N", "O", "S", "F", "P", "Cl", "Br", "X"] 42 | DEGREES = [0, 1, 2, 3, 4, 5] 43 | HYBRIDIZATIONS = [ 44 | Chem.rdchem.HybridizationType.S, 45 | Chem.rdchem.HybridizationType.SP, 46 | Chem.rdchem.HybridizationType.SP2, 47 | Chem.rdchem.HybridizationType.SP3, 48 | Chem.rdchem.HybridizationType.SP3D, 49 | Chem.rdchem.HybridizationType.SP3D2, 50 | Chem.rdchem.HybridizationType.UNSPECIFIED, 51 | ] 52 | FORMALCHARGES = [-2, -1, 0, 1, 2, 3, 4] 53 | METALS = ("Zn", "Mn", "Co", "Mg", "Ni", "Fe", "Ca", "Cu") 54 | HYDROPHOBICS = ("F", "CL", "BR", "I") 55 | VDWRADII = { 56 | 6: 1.90, 57 | 7: 1.8, 58 | 8: 1.7, 59 | 16: 2.0, 60 | 15: 2.1, 61 | 9: 1.5, 62 | 17: 1.8, 63 | 35: 2.0, 64 | 53: 2.2, 65 | 30: 1.2, 66 | 25: 1.2, 67 | 26: 1.2, 68 | 27: 1.2, 69 | 12: 1.2, 70 | 28: 1.2, 71 | 20: 1.2, 72 | 29: 1.2, 73 | } 74 | HBOND_DONOR_INDICES = ["[!#6;!H0]"] 75 | HBOND_ACCEPPTOR_SMARTS = [ 76 | "[$([!#6;+0]);!$([F,Cl,Br,I]);!$([o,s,nX3]);!$([Nv5,Pv5,Sv4,Sv6])]" 77 | ] 78 | 79 | 80 | def get_period_group(atom: Atom) -> List[bool]: 81 | period, group = PERIODIC_TABLE[atom.GetSymbol().upper()] 82 | return one_of_k_encoding(period, PERIODS) + one_of_k_encoding( 83 | group, GROUPS 84 | ) 85 | 86 | 87 | def one_of_k_encoding(x: Any, allowable_set: List[Any]) -> List[bool]: 88 | if x not in allowable_set: 89 | raise Exception( 90 | "input {0} not in allowable set{1}:".format(x, allowable_set) 91 | ) 92 | return list(map(lambda s: x == s, allowable_set)) 93 | 94 | 95 | def one_of_k_encoding_unk(x: Any, allowable_set: List[Any]) -> List[bool]: 96 | """Maps inputs not in the allowable set to the last element.""" 97 | if x not in allowable_set: 98 | x = allowable_set[-1] 99 | return list(map(lambda s: x == s, allowable_set)) 100 | 101 | 102 | def atom_feature(mol: Mol, atom_index: int) -> np.ndarray: 103 | atom = mol.GetAtomWithIdx(atom_index) 104 | return np.array( 105 | one_of_k_encoding_unk(atom.GetSymbol(), SYMBOLS) 106 | + one_of_k_encoding_unk(atom.GetDegree(), DEGREES) 107 | + one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDIZATIONS) 108 | + one_of_k_encoding_unk(atom.GetFormalCharge(), FORMALCHARGES) 109 | + get_period_group(atom) 110 | + [atom.GetIsAromatic()] 111 | ) # (9, 6, 7, 7, 24, 1) --> total 54 112 | 113 | 114 | def get_atom_feature(mol: Mol) -> np.ndarray: 115 | natoms = mol.GetNumAtoms() 116 | H = [] 117 | for idx in range(natoms): 118 | H.append(atom_feature(mol, idx)) 119 | H = np.array(H) 120 | return H 121 | 122 | 123 | def get_vdw_radius(atom: Atom) -> float: 124 | atomic_num = atom.GetAtomicNum() 125 | if VDWRADII.get(atomic_num): 126 | return VDWRADII[atomic_num] 127 | return Chem.GetPeriodicTable().GetRvdw(atomic_num) 128 | 129 | 130 | def get_hydrophobic_atom(mol: Mol) -> np.ndarray: 131 | natoms = mol.GetNumAtoms() 132 | hydrophobic_indice = np.zeros((natoms,)) 133 | for atom_idx in range(natoms): 134 | atom = mol.GetAtomWithIdx(atom_idx) 135 | symbol = atom.GetSymbol() 136 | if symbol.upper() in HYDROPHOBICS: 137 | hydrophobic_indice[atom_idx] = 1 138 | elif symbol.upper() in ["C"]: 139 | neighbors = [x.GetSymbol() for x in atom.GetNeighbors()] 140 | neighbors_wo_c = list(set(neighbors) - set(["C"])) 141 | if len(neighbors_wo_c) == 0: 142 | hydrophobic_indice[atom_idx] = 1 143 | return hydrophobic_indice 144 | 145 | 146 | def get_A_hydrophobic(ligand_mol: Mol, target_mol: Mol) -> np.ndarray: 147 | ligand_indice = get_hydrophobic_atom(ligand_mol) 148 | target_indice = get_hydrophobic_atom(target_mol) 149 | return np.outer(ligand_indice, target_indice) 150 | 151 | 152 | def get_hbond_atom_indices(mol: Mol, smarts_list: List[str]) -> np.ndarray: 153 | indice = [] 154 | for smarts in smarts_list: 155 | smarts = Chem.MolFromSmarts(smarts) 156 | indice += [idx[0] for idx in mol.GetSubstructMatches(smarts)] 157 | indice = np.array(indice) 158 | return indice 159 | 160 | 161 | def get_A_hbond(ligand_mol: Mol, target_mol: Mol) -> np.ndarray: 162 | ligand_h_acc_indice = get_hbond_atom_indices( 163 | ligand_mol, HBOND_ACCEPPTOR_SMARTS 164 | ) 165 | target_h_acc_indice = get_hbond_atom_indices( 166 | target_mol, HBOND_ACCEPPTOR_SMARTS 167 | ) 168 | ligand_h_donor_indice = get_hbond_atom_indices( 169 | ligand_mol, HBOND_DONOR_INDICES 170 | ) 171 | target_h_donor_indice = get_hbond_atom_indices( 172 | target_mol, HBOND_DONOR_INDICES 173 | ) 174 | 175 | hbond_indice = np.zeros( 176 | (ligand_mol.GetNumAtoms(), target_mol.GetNumAtoms()) 177 | ) 178 | for i in ligand_h_acc_indice: 179 | for j in target_h_donor_indice: 180 | hbond_indice[i, j] = 1 181 | for i in ligand_h_donor_indice: 182 | for j in target_h_acc_indice: 183 | hbond_indice[i, j] = 1 184 | return hbond_indice 185 | 186 | 187 | def get_A_metal_complexes(ligand_mol: Mol, target_mol: Mol) -> np.ndarray: 188 | ligand_h_acc_indice = get_hbond_atom_indices( 189 | ligand_mol, HBOND_ACCEPPTOR_SMARTS 190 | ) 191 | target_h_acc_indice = get_hbond_atom_indices( 192 | target_mol, HBOND_ACCEPPTOR_SMARTS 193 | ) 194 | ligand_metal_indice = np.array( 195 | [ 196 | idx 197 | for idx in range(ligand_mol.GetNumAtoms()) 198 | if ligand_mol.GetAtomWithIdx(i).GetSymbol() in METALS 199 | ] 200 | ) 201 | target_metal_indice = np.array( 202 | [ 203 | idx 204 | for idx in range(target_mol.GetNumAtoms()) 205 | if target_mol.GetAtomWithIdx(i).GetSymbol() in METALS 206 | ] 207 | ) 208 | 209 | metal_indice = np.zeros( 210 | (ligand_mol.GetNumAtoms(), target_mol.GetNumAtoms()) 211 | ) 212 | for ligand_idx in ligand_h_acc_indice: 213 | for target_idx in target_metal_indice: 214 | metal_indice[ligand_idx, target_idx] = 1 215 | for ligand_idx in ligand_metal_indice: 216 | for target_idx in target_h_acc_indice: 217 | metal_indice[ligand_idx, target_idx] = 1 218 | return metal_indice 219 | 220 | 221 | def get_interaction_indices(ligand_mol: Mol, target_mol: Mol) -> np.array: 222 | interaction_indice = np.zeros( 223 | ( 224 | len(INTERACTION_TYPES), 225 | ligand_mol.GetNumAtoms(), 226 | target_mol.GetNumAtoms(), 227 | ) 228 | ) 229 | interaction_indice[0] = get_A_hbond(ligand_mol, target_mol) 230 | interaction_indice[1] = get_A_metal_complexes(ligand_mol, target_mol) 231 | interaction_indice[2] = get_A_hydrophobic(ligand_mol, target_mol) 232 | return interaction_indice 233 | 234 | 235 | def mol_to_feature( 236 | ligand_mol: Mol, target_mol: Mol, compute_full: bool = False 237 | ) -> Dict[str, Any]: 238 | """ 239 | Args: 240 | compute_full: if True, compute components for 241 | both inter and intra energies. 242 | """ 243 | # Remove hydrogens 244 | ligand_mol = Chem.RemoveHs(ligand_mol) 245 | target_mol = Chem.RemoveHs(target_mol) 246 | 247 | # prepare ligand 248 | ligand_pos = np.array(ligand_mol.GetConformers()[0].GetPositions()) 249 | 250 | # prepare protein 251 | target_pos = np.array(target_mol.GetConformers()[0].GetPositions()) 252 | 253 | interaction_indice = get_interaction_indices(ligand_mol, target_mol) 254 | 255 | # count rotatable bonds 256 | rotor = CalcNumRotatableBonds(ligand_mol) 257 | 258 | # no metal 259 | ligand_non_metal = np.array( 260 | [ 261 | 1 if atom.GetSymbol() not in METALS else 0 262 | for atom in ligand_mol.GetAtoms() 263 | ] 264 | ) 265 | target_non_metal = np.array( 266 | [ 267 | 1 if atom.GetSymbol() not in METALS else 0 268 | for atom in target_mol.GetAtoms() 269 | ] 270 | ) 271 | # vdw radius 272 | ligand_vdw_radii = np.array( 273 | [get_vdw_radius(atom) for atom in ligand_mol.GetAtoms()] 274 | ) 275 | target_vdw_radii = np.array( 276 | [get_vdw_radius(atom) for atom in target_mol.GetAtoms()] 277 | ) 278 | 279 | sample = { 280 | "interaction_indice": interaction_indice, 281 | "ligand_pos": ligand_pos, 282 | "target_pos": target_pos, 283 | "rotor": rotor, 284 | "ligand_vdw_radii": ligand_vdw_radii, 285 | "target_vdw_radii": target_vdw_radii, 286 | "ligand_non_metal": ligand_non_metal, 287 | "target_non_metal": target_non_metal, 288 | } 289 | if compute_full: 290 | sample["ligand_interaction_indice"] = get_interaction_indices( 291 | ligand_mol, ligand_mol 292 | ) 293 | sample["target_interaction_indice"] = get_interaction_indices( 294 | target_mol, target_mol 295 | ) 296 | sample["rotor_target"] = CalcNumRotatableBonds(target_mol) 297 | return sample 298 | -------------------------------------------------------------------------------- /ppi/data_utils/residue_featurizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for featurizing a small molecule (amino acid residue) from their structures. 3 | """ 4 | import dgl 5 | from typing import Union, List 6 | from transformers import T5Tokenizer, T5EncoderModel 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | from dgllife.utils import ( 11 | mol_to_bigraph, 12 | PretrainAtomFeaturizer, 13 | PretrainBondFeaturizer, 14 | ) 15 | from rdkit import Chem 16 | from rdkit.Chem import MACCSkeys 17 | from rdkit.Chem import AllChem 18 | from dgllife.model import load_pretrained 19 | from dgl.nn.pytorch.glob import ( 20 | GlobalAttentionPooling, 21 | SumPooling, 22 | AvgPooling, 23 | MaxPooling, 24 | Set2Set, 25 | ) 26 | 27 | 28 | class BaseResidueFeaturizer(object): 29 | """A simple base class with caching""" 30 | 31 | def __init__(self): 32 | self.cache = {} 33 | 34 | def featurize(self, smiles: str) -> torch.tensor: 35 | if smiles not in self.cache: 36 | self.cache[smiles] = self._featurize(smiles) 37 | return self.cache[smiles] 38 | 39 | def _featurize(self, smiles: str) -> torch.tensor: 40 | raise NotImplementedError 41 | 42 | 43 | class FingerprintFeaturizer(BaseResidueFeaturizer): 44 | """ 45 | https://www.rdkit.org/docs/GettingStartedInPython.html#list-of-available-fingerprints 46 | """ 47 | 48 | def __init__(self, fingerprint_type): 49 | self.fingerprint_type = fingerprint_type 50 | super(FingerprintFeaturizer, self).__init__() 51 | 52 | def _featurize(self, smiles: str) -> torch.tensor: 53 | if self.fingerprint_type == "dummy": 54 | return torch.zeros(167) 55 | mol = Chem.MolFromSmiles(smiles) 56 | if self.fingerprint_type == "MACCS": 57 | fps = MACCSkeys.GenMACCSKeys(mol) 58 | elif self.fingerprint_type == "Morgan": 59 | fps = AllChem.GetMorganFingerprintAsBitVect( 60 | mol, 2, useFeatures=True, nBits=1024 61 | ) 62 | else: 63 | raise NotImplementedError 64 | # convert ExplicitBitVect to uint vector: 65 | fps = fps.ToBitString().encode() 66 | fps_vec = torch.from_numpy(np.frombuffer(fps, "u1") - ord("0")) 67 | return fps_vec 68 | 69 | 70 | class GINFeaturizer(BaseResidueFeaturizer, nn.Module): 71 | """ 72 | Convert a molecule to atom graph, then apply pretrained GNN 73 | to featurize the graph as a vector. 74 | """ 75 | 76 | def __init__( 77 | self, gin_model, readout="attention", requires_grad=False, device="cpu" 78 | ): 79 | nn.Module.__init__(self) 80 | BaseResidueFeaturizer.__init__(self) 81 | self.device = device 82 | self.gin_model = gin_model 83 | self.requires_grad = requires_grad 84 | 85 | self.emb_dim = self.gin_model.node_embeddings[0].embedding_dim 86 | 87 | if readout == "sum": 88 | self.readout = SumPooling() 89 | elif readout == "mean": 90 | self.readout = AvgPooling() 91 | elif readout == "max": 92 | self.readout = MaxPooling() 93 | elif readout == "attention": 94 | if gin_model.JK == "concat": 95 | self.readout = GlobalAttentionPooling( 96 | gate_nn=nn.Linear( 97 | (self.gin_model.num_layers + 1) * self.emb_dim, 1 98 | ) 99 | ) 100 | else: 101 | self.readout = GlobalAttentionPooling( 102 | gate_nn=nn.Linear(self.emb_dim, 1) 103 | ) 104 | elif readout == "set2set": 105 | self.readout = Set2Set() 106 | else: 107 | raise ValueError( 108 | "Expect readout to be 'sum', 'mean', " 109 | "'max', 'attention' or 'set2set', got {}".format(readout) 110 | ) 111 | 112 | def _featurize( 113 | self, smiles: Union[str, List[str]], device="cpu" 114 | ) -> torch.tensor: 115 | self.gin_model = self.gin_model.to(device) 116 | if not self.requires_grad: 117 | self.gin_model.eval() 118 | self.readout = self.readout.to(device) 119 | graphs = [] 120 | if isinstance(smiles, str): 121 | mol = Chem.MolFromSmiles(smiles) 122 | g = mol_to_bigraph( 123 | mol, 124 | add_self_loop=True, 125 | node_featurizer=PretrainAtomFeaturizer(), 126 | edge_featurizer=PretrainBondFeaturizer(), 127 | canonical_atom_order=False, 128 | ) 129 | g = g.to(device) 130 | nfeats = [ 131 | g.ndata.pop("atomic_number").to(device), 132 | g.ndata.pop("chirality_type").to(device), 133 | ] 134 | efeats = [ 135 | g.edata.pop("bond_type").to(device), 136 | g.edata.pop("bond_direction_type").to(device), 137 | ] 138 | if not self.requires_grad: 139 | with torch.no_grad(): 140 | node_feats = self.gin_model(g, nfeats, efeats) 141 | graph_feats = self.readout(g, node_feats) 142 | else: 143 | node_feats = self.gin_model(g, nfeats, efeats) 144 | graph_feats = self.readout(g, node_feats) 145 | output_vec_graph = graph_feats.squeeze(0) 146 | return output_vec_graph 147 | else: 148 | for smi in smiles: 149 | mol = Chem.MolFromSmiles(smi) 150 | graph = mol_to_bigraph( 151 | mol, 152 | add_self_loop=True, 153 | node_featurizer=PretrainAtomFeaturizer(), 154 | edge_featurizer=PretrainBondFeaturizer(), 155 | canonical_atom_order=False, 156 | ) 157 | graphs.append(graph) 158 | bg = dgl.batch(graphs) 159 | bg = bg.to(device) 160 | nfeats = [ 161 | bg.ndata.pop("atomic_number").to(device), 162 | bg.ndata.pop("chirality_type").to(device), 163 | ] 164 | efeats = [ 165 | bg.edata.pop("bond_type").to(device), 166 | bg.edata.pop("bond_direction_type").to(device), 167 | ] 168 | if not self.requires_grad: 169 | with torch.no_grad(): 170 | node_feats = self.gin_model(bg, nfeats, efeats) 171 | graph_feats = self.readout(bg, node_feats) 172 | else: 173 | node_feats = self.gin_model(bg, nfeats, efeats) 174 | graph_feats = self.readout(bg, node_feats) 175 | return graph_feats 176 | 177 | def forward(self, smiles: str, device="cpu") -> torch.tensor: 178 | """Expose this method when we want to unfreeze the network, 179 | training jointly with higher level GNN""" 180 | assert self.requires_grad 181 | return self._featurize(smiles, device=device) 182 | 183 | @property 184 | def output_size(self) -> int: 185 | return self.emb_dim 186 | 187 | 188 | class MolT5Featurizer(BaseResidueFeaturizer, nn.Module): 189 | """ 190 | Use MolT5 encodings as residue features. 191 | """ 192 | 193 | def __init__( 194 | self, 195 | model_size="small", 196 | model_max_length=512, 197 | requires_grad=False, 198 | ): 199 | """ 200 | Args: 201 | model_size: one of ('small', 'base', 'large') 202 | """ 203 | nn.Module.__init__(self) 204 | BaseResidueFeaturizer.__init__(self) 205 | self.tokenizer = T5Tokenizer.from_pretrained( 206 | "laituan245/molt5-%s" % model_size, 207 | model_max_length=model_max_length, 208 | ) 209 | self.model = T5EncoderModel.from_pretrained( 210 | "laituan245/molt5-%s" % model_size 211 | ) 212 | self.requires_grad = requires_grad 213 | 214 | def _featurize(self, smiles: Union[str, List[str]]) -> torch.tensor: 215 | input_ids = self.tokenizer( 216 | smiles, return_tensors="pt", padding=True 217 | ).input_ids 218 | input_ids = input_ids.to(self.model.device) 219 | if not self.requires_grad: 220 | with torch.no_grad(): 221 | outputs = self.model(input_ids) 222 | else: 223 | outputs = self.model(input_ids) 224 | 225 | # n_smiles_strings = 1 if type(smiles) is str else len(smiles) 226 | # shape: [n_smiles_strings, input_ids.shape[1], model_max_length] 227 | last_hidden_states = outputs.last_hidden_state 228 | 229 | # average over positions: 230 | return last_hidden_states.mean(axis=1).squeeze(0) 231 | 232 | def forward(self, smiles: str, device="cpu") -> torch.tensor: 233 | """Expose this method when we want to unfreeze the network, 234 | training jointly with higher level GNN""" 235 | assert self.requires_grad 236 | return self._featurize(smiles) 237 | 238 | @property 239 | def output_size(self) -> int: 240 | return self.model.config.d_model 241 | 242 | 243 | def get_residue_featurizer(name="", device="cpu"): 244 | """ 245 | Handles initializing the residue featurizer. 246 | """ 247 | fingerprint_names = ("MACCS", "Morgan", "dummy") 248 | gin_names = ( 249 | "gin_supervised_contextpred", 250 | "gin_supervised_infomax", 251 | "gin_supervised_edgepred", 252 | "gin_supervised_masking", 253 | ) 254 | if name in fingerprint_names: 255 | residue_featurizer = FingerprintFeaturizer(name) 256 | elif name.lower().startswith("molt5"): 257 | model_size = "small" 258 | if "-" in name: 259 | model_size = name.split("-")[1] 260 | requires_grad = True if "grad" in name else False 261 | residue_featurizer = MolT5Featurizer( 262 | model_size=model_size, requires_grad=requires_grad 263 | ) 264 | elif name.lower().startswith("gin"): 265 | requires_grad = True if "grad" in name else False 266 | name_split = name.split("-") 267 | readout = name_split[3] 268 | name = "_".join(name_split[0:3]) 269 | name = name.lower() 270 | print(name) 271 | print(device) 272 | assert name in gin_names 273 | gin_model = load_pretrained(name) 274 | gin_model = gin_model.to(device) 275 | residue_featurizer = GINFeaturizer( 276 | gin_model=gin_model, 277 | readout=readout, 278 | requires_grad=requires_grad, 279 | device=device, 280 | ) 281 | else: 282 | raise NotImplementedError 283 | return residue_featurizer 284 | -------------------------------------------------------------------------------- /ppi/data_utils/xpdb.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | """ 5 | PDB parsers for large files with Biopython. 6 | Modified from https://biopython.org/wiki/Reading_large_PDB_files 7 | """ 8 | 9 | import sys 10 | import Bio.PDB 11 | import Bio.PDB.StructureBuilder 12 | from Bio.PDB.Residue import Residue 13 | 14 | 15 | class SloppyStructureBuilder(Bio.PDB.StructureBuilder.StructureBuilder): 16 | """Cope with resSeq < 10,000 limitation by just incrementing internally.""" 17 | 18 | def __init__(self, verbose=False): 19 | Bio.PDB.StructureBuilder.StructureBuilder.__init__(self) 20 | self.max_resseq = -1 21 | self.verbose = verbose 22 | 23 | def init_residue(self, resname, field, resseq, icode): 24 | """Initiate a new Residue object. 25 | 26 | Arguments: 27 | resname: string, e.g. "ASN" 28 | field: hetero flag, "W" for waters, "H" for hetero residues, otherwise blanc. 29 | resseq: int, sequence identifier 30 | icode: string, insertion code 31 | 32 | Return: 33 | None 34 | """ 35 | if field != " ": 36 | if field == "H": 37 | # The hetero field consists of 38 | # H_ + the residue name (e.g. H_FUC) 39 | field = "H_" + resname 40 | res_id = (field, resseq, icode) 41 | 42 | if resseq > self.max_resseq: 43 | self.max_resseq = resseq 44 | 45 | if field == " ": 46 | fudged_resseq = False 47 | while self.chain.has_id(res_id) or resseq == 0: 48 | # There already is a residue with the id (field, resseq, icode) 49 | # resseq == 0 catches already wrapped residue numbers which 50 | # do not trigger the has_id() test. 51 | # 52 | # Be sloppy and just increment... 53 | # (This code will not leave gaps in resids... I think) 54 | # 55 | # XXX: shouldn't we also do this for hetero atoms and water?? 56 | self.max_resseq += 1 57 | resseq = self.max_resseq 58 | res_id = (field, resseq, icode) # use max_resseq! 59 | fudged_resseq = True 60 | 61 | if fudged_resseq and self.verbose: 62 | sys.stderr.write( 63 | "Residues are wrapping (Residue " 64 | + "('%s', %i, '%s') at line %i)." 65 | % (field, resseq, icode, self.line_counter) 66 | + ".... assigning new resid %d.\n" % self.max_resseq 67 | ) 68 | residue = Residue(res_id, resname, self.segid) 69 | self.chain.add(residue) 70 | self.residue = residue 71 | return None 72 | 73 | 74 | class SloppyPDBIO(Bio.PDB.PDBIO): 75 | """PDBIO class that can deal with large pdb files as used in MD simulations 76 | 77 | - resSeq simply wrap and are printed modulo 10,000. 78 | - atom numbers wrap at 99,999 and are printed modulo 100,000 79 | 80 | """ 81 | 82 | # The format string is derived from the PDB format as used in PDBIO.py 83 | # (has to be copied to the class because of the package layout it is not 84 | # externally accessible) 85 | _ATOM_FORMAT_STRING = ( 86 | "%s%5i %-4s%c%3s %c%4i%c " 87 | + "%8.3f%8.3f%8.3f%6.2f%6.2f %4s%2s%2s\n" 88 | ) 89 | 90 | def _get_atom_line( 91 | self, 92 | atom, 93 | hetfield, 94 | segid, 95 | atom_number, 96 | resname, 97 | resseq, 98 | icode, 99 | chain_id, 100 | element=" ", 101 | charge=" ", 102 | ): 103 | """Returns an ATOM string that is guaranteed to fit the ATOM format. 104 | 105 | - Resid (resseq) is wrapped (modulo 10,000) to fit into %4i (4I) format 106 | - Atom number (atom_number) is wrapped (modulo 100,000) to fit into 107 | %5i (5I) format 108 | 109 | Args: #TODO 110 | atom: 111 | hetfield: 112 | segid: 113 | atom_number: 114 | resname: 115 | resseq: 116 | icode: 117 | chain_id: 118 | element: 119 | charge: 120 | 121 | Returns: 122 | #TODO 123 | """ 124 | if hetfield != " ": 125 | record_type = "HETATM" 126 | else: 127 | record_type = "ATOM " 128 | name = atom.get_fullname() 129 | altloc = atom.get_altloc() 130 | x, y, z = atom.get_coord() 131 | bfactor = atom.get_bfactor() 132 | occupancy = atom.get_occupancy() 133 | args = ( 134 | record_type, 135 | atom_number % 100000, 136 | name, 137 | altloc, 138 | resname, 139 | chain_id, 140 | resseq % 10000, 141 | icode, 142 | x, 143 | y, 144 | z, 145 | occupancy, 146 | bfactor, 147 | segid, 148 | element, 149 | charge, 150 | ) 151 | return self._ATOM_FORMAT_STRING % args 152 | -------------------------------------------------------------------------------- /ppi/gvp.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | DGL implementation of GVP and GVP-GNN (without the autoregressive functionality) 5 | modified from source: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py 6 | """ 7 | import functools 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def tuple_sum(*args): 14 | """ 15 | Sums any number of tuples (s, V) elementwise. 16 | """ 17 | return tuple(map(sum, zip(*args))) 18 | 19 | 20 | def tuple_cat(*args, dim=-1): 21 | """ 22 | Concatenates any number of tuples (s, V) elementwise. 23 | 24 | :param dim: dimension along which to concatenate when viewed 25 | as the `dim` index for the scalar-channel tensors. 26 | This means that `dim=-1` will be applied as 27 | `dim=-2` for the vector-channel tensors. 28 | """ 29 | dim %= len(args[0][0].shape) 30 | s_args, v_args = list(zip(*args)) 31 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) 32 | 33 | 34 | def tuple_index(x, idx): 35 | """ 36 | Indexes into a tuple (s, V) along the first dimension. 37 | 38 | :param idx: any object which can be used to index into a `torch.Tensor` 39 | """ 40 | return x[0][idx], x[1][idx] 41 | 42 | 43 | def randn(n, dims, device="cpu"): 44 | """ 45 | Returns random tuples (s, V) drawn elementwise from a normal distribution. 46 | 47 | :param n: number of data points 48 | :param dims: tuple of dimensions (n_scalar, n_vector) 49 | 50 | :return: (s, V) with s.shape = (n, n_scalar) and 51 | V.shape = (n, n_vector, 3) 52 | """ 53 | return torch.randn(n, dims[0], device=device), torch.randn( 54 | n, dims[1], 3, device=device 55 | ) 56 | 57 | 58 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): 59 | """ 60 | L2 norm of tensor clamped above a minimum value `eps`. 61 | 62 | :param sqrt: if `False`, returns the square of the L2 norm 63 | """ 64 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) 65 | return torch.sqrt(out) if sqrt else out 66 | 67 | 68 | class GVP(nn.Module): 69 | """ 70 | Geometric Vector Perceptron. See manuscript and README.md 71 | for more details. 72 | 73 | :param in_dims: tuple (n_scalar, n_vector) 74 | :param out_dims: tuple (n_scalar, n_vector) 75 | :param h_dim: intermediate number of vector channels, optional 76 | :param activations: tuple of functions (scalar_act, vector_act) 77 | :param vector_gate: whether to use vector gating. 78 | (vector_act will be used as sigma^+ in vector gating if `True`) 79 | """ 80 | 81 | def __init__( 82 | self, 83 | in_dims, 84 | out_dims, 85 | h_dim=None, 86 | activations=(F.relu, torch.sigmoid), 87 | vector_gate=False, 88 | ): 89 | super(GVP, self).__init__() 90 | self.si, self.vi = in_dims 91 | self.so, self.vo = out_dims 92 | self.vector_gate = vector_gate 93 | if self.vi: 94 | self.h_dim = h_dim or max(self.vi, self.vo) 95 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False) 96 | self.ws = nn.Linear(self.h_dim + self.si, self.so) 97 | if self.vo: 98 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False) 99 | if self.vector_gate: 100 | self.wsv = nn.Linear(self.so, self.vo) 101 | else: 102 | self.ws = nn.Linear(self.si, self.so) 103 | 104 | self.scalar_act, self.vector_act = activations 105 | self.dummy_param = nn.Parameter(torch.empty(0)) 106 | 107 | def forward(self, x): 108 | """ 109 | :param x: tuple (s, V) of `torch.Tensor`, 110 | or (if vectors_in is 0), a single `torch.Tensor` 111 | :return: tuple (s, V) of `torch.Tensor`, 112 | or (if vectors_out is 0), a single `torch.Tensor` 113 | """ 114 | if self.vi: 115 | s, v = x 116 | v = torch.transpose(v, -1, -2) 117 | vh = self.wh(v) 118 | vn = _norm_no_nan(vh, axis=-2) 119 | s = self.ws(torch.cat([s, vn], -1)) 120 | if self.vo: 121 | v = self.wv(vh) 122 | v = torch.transpose(v, -1, -2) 123 | if self.vector_gate: 124 | if self.vector_act: 125 | gate = self.wsv(self.vector_act(s)) 126 | else: 127 | gate = self.wsv(s) 128 | v = v * torch.sigmoid(gate).unsqueeze(-1) 129 | elif self.vector_act: 130 | v = v * self.vector_act( 131 | _norm_no_nan(v, axis=-1, keepdims=True) 132 | ) 133 | else: 134 | s = self.ws(x) 135 | if self.vo: 136 | v = torch.zeros( 137 | s.shape[0], self.vo, 3, device=self.dummy_param.device 138 | ) 139 | if self.scalar_act: 140 | s = self.scalar_act(s) 141 | 142 | return (s, v) if self.vo else s 143 | 144 | 145 | class _VDropout(nn.Module): 146 | """ 147 | Vector channel dropout where the elements of each 148 | vector channel are dropped together. 149 | """ 150 | 151 | def __init__(self, drop_rate): 152 | super(_VDropout, self).__init__() 153 | self.drop_rate = drop_rate 154 | self.dummy_param = nn.Parameter(torch.empty(0)) 155 | 156 | def forward(self, x): 157 | """ 158 | :param x: `torch.Tensor` corresponding to vector channels 159 | """ 160 | device = self.dummy_param.device 161 | if not self.training: 162 | return x 163 | mask = torch.bernoulli( 164 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) 165 | ).unsqueeze(-1) 166 | x = mask * x / (1 - self.drop_rate) 167 | return x 168 | 169 | 170 | class Dropout(nn.Module): 171 | """ 172 | Combined dropout for tuples (s, V). 173 | Takes tuples (s, V) as input and as output. 174 | """ 175 | 176 | def __init__(self, drop_rate): 177 | super(Dropout, self).__init__() 178 | self.sdropout = nn.Dropout(drop_rate) 179 | self.vdropout = _VDropout(drop_rate) 180 | 181 | def forward(self, x): 182 | """ 183 | :param x: tuple (s, V) of `torch.Tensor`, 184 | or single `torch.Tensor` 185 | (will be assumed to be scalar channels) 186 | """ 187 | if type(x) is torch.Tensor: 188 | return self.sdropout(x) 189 | s, v = x 190 | return self.sdropout(s), self.vdropout(v) 191 | 192 | 193 | class LayerNorm(nn.Module): 194 | """ 195 | Combined LayerNorm for tuples (s, V). 196 | Takes tuples (s, V) as input and as output. 197 | """ 198 | 199 | def __init__(self, dims): 200 | super(LayerNorm, self).__init__() 201 | self.s, self.v = dims 202 | self.scalar_norm = nn.LayerNorm(self.s) 203 | 204 | def forward(self, x): 205 | """ 206 | :param x: tuple (s, V) of `torch.Tensor`, 207 | or single `torch.Tensor` 208 | (will be assumed to be scalar channels) 209 | """ 210 | if not self.v: 211 | return self.scalar_norm(x) 212 | s, v = x 213 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) 214 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) 215 | return self.scalar_norm(s), v / vn 216 | 217 | 218 | class GVPConv(nn.Module): 219 | """ 220 | Graph convolution / message passing with Geometric Vector Perceptrons. 221 | Takes in a graph with node and edge embeddings, 222 | and returns new node embeddings. 223 | 224 | This does NOT do residual updates and pointwise feedforward layers 225 | ---see `GVPConvLayer`. 226 | 227 | :param in_dims: input node embedding dimensions (n_scalar, n_vector) 228 | :param out_dims: output node embedding dimensions (n_scalar, n_vector) 229 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 230 | :param n_layers: number of GVPs in the message function 231 | :param module_list: preconstructed message function, overrides n_layers 232 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 233 | :param vector_gate: whether to use vector gating. 234 | (vector_act will be used as sigma^+ in vector gating if `True`) 235 | """ 236 | 237 | def __init__( 238 | self, 239 | in_dims, 240 | out_dims, 241 | edge_dims, 242 | n_layers=3, 243 | module_list=None, 244 | activations=(F.relu, torch.sigmoid), 245 | vector_gate=False, 246 | ): 247 | super(GVPConv, self).__init__() 248 | self.si, self.vi = in_dims 249 | self.so, self.vo = out_dims 250 | self.se, self.ve = edge_dims 251 | 252 | GVP_ = functools.partial( 253 | GVP, activations=activations, vector_gate=vector_gate 254 | ) 255 | 256 | module_list = module_list or [] 257 | if not module_list: 258 | if n_layers == 1: 259 | module_list.append( 260 | GVP_( 261 | (2 * self.si + self.se, 2 * self.vi + self.ve), 262 | (self.so, self.vo), 263 | activations=(None, None), 264 | ) 265 | ) 266 | else: 267 | module_list.append( 268 | GVP_( 269 | (2 * self.si + self.se, 2 * self.vi + self.ve), 270 | out_dims, 271 | ) 272 | ) 273 | for i in range(n_layers - 2): 274 | module_list.append(GVP_(out_dims, out_dims)) 275 | module_list.append( 276 | GVP_(out_dims, out_dims, activations=(None, None)) 277 | ) 278 | self.message_func = nn.Sequential(*module_list) 279 | 280 | def forward(self, g): 281 | g.update_all( 282 | message_func=self.message_udf, reduce_func=self.reduce_udf 283 | ) 284 | return g.ndata["node_s_agg"], g.ndata["node_v_agg"] 285 | 286 | def message(self, s_i, v_i, s_j, v_j, edge_attr): 287 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) 288 | message = self.message_func(message) 289 | return message 290 | 291 | def message_udf(self, edges): 292 | """ 293 | message function for GVP-GNN 294 | :param edges: EdgeBatch 295 | :return dict[str, tensor]: s_m: scalar message; s_v: vector message 296 | """ 297 | s_i, v_i = edges.src["node_s"], edges.src["node_v"] 298 | s_j, v_j = edges.dst["node_s"], edges.dst["node_v"] 299 | edge_attr = edges.data["edge_s"], edges.data["edge_v"] 300 | 301 | s_m, v_m = self.message(s_i, v_i, s_j, v_j, edge_attr) 302 | return {"s_m": s_m, "v_m": v_m} 303 | 304 | def reduce_udf(self, nodes): 305 | """ 306 | reduce function for GVP-GNN 307 | :param nodes: NodeBatch 308 | """ 309 | s_m, v_m = nodes.mailbox["s_m"], nodes.mailbox["v_m"] 310 | 311 | return { 312 | "node_s_agg": torch.mean(s_m, dim=1), 313 | "node_v_agg": torch.mean(v_m, dim=1), 314 | } 315 | 316 | 317 | class GVPConvLayer(nn.Module): 318 | """ 319 | Full graph convolution / message passing layer with 320 | Geometric Vector Perceptrons. Residually updates node embeddings with 321 | aggregated incoming messages, applies a pointwise feedforward 322 | network to node embeddings, and returns updated node embeddings. 323 | 324 | To only compute the aggregated messages, see `GVPConv`. 325 | 326 | :param node_dims: node embedding dimensions (n_scalar, n_vector) 327 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 328 | :param n_message: number of GVPs to use in message function 329 | :param n_feedforward: number of GVPs to use in feedforward function 330 | :param drop_rate: drop probability in all dropout layers 331 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 332 | :param vector_gate: whether to use vector gating. 333 | (vector_act will be used as sigma^+ in vector gating if `True`) 334 | """ 335 | 336 | def __init__( 337 | self, 338 | node_dims, 339 | edge_dims, 340 | n_message=3, 341 | n_feedforward=2, 342 | drop_rate=0.1, 343 | activations=(F.relu, torch.sigmoid), 344 | vector_gate=False, 345 | ): 346 | 347 | super(GVPConvLayer, self).__init__() 348 | self.conv = GVPConv( 349 | node_dims, 350 | node_dims, 351 | edge_dims, 352 | n_message, 353 | activations=activations, 354 | vector_gate=vector_gate, 355 | ) 356 | GVP_ = functools.partial( 357 | GVP, activations=activations, vector_gate=vector_gate 358 | ) 359 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) 360 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) 361 | 362 | ff_func = [] 363 | if n_feedforward == 1: 364 | ff_func.append( 365 | GVP_(node_dims, node_dims, activations=(None, None)) 366 | ) 367 | else: 368 | hid_dims = 4 * node_dims[0], 2 * node_dims[1] 369 | ff_func.append(GVP_(node_dims, hid_dims)) 370 | for i in range(n_feedforward - 2): 371 | ff_func.append(GVP_(hid_dims, hid_dims)) 372 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) 373 | self.ff_func = nn.Sequential(*ff_func) 374 | 375 | def forward(self, g): 376 | """ 377 | :param g: dgl.graph 378 | """ 379 | 380 | dh = self.conv(g) 381 | 382 | x = g.ndata["node_s"], g.ndata["node_v"] 383 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) 384 | 385 | dh = self.ff_func(x) 386 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) 387 | 388 | return x 389 | -------------------------------------------------------------------------------- /ppi/transfer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | Utils for transfer learning 5 | """ 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | def load_state_dict_to_model(model, state_dict): 10 | """Initialize a model with parameters in `state_dict` (inplace) 11 | from a pretrained model with slightly different architecture. 12 | Args: 13 | model: Torch model 14 | state_dict: Dictionary containing weight for each layer of the `model` 15 | Returns: 16 | input `model` where layer weights have been updated based on `state_dict` 17 | """ 18 | own_state = model.state_dict() 19 | print("model own state keys:", len(own_state)) 20 | print("state_dict keys:", len(state_dict)) 21 | keys_loaded = 0 22 | for name, param in state_dict.items(): 23 | if name not in own_state: 24 | continue 25 | if isinstance(param, Parameter): 26 | # backwards compatibility for serialized parameters 27 | param = param.data 28 | own_state[name].copy_(param) 29 | keys_loaded += 1 30 | print("keys loaded into model:", keys_loaded) 31 | -------------------------------------------------------------------------------- /preprocess_diffdock_output.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | """ 4 | # Convert DiffDock's output into the same format with CASF-2016 Docking data. 5 | 6 | The output from DiffDock is a PDB file, which contains a single chain 7 | representing the pose of the small molecule ligand. 8 | 9 | Example usage: 10 | python preprocess_diffdock_output.py \ 11 | --data_dir /home/ec2-user/SageMaker/efs/data/DiffDockData/inference1 \ 12 | --pdb_data_dir /home/ec2-user/SageMaker/efs/data/DiffDockData/PDBBind_processed \ 13 | --thres 6 \ 14 | --output_dir /home/ec2-user/SageMaker/efs/data/DiffDockData/inference1_processed_t6 15 | 16 | In this pipeline, we perform the following steps to convert the outputs from 17 | DiffDock to a format compatible with our affinity prediction model: 18 | 19 | 1. parse the ligand pose PDB file 20 | 2. combine the ligand with protein PDB into the same coordinate system 21 | 3. subset the protein chain(s) to only include the residues around the ligand 22 | -> pocket-ligand structure 23 | 4. save the pocket-ligand structures into pickles; save the RMSD values into 24 | text files 25 | 26 | After these steps, we should be able to run inference using our affinity 27 | prediction model in a similar settings in `evaluate_casf2016.py` function 28 | `evaluate_docking`. 29 | """ 30 | import os 31 | import pickle 32 | import argparse 33 | from tqdm import tqdm 34 | from rdkit import Chem 35 | import pandas as pd 36 | import numpy as np 37 | from scipy.spatial.distance import cdist 38 | from Bio.PDB import PDBParser, Select, PDBIO 39 | from Bio.PDB.Polypeptide import is_aa 40 | 41 | from ppi.data_utils import parse_pdb_structure 42 | 43 | 44 | def get_calpha_coords(residue): 45 | try: 46 | return residue["CA"].coord 47 | except KeyError: 48 | return [np.nan] * 3 49 | 50 | 51 | def get_contact_residues(ligand_mol, chain, thres=10): 52 | # get the residue IDs 53 | res_ids = np.asarray( 54 | [res.id[1] for res in chain.get_residues() if is_aa(res)] 55 | ) 56 | # ligand coords 57 | coords1 = ligand_mol.GetConformers()[0].GetPositions() 58 | # extract the C-alpha coordinates of all AA residues 59 | coords2 = np.asarray( 60 | [get_calpha_coords(res) for res in chain.get_residues() if is_aa(res)] 61 | ) 62 | # calculate interchain distance 63 | dist = cdist(coords1, coords2) 64 | dist_bool = dist <= thres 65 | 66 | res_keep = res_ids[dist_bool.sum(axis=0) > 0] 67 | return res_keep 68 | 69 | 70 | def get_contact_residues_across_chains(ligand_mol, protein, thres=10): 71 | d_chain_residues = {} 72 | for chain in protein.get_chains(): 73 | res_keep = get_contact_residues(ligand_mol, chain, thres=thres) 74 | d_chain_residues[chain.id] = res_keep 75 | return d_chain_residues 76 | 77 | 78 | def subset_pdb_structure(structure, d_chain_residues, outfile): 79 | # to subset the protein structure 80 | class ResSelect(Select): 81 | def accept_residue(self, res): 82 | if res.id[1] in d_chain_residues.get(res.parent.id, set()): 83 | return True 84 | else: 85 | return False 86 | 87 | io = PDBIO() 88 | # set the structure as the entire protein 89 | io.set_structure(structure) 90 | # subset and save the pocket into PDB file 91 | io.save(outfile, ResSelect()) 92 | return 93 | 94 | 95 | def process_one(row, pdb_parser, args): 96 | # 1. Parse ligand poses from PDB files 97 | ligand_mol = Chem.MolFromPDBFile( 98 | os.path.join(args.data_dir, row["pdb_file"]), sanitize=False 99 | ) 100 | # 2. combine the ligand with protein PDB into the same coordinate system 101 | # parse the corresponding protein 102 | protein = parse_pdb_structure( 103 | pdb_parser, 104 | row.pdb_id, 105 | os.path.join( 106 | args.pdb_data_dir, 107 | row["pdb_id"], 108 | f"{row['pdb_id']}_protein_processed.pdb", 109 | ), 110 | ) 111 | # 3. subset the protein chain(s) to only include the residues around the 112 | # ligand 113 | d_chain_residues = get_contact_residues_across_chains( 114 | ligand_mol, protein, thres=args.thres 115 | ) 116 | subset_pdb_structure( 117 | protein, 118 | d_chain_residues, 119 | os.path.join(args.output_dir, f"{row['pdb_id']}_chopped.pdb"), 120 | ) 121 | protein_pocket = parse_pdb_structure( 122 | pdb_parser, 123 | row.pdb_id, 124 | os.path.join(args.output_dir, f"{row['pdb_id']}_chopped.pdb"), 125 | ) 126 | # 4. write to pickle 127 | output = (ligand_mol, None, protein_pocket, None) 128 | output_file = os.path.join(args.output_dir, "data", row["file_id"]) 129 | pickle.dump(output, open(output_file, "wb")) 130 | return 131 | 132 | 133 | def main(args): 134 | os.makedirs(args.output_dir, exist_ok=True) 135 | sub_dirs = ["decoys_docking_rmsd", "data", "keys"] 136 | for sub_dir in sub_dirs: 137 | os.makedirs(os.path.join(args.output_dir, sub_dir), exist_ok=True) 138 | 139 | # 0. parse metadata from DiffDock output files 140 | meta_df = [] 141 | for pdb_file in os.listdir(args.data_dir): 142 | if pdb_file.endswith(".pdb"): 143 | row = { 144 | "id": pdb_file[:-4], 145 | "pdb_file": pdb_file, 146 | "pdb_id": pdb_file.split("_")[0], 147 | "rank": int(pdb_file.split("_")[1]), 148 | "rmsd": float(pdb_file.split("_")[2]), 149 | "confidence": float(pdb_file.split("_")[3][:-4]), 150 | } 151 | row["file_id"] = row["pdb_id"] + "_" + row["id"].split("_")[1] 152 | meta_df.append(row) 153 | 154 | meta_df = pd.DataFrame(meta_df).set_index("id", verify_integrity=True) 155 | print(meta_df.shape) 156 | 157 | pdb_parser = PDBParser( 158 | QUIET=True, 159 | PERMISSIVE=True, 160 | ) 161 | for _, row in tqdm(meta_df.iterrows(), total=meta_df.shape[0]): 162 | process_one(row, pdb_parser, args) 163 | 164 | # Write RMSD files 165 | for pdb_id, sub_df in meta_df.groupby("pdb_id"): 166 | out_rmsd_filename = f"{pdb_id}_rmsd.dat" 167 | sub_df[["file_id", "rmsd"]].to_csv( 168 | os.path.join( 169 | args.output_dir, "decoys_docking_rmsd", out_rmsd_filename 170 | ), 171 | sep="\t", 172 | index=False, 173 | ) 174 | # Write keys and pdb_to_affinity.txt 175 | keys = list(meta_df["file_id"]) 176 | pickle.dump( 177 | keys, open(os.path.join(args.output_dir, "keys/test_keys.pkl"), "wb") 178 | ) 179 | pdb_to_affinity = meta_df[["file_id"]] 180 | pdb_to_affinity.loc[:, "affinity"] = 0 181 | pdb_to_affinity.to_csv( 182 | os.path.join(args.output_dir, "pdb_to_affinity.txt"), 183 | sep="\t", 184 | index=False, 185 | header=False, 186 | ) 187 | return 188 | 189 | 190 | if __name__ == "__main__": 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument( 193 | "--data_dir", 194 | type=str, 195 | required=True, 196 | help="Directory to the output ligand poses in pdb files from DiffDock", 197 | ) 198 | parser.add_argument( 199 | "--pdb_data_dir", 200 | type=str, 201 | required=True, 202 | help="Directory to oringal protein pdb files", 203 | ) 204 | parser.add_argument( 205 | "--output_dir", 206 | type=str, 207 | required=True, 208 | help="Output directory", 209 | ) 210 | parser.add_argument( 211 | "--thres", 212 | type=int, 213 | required=True, 214 | default=10, 215 | help="Threshold for identifying contact residues", 216 | ) 217 | 218 | args = parser.parse_args() 219 | main(args) 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | biopython==1.79 2 | dgl-cu111==0.9.1 3 | dgllife==0.2.9 4 | pytorch-lightning==1.7.3 5 | rdkit-pypi==2022.3.5 6 | sentencepiece==0.1.97 7 | torch==1.10.0 8 | torchmetrics==0.9.3 9 | transformers==4.21.2 10 | setuptools==59.5.0 11 | seaborn 12 | matplotlib 13 | biopython 14 | boto3 15 | 16 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Setting up conda env for this project 3 | # Note: this requires pytorch_p38 in SageMaker or DLAMI 4 | source ~/anaconda3/etc/profile.d/conda.sh 5 | conda activate pytorch_p38 6 | 7 | pip install dgl-cu111 dglgo -f https://data.dgl.ai/wheels/repo.html 8 | pip install -r requirements.txt 9 | -------------------------------------------------------------------------------- /test.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/eggnet-equivariant-graph-of-graph-neural-network/87ee428c8a79171f2d5331e1cae7c6ac82d84dd8/test.txt -------------------------------------------------------------------------------- /test_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/anaconda3/etc/profile.d/conda.sh 3 | conda activate pytorch_p38 4 | 5 | # PDBBind dataset 6 | python train.py --accelerator gpu \ 7 | --devices -1 \ 8 | --max_epochs 1 \ 9 | --precision 32 \ 10 | --stage1_num_layers 3 \ 11 | --stage1_node_h_dim 200 32 \ 12 | --stage1_edge_h_dim 64 2 \ 13 | --stage2_num_layers 3 \ 14 | --stage2_node_h_dim 200 32 \ 15 | --stage2_edge_h_dim 64 2 \ 16 | --dataset_name PDBBind \ 17 | --input_type multistage-hetero \ 18 | --model_name multistage-hgvp \ 19 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019_processed/scoring \ 20 | --residual \ 21 | --num_workers 8 \ 22 | --lr 1e-4 \ 23 | --bs 2 \ 24 | --early_stopping_patience 50 \ 25 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/brandry/PDBBind_MS-HGVP_hetero_energy \ 26 | --residue_featurizer_name MolT5-small-grad \ 27 | --is_hetero \ 28 | --use_energy_decoder \ 29 | --loss_der1_ratio=10.0 \ 30 | --loss_der2_ratio=10.0 \ 31 | --min_loss_der2=-20.0 \ 32 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/brandry/PDBBind_MSGVP_hetero_energy \ 33 | --use_energy_decoder \ 34 | --is_hetero 35 | 36 | # small PDBBind 37 | python train.py --accelerator gpu \ 38 | --devices -1 \ 39 | --max_epochs 1000 \ 40 | --precision 16 \ 41 | --dataset_name PDBBind \ 42 | --input_type multistage-hetero \ 43 | --model_name multistage-gvp \ 44 | --residue_featurizer_name MolT5-small \ 45 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring 46 | 47 | ## with energy 48 | python train.py --accelerator gpu \ 49 | --devices -1 \ 50 | --max_epochs 1000 \ 51 | --precision 16 \ 52 | --dataset_name PDBBind \ 53 | --input_type multistage-hetero \ 54 | --model_name multistage-gvp \ 55 | --residue_featurizer_name MolT5-small \ 56 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring \ 57 | --use_energy_decoder \ 58 | --is_hetero \ 59 | --num_workers 8 60 | 61 | ## ssGVP with energy 62 | python train.py --accelerator gpu \ 63 | --devices -1 \ 64 | --max_epochs 1000 \ 65 | --precision 16 \ 66 | --dataset_name PDBBind \ 67 | --input_type complex \ 68 | --model_name gvp \ 69 | --residue_featurizer_name MolT5-small \ 70 | --data_dir /home/ec2-user/SageMaker/efs/data/PIGNet/data/pdbbind_v2019/scoring \ 71 | --use_energy_decoder \ 72 | --is_hetero \ 73 | --num_workers 8 \ 74 | --persistent_workers True 75 | 76 | 77 | # intact PDBBind 78 | python train.py --accelerator gpu \ 79 | --devices -1 \ 80 | --max_epochs 1000 \ 81 | --precision 16 \ 82 | --dataset_name PDBBind \ 83 | --input_type complex \ 84 | --model_name gvp \ 85 | --residue_featurizer_name MolT5-small \ 86 | --data_dir /home/ec2-user/SageMaker/efs/data/PDBBind/pdbbind_v2019/scoring 87 | 88 | python train.py --accelerator gpu \ 89 | --devices -1 \ 90 | --max_epochs 1000 \ 91 | --precision 16 \ 92 | --dataset_name PDBBind \ 93 | --input_type multistage-hetero \ 94 | --model_name multistage-gvp \ 95 | --residue_featurizer_name MolT5-small \ 96 | --data_dir /home/ec2-user/SageMaker/efs/data/PDBBind/pdbbind_v2019/scoring \ 97 | --default_root_dir /home/ec2-user/SageMaker/efs/model_logs/zichen/PDBBind_intact_MSGVP_hetero \ 98 | --bs 16 \ 99 | --num_workers 8 \ 100 | --persistent_workers True 101 | 102 | python train.py --accelerator gpu \ 103 | --devices -1 \ 104 | --max_epochs 1000 \ 105 | --precision 16 \ 106 | --dataset_name PDBBind \ 107 | --input_type multistage-hetero \ 108 | --model_name multistage-gvp \ 109 | --residue_featurizer_name MolT5-small \ 110 | --data_dir /home/ec2-user/SageMaker/efs/data/PDBBind/pdbbind_v2019/scoring \ 111 | --bs 16 \ 112 | --num_workers 8 \ 113 | --persistent_workers True \ 114 | --use_energy_decoder \ 115 | --is_hetero 116 | 117 | python train.py --accelerator gpu \ 118 | --devices -1 \ 119 | --max_epochs 1000 \ 120 | --precision 16 \ 121 | --dataset_name PDBBind \ 122 | --input_type complex \ 123 | --model_name gvp \ 124 | --residue_featurizer_name MACCS \ 125 | --data_dir /home/ec2-user/SageMaker/efs/data/PDBBind/pdbbind_v2019/scoring \ 126 | --use_energy_decoder \ 127 | --is_hetero \ 128 | --num_workers 0 \ 129 | --bs 2 130 | --------------------------------------------------------------------------------