├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── dips │ └── model_0.ckpt ├── configs ├── callbacks │ ├── default.yaml │ └── paramfreezer.yaml ├── config.yaml ├── datamodule │ ├── default.yaml │ ├── docking_datamodule.yaml │ ├── pinder_datamodule.yaml │ └── ppi_mlsb_datamodule.yaml ├── inference.yaml ├── logger │ └── wandb.yaml ├── mode │ └── default.yaml ├── model │ ├── DFMDock.yaml │ ├── DFMDock_guide.yaml │ ├── force_model.yaml │ └── score_model_mlsb.yaml └── trainer │ └── default.yaml ├── data └── db5_test │ ├── 1AVX.pt │ ├── 1H1V.pt │ ├── 1HCF.pt │ ├── 1IRA.pt │ ├── 1JIW.pt │ ├── 1JPS.pt │ ├── 1MLC.pt │ ├── 1N2C.pt │ ├── 1NW9.pt │ ├── 1QA9.pt │ ├── 1VFB.pt │ ├── 1ZHI.pt │ ├── 2A1A.pt │ ├── 2A9K.pt │ ├── 2AYO.pt │ ├── 2SIC.pt │ ├── 2SNI.pt │ ├── 2VDB.pt │ ├── 3SZK.pt │ ├── 4POU.pt │ ├── 5C7X.pt │ ├── 5HGG.pt │ ├── 5JMO.pt │ ├── 6B0S.pt │ ├── 7CEI.pt │ └── test.txt ├── environment.yml ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── data │ └── gen_dips_attn.py ├── datasets │ ├── __init__.py │ ├── docking_dataset.py │ ├── pinder_dataset.py │ ├── ppi_dataset.py │ ├── ppi_mlsb_dataset.py │ └── submit_cpu.sh ├── inference.py ├── inference_base.py ├── inference_mlsb.py ├── inference_single.py ├── models │ ├── DFMDock.py │ ├── __init__.py │ ├── egnn.py │ ├── egnn_net.py │ ├── score_model.py │ ├── score_model_mlsb.py │ ├── score_net.py │ └── score_net_mlsb.py ├── run.py ├── train.py └── utils │ ├── __init__.py │ ├── coords6d.py │ ├── crop.py │ ├── frame.py │ ├── geometry.py │ ├── loss.py │ ├── metrics.py │ ├── pdb.py │ ├── r3_diffuser.py │ ├── residue_constants.py │ ├── so3_diffuser.py │ └── utils.py ├── tests ├── test_biotite.py ├── test_gLM2.py └── test_pinder.py └── weights └── pinder_0.ckpt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | slogs/ 165 | outputs/ 166 | scripts/ 167 | checkpoints/ 168 | lightning_logs/ 169 | tests/ 170 | debug/ 171 | *.pdb 172 | *.csv 173 | *.ipynb 174 | *.png 175 | *.yaml 176 | inference_gpu.sh 177 | submit_inference.sh 178 | DFMDock_guide.py 179 | inference_guide.py 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gray Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFMDock 2 | DFMDock (Denoising Force Matching Dock), a diffusion model that unifies sampling and ranking within a single framework. 3 | 4 | ## Setup 5 | 6 | ### 1. Clone the Repository 7 | 8 | ```bash 9 | git clone https://github.com/Graylab/DFMDock.git 10 | cd DFMDock 11 | ``` 12 | 13 | ### 2. Create and Activate Conda Environment 14 | 15 | Run the following commands to create and activate the Conda environment: 16 | 17 | ```bash 18 | conda env create -f environment.yml 19 | conda activate DFMDock 20 | ``` 21 | 22 | ### 3. Install the Project in Editable Mode 23 | 24 | To install the project in editable mode, run the following command: 25 | 26 | ```bash 27 | pip install -e . 28 | ``` 29 | 30 | 31 | ### Usage 32 | 33 | To run inference on your own PDB files, use the following command: 34 | 35 | ```bash 36 | python src/inference_single.py path_to_input_pdb_1 path_to_input_pdb_2 37 | ``` 38 | 39 | ### Citing this work 40 | 41 | ```bibtex 42 | @article{chu2024unified, 43 | title={Unified Sampling and Ranking for Protein Docking with DFMDock}, 44 | author={Chu, Lee-Shin and Sarma, Sudeep and Gray, Jeffrey J}, 45 | journal={bioRxiv}, 46 | pages={2024--09}, 47 | year={2024}, 48 | publisher={Cold Spring Harbor Laboratory} 49 | } 50 | ``` 51 | 52 | 53 | -------------------------------------------------------------------------------- /checkpoints/dips/model_0.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/checkpoints/dips/model_0.ckpt -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/loss" # name of the logged metric which determines when model is improving 4 | mode: "min" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | -------------------------------------------------------------------------------- /configs/callbacks/paramfreezer.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/loss" # name of the logged metric which determines when model is improving 4 | mode: "min" # "max" means higher metric value is better, can be also "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | every_n_train_steps: 5000 8 | verbose: False 9 | dirpath: "checkpoints/" 10 | filename: "epoch_{epoch:03d}" 11 | auto_insert_metric_name: False 12 | 13 | ParamFreezer: 14 | _target_: src.callbacks.paramfreezer.ParamFreezer 15 | 16 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: default.yaml 8 | - datamodule: default.yaml 9 | - callbacks: default.yaml 10 | - logger: wandb # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | 12 | - mode: default.yaml 13 | 14 | - experiment: null 15 | - hparams_search: null 16 | 17 | # enable color logging 18 | - override hydra/hydra_logging: colorlog 19 | - override hydra/job_logging: colorlog 20 | 21 | # path to original working directory 22 | # hydra hijacks working directory by changing it to the current log directory, 23 | # so it's useful to have this path as a special variable 24 | # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 25 | work_dir: ${hydra:runtime.cwd} 26 | 27 | # path to folder with data 28 | data_dir: ${work_dir}/data 29 | 30 | # pretty print config at the start of the run using Rich library 31 | print_config: True 32 | 33 | # disable python warnings if they annoy you 34 | ignore_warnings: True 35 | 36 | # evaluate on test set, using best model weights achieved during training 37 | # lightning chooses best weights based on metric specified in checkpoint callback 38 | test_after_training: False 39 | 40 | # seed for random number generators in pytorch, numpy and python.random 41 | seed: 0 42 | 43 | # name of the run is accessed by loggers 44 | # should be used along with experiment mode 45 | name: null 46 | 47 | # path to ckpt 48 | ckpt_path: null 49 | -------------------------------------------------------------------------------- /configs/datamodule/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: rigid_docking.datasets.dips_datamodule.DipsDataModule 2 | 3 | data_dir: /home/lchu11/scr4_jgray21/lchu11/Docking-dev/data/dips/pt_files 4 | train_list: /home/lchu11/scr4_jgray21/lchu11/Docking-dev/data/dips_equidock/train_list_rev.txt 5 | val_list: /home/lchu11/scr4_jgray21/lchu11/Docking-dev/data/dips_equidock/val_list_rev.txt 6 | test_list: /home/lchu11/scr4_jgray21/lchu11/Docking-dev/data/dips_equidock/test_list.txt 7 | batch_size: 1 8 | num_workers: 6 9 | pin_memory: True 10 | -------------------------------------------------------------------------------- /configs/datamodule/docking_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.docking_dataset.DockingDataModule 2 | 3 | train_set: dips_train 4 | val_set: dips_val 5 | batch_size: 1 6 | use_esm: True 7 | num_workers: 12 8 | pin_memory: False 9 | -------------------------------------------------------------------------------- /configs/datamodule/pinder_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.pinder_dataset.PinderDataModule 2 | 3 | batch_size: 1 4 | use_esm: True 5 | num_workers: 12 6 | pin_memory: True 7 | -------------------------------------------------------------------------------- /configs/datamodule/ppi_mlsb_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: datasets.ppi_mlsb_dataset.PPIDataModule 2 | 3 | train_dataset: dips_train_hetero 4 | val_dataset: dips_val_hetero 5 | use_esm: True 6 | crop_size: 1500 7 | batch_size: 1 8 | num_workers: 12 9 | pin_memory: True 10 | -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | ckpt: ./weights/weight_0.ckpt 3 | dataset: db5_test 4 | test_all: True 5 | out_pdb: False 6 | out_trj: False 7 | get_gt_energy: False 8 | ode: False 9 | out_pdb_dir: ./pdbs/ 10 | out_trj_dir: ./trjs/ 11 | out_csv_dir: ./csv_files/ 12 | out_csv: test.csv 13 | num_steps: 40 14 | num_samples: 1 15 | tr_noise_scale: 0.5 16 | rot_noise_scale: 0.5 17 | perturb_tr: True 18 | perturb_rot: True 19 | use_esm: True 20 | use_clash_force: False 21 | use_interface: False 22 | use_tm_score: False 23 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: "generative_model" 6 | name: ${name} 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /configs/mode/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default running mode 4 | 5 | default_mode: True 6 | 7 | hydra: 8 | # default output paths for all file logs 9 | run: 10 | dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 11 | sweep: 12 | dir: logs/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/model/DFMDock.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.DFMDock.DFMDock 2 | 3 | model: 4 | lm_embed_dim: 1301 # 1280 (ESM) + 21 (One-Hot) 5 | positional_embed_dim: 67 # 66 (relpos) + 1 (sym) 6 | spatial_embed_dim: 100 # 40 (dist) + 24 (phi) + 24 (psi) + 12 (omega) 7 | node_dim: 256 8 | edge_dim: 128 9 | inner_dim: 128 10 | depth: 6 11 | dropout: 0.1 12 | cut_off: 20.0 13 | normalize: True 14 | agg: 'mean' 15 | 16 | diffuser: 17 | r3: 18 | min_sigma: 0.1 19 | max_sigma: 30.0 20 | schedule: VE 21 | so3: 22 | num_omega: 1000 23 | num_sigma: 1000 24 | min_sigma: 0.1 25 | max_sigma: 1.5 26 | schedule: logarithmic 27 | cache_dir: .cache/ 28 | use_cached_score: False 29 | 30 | experiment: 31 | lr: 1e-4 32 | weight_decay: 0.0 33 | crop_size: 1200 34 | perturb_tr: True 35 | perturb_rot: True 36 | separate_energy_loss: True 37 | separate_tr_loss: True 38 | separate_rot_loss: True 39 | grad_energy: False 40 | use_contrastive_loss: False 41 | use_confidence_loss: False 42 | use_dist_loss: False 43 | use_interface_loss: False 44 | 45 | -------------------------------------------------------------------------------- /configs/model/DFMDock_guide.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.DFMDock_guide.DFMDock 2 | 3 | model: 4 | lm_embed_dim: 1302 # 1280 (ESM) + 21 (One-Hot) + 1 (interface) 5 | positional_embed_dim: 68 # 66 (Residue) + 1 (sym) + 1 (contact) 6 | spatial_embed_dim: 100 # 40 (dist) + 24 (phi) + 24 (psi) + 12 (omega) 7 | node_dim: 256 8 | edge_dim: 128 9 | inner_dim: 128 10 | depth: 6 11 | dropout: 0.1 12 | cut_off: 20.0 13 | normalize: True 14 | agg: 'mean' 15 | 16 | diffuser: 17 | r3: 18 | min_sigma: 0.1 19 | max_sigma: 30.0 20 | schedule: VE 21 | so3: 22 | num_omega: 1000 23 | num_sigma: 1000 24 | min_sigma: 0.1 25 | max_sigma: 1.5 26 | schedule: logarithmic 27 | cache_dir: .cache/ 28 | use_cached_score: False 29 | 30 | experiment: 31 | lr: 1e-4 32 | weight_decay: 0.0 33 | crop_size: 1200 34 | perturb_tr: True 35 | perturb_rot: True 36 | separate_energy_loss: True 37 | separate_tr_loss: True 38 | separate_rot_loss: True 39 | grad_energy: True 40 | use_contrastive_loss: True 41 | use_dist_loss: True 42 | use_interface_loss: True 43 | use_confidence_loss: True 44 | 45 | -------------------------------------------------------------------------------- /configs/model/force_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.force_model.Force_Model 2 | 3 | model: 4 | lm_embed_dim: 1301 # 1280 (ESM) + 21 (One-Hot) 5 | positional_embed_dim: 66 # 66 (Residue) 6 | spatial_embed_dim: 25 # 16 (dist) + 3 (direct) + 6 (orient) 7 | node_dim: 256 8 | edge_dim: 128 9 | inner_dim: 128 10 | depth: 6 11 | dropout: 0.1 12 | cut_off: 20.0 13 | normalize: True 14 | 15 | diffuser: 16 | r3: 17 | min_sigma: 0.1 18 | max_sigma: 30.0 19 | schedule: VE 20 | so3: 21 | num_omega: 1000 22 | num_sigma: 1000 23 | min_sigma: 0.1 24 | max_sigma: 1.5 25 | schedule: logarithmic 26 | cache_dir: .cache/ 27 | use_cached_score: False 28 | 29 | experiment: 30 | lr: 1e-4 31 | weight_decay: 0.0 32 | perturb_tr: True 33 | perturb_rot: True 34 | separate_tr_loss: True 35 | separate_rot_loss: True 36 | use_interface_loss: True 37 | use_contrastive_loss: False 38 | 39 | -------------------------------------------------------------------------------- /configs/model/score_model_mlsb.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.score_model_mlsb.Score_Model 2 | 3 | model: 4 | lm_embed_dim: 1301 # 1280 (ESM) + 21 (One-Hot) 5 | positional_embed_dim: 66 # 66 (Residue) 6 | spatial_embed_dim: 100 # 40 (dist) + 24 (phi) + 24 (psi) + 12 (omega) 7 | node_dim: 256 8 | edge_dim: 128 9 | inner_dim: 128 10 | depth: 6 11 | dropout: 0.1 12 | cut_off: 20.0 13 | normalize: True 14 | 15 | diffuser: 16 | r3: 17 | min_sigma: 0.1 18 | max_sigma: 30.0 19 | schedule: VE 20 | so3: 21 | num_omega: 1000 22 | num_sigma: 1000 23 | min_sigma: 0.1 24 | max_sigma: 1.5 25 | schedule: logarithmic 26 | cache_dir: .cache/ 27 | use_cached_score: False 28 | 29 | experiment: 30 | lr: 1e-4 31 | weight_decay: 0.0 32 | perturb_tr: True 33 | perturb_rot: True 34 | separate_energy_loss: True 35 | separate_tr_loss: True 36 | separate_rot_loss: True 37 | use_interface_loss: True 38 | grad_energy: False 39 | use_contrastive_loss: False 40 | 41 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | accelerator: auto 4 | min_epochs: 1 5 | max_epochs: 10 6 | num_sanity_val_steps: 1 7 | gradient_clip_val: 0.0 8 | check_val_every_n_epoch: 1 9 | #precision: bf16-mixed 10 | -------------------------------------------------------------------------------- /data/db5_test/1AVX.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1AVX.pt -------------------------------------------------------------------------------- /data/db5_test/1H1V.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1H1V.pt -------------------------------------------------------------------------------- /data/db5_test/1HCF.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1HCF.pt -------------------------------------------------------------------------------- /data/db5_test/1IRA.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1IRA.pt -------------------------------------------------------------------------------- /data/db5_test/1JIW.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1JIW.pt -------------------------------------------------------------------------------- /data/db5_test/1JPS.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1JPS.pt -------------------------------------------------------------------------------- /data/db5_test/1MLC.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1MLC.pt -------------------------------------------------------------------------------- /data/db5_test/1N2C.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1N2C.pt -------------------------------------------------------------------------------- /data/db5_test/1NW9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1NW9.pt -------------------------------------------------------------------------------- /data/db5_test/1QA9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1QA9.pt -------------------------------------------------------------------------------- /data/db5_test/1VFB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1VFB.pt -------------------------------------------------------------------------------- /data/db5_test/1ZHI.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/1ZHI.pt -------------------------------------------------------------------------------- /data/db5_test/2A1A.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2A1A.pt -------------------------------------------------------------------------------- /data/db5_test/2A9K.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2A9K.pt -------------------------------------------------------------------------------- /data/db5_test/2AYO.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2AYO.pt -------------------------------------------------------------------------------- /data/db5_test/2SIC.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2SIC.pt -------------------------------------------------------------------------------- /data/db5_test/2SNI.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2SNI.pt -------------------------------------------------------------------------------- /data/db5_test/2VDB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/2VDB.pt -------------------------------------------------------------------------------- /data/db5_test/3SZK.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/3SZK.pt -------------------------------------------------------------------------------- /data/db5_test/4POU.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/4POU.pt -------------------------------------------------------------------------------- /data/db5_test/5C7X.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/5C7X.pt -------------------------------------------------------------------------------- /data/db5_test/5HGG.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/5HGG.pt -------------------------------------------------------------------------------- /data/db5_test/5JMO.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/5JMO.pt -------------------------------------------------------------------------------- /data/db5_test/6B0S.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/6B0S.pt -------------------------------------------------------------------------------- /data/db5_test/7CEI.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/data/db5_test/7CEI.pt -------------------------------------------------------------------------------- /data/db5_test/test.txt: -------------------------------------------------------------------------------- 1 | 1AVX 2 | 1H1V 3 | 1HCF 4 | 1IRA 5 | 1JIW 6 | 1JPS 7 | 1MLC 8 | 1N2C 9 | 1NW9 10 | 1QA9 11 | 1VFB 12 | 1ZHI 13 | 2A1A 14 | 2A9K 15 | 2AYO 16 | 2SIC 17 | 2SNI 18 | 2VDB 19 | 3SZK 20 | 4POU 21 | 5C7X 22 | 5HGG 23 | 5JMO 24 | 6B0S 25 | 7CEI 26 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DFMDock 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.10 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - bzip2=1.0.8=h4bc722e_7 9 | - ca-certificates=2024.8.30=hbcca054_0 10 | - ld_impl_linux-64=2.40=hf3520f5_7 11 | - libffi=3.4.2=h7f98852_5 12 | - libgcc=14.1.0=h77fa898_1 13 | - libgcc-ng=14.1.0=h69a702a_1 14 | - libgomp=14.1.0=h77fa898_1 15 | - libnsl=2.0.1=hd590300_0 16 | - libsqlite=3.46.1=hadc24fc_0 17 | - libuuid=2.38.1=h0b41bf4_0 18 | - libxcrypt=4.4.36=hd590300_1 19 | - libzlib=1.3.1=h4ab18f5_1 20 | - ncurses=6.5=he02047a_1 21 | - openssl=3.3.2=hb9d3cd8_0 22 | - pip=24.2=pyh8b19718_1 23 | - python=3.10.14=hd12c33a_0_cpython 24 | - readline=8.2=h8228510_1 25 | - setuptools=74.1.2=pyhd8ed1ab_0 26 | - tk=8.6.13=noxft_h4845f30_101 27 | - tzdata=2024a=h8827d51_1 28 | - wheel=0.44.0=pyhd8ed1ab_0 29 | - xz=5.2.6=h166bdaf_0 30 | - pip: 31 | - aiohappyeyeballs==2.4.0 32 | - aiohttp==3.10.5 33 | - aiosignal==1.3.1 34 | - antlr4-python3-runtime==4.9.3 35 | - async-timeout==4.0.3 36 | - attrs==24.2.0 37 | - biotite==1.0.1 38 | - biotraj==1.2.1 39 | - certifi==2024.8.30 40 | - charset-normalizer==3.3.2 41 | - dm-tree==0.1.8 42 | - einops==0.8.0 43 | - fair-esm==2.0.0 44 | - filelock==3.13.1 45 | - frozenlist==1.4.1 46 | - fsspec==2024.2.0 47 | - hydra-core==1.3.2 48 | - idna==3.10 49 | - jinja2==3.1.3 50 | - lightning==2.4.0 51 | - lightning-utilities==0.11.7 52 | - markupsafe==2.1.5 53 | - mpmath==1.3.0 54 | - msgpack==1.1.0 55 | - multidict==6.1.0 56 | - networkx==3.2.1 57 | - numpy==1.26.3 58 | - nvidia-cublas-cu11==11.11.3.6 59 | - nvidia-cuda-cupti-cu11==11.8.87 60 | - nvidia-cuda-nvrtc-cu11==11.8.89 61 | - nvidia-cuda-runtime-cu11==11.8.89 62 | - nvidia-cudnn-cu11==9.1.0.70 63 | - nvidia-cufft-cu11==10.9.0.58 64 | - nvidia-curand-cu11==10.3.0.86 65 | - nvidia-cusolver-cu11==11.4.1.48 66 | - nvidia-cusparse-cu11==11.7.5.86 67 | - nvidia-nccl-cu11==2.20.5 68 | - nvidia-nvtx-cu11==11.8.86 69 | - omegaconf==2.3.0 70 | - packaging==24.1 71 | - pillow==10.2.0 72 | - psutil==6.0.0 73 | - pyparsing==3.1.4 74 | - pytorch-lightning==2.4.0 75 | - pyyaml==6.0.2 76 | - requests==2.32.3 77 | - scipy==1.14.1 78 | - sympy==1.12 79 | - torch-geometric==2.6.0 80 | - torchmetrics==1.4.2 81 | - tqdm==4.66.5 82 | - triton==3.0.0 83 | - typing-extensions==4.9.0 84 | - urllib3==2.2.3 85 | - yarl==1.11.1 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.0 2 | aiohttp==3.10.5 3 | aiosignal==1.3.1 4 | antlr4-python3-runtime==4.9.3 5 | async-timeout==4.0.3 6 | attrs==24.2.0 7 | biotite==1.0.1 8 | biotraj==1.2.1 9 | certifi==2024.8.30 10 | charset-normalizer==3.3.2 11 | dm-tree==0.1.8 12 | einops==0.8.0 13 | fair-esm==2.0.0 14 | filelock==3.13.1 15 | frozenlist==1.4.1 16 | fsspec==2024.2.0 17 | hydra-core==1.3.2 18 | idna==3.10 19 | Jinja2==3.1.3 20 | lightning==2.4.0 21 | lightning-utilities==0.11.7 22 | MarkupSafe==2.1.5 23 | mpmath==1.3.0 24 | msgpack==1.1.0 25 | multidict==6.1.0 26 | networkx==3.2.1 27 | numpy==1.26.3 28 | nvidia-cublas-cu11==11.11.3.6 29 | nvidia-cuda-cupti-cu11==11.8.87 30 | nvidia-cuda-nvrtc-cu11==11.8.89 31 | nvidia-cuda-runtime-cu11==11.8.89 32 | nvidia-cudnn-cu11==9.1.0.70 33 | nvidia-cufft-cu11==10.9.0.58 34 | nvidia-curand-cu11==10.3.0.86 35 | nvidia-cusolver-cu11==11.4.1.48 36 | nvidia-cusparse-cu11==11.7.5.86 37 | nvidia-nccl-cu11==2.20.5 38 | nvidia-nvtx-cu11==11.8.86 39 | omegaconf==2.3.0 40 | packaging==24.1 41 | pillow==10.2.0 42 | psutil==6.0.0 43 | pyparsing==3.1.4 44 | pytorch-lightning==2.4.0 45 | PyYAML==6.0.2 46 | requests==2.32.3 47 | scipy==1.14.1 48 | sympy==1.12 49 | torch==2.4.1+cu118 50 | torch-geometric==2.6.0 51 | torchaudio==2.4.1+cu118 52 | torchmetrics==1.4.2 53 | torchvision==0.19.1+cu118 54 | tqdm==4.66.5 55 | triton==3.0.0 56 | typing_extensions==4.9.0 57 | urllib3==2.2.3 58 | yarl==1.11.1 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='DFMDock', 5 | version='0.1', 6 | packages=find_packages(where="src"), 7 | package_dir={"": "src"}, 8 | ) 9 | 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/src/__init__.py -------------------------------------------------------------------------------- /src/data/gen_dips_attn.py: -------------------------------------------------------------------------------- 1 | import esm 2 | import os 3 | import torch 4 | import os.path as path 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | from src.utils.use_dill import get_data 8 | from Bio.Data.IUPACData import protein_letters_3to1 9 | 10 | 11 | def get_esm_attn(seq_prim, batch_converter, esm_model, device): 12 | # Use ESM-1b format. 13 | # The length of tokens is: 14 | # L (sequence length) + 2 (start and end tokens) 15 | seq = [ 16 | ("seq", seq_prim) 17 | ] 18 | out = batch_converter(seq) 19 | with torch.no_grad(): 20 | results = esm_model(out[-1].to(device), repr_layers=[33], return_contacts=True) 21 | attn = results["attentions"].squeeze(0)[:, :, 1:-1, 1:-1] 22 | output = attn.permute(2, 3, 0, 1).flatten(2, 3) 23 | 24 | return output 25 | 26 | 27 | if __name__ == '__main__': 28 | data_dir = "/home/lchu11/scr4_jgray21/lchu11/data/dips/pairs_pruned" 29 | data_list = "/home/lchu11/scr4_jgray21/lchu11/data/dips/pairs_pruned/pairs-postprocessed.txt" 30 | save_dir = "/home/lchu11/scr4_jgray21/lchu11/data/dips/pt_attn" 31 | 32 | os.makedirs(save_dir, exist_ok=True) 33 | 34 | with open(data_list, 'r') as f: 35 | lines = f.readlines() 36 | file_list = [line.strip() for line in lines] 37 | 38 | # Load esm 39 | esm_model, alphabet = esm.pretrained.load_model_and_alphabet('/home/lchu11/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt') 40 | batch_converter = alphabet.get_batch_converter() 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | esm_model = esm_model.to(device).eval() 43 | 44 | # save 45 | for _id in tqdm(file_list): 46 | pdb_file = path.join(data_dir, _id) 47 | split_string = _id.split('/') 48 | _id = split_string[0] + '_' + split_string[1].rsplit('.', 1)[0] 49 | 50 | # Get data from files 51 | data = get_data(pdb_file) 52 | 53 | # Convert res from 3 to 1 54 | aa_code = defaultdict(lambda: "") 55 | aa_code.update( 56 | {k.upper():v for k,v in protein_letters_3to1.items()}) 57 | 58 | seq1 = "".join(aa_code[s] for s in data['receptor']['res']) 59 | seq2 = "".join(aa_code[s] for s in data['ligand']['res']) 60 | seq = seq1 + seq2 61 | 62 | 63 | # ESM embedding 64 | attn = get_esm_attn(seq, batch_converter, esm_model, device).half() 65 | 66 | torch.save(attn, path.join(save_dir, _id+'.pt')) 67 | break 68 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/docking_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import esm 4 | import random 5 | import torch 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from typing import Optional 9 | from torch.utils.data import DataLoader, Dataset 10 | from scipy.spatial.transform import Rotation 11 | from utils import residue_constants 12 | 13 | #---------------------------------------------------------------------------- 14 | # Helper functions 15 | 16 | def random_rotation(rec_pos, lig_pos): 17 | rot = torch.from_numpy(Rotation.random().as_matrix()).float() 18 | pos = torch.cat([rec_pos, lig_pos], dim=0) 19 | cen = pos.mean(dim=(0, 1)) 20 | pos = (pos - cen) @ rot.T 21 | rec_pos_out = pos[:rec_pos.size(0)] 22 | lig_pos_out = pos[rec_pos.size(0):] 23 | return rec_pos_out, lig_pos_out 24 | 25 | def get_esm_attn(seq_prim, batch_converter, esm_model, device): 26 | # Use ESM-1b format. 27 | # The length of tokens is: 28 | # L (sequence length) + 2 (start and end tokens) 29 | seq = [ 30 | ("seq", seq_prim) 31 | ] 32 | out = batch_converter(seq) 33 | with torch.no_grad(): 34 | results = esm_model(out[-1].to(device), repr_layers=[33], return_contacts=True) 35 | attn = results["attentions"].squeeze(0)[:, :, 1:-1, 1:-1].cpu() 36 | output = attn.permute(2, 3, 0, 1).flatten(2, 3) 37 | 38 | return output 39 | 40 | #---------------------------------------------------------------------------- 41 | # Dataset class 42 | 43 | class DockingDataset(Dataset): 44 | def __init__( 45 | self, 46 | dataset: str, 47 | training: bool = True, 48 | use_esm: bool = True, 49 | ): 50 | self.dataset = dataset 51 | self.training = training 52 | self.use_esm = use_esm 53 | 54 | if dataset == 'dips_train': 55 | self.data_dir = "/scratch4/jgray21/lchu11/data/dips/pt_clean" 56 | self.data_list = "/scratch4/jgray21/lchu11/data/dips/data_list/diffdock-pp/train.txt" 57 | 58 | elif dataset == 'dips_val': 59 | self.data_dir = "/scratch4/jgray21/lchu11/data/dips/pt_clean" 60 | self.data_list = "/scratch4/jgray21/lchu11/data/dips/data_list/diffdock-pp/val.txt" 61 | 62 | elif dataset == 'dips_testing': 63 | self.data_dir = "/scratch4/jgray21/lchu11/data/dips/pt_clean" 64 | self.data_list = "/scratch4/jgray21/lchu11/data/dips/data_list/diffdock-pp/testing.txt" 65 | 66 | elif dataset == 'dips_train_hetero': 67 | self.data_dir = "/scratch4/jgray21/lchu11/data/pt/dips_bb" 68 | self.data_list = "/scratch4/jgray21/lchu11/data/dips/data_list/diffdock-pp/dips_train_hetero.txt" 69 | 70 | elif dataset == 'dips_val_hetero': 71 | self.data_dir = "/scratch4/jgray21/lchu11/data/pt/dips_bb" 72 | self.data_list = "/scratch4/jgray21/lchu11/data/dips/data_list/diffdock-pp/dips_val_hetero.txt" 73 | 74 | with open(self.data_list, 'r') as f: 75 | lines = f.readlines() 76 | self.file_list = [line.strip() for line in lines] 77 | 78 | def __getitem__(self, idx: int): 79 | # Get info from file_list 80 | if self.dataset[:4] == 'dips': 81 | _id = self.file_list[idx] 82 | split_string = _id.split('/') 83 | _id = split_string[0] + '_' + split_string[1].rsplit('.', 1)[0] 84 | data = torch.load(os.path.join(self.data_dir, _id+'.pt')) 85 | else: 86 | _id = self.file_list[idx] 87 | data = torch.load(os.path.join(self.data_dir, _id+'.pt')) 88 | 89 | rec_x = data['receptor'].x 90 | rec_seq = data['receptor'].seq 91 | rec_pos = data['receptor'].pos.float() 92 | lig_x = data['ligand'].x 93 | lig_seq = data['ligand'].seq 94 | lig_pos = data['ligand'].pos.float() 95 | 96 | # One-Hot embeddings 97 | rec_onehot = torch.from_numpy(residue_constants.sequence_to_onehot( 98 | sequence=rec_seq, 99 | mapping=residue_constants.restype_order_with_x, 100 | map_unknown_to_x=True, 101 | )).float() 102 | 103 | lig_onehot = torch.from_numpy(residue_constants.sequence_to_onehot( 104 | sequence=lig_seq, 105 | mapping=residue_constants.restype_order_with_x, 106 | map_unknown_to_x=True, 107 | )).float() 108 | 109 | # ESM embeddings 110 | if self.use_esm: 111 | rec_x = torch.cat([rec_x, rec_onehot], dim=-1) 112 | lig_x = torch.cat([lig_x, lig_onehot], dim=-1) 113 | else: 114 | rec_x = rec_onehot 115 | lig_x = lig_onehot 116 | 117 | # Shuffle and Crop for training 118 | if self.training: 119 | # Shuffle the order of rec and lig 120 | vars_list = [(rec_x, rec_seq, rec_pos), (lig_x, lig_seq, lig_pos)] 121 | if random.random() > 0.5: 122 | rec_x, rec_seq, rec_pos = vars_list[1] 123 | lig_x, lig_seq, lig_pos = vars_list[0] 124 | 125 | # Random rotation augmentation 126 | rec_pos, lig_pos = random_rotation(rec_pos, lig_pos) 127 | 128 | # is homomer 129 | is_homomer = rec_seq == lig_seq 130 | 131 | # Output 132 | output = { 133 | 'id': _id, 134 | 'rec_seq': rec_seq, 135 | 'lig_seq': lig_seq, 136 | 'rec_x': rec_x, 137 | 'lig_x': lig_x, 138 | 'rec_pos': rec_pos, 139 | 'lig_pos': lig_pos, 140 | 'is_homomer': is_homomer, 141 | } 142 | 143 | return {key: value for key, value in output.items()} 144 | 145 | def __len__(self): 146 | return len(self.file_list) 147 | 148 | 149 | #---------------------------------------------------------------------------- 150 | # DataModule class 151 | 152 | class DockingDataModule(pl.LightningDataModule): 153 | def __init__( 154 | self, 155 | train_set: str = 'dips_train', 156 | val_set: str = 'dips_val', 157 | batch_size: int = 1, 158 | use_esm: bool = True, 159 | **kwargs 160 | ): 161 | super().__init__() 162 | self.train_set = train_set 163 | self.val_set = val_set 164 | self.batch_size = batch_size 165 | self.use_esm = use_esm 166 | self.num_workers = kwargs['num_workers'] 167 | self.pin_memory = kwargs['pin_memory'] 168 | 169 | self.data_train: Optional[Dataset] = None 170 | self.data_val: Optional[Dataset] = None 171 | 172 | def prepare_data(self): 173 | pass 174 | 175 | def setup(self, stage: Optional[str] = None): 176 | self.data_train = DockingDataset( 177 | dataset=self.train_set, 178 | use_esm=self.use_esm, 179 | ) 180 | self.data_val = DockingDataset( 181 | dataset=self.val_set, 182 | use_esm=self.use_esm, 183 | ) 184 | 185 | def train_dataloader(self): 186 | return DataLoader( 187 | dataset=self.data_train, 188 | batch_size=self.batch_size, 189 | num_workers=self.num_workers, 190 | pin_memory=self.pin_memory, 191 | shuffle=True, 192 | ) 193 | 194 | def val_dataloader(self): 195 | return DataLoader( 196 | dataset=self.data_val, 197 | batch_size=self.batch_size, 198 | num_workers=self.num_workers, 199 | pin_memory=self.pin_memory, 200 | shuffle=False, 201 | ) 202 | 203 | 204 | if __name__ == '__main__': 205 | dataset = DockingDataset(dataset='dips_train') 206 | print(dataset[0]) 207 | """ 208 | dataloader = DataLoader(dataset, batch_size=1, num_workers=6) 209 | 210 | with open('dips_train_size.txt', 'w') as f: 211 | for batch in dataloader: 212 | n = batch['rec_x'].size(1) + batch['lig_x'].size(1) 213 | f.write(str(n) + '\n') 214 | 215 | """ 216 | 217 | -------------------------------------------------------------------------------- /src/datasets/pinder_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | import gzip 8 | import pickle 9 | from pathlib import Path 10 | from typing import Optional 11 | from torch.utils.data import DataLoader, Dataset 12 | from scipy.spatial.transform import Rotation 13 | from pinder.core.index.utils import get_index 14 | from utils import residue_constants 15 | 16 | 17 | #---------------------------------------------------------------------------- 18 | # Helper functions 19 | 20 | def load_dict_data(file_path): 21 | with gzip.open(file_path, 'rb') as f: 22 | dict_data = pickle.load(f) 23 | return dict_data 24 | 25 | def random_rotation(rec_pos, lig_pos): 26 | rot = torch.from_numpy(Rotation.random().as_matrix()).float() 27 | pos = torch.cat([rec_pos, lig_pos], dim=0) 28 | cen = pos.mean(dim=(0, 1)) 29 | pos = (pos - cen) @ rot.T 30 | rec_pos_out = pos[:rec_pos.size(0)] 31 | lig_pos_out = pos[rec_pos.size(0):] 32 | return rec_pos_out, lig_pos_out 33 | 34 | #---------------------------------------------------------------------------- 35 | # Dataset class 36 | 37 | class PinderDataset(Dataset): 38 | def __init__( 39 | self, 40 | data_dir, 41 | test_split: str = 'pinder_s', 42 | training: bool = True, 43 | use_esm: bool = False, 44 | ): 45 | self.training = training 46 | self.use_esm = use_esm 47 | 48 | # Load the dictionary data 49 | self.data_dir = data_dir 50 | if training: 51 | self.data_list = [f.name.split('.')[0] for f in Path(self.data_dir).iterdir()] 52 | else: 53 | pindex = get_index() 54 | self.data_list = list(pindex.query(f'{test_split} == True').id) 55 | 56 | if self.use_esm: 57 | self.h5f = h5py.File('/scratch16/jgray21/lchu11/data/h5_files/pinder_combined.h5', 'r') 58 | 59 | def __getitem__(self, idx: int): 60 | data = load_dict_data(os.path.join(self.data_dir, f'{self.data_list[idx]}.pkl.gz')) 61 | 62 | _id = data['id'] 63 | rec_seq = data['rec_seq'] 64 | lig_seq = data['lig_seq'] 65 | rec_pos = torch.from_numpy(data['rec_pos']).float() 66 | lig_pos = torch.from_numpy(data['lig_pos']).float() 67 | 68 | # One-Hot embeddings 69 | rec_x = torch.from_numpy(residue_constants.sequence_to_onehot( 70 | sequence=rec_seq, 71 | mapping=residue_constants.restype_order_with_x, 72 | map_unknown_to_x=True, 73 | )).float() 74 | 75 | lig_x = torch.from_numpy(residue_constants.sequence_to_onehot( 76 | sequence=lig_seq, 77 | mapping=residue_constants.restype_order_with_x, 78 | map_unknown_to_x=True, 79 | )).float() 80 | 81 | # ESM embeddings 82 | if self.use_esm: 83 | group = self.h5f[_id] 84 | rec_esm = torch.tensor(group['rec_esm'][:]) 85 | lig_esm = torch.tensor(group['lig_esm'][:]) 86 | 87 | rec_x = torch.cat([rec_esm, rec_x], dim=-1) 88 | lig_x = torch.cat([lig_esm, lig_x], dim=-1) 89 | 90 | if self.training: 91 | # shuffle the order of rec and lig 92 | vars_list = [(rec_x, rec_pos), (lig_x, lig_pos)] 93 | random.shuffle(vars_list) 94 | rec_x, rec_pos = vars_list[0] 95 | lig_x, lig_pos = vars_list[1] 96 | 97 | 98 | # random rotation augmentation 99 | rec_pos, lig_pos = random_rotation(rec_pos, lig_pos) 100 | 101 | # is homomer 102 | is_homomer = rec_seq == lig_seq 103 | 104 | # Output 105 | output = { 106 | 'id': _id, 107 | 'rec_seq': rec_seq, 108 | 'lig_seq': lig_seq, 109 | 'rec_x': rec_x, 110 | 'lig_x': lig_x, 111 | 'rec_pos': rec_pos, 112 | 'lig_pos': lig_pos, 113 | 'is_homomer': is_homomer, 114 | } 115 | 116 | return {key: value for key, value in output.items()} 117 | 118 | def __len__(self): 119 | return len(self.data_list) 120 | 121 | 122 | #---------------------------------------------------------------------------- 123 | # DataModule class 124 | 125 | class PinderDataModule(pl.LightningDataModule): 126 | def __init__( 127 | self, 128 | batch_size: int = 1, 129 | use_esm: bool = True, 130 | **kwargs 131 | ): 132 | super().__init__() 133 | self.batch_size = batch_size 134 | self.use_esm = use_esm 135 | self.num_workers = kwargs['num_workers'] 136 | self.pin_memory = kwargs['pin_memory'] 137 | 138 | self.data_train: Optional[Dataset] = None 139 | self.data_val: Optional[Dataset] = None 140 | 141 | def prepare_data(self): 142 | pass 143 | 144 | def setup(self, stage: Optional[str] = None): 145 | self.data_train = PinderDataset( 146 | data_dir='/scratch4/jgray21/lchu11/data/pinder/train', 147 | use_esm=self.use_esm, 148 | ) 149 | self.data_val = PinderDataset( 150 | data_dir='/scratch4/jgray21/lchu11/data/pinder/val', 151 | use_esm=self.use_esm, 152 | ) 153 | 154 | def train_dataloader(self): 155 | return DataLoader( 156 | dataset=self.data_train, 157 | batch_size=self.batch_size, 158 | num_workers=self.num_workers, 159 | pin_memory=self.pin_memory, 160 | shuffle=True, 161 | ) 162 | 163 | def val_dataloader(self): 164 | return DataLoader( 165 | dataset=self.data_val, 166 | batch_size=self.batch_size, 167 | num_workers=self.num_workers, 168 | pin_memory=self.pin_memory, 169 | shuffle=False, 170 | ) 171 | 172 | #---------------------------------------------------------------------------- 173 | # Testing 174 | 175 | if __name__ == '__main__': 176 | dataset = PinderDataset( 177 | data_dir='/scratch4/jgray21/lchu11/data/pinder/train', 178 | test_split='pinder_s', 179 | training=True, 180 | use_esm=True, 181 | ) 182 | print(dataset[0]) 183 | -------------------------------------------------------------------------------- /src/datasets/ppi_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | import warnings 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from typing import Optional 11 | from torch.utils.data import DataLoader, Dataset 12 | from torch_geometric.data import HeteroData 13 | from scipy.spatial.transform import Rotation 14 | from utils import residue_constants 15 | 16 | #---------------------------------------------------------------------------- 17 | # Helper functions 18 | 19 | def get_interface_residues(coords, asym_id, interface_threshold=10.0): 20 | coord_diff = coords[..., None, :, :] - coords[..., None, :, :, :] 21 | pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) 22 | diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float() 23 | mask = diff_chain_mask[..., None].bool() 24 | min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1) 25 | valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1) 26 | interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0] 27 | 28 | return interface_residues_idxs 29 | 30 | def get_spatial_crop_idx(coords, asym_id, crop_size=256, interface_threshold=10.0): 31 | interface_residues = get_interface_residues(coords, asym_id, interface_threshold=interface_threshold) 32 | 33 | if not torch.any(interface_residues): 34 | return get_contiguous_crop_idx(asym_id, crop_size) 35 | 36 | target_res_idx = randint(lower=0, upper=interface_residues.shape[-1] - 1) 37 | target_res = interface_residues[target_res_idx] 38 | 39 | ca_positions = coords[..., 1, :] 40 | coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :] 41 | ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) 42 | to_target_distances = ca_pairwise_dists[target_res] 43 | 44 | break_tie = ( 45 | torch.arange( 46 | 0, to_target_distances.shape[-1] 47 | ).float() 48 | * 1e-3 49 | ) 50 | to_target_distances += break_tie 51 | ret = torch.argsort(to_target_distances)[:crop_size] 52 | return ret.sort().values 53 | 54 | def get_contiguous_crop_idx(asym_id, crop_size): 55 | unique_asym_ids, chain_idxs, chain_lens = asym_id.unique(dim=-1, 56 | return_inverse=True, 57 | return_counts=True) 58 | 59 | shuffle_idx = torch.randperm(chain_lens.shape[-1]) 60 | 61 | 62 | _, idx_sorted = torch.sort(chain_idxs, stable=True) 63 | cum_sum = chain_lens.cumsum(dim=0) 64 | cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0) 65 | asym_offsets = idx_sorted[cum_sum] 66 | 67 | num_budget = crop_size 68 | num_remaining = len(chain_idxs) 69 | 70 | crop_idxs = [] 71 | for i, idx in enumerate(shuffle_idx): 72 | chain_len = int(chain_lens[idx]) 73 | num_remaining -= chain_len 74 | 75 | if i == 0: 76 | crop_size_max = min(num_budget - 50, chain_len) 77 | crop_size_min = min(chain_len, 50) 78 | else: 79 | crop_size_max = min(num_budget, chain_len) 80 | crop_size_min = min(chain_len, max(50, num_budget - num_remaining)) 81 | 82 | chain_crop_size = randint(lower=crop_size_min, 83 | upper=crop_size_max) 84 | 85 | num_budget -= chain_crop_size 86 | 87 | chain_start = randint(lower=0, 88 | upper=chain_len - chain_crop_size) 89 | 90 | asym_offset = asym_offsets[idx] 91 | crop_idxs.append( 92 | torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size) 93 | ) 94 | 95 | return torch.concat(crop_idxs).sort().values 96 | 97 | def randint(lower, upper): 98 | return int(torch.randint( 99 | lower, 100 | upper + 1, 101 | (1,), 102 | )[0]) 103 | 104 | 105 | def get_interface_residue_tensors(set1, set2, threshold=8.0): 106 | n1_len = set1.shape[0] 107 | n2_len = set2.shape[0] 108 | 109 | # Calculate the Euclidean distance between each pair of points from the two sets 110 | dists = torch.cdist(set1, set2) 111 | 112 | # Find the indices where the distance is less than the threshold 113 | close_points = dists < threshold 114 | 115 | # Create indicator tensors initialized to 0 116 | indicator_set1 = torch.zeros((n1_len, 1), dtype=torch.float32) 117 | indicator_set2 = torch.zeros((n2_len, 1), dtype=torch.float32) 118 | 119 | # Set the corresponding indices to 1 where the points are close 120 | indicator_set1[torch.any(close_points, dim=1)] = 1.0 121 | indicator_set2[torch.any(close_points, dim=0)] = 1.0 122 | 123 | return indicator_set1, indicator_set2 124 | 125 | def get_sampled_contact_matrix(set1, set2, threshold=8.0, num_samples=None): 126 | """ 127 | Constructs a contact matrix for two sets of residues with 1 indicating sampled contact pairs. 128 | 129 | :param set1: PyTorch tensor of shape [n1, 3] for residues in set 1 130 | :param set2: PyTorch tensor of shape [n2, 3] for residues in set 2 131 | :param threshold: Distance threshold to define contact residues 132 | :param num_samples: Number of contact pairs to sample. If None, use all valid contacts. 133 | :return: PyTorch tensor of shape [(n1+n2), (n1+n2)] representing the contact matrix with sampled contact pairs 134 | """ 135 | n1 = set1.size(0) 136 | n2 = set2.size(0) 137 | 138 | # Compute the pairwise distances between set1 and set2 139 | dists = torch.cdist(set1, set2) 140 | 141 | # Find pairs where distances are less than or equal to the threshold 142 | contact_pairs = (dists <= threshold) 143 | 144 | # Get indices of valid contact pairs 145 | contact_indices = contact_pairs.nonzero(as_tuple=False) 146 | 147 | # Initialize the contact matrix with zeros 148 | contact_matrix = torch.zeros((n1 + n2, n1 + n2)) 149 | 150 | # Determine the number of samples 151 | if num_samples is None or num_samples > contact_indices.size(0): 152 | num_samples = contact_indices.size(0) 153 | 154 | if num_samples > 0: 155 | # Sample contact indices uniformly 156 | sampled_indices = contact_indices[torch.randint(0, contact_indices.size(0), (num_samples,))] 157 | 158 | # Fill in the contact matrix for the sampled contacts 159 | contact_matrix[sampled_indices[:, 0], sampled_indices[:, 1] + n1] = 1.0 160 | contact_matrix[sampled_indices[:, 1] + n1, sampled_indices[:, 0]] = 1.0 161 | 162 | return contact_matrix 163 | 164 | def one_hot(x, v_bins): 165 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 166 | diffs = x[..., None] - reshaped_bins 167 | am = torch.argmin(torch.abs(diffs), dim=-1) 168 | return F.one_hot(am, num_classes=len(v_bins)).float() 169 | 170 | def relpos(res_id, asym_id, use_chain_relative=True): 171 | max_relative_idx = 32 172 | pos = res_id 173 | asym_id_same = (asym_id[..., None] == asym_id[..., None, :]) 174 | offset = pos[..., None] - pos[..., None, :] 175 | 176 | clipped_offset = torch.clamp( 177 | offset + max_relative_idx, 0, 2 * max_relative_idx 178 | ) 179 | 180 | rel_feats = [] 181 | if use_chain_relative: 182 | final_offset = torch.where( 183 | asym_id_same, 184 | clipped_offset, 185 | (2 * max_relative_idx + 1) * 186 | torch.ones_like(clipped_offset) 187 | ) 188 | 189 | boundaries = torch.arange( 190 | start=0, end=2 * max_relative_idx + 2 191 | ) 192 | rel_pos = one_hot( 193 | final_offset, 194 | boundaries, 195 | ) 196 | 197 | rel_feats.append(rel_pos) 198 | 199 | else: 200 | boundaries = torch.arange( 201 | start=0, end=2 * max_relative_idx + 1 202 | ) 203 | rel_pos = one_hot( 204 | clipped_offset, boundaries, 205 | ) 206 | rel_feats.append(rel_pos) 207 | 208 | rel_feat = torch.cat(rel_feats, dim=-1).float() 209 | 210 | return rel_feat 211 | 212 | def random_rotation(rec_pos, lig_pos): 213 | rot = torch.from_numpy(Rotation.random().as_matrix()).float() 214 | pos = torch.cat([rec_pos, lig_pos], dim=0) 215 | cen = pos[..., 1, :].mean(dim=0) 216 | pos = (pos - cen) @ rot.T 217 | rec_pos_out = pos[:rec_pos.size(0)] 218 | lig_pos_out = pos[rec_pos.size(0):] 219 | return rec_pos_out, lig_pos_out 220 | 221 | #---------------------------------------------------------------------------- 222 | # Dataset class 223 | 224 | class PPIDataset(Dataset): 225 | def __init__( 226 | self, 227 | dataset: str, 228 | training: bool = True, 229 | use_interface: bool = False, 230 | use_esm: bool = True, 231 | crop_size: int = 1500, 232 | ): 233 | self.dataset = dataset 234 | self.training = training 235 | self.use_interface = use_interface 236 | self.use_esm = use_esm 237 | self.crop_size = crop_size 238 | 239 | parent_dir = Path(__file__).resolve().parents[2] 240 | 241 | if dataset == 'db5_test': 242 | self.data_dir = f"{parent_dir}/data/db5_test" 243 | self.data_list = f"{parent_dir}/data/db5_test/test.txt" 244 | 245 | with open(self.data_list, 'r') as f: 246 | lines = f.readlines() 247 | self.file_list = [line.strip() for line in lines] 248 | 249 | def __getitem__(self, idx: int): 250 | # Get info from file_list 251 | if self.dataset[:4] == 'dips': 252 | _id = self.file_list[idx] 253 | split_string = _id.split('/') 254 | _id = split_string[0] + '_' + split_string[1].rsplit('.', 1)[0] 255 | data = torch.load(os.path.join(self.data_dir, _id+'.pt')) 256 | else: 257 | _id = self.file_list[idx] 258 | data = torch.load(os.path.join(self.data_dir, _id+'.pt')) 259 | 260 | rec_x = data['receptor'].x 261 | rec_seq = data['receptor'].seq 262 | rec_pos = data['receptor'].pos 263 | lig_x = data['ligand'].x 264 | lig_seq = data['ligand'].seq 265 | lig_pos = data['ligand'].pos 266 | 267 | # One-Hot embeddings 268 | rec_onehot = torch.from_numpy(residue_constants.sequence_to_onehot( 269 | sequence=rec_seq, 270 | mapping=residue_constants.restype_order_with_x, 271 | map_unknown_to_x=True, 272 | )).float() 273 | 274 | lig_onehot = torch.from_numpy(residue_constants.sequence_to_onehot( 275 | sequence=lig_seq, 276 | mapping=residue_constants.restype_order_with_x, 277 | map_unknown_to_x=True, 278 | )).float() 279 | 280 | # ESM embeddings 281 | if self.use_esm: 282 | rec_x = torch.cat([rec_x, rec_onehot], dim=-1) 283 | lig_x = torch.cat([lig_x, lig_onehot], dim=-1) 284 | else: 285 | rec_x = rec_onehot 286 | lig_x = lig_onehot 287 | 288 | # Shuffle and Crop for training 289 | if self.training: 290 | # Shuffle the order of rec and lig 291 | vars_list = [(rec_x, rec_seq, rec_pos), (lig_x, lig_seq, lig_pos)] 292 | random.shuffle(vars_list) 293 | rec_x, rec_seq, rec_pos = vars_list[0] 294 | lig_x, lig_seq, lig_pos = vars_list[1] 295 | 296 | # Crop to crop_size 297 | rec_x, lig_x, rec_pos, lig_pos, res_id, asym_id = self.crop_to_size(rec_x, lig_x, rec_pos, lig_pos) 298 | else: 299 | # get res_id and asym_id 300 | n = rec_x.size(0) + lig_x.size(0) 301 | res_id = torch.arange(n).long() 302 | asym_id = torch.zeros(n).long() 303 | asym_id[rec_x.size(0):] = 1 304 | 305 | # Positional embeddings 306 | position_matrix = relpos(res_id, asym_id) 307 | 308 | # Random rotation augmentation 309 | rec_pos, lig_pos = random_rotation(rec_pos, lig_pos) 310 | 311 | # Interface residues 312 | rec_ires, lig_ires = get_interface_residue_tensors(rec_pos[..., 1, :], lig_pos[..., 1, :]) 313 | ires = torch.cat([rec_ires, lig_ires], dim=0) 314 | 315 | # Output 316 | output = { 317 | 'id': _id, 318 | 'rec_seq': rec_seq, 319 | 'lig_seq': lig_seq, 320 | 'rec_x': rec_x, 321 | 'lig_x': lig_x, 322 | 'rec_pos': rec_pos, 323 | 'lig_pos': lig_pos, 324 | 'position_matrix': position_matrix, 325 | 'ires': ires, 326 | } 327 | 328 | return {key: value for key, value in output.items()} 329 | 330 | def __len__(self): 331 | return len(self.file_list) 332 | 333 | def crop_to_size(self, rec_x, lig_x, rec_pos, lig_pos): 334 | 335 | n = rec_x.size(0) + lig_x.size(0) 336 | res_id = torch.arange(n).long() 337 | asym_id = torch.zeros(n).long() 338 | asym_id[rec_x.size(0):] = 1 339 | 340 | x = torch.cat([rec_x, lig_x], dim=0) 341 | pos = torch.cat([rec_pos, lig_pos], dim=0) 342 | 343 | #use_spatial_crop = random.random() < 0.5 344 | use_spatial_crop = True 345 | num_res = asym_id.size(0) 346 | 347 | if num_res <= self.crop_size: 348 | crop_idxs = torch.arange(num_res) 349 | elif use_spatial_crop: 350 | crop_idxs = get_spatial_crop_idx(pos, asym_id, crop_size=self.crop_size) 351 | else: 352 | crop_idxs = get_contiguous_crop_idx(asym_id, crop_size=self.crop_size) 353 | 354 | res_id = torch.index_select(res_id, 0, crop_idxs) 355 | asym_id = torch.index_select(asym_id, 0, crop_idxs) 356 | x = torch.index_select(x, 0, crop_idxs) 357 | pos = torch.index_select(pos, 0, crop_idxs) 358 | 359 | sep = asym_id.tolist().index(1) 360 | rec_x = x[:sep] 361 | lig_x = x[sep:] 362 | rec_pos = pos[:sep] 363 | lig_pos = pos[sep:] 364 | 365 | return rec_x, lig_x, rec_pos, lig_pos, res_id, asym_id 366 | 367 | 368 | if __name__ == '__main__': 369 | dataset = PPIDataset(dataset='db5_test') 370 | -------------------------------------------------------------------------------- /src/datasets/submit_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --ntasks-per-node=48 4 | #SBATCH --partition=parallel 5 | #SBATCH --account=jgray21 6 | #SBATCH --time=12:00:00 7 | #SBATCH --output=slogs/%j.out 8 | 9 | #### execute code 10 | python docking_dataset.py 11 | -------------------------------------------------------------------------------- /src/inference_mlsb.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=FutureWarning) 3 | 4 | # import packages 5 | import os 6 | import csv 7 | import torch 8 | import numpy as np 9 | import hydra 10 | from omegaconf import DictConfig, OmegaConf 11 | from dataclasses import dataclass 12 | from tqdm import tqdm 13 | from torch.utils import data 14 | from scipy.spatial.transform import Rotation 15 | from models.score_model_mlsb import Score_Model 16 | from datasets.ppi_dataset import PPIDataset 17 | from utils.geometry import axis_angle_to_matrix, matrix_to_axis_angle 18 | from utils.pdb import save_PDB, place_fourth_atom 19 | from utils.metrics import compute_metrics 20 | 21 | #---------------------------------------------------------------------------- 22 | # Data class for pose 23 | 24 | @dataclass 25 | class pose(): 26 | _id: str 27 | rec_seq: str 28 | lig_seq: str 29 | rec_pos: torch.FloatTensor 30 | lig_pos: torch.FloatTensor 31 | index: str = None 32 | 33 | #---------------------------------------------------------------------------- 34 | # Helper functions 35 | 36 | def sample_sphere(radius=1): 37 | """Samples a random point on the surface of a sphere. 38 | 39 | Args: 40 | radius: The radius of the sphere. 41 | 42 | Returns: 43 | A 3D NumPy array representing the sampled point. 44 | """ 45 | 46 | # Generate two random numbers in the range [0, 1). 47 | u = np.random.rand() 48 | v = np.random.rand() 49 | 50 | # Compute the azimuthal and polar angles. 51 | theta = 2 * np.pi * u 52 | phi = np.arccos(2 * v - 1) 53 | 54 | # Compute the Cartesian coordinates of the sampled point. 55 | x = radius * np.sin(phi) * np.cos(theta) 56 | y = radius * np.sin(phi) * np.sin(theta) 57 | z = radius * np.cos(phi) 58 | 59 | return np.array([x, y, z]) 60 | 61 | def rot_compose(r1, r2): 62 | R1 = axis_angle_to_matrix(r1) 63 | R2 = axis_angle_to_matrix(r2) 64 | R = torch.einsum('b i j, b j k -> b i k', R2, R1) 65 | r = matrix_to_axis_angle(R) 66 | return r 67 | 68 | def get_full_coords(coords): 69 | #get full coords 70 | N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] 71 | # Infer CB coordinates. 72 | b = CA - N 73 | c = C - CA 74 | a = b.cross(c, dim=-1) 75 | CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA 76 | 77 | O = place_fourth_atom(torch.roll(N, -1, 0), 78 | CA, C, 79 | torch.tensor(1.231), 80 | torch.tensor(2.108), 81 | torch.tensor(-3.142)) 82 | full_coords = torch.stack( 83 | [N, CA, C, O, CB], dim=1) 84 | 85 | return full_coords 86 | 87 | #---------------------------------------------------------------------------- 88 | # Sampler 89 | 90 | class Sampler: 91 | def __init__( 92 | self, 93 | conf: DictConfig, 94 | ): 95 | self.data_conf = conf.data 96 | self.perturb_tr = self.data_conf.perturb_tr 97 | self.perturb_rot = self.data_conf.perturb_rot 98 | 99 | # set device 100 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | 102 | # load models 103 | self.model = Score_Model.load_from_checkpoint( 104 | self.data_conf.ckpt, 105 | map_location=self.device, 106 | ) 107 | self.model.eval() 108 | self.model.to(self.device) 109 | 110 | # get testset 111 | testset = PPIDataset( 112 | dataset=self.data_conf.dataset, 113 | training=False, 114 | use_esm=self.data_conf.use_esm, 115 | ) 116 | 117 | # load dataset 118 | if self.data_conf.test_all: 119 | self.test_dataloader = data.DataLoader(testset, batch_size=1, num_workers=6) 120 | else: 121 | # get subset 122 | subset_indices = [0] 123 | subset = data.Subset(testset, subset_indices) 124 | self.test_dataloader = data.DataLoader(subset, batch_size=1, num_workers=6) 125 | 126 | def get_metrics(self, pred, label): 127 | metrics = compute_metrics(pred, label) 128 | return metrics 129 | 130 | def save_trj(self, pred): 131 | # create output directory if not exist 132 | if not os.path.exists(self.data_conf.out_trj_dir): 133 | os.makedirs(self.data_conf.out_trj_dir) 134 | 135 | # set output directory 136 | out_pdb = os.path.join(self.data_conf.out_trj_dir, pred._id + '_p' + pred.index + '.pdb') 137 | 138 | # output trajectory 139 | if os.path.exists(out_pdb): 140 | os.remove(out_pdb) 141 | 142 | seq1 = pred.rec_seq 143 | seq2 = pred.lig_seq 144 | 145 | for i, (x1, x2) in enumerate(zip(pred.rec_pos, pred.lig_pos)): 146 | coords = torch.cat([x1, x2], dim=0) 147 | coords = get_full_coords(coords) 148 | 149 | #get total len 150 | total_len = coords.size(0) 151 | 152 | #check seq len 153 | assert len(seq1) + len(seq2) == total_len 154 | 155 | #get pdb 156 | f = open(out_pdb, 'a') 157 | f.write("MODEL " + str(i) + "\n") 158 | f.close() 159 | save_PDB(out_pdb=out_pdb, coords=coords, seq=seq1+seq2, delim=len(seq1)-1) 160 | 161 | def save_pdb(self, pred): 162 | # create output directory if not exist 163 | if not os.path.exists(self.data_conf.out_pdb_dir): 164 | os.makedirs(self.data_conf.out_pdb_dir) 165 | 166 | # set output directory 167 | out_pdb = os.path.join(self.data_conf.out_pdb_dir, pred._id + '_p' + pred.index + '.pdb') 168 | 169 | # output trajectory 170 | if os.path.exists(out_pdb): 171 | os.remove(out_pdb) 172 | 173 | seq1 = pred.rec_seq 174 | seq2 = pred.lig_seq 175 | 176 | coords = torch.cat([pred.rec_pos[-1], pred.lig_pos[-1]], dim=0) 177 | coords = get_full_coords(coords) 178 | 179 | # get total len 180 | total_len = coords.size(0) 181 | 182 | # check seq len 183 | assert len(seq1) + len(seq2) == total_len 184 | 185 | # get pdb 186 | save_PDB(out_pdb=out_pdb, coords=coords, seq=seq1+seq2, delim=len(seq1)-1) 187 | 188 | def run_sampling(self): 189 | metrics_list = [] 190 | transforms_list = [] 191 | for batch in tqdm(self.test_dataloader): 192 | # get batch from testset loader 193 | _id = batch['id'][0] 194 | rec_seq = batch['rec_seq'][0] 195 | lig_seq = batch['lig_seq'][0] 196 | rec_x = batch['rec_x'].to(self.device).squeeze(0) 197 | lig_x = batch['lig_x'].to(self.device).squeeze(0) 198 | rec_pos = batch['rec_pos'].to(self.device).squeeze(0) 199 | lig_pos = batch['lig_pos'].to(self.device).squeeze(0) 200 | position_matrix = batch['position_matrix'].to(self.device).squeeze(0) 201 | 202 | batch = { 203 | "rec_x": rec_x, 204 | "lig_x": lig_x, 205 | "rec_pos": rec_pos.clone().detach(), 206 | "lig_pos": lig_pos.clone().detach(), 207 | "position_matrix": position_matrix, 208 | } 209 | 210 | # get ground truth pose 211 | label = pose( 212 | _id=_id, 213 | rec_seq=rec_seq, 214 | lig_seq=lig_seq, 215 | rec_pos=rec_pos, 216 | lig_pos=lig_pos 217 | ) 218 | 219 | if self.data_conf.get_gt_energy: 220 | batch["t"] = torch.zeros(1, device=self.device) + 1e-5 221 | output = self.model(batch) 222 | 223 | metrics = {'id': _id} 224 | metrics.update(self.get_metrics([rec_pos, lig_pos], [rec_pos, lig_pos])) 225 | metrics.update({'energy': output["energy"].item()}) 226 | metrics.update({'num_clashes': output["num_clashes"].item()}) 227 | metrics_list.append(metrics) 228 | 229 | else: 230 | # run 231 | for i in range(self.data_conf.num_samples): 232 | _rec_pos, _lig_pos, energy, num_clashes = self.Euler_Maruyama_sampler( 233 | batch=batch, 234 | batch_size=1, 235 | eps=1e-3, 236 | ode=self.data_conf.ode, 237 | ) 238 | 239 | # get predicted pose 240 | pred = pose( 241 | _id=_id, 242 | rec_seq=rec_seq, 243 | lig_seq=lig_seq, 244 | rec_pos=_rec_pos, 245 | lig_pos=_lig_pos, 246 | index=str(i) 247 | ) 248 | 249 | # get metrics 250 | metrics = {'id': _id, 'index': str(i)} 251 | metrics.update(self.get_metrics([_rec_pos[-1], _lig_pos[-1]], [rec_pos, lig_pos])) 252 | metrics.update({'energy': energy.item()}) 253 | metrics.update({'num_clashes': num_clashes.item()}) 254 | metrics_list.append(metrics) 255 | 256 | if self.data_conf.out_trj: 257 | self.save_trj(pred) 258 | 259 | if self.data_conf.out_pdb: 260 | self.save_pdb(pred) 261 | 262 | return metrics_list 263 | 264 | def Euler_Maruyama_sampler( 265 | self, 266 | batch, 267 | batch_size=1, 268 | eps=1e-3, 269 | ode=False, 270 | ): 271 | # coordinates and energy saver 272 | rec_trj = [] 273 | lig_trj = [] 274 | 275 | # initialize time steps 276 | t = torch.ones(batch_size, device=self.device) 277 | time_steps = torch.linspace(1., eps, self.data_conf.num_steps, device=self.device) 278 | dt = time_steps[0] - time_steps[1] 279 | 280 | # get initial pose 281 | rec_pos = batch["rec_pos"] 282 | lig_pos = batch["lig_pos"] 283 | 284 | # randomly initialize coordinates 285 | rec_pos, lig_pos, rot_update, tr_update = self.randomize_pose(rec_pos, lig_pos) 286 | 287 | # save initial coordinates 288 | rec_trj.append(rec_pos) 289 | lig_trj.append(lig_pos) 290 | 291 | # run reverse sde 292 | with torch.no_grad(): 293 | for i, time_step in enumerate(tqdm((time_steps))): 294 | # get current time step 295 | is_last = i == time_steps.size(0) - 1 296 | t = torch.ones(batch_size, device=self.device) * time_step 297 | 298 | batch["t"] = t 299 | batch["rec_pos"] = rec_pos.clone().detach() 300 | batch["lig_pos"] = lig_pos.clone().detach() 301 | 302 | # get predictions 303 | output = self.model(batch) 304 | 305 | if not is_last: 306 | tr_noise_scale = self.data_conf.tr_noise_scale 307 | rot_noise_scale = self.data_conf.rot_noise_scale 308 | else: 309 | tr_noise_scale = 0.0 310 | rot_noise_scale = 0.0 311 | 312 | if self.perturb_rot: 313 | rot = self.model.so3_diffuser.torch_reverse( 314 | score_t=output["rot_score"].detach(), 315 | t=t.item(), 316 | dt=dt, 317 | noise_scale=rot_noise_scale, 318 | ode=ode, 319 | ) 320 | else: 321 | rot = torch.zeros((1, 3), device=self.device) 322 | 323 | if self.perturb_tr: 324 | tr = self.model.r3_diffuser.torch_reverse( 325 | score_t=output["tr_score"].detach(), 326 | t=t.item(), 327 | dt=dt, 328 | noise_scale=tr_noise_scale, 329 | ode=ode, 330 | ) 331 | else: 332 | tr = torch.zeros((1, 3), device=self.device) 333 | 334 | lig_pos = self.modify_coords(lig_pos, rot, tr) 335 | 336 | # clash 337 | if self.data_conf.use_clash_force: 338 | clash_force = self.clash_force(rec_pos.clone().detach(), lig_pos.clone().detach()) 339 | lig_pos = lig_pos + clash_force 340 | 341 | if is_last: 342 | batch["rec_pos"] = rec_pos.clone().detach() 343 | batch["lig_pos"] = lig_pos.clone().detach() 344 | output = self.model(batch) 345 | 346 | # save coordinates 347 | rec_trj.append(rec_pos) 348 | lig_trj.append(lig_pos) 349 | 350 | return rec_trj, lig_trj, output["energy"], output["num_clashes"] 351 | 352 | def randomize_pose(self, x1, x2): 353 | # get center of mass 354 | c1 = torch.mean(x1[..., 1, :], dim=0) 355 | c2 = torch.mean(x2[..., 1, :], dim=0) 356 | 357 | # get rotat update 358 | rot_update = torch.from_numpy(Rotation.random().as_matrix()).float().to(self.device) 359 | 360 | # get trans update 361 | tr_update = torch.normal(0.0, 30.0, size=(1, 3), device=self.device) 362 | #tr_update = torch.from_numpy(sample_sphere(radius=50.0)).float().to(self.device) 363 | 364 | # move to origin 365 | x1 = x1 - c1 366 | x2 = x2 - c2 367 | 368 | # init rotation 369 | if self.perturb_rot: 370 | x2 = x2 @ rot_update.T 371 | 372 | # init translation 373 | if self.perturb_tr: 374 | x2 = x2 + tr_update 375 | 376 | # convert to axis angle 377 | rot_update = matrix_to_axis_angle(rot_update.unsqueeze(0)) 378 | 379 | return x1, x2, rot_update, tr_update 380 | 381 | def modify_coords(self, x, rot, tr): 382 | center = torch.mean(x[..., 1, :], dim=0, keepdim=True) 383 | rot = axis_angle_to_matrix(rot).squeeze() 384 | # update rotation 385 | if self.perturb_rot: 386 | x = (x - center) @ rot.T + center 387 | # update translation 388 | if self.perturb_tr: 389 | x = x + tr 390 | 391 | return x 392 | 393 | def clash_force(self, rec_pos, lig_pos): 394 | rec_pos = rec_pos.view(-1, 3) 395 | lig_pos = lig_pos.view(-1, 3) 396 | 397 | with torch.set_grad_enabled(True): 398 | lig_pos.requires_grad_(True) 399 | # get distance matrix 400 | D = torch.norm((rec_pos[:, None, :] - lig_pos[None, :, :]), dim=-1) 401 | 402 | def rep_fn(x): 403 | x0, p, w_rep = 4, 1.5, 5 404 | rep = torch.where(x < x0, (torch.abs(x0 - x) ** p) / (p * x * (p - 1)), torch.tensor(0.0, device=x.device)) 405 | return - w_rep * torch.sum(rep) 406 | 407 | rep = rep_fn(D) 408 | 409 | force = torch.autograd.grad(rep, lig_pos, retain_graph=False)[0] 410 | 411 | return force.mean(dim=0).detach() 412 | 413 | #---------------------------------------------------------------------------- 414 | # Main 415 | @hydra.main(version_base=None, config_path="/scratch4/jgray21/lchu11/graylab_repos/DFMDock/configs", config_name="inference") 416 | def main(config: DictConfig): 417 | # Print the entire configuration 418 | print(OmegaConf.to_yaml(config)) 419 | 420 | torch.manual_seed(0) 421 | sampler = Sampler(config) 422 | 423 | output_dir = config.data.out_csv_dir 424 | if not os.path.exists(output_dir): 425 | os.makedirs(output_dir) 426 | 427 | # set output directory 428 | output_filename = os.path.join(output_dir, config.data.out_csv) 429 | 430 | with open(output_filename, "w", newline="") as csvfile: 431 | results = sampler.run_sampling() 432 | 433 | # Write header row to CSV file 434 | header = list(results[0].keys()) 435 | writer = csv.DictWriter(csvfile, fieldnames=header) 436 | writer.writeheader() 437 | 438 | for row in results: 439 | writer.writerow(row) 440 | 441 | if __name__ == "__main__": 442 | main() 443 | -------------------------------------------------------------------------------- /src/inference_single.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from inference_base import inference 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description="Process two required PDB files.") 6 | parser.add_argument("pdb_1", type=str, help="Path to the first PDB file") 7 | parser.add_argument("pdb_2", type=str, help="Path to the second PDB file") 8 | return parser.parse_args() 9 | 10 | if __name__ == "__main__": 11 | args = parse_args() 12 | inference(args.pdb_1, args.pdb_2) 13 | -------------------------------------------------------------------------------- /src/models/DFMDock.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import hydra 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | import numpy as np 8 | from torch.utils import data 9 | from torch_geometric.loader import DataLoader 10 | from omegaconf import DictConfig 11 | from datasets.docking_dataset import DockingDataset 12 | from models.egnn_net import EGNN_Net 13 | from utils.so3_diffuser import SO3Diffuser 14 | from utils.r3_diffuser import R3Diffuser 15 | from utils.geometry import axis_angle_to_matrix 16 | from utils.crop import get_crop_idxs, get_crop, get_position_matrix 17 | from utils.loss import distogram_loss 18 | 19 | #---------------------------------------------------------------------------- 20 | # Main wrapper for training the model 21 | 22 | class DFMDock(pl.LightningModule): 23 | def __init__( 24 | self, 25 | model, 26 | diffuser, 27 | experiment, 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | self.lr = experiment.lr 32 | self.weight_decay = experiment.weight_decay 33 | 34 | # crop size 35 | self.crop_size = experiment.crop_size 36 | 37 | # confidence model 38 | self.use_confidence_loss = experiment.use_confidence_loss 39 | 40 | # dist model 41 | self.use_dist_loss = experiment.use_dist_loss 42 | 43 | # interface residue model 44 | self.use_interface_loss = experiment.use_interface_loss 45 | 46 | # energy 47 | self.grad_energy = experiment.grad_energy 48 | self.separate_energy_loss = experiment.separate_energy_loss 49 | self.use_contrastive_loss = experiment.use_contrastive_loss 50 | 51 | # translation 52 | self.perturb_tr = experiment.perturb_tr 53 | self.separate_tr_loss = experiment.separate_tr_loss 54 | 55 | # rotation 56 | self.perturb_rot = experiment.perturb_rot 57 | self.separate_rot_loss = experiment.separate_rot_loss 58 | 59 | # diffuser 60 | if self.perturb_tr: 61 | self.r3_diffuser = R3Diffuser(diffuser.r3) 62 | if self.perturb_rot: 63 | self.so3_diffuser = SO3Diffuser(diffuser.so3) 64 | 65 | # net 66 | self.net = EGNN_Net(model) 67 | 68 | def forward(self, batch): 69 | # move lig center to origin 70 | self.move_to_lig_center(batch) 71 | 72 | # predict 73 | outputs = self.net(batch, predict=True) 74 | 75 | return outputs 76 | 77 | def loss_fn(self, batch, eps=1e-5): 78 | with torch.no_grad(): 79 | # uniformly sample a timestep 80 | t = torch.rand(1, device=self.device) * (1. - eps) + eps 81 | batch["t"] = t 82 | 83 | # sample perturbation for translation and rotation 84 | if self.perturb_tr: 85 | tr_score_scale = self.r3_diffuser.score_scaling(t.item()) 86 | tr_update, tr_score_gt = self.r3_diffuser.forward_marginal(t.item()) 87 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 88 | tr_score_gt = torch.from_numpy(tr_score_gt).float().to(self.device) 89 | else: 90 | tr_update = np.zeros(3) 91 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 92 | 93 | if self.perturb_rot: 94 | rot_score_scale = self.so3_diffuser.score_scaling(t.item()) 95 | rot_update, rot_score_gt = self.so3_diffuser.forward_marginal(t.item()) 96 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 97 | rot_score_gt = torch.from_numpy(rot_score_gt).float().to(self.device) 98 | else: 99 | rot_update = np.zeros(3) 100 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 101 | 102 | # save gt state 103 | batch_gt = copy.deepcopy(batch) 104 | 105 | # get crop_idxs 106 | crop_idxs = get_crop_idxs(batch_gt, crop_size=self.crop_size) 107 | 108 | # pre crop 109 | batch = get_crop(batch, crop_idxs) 110 | batch_gt = get_crop(batch_gt, crop_idxs) 111 | 112 | # noised pose 113 | batch["lig_pos"] = self.modify_coords(batch["lig_pos"], rot_update, tr_update) 114 | 115 | # get LRMSD 116 | l_rmsd = get_rmsd(batch["lig_pos"][..., 1, :], batch_gt["lig_pos"][..., 1, :]) 117 | 118 | # move lig center to origin 119 | self.move_to_lig_center(batch) 120 | self.move_to_lig_center(batch_gt) 121 | 122 | # post crop 123 | #batch = get_crop(batch, crop_idxs) 124 | #batch_gt = get_crop(batch_gt, crop_idxs) 125 | 126 | # predict score based on the current state 127 | if self.grad_energy: 128 | outputs = self.net(batch) 129 | 130 | # grab some outputs 131 | tr_score = outputs["tr_score"] 132 | rot_score = outputs["rot_score"] 133 | f = outputs["f"] 134 | dedx = outputs["dedx"] 135 | energy_noised = outputs["energy"] 136 | 137 | # energy conservation loss 138 | if self.separate_energy_loss: 139 | f_angle = torch.norm(f, dim=-1, keepdim=True) 140 | f_axis = f / (f_angle + 1e-6) 141 | 142 | dedx_angle = torch.norm(dedx, dim=-1, keepdim=True) 143 | dedx_axis = dedx / (dedx_angle + 1e-6) 144 | 145 | ec_axis_loss = torch.mean((f_axis - dedx_axis)**2) 146 | ec_angle_loss = torch.mean((f_angle - dedx_angle)**2) 147 | ec_loss = 0.5 * (ec_axis_loss + ec_angle_loss) 148 | 149 | else: 150 | ec_loss = torch.mean((dedx - f)**2) 151 | else: 152 | outputs = self.net(batch, predict=True) 153 | 154 | # grab some outputs 155 | tr_score = outputs["tr_score"] 156 | rot_score = outputs["rot_score"] 157 | energy_noised = outputs["energy"] 158 | 159 | # energy conservation loss 160 | ec_loss = torch.tensor(0.0, device=self.device) 161 | 162 | mse_loss_fn = nn.MSELoss() 163 | # translation loss 164 | if self.perturb_tr: 165 | if self.separate_tr_loss: 166 | gt_tr_angle = torch.norm(tr_score_gt, dim=-1, keepdim=True) 167 | gt_tr_axis = tr_score_gt / (gt_tr_angle + 1e-6) 168 | 169 | pred_tr_angle = torch.norm(tr_score, dim=-1, keepdim=True) 170 | pred_tr_axis = tr_score / (pred_tr_angle + 1e-6) 171 | 172 | tr_axis_loss = torch.mean((pred_tr_axis - gt_tr_axis)**2) 173 | tr_angle_loss = torch.mean((pred_tr_angle - gt_tr_angle)**2 / tr_score_scale**2) 174 | tr_loss = 0.5 * (tr_axis_loss + tr_angle_loss) 175 | 176 | else: 177 | tr_loss = torch.mean((tr_score - tr_score_gt)**2 / tr_score_scale**2) 178 | else: 179 | tr_loss = torch.tensor(0.0, device=self.device) 180 | 181 | # rotation loss 182 | if self.perturb_rot: 183 | if self.separate_rot_loss: 184 | gt_rot_angle = torch.norm(rot_score_gt, dim=-1, keepdim=True) 185 | gt_rot_axis = rot_score_gt / (gt_rot_angle + 1e-6) 186 | 187 | pred_rot_angle = torch.norm(rot_score, dim=-1, keepdim=True) 188 | pred_rot_axis = rot_score / (pred_rot_angle + 1e-6) 189 | 190 | rot_axis_loss = torch.mean((pred_rot_axis - gt_rot_axis)**2) 191 | rot_angle_loss = torch.mean((pred_rot_angle - gt_rot_angle)**2 / rot_score_scale**2) 192 | rot_loss = 0.5 * (rot_axis_loss + rot_angle_loss) 193 | 194 | else: 195 | rot_loss = torch.mean((rot_score - rot_score_gt)**2 / rot_score_scale**2) 196 | else: 197 | rot_loss = torch.tensor(0.0, device=self.device) 198 | 199 | # contrastive loss 200 | # modified from https://github.com/yilundu/ired_code_release/blob/main/diffusion_lib/denoising_diffusion_pytorch_1d.py 201 | if self.use_contrastive_loss: 202 | energy_gt = self.net(batch_gt, return_energy=True) 203 | energy_stack = torch.stack([energy_gt, energy_noised], dim=-1) 204 | target = torch.zeros([], device=energy_stack.device) 205 | el_loss = F.cross_entropy(-1 * energy_stack, target.long(), reduction='none') 206 | else: 207 | el_loss = torch.tensor(0.0, device=self.device) 208 | 209 | bce_logits_loss = nn.BCEWithLogitsLoss() 210 | # distogram loss 211 | if self.use_dist_loss: 212 | gt_dist = torch.norm((batch_gt["rec_pos"][:, None, 1, :] - batch_gt["lig_pos"][None, :, 1, :]), dim=-1, keepdim=True) 213 | dist_loss = distogram_loss(outputs["dist_logits"], gt_dist) 214 | else: 215 | dist_loss = torch.tensor(0.0, device=self.device) 216 | 217 | # interface loss 218 | if self.use_interface_loss: 219 | gt_ires = get_interface_residue_tensors(batch_gt["rec_pos"][:, 1, :], batch_gt["lig_pos"][:, 1, :]) 220 | ires_loss = bce_logits_loss(outputs["ires_logits"], gt_ires) 221 | else: 222 | ires_loss = torch.tensor(0.0, device=self.device) 223 | 224 | # confidence loss 225 | if self.use_confidence_loss: 226 | label = (l_rmsd < 5.0).float() 227 | conf_loss = bce_logits_loss(outputs["confidence_logits"], label) 228 | else: 229 | conf_loss = torch.tensor(0.0, device=self.device) 230 | 231 | # total losses 232 | loss = tr_loss + rot_loss + 0.1 * (ec_loss + el_loss+ conf_loss + dist_loss + ires_loss) 233 | losses = { 234 | "tr_loss": tr_loss, 235 | "rot_loss": rot_loss, 236 | "ec_loss": ec_loss, 237 | "el_loss": el_loss, 238 | "dist_loss": dist_loss, 239 | "ires_loss": ires_loss, 240 | "conf_loss": conf_loss, 241 | "loss": loss, 242 | } 243 | 244 | return losses 245 | 246 | def modify_coords(self, lig_pos, rot_update, tr_update): 247 | cen = lig_pos.mean(dim=(0, 1)) 248 | rot = axis_angle_to_matrix(rot_update.squeeze()) 249 | tr = tr_update.squeeze() 250 | lig_pos = (lig_pos - cen) @ rot.T + cen 251 | lig_pos = lig_pos + tr 252 | return lig_pos 253 | 254 | def move_to_lig_center(self, batch): 255 | center = batch["lig_pos"].mean(dim=(0, 1)) 256 | batch["rec_pos"] = batch["rec_pos"] - center 257 | batch["lig_pos"] = batch["lig_pos"] - center 258 | 259 | def step(self, batch, batch_idx): 260 | rec_x = batch['rec_x'].squeeze(0) 261 | lig_x = batch['lig_x'].squeeze(0) 262 | rec_pos = batch['rec_pos'].squeeze(0) 263 | lig_pos = batch['lig_pos'].squeeze(0) 264 | is_homomer = batch['is_homomer'] 265 | 266 | # wrap to a batch 267 | batch = { 268 | "rec_x": rec_x, 269 | "lig_x": lig_x, 270 | "rec_pos": rec_pos, 271 | "lig_pos": lig_pos, 272 | "is_homomer": is_homomer, 273 | } 274 | 275 | # get losses 276 | losses = self.loss_fn(batch) 277 | return losses 278 | 279 | def training_step(self, batch, batch_idx): 280 | losses = self.step(batch, batch_idx) 281 | for loss_name, indiv_loss in losses.items(): 282 | self.log( 283 | f"train/{loss_name}", 284 | indiv_loss, 285 | batch_size=1, 286 | ) 287 | return losses["loss"] 288 | 289 | def on_validation_model_eval(self, *args, **kwargs): 290 | super().on_validation_model_eval(*args, **kwargs) 291 | torch.set_grad_enabled(True) 292 | 293 | def on_validation_model_train(self, *args, **kwargs): 294 | super().on_validation_model_train(*args, **kwargs) 295 | torch.set_grad_enabled(True) 296 | 297 | def validation_step(self, batch, batch_idx): 298 | losses = self.step(batch, batch_idx) 299 | for loss_name, indiv_loss in losses.items(): 300 | self.log( 301 | f"val/{loss_name}", 302 | indiv_loss, 303 | batch_size=1, 304 | ) 305 | return losses["loss"] 306 | 307 | def test_step(self, batch, batch_idx): 308 | losses = self.step(batch, batch_idx) 309 | for loss_name, indiv_loss in losses.items(): 310 | self.log( 311 | f"test/{loss_name}", 312 | indiv_loss, 313 | batch_size=1, 314 | ) 315 | return losses["loss"] 316 | 317 | def configure_optimizers(self): 318 | optimizer = torch.optim.AdamW( 319 | filter(lambda p: p.requires_grad, self.parameters()), 320 | lr=self.lr, 321 | weight_decay=self.weight_decay 322 | ) 323 | return optimizer 324 | 325 | # helper functions 326 | 327 | def get_interface_residue_tensors(set1, set2, threshold=8.0): 328 | device = set1.device 329 | n1_len = set1.shape[0] 330 | n2_len = set2.shape[0] 331 | 332 | # Calculate the Euclidean distance between each pair of points from the two sets 333 | dists = torch.cdist(set1, set2) 334 | 335 | # Find the indices where the distance is less than the threshold 336 | close_points = dists < threshold 337 | 338 | # Create indicator tensors initialized to 0 339 | indicator_set1 = torch.zeros((n1_len, 1), device=device) 340 | indicator_set2 = torch.zeros((n2_len, 1), device=device) 341 | 342 | # Set the corresponding indices to 1 where the points are close 343 | indicator_set1[torch.any(close_points, dim=1)] = 1.0 344 | indicator_set2[torch.any(close_points, dim=0)] = 1.0 345 | 346 | return torch.cat([indicator_set1, indicator_set2], dim=0) 347 | 348 | def get_rmsd(pred, label): 349 | rmsd = torch.sqrt(torch.mean(torch.sum((pred - label) ** 2.0, dim=-1))) 350 | return rmsd 351 | 352 | #---------------------------------------------------------------------------- 353 | # Testing run 354 | 355 | @hydra.main(version_base=None, config_path="/scratch4/jgray21/lchu11/graylab_repos/DFMDock/configs/model", config_name="DFMDock.yaml") 356 | def main(conf: DictConfig): 357 | dataset = DockingDataset( 358 | dataset='dips_train', 359 | training=True, 360 | use_esm=True, 361 | ) 362 | 363 | subset_indices = [0] 364 | subset = data.Subset(dataset, subset_indices) 365 | 366 | #load dataset 367 | dataloader = DataLoader(subset) 368 | 369 | model = DFMDock( 370 | model=conf.model, 371 | diffuser=conf.diffuser, 372 | experiment=conf.experiment 373 | ) 374 | trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=10, inference_mode=False) 375 | trainer.validate(model, dataloader) 376 | 377 | if __name__ == '__main__': 378 | main() 379 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/egnn.py: -------------------------------------------------------------------------------- 1 | # The code is adopted from: 2 | # https://github.com/vgsatorras/egnn 3 | 4 | import torch 5 | from torch import nn 6 | from torch_geometric.nn.norm import GraphNorm 7 | 8 | #---------------------------------------------------------------------------- 9 | # Helper functions 10 | 11 | def unsorted_segment_sum(data, segment_ids, num_segments): 12 | result_shape = (num_segments, data.size(1)) 13 | result = data.new_full(result_shape, 0) # Init empty result tensor. 14 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 15 | result.scatter_add_(0, segment_ids, data) 16 | return result 17 | 18 | 19 | def unsorted_segment_mean(data, segment_ids, num_segments): 20 | result_shape = (num_segments, data.size(1)) 21 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 22 | result = data.new_full(result_shape, 0) # Init empty result tensor. 23 | count = data.new_full(result_shape, 0) 24 | result.scatter_add_(0, segment_ids, data) 25 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 26 | return result / count.clamp(min=1) 27 | 28 | #---------------------------------------------------------------------------- 29 | # nn Modules 30 | 31 | class E_GCL(nn.Module): 32 | """ 33 | E(n) Equivariant Convolutional Layer 34 | re 35 | """ 36 | 37 | def __init__( 38 | self, 39 | input_nf, 40 | output_nf, 41 | hidden_nf, 42 | edges_in_d=0, 43 | act_fn=nn.SiLU(), 44 | residual=True, 45 | attention=False, 46 | normalize=False, 47 | coords_agg='mean', 48 | tanh=False, 49 | update_coords=False, 50 | dropout=0.0, 51 | coord_weights_clamp_value=None, 52 | ): 53 | super(E_GCL, self).__init__() 54 | input_edge = input_nf * 2 55 | self.residual = residual 56 | self.attention = attention 57 | self.normalize = normalize 58 | self.coords_agg = coords_agg 59 | self.tanh = tanh 60 | self.epsilon = 1e-8 61 | self.update_coords = update_coords 62 | self.coord_weights_clamp_value = coord_weights_clamp_value 63 | 64 | edge_coords_nf = 1 65 | 66 | self.edge_mlp = nn.Sequential( 67 | nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), 68 | act_fn, 69 | nn.Linear(hidden_nf, hidden_nf), 70 | act_fn) 71 | 72 | self.node_mlp = nn.Sequential( 73 | nn.Linear(hidden_nf + input_nf, hidden_nf), 74 | GraphNorm(hidden_nf), 75 | act_fn, 76 | nn.Linear(hidden_nf, output_nf)) 77 | 78 | if self.update_coords: 79 | layer = nn.Linear(hidden_nf, 1, bias=False) 80 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 81 | 82 | coord_mlp = [] 83 | coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) 84 | coord_mlp.append(act_fn) 85 | coord_mlp.append(layer) 86 | if self.tanh: 87 | coord_mlp.append(nn.Tanh()) 88 | self.coord_mlp = nn.Sequential(*coord_mlp) 89 | 90 | if self.attention: 91 | self.att_mlp = nn.Sequential( 92 | nn.Linear(hidden_nf, 1), 93 | nn.Sigmoid()) 94 | 95 | def edge_model(self, source, target, radial, edge_attr): 96 | if edge_attr is None: # Unused. 97 | out = torch.cat([source, target, radial], dim=1) 98 | else: 99 | out = torch.cat([source, target, radial, edge_attr], dim=1) 100 | out = self.edge_mlp(out) 101 | if self.attention: 102 | att_val = self.att_mlp(out) 103 | out = out * att_val 104 | return out 105 | 106 | def node_model(self, x, edge_index, edge_attr, node_attr): 107 | row, col = edge_index 108 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) 109 | if node_attr is not None: 110 | agg = torch.cat([x, agg, node_attr], dim=1) 111 | else: 112 | agg = torch.cat([x, agg], dim=1) 113 | out = self.node_mlp(agg) 114 | if self.residual: 115 | out = x + out 116 | return out, agg 117 | 118 | def coord_model(self, coord, edge_index, coord_diff, edge_feat, lig_mask): 119 | row, col = edge_index 120 | coord_weights = self.coord_mlp(edge_feat) 121 | 122 | if self.coord_weights_clamp_value is not None: 123 | clamp_value = self.coord_weights_clamp_value 124 | coord_weights.clamp_(min = -clamp_value, max = clamp_value) 125 | 126 | trans = coord_diff * coord_weights 127 | if self.coords_agg == 'sum': 128 | agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) 129 | elif self.coords_agg == 'mean': 130 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 131 | else: 132 | raise Exception('Wrong coords_agg parameter' % self.coords_agg) 133 | if lig_mask is not None: 134 | coord = coord + agg * lig_mask[:, None] 135 | else: 136 | coord = coord + agg 137 | return coord 138 | 139 | def coord2radial(self, edge_index, coord): 140 | row, col = edge_index 141 | coord_diff = coord[row] - coord[col] 142 | radial = torch.sum(coord_diff**2, 1).unsqueeze(1) 143 | 144 | if self.normalize: 145 | norm = torch.sqrt(radial + self.epsilon) 146 | coord_diff = coord_diff / (norm + 1.0) 147 | 148 | return radial, coord_diff 149 | 150 | def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, lig_mask=None): 151 | row, col = edge_index 152 | radial, coord_diff = self.coord2radial(edge_index, coord) 153 | 154 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 155 | if self.update_coords: 156 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat, lig_mask) 157 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 158 | 159 | return h, coord, edge_attr 160 | -------------------------------------------------------------------------------- /src/models/score_model.py: -------------------------------------------------------------------------------- 1 | import esm 2 | import copy 3 | import hydra 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | import numpy as np 9 | import random 10 | from torch.utils import data 11 | from torch_geometric.loader import DataLoader 12 | from omegaconf import DictConfig 13 | from models.score_net import Score_Net 14 | from utils.so3_diffuser import SO3Diffuser 15 | from utils.r3_diffuser import R3Diffuser 16 | from utils.geometry import axis_angle_to_matrix 17 | from datasets.ppi_dataset import PPIDataset 18 | 19 | #---------------------------------------------------------------------------- 20 | # Main wrapper for training the model 21 | 22 | class Score_Model(pl.LightningModule): 23 | def __init__( 24 | self, 25 | model, 26 | diffuser, 27 | experiment, 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | self.lr = experiment.lr 32 | self.weight_decay = experiment.weight_decay 33 | 34 | # energy 35 | self.grad_energy = experiment.grad_energy 36 | self.separate_energy_loss = experiment.separate_energy_loss 37 | 38 | # translation 39 | self.perturb_tr = experiment.perturb_tr 40 | self.separate_tr_loss = experiment.separate_tr_loss 41 | 42 | # rotation 43 | self.perturb_rot = experiment.perturb_rot 44 | self.separate_rot_loss = experiment.separate_rot_loss 45 | 46 | # interface 47 | self.use_interface_loss = experiment.use_interface_loss 48 | 49 | # contrastive 50 | self.use_contrastive_loss = experiment.use_contrastive_loss 51 | 52 | # diffuser 53 | if self.perturb_tr: 54 | self.r3_diffuser = R3Diffuser(diffuser.r3) 55 | if self.perturb_rot: 56 | self.so3_diffuser = SO3Diffuser(diffuser.so3) 57 | 58 | # net 59 | self.net = Score_Net(model) 60 | 61 | def forward(self, batch): 62 | # grab some input 63 | rec_pos = batch["rec_pos"] 64 | lig_pos = batch["lig_pos"] 65 | 66 | # move lig center to origin 67 | center = lig_pos[..., 1, :].mean(dim=0) 68 | rec_pos -= center 69 | lig_pos -= center 70 | 71 | # push to batch 72 | batch["rec_pos"] = rec_pos 73 | batch["lig_pos"] = lig_pos 74 | 75 | # predict 76 | outputs = self.net(batch, predict=True) 77 | 78 | return outputs 79 | 80 | def loss_fn(self, batch, eps=1e-5): 81 | # grab some input 82 | rec_pos = batch["rec_pos"] 83 | lig_pos = batch["lig_pos"] 84 | 85 | with torch.no_grad(): 86 | # uniformly sample a timestep 87 | t = torch.rand(1, device=self.device) * (1. - eps) + eps 88 | batch["t"] = t 89 | 90 | # sample perturbation for translation and rotation 91 | if self.perturb_tr: 92 | tr_score_scale = self.r3_diffuser.score_scaling(t.item()) 93 | tr_update, tr_score_gt = self.r3_diffuser.forward_marginal(t.item()) 94 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 95 | tr_score_gt = torch.from_numpy(tr_score_gt).float().to(self.device) 96 | else: 97 | tr_update = np.zeros(3) 98 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 99 | 100 | if self.perturb_rot: 101 | rot_score_scale = self.so3_diffuser.score_scaling(t.item()) 102 | rot_update, rot_score_gt = self.so3_diffuser.forward_marginal(t.item()) 103 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 104 | rot_score_gt = torch.from_numpy(rot_score_gt).float().to(self.device) 105 | else: 106 | rot_update = np.zeros(3) 107 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 108 | 109 | # save gt state 110 | batch_gt = copy.deepcopy(batch) 111 | 112 | # update poses 113 | lig_pos = self.modify_coords(lig_pos, rot_update, tr_update) 114 | 115 | # get LRMSD 116 | #l_rmsd = get_rmsd(lig_pos[..., 1, :], batch_gt["lig_pos"][..., 1, :]) 117 | 118 | # move lig center to origin 119 | center = lig_pos[..., 1, :].mean(dim=0) 120 | rec_pos -= center 121 | lig_pos -= center 122 | 123 | # save noised state 124 | batch["rec_pos"] = rec_pos 125 | batch["lig_pos"] = lig_pos 126 | 127 | # predict score based on the current state 128 | if self.grad_energy: 129 | outputs = self.net(batch) 130 | 131 | # grab some outputs 132 | tr_score = outputs["tr_score"] 133 | rot_score = outputs["rot_score"] 134 | f = outputs["f"] 135 | dedx = outputs["dedx"] 136 | energy_noised = outputs["energy"] 137 | 138 | # energy conservation loss 139 | if self.separate_energy_loss: 140 | f_angle = torch.norm(f, dim=-1, keepdim=True) 141 | f_axis = f / (f_angle + 1e-6) 142 | 143 | dedx_angle = torch.norm(dedx, dim=-1, keepdim=True) 144 | dedx_axis = dedx / (dedx_angle + 1e-6) 145 | 146 | ec_axis_loss = torch.mean((f_axis - dedx_axis)**2) 147 | ec_angle_loss = torch.mean((f_angle - dedx_angle)**2) 148 | ec_loss = 0.5 * (ec_axis_loss + ec_angle_loss) 149 | 150 | else: 151 | ec_loss = torch.mean((dedx - f)**2) 152 | else: 153 | outputs = self.net(batch, predict=True) 154 | 155 | # grab some outputs 156 | tr_score = outputs["tr_score"] 157 | rot_score = outputs["rot_score"] 158 | energy_noised = outputs["energy"] 159 | 160 | # energy conservation loss 161 | ec_loss = torch.tensor(0.0, device=self.device) 162 | 163 | 164 | # translation loss 165 | if self.perturb_tr: 166 | if self.separate_tr_loss: 167 | gt_tr_angle = torch.norm(tr_score_gt, dim=-1, keepdim=True) 168 | gt_tr_axis = tr_score_gt / (gt_tr_angle + 1e-6) 169 | 170 | pred_tr_angle = torch.norm(tr_score, dim=-1, keepdim=True) 171 | pred_tr_axis = tr_score / (pred_tr_angle + 1e-6) 172 | 173 | tr_axis_loss = torch.mean((pred_tr_axis - gt_tr_axis)**2) 174 | tr_angle_loss = torch.mean((pred_tr_angle - gt_tr_angle)**2 / tr_score_scale**2) 175 | tr_loss = 0.5 * (tr_axis_loss + tr_angle_loss) 176 | 177 | else: 178 | tr_loss = torch.mean((tr_score - tr_score_gt)**2 / tr_score_scale**2) 179 | else: 180 | tr_loss = torch.tensor(0.0, device=self.device) 181 | 182 | # rotation loss 183 | if self.perturb_rot: 184 | if self.separate_rot_loss: 185 | gt_rot_angle = torch.norm(rot_score_gt, dim=-1, keepdim=True) 186 | gt_rot_axis = rot_score_gt / (gt_rot_angle + 1e-6) 187 | 188 | pred_rot_angle = torch.norm(rot_score, dim=-1, keepdim=True) 189 | pred_rot_axis = rot_score / (pred_rot_angle + 1e-6) 190 | 191 | rot_axis_loss = torch.mean((pred_rot_axis - gt_rot_axis)**2) 192 | rot_angle_loss = torch.mean((pred_rot_angle - gt_rot_angle)**2 / rot_score_scale**2) 193 | rot_loss = 0.5 * (rot_axis_loss + rot_angle_loss) 194 | 195 | else: 196 | rot_loss = torch.mean((rot_score - rot_score_gt)**2 / rot_score_scale**2) 197 | else: 198 | rot_loss = torch.tensor(0.0, device=self.device) 199 | 200 | # interface loss 201 | bce_logits_loss = nn.BCEWithLogitsLoss() 202 | if self.use_interface_loss: 203 | ires_loss = bce_logits_loss(outputs['ires'], batch['ires']) 204 | else: 205 | ires_loss = torch.tensor(0.0, device=self.device) 206 | 207 | # contrastive loss 208 | # modified from https://github.com/yilundu/ired_code_release/blob/main/diffusion_lib/denoising_diffusion_pytorch_1d.py 209 | if self.use_contrastive_loss: 210 | energy_gt = self.net(batch_gt, return_energy=True) 211 | energy_stack = torch.stack([energy_gt, energy_noised], dim=-1) 212 | target = torch.zeros([], device=energy_stack.device) 213 | el_loss = F.cross_entropy(-1 * energy_stack, target.long(), reduction='none') 214 | else: 215 | el_loss = torch.tensor(0.0, device=self.device) 216 | 217 | # total losses 218 | loss = tr_loss + rot_loss + ec_loss + el_loss + ires_loss 219 | losses = {"tr_loss": tr_loss, "rot_loss": rot_loss, "ec_loss": ec_loss, "el_loss": el_loss, "ires_loss": ires_loss, "loss": loss} 220 | 221 | return losses 222 | 223 | def modify_coords(self, lig_pos, rot_update, tr_update): 224 | cen = lig_pos[..., 1, :].mean(dim=0) 225 | rot = axis_angle_to_matrix(rot_update.squeeze()) 226 | tr = tr_update.squeeze() 227 | lig_pos = (lig_pos - cen) @ rot.T + cen 228 | lig_pos = lig_pos + tr 229 | return lig_pos 230 | 231 | def step(self, batch, batch_idx): 232 | rec_x = batch['rec_x'].squeeze(0) 233 | lig_x = batch['lig_x'].squeeze(0) 234 | rec_pos = batch['rec_pos'].squeeze(0) 235 | lig_pos = batch['lig_pos'].squeeze(0) 236 | position_matrix = batch['position_matrix'].squeeze(0) 237 | ires = batch['ires'].squeeze(0) 238 | 239 | # wrap to a batch 240 | batch = { 241 | "rec_x": rec_x, 242 | "lig_x": lig_x, 243 | "rec_pos": rec_pos, 244 | "lig_pos": lig_pos, 245 | "position_matrix": position_matrix, 246 | "ires": ires, 247 | } 248 | 249 | # get losses 250 | losses = self.loss_fn(batch) 251 | return losses 252 | 253 | def get_esm_rep(self, out): 254 | with torch.no_grad(): 255 | results = self.esm_model(out, repr_layers = [33]) 256 | rep = results["representations"][33] 257 | return rep[0, :, :] 258 | 259 | def training_step(self, batch, batch_idx): 260 | losses = self.step(batch, batch_idx) 261 | for loss_name, indiv_loss in losses.items(): 262 | self.log( 263 | f"train/{loss_name}", 264 | indiv_loss, 265 | batch_size=1, 266 | ) 267 | return losses["loss"] 268 | 269 | def on_validation_model_eval(self, *args, **kwargs): 270 | super().on_validation_model_eval(*args, **kwargs) 271 | torch.set_grad_enabled(True) 272 | 273 | def on_validation_model_train(self, *args, **kwargs): 274 | super().on_validation_model_train(*args, **kwargs) 275 | torch.set_grad_enabled(True) 276 | 277 | def validation_step(self, batch, batch_idx): 278 | losses = self.step(batch, batch_idx) 279 | for loss_name, indiv_loss in losses.items(): 280 | self.log( 281 | f"val/{loss_name}", 282 | indiv_loss, 283 | batch_size=1, 284 | ) 285 | return losses["loss"] 286 | 287 | def test_step(self, batch, batch_idx): 288 | losses = self.step(batch, batch_idx) 289 | for loss_name, indiv_loss in losses.items(): 290 | self.log( 291 | f"test/{loss_name}", 292 | indiv_loss, 293 | batch_size=1, 294 | ) 295 | return losses["loss"] 296 | 297 | def configure_optimizers(self): 298 | optimizer = torch.optim.AdamW( 299 | filter(lambda p: p.requires_grad, self.parameters()), 300 | lr=self.lr, 301 | weight_decay=self.weight_decay 302 | ) 303 | return optimizer 304 | 305 | #---------------------------------------------------------------------------- 306 | # Helpers 307 | 308 | def get_rmsd(pred, label): 309 | rmsd = torch.sqrt(torch.mean(torch.sum((pred - label) ** 2.0, dim=-1))) 310 | return rmsd 311 | 312 | #---------------------------------------------------------------------------- 313 | # Testing run 314 | 315 | @hydra.main(version_base=None, config_path="/scratch4/jgray21/lchu11/graylab_repos/DFMDock/configs/model", config_name="score_model.yaml") 316 | def main(conf: DictConfig): 317 | dataset = PPIDataset( 318 | dataset='pinder_train', 319 | crop_size=500, 320 | ) 321 | 322 | subset_indices = [0] 323 | subset = data.Subset(dataset, subset_indices) 324 | 325 | #load dataset 326 | dataloader = DataLoader(subset) 327 | 328 | model = Score_Model( 329 | model=conf.model, 330 | diffuser=conf.diffuser, 331 | experiment=conf.experiment 332 | ) 333 | trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=10, inference_mode=False) 334 | trainer.validate(model, dataloader) 335 | 336 | if __name__ == '__main__': 337 | main() 338 | -------------------------------------------------------------------------------- /src/models/score_model_mlsb.py: -------------------------------------------------------------------------------- 1 | import esm 2 | import copy 3 | import hydra 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | import numpy as np 9 | import random 10 | from torch.utils import data 11 | from torch_geometric.loader import DataLoader 12 | from omegaconf import DictConfig 13 | from models.score_net_mlsb import Score_Net 14 | from utils.so3_diffuser import SO3Diffuser 15 | from utils.r3_diffuser import R3Diffuser 16 | from utils.geometry import axis_angle_to_matrix 17 | from datasets.ppi_mlsb_dataset import PPIDataset 18 | 19 | #---------------------------------------------------------------------------- 20 | # Main wrapper for training the model 21 | 22 | class Score_Model(pl.LightningModule): 23 | def __init__( 24 | self, 25 | model, 26 | diffuser, 27 | experiment, 28 | ): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | self.lr = experiment.lr 32 | self.weight_decay = experiment.weight_decay 33 | 34 | # energy 35 | self.grad_energy = experiment.grad_energy 36 | self.separate_energy_loss = experiment.separate_energy_loss 37 | 38 | # translation 39 | self.perturb_tr = experiment.perturb_tr 40 | self.separate_tr_loss = experiment.separate_tr_loss 41 | 42 | # rotation 43 | self.perturb_rot = experiment.perturb_rot 44 | self.separate_rot_loss = experiment.separate_rot_loss 45 | 46 | # interface 47 | self.use_interface_loss = experiment.use_interface_loss 48 | 49 | # contrastive 50 | self.use_contrastive_loss = experiment.use_contrastive_loss 51 | 52 | # diffuser 53 | if self.perturb_tr: 54 | self.r3_diffuser = R3Diffuser(diffuser.r3) 55 | if self.perturb_rot: 56 | self.so3_diffuser = SO3Diffuser(diffuser.so3) 57 | 58 | # net 59 | self.net = Score_Net(model) 60 | 61 | def forward(self, batch): 62 | outputs = self.net(batch, predict=True) 63 | return outputs 64 | 65 | def loss_fn(self, batch, eps=1e-5): 66 | with torch.no_grad(): 67 | # uniformly sample a timestep 68 | t = torch.rand(1, device=self.device) * (1. - eps) + eps 69 | batch["t"] = t 70 | 71 | # sample perturbation for translation and rotation 72 | if self.perturb_tr: 73 | tr_score_scale = self.r3_diffuser.score_scaling(t.item()) 74 | tr_update, tr_score_gt = self.r3_diffuser.forward_marginal(t.item()) 75 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 76 | tr_score_gt = torch.from_numpy(tr_score_gt).float().to(self.device) 77 | else: 78 | tr_update = np.zeros(3) 79 | tr_update = torch.from_numpy(tr_update).float().to(self.device) 80 | 81 | if self.perturb_rot: 82 | rot_score_scale = self.so3_diffuser.score_scaling(t.item()) 83 | rot_update, rot_score_gt = self.so3_diffuser.forward_marginal(t.item()) 84 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 85 | rot_score_gt = torch.from_numpy(rot_score_gt).float().to(self.device) 86 | else: 87 | rot_update = np.zeros(3) 88 | rot_update = torch.from_numpy(rot_update).float().to(self.device) 89 | 90 | # save gt state 91 | batch_gt = copy.deepcopy(batch) 92 | 93 | # update poses 94 | batch["lig_pos"] = self.modify_coords(batch["lig_pos"], rot_update, tr_update) 95 | 96 | 97 | # predict score based on the current state 98 | if self.grad_energy: 99 | outputs = self.net(batch) 100 | 101 | # grab some outputs 102 | tr_score = outputs["tr_score"] 103 | rot_score = outputs["rot_score"] 104 | f = outputs["f"] 105 | dedx = outputs["dedx"] 106 | energy_noised = outputs["energy"] 107 | 108 | # energy conservation loss 109 | if self.separate_energy_loss: 110 | f_angle = torch.norm(f, dim=-1, keepdim=True) 111 | f_axis = f / (f_angle + 1e-6) 112 | 113 | dedx_angle = torch.norm(dedx, dim=-1, keepdim=True) 114 | dedx_axis = dedx / (dedx_angle + 1e-6) 115 | 116 | ec_axis_loss = torch.mean((f_axis - dedx_axis)**2) 117 | ec_angle_loss = torch.mean((f_angle - dedx_angle)**2) 118 | ec_loss = 0.5 * (ec_axis_loss + ec_angle_loss) 119 | 120 | else: 121 | ec_loss = torch.mean((dedx - f)**2) 122 | else: 123 | outputs = self.net(batch, predict=True) 124 | 125 | # grab some outputs 126 | tr_score = outputs["tr_score"] 127 | rot_score = outputs["rot_score"] 128 | energy_noised = outputs["energy"] 129 | 130 | # energy conservation loss 131 | ec_loss = torch.tensor(0.0, device=self.device) 132 | 133 | 134 | # translation loss 135 | if self.perturb_tr: 136 | if self.separate_tr_loss: 137 | gt_tr_angle = torch.norm(tr_score_gt, dim=-1, keepdim=True) 138 | gt_tr_axis = tr_score_gt / (gt_tr_angle + 1e-6) 139 | 140 | pred_tr_angle = torch.norm(tr_score, dim=-1, keepdim=True) 141 | pred_tr_axis = tr_score / (pred_tr_angle + 1e-6) 142 | 143 | tr_axis_loss = torch.mean((pred_tr_axis - gt_tr_axis)**2) 144 | tr_angle_loss = torch.mean((pred_tr_angle - gt_tr_angle)**2 / tr_score_scale**2) 145 | tr_loss = 0.5 * (tr_axis_loss + tr_angle_loss) 146 | 147 | else: 148 | tr_loss = torch.mean((tr_score - tr_score_gt)**2 / tr_score_scale**2) 149 | else: 150 | tr_loss = torch.tensor(0.0, device=self.device) 151 | 152 | # rotation loss 153 | if self.perturb_rot: 154 | if self.separate_rot_loss: 155 | gt_rot_angle = torch.norm(rot_score_gt, dim=-1, keepdim=True) 156 | gt_rot_axis = rot_score_gt / (gt_rot_angle + 1e-6) 157 | 158 | pred_rot_angle = torch.norm(rot_score, dim=-1, keepdim=True) 159 | pred_rot_axis = rot_score / (pred_rot_angle + 1e-6) 160 | 161 | rot_axis_loss = torch.mean((pred_rot_axis - gt_rot_axis)**2) 162 | rot_angle_loss = torch.mean((pred_rot_angle - gt_rot_angle)**2 / rot_score_scale**2) 163 | rot_loss = 0.5 * (rot_axis_loss + rot_angle_loss) 164 | 165 | else: 166 | rot_loss = torch.mean((rot_score - rot_score_gt)**2 / rot_score_scale**2) 167 | else: 168 | rot_loss = torch.tensor(0.0, device=self.device) 169 | 170 | # interface loss 171 | bce_logits_loss = nn.BCEWithLogitsLoss() 172 | if self.use_interface_loss: 173 | ires_loss = bce_logits_loss(outputs['ires'], batch['ires']) 174 | else: 175 | ires_loss = torch.tensor(0.0, device=self.device) 176 | 177 | # contrastive loss 178 | # modified from https://github.com/yilundu/ired_code_release/blob/main/diffusion_lib/denoising_diffusion_pytorch_1d.py 179 | if self.use_contrastive_loss: 180 | energy_gt = self.net(batch_gt, return_energy=True) 181 | energy_stack = torch.stack([energy_gt, energy_noised], dim=-1) 182 | target = torch.zeros([], device=energy_stack.device) 183 | el_loss = F.cross_entropy(-1 * energy_stack, target.long(), reduction='none') 184 | else: 185 | el_loss = torch.tensor(0.0, device=self.device) 186 | 187 | # total losses 188 | loss = tr_loss + rot_loss + ec_loss + el_loss + ires_loss 189 | losses = {"tr_loss": tr_loss, "rot_loss": rot_loss, "ec_loss": ec_loss, "el_loss": el_loss, "ires_loss": ires_loss, "loss": loss} 190 | 191 | return losses 192 | 193 | def modify_coords(self, lig_pos, rot_update, tr_update): 194 | cen = lig_pos[..., 1, :].mean(dim=0) 195 | rot = axis_angle_to_matrix(rot_update.squeeze()) 196 | tr = tr_update.squeeze() 197 | lig_pos = (lig_pos - cen) @ rot.T + cen 198 | lig_pos = lig_pos + tr 199 | return lig_pos 200 | 201 | def step(self, batch, batch_idx): 202 | rec_x = batch['rec_x'].squeeze(0) 203 | lig_x = batch['lig_x'].squeeze(0) 204 | rec_pos = batch['rec_pos'].squeeze(0) 205 | lig_pos = batch['lig_pos'].squeeze(0) 206 | position_matrix = batch['position_matrix'].squeeze(0) 207 | ires = batch['ires'].squeeze(0) 208 | 209 | # wrap to a batch 210 | batch = { 211 | "rec_x": rec_x, 212 | "lig_x": lig_x, 213 | "rec_pos": rec_pos, 214 | "lig_pos": lig_pos, 215 | "position_matrix": position_matrix, 216 | "ires": ires, 217 | } 218 | 219 | # get losses 220 | losses = self.loss_fn(batch) 221 | return losses 222 | 223 | def get_esm_rep(self, out): 224 | with torch.no_grad(): 225 | results = self.esm_model(out, repr_layers = [33]) 226 | rep = results["representations"][33] 227 | return rep[0, :, :] 228 | 229 | def training_step(self, batch, batch_idx): 230 | losses = self.step(batch, batch_idx) 231 | for loss_name, indiv_loss in losses.items(): 232 | self.log( 233 | f"train/{loss_name}", 234 | indiv_loss, 235 | batch_size=1, 236 | ) 237 | return losses["loss"] 238 | 239 | def on_validation_model_eval(self, *args, **kwargs): 240 | super().on_validation_model_eval(*args, **kwargs) 241 | torch.set_grad_enabled(True) 242 | 243 | def on_validation_model_train(self, *args, **kwargs): 244 | super().on_validation_model_train(*args, **kwargs) 245 | torch.set_grad_enabled(True) 246 | 247 | def validation_step(self, batch, batch_idx): 248 | losses = self.step(batch, batch_idx) 249 | for loss_name, indiv_loss in losses.items(): 250 | self.log( 251 | f"val/{loss_name}", 252 | indiv_loss, 253 | batch_size=1, 254 | ) 255 | return losses["loss"] 256 | 257 | def test_step(self, batch, batch_idx): 258 | losses = self.step(batch, batch_idx) 259 | for loss_name, indiv_loss in losses.items(): 260 | self.log( 261 | f"test/{loss_name}", 262 | indiv_loss, 263 | batch_size=1, 264 | ) 265 | return losses["loss"] 266 | 267 | def configure_optimizers(self): 268 | optimizer = torch.optim.AdamW( 269 | filter(lambda p: p.requires_grad, self.parameters()), 270 | lr=self.lr, 271 | weight_decay=self.weight_decay 272 | ) 273 | return optimizer 274 | 275 | #---------------------------------------------------------------------------- 276 | # Helpers 277 | 278 | def get_rmsd(pred, label): 279 | rmsd = torch.sqrt(torch.mean(torch.sum((pred - label) ** 2.0, dim=-1))) 280 | return rmsd 281 | 282 | #---------------------------------------------------------------------------- 283 | # Testing run 284 | 285 | @hydra.main(version_base=None, config_path="/scratch4/jgray21/lchu11/graylab_repos/DFMDock/configs/model", config_name="score_model_mlsb.yaml") 286 | def main(conf: DictConfig): 287 | dataset = PPIDataset( 288 | dataset='dips_train', 289 | crop_size=500, 290 | ) 291 | 292 | subset_indices = [0] 293 | subset = data.Subset(dataset, subset_indices) 294 | 295 | #load dataset 296 | dataloader = DataLoader(subset) 297 | 298 | model = Score_Model( 299 | model=conf.model, 300 | diffuser=conf.diffuser, 301 | experiment=conf.experiment 302 | ) 303 | trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=10, inference_mode=False) 304 | trainer.validate(model, dataloader) 305 | 306 | if __name__ == '__main__': 307 | main() 308 | -------------------------------------------------------------------------------- /src/models/score_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from dataclasses import dataclass 6 | from einops import repeat 7 | from models.egnn import E_GCL 8 | from utils.coords6d import get_coords6d 9 | 10 | #---------------------------------------------------------------------------- 11 | # Data class for model config 12 | 13 | @dataclass 14 | class ModelConfig: 15 | lm_embed_dim: int 16 | positional_embed_dim: int 17 | spatial_embed_dim: int 18 | contact_embed_dim: int 19 | node_dim: int 20 | edge_dim: int 21 | inner_dim: int 22 | depth: int 23 | dropout: float = 0.0 24 | cut_off: float = 30.0 25 | normalize: bool = False 26 | 27 | #---------------------------------------------------------------------------- 28 | # Helper functions 29 | 30 | def get_spatial_matrix(coord): 31 | dist, omega, theta, phi = get_coords6d(coord) 32 | 33 | mask = dist < 22.0 34 | 35 | num_dist_bins = 40 36 | num_omega_bins = 24 37 | num_theta_bins = 24 38 | num_phi_bins = 12 39 | dist_bin = get_bins(dist, 3.25, 50.75, num_dist_bins) 40 | omega_bin = get_bins(omega, -180.0, 180.0, num_omega_bins) 41 | theta_bin = get_bins(theta, -180.0, 180.0, num_theta_bins) 42 | phi_bin = get_bins(phi, 0.0, 180.0, num_phi_bins) 43 | 44 | def mask_mat(mat, num_bins): 45 | mat[~mask] = 0 46 | mat.fill_diagonal_(0) 47 | return mat 48 | 49 | omega_bin = mask_mat(omega_bin, num_omega_bins) 50 | theta_bin = mask_mat(theta_bin, num_theta_bins) 51 | phi_bin = mask_mat(phi_bin, num_phi_bins) 52 | 53 | # to onehot 54 | dist = F.one_hot(dist_bin, num_classes=num_dist_bins).float() 55 | omega = F.one_hot(omega_bin, num_classes=num_omega_bins).float() 56 | theta = F.one_hot(theta_bin, num_classes=num_theta_bins).float() 57 | phi = F.one_hot(phi_bin, num_classes=num_phi_bins).float() 58 | 59 | return torch.cat([dist, omega, theta, phi], dim=-1) 60 | 61 | def get_bins(x, min_bin, max_bin, num_bins): 62 | # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. 63 | boundaries = torch.linspace( 64 | min_bin, 65 | max_bin, 66 | num_bins - 1, 67 | device=x.device, 68 | ) 69 | bins = torch.sum(x.unsqueeze(-1) > boundaries, dim=-1) # [..., L, L] 70 | return bins 71 | 72 | def get_clashes(distance): 73 | return torch.sum(distance <= 3.0) 74 | 75 | def sample_indices(matrix, num_samples): 76 | n, m = matrix.shape 77 | # Generate random permutations of indices for each row 78 | permuted_indices = torch.argsort(torch.rand(n, m, device=matrix.device), dim=1) 79 | 80 | # Select the first num_samples indices from each permutation 81 | sampled_indices = permuted_indices[:, :num_samples] 82 | 83 | return sampled_indices 84 | 85 | def get_knn_and_sample(points, knn=20, sample_size=40, epsilon=1e-10): 86 | device = points.device 87 | n_points = points.size(0) 88 | 89 | if n_points < knn: 90 | knn = n_points 91 | sample_size = 0 92 | 93 | if n_points < knn + sample_size: 94 | sample_size = n_points - knn 95 | 96 | # Step 1: Compute pairwise distances 97 | dist_matrix = torch.cdist(points, points) 98 | 99 | # Step 2: Find the 20 nearest neighbors (including the point itself) 100 | _, knn_indices = torch.topk(dist_matrix, k=knn, largest=False) 101 | 102 | if sample_size > 0: 103 | # Step 3: Create a mask for the non-knn points 104 | mask = torch.ones(n_points, n_points, dtype=torch.bool, device=device) 105 | mask.scatter_(1, knn_indices, False) 106 | 107 | # Select the non-knn distances and compute inverse cubic distances 108 | non_knn_distances = dist_matrix[mask].view(n_points, -1) 109 | 110 | # Replace zero distances with a small value to avoid division by zero 111 | non_knn_distances = torch.where(non_knn_distances < epsilon, torch.tensor(epsilon, device=device), non_knn_distances) 112 | 113 | inv_cubic_distances = 1 / torch.pow(non_knn_distances, 3) 114 | 115 | # Normalize the inverse cubic distances to get probabilities 116 | probabilities = inv_cubic_distances / inv_cubic_distances.sum(dim=1, keepdim=True) 117 | 118 | # Ensure there are no NaNs or negative values 119 | probabilities = torch.nan_to_num(probabilities, nan=0.0, posinf=0.0, neginf=0.0) 120 | probabilities = torch.clamp(probabilities, min=0) 121 | 122 | # Normalize again to ensure it's a proper probability distribution 123 | probabilities /= probabilities.sum(dim=1, keepdim=True) 124 | 125 | # Generate a tensor of indices excluding knn_indices 126 | all_indices = torch.arange(n_points, device=device).expand(n_points, n_points) 127 | non_knn_indices = all_indices[mask].view(n_points, -1) 128 | 129 | # Step 4: Sample 40 indices based on the probability distribution 130 | sample_indices = torch.multinomial(probabilities, sample_size, replacement=False) 131 | sampled_points_indices = non_knn_indices.gather(1, sample_indices) 132 | else: 133 | sampled_points_indices = None 134 | 135 | return knn_indices, sampled_points_indices 136 | 137 | #---------------------------------------------------------------------------- 138 | # Edge seletion functions 139 | 140 | def get_knn_and_sample_graph(x, e, knn=20, sample_size=40): 141 | knn_indices, sampled_points_indices = get_knn_and_sample(x, knn=knn, sample_size=sample_size) 142 | if sampled_points_indices is not None: 143 | indices = torch.cat([knn_indices, sampled_points_indices], dim=-1) 144 | else: 145 | indices = knn_indices 146 | n_points, n_samples = indices.shape 147 | 148 | # edge src and dst 149 | edge_src = torch.arange(start=0, end=n_points, device=x.device)[..., None].repeat(1, n_samples).reshape(-1) 150 | edge_dst = indices.reshape(-1) 151 | 152 | # combine graphs 153 | edge_index = [edge_src, edge_dst] 154 | edge_indices = torch.stack(edge_index, dim=1) 155 | edge_attr = e[edge_indices[:, 0], edge_indices[:, 1]] 156 | 157 | return edge_index, edge_attr 158 | 159 | #---------------------------------------------------------------------------- 160 | # nn Modules 161 | 162 | class GaussianFourierProjection(nn.Module): 163 | """Gaussian random features for encoding time steps.""" 164 | def __init__(self, embed_dim, scale=1.): 165 | super().__init__() 166 | # Randomly sample weights during initialization. These weights are fixed 167 | # during optimization and are not trainable. 168 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 169 | 170 | def forward(self, x): 171 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 172 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 173 | 174 | 175 | class EGNNLayer(nn.Module): 176 | def __init__( 177 | self, 178 | node_dim, 179 | edge_dim=0, 180 | act_fn=nn.SiLU(), 181 | residual=True, 182 | attention=False, 183 | normalize=False, 184 | tanh=False, 185 | update_coords=False, 186 | coord_weights_clamp_value=2.0, 187 | dropout=0.0, 188 | ): 189 | super(EGNNLayer, self).__init__() 190 | self.egcl = E_GCL( 191 | input_nf=node_dim, 192 | output_nf=node_dim, 193 | hidden_nf=node_dim, 194 | edges_in_d=edge_dim, 195 | act_fn=act_fn, 196 | residual=residual, 197 | attention=attention, 198 | normalize=normalize, 199 | tanh=tanh, 200 | update_coords=update_coords, 201 | coord_weights_clamp_value=coord_weights_clamp_value, 202 | dropout=dropout, 203 | ) 204 | 205 | def forward(self, h, x, edges, edge_attr=None, lig_mask=None): 206 | h, x, edge_attr = self.egcl(h, edges, x, edge_attr=edge_attr, lig_mask=lig_mask) 207 | return h, x, edge_attr 208 | 209 | 210 | class EGNN(nn.Module): 211 | def __init__( 212 | self, 213 | node_dim, 214 | edge_dim=0, 215 | act_fn=nn.SiLU(), 216 | depth=4, 217 | residual=True, 218 | attention=False, 219 | normalize=False, 220 | tanh=False, 221 | dropout=0.0, 222 | ): 223 | super(EGNN, self).__init__() 224 | self.depth = depth 225 | for i in range(depth): 226 | is_last = i == depth - 1 227 | self.add_module("EGNN_%d" % i, EGNNLayer( 228 | node_dim=node_dim, 229 | edge_dim=edge_dim, 230 | act_fn=act_fn, 231 | residual=residual, 232 | attention=attention, 233 | normalize=normalize, 234 | tanh=tanh, 235 | dropout=dropout, 236 | update_coords=is_last, 237 | ) 238 | ) 239 | 240 | def forward(self, h, x, edges, edge_attr=None, lig_mask=None): 241 | for i in range(self.depth): 242 | h, x, edge_attr = self._modules["EGNN_%d" % i](h, x, edges, edge_attr=edge_attr, lig_mask=lig_mask) 243 | return h, x, edge_attr 244 | 245 | 246 | #---------------------------------------------------------------------------- 247 | # Main score network 248 | 249 | class Score_Net(nn.Module): 250 | """EGNN backbone for translation and rotation scores""" 251 | def __init__( 252 | self, 253 | conf, 254 | ): 255 | super().__init__() 256 | lm_embed_dim = conf.lm_embed_dim 257 | spatial_embed_dim = conf.spatial_embed_dim 258 | positional_embed_dim = conf.positional_embed_dim 259 | node_dim = conf.node_dim 260 | edge_dim = conf.edge_dim 261 | inner_dim = conf.inner_dim 262 | depth = conf.depth 263 | dropout = conf.dropout 264 | normalize = conf.normalize 265 | 266 | self.cut_off = conf.cut_off 267 | 268 | # single init embedding 269 | self.single_embed = nn.Linear(lm_embed_dim, node_dim, bias=False) 270 | 271 | # pair init embedding 272 | self.spatial_embed = nn.Linear(spatial_embed_dim, edge_dim, bias=False) 273 | self.positional_embed = nn.Linear(positional_embed_dim, edge_dim, bias=False) 274 | 275 | # denoising score network 276 | self.network = EGNN( 277 | node_dim=node_dim, 278 | edge_dim=edge_dim, 279 | act_fn=nn.SiLU(), 280 | depth=depth, 281 | residual=True, 282 | attention=True, 283 | normalize=normalize, 284 | tanh=False, 285 | dropout=dropout, 286 | ) 287 | 288 | # energy head 289 | self.to_energy = nn.Sequential( 290 | nn.Linear(2*node_dim, node_dim, bias=False), 291 | nn.LayerNorm(node_dim), 292 | nn.SiLU(), 293 | nn.Linear(node_dim, 1, bias=False), 294 | ) 295 | 296 | # interface residue head 297 | self.to_ires = nn.Sequential( 298 | nn.Linear(node_dim, 2*node_dim), 299 | nn.SiLU(), 300 | nn.Linear(2*node_dim, 2*node_dim), 301 | nn.SiLU(), 302 | nn.Linear(2*node_dim, 1), 303 | ) 304 | 305 | # timestep embedding 306 | self.t_embed = nn.Sequential( 307 | GaussianFourierProjection(embed_dim=inner_dim), 308 | nn.Linear(inner_dim, inner_dim, bias=False), 309 | nn.Sigmoid(), 310 | ) 311 | 312 | # tr_scale mlp 313 | self.tr_scale = nn.Sequential( 314 | nn.Linear(inner_dim + 1, inner_dim, bias=False), 315 | nn.LayerNorm(inner_dim), 316 | nn.Dropout(dropout), 317 | nn.SiLU(), 318 | nn.Linear(inner_dim, 1, bias=False), 319 | nn.Softplus() 320 | ) 321 | 322 | # rot_scale mlp 323 | self.rot_scale = nn.Sequential( 324 | nn.Linear(inner_dim + 1, inner_dim, bias=False), 325 | nn.LayerNorm(inner_dim), 326 | nn.Dropout(dropout), 327 | nn.SiLU(), 328 | nn.Linear(inner_dim, 1, bias=False), 329 | nn.Softplus() 330 | ) 331 | 332 | self.apply(self._init_weights) 333 | 334 | def _init_weights(self, module): 335 | if isinstance(module, nn.Linear): 336 | module.weight.data.normal_(mean=0.0, std=0.02) 337 | if module.bias is not None: 338 | module.bias.data.zero_() 339 | elif isinstance(module, nn.LayerNorm): 340 | module.bias.data.zero_() 341 | module.weight.data.fill_(1.0) 342 | 343 | def forward(self, batch, predict=False, return_energy=False): 344 | # get inputs 345 | rec_x = batch["rec_x"] 346 | lig_x = batch["lig_x"] 347 | rec_pos = batch["rec_pos"] 348 | lig_pos = batch["lig_pos"] 349 | t = batch["t"] 350 | position_matrix = batch["position_matrix"] 351 | 352 | # get the current complex pose 353 | lig_pos.requires_grad_() 354 | pos = torch.cat([rec_pos, lig_pos], dim=0) 355 | 356 | # get ca distance matrix 357 | D = torch.norm((rec_pos[:, None, 1, :] - lig_pos[None, :, 1, :]), dim=-1) 358 | 359 | # node feature embedding 360 | x = torch.cat([rec_x, lig_x], dim=0) 361 | node = self.single_embed(x) # [n, c] 362 | 363 | # edge feature embedding 364 | spatial_matrix = get_spatial_matrix(pos) 365 | edge = self.spatial_embed(spatial_matrix) + self.positional_embed(position_matrix) 366 | 367 | # sample edge_index and get edge_attr 368 | edge_index, edge_attr = get_knn_and_sample_graph(pos[..., 1, :], edge) 369 | 370 | # get ligand mask 371 | lig_mask = torch.zeros(x.size(0), device=x.device) 372 | lig_mask[rec_x.size(0):] = 1.0 373 | 374 | # main network 375 | node_out, pos_out, _ = self.network(node, pos[..., 1, :], edge_index, edge_attr, lig_mask) # [R+L, H] 376 | 377 | # interface residue 378 | ires = self.to_ires(node_out) 379 | 380 | # energy 381 | h_rec = repeat(node_out[:rec_pos.size(0)], 'n h -> n m h', m=lig_pos.size(0)) 382 | h_lig = repeat(node_out[rec_pos.size(0):], 'm h -> n m h', n=rec_pos.size(0)) 383 | energy = self.to_energy(torch.cat([h_rec, h_lig], dim=-1)).squeeze(-1) # [R, L] 384 | mask_2D = (D < self.cut_off).float() # [R, L] 385 | energy = (energy * mask_2D).sum() / (mask_2D.sum() + 1e-6) 386 | 387 | if return_energy: 388 | return energy 389 | 390 | # force 391 | lig_pos_curr = pos_out[rec_pos.size(0):] 392 | r = lig_pos[..., 1, :].detach() 393 | f = lig_pos_curr - r # f / kT 394 | 395 | # translation 396 | tr_pred = f.mean(dim=0, keepdim=True) 397 | 398 | # rotation 399 | rot_pred = torch.cross(r, f, dim=-1).mean(dim=0, keepdim=True) 400 | 401 | # scale 402 | t = self.t_embed(t) 403 | tr_norm = torch.linalg.vector_norm(tr_pred, keepdim=True) 404 | tr_score = tr_pred / (tr_norm + 1e-6) * self.tr_scale(torch.cat([tr_norm, t], dim=-1)) 405 | rot_norm = torch.linalg.vector_norm(rot_pred, keepdim=True) 406 | rot_score = rot_pred / (rot_norm + 1e-6) * self.rot_scale(torch.cat([rot_norm, t], dim=-1)) 407 | 408 | if predict: 409 | num_clashes = get_clashes(D) 410 | 411 | outputs = { 412 | "tr_score": tr_score, 413 | "rot_score": rot_score, 414 | "energy": energy, 415 | "f": f, 416 | "num_clashes": num_clashes, 417 | "ires": ires, 418 | } 419 | 420 | return outputs 421 | 422 | # dedx 423 | dedx = torch.autograd.grad( 424 | outputs=energy, 425 | inputs=lig_pos, 426 | grad_outputs=torch.ones_like(energy), 427 | create_graph=self.training, 428 | retain_graph=self.training, 429 | only_inputs=True, 430 | allow_unused=True, 431 | )[0] 432 | 433 | dedx = -dedx[..., 1, :] # F / kT 434 | 435 | outputs = { 436 | "tr_score": tr_score, 437 | "rot_score": rot_score, 438 | "energy": energy, 439 | "f": f, 440 | "dedx": dedx, 441 | "ires": ires, 442 | } 443 | 444 | return outputs 445 | 446 | #---------------------------------------------------------------------------- 447 | # Testing 448 | 449 | if __name__ == '__main__': 450 | conf = ModelConfig( 451 | lm_embed_dim=1280, 452 | positional_embed_dim=68, 453 | spatial_embed_dim=100, 454 | contact_embed_dim=1, 455 | node_dim=24, 456 | edge_dim=12, 457 | inner_dim=24, 458 | depth=2, 459 | ) 460 | 461 | model = Score_Net(conf) 462 | 463 | rec_x = torch.randn(40, 1280) 464 | lig_x = torch.randn(5, 1280) 465 | rec_pos = torch.randn(40, 3, 3) 466 | lig_pos = torch.randn(5, 3, 3) 467 | t = torch.tensor([0.5]) 468 | contact_matrix = torch.zeros(45, 45) 469 | position_matrix = torch.zeros(45, 45, 68) 470 | 471 | batch = { 472 | "rec_x": rec_x, 473 | "lig_x": lig_x, 474 | "rec_pos": rec_pos, 475 | "lig_pos": lig_pos, 476 | "t": t, 477 | "contact_matrix": contact_matrix, 478 | "position_matrix": position_matrix, 479 | } 480 | 481 | out = model(batch) 482 | print(out) 483 | -------------------------------------------------------------------------------- /src/models/score_net_mlsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from dataclasses import dataclass 6 | from einops import repeat 7 | from models.egnn import E_GCL 8 | from utils.coords6d import get_coords6d 9 | 10 | #---------------------------------------------------------------------------- 11 | # Data class for model config 12 | 13 | @dataclass 14 | class ModelConfig: 15 | lm_embed_dim: int 16 | positional_embed_dim: int 17 | spatial_embed_dim: int 18 | contact_embed_dim: int 19 | node_dim: int 20 | edge_dim: int 21 | inner_dim: int 22 | depth: int 23 | dropout: float = 0.0 24 | cut_off: float = 30.0 25 | normalize: bool = False 26 | 27 | #---------------------------------------------------------------------------- 28 | # Helper functions 29 | 30 | def get_spatial_matrix(coord): 31 | dist, omega, theta, phi = get_coords6d(coord) 32 | 33 | mask = dist < 22.0 34 | 35 | num_dist_bins = 40 36 | num_omega_bins = 24 37 | num_theta_bins = 24 38 | num_phi_bins = 12 39 | dist_bin = get_bins(dist, 3.25, 50.75, num_dist_bins) 40 | omega_bin = get_bins(omega, -180.0, 180.0, num_omega_bins) 41 | theta_bin = get_bins(theta, -180.0, 180.0, num_theta_bins) 42 | phi_bin = get_bins(phi, 0.0, 180.0, num_phi_bins) 43 | 44 | def mask_mat(mat, num_bins): 45 | mat[~mask] = 0 46 | mat.fill_diagonal_(0) 47 | return mat 48 | 49 | omega_bin = mask_mat(omega_bin, num_omega_bins) 50 | theta_bin = mask_mat(theta_bin, num_theta_bins) 51 | phi_bin = mask_mat(phi_bin, num_phi_bins) 52 | 53 | # to onehot 54 | dist = F.one_hot(dist_bin, num_classes=num_dist_bins).float() 55 | omega = F.one_hot(omega_bin, num_classes=num_omega_bins).float() 56 | theta = F.one_hot(theta_bin, num_classes=num_theta_bins).float() 57 | phi = F.one_hot(phi_bin, num_classes=num_phi_bins).float() 58 | 59 | return torch.cat([dist, omega, theta, phi], dim=-1) 60 | 61 | def get_bins(x, min_bin, max_bin, num_bins): 62 | # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. 63 | boundaries = torch.linspace( 64 | min_bin, 65 | max_bin, 66 | num_bins - 1, 67 | device=x.device, 68 | ) 69 | bins = torch.sum(x.unsqueeze(-1) > boundaries, dim=-1) # [..., L, L] 70 | return bins 71 | 72 | def get_clashes(distance): 73 | return torch.sum(distance <= 3.0) 74 | 75 | def sample_indices(matrix, num_samples): 76 | n, m = matrix.shape 77 | # Generate random permutations of indices for each row 78 | permuted_indices = torch.argsort(torch.rand(n, m, device=matrix.device), dim=1) 79 | 80 | # Select the first num_samples indices from each permutation 81 | sampled_indices = permuted_indices[:, :num_samples] 82 | 83 | return sampled_indices 84 | 85 | def get_knn_and_sample(points, knn=20, sample_size=40, epsilon=1e-10): 86 | device = points.device 87 | n_points = points.size(0) 88 | 89 | if n_points < knn: 90 | knn = n_points 91 | sample_size = 0 92 | 93 | if n_points < knn + sample_size: 94 | sample_size = n_points - knn 95 | 96 | # Step 1: Compute pairwise distances 97 | dist_matrix = torch.cdist(points, points) 98 | 99 | # Step 2: Find the 20 nearest neighbors (including the point itself) 100 | _, knn_indices = torch.topk(dist_matrix, k=knn, largest=False) 101 | 102 | if sample_size > 0: 103 | # Step 3: Create a mask for the non-knn points 104 | mask = torch.ones(n_points, n_points, dtype=torch.bool, device=device) 105 | mask.scatter_(1, knn_indices, False) 106 | 107 | # Select the non-knn distances and compute inverse cubic distances 108 | non_knn_distances = dist_matrix[mask].view(n_points, -1) 109 | 110 | # Replace zero distances with a small value to avoid division by zero 111 | non_knn_distances = torch.where(non_knn_distances < epsilon, torch.tensor(epsilon, device=device), non_knn_distances) 112 | 113 | inv_cubic_distances = 1 / torch.pow(non_knn_distances, 3) 114 | 115 | # Normalize the inverse cubic distances to get probabilities 116 | probabilities = inv_cubic_distances / inv_cubic_distances.sum(dim=1, keepdim=True) 117 | 118 | # Ensure there are no NaNs or negative values 119 | probabilities = torch.nan_to_num(probabilities, nan=0.0, posinf=0.0, neginf=0.0) 120 | probabilities = torch.clamp(probabilities, min=0) 121 | 122 | # Normalize again to ensure it's a proper probability distribution 123 | probabilities /= probabilities.sum(dim=1, keepdim=True) 124 | 125 | # Generate a tensor of indices excluding knn_indices 126 | all_indices = torch.arange(n_points, device=device).expand(n_points, n_points) 127 | non_knn_indices = all_indices[mask].view(n_points, -1) 128 | 129 | # Step 4: Sample 40 indices based on the probability distribution 130 | sample_indices = torch.multinomial(probabilities, sample_size, replacement=False) 131 | sampled_points_indices = non_knn_indices.gather(1, sample_indices) 132 | else: 133 | sampled_points_indices = None 134 | 135 | return knn_indices, sampled_points_indices 136 | 137 | #---------------------------------------------------------------------------- 138 | # Edge seletion functions 139 | 140 | def get_knn_and_sample_graph(x, e, knn=20, sample_size=40): 141 | knn_indices, sampled_points_indices = get_knn_and_sample(x, knn=knn, sample_size=sample_size) 142 | if sampled_points_indices is not None: 143 | indices = torch.cat([knn_indices, sampled_points_indices], dim=-1) 144 | else: 145 | indices = knn_indices 146 | n_points, n_samples = indices.shape 147 | 148 | # edge src and dst 149 | edge_src = torch.arange(start=0, end=n_points, device=x.device)[..., None].repeat(1, n_samples).reshape(-1) 150 | edge_dst = indices.reshape(-1) 151 | 152 | # combine graphs 153 | edge_index = [edge_src, edge_dst] 154 | edge_indices = torch.stack(edge_index, dim=1) 155 | edge_attr = e[edge_indices[:, 0], edge_indices[:, 1]] 156 | 157 | return edge_index, edge_attr 158 | 159 | #---------------------------------------------------------------------------- 160 | # nn Modules 161 | 162 | class GaussianFourierProjection(nn.Module): 163 | """Gaussian random features for encoding time steps.""" 164 | def __init__(self, embed_dim, scale=1.): 165 | super().__init__() 166 | # Randomly sample weights during initialization. These weights are fixed 167 | # during optimization and are not trainable. 168 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 169 | 170 | def forward(self, x): 171 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 172 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 173 | 174 | 175 | class EGNNLayer(nn.Module): 176 | def __init__( 177 | self, 178 | node_dim, 179 | edge_dim=0, 180 | act_fn=nn.SiLU(), 181 | residual=True, 182 | attention=False, 183 | normalize=False, 184 | tanh=False, 185 | update_coords=False, 186 | coord_weights_clamp_value=2.0, 187 | dropout=0.0, 188 | ): 189 | super(EGNNLayer, self).__init__() 190 | self.egcl = E_GCL( 191 | input_nf=node_dim, 192 | output_nf=node_dim, 193 | hidden_nf=node_dim, 194 | edges_in_d=edge_dim, 195 | act_fn=act_fn, 196 | residual=residual, 197 | attention=attention, 198 | normalize=normalize, 199 | tanh=tanh, 200 | update_coords=update_coords, 201 | coord_weights_clamp_value=coord_weights_clamp_value, 202 | dropout=dropout, 203 | ) 204 | 205 | def forward(self, h, x, edges, edge_attr=None, lig_mask=None): 206 | h, x, edge_attr = self.egcl(h, edges, x, edge_attr=edge_attr, lig_mask=lig_mask) 207 | return h, x, edge_attr 208 | 209 | 210 | class EGNN(nn.Module): 211 | def __init__( 212 | self, 213 | node_dim, 214 | edge_dim=0, 215 | act_fn=nn.SiLU(), 216 | depth=4, 217 | residual=True, 218 | attention=False, 219 | normalize=False, 220 | tanh=False, 221 | dropout=0.0, 222 | ): 223 | super(EGNN, self).__init__() 224 | self.depth = depth 225 | for i in range(depth): 226 | is_last = i == depth - 1 227 | self.add_module("EGNN_%d" % i, EGNNLayer( 228 | node_dim=node_dim, 229 | edge_dim=edge_dim, 230 | act_fn=act_fn, 231 | residual=residual, 232 | attention=attention, 233 | normalize=normalize, 234 | tanh=tanh, 235 | dropout=dropout, 236 | update_coords=is_last, 237 | ) 238 | ) 239 | 240 | def forward(self, h, x, edges, edge_attr=None, lig_mask=None): 241 | for i in range(self.depth): 242 | h, x, edge_attr = self._modules["EGNN_%d" % i](h, x, edges, edge_attr=edge_attr, lig_mask=lig_mask) 243 | return h, x, edge_attr 244 | 245 | 246 | #---------------------------------------------------------------------------- 247 | # Main score network 248 | 249 | class Score_Net(nn.Module): 250 | """EGNN backbone for translation and rotation scores""" 251 | def __init__( 252 | self, 253 | conf, 254 | ): 255 | super().__init__() 256 | lm_embed_dim = conf.lm_embed_dim 257 | spatial_embed_dim = conf.spatial_embed_dim 258 | positional_embed_dim = conf.positional_embed_dim 259 | node_dim = conf.node_dim 260 | edge_dim = conf.edge_dim 261 | inner_dim = conf.inner_dim 262 | depth = conf.depth 263 | dropout = conf.dropout 264 | normalize = conf.normalize 265 | 266 | self.cut_off = conf.cut_off 267 | 268 | # single init embedding 269 | self.single_embed = nn.Linear(lm_embed_dim, node_dim, bias=False) 270 | 271 | # pair init embedding 272 | self.spatial_embed = nn.Linear(spatial_embed_dim, edge_dim, bias=False) 273 | self.positional_embed = nn.Linear(positional_embed_dim, edge_dim, bias=False) 274 | 275 | # denoising score network 276 | self.network = EGNN( 277 | node_dim=node_dim, 278 | edge_dim=edge_dim, 279 | act_fn=nn.SiLU(), 280 | depth=depth, 281 | residual=True, 282 | attention=True, 283 | normalize=normalize, 284 | tanh=False, 285 | dropout=dropout, 286 | ) 287 | 288 | # energy head 289 | self.to_energy = nn.Sequential( 290 | nn.Linear(2*node_dim, node_dim, bias=False), 291 | nn.LayerNorm(node_dim), 292 | nn.SiLU(), 293 | nn.Linear(node_dim, 1, bias=False), 294 | ) 295 | 296 | # interface residue head 297 | self.to_ires = nn.Sequential( 298 | nn.Linear(node_dim, 2*node_dim), 299 | nn.SiLU(), 300 | nn.Linear(2*node_dim, 2*node_dim), 301 | nn.SiLU(), 302 | nn.Linear(2*node_dim, 1), 303 | ) 304 | 305 | # timestep embedding 306 | self.t_embed = nn.Sequential( 307 | GaussianFourierProjection(embed_dim=inner_dim), 308 | nn.Linear(inner_dim, inner_dim, bias=False), 309 | nn.Sigmoid(), 310 | ) 311 | 312 | # tr_scale mlp 313 | self.tr_scale = nn.Sequential( 314 | nn.Linear(inner_dim + 1, inner_dim, bias=False), 315 | nn.LayerNorm(inner_dim), 316 | nn.Dropout(dropout), 317 | nn.SiLU(), 318 | nn.Linear(inner_dim, 1, bias=False), 319 | nn.Softplus() 320 | ) 321 | 322 | # rot_scale mlp 323 | self.rot_scale = nn.Sequential( 324 | nn.Linear(inner_dim + 1, inner_dim, bias=False), 325 | nn.LayerNorm(inner_dim), 326 | nn.Dropout(dropout), 327 | nn.SiLU(), 328 | nn.Linear(inner_dim, 1, bias=False), 329 | nn.Softplus() 330 | ) 331 | 332 | self.apply(self._init_weights) 333 | 334 | def _init_weights(self, module): 335 | if isinstance(module, nn.Linear): 336 | module.weight.data.normal_(mean=0.0, std=0.02) 337 | if module.bias is not None: 338 | module.bias.data.zero_() 339 | elif isinstance(module, nn.LayerNorm): 340 | module.bias.data.zero_() 341 | module.weight.data.fill_(1.0) 342 | 343 | def forward(self, batch, predict=False, return_energy=False): 344 | # get inputs 345 | rec_x = batch["rec_x"] 346 | lig_x = batch["lig_x"] 347 | rec_pos = batch["rec_pos"] 348 | lig_pos = batch["lig_pos"] 349 | t = batch["t"] 350 | position_matrix = batch["position_matrix"] 351 | 352 | # move to center 353 | center = lig_pos[..., 1, :].mean(dim=0) 354 | rec_pos = rec_pos - center 355 | lig_pos = lig_pos - center 356 | 357 | # get the current complex pose 358 | lig_pos.requires_grad_() 359 | pos = torch.cat([rec_pos, lig_pos], dim=0) 360 | 361 | # get ca distance matrix 362 | D = torch.norm((rec_pos[:, None, 1, :] - lig_pos[None, :, 1, :]), dim=-1) 363 | 364 | # node feature embedding 365 | x = torch.cat([rec_x, lig_x], dim=0) 366 | node = self.single_embed(x) # [n, c] 367 | 368 | # edge feature embedding 369 | spatial_matrix = get_spatial_matrix(pos) 370 | edge = self.spatial_embed(spatial_matrix) + self.positional_embed(position_matrix) 371 | 372 | # sample edge_index and get edge_attr 373 | edge_index, edge_attr = get_knn_and_sample_graph(pos[..., 1, :], edge) 374 | 375 | # get ligand mask 376 | lig_mask = torch.zeros(x.size(0), device=x.device) 377 | lig_mask[rec_x.size(0):] = 1.0 378 | 379 | # main network 380 | node_out, pos_out, _ = self.network(node, pos[..., 1, :], edge_index, edge_attr, lig_mask) # [R+L, H] 381 | 382 | # interface residue 383 | ires = self.to_ires(node_out) 384 | 385 | # energy 386 | h_rec = repeat(node_out[:rec_pos.size(0)], 'n h -> n m h', m=lig_pos.size(0)) 387 | h_lig = repeat(node_out[rec_pos.size(0):], 'm h -> n m h', n=rec_pos.size(0)) 388 | energy = self.to_energy(torch.cat([h_rec, h_lig], dim=-1)).squeeze(-1) # [R, L] 389 | mask_2D = (D < self.cut_off).float() # [R, L] 390 | energy = (energy * mask_2D).sum() / (mask_2D.sum() + 1e-6) 391 | 392 | if return_energy: 393 | return energy 394 | 395 | # force 396 | lig_pos_curr = pos_out[rec_pos.size(0):] 397 | r = lig_pos[..., 1, :].detach() 398 | f = lig_pos_curr - r # f / kT 399 | 400 | # translation 401 | tr_pred = f.mean(dim=0, keepdim=True) 402 | 403 | # rotation 404 | rot_pred = torch.cross(r, f, dim=-1).mean(dim=0, keepdim=True) 405 | 406 | # scale 407 | t = self.t_embed(t) 408 | tr_norm = torch.linalg.vector_norm(tr_pred, keepdim=True) 409 | tr_score = tr_pred / (tr_norm + 1e-6) * self.tr_scale(torch.cat([tr_norm, t], dim=-1)) 410 | rot_norm = torch.linalg.vector_norm(rot_pred, keepdim=True) 411 | rot_score = rot_pred / (rot_norm + 1e-6) * self.rot_scale(torch.cat([rot_norm, t], dim=-1)) 412 | 413 | if predict: 414 | num_clashes = get_clashes(D) 415 | 416 | outputs = { 417 | "tr_score": tr_score, 418 | "rot_score": rot_score, 419 | "energy": energy, 420 | "f": f, 421 | "num_clashes": num_clashes, 422 | "ires": ires, 423 | } 424 | 425 | return outputs 426 | 427 | # dedx 428 | dedx = torch.autograd.grad( 429 | outputs=energy, 430 | inputs=lig_pos, 431 | grad_outputs=torch.ones_like(energy), 432 | create_graph=self.training, 433 | retain_graph=self.training, 434 | only_inputs=True, 435 | allow_unused=True, 436 | )[0] 437 | 438 | dedx = -dedx[..., 1, :] # F / kT 439 | 440 | outputs = { 441 | "tr_score": tr_score, 442 | "rot_score": rot_score, 443 | "energy": energy, 444 | "f": f, 445 | "dedx": dedx, 446 | "ires": ires, 447 | } 448 | 449 | return outputs 450 | 451 | #---------------------------------------------------------------------------- 452 | # Testing 453 | 454 | if __name__ == '__main__': 455 | conf = ModelConfig( 456 | lm_embed_dim=1280, 457 | positional_embed_dim=68, 458 | spatial_embed_dim=100, 459 | contact_embed_dim=1, 460 | node_dim=24, 461 | edge_dim=12, 462 | inner_dim=24, 463 | depth=2, 464 | ) 465 | 466 | model = Score_Net(conf) 467 | 468 | rec_x = torch.randn(40, 1280) 469 | lig_x = torch.randn(5, 1280) 470 | rec_pos = torch.randn(40, 3, 3) 471 | lig_pos = torch.randn(5, 3, 3) 472 | t = torch.tensor([0.5]) 473 | contact_matrix = torch.zeros(45, 45) 474 | position_matrix = torch.zeros(45, 45, 68) 475 | 476 | batch = { 477 | "rec_x": rec_x, 478 | "lig_x": lig_x, 479 | "rec_pos": rec_pos, 480 | "lig_pos": lig_pos, 481 | "t": t, 482 | "contact_matrix": contact_matrix, 483 | "position_matrix": position_matrix, 484 | } 485 | 486 | out = model(batch) 487 | print(out) 488 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import os 4 | from omegaconf import DictConfig 5 | 6 | os.environ["WANDB__SERVICE_WAIT"] = "300" 7 | 8 | @hydra.main(version_base="1.1", config_path="../configs/", config_name="config.yaml") 9 | def main(config: DictConfig): 10 | torch.manual_seed(0) 11 | 12 | # Imports can be nested inside @hydra.main to optimize tab completion 13 | # https://github.com/facebookresearch/hydra/issues/934 14 | from train import train 15 | from utils import utils 16 | 17 | # A couple of optional utilities: 18 | # - disabling python warnings 19 | # - forcing debug-friendly configuration 20 | # - verifying experiment name is set when running in experiment mode 21 | # You can safely get rid of this line if you don't want those 22 | utils.extras(config) 23 | 24 | # Pretty print config using Rich library 25 | if config.get("print_config"): 26 | utils.print_config(config, resolve=True) 27 | 28 | # Train model 29 | return train(config) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from typing import List, Optional 3 | from omegaconf import DictConfig 4 | from pytorch_lightning import ( 5 | Callback, 6 | LightningDataModule, 7 | LightningModule, 8 | Trainer, 9 | seed_everything, 10 | ) 11 | from pytorch_lightning.loggers import Logger 12 | from utils import utils 13 | 14 | 15 | log = utils.get_logger(__name__) 16 | 17 | 18 | def train(config: DictConfig) -> Optional[float]: 19 | """Contains training pipeline. 20 | Instantiates all PyTorch Lightning objects from config. 21 | Args: 22 | config (DictConfig): Configuration composed by Hydra. 23 | Returns: 24 | Optional[float]: Metric score for hyperparameter optimization. 25 | """ 26 | 27 | # Set seed for random number generators in pytorch, numpy and python.random 28 | if config.get("seed"): 29 | seed_everything(config.seed, workers=True) 30 | 31 | # Init lightning datamodule 32 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 33 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 34 | 35 | # Init lightning model 36 | log.info(f"Instantiating model <{config.model._target_}>") 37 | model: LightningModule = hydra.utils.instantiate(config.model) 38 | 39 | # Init lightning callbacks 40 | callbacks: List[Callback] = [] 41 | if "callbacks" in config: 42 | for _, cb_conf in config.callbacks.items(): 43 | if "_target_" in cb_conf: 44 | log.info(f"Instantiating callback <{cb_conf._target_}>") 45 | callbacks.append(hydra.utils.instantiate(cb_conf)) 46 | 47 | # Init lightning loggers 48 | logger: List[Logger] = [] 49 | if "logger" in config: 50 | for _, lg_conf in config.logger.items(): 51 | if "_target_" in lg_conf: 52 | log.info(f"Instantiating logger <{lg_conf._target_}>") 53 | logger.append(hydra.utils.instantiate(lg_conf)) 54 | 55 | # Init lightning trainer 56 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 57 | trainer: Trainer = hydra.utils.instantiate( 58 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 59 | ) 60 | 61 | # Send some parameters from config to all lightning loggers 62 | log.info("Logging hyperparameters!") 63 | utils.log_hyperparameters( 64 | config=config, 65 | model=model, 66 | datamodule=datamodule, 67 | trainer=trainer, 68 | callbacks=callbacks, 69 | logger=logger, 70 | ) 71 | 72 | # Train the model 73 | log.info("Starting training!") 74 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path) 75 | 76 | # Get metric score for hyperparameter optimization 77 | score = trainer.callback_metrics.get(config.get("optimized_metric")) 78 | 79 | # Test the model 80 | if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): 81 | log.info("Starting testing!") 82 | trainer.test(model=model, datamodule=datamodule, ckpt_path="best") 83 | 84 | # Make sure everything closed properly 85 | log.info("Finalizing!") 86 | utils.finish( 87 | config=config, 88 | model=model, 89 | datamodule=datamodule, 90 | trainer=trainer, 91 | callbacks=callbacks, 92 | logger=logger, 93 | ) 94 | 95 | # Print path to best checkpoint 96 | if not config.trainer.get("fast_dev_run"): 97 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") 98 | 99 | # Return metric score for hyperparameter optimization 100 | return score 101 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/coords6d.py: -------------------------------------------------------------------------------- 1 | ### 2 | # Modified from https://github.com/RosettaCommons/trRosetta2/blob/main/trRosetta/coords6d.py 3 | ### 4 | 5 | import math 6 | import torch 7 | from einops import repeat 8 | 9 | 10 | def calc_dist(a_coords, b_coords): 11 | assert a_coords.shape == b_coords.shape 12 | mat_shape = list(a_coords.shape) 13 | mat_shape.insert(-1, mat_shape[-2]) 14 | 15 | a_coords = a_coords.unsqueeze(-3).expand(mat_shape) 16 | b_coords = b_coords.unsqueeze(-2).expand(mat_shape) 17 | 18 | dist_mat = (a_coords - b_coords).norm(dim=-1) 19 | 20 | return dist_mat 21 | 22 | 23 | def calc_dihedral(a_coords, 24 | b_coords, 25 | c_coords, 26 | d_coords, 27 | convert_to_degree=True): 28 | b1 = a_coords - b_coords 29 | b2 = b_coords - c_coords 30 | b3 = c_coords - d_coords 31 | 32 | n1 = torch.linalg.cross(b1, b2) 33 | n1 = torch.div(n1, n1.norm(dim=-1, keepdim=True)) 34 | n2 = torch.linalg.cross(b2, b3) 35 | n2 = torch.div(n2, n2.norm(dim=-1, keepdim=True)) 36 | m1 = torch.linalg.cross(n1, torch.div(b2, b2.norm(dim=-1, keepdim=True))) 37 | 38 | dihedral = torch.atan2((m1 * n2).sum(-1), (n1 * n2).sum(-1)) 39 | 40 | if convert_to_degree: 41 | dihedral = dihedral * 180 / math.pi 42 | 43 | return dihedral 44 | 45 | 46 | def calc_planar(a_coords, b_coords, c_coords, convert_to_degree=True): 47 | v1 = a_coords - b_coords 48 | v2 = c_coords - b_coords 49 | 50 | a = (v1 * v2).sum(-1) 51 | b = v1.norm(dim=-1) * v2.norm(dim=-1) 52 | 53 | planar = torch.acos(a / b) 54 | 55 | if convert_to_degree: 56 | planar = planar * 180 / math.pi 57 | 58 | return planar 59 | 60 | 61 | # get 6d coordinates from x,y,z coords of N,Ca,C atoms 62 | def get_coords6d(xyz, use_Cb=False): 63 | 64 | n = xyz.shape[0] 65 | 66 | # three anchor atoms 67 | N = xyz[..., 0, :] 68 | Ca = xyz[..., 1, :] 69 | C = xyz[..., 2, :] 70 | 71 | # recreate Cb given N,Ca,C 72 | b = Ca - N 73 | c = C - Ca 74 | a = torch.cross(b, c, dim=-1) 75 | Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca 76 | 77 | if use_Cb: 78 | dist = calc_dist(Cb, Cb) 79 | else: 80 | dist = calc_dist(Ca, Ca) 81 | 82 | 83 | omega = calc_dihedral( 84 | repeat(Ca, 'r i -> r c i', c=n), 85 | repeat(Cb, 'r i -> r c i', c=n), 86 | repeat(Cb, 'c i -> r c i', r=n), 87 | repeat(Ca, 'c i -> r c i', r=n), 88 | ) 89 | 90 | theta = calc_dihedral( 91 | repeat(N, 'r i -> r c i', c=n), 92 | repeat(Ca, 'r i -> r c i', c=n), 93 | repeat(Cb, 'r i -> r c i', c=n), 94 | repeat(Cb, 'c i -> r c i', r=n), 95 | ) 96 | phi = calc_planar( 97 | repeat(Ca, 'r i -> r c i', c=n), 98 | repeat(Cb, 'r i -> r c i', c=n), 99 | repeat(Cb, 'c i -> r c i', r=n), 100 | ) 101 | 102 | 103 | return dist, omega, theta, phi 104 | 105 | if __name__ == '__main__': 106 | coords = torch.randn(10, 3, 3) 107 | out = get_coords6d(coords) 108 | print(out) 109 | 110 | -------------------------------------------------------------------------------- /src/utils/crop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def one_hot(x, v_bins): 4 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 5 | diffs = x[..., None] - reshaped_bins 6 | am = torch.argmin(torch.abs(diffs), dim=-1) 7 | return torch.nn.functional.one_hot(am, num_classes=len(v_bins)).float() 8 | 9 | def relpos(res_id, asym_id, use_chain_relative=True): 10 | max_relative_idx = 32 11 | pos = res_id 12 | asym_id_same = (asym_id[..., None] == asym_id[..., None, :]) 13 | offset = pos[..., None] - pos[..., None, :] 14 | 15 | clipped_offset = torch.clamp( 16 | offset + max_relative_idx, 0, 2 * max_relative_idx 17 | ) 18 | 19 | rel_feats = [] 20 | if use_chain_relative: 21 | final_offset = torch.where( 22 | asym_id_same, 23 | clipped_offset, 24 | (2 * max_relative_idx + 1) * 25 | torch.ones_like(clipped_offset) 26 | ) 27 | 28 | boundaries = torch.arange( 29 | start=0, end=2 * max_relative_idx + 2, device=res_id.device 30 | ) 31 | rel_pos = one_hot( 32 | final_offset, 33 | boundaries, 34 | ) 35 | 36 | rel_feats.append(rel_pos) 37 | 38 | else: 39 | boundaries = torch.arange( 40 | start=0, end=2 * max_relative_idx + 1, device=res_id.device 41 | ) 42 | rel_pos = one_hot( 43 | clipped_offset, boundaries, 44 | ) 45 | rel_feats.append(rel_pos) 46 | 47 | rel_feat = torch.cat(rel_feats, dim=-1).float() 48 | 49 | return rel_feat 50 | 51 | def get_interface_residues(coords, asym_id, interface_threshold=10.0): 52 | coord_diff = coords[..., None, :, :] - coords[..., None, :, :, :] 53 | pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) 54 | diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float() 55 | mask = diff_chain_mask[..., None].bool() 56 | min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1) 57 | valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1) 58 | interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0] 59 | 60 | return interface_residues_idxs 61 | 62 | def get_spatial_crop_idx(coords, asym_id, crop_size=256, interface_threshold=10.0): 63 | interface_residues = get_interface_residues(coords, asym_id, interface_threshold=interface_threshold) 64 | 65 | if not torch.any(interface_residues): 66 | return get_contiguous_crop_idx(asym_id, crop_size) 67 | 68 | target_res_idx = randint(lower=0, upper=interface_residues.shape[-1] - 1) 69 | target_res = interface_residues[target_res_idx] 70 | 71 | ca_positions = coords[..., 1, :] 72 | coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :] 73 | ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) 74 | to_target_distances = ca_pairwise_dists[target_res] 75 | 76 | break_tie = ( 77 | torch.arange( 78 | 0, to_target_distances.shape[-1], device=coords.device 79 | ).float() 80 | * 1e-3 81 | ) 82 | to_target_distances += break_tie 83 | ret = torch.argsort(to_target_distances)[:crop_size] 84 | return ret.sort().values 85 | 86 | def get_contiguous_crop_idx(asym_id, crop_size): 87 | unique_asym_ids, chain_idxs, chain_lens = asym_id.unique(dim=-1, 88 | return_inverse=True, 89 | return_counts=True) 90 | 91 | shuffle_idx = torch.randperm(chain_lens.shape[-1]) 92 | 93 | 94 | _, idx_sorted = torch.sort(chain_idxs, stable=True) 95 | cum_sum = chain_lens.cumsum(dim=0) 96 | cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1]), dim=0) 97 | asym_offsets = idx_sorted[cum_sum] 98 | 99 | num_budget = crop_size 100 | num_remaining = len(chain_idxs) 101 | 102 | crop_idxs = [] 103 | for i, idx in enumerate(shuffle_idx): 104 | chain_len = int(chain_lens[idx]) 105 | num_remaining -= chain_len 106 | 107 | if i == 0: 108 | crop_size_max = min(num_budget - 50, chain_len) 109 | crop_size_min = min(chain_len, 50) 110 | else: 111 | crop_size_max = min(num_budget, chain_len) 112 | crop_size_min = min(chain_len, max(50, num_budget - num_remaining)) 113 | 114 | chain_crop_size = randint(lower=crop_size_min, 115 | upper=crop_size_max) 116 | 117 | num_budget -= chain_crop_size 118 | 119 | chain_start = randint(lower=0, 120 | upper=chain_len - chain_crop_size) 121 | 122 | asym_offset = asym_offsets[idx] 123 | crop_idxs.append( 124 | torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size) 125 | ) 126 | 127 | return torch.concat(crop_idxs).sort().values 128 | 129 | def randint(lower, upper): 130 | return int(torch.randint( 131 | lower, 132 | upper + 1, 133 | (1,), 134 | )[0]) 135 | 136 | def get_crop_idxs(batch, crop_size): 137 | rec_pos = batch["rec_pos"] 138 | lig_pos = batch["lig_pos"] 139 | n = rec_pos.size(0) + lig_pos.size(0) 140 | pos = torch.cat([rec_pos, lig_pos], dim=0) 141 | asym_id = torch.zeros(n, device=pos.device).long() 142 | asym_id[rec_pos.size(0):] = 1 143 | 144 | use_spatial_crop = True 145 | num_res = asym_id.size(0) 146 | 147 | if num_res <= crop_size: 148 | crop_idxs = torch.arange(num_res) 149 | elif use_spatial_crop: 150 | crop_idxs = get_spatial_crop_idx(pos, asym_id, crop_size=crop_size) 151 | else: 152 | crop_idxs = get_contiguous_crop_idx(asym_id, crop_size=crop_size) 153 | 154 | crop_idxs = crop_idxs.to(pos.device) 155 | 156 | return crop_idxs 157 | 158 | def get_crop(batch, crop_idxs): 159 | rec_x = batch["rec_x"] 160 | lig_x = batch["lig_x"] 161 | rec_pos = batch["rec_pos"] 162 | lig_pos = batch["lig_pos"] 163 | 164 | n = rec_x.size(0) + lig_x.size(0) 165 | x = torch.cat([rec_x, lig_x], dim=0) 166 | pos = torch.cat([rec_pos, lig_pos], dim=0) 167 | res_id = torch.arange(n, device=x.device).long() 168 | asym_id = torch.zeros(n, device=x.device).long() 169 | asym_id[rec_x.size(0):] = 1 170 | 171 | res_id = torch.index_select(res_id, 0, crop_idxs) 172 | asym_id = torch.index_select(asym_id, 0, crop_idxs) 173 | x = torch.index_select(x, 0, crop_idxs) 174 | pos = torch.index_select(pos, 0, crop_idxs) 175 | 176 | sep = asym_id.tolist().index(1) 177 | rec_x = x[:sep] 178 | lig_x = x[sep:] 179 | rec_pos = pos[:sep] 180 | lig_pos = pos[sep:] 181 | 182 | # Positional embeddings 183 | position_matrix = relpos(res_id, asym_id).to(x.device) 184 | 185 | batch["rec_x"] = rec_x 186 | batch["lig_x"] = lig_x 187 | batch["rec_pos"] = rec_pos 188 | batch["lig_pos"] = lig_pos 189 | batch["position_matrix"] = position_matrix 190 | 191 | return batch 192 | 193 | def get_position_matrix(batch): 194 | rec_x = batch["rec_x"] 195 | lig_x = batch["lig_x"] 196 | x = torch.cat([rec_x, lig_x], dim=0) 197 | 198 | res_id = torch.arange(x.size(0), device=x.device).long() 199 | asym_id = torch.zeros(x.size(0), device=x.device).long() 200 | asym_id[rec_x.size(0):] = 1 201 | 202 | # Positional embeddings 203 | position_matrix = relpos(res_id, asym_id).to(x.device) 204 | 205 | batch["position_matrix"] = position_matrix 206 | 207 | return batch 208 | -------------------------------------------------------------------------------- /src/utils/frame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import repeat 4 | from utils.geometry import matrix_to_rotation_6d 5 | 6 | 7 | def get_rotat(coords): 8 | # Get backbone coordinates. 9 | n_coords = coords[:, 0, :] 10 | ca_coords = coords[:, 1, :] 11 | c_coords = coords[:, 2, :] 12 | 13 | # Gram-Schmidt process. 14 | v1 = c_coords - ca_coords 15 | v2 = n_coords - ca_coords 16 | e1 = F.normalize(v1) 17 | u2 = v2 - e1 * (torch.einsum('b i, b i -> b', e1, v2).unsqueeze(-1)) 18 | e2 = F.normalize(u2) 19 | e3 = torch.cross(e1, e2, dim=-1) 20 | 21 | # Get rotations. 22 | rotations=torch.stack([e1, e2, e3], dim=-1) 23 | return rotations 24 | 25 | def get_trans(coords): 26 | return coords[:, 1, :] 27 | 28 | def get_pair_dist(trans): 29 | vec = repeat(trans, 'i c -> i j c', j=trans.size(0)) - repeat(trans, 'j c -> i j c', i=trans.size(0)) 30 | dist = torch.norm(vec, dim=-1, keepdim=True) 31 | dist = rbf(dist, 2.0, 22.0, n_bins=16) 32 | return dist 33 | 34 | def get_pair_direct(trans, rotat): 35 | vec = repeat(trans, 'i c -> i j c', j=trans.size(0)) - repeat(trans, 'j c -> i j c', i=trans.size(0)) 36 | direct = F.normalize(vec, dim=-1) 37 | rotat = repeat(rotat, 'r i j -> r c i j', c=rotat.size(0)) 38 | direct = torch.einsum('r c i j, r c j -> r c i', rotat.transpose(-1, -2), direct) 39 | return direct 40 | 41 | def get_pair_orient(trans, rotat): 42 | rotat_i = repeat(rotat, 'r i j -> r c i j', c=rotat.size(0)) 43 | rotat_j = repeat(rotat, 'c i j -> r c i j', r=rotat.size(0)) 44 | orient = torch.einsum('r c i j, r c j k -> r c i k', rotat_i.transpose(-1, -2), rotat_j) 45 | orient = matrix_to_rotation_6d(orient) 46 | return orient 47 | 48 | def get_pairs(trans, rotat): 49 | dists = get_pair_dist(trans) 50 | direct = get_pair_direct(trans, rotat) 51 | orient = get_pair_orient(trans, rotat) 52 | pair = torch.cat([dists, direct, orient], dim=-1) 53 | return pair 54 | 55 | def rbf(values, v_min, v_max, n_bins=16): 56 | """ 57 | Returns RBF encodings in a new dimension at the end. 58 | """ 59 | rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device) 60 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) 61 | rbf_std = (v_max - v_min) / n_bins 62 | #v_expand = torch.unsqueeze(values, -1) 63 | z = ((values.unsqueeze(-1) - rbf_centers) / rbf_std).squeeze(-2) 64 | return torch.exp(-z ** 2) 65 | -------------------------------------------------------------------------------- /src/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 8 | """ 9 | Returns torch.sqrt(torch.max(0, x)) 10 | but with a zero subgradient where x is 0. 11 | """ 12 | ret = torch.zeros_like(x) 13 | positive_mask = x > 0 14 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 15 | return ret 16 | 17 | 18 | def quaternion_to_matrix(quaternions): 19 | """ 20 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 21 | Convert rotations given as quaternions to rotation matrices. 22 | Args: 23 | quaternions: quaternions with real part first, 24 | as tensor of shape (..., 4). 25 | Returns: 26 | Rotation matrices as tensor of shape (..., 3, 3). 27 | """ 28 | r, i, j, k = torch.unbind(quaternions, -1) 29 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 30 | 31 | o = torch.stack( 32 | ( 33 | 1 - two_s * (j * j + k * k), 34 | two_s * (i * j - k * r), 35 | two_s * (i * k + j * r), 36 | two_s * (i * j + k * r), 37 | 1 - two_s * (i * i + k * k), 38 | two_s * (j * k - i * r), 39 | two_s * (i * k - j * r), 40 | two_s * (j * k + i * r), 41 | 1 - two_s * (i * i + j * j), 42 | ), 43 | -1, 44 | ) 45 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 46 | 47 | 48 | def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: 49 | """ 50 | Convert rotations given as rotation matrices to axis/angle. 51 | 52 | Args: 53 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 54 | 55 | Returns: 56 | Rotations given as a vector in axis angle form, as a tensor 57 | of shape (..., 3), where the magnitude is the angle 58 | turned anticlockwise in radians around the vector's 59 | direction. 60 | """ 61 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) 62 | 63 | 64 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 65 | """ 66 | Convert rotations given as rotation matrices to quaternions. 67 | 68 | Args: 69 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 70 | 71 | Returns: 72 | quaternions with real part first, as tensor of shape (..., 4). 73 | """ 74 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 75 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 76 | 77 | batch_dim = matrix.shape[:-2] 78 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 79 | matrix.reshape(batch_dim + (9,)), dim=-1 80 | ) 81 | 82 | q_abs = _sqrt_positive_part( 83 | torch.stack( 84 | [ 85 | 1.0 + m00 + m11 + m22, 86 | 1.0 + m00 - m11 - m22, 87 | 1.0 - m00 + m11 - m22, 88 | 1.0 - m00 - m11 + m22, 89 | ], 90 | dim=-1, 91 | ) 92 | ) 93 | 94 | # we produce the desired quaternion multiplied by each of r, i, j, k 95 | quat_by_rijk = torch.stack( 96 | [ 97 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 98 | # `int`. 99 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 100 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 101 | # `int`. 102 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 103 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 104 | # `int`. 105 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 106 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 107 | # `int`. 108 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 109 | ], 110 | dim=-2, 111 | ) 112 | 113 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 114 | # the candidate won't be picked. 115 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 116 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 117 | 118 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 119 | # forall i; we pick the best-conditioned one (with the largest denominator) 120 | 121 | return quat_candidates[ 122 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 123 | ].reshape(batch_dim + (4,)) 124 | 125 | 126 | def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: 127 | """ 128 | Convert rotations given as quaternions to axis/angle. 129 | 130 | Args: 131 | quaternions: quaternions with real part first, 132 | as tensor of shape (..., 4). 133 | 134 | Returns: 135 | Rotations given as a vector in axis angle form, as a tensor 136 | of shape (..., 3), where the magnitude is the angle 137 | turned anticlockwise in radians around the vector's 138 | direction. 139 | """ 140 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 141 | half_angles = torch.atan2(norms, quaternions[..., :1]) 142 | angles = 2 * half_angles 143 | eps = 1e-6 144 | small_angles = angles.abs() < eps 145 | sin_half_angles_over_angles = torch.empty_like(angles) 146 | sin_half_angles_over_angles[~small_angles] = ( 147 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 148 | ) 149 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 150 | # so sin(x/2)/x is about 1/2 - (x*x)/48 151 | sin_half_angles_over_angles[small_angles] = ( 152 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 153 | ) 154 | return quaternions[..., 1:] / sin_half_angles_over_angles 155 | 156 | 157 | def axis_angle_to_quaternion(axis_angle): 158 | """ 159 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 160 | Convert rotations given as axis/angle to quaternions. 161 | Args: 162 | axis_angle: Rotations given as a vector in axis angle form, 163 | as a tensor of shape (..., 3), where the magnitude is 164 | the angle turned anticlockwise in radians around the 165 | vector's direction. 166 | Returns: 167 | quaternions with real part first, as tensor of shape (..., 4). 168 | """ 169 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 170 | half_angles = 0.5 * angles 171 | eps = 1e-6 172 | small_angles = angles.abs() < eps 173 | sin_half_angles_over_angles = torch.empty_like(angles) 174 | sin_half_angles_over_angles[~small_angles] = ( 175 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 176 | ) 177 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 178 | # so sin(x/2)/x is about 1/2 - (x*x)/48 179 | sin_half_angles_over_angles[small_angles] = ( 180 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 181 | ) 182 | quaternions = torch.cat( 183 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 184 | ) 185 | return quaternions 186 | 187 | 188 | def axis_angle_to_matrix(axis_angle): 189 | """ 190 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 191 | Convert rotations given as axis/angle to rotation matrices. 192 | Args: 193 | axis_angle: Rotations given as a vector in axis angle form, 194 | as a tensor of shape (..., 3), where the magnitude is 195 | the angle turned anticlockwise in radians around the 196 | vector's direction. 197 | Returns: 198 | Rotation matrices as tensor of shape (..., 3, 3). 199 | """ 200 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 201 | 202 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: 203 | """ 204 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 205 | using Gram--Schmidt orthogonalization per Section B of [1]. 206 | Args: 207 | d6: 6D rotation representation, of size (*, 6) 208 | 209 | Returns: 210 | batch of rotation matrices of size (*, 3, 3) 211 | 212 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 213 | On the Continuity of Rotation Representations in Neural Networks. 214 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 215 | Retrieved from http://arxiv.org/abs/1812.07035 216 | """ 217 | 218 | a1, a2 = d6[..., :3], d6[..., 3:] 219 | b1 = F.normalize(a1, dim=-1) 220 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 221 | b2 = F.normalize(b2, dim=-1) 222 | b3 = torch.cross(b1, b2, dim=-1) 223 | return torch.stack((b1, b2, b3), dim=-2) 224 | 225 | def rigid_transform_Kabsch_3D_torch(A, B): 226 | # R = 3x3 rotation matrix, t = 3x1 column vector 227 | # This already takes residue identity into account. 228 | 229 | assert A.shape[1] == B.shape[1] 230 | num_rows, num_cols = A.shape 231 | if num_rows != 3: 232 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 233 | num_rows, num_cols = B.shape 234 | if num_rows != 3: 235 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 236 | 237 | 238 | # find mean column wise: 3 x 1 239 | centroid_A = torch.mean(A, axis=1, keepdims=True) 240 | centroid_B = torch.mean(B, axis=1, keepdims=True) 241 | 242 | # subtract mean 243 | Am = A - centroid_A 244 | Bm = B - centroid_B 245 | 246 | H = Am @ Bm.T 247 | 248 | # find rotation 249 | U, S, Vt = torch.linalg.svd(H) 250 | 251 | R = Vt.T @ U.T 252 | # special reflection case 253 | if torch.linalg.det(R) < 0: 254 | # print("det(R) < R, reflection detected!, correcting for it ...") 255 | SS = torch.diag(torch.tensor([1.,1.,-1.], device=A.device)) 256 | R = (Vt.T @ SS) @ U.T 257 | assert math.fabs(torch.linalg.det(R) - 1) < 3e-3 # note I had to change this error bound to be higher 258 | 259 | t = -R @ centroid_A + centroid_B 260 | return R, t 261 | 262 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 263 | """ 264 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 265 | by dropping the last row. Note that 6D representation is not unique. 266 | Args: 267 | matrix: batch of rotation matrices of size (*, 3, 3) 268 | 269 | Returns: 270 | 6D rotation representation, of size (*, 6) 271 | 272 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 273 | On the Continuity of Rotation Representations in Neural Networks. 274 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 275 | Retrieved from http://arxiv.org/abs/1812.07035 276 | """ 277 | batch_dim = matrix.size()[:-2] 278 | return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) 279 | 280 | def axis_angle_to_rotation_6d(axis_angle): 281 | return matrix_to_rotation_6d(quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))) 282 | 283 | def rotation_6d_to_axis_angle(rotation_6d): 284 | return matrix_to_axis_angle(rotation_6d_to_matrix(rotation_6d)) 285 | 286 | def vector_to_skew_matrix(vectors: torch.Tensor) -> torch.Tensor: 287 | """ 288 | Map a vector into the corresponding skew matrix so(3) basis. 289 | ``` 290 | [ 0 -z y] 291 | [x,y,z] -> [ z 0 -x] 292 | [ -y x 0] 293 | ``` 294 | 295 | Args: 296 | vectors (torch.Tensor): Batch of vectors to be mapped to skew matrices. 297 | 298 | Returns: 299 | torch.Tensor: Vectors in skew matrix representation. 300 | """ 301 | # Generate empty skew matrices. 302 | skew_matrices = torch.zeros((*vectors.shape, 3), device=vectors.device, dtype=vectors.dtype) 303 | 304 | # Populate positive values. 305 | skew_matrices[..., 2, 1] = vectors[..., 0] 306 | skew_matrices[..., 0, 2] = vectors[..., 1] 307 | skew_matrices[..., 1, 0] = vectors[..., 2] 308 | 309 | # Generate skew symmetry. 310 | skew_matrices = skew_matrices - skew_matrices.transpose(-2, -1) 311 | 312 | return skew_matrices 313 | 314 | def skew_matrix_to_vector(skew_matrices: torch.Tensor) -> torch.Tensor: 315 | """ 316 | Extract a rotation vector from the so(3) skew matrix basis. 317 | 318 | Args: 319 | skew_matrices (torch.Tensor): Skew matrices. 320 | 321 | Returns: 322 | torch.Tensor: Rotation vectors corresponding to skew matrices. 323 | """ 324 | vectors = torch.zeros_like(skew_matrices[..., 0]) 325 | vectors[..., 0] = skew_matrices[..., 2, 1] 326 | vectors[..., 1] = skew_matrices[..., 0, 2] 327 | vectors[..., 2] = skew_matrices[..., 1, 0] 328 | return vectors 329 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def softmax_cross_entropy(logits, labels): 5 | loss = -1 * torch.sum( 6 | labels * torch.nn.functional.log_softmax(logits, dim=-1), 7 | dim=-1, 8 | ) 9 | return loss 10 | 11 | def _calculate_bin_centers(boundaries: torch.Tensor): 12 | step = boundaries[1] - boundaries[0] 13 | bin_centers = boundaries + step / 2 14 | bin_centers = torch.cat( 15 | [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0 16 | ) 17 | return bin_centers 18 | 19 | def compute_tm( 20 | logits: torch.Tensor, 21 | max_bin: int = 31, 22 | no_bins: int = 64, 23 | eps: float = 1e-8, 24 | ) -> torch.Tensor: 25 | boundaries = torch.linspace( 26 | 0, max_bin, steps=(no_bins - 1), device=logits.device 27 | ) 28 | 29 | bin_centers = _calculate_bin_centers(boundaries) 30 | clipped_n = max(logits.size(0) + logits.size(1), 19) 31 | 32 | d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 33 | 34 | probs = torch.nn.functional.softmax(logits, dim=-1) 35 | 36 | tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) 37 | predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) 38 | 39 | max_sum = max(torch.mean(predicted_tm_term, dim=0).max(), torch.mean(predicted_tm_term, dim=1).max()) 40 | 41 | return max_sum 42 | 43 | def get_tm_loss( 44 | logits, 45 | sq_diff, 46 | max_bin=31, 47 | no_bins=64, 48 | ): 49 | sq_diff = sq_diff.detach() 50 | 51 | boundaries = torch.linspace( 52 | 0, max_bin, steps=(no_bins - 1), device=logits.device 53 | ) 54 | boundaries = boundaries ** 2 55 | true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1) 56 | 57 | errors = softmax_cross_entropy( 58 | logits, torch.nn.functional.one_hot(true_bins, no_bins) 59 | ) 60 | 61 | loss = torch.mean(errors) 62 | 63 | return loss 64 | 65 | def distogram_loss( 66 | logits, 67 | dists, 68 | min_bin=3.25, 69 | max_bin=50.75, 70 | no_bins=64, 71 | eps=1e-6, 72 | **kwargs, 73 | ) -> torch.Tensor: 74 | """ 75 | """ 76 | boundaries = torch.linspace( 77 | min_bin, 78 | max_bin, 79 | no_bins - 1, 80 | device=logits.device, 81 | ) 82 | boundaries = boundaries ** 2 83 | 84 | true_bins = torch.sum(dists ** 2 > boundaries, dim=-1) 85 | 86 | errors = softmax_cross_entropy( 87 | logits, 88 | torch.nn.functional.one_hot(true_bins, no_bins), 89 | ) 90 | 91 | loss = torch.mean(errors) 92 | 93 | return loss -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_metrics(model, native): 4 | # get inputs 5 | model_rec = model[0].squeeze() 6 | model_lig = model[1].squeeze() 7 | native_rec = native[0].squeeze() 8 | native_lig = native[1].squeeze() 9 | 10 | # calc metrics 11 | c_rmsd = get_c_rmsd(model_rec, model_lig, native_rec, native_lig) 12 | i_rmsd = get_i_rmsd(model_rec, model_lig, native_rec, native_lig) 13 | l_rmsd = get_l_rmsd(model_rec, model_lig, native_rec, native_lig) 14 | fnat = get_fnat(model_rec, model_lig, native_rec, native_lig) 15 | DockQ = get_DockQ(i_rmsd, l_rmsd, fnat) 16 | return {'c_rmsd': c_rmsd, 'i_rmsd': i_rmsd, 'l_rmsd': l_rmsd, 'fnat': fnat, 'DockQ': DockQ} 17 | 18 | def get_interface_res(x1, x2, cutoff=10.0): 19 | # Calculate pairwise distances 20 | dist = x1[..., None, :, None, :] - x2[..., None, :, None, :, :] 21 | dist = (dist ** 2).sum(dim=-1).sqrt().flatten(start_dim=-2) 22 | 23 | # Find minimum distance between each pair of residues 24 | min_dist, _ = torch.min(dist, dim=-1) 25 | 26 | # Find index < cutoff 27 | index = torch.where(min_dist < cutoff) 28 | res1 = torch.unique(index[0]) 29 | res2 = torch.unique(index[1]) 30 | return res1, res2 31 | 32 | def get_c_rmsd(model_rec, model_lig, native_rec, native_lig): 33 | pred = torch.cat([model_rec, model_lig], dim=0).flatten(end_dim=1) 34 | label = torch.cat([native_rec, native_lig], dim=0).flatten(end_dim=1) 35 | R, t = find_rigid_alignment(pred, label) 36 | pred = (R.mm(pred.T)).T + t 37 | return get_rmsd(pred, label).item() 38 | 39 | def get_i_rmsd(model_rec, model_lig, native_rec, native_lig): 40 | res1, res2 = get_interface_res(native_rec, native_lig, cutoff=10.0) 41 | pred = torch.cat([model_rec[res1], model_lig[res2]], dim=0).flatten(end_dim=1) 42 | label = torch.cat([native_rec[res1], native_lig[res2]], dim=0).flatten(end_dim=1) 43 | R, t = find_rigid_alignment(pred, label) 44 | pred = (R.mm(pred.T)).T + t 45 | return get_rmsd(pred, label).item() 46 | 47 | def get_l_rmsd(model_rec, model_lig, native_rec, native_lig): 48 | model_rec = model_rec.flatten(end_dim=1) 49 | model_lig = model_lig.flatten(end_dim=1) 50 | native_rec = native_rec.flatten(end_dim=1) 51 | native_lig = native_lig.flatten(end_dim=1) 52 | R, t = find_rigid_alignment(model_rec, native_rec) 53 | model_lig = (R.mm(model_lig.T)).T + t 54 | return get_rmsd(model_lig, native_lig).item() 55 | 56 | def get_fnat(model_rec, model_lig, native_rec, native_lig, cutoff=5.5): 57 | ligand_receptor_distance = get_dist(native_rec, native_lig) 58 | positive_tuple = torch.where(ligand_receptor_distance < cutoff) 59 | active_receptor = positive_tuple[0] 60 | active_ligand = positive_tuple[1] 61 | assert len(active_ligand) == len(active_receptor) 62 | ligand_receptor_distance_pred = get_dist(model_rec, model_lig) 63 | selected_elements = ligand_receptor_distance_pred[active_receptor, active_ligand] 64 | count = torch.count_nonzero(selected_elements < cutoff) 65 | Fnat = round(count.item() / (len(active_ligand) + 1e-6), 6) 66 | return Fnat 67 | 68 | def get_DockQ(i_rmsd, l_rmsd, fnat): 69 | i_rmsd_scaled = 1.0 / (1.0 + (i_rmsd/1.5)**2) 70 | l_rmsd_scaled = 1.0 / (1.0 + (l_rmsd/8.5)**2) 71 | return (fnat + i_rmsd_scaled + l_rmsd_scaled) / 3 72 | 73 | def get_dist(x1, x2): 74 | # Calculate pairwise distances 75 | dist = x1[..., None, :, None, :] - x2[..., None, :, None, :, :] 76 | dist = (dist ** 2).sum(dim=-1).sqrt().flatten(start_dim=-2) 77 | 78 | # Find minimum distance between each pair of residues 79 | min_dist, _ = torch.min(dist, dim=-1) 80 | 81 | return min_dist 82 | 83 | def get_rmsd(pred, label): 84 | rmsd = torch.sqrt(torch.mean(torch.sum((pred - label) ** 2.0, dim=-1))) 85 | return rmsd 86 | 87 | def find_rigid_alignment(A, B): 88 | """ 89 | align A to B 90 | See: https://en.wikipedia.org/wiki/Kabsch_algorithm 91 | 2-D or 3-D registration with known correspondences. 92 | Registration occurs in the zero centered coordinate system, and then 93 | must be transported back. 94 | Args: 95 | - A: Torch tensor of shape (N,D) -- Point Cloud to Align (source) 96 | - B: Torch tensor of shape (N,D) -- Reference Point Cloud (target) 97 | Returns: 98 | - R: optimal rotation 99 | - t: optimal translation 100 | Test on rotation + translation and on rotation + translation + reflection 101 | """ 102 | a_mean = A.mean(axis=0) 103 | b_mean = B.mean(axis=0) 104 | A_c = A - a_mean 105 | B_c = B - b_mean 106 | # Covariance matrix 107 | H = A_c.T.mm(B_c) 108 | U, S, Vt = torch.linalg.svd(H) 109 | # Rotation matrix 110 | R = Vt.T.mm(U.T) 111 | 112 | # special reflection case 113 | if torch.linalg.det(R) < 0: 114 | # print("det(R) < R, reflection detected!, correcting for it ...") 115 | SS = torch.diag(torch.tensor([1.,1.,-1.], device=R.device)) 116 | R = (Vt.T @ SS) @ U.T 117 | 118 | # Translation vector 119 | t = b_mean[None, :] - R.mm(a_mean[None, :].T).T 120 | t = t.T 121 | return R, t.squeeze() 122 | 123 | -------------------------------------------------------------------------------- /src/utils/pdb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | _aa_1_3_dict = { 5 | 'A': 'ALA', 6 | 'C': 'CYS', 7 | 'D': 'ASP', 8 | 'E': 'GLU', 9 | 'F': 'PHE', 10 | 'G': 'GLY', 11 | 'H': 'HIS', 12 | 'I': 'ILE', 13 | 'K': 'LYS', 14 | 'L': 'LEU', 15 | 'M': 'MET', 16 | 'N': 'ASN', 17 | 'P': 'PRO', 18 | 'Q': 'GLN', 19 | 'R': 'ARG', 20 | 'S': 'SER', 21 | 'T': 'THR', 22 | 'V': 'VAL', 23 | 'W': 'TRP', 24 | 'Y': 'TYR', 25 | '-': 'GAP', 26 | 'X': 'URI', 27 | } 28 | 29 | 30 | 31 | def place_fourth_atom(a_coord: torch.Tensor, b_coord: torch.Tensor, 32 | c_coord: torch.Tensor, length: torch.Tensor, 33 | planar: torch.Tensor, 34 | dihedral: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Given 3 coords + a length + a planar angle + a dihedral angle, compute a fourth coord 37 | """ 38 | bc_vec = b_coord - c_coord 39 | bc_vec = bc_vec / bc_vec.norm(dim=-1, keepdim=True) 40 | 41 | n_vec = (b_coord - a_coord).expand(bc_vec.shape).cross(bc_vec, dim=-1) 42 | n_vec = n_vec / n_vec.norm(dim=-1, keepdim=True) 43 | 44 | m_vec = [bc_vec, n_vec.cross(bc_vec, dim=-1), n_vec] 45 | d_vec = [ 46 | length * torch.cos(planar), 47 | length * torch.sin(planar) * torch.cos(dihedral), 48 | -length * torch.sin(planar) * torch.sin(dihedral) 49 | ] 50 | 51 | d_coord = c_coord + sum([m * d for m, d in zip(m_vec, d_vec)]) 52 | return d_coord 53 | 54 | 55 | def save_PDB(out_pdb: str, 56 | coords: torch.Tensor, 57 | seq: str, 58 | b_factors: torch.Tensor = None, 59 | delim: int = None) -> None: 60 | """ 61 | Write set of N, CA, C, O, CB coords to PDB file 62 | """ 63 | 64 | if type(delim) == type(None): 65 | delim = -1 66 | 67 | if b_factors is None: 68 | b_factors = torch.zeros(coords.size(0), device=coords.device) 69 | 70 | atoms = ['N', 'CA', 'C', 'O', 'CB'] 71 | 72 | with open(out_pdb, "a") as f: 73 | k = 0 74 | for r, residue in enumerate(coords): 75 | AA = _aa_1_3_dict[seq[r]] 76 | for a, atom in enumerate(residue): 77 | if AA == "GLY" and atoms[a] == "CB": continue 78 | x, y, z = atom 79 | f.write( 80 | "ATOM %5d %-2s %3s %s%4d %8.3f%8.3f%8.3f %4.2f %4.2f\n" 81 | % (k + 1, atoms[a], AA, "A" if r <= delim else "B", r + 1, 82 | x, y, z, 1, b_factors[r])) 83 | k += 1 84 | f.close() 85 | 86 | 87 | def save_PDB_3(out_pdb: str, 88 | coords: torch.Tensor, 89 | seq: str, 90 | delim: int = None) -> None: 91 | """ 92 | Write set of N, CA, C coords to PDB file 93 | """ 94 | 95 | if type(delim) == type(None): 96 | delim = -1 97 | 98 | atoms = ['N', 'CA', 'C'] 99 | 100 | with open(out_pdb, "w") as f: 101 | k = 0 102 | for r, residue in enumerate(coords): 103 | AA = _aa_1_3_dict[seq[r]] 104 | for a, atom in enumerate(residue): 105 | x, y, z = atom 106 | f.write( 107 | "ATOM %5d %-2s %3s %s%4d %8.3f%8.3f%8.3f %4.2f\n" 108 | % (k + 1, atoms[a], AA, "A" if r <= delim else "B", r + 1, 109 | x, y, z, 1)) 110 | k += 1 111 | f.close() -------------------------------------------------------------------------------- /src/utils/r3_diffuser.py: -------------------------------------------------------------------------------- 1 | # The code is adopted from: 2 | # https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/r3_diffuser.py 3 | 4 | import numpy as np 5 | import torch 6 | 7 | #---------------------------------------------------------------------------- 8 | # Helper functions 9 | 10 | move_to_np = lambda x: x.cpu().detach().numpy() 11 | 12 | #---------------------------------------------------------------------------- 13 | # VE-SDE diffuser class for diffusion in 3D translational (R(3)) space 14 | 15 | class R3Diffuser: 16 | def __init__(self, conf): 17 | self.min_sigma = conf.min_sigma 18 | self.max_sigma = conf.max_sigma 19 | 20 | def sigma(self, t): 21 | return self.min_sigma * (self.max_sigma / self.min_sigma) ** t 22 | 23 | def diffusion_coef(self, t): 24 | return self.sigma(t) * np.sqrt(2 * (np.log(self.max_sigma) - np.log(self.min_sigma))) 25 | 26 | def torch_score(self, tr_t, t): 27 | return -tr_t / self.sigma(t)**2 28 | 29 | def score_scaling(self, t: float): 30 | return 1 / self.sigma(t) 31 | 32 | def forward_marginal(self, t: float): 33 | if not np.isscalar(t): 34 | raise ValueError(f'{t} must be a scalar.') 35 | z = np.random.randn(1, 3) 36 | tr_t = self.sigma(t) * z 37 | tr_score = self.torch_score(tr_t, t) 38 | return tr_t, tr_score 39 | 40 | def torch_reverse( 41 | self, 42 | score_t: torch.tensor, 43 | dt: torch.tensor, 44 | t: float, 45 | noise_scale: float=1.0, 46 | ode: bool=False, 47 | ): 48 | if not np.isscalar(t): raise ValueError(f'{t} must be a scalar.') 49 | g_t = self.diffusion_coef(t) 50 | if not ode: 51 | z = noise_scale * torch.randn(1, 3, device=score_t.device) 52 | perturb = (g_t ** 2) * score_t * dt + g_t * torch.sqrt(dt) * z 53 | else: 54 | perturb = 0.5 * (g_t ** 2) * score_t * dt 55 | return perturb.float() 56 | 57 | -------------------------------------------------------------------------------- /src/utils/so3_diffuser.py: -------------------------------------------------------------------------------- 1 | # The code is adopted from: 2 | # https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/so3_diffuser.py 3 | 4 | import numpy as np 5 | import os 6 | import logging 7 | import torch 8 | from scipy.spatial.transform import Rotation 9 | 10 | #---------------------------------------------------------------------------- 11 | # Helper functions 12 | 13 | move_to_np = lambda x: x.cpu().detach().numpy() 14 | 15 | def rotvec_to_matrix(rotvec): 16 | return Rotation.from_rotvec(rotvec).as_matrix() 17 | 18 | def matrix_to_rotvec(mat): 19 | return Rotation.from_matrix(mat).as_rotvec() 20 | 21 | def compose_rotvec(r1, r2): 22 | """Compose two rotation euler vectors.""" 23 | R1 = rotvec_to_matrix(r1) 24 | R2 = rotvec_to_matrix(r2) 25 | cR = np.einsum('...ij,...jk->...ik', R1, R2) 26 | return matrix_to_rotvec(cR) 27 | 28 | def igso3_expansion(omega, eps, L=1000, use_torch=False): 29 | """Truncated sum of IGSO(3) distribution. 30 | 31 | This function approximates the power series in equation 5 of 32 | "DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL 33 | ALIGNMENT" 34 | Leach et al. 2022 35 | 36 | This expression diverges from the expression in Leach in that here, eps = 37 | sqrt(2) * eps_leach, if eps_leach were the scale parameter of the IGSO(3). 38 | 39 | With this reparameterization, IGSO(3) agrees with the Brownian motion on 40 | SO(3) with t=eps^2. 41 | 42 | Args: 43 | omega: rotation of Euler vector (i.e. the angle of rotation) 44 | eps: std of IGSO(3). 45 | L: Truncation level 46 | use_torch: set true to use torch tensors, otherwise use numpy arrays. 47 | """ 48 | 49 | lib = torch if use_torch else np 50 | ls = lib.arange(L) 51 | if use_torch: 52 | ls = ls.to(omega.device) 53 | if len(omega.shape) == 2: 54 | # Used during predicted score calculation. 55 | ls = ls[None, None] # [1, 1, L] 56 | omega = omega[..., None] # [num_batch, num_res, 1] 57 | eps = eps[..., None] 58 | elif len(omega.shape) == 1: 59 | # Used during cache computation. 60 | ls = ls[None] # [1, L] 61 | omega = omega[..., None] # [num_batch, 1] 62 | else: 63 | raise ValueError("Omega must be 1D or 2D.") 64 | p = (2*ls + 1) * lib.exp(-ls*(ls+1)*eps**2/2) * lib.sin(omega*(ls+1/2)) / lib.sin(omega/2) 65 | if use_torch: 66 | return p.sum(dim=-1) 67 | else: 68 | return p.sum(axis=-1) 69 | 70 | def density(expansion, omega, marginal=True): 71 | """IGSO(3) density. 72 | 73 | Args: 74 | expansion: truncated approximation of the power series in the IGSO(3) 75 | density. 76 | omega: length of an Euler vector (i.e. angle of rotation) 77 | marginal: set true to give marginal density over the angle of rotation, 78 | otherwise include normalization to give density on SO(3) or a 79 | rotation with angle omega. 80 | """ 81 | if marginal: 82 | # if marginal, density over [0, pi], else over SO(3) 83 | return expansion * (1-np.cos(omega))/np.pi 84 | else: 85 | # the constant factor doesn't affect any actual calculations though 86 | return expansion / 8 / np.pi**2 87 | 88 | def score(exp, omega, eps, L=1000, use_torch=False): # score of density over SO(3) 89 | """score uses the quotient rule to compute the scaling factor for the score 90 | of the IGSO(3) density. 91 | 92 | This function is used within the Diffuser class to when computing the score 93 | as an element of the tangent space of SO(3). 94 | 95 | This uses the quotient rule of calculus, and take the derivative of the 96 | log: 97 | d hi(x)/lo(x) = (lo(x) d hi(x)/dx - hi(x) d lo(x)/dx) / lo(x)^2 98 | and 99 | d log expansion(x) / dx = (d expansion(x)/ dx) / expansion(x) 100 | 101 | Args: 102 | exp: truncated expansion of the power series in the IGSO(3) density 103 | omega: length of an Euler vector (i.e. angle of rotation) 104 | eps: scale parameter for IGSO(3) -- as in expansion() this scaling 105 | differ from that in Leach by a factor of sqrt(2). 106 | L: truncation level 107 | use_torch: set true to use torch tensors, otherwise use numpy arrays. 108 | 109 | Returns: 110 | The d/d omega log IGSO3(omega; eps)/(1-cos(omega)) 111 | 112 | """ 113 | 114 | lib = torch if use_torch else np 115 | ls = lib.arange(L) 116 | if use_torch: 117 | ls = ls.to(omega.device) 118 | ls = ls[None] 119 | if len(omega.shape) == 2: 120 | ls = ls[None] 121 | elif len(omega.shape) > 2: 122 | raise ValueError("Omega must be 1D or 2D.") 123 | omega = omega[..., None] 124 | eps = eps[..., None] 125 | hi = lib.sin(omega * (ls + 1 / 2)) 126 | dhi = (ls + 1 / 2) * lib.cos(omega * (ls + 1 / 2)) 127 | lo = lib.sin(omega / 2) 128 | dlo = 1 / 2 * lib.cos(omega / 2) 129 | dSigma = (2 * ls + 1) * lib.exp(-ls * (ls + 1) * eps**2/2) * (lo * dhi - hi * dlo) / lo ** 2 130 | if use_torch: 131 | dSigma = dSigma.sum(dim=-1) 132 | else: 133 | dSigma = dSigma.sum(axis=-1) 134 | return dSigma / (exp + 1e-4) 135 | 136 | #---------------------------------------------------------------------------- 137 | # VE-SDE diffuser class for diffusion in 3D rotational (SO(3)) space 138 | 139 | class SO3Diffuser: 140 | def __init__(self, so3_conf): 141 | self.schedule = so3_conf.schedule 142 | 143 | self.min_sigma = so3_conf.min_sigma 144 | self.max_sigma = so3_conf.max_sigma 145 | 146 | self.num_sigma = so3_conf.num_sigma 147 | self.use_cached_score = so3_conf.use_cached_score 148 | self._log = logging.getLogger(__name__) 149 | 150 | # Discretize omegas for calculating CDFs. Skip omega=0. 151 | self.discrete_omega = np.linspace(0, np.pi, so3_conf.num_omega+1)[1:] 152 | 153 | # Precompute IGSO3 values. 154 | replace_period = lambda x: str(x).replace('.', '_') 155 | cache_dir = os.path.join( 156 | so3_conf.cache_dir, 157 | f'eps_{so3_conf.num_sigma}_omega_{so3_conf.num_omega}_min_sigma_{replace_period(so3_conf.min_sigma)}_max_sigma_{replace_period(so3_conf.max_sigma)}_schedule_{so3_conf.schedule}' 158 | ) 159 | 160 | # If cache directory doesn't exist, create it 161 | if not os.path.isdir(cache_dir): 162 | os.makedirs(cache_dir) 163 | pdf_cache = os.path.join(cache_dir, 'pdf_vals.npy') 164 | cdf_cache = os.path.join(cache_dir, 'cdf_vals.npy') 165 | score_norms_cache = os.path.join(cache_dir, 'score_norms.npy') 166 | 167 | if os.path.exists(pdf_cache) and os.path.exists(cdf_cache) and os.path.exists(score_norms_cache): 168 | self._log.info(f'Using cached IGSO3 in {cache_dir}') 169 | self._pdf = np.load(pdf_cache) 170 | self._cdf = np.load(cdf_cache) 171 | self._score_norms = np.load(score_norms_cache) 172 | else: 173 | self._log.info(f'Computing IGSO3. Saving in {cache_dir}') 174 | # compute the expansion of the power series 175 | exp_vals = np.asarray( 176 | [igso3_expansion(self.discrete_omega, sigma) for sigma in self.discrete_sigma]) 177 | # Compute the pdf and cdf values for the marginal distribution of the angle 178 | # of rotation (which is needed for sampling) 179 | self._pdf = np.asarray( 180 | [density(x, self.discrete_omega, marginal=True) for x in exp_vals]) 181 | self._cdf = np.asarray( 182 | [pdf.cumsum() / so3_conf.num_omega * np.pi for pdf in self._pdf]) 183 | 184 | # Compute the norms of the scores. This are used to scale the rotation axis when 185 | # computing the score as a vector. 186 | self._score_norms = np.asarray( 187 | [score(exp_vals[i], self.discrete_omega, x) for i, x in enumerate(self.discrete_sigma)]) 188 | 189 | # Cache the precomputed values 190 | np.save(pdf_cache, self._pdf) 191 | np.save(cdf_cache, self._cdf) 192 | np.save(score_norms_cache, self._score_norms) 193 | 194 | self._score_scaling = np.sqrt(np.abs( 195 | np.sum( 196 | self._score_norms**2 * self._pdf, axis=-1) / np.sum( 197 | self._pdf, axis=-1) 198 | )) / np.sqrt(3) 199 | 200 | @property 201 | def discrete_sigma(self): 202 | return self.sigma( 203 | np.linspace(0.0, 1.0, self.num_sigma) 204 | ) 205 | 206 | def sigma_idx(self, sigma: np.ndarray): 207 | """Calculates the index for discretized sigma during IGSO(3) initialization.""" 208 | return np.digitize(sigma, self.discrete_sigma) - 1 209 | 210 | def sigma(self, t: np.ndarray): 211 | """Extract \sigma(t) corresponding to chosen sigma schedule.""" 212 | if np.any(t < 0) or np.any(t > 1): 213 | raise ValueError(f'Invalid t={t}') 214 | if self.schedule == 'logarithmic': 215 | return np.log(t * np.exp(self.max_sigma) + (1 - t) * np.exp(self.min_sigma)) 216 | else: 217 | raise ValueError(f'Unrecognize schedule {self.schedule}') 218 | 219 | def diffusion_coef(self, t): 220 | """Compute diffusion coefficient (g_t).""" 221 | if self.schedule == 'logarithmic': 222 | g_t = np.sqrt( 223 | 2 * (np.exp(self.max_sigma) - np.exp(self.min_sigma)) * self.sigma(t) / np.exp(self.sigma(t)) 224 | ) 225 | else: 226 | raise ValueError(f'Unrecognize schedule {self.schedule}') 227 | return g_t 228 | 229 | def t_to_idx(self, t: np.ndarray): 230 | """Helper function to go from time t to corresponding sigma_idx.""" 231 | return self.sigma_idx(self.sigma(t)) 232 | 233 | def sample_igso3( 234 | self, 235 | t: float, 236 | n_samples: float=1): 237 | """Uses the inverse cdf to sample an angle of rotation from IGSO(3). 238 | 239 | Args: 240 | t: continuous time in [0, 1]. 241 | n_samples: number of samples to draw. 242 | 243 | Returns: 244 | [n_samples] angles of rotation. 245 | """ 246 | if not np.isscalar(t): 247 | raise ValueError(f'{t} must be a scalar.') 248 | x = np.random.rand(n_samples) 249 | return np.interp(x, self._cdf[self.t_to_idx(t)], self.discrete_omega) 250 | 251 | def sample( 252 | self, 253 | t: float, 254 | n_samples: float=1): 255 | """Generates rotation vector(s) from IGSO(3). 256 | 257 | Args: 258 | t: continuous time in [0, 1]. 259 | n_sample: number of samples to generate. 260 | 261 | Returns: 262 | [n_samples, 3] axis-angle rotation vectors sampled from IGSO(3). 263 | """ 264 | x = np.random.randn(n_samples, 3) 265 | x /= np.linalg.norm(x, axis=-1, keepdims=True) 266 | return x * self.sample_igso3(t, n_samples=n_samples)[:, None] 267 | 268 | def score( 269 | self, 270 | vec: np.ndarray, 271 | t: float, 272 | eps: float=1e-6 273 | ): 274 | """Computes the score of IGSO(3) density as a rotation vector. 275 | 276 | Args: 277 | vec: [..., 3] array of axis-angle rotation vectors. 278 | t: continuous time in [0, 1]. 279 | 280 | Returns: 281 | [..., 3] score vector in the direction of the sampled vector with 282 | magnitude given by _score_norms. 283 | """ 284 | if not np.isscalar(t): 285 | raise ValueError(f'{t} must be a scalar.') 286 | torch_score = self.torch_score(torch.tensor(vec), torch.tensor(t)[None]) 287 | return torch_score.numpy() 288 | 289 | def torch_score( 290 | self, 291 | vec: torch.tensor, 292 | t: torch.tensor, 293 | eps: float=1e-6, 294 | ): 295 | """Computes the score of IGSO(3) density as a rotation vector. 296 | 297 | Same as score function but uses pytorch and performs a look-up. 298 | 299 | Args: 300 | vec: [..., 3] array of axis-angle rotation vectors. 301 | t: continuous time in [0, 1]. 302 | 303 | Returns: 304 | [..., 3] score vector in the direction of the sampled vector with 305 | magnitude given by _score_norms. 306 | """ 307 | omega = torch.linalg.norm(vec, dim=-1) + eps 308 | if self.use_cached_score: 309 | score_norms_t = self._score_norms[self.t_to_idx(move_to_np(t))] 310 | score_norms_t = torch.tensor(score_norms_t).to(vec.device) 311 | omega_idx = torch.bucketize( 312 | omega, torch.tensor(self.discrete_omega[:-1]).to(vec.device)) 313 | omega_scores_t = torch.gather( 314 | score_norms_t, 1, omega_idx) 315 | else: 316 | sigma = self.discrete_sigma[self.t_to_idx(move_to_np(t))] 317 | sigma = torch.tensor(sigma).to(vec.device) 318 | omega_vals = igso3_expansion(omega, sigma[:, None], use_torch=True) 319 | omega_scores_t = score(omega_vals, omega, sigma[:, None], use_torch=True) 320 | return omega_scores_t * vec / (omega[..., None] + eps) 321 | 322 | def score_scaling(self, t: np.ndarray): 323 | """Calculates scaling used for scores during trianing.""" 324 | return self._score_scaling[self.t_to_idx(t)] 325 | 326 | def forward_marginal(self, t: float, rot_0: np.ndarray = np.zeros((1, 3))): 327 | """Samples from the forward diffusion process at time index t. 328 | 329 | Args: 330 | rot_0: [..., 3] initial rotations. 331 | t: continuous time in [0, 1]. 332 | 333 | Returns: 334 | rot_t: [..., 3] noised rotation vectors. 335 | rot_score: [..., 3] score of rot_t as a rotation vector. 336 | """ 337 | sampled_rots = self.sample(t, n_samples=1) 338 | rot_score = self.score(sampled_rots, t).reshape(rot_0.shape) 339 | 340 | # Right multiply. 341 | rot_t = compose_rotvec(rot_0, sampled_rots).reshape(rot_0.shape) 342 | return rot_t, rot_score 343 | 344 | def torch_reverse( 345 | self, 346 | score_t: torch.tensor, 347 | dt: torch.tensor, 348 | t: float, 349 | noise_scale: float=1.0, 350 | ode: bool=False, 351 | ): 352 | """Simulates the reverse SDE for 1 step using the Geodesic random walk. 353 | 354 | Args: 355 | score_t: [..., 3] rotation score at time t. 356 | t: continuous time in [0, 1]. 357 | dt: continuous step size in [0, 1]. 358 | 359 | Returns: 360 | [..., 3] rotation vector at next step. 361 | """ 362 | if not np.isscalar(t): raise ValueError(f'{t} must be a scalar.') 363 | g_t = self.diffusion_coef(t) 364 | if not ode: 365 | z = noise_scale * torch.randn(1, 3, device=score_t.device) 366 | perturb = (g_t ** 2) * score_t * dt + g_t * torch.sqrt(dt) * z 367 | else: 368 | perturb = 0.5 * (g_t ** 2) * score_t * dt 369 | return perturb.float() 370 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import List, Sequence 4 | 5 | import pytorch_lightning as pl 6 | import rich.syntax 7 | import rich.tree 8 | from omegaconf import DictConfig, OmegaConf 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | 12 | def get_logger(name=__name__) -> logging.Logger: 13 | """Initializes multi-GPU-friendly python command line logger.""" 14 | 15 | logger = logging.getLogger(name) 16 | 17 | # this ensures all logging levels get marked with the rank zero decorator 18 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 19 | for level in ( 20 | "debug", 21 | "info", 22 | "warning", 23 | "error", 24 | "exception", 25 | "fatal", 26 | "critical", 27 | ): 28 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 29 | 30 | return logger 31 | 32 | 33 | def extras(config: DictConfig) -> None: 34 | """A couple of optional utilities, controlled by main config file: 35 | - disabling warnings 36 | - forcing debug friendly configuration 37 | - verifying experiment name is set when running in experiment mode 38 | Modifies DictConfig in place. 39 | Args: 40 | config (DictConfig): Configuration composed by Hydra. 41 | """ 42 | 43 | log = get_logger(__name__) 44 | 45 | # disable python warnings if 46 | if config.get("ignore_warnings"): 47 | log.info("Disabling python warnings! ") 48 | warnings.filterwarnings("ignore") 49 | 50 | # verify experiment name is set when running in experiment mode 51 | if config.get("experiment_mode") and not config.get("name"): 52 | log.info( 53 | "Running in experiment mode without the experiment name specified! " 54 | "Use `python run.py mode=exp name=experiment_name`" 55 | ) 56 | log.info("Exiting...") 57 | exit() 58 | 59 | # force debugger friendly configuration if 60 | # debuggers don't like GPUs and multiprocessing 61 | if config.trainer.get("fast_dev_run"): 62 | log.info("Forcing debugger friendly configuration! ") 63 | if config.trainer.get("gpus"): 64 | config.trainer.gpus = 0 65 | if config.datamodule.get("pin_memory"): 66 | config.datamodule.pin_memory = False 67 | if config.datamodule.get("num_workers"): 68 | config.datamodule.num_workers = 0 69 | 70 | 71 | @rank_zero_only 72 | def print_config( 73 | config: DictConfig, 74 | fields: Sequence[str] = ( 75 | "trainer", 76 | "model", 77 | "datamodule", 78 | "callbacks", 79 | "logger", 80 | "test_after_training", 81 | "seed", 82 | "name", 83 | ), 84 | resolve: bool = True, 85 | ) -> None: 86 | """Prints content of DictConfig using Rich library and its tree structure. 87 | Args: 88 | config (DictConfig): Configuration composed by Hydra. 89 | fields (Sequence[str], optional): Determines which main fields from config will 90 | be printed and in what order. 91 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 92 | """ 93 | 94 | style = "dim" 95 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 96 | 97 | for field in fields: 98 | branch = tree.add(field, style=style, guide_style=style) 99 | 100 | config_section = config.get(field) 101 | branch_content = str(config_section) 102 | if isinstance(config_section, DictConfig): 103 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 104 | 105 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 106 | 107 | rich.print(tree) 108 | 109 | with open("config_tree.log", "w") as fp: 110 | rich.print(tree, file=fp) 111 | 112 | 113 | @rank_zero_only 114 | def log_hyperparameters( 115 | config: DictConfig, 116 | model: pl.LightningModule, 117 | datamodule: pl.LightningDataModule, 118 | trainer: pl.Trainer, 119 | callbacks: List[pl.Callback], 120 | logger: List[pl.loggers.Logger], 121 | ) -> None: 122 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 123 | Additionaly saves: 124 | - number of model parameters 125 | """ 126 | 127 | hparams = {} 128 | 129 | # choose which parts of hydra config will be saved to loggers 130 | hparams["trainer"] = config["trainer"] 131 | hparams["model"] = config["model"] 132 | hparams["datamodule"] = config["datamodule"] 133 | 134 | if "seed" in config: 135 | hparams["seed"] = config["seed"] 136 | if "callbacks" in config: 137 | hparams["callbacks"] = config["callbacks"] 138 | 139 | # save number of model parameters 140 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 141 | hparams["model/params/trainable"] = sum( 142 | p.numel() for p in model.parameters() if p.requires_grad 143 | ) 144 | hparams["model/params/non_trainable"] = sum( 145 | p.numel() for p in model.parameters() if not p.requires_grad 146 | ) 147 | 148 | # send hparams to all loggers 149 | trainer.logger.log_hyperparams(hparams) 150 | 151 | 152 | def finish( 153 | config: DictConfig, 154 | model: pl.LightningModule, 155 | datamodule: pl.LightningDataModule, 156 | trainer: pl.Trainer, 157 | callbacks: List[pl.Callback], 158 | logger: List[pl.loggers.Logger], 159 | ) -> None: 160 | """Makes sure everything closed properly.""" 161 | 162 | # without this sweeps with wandb logger might crash! 163 | for lg in logger: 164 | if isinstance(lg, pl.loggers.wandb.WandbLogger): 165 | import wandb 166 | 167 | wandb.finish() -------------------------------------------------------------------------------- /tests/test_biotite.py: -------------------------------------------------------------------------------- 1 | import biotite.structure as struc 2 | import biotite.structure.io as strucio 3 | import numpy as np 4 | from utils.residue_constants import restype_3to1 5 | 6 | file_path = '1A2K_r_b.pdb' # Replace with your PDB file path 7 | 8 | # Load the structure from the PDB file 9 | structure = strucio.load_structure(file_path) 10 | 11 | # Get the atomic coordinates as a NumPy array 12 | coordinates = structure.coord # Coordinates are stored in the 'coord' attribute 13 | 14 | # Get the residue names (three-letter codes) 15 | numbering, resn = struc.get_residues(structure) 16 | seq_list = [restype_3to1.get(three, "X") for three in resn] 17 | seq = ''.join(seq_list) 18 | 19 | 20 | def get_bb_coords(structure): 21 | # Filter atoms by names 'N', 'CA', 'C' 22 | n_atoms = structure[structure.atom_name == "N"] 23 | ca_atoms = structure[structure.atom_name == "CA"] 24 | c_atoms = structure[structure.atom_name == "C"] 25 | print(len(n_atoms)) 26 | print(len(ca_atoms)) 27 | print(len(c_atoms)) 28 | 29 | # Ensure that the number of N, CA, and C atoms are the same and correspond to residues 30 | n_res = min(len(n_atoms), len(ca_atoms), len(c_atoms)) 31 | 32 | # Create an array of shape (n_res, 3, 3) to hold [N, CA, C] for each residue 33 | coords = np.zeros((n_res, 3, 3)) 34 | 35 | # Assign coordinates for N, CA, and C atoms in the correct order 36 | coords[:, 0, :] = n_atoms.coord[:n_res] # N 37 | coords[:, 1, :] = ca_atoms.coord[:n_res] # CA 38 | coords[:, 2, :] = c_atoms.coord[:n_res] # C 39 | 40 | return coords 41 | 42 | bb_coords = get_bb_coords(structure) 43 | -------------------------------------------------------------------------------- /tests/test_gLM2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, AutoTokenizer 3 | 4 | model = AutoModel.from_pretrained('tattabio/gLM2_650M', torch_dtype=torch.bfloat16, trust_remote_code=True) 5 | tokenizer = AutoTokenizer.from_pretrained('tattabio/gLM2_650M', trust_remote_code=True) 6 | 7 | # A contig with two proteins and an inter-genic sequence. 8 | # NOTE: Nucleotides should always be lowercase, and prepended with `<+>`. 9 | sequence = "<+>MALTKVEKRNRIKRRVRGKISGTQASPRLSVYKSNK<+>aatttaaggaa<->MLGIDNIERVKPGGLELVDRLVAVNRVTKVTKGGRAFGFSAIVVVGNED" 10 | 11 | # Tokenize the sequence. 12 | encodings = tokenizer([sequence], return_tensors='pt') 13 | 14 | # Extract embeddings. 15 | with torch.no_grad(): 16 | embeddings = model(encodings.input_ids.cuda(), output_hidden_states=True).last_hidden_state 17 | 18 | -------------------------------------------------------------------------------- /tests/test_pinder.py: -------------------------------------------------------------------------------- 1 | from pinder.core.index.utils import get_index 2 | 3 | 4 | pindex = get_index() 5 | testset = pindex.query(f'pinder_s == True') 6 | df = testset[['id', 'holo_R_pdb', 'holo_L_pdb']] 7 | 8 | 9 | # Define the path string to prepend 10 | path_string = "/scratch16/jgray21/lchu11/data/pinder/2024-02/test_set_pdbs/" 11 | 12 | # Add the path string to column2 and column3 13 | df['holo_R_pdb'] = path_string + df['holo_R_pdb'].astype(str) 14 | df['holo_L_pdb'] = path_string + df['holo_L_pdb'].astype(str) 15 | 16 | # Display the updated DataFrame 17 | print(df) 18 | 19 | # Save the DataFrame to a CSV file without header and index 20 | df.to_csv('pinder_s.csv', header=False, index=False) 21 | -------------------------------------------------------------------------------- /weights/pinder_0.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graylab/DFMDock/e2fd49910b4d153259816b01d0b73dc2ebf4314e/weights/pinder_0.ckpt --------------------------------------------------------------------------------