├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── models │ └── conditional │ │ ├── config.json │ │ ├── logs │ │ └── lightning_logs │ │ │ └── version_0 │ │ │ ├── hparams.yaml │ │ │ └── metrics.csv │ │ ├── models │ │ └── best_by_valid │ │ │ └── epoch=805-step=93496.ckpt │ │ ├── training_args.json │ │ ├── training_mean_angles.json │ │ ├── training_mean_distances.json │ │ ├── training_mean_offset.json │ │ └── training_std_angles.json └── overview.png ├── configs ├── conditional.json ├── minimal_conditional.json ├── minimal_unconditional.json └── wandb │ └── wandb.json ├── data ├── cremp │ ├── test.csv │ ├── test │ │ └── .gitkeep │ ├── train.csv │ └── train │ │ └── .gitkeep └── tetrapeptides │ ├── E.C.T.S.pickle │ ├── E.V.E.D.pickle │ ├── E.V.S.S.pickle │ ├── E.V.T.C.pickle │ ├── F.A.S.T.pickle │ ├── F.C.C.T.pickle │ ├── F.F.C.K.pickle │ ├── F.F.F.F.pickle │ ├── F.G.A.T.pickle │ ├── F.K.T.L.pickle │ ├── F.T.K.E.pickle │ ├── F.V.T.T.pickle │ ├── H.A.K.S.pickle │ ├── H.H.K.G.pickle │ ├── H.H.L.H.pickle │ ├── H.K.K.S.pickle │ ├── H.L.D.T.pickle │ ├── H.L.K.S.pickle │ ├── H.S.G.A.pickle │ ├── H.T.G.S.pickle │ ├── K.E.E.V.pickle │ ├── K.K.S.A.pickle │ ├── K.T.T.E.pickle │ ├── K.V.G.G.pickle │ ├── L.D.C.E.pickle │ ├── L.V.E.D.pickle │ ├── T.C.T.A.pickle │ ├── T.S.V.C.pickle │ ├── T.V.S.D.pickle │ ├── V.G.D.C.pickle │ ├── W.A.E.L.pickle │ ├── W.F.V.E.pickle │ ├── W.G.T.Y.pickle │ ├── W.S.F.H.pickle │ ├── W.V.E.E.pickle │ ├── W.V.V.C.pickle │ ├── Y.A.K.V.pickle │ ├── Y.D.S.Y.pickle │ ├── Y.E.A.D.pickle │ ├── Y.E.E.D.pickle │ ├── Y.E.T.D.pickle │ ├── Y.G.H.Y.pickle │ ├── Y.G.S.K.pickle │ ├── Y.K.A.D.pickle │ ├── Y.K.F.E.pickle │ ├── Y.S.C.A.pickle │ ├── Y.T.D.T.pickle │ ├── Y.T.V.K.pickle │ ├── Y.Y.F.E.pickle │ └── Y.Y.S.Y.pickle ├── environment.yaml ├── pyproject.toml ├── ringer ├── __init__.py ├── compute_metrics.py ├── data │ ├── __init__.py │ ├── macrocycle.py │ └── noised.py ├── eval.py ├── models │ ├── __init__.py │ ├── bert_for_diffusion.py │ └── components │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── embeddings.py │ │ └── output.py ├── reconstruct.py ├── sidechain_reconstruction │ ├── __init__.py │ ├── data │ │ └── reconstruction_data.pickle │ ├── reconstruction.py │ └── transforms.py ├── train.py └── utils │ ├── __init__.py │ ├── chem.py │ ├── data │ └── amino_acids.csv │ ├── data_loading.py │ ├── evaluation.py │ ├── featurization.py │ ├── internal_coords.py │ ├── losses.py │ ├── peptides.py │ ├── plotting.py │ ├── reconstruction.py │ ├── sampling.py │ ├── utils.py │ └── variance_schedules.py ├── scripts ├── aggregate_metrics.py ├── compute_metrics_single.py └── reconstruct_single.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # User-added 2 | ringer/data/cache*.pickle 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | ### VisualStudioCode 134 | .vscode/* 135 | !.vscode/settings.json 136 | !.vscode/tasks.json 137 | !.vscode/launch.json 138 | !.vscode/extensions.json 139 | *.code-workspace 140 | **/.vscode 141 | 142 | # JetBrains 143 | .idea/ 144 | 145 | # Data & Models 146 | *.h5 147 | *.tar 148 | *.tar.gz 149 | 150 | # Lightning-Hydra-Template 151 | configs/local/default.yaml 152 | /data/ 153 | /logs/ 154 | .env 155 | 156 | # Aim logging 157 | .aim 158 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 23.3.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python docstring formatting 35 | - repo: https://github.com/myint/docformatter 36 | rev: v1.6.2 37 | hooks: 38 | - id: docformatter 39 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 40 | 41 | # python check (PEP8), programming errors and code complexity 42 | - repo: https://github.com/PyCQA/flake8 43 | rev: 6.0.0 44 | hooks: 45 | - id: flake8 46 | args: 47 | [ 48 | "--extend-ignore", 49 | "E203,E402,E501,F401,F841", 50 | "--exclude", 51 | "logs/*,data/*", 52 | ] 53 | 54 | # yaml formatting 55 | - repo: https://github.com/pre-commit/mirrors-prettier 56 | rev: v3.0.0-alpha.8-for-vscode 57 | hooks: 58 | - id: prettier 59 | types: [yaml] 60 | exclude: "environment.yaml" 61 | 62 | # md formatting 63 | - repo: https://github.com/executablebooks/mdformat 64 | rev: 0.7.16 65 | hooks: 66 | - id: mdformat 67 | args: ["--number"] 68 | additional_dependencies: 69 | - mdformat-gfm 70 | - mdformat-tables 71 | - mdformat_frontmatter 72 | # - mdformat-toc 73 | # - mdformat-black 74 | 75 | # word spelling linter 76 | - repo: https://github.com/codespell-project/codespell 77 | rev: v2.2.4 78 | hooks: 79 | - id: codespell 80 | args: 81 | - --skip=logs/**,data/**,*.ipynb 82 | # - --ignore-words-list=abc,def 83 | 84 | # jupyter notebook cell output clearing 85 | - repo: https://github.com/kynan/nbstripout 86 | rev: 0.6.1 87 | hooks: 88 | - id: nbstripout 89 | 90 | # jupyter notebook linting 91 | - repo: https://github.com/nbQA-dev/nbQA 92 | rev: 1.7.0 93 | hooks: 94 | - id: nbqa-black 95 | args: ["--line-length=99"] 96 | - id: nbqa-isort 97 | args: ["--profile=black"] 98 | - id: nbqa-flake8 99 | args: 100 | [ 101 | "--extend-ignore=E203,E402,E501,F401,F841", 102 | "--exclude=logs/*,data/*", 103 | ] 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023, Genentech, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RINGER 2 | 3 | This repository is the official implementation of [RINGER: Rapid Conformer Generation for Macrocycles with Sequence-Conditioned Internal Coordinate Diffusion](https://arxiv.org/abs/2305.19800). 4 | 5 | ![cover](assets/overview.png) 6 | 7 | ## Requirements 8 | 9 | To install requirements: 10 | 11 | ```setup 12 | conda env create -f environment.yaml 13 | conda activate ringer 14 | pip install -e . 15 | ``` 16 | 17 | ## Data 18 | 19 | Download and extract the CREMP pickle.tar.gz from [here](https://doi.org/10.5281/zenodo.7931444). Use [train.csv](data/cremp/train.csv) and [test.csv](data/cremp/test.csv) to partition it into training and test data and put the corresponding files into [train](data/cremp/train) and [test](data/cremp/test). 20 | 21 | ## Training 22 | 23 | To train the full conditional model, run this command: 24 | 25 | ```train 26 | train conditional.json 27 | ``` 28 | 29 | The config file can be specified by an absolute path or by a path relative to the [configs](configs) folder. Similarly, within the config file, `data_dir` can be an absolute path or a path relative to the [data](data) folder. 30 | 31 | To log a training run with Weights & Biases, set up your configuration in [configs/wandb/wandb.json](configs/wandb/wandb.json) and set up logging using: 32 | 33 | ```train 34 | train conditional.json --wandb-run 35 | ``` 36 | 37 | ## Sampling 38 | 39 | The [pre-trained model](assets/models/conditional) is included in this repository. 40 | 41 | To generate samples for the CREMP test set, run: 42 | 43 | ```eval 44 | evaluate \ 45 | --model-dir assets/models/conditional \ 46 | --data-dir cremp/test \ 47 | --split-sizes 0.0 0.0 1.0 \ 48 | --sample-only 49 | ``` 50 | 51 | This creates a `sample` directory containing samples for all molecules in `sample/samples.pickle`. 52 | 53 | Run `evaluate --help` to see all options available for sampling and evaluation. 54 | 55 | ## Reconstruction 56 | 57 | The `evaluate` command can also be used to reconstruct backbones (not including side chains) and to compute evaluation metrics. However, it is not recommended to do so because `evaluate` does not parallelize well across molecules. 58 | 59 | Instead, reconstruction (including side chains) is done most effectively for each molecule individually using [scripts/reconstruct_single.py](scripts/reconstruct_single.py). Parallelization can then be efficiently achieved by submitting a batch job array using an HPC job scheduler (e.g., Slurm) and passing the job array index as the first argument to the script. To reconstruct molecule 0, run: 60 | 61 | ```shell 62 | python scripts/reconstruct_single.py 0 \ 63 | cremp/test \ 64 | sample/samples.pickle \ 65 | sample/reconstructed_mols \ 66 | assets/models/conditional/training_mean_distances.json 67 | ``` 68 | 69 | The script will run the optimization to reconstruct the ring coordinates, followed by a linear (NeRF) reconstruction of the side chains using the [conformer samples previously generated](#sampling), and save the resulting molecule in `sample/reconstructed_mols`. Note that even though we point the script to `cremp/test`, it only uses the atom identities and connectivity information from the test molecules; their geometries are entirely set during the reconstruction procedure. 70 | 71 | Run `python scripts/reconstruct_single.py --help` for an overview of other parameters available for reconstruction. 72 | 73 | ## Evaluation 74 | 75 | As with reconstruction, computing metrics is best done separately for each molecule using [scripts/compute_metrics_single.py](scripts/compute_metrics_single.py) followed by aggregation across molecules using [scripts/aggregate_metrics.py](scripts/aggregate_metrics.py). For example, to compute metrics for the `H.A.S.V` macrocycle, run 76 | 77 | ```shell 78 | python scripts/compute_metrics_single.py \ 79 | cremp/test/H.A.S.V.pickle \ 80 | sample/reconstructed_mols/H.A.S.V.pickle 81 | ``` 82 | 83 | Run `python scripts/compute_metrics_single.py --help` and `python scripts/aggregate_metrics.py --help` for an overview of other parameters available for computing metrics. 84 | 85 | ## Contributing 86 | 87 | Install pre-commit hooks to use automated code formatting before committing changes. Make sure you're in the top-level directory and run: 88 | 89 | ```bash 90 | pre-commit install 91 | ``` 92 | 93 | After that, your code will be automatically reformatted on every new commit. 94 | 95 | To manually reformat all files in the project, use: 96 | 97 | ```bash 98 | pre-commit run -a 99 | ``` 100 | 101 | To update the hooks defined in [.pre-commit-config.yaml](.pre-commit-config.yaml), use: 102 | 103 | ```bash 104 | pre-commit autoupdate 105 | ``` 106 | 107 | ## License 108 | 109 | Licensed under the MIT License. See [LICENSE](LICENSE) for additional details. 110 | 111 | ## Citations 112 | 113 | For the code and/or model, please cite: 114 | 115 | ``` 116 | @misc{grambow2023ringer, 117 | title={{RINGER}: Rapid Conformer Generation for Macrocycles with Sequence-Conditioned Internal Coordinate Diffusion}, 118 | author={Colin A. Grambow and Hayley Weir and Nathaniel L. Diamant and Alex M. Tseng and Tommaso Biancalani and Gabriele Scalia and Kangway V. Chuang}, 119 | year={2023}, 120 | eprint={2305.19800}, 121 | archivePrefix={arXiv}, 122 | primaryClass={q-bio.BM} 123 | } 124 | ``` 125 | 126 | To cite the CREMP dataset, please use: 127 | 128 | ``` 129 | @article{grambow2024cremp, 130 | title = {{CREMP: Conformer-rotamer ensembles of macrocyclic peptides for machine learning}}, 131 | author = {Grambow, Colin A. and Weir, Hayley and Cunningham, Christian N. and Biancalani, Tommaso and Chuang, Kangway V.}, 132 | year = {2024}, 133 | journal = {Scientific Data}, 134 | doi = {10.1038/s41597-024-03698-y}, 135 | pages = {859}, 136 | number = {1}, 137 | volume = {11} 138 | } 139 | ``` 140 | 141 | You can also cite the CREMP Zenodo repository directly: 142 | 143 | ``` 144 | @dataset{grambow_colin_a_2023_7931444, 145 | author = {Grambow, Colin A. and 146 | Weir, Hayley and 147 | Cunningham, Christian N. and 148 | Biancalani, Tommaso and 149 | Chuang, Kangway V.}, 150 | title = {{CREMP: Conformer-Rotamer Ensembles of Macrocyclic 151 | Peptides for Machine Learning}}, 152 | month = may, 153 | year = 2023, 154 | publisher = {Zenodo}, 155 | version = {1.0.1}, 156 | doi = {10.5281/zenodo.7931444}, 157 | url = {https://doi.org/10.5281/zenodo.7931444} 158 | } 159 | ``` 160 | -------------------------------------------------------------------------------- /assets/models/conditional/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "classifier_dropout": null, 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 384, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 512, 9 | "layer_norm_eps": 1e-12, 10 | "max_position_embeddings": 18, 11 | "model_type": "bert", 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "pad_token_id": 0, 15 | "position_embedding_type": "cyclic_relative_key", 16 | "transformers_version": "4.11.3", 17 | "type_vocab_size": 2, 18 | "use_cache": false, 19 | "vocab_size": 30522 20 | } 21 | -------------------------------------------------------------------------------- /assets/models/conditional/logs/lightning_logs/version_0/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /assets/models/conditional/models/best_by_valid/epoch=805-step=93496.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/assets/models/conditional/models/best_by_valid/epoch=805-step=93496.ckpt -------------------------------------------------------------------------------- /assets/models/conditional/training_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "out_dir": "results", 3 | "data_dir": "cremp/train", 4 | "split_sizes": [ 5 | 0.9, 6 | 0.1, 7 | 0.0 8 | ], 9 | "internal_coordinates_definitions": "angles-sidechains", 10 | "use_atom_features": true, 11 | "atom_feature_fingerprint_radius": 3, 12 | "atom_feature_fingerprint_size": 32, 13 | "atom_feature_embed_size": 128, 14 | "max_conf": 30, 15 | "timesteps": 20, 16 | "variance_schedule": "cosine", 17 | "variance_scale": 1.0, 18 | "use_feat_mask": false, 19 | "mask_noise": false, 20 | "mask_noise_for_features": null, 21 | "time_encoding": "gaussian_fourier", 22 | "num_hidden_layers": 12, 23 | "hidden_size": 256, 24 | "intermediate_size": 512, 25 | "num_heads": 12, 26 | "position_embedding_type": "cyclic_relative_key", 27 | "dropout_p": 0.1, 28 | "decoder": "mlp", 29 | "batch_size": 8192, 30 | "loss": "smooth_l1", 31 | "l2_norm": 0.0, 32 | "l1_norm": 0.0, 33 | "circle_reg": 0.0, 34 | "gradient_clip": 1.0, 35 | "lr": 0.0005, 36 | "lr_scheduler": "LinearWarmup", 37 | "min_epochs": null, 38 | "max_epochs": 1000, 39 | "warmup_epochs": 10, 40 | "weights": null, 41 | "early_stop_patience": 0, 42 | "use_swa": false, 43 | "exhaustive_validation_t": false, 44 | "use_data_cache": true, 45 | "data_cache_dir": null, 46 | "unsafe_cache": false, 47 | "ngpu": -1, 48 | "write_validation_preds": false, 49 | "profile": false, 50 | "overwrite": false, 51 | "wandb_config": null, 52 | "ncpu": 12 53 | } -------------------------------------------------------------------------------- /assets/models/conditional/training_mean_angles.json: -------------------------------------------------------------------------------- 1 | { 2 | "N": 2.1367293302820793, 3 | "Calpha": 1.8861722331825885, 4 | "CO": 2.0218062240742642 5 | } -------------------------------------------------------------------------------- /assets/models/conditional/training_mean_distances.json: -------------------------------------------------------------------------------- 1 | { 2 | "N": 1.4554429785247358, 3 | "Calpha": 1.5352094074926335, 4 | "CO": 1.3386375832494648 5 | } -------------------------------------------------------------------------------- /assets/models/conditional/training_mean_offset.json: -------------------------------------------------------------------------------- 1 | { 2 | "angle": 2.014923074580997, 3 | "dihedral": -2.8527701661051834, 4 | "sc_a0": 1.9366028715164367, 5 | "sc_a1": 1.9501811611447346, 6 | "sc_a2": 2.044697662261966, 7 | "sc_a3": 2.007077917758493, 8 | "sc_a4": 1.966931025746925, 9 | "sc_chi0": -3.0972997779407567, 10 | "sc_chi1": -1.1837038449289232, 11 | "sc_chi2": 2.9365040076864912, 12 | "sc_chi3": -3.111979524726831, 13 | "sc_chi4": -1.8523933436797895 14 | } -------------------------------------------------------------------------------- /assets/models/conditional/training_std_angles.json: -------------------------------------------------------------------------------- 1 | { 2 | "N": 0.05673038546980757, 3 | "Calpha": 0.05973643014341084, 4 | "CO": 0.03822711803587237 5 | } -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/assets/overview.png -------------------------------------------------------------------------------- /configs/conditional.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "cremp/train", 3 | "split_sizes": [0.9, 0.1, 0.0], 4 | "internal_coordinates_definitions": "angles-sidechains", 5 | "use_atom_features": true, 6 | "atom_feature_fingerprint_radius": 3, 7 | "atom_feature_fingerprint_size": 32, 8 | "atom_feature_embed_size": 128, 9 | "max_conf": 30, 10 | "timesteps": 20, 11 | "variance_schedule": "cosine", 12 | "variance_scale": 1.0, 13 | "time_encoding": "gaussian_fourier", 14 | "num_hidden_layers": 12, 15 | "hidden_size": 256, 16 | "intermediate_size": 512, 17 | "num_heads": 12, 18 | "position_embedding_type": "cyclic_relative_key", 19 | "dropout_p": 0.1, 20 | "decoder": "mlp", 21 | "loss": "smooth_l1", 22 | "use_feat_mask": false, 23 | "l2_norm": 0.0, 24 | "l1_norm": 0.0, 25 | "circle_reg": 0.0, 26 | "gradient_clip": 1.0, 27 | "lr": 5e-4, 28 | "lr_scheduler": "LinearWarmup", 29 | "max_epochs": 1000, 30 | "warmup_epochs": 10, 31 | "early_stop_patience": 0, 32 | "use_swa": false, 33 | "batch_size": 8192 34 | } 35 | -------------------------------------------------------------------------------- /configs/minimal_conditional.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "tetrapeptides", 3 | "split_sizes": [0.8, 0.1, 0.1], 4 | "internal_coordinates_definitions": "angles-sidechains", 5 | "use_atom_features": true, 6 | "atom_feature_fingerprint_radius": 3, 7 | "atom_feature_fingerprint_size": 32, 8 | "atom_feature_embed_size": 24, 9 | "max_conf": 30, 10 | "timesteps": 50, 11 | "variance_schedule": "cosine", 12 | "variance_scale": 1.0, 13 | "time_encoding": "gaussian_fourier", 14 | "num_hidden_layers": 3, 15 | "hidden_size": 24, 16 | "intermediate_size": 96, 17 | "num_heads": 3, 18 | "position_embedding_type": "cyclic_relative_key", 19 | "dropout_p": 0.1, 20 | "decoder": "mlp", 21 | "loss": "smooth_l1", 22 | "use_feat_mask": false, 23 | "l2_norm": 0.0, 24 | "l1_norm": 0.0, 25 | "circle_reg": 0.0, 26 | "gradient_clip": 1.0, 27 | "lr": 5e-5, 28 | "lr_scheduler": "LinearWarmup", 29 | "max_epochs": 2000, 30 | "warmup_epochs": 100, 31 | "early_stop_patience": 0, 32 | "use_swa": false, 33 | "batch_size": 64 34 | } 35 | -------------------------------------------------------------------------------- /configs/minimal_unconditional.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "tetrapeptides", 3 | "split_sizes": [0.8, 0.1, 0.1], 4 | "internal_coordinates_definitions": "angles", 5 | "use_atom_features": false, 6 | "max_conf": 30, 7 | "timesteps": 50, 8 | "variance_schedule": "cosine", 9 | "variance_scale": 1.0, 10 | "time_encoding": "gaussian_fourier", 11 | "num_hidden_layers": 3, 12 | "hidden_size": 24, 13 | "intermediate_size": 96, 14 | "num_heads": 3, 15 | "position_embedding_type": "cyclic_relative_key", 16 | "dropout_p": 0.1, 17 | "decoder": "mlp", 18 | "loss": "smooth_l1", 19 | "l2_norm": 0.0, 20 | "l1_norm": 0.0, 21 | "circle_reg": 0.0, 22 | "gradient_clip": 1.0, 23 | "lr": 5e-5, 24 | "lr_scheduler": "LinearWarmup", 25 | "max_epochs": 2000, 26 | "warmup_epochs": 100, 27 | "early_stop_patience": 0, 28 | "use_swa": false, 29 | "batch_size": 64 30 | } 31 | -------------------------------------------------------------------------------- /configs/wandb/wandb.json: -------------------------------------------------------------------------------- 1 | { 2 | "project": "", 3 | "entity": "" 4 | } 5 | -------------------------------------------------------------------------------- /data/cremp/test/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/cremp/test/.gitkeep -------------------------------------------------------------------------------- /data/cremp/train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/cremp/train/.gitkeep -------------------------------------------------------------------------------- /data/tetrapeptides/E.C.T.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/E.C.T.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/E.V.E.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/E.V.E.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/E.V.S.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/E.V.S.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/E.V.T.C.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/E.V.T.C.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.A.S.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.A.S.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.C.C.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.C.C.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.F.C.K.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.F.C.K.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.F.F.F.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.F.F.F.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.G.A.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.G.A.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.K.T.L.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.K.T.L.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.T.K.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.T.K.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/F.V.T.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/F.V.T.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.A.K.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.A.K.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.H.K.G.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.H.K.G.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.H.L.H.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.H.L.H.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.K.K.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.K.K.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.L.D.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.L.D.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.L.K.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.L.K.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.S.G.A.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.S.G.A.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/H.T.G.S.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/H.T.G.S.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/K.E.E.V.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/K.E.E.V.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/K.K.S.A.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/K.K.S.A.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/K.T.T.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/K.T.T.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/K.V.G.G.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/K.V.G.G.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/L.D.C.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/L.D.C.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/L.V.E.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/L.V.E.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/T.C.T.A.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/T.C.T.A.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/T.S.V.C.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/T.S.V.C.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/T.V.S.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/T.V.S.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/V.G.D.C.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/V.G.D.C.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.A.E.L.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.A.E.L.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.F.V.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.F.V.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.G.T.Y.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.G.T.Y.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.S.F.H.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.S.F.H.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.V.E.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.V.E.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/W.V.V.C.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/W.V.V.C.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.A.K.V.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.A.K.V.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.D.S.Y.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.D.S.Y.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.E.A.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.E.A.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.E.E.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.E.E.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.E.T.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.E.T.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.G.H.Y.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.G.H.Y.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.G.S.K.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.G.S.K.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.K.A.D.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.K.A.D.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.K.F.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.K.F.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.S.C.A.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.S.C.A.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.T.D.T.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.T.D.T.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.T.V.K.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.T.V.K.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.Y.F.E.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.Y.F.E.pickle -------------------------------------------------------------------------------- /data/tetrapeptides/Y.Y.S.Y.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/data/tetrapeptides/Y.Y.S.Y.pickle -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ringer 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - huggingface 6 | dependencies: 7 | - python=3.8 8 | - pip 9 | - pytorch::pytorch=1.12 10 | - huggingface::transformers=4.11.3 11 | - conda-forge::pytorch-lightning=1.6.4 12 | - conda-forge::numpy 13 | - conda-forge::scipy 14 | - conda-forge::pandas 15 | - conda-forge::matplotlib 16 | - conda-forge::seaborn 17 | - conda-forge::astropy 18 | - conda-forge::rdkit>=2022.09.5 19 | - conda-forge::pre-commit 20 | - conda-forge::pytest 21 | - conda-forge::rich 22 | - conda-forge::shellingham 23 | - conda-forge::typer 24 | - conda-forge::wandb 25 | - pip: 26 | - ray 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /ringer/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | CONFIG_DIR = Path(__file__).resolve().parent.parent / "configs" 4 | DATA_DIR = Path(__file__).resolve().parent.parent / "data" 5 | -------------------------------------------------------------------------------- /ringer/compute_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import multiprocessing 5 | import pickle 6 | from pathlib import Path 7 | from typing import Any, Union 8 | 9 | import typer 10 | from tqdm.contrib.concurrent import process_map 11 | 12 | from ringer.utils import evaluation 13 | 14 | 15 | def load_pickle(path: Union[str, Path]) -> Any: 16 | with open(path, "rb") as f: 17 | return pickle.load(f) 18 | 19 | 20 | def save_pickle(path: Union[str, Path], data: Any) -> None: 21 | with open(path, "wb") as f: 22 | pickle.dump(data, f) 23 | 24 | 25 | def compute_metrics( 26 | mol_dir: str, 27 | mol_opt_dir: str, 28 | out_dir: str, 29 | include_all_atom: bool = True, 30 | ncpu: int = multiprocessing.cpu_count(), 31 | ) -> None: 32 | output_dir = Path(out_dir) 33 | output_dir.mkdir(exist_ok=True) 34 | 35 | # Load data 36 | mols_avail = {path.name: load_pickle(path) for path in Path(mol_dir).glob("*.pickle")} 37 | mols_opt_avail = {path.name: load_pickle(path) for path in Path(mol_opt_dir).glob("*.pickle")} 38 | 39 | # Only keep mols that were generated and reorder so that mols correspond to mols_opt 40 | mols = {} 41 | mols_opt = {} 42 | for name, mol in mols_avail.items(): 43 | try: 44 | mol_opt = mols_opt_avail[name] 45 | except KeyError: 46 | logging.warning(f"Skipping '{name}', no generated mol found") 47 | else: 48 | mols[name] = mol 49 | mols_opt[name] = mol_opt 50 | 51 | # Evaluate 52 | metric_names = ["ring-rmsd", "ring-tfd"] 53 | if include_all_atom: 54 | metric_names.append("rmsd") 55 | cov_mat_evaluator = evaluation.CovMatEvaluator(metric_names) 56 | metrics = process_map(cov_mat_evaluator, mols_opt.values(), mols.values(), max_workers=ncpu) 57 | metrics = dict(zip(mols.keys(), metrics)) # Add names as keys 58 | 59 | # Simplify and aggregate results 60 | metrics = cov_mat_evaluator.stack_results(metrics) 61 | metrics_aggregated = cov_mat_evaluator.aggregate_results(metrics) 62 | 63 | metrics_path = output_dir / "metrics.pickle" 64 | metrics_aggregated_path = output_dir / "metrics_aggregated.pickle" 65 | save_pickle(metrics_path, metrics) 66 | save_pickle(metrics_aggregated_path, metrics_aggregated) 67 | 68 | 69 | def main() -> None: 70 | logging.basicConfig(level=logging.INFO) 71 | typer.run(compute_metrics) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /ringer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import macrocycle 2 | 3 | DATASET_CLASSES = { 4 | "distances-angles": macrocycle.MacrocycleInternalCoordinateDataset, 5 | "angles": macrocycle.MacrocycleAnglesDataset, 6 | "dihedrals": macrocycle.MacrocycleDihedralsDataset, 7 | "angles-sidechains": macrocycle.MacrocycleAnglesWithSideChainsDataset, 8 | } 9 | -------------------------------------------------------------------------------- /ringer/data/noised.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from rdkit import Chem 8 | from torch.utils.data import Dataset 9 | 10 | from ..utils import utils, variance_schedules 11 | from . import macrocycle 12 | 13 | 14 | class NoisedDataset(Dataset): 15 | """Class that produces noised outputs given a wrapped dataset. Wrapped dset should return a 16 | tensor from __getitem__ if dset_key is not specified; otherwise, returns a dictionary where the 17 | item to noise is under dset_key. 18 | 19 | modulo can be given as either a float or a list of floats 20 | """ 21 | 22 | def __init__( 23 | self, 24 | dset: macrocycle.MacrocycleInternalCoordinateDataset, 25 | dset_key: str = "angles", 26 | timesteps: int = 50, 27 | exhaustive_t: bool = False, 28 | beta_schedule: variance_schedules.SCHEDULES = "cosine", 29 | nonangular_variance: float = 1.0, 30 | angular_variance: float = 1.0, 31 | mask_noise: bool = False, 32 | mask_noise_for_features: Optional[List[str]] = None, 33 | ) -> None: 34 | super().__init__() 35 | 36 | self.dset = dset 37 | assert hasattr(dset, "feature_names") 38 | assert hasattr(dset, "feature_is_angular") 39 | self.dset_key = dset_key 40 | self.n_features = len(dset.feature_is_angular) 41 | 42 | self.nonangular_var_scale = nonangular_variance 43 | self.angular_var_scale = angular_variance 44 | 45 | self.timesteps = timesteps 46 | self.schedule = beta_schedule 47 | self.exhaustive_timesteps = exhaustive_t 48 | if self.exhaustive_timesteps: 49 | logging.info(f"Exhuastive timesteps for {dset}") 50 | 51 | betas = variance_schedules.get_variance_schedule(beta_schedule, timesteps) 52 | self.alpha_beta_terms = variance_schedules.compute_alphas(betas) 53 | 54 | # Whether to use feature mask to mask out (side-chain) noise 55 | self.mask_noise = mask_noise 56 | 57 | # List of feature names that we don't want to noise 58 | self.mask_noise_for_features = mask_noise_for_features 59 | 60 | @property 61 | def structures(self) -> Optional[Dict[str, Dict[str, pd.DataFrame]]]: 62 | return self.dset.structures 63 | 64 | @property 65 | def atom_features(self) -> Optional[Dict[str, Dict[str, Union[Chem.Mol, pd.DataFrame]]]]: 66 | return self.dset.atom_features 67 | 68 | @property 69 | def feature_names(self) -> Tuple[str, ...]: 70 | """Pass through feature names property of wrapped dset.""" 71 | return self.dset.feature_names 72 | 73 | @property 74 | def feature_is_angular(self) -> Tuple[bool, ...]: 75 | """Pass through feature is angular property of wrapped dset.""" 76 | return self.dset.feature_is_angular 77 | 78 | @property 79 | def pad(self) -> int: 80 | """Pass through the pad property of wrapped dset.""" 81 | return self.dset.pad 82 | 83 | @property 84 | def means(self) -> Optional[np.ndarray]: 85 | return self.dset.means 86 | 87 | @property 88 | def means_dict(self) -> Optional[Dict[str, float]]: 89 | return self.dset.means_dict 90 | 91 | @means.setter 92 | def means(self, means: Dict[str, float]) -> None: 93 | self.dset.means = means 94 | 95 | @property 96 | def all_lengths(self) -> List[int]: 97 | return self.dset.all_lengths 98 | 99 | def sample_length(self, *args, **kwargs) -> Union[int, List[int]]: 100 | return self.dset.sample_length(*args, **kwargs) 101 | 102 | def get_atom_features( 103 | self, *args, **kwargs 104 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[int]]]: 105 | return self.dset.get_atom_features(*args, **kwargs) 106 | 107 | def get_structure_as_dataframe(self, index: int) -> pd.DataFrame: 108 | return self.dset.get_structure_as_dataframe(index) 109 | 110 | def __str__(self) -> str: 111 | return f"NoisedAnglesDataset wrapping {self.dset} with {len(self)} examples with {self.schedule}-{self.timesteps} with variance scales {self.nonangular_var_scale} and {self.angular_var_scale}" 112 | 113 | def __len__(self) -> int: 114 | if not self.exhaustive_timesteps: 115 | return len(self.dset) 116 | else: 117 | return int(len(self.dset) * self.timesteps) 118 | 119 | def sample_noise(self, vals: torch.Tensor, uniform: bool = False) -> torch.Tensor: 120 | """Adaptively sample noise based on modulo. 121 | 122 | We scale only the variance because we want the noise to remain zero centered 123 | """ 124 | # Noise is always 0 centered 125 | if uniform: 126 | noise = torch.rand_like(vals) * 2 * np.pi - np.pi 127 | else: 128 | noise = torch.randn_like(vals) 129 | 130 | # Shapes of vals couled be (batch, seq, feat) or (seq, feat) 131 | # Therefore we need to index into last dimension consistently 132 | 133 | # Scale by provided variance scales based on angular or not 134 | if self.angular_var_scale != 1.0 or self.nonangular_var_scale != 1.0: 135 | for j in range(noise.shape[-1]): # Last dim = feature dim 136 | s = ( 137 | self.angular_var_scale 138 | if self.feature_is_angular[j] 139 | else self.nonangular_var_scale 140 | ) 141 | noise[..., j] *= s 142 | 143 | # Make sure that the noise doesn't run over the boundaries 144 | noise[..., self.feature_is_angular] = utils.modulo_with_wrapped_range( 145 | noise[..., self.feature_is_angular], -np.pi, np.pi 146 | ) 147 | 148 | return noise 149 | 150 | def __getitem__( 151 | self, 152 | index: int, 153 | use_t_val: Optional[int] = None, 154 | ignore_zero_center: bool = False, 155 | ) -> Dict[str, torch.Tensor]: 156 | """Gets the i-th item in the dataset and adds noise use_t_val is useful for manually 157 | querying specific timepoints.""" 158 | assert 0 <= index < len(self), f"Index {index} out of bounds for {len(self)}" 159 | # Handle cases where we exhaustively loop over t 160 | if self.exhaustive_timesteps: 161 | item_index = index // self.timesteps 162 | assert item_index < len(self) 163 | time_index = index % self.timesteps 164 | logging.debug(f"Exhaustive {index} -> item {item_index} at time {time_index}") 165 | assert ( 166 | item_index * self.timesteps + time_index == index 167 | ), f"Unexpected indices for {index} -- {item_index} {time_index}" 168 | item = self.dset.__getitem__(item_index, ignore_zero_center=ignore_zero_center) 169 | else: 170 | item = self.dset.__getitem__(index, ignore_zero_center=ignore_zero_center) 171 | 172 | # If wrapped dset returns a dictionary then we extract the item to noise 173 | if self.dset_key is not None: 174 | assert isinstance(item, dict) 175 | vals = item[self.dset_key].clone() 176 | else: 177 | vals = item.clone() 178 | assert isinstance( 179 | vals, torch.Tensor 180 | ), f"Using dset_key {self.dset_key} - expected tensor but got {type(vals)}" 181 | 182 | # Sample a random timepoint and add corresponding noise 183 | if use_t_val is not None: 184 | assert not self.exhaustive_timesteps, "Cannot use specific t in exhaustive mode" 185 | t_val = np.clip(np.array([use_t_val]), 0, self.timesteps - 1) 186 | t = torch.from_numpy(t_val).long() 187 | elif self.exhaustive_timesteps: 188 | t = torch.tensor([time_index]).long() # list to get correct shape 189 | else: 190 | t = torch.randint(0, self.timesteps, (1,)).long() 191 | 192 | # Get the values for alpha and beta 193 | sqrt_alphas_cumprod_t = self.alpha_beta_terms["sqrt_alphas_cumprod"][t.item()] 194 | sqrt_one_minus_alphas_cumprod_t = self.alpha_beta_terms["sqrt_one_minus_alphas_cumprod"][ 195 | t.item() 196 | ] 197 | # Noise is sampled within range of [-pi, pi], and optionally 198 | # shifted to [0, 2pi] by adding pi 199 | noise = self.sample_noise(vals) # Vals passed in only for shape 200 | 201 | if self.mask_noise and "feat_mask" in item: 202 | noise[~item["feat_mask"].bool()] = 0 203 | 204 | if self.mask_noise_for_features is not None: 205 | feature_names = np.array(self.feature_names) 206 | feat_mask_idxs = np.in1d(feature_names, self.mask_noise_for_features).nonzero()[0] 207 | noise[..., feat_mask_idxs] = 0 208 | 209 | # Add noise and ensure noised vals are still in range 210 | noised_vals = sqrt_alphas_cumprod_t * vals + sqrt_one_minus_alphas_cumprod_t * noise 211 | assert noised_vals.shape == vals.shape, f"Unexpected shape {noised_vals.shape}" 212 | # The underlying vals are already shifted, and noise is already shifted 213 | # All we need to do is ensure we stay on the corresponding manifold 214 | # Wrap around the correct range 215 | noised_vals[:, self.feature_is_angular] = utils.modulo_with_wrapped_range( 216 | noised_vals[:, self.feature_is_angular], -np.pi, np.pi 217 | ) 218 | 219 | retval = { 220 | "corrupted": noised_vals, 221 | "t": t, 222 | "known_noise": noise, 223 | "sqrt_alphas_cumprod_t": sqrt_alphas_cumprod_t, 224 | "sqrt_one_minus_alphas_cumprod_t": sqrt_one_minus_alphas_cumprod_t, 225 | } 226 | 227 | # Update dictionary if wrapped dset returns dicts, else just return 228 | if isinstance(item, dict): 229 | assert item.keys().isdisjoint(retval.keys()) 230 | item.update(retval) 231 | return item 232 | return retval 233 | -------------------------------------------------------------------------------- /ringer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Import attention so that transformers.models.bert.modeling_bert gets patched 2 | from .components import attention 3 | -------------------------------------------------------------------------------- /ringer/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/ringer/models/components/__init__.py -------------------------------------------------------------------------------- /ringer/models/components/attention.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | from transformers.models.bert import configuration_bert, modeling_bert 8 | 9 | # Monkey-patch BertSelfAttention to enable cyclic relative positional encoding 10 | # This means that the output of BertSelfAttention.forward transforms equivariantly under 11 | # a cyclic shift of the input sequences 12 | if modeling_bert.BertSelfAttention.__name__ == "BertSelfAttention": 13 | 14 | class BertSelfAttentionWithCyclicEncoding(modeling_bert.BertSelfAttention): 15 | def __init__(self, config: configuration_bert.BertConfig, *args, **kwargs) -> None: 16 | super().__init__(config, *args, **kwargs) 17 | if self.position_embedding_type == "cyclic_relative_key": 18 | logging.info("Using cyclic positional encoding") 19 | assert not hasattr(self, "max_position_embeddings") 20 | assert not hasattr(self, "distance_embedding") 21 | self.max_position_embeddings = config.max_position_embeddings 22 | self.distance_embedding = nn.Embedding( 23 | 2 * config.max_position_embeddings - 1, self.attention_head_size 24 | ) 25 | 26 | def forward( 27 | self, 28 | hidden_states: torch.Tensor, 29 | attention_mask: Optional[torch.FloatTensor] = None, 30 | head_mask: Optional[torch.FloatTensor] = None, 31 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 32 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 33 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 34 | output_attentions: Optional[bool] = False, 35 | ) -> Tuple[torch.Tensor]: 36 | mixed_query_layer = self.query(hidden_states) 37 | 38 | # If this is instantiated as a cross-attention module, the keys 39 | # and values come from an encoder; the attention mask needs to be 40 | # such that the encoder's padding tokens are not attended to. 41 | is_cross_attention = encoder_hidden_states is not None 42 | 43 | if is_cross_attention and past_key_value is not None: 44 | # reuse k,v, cross_attentions 45 | key_layer = past_key_value[0] 46 | value_layer = past_key_value[1] 47 | attention_mask = encoder_attention_mask 48 | elif is_cross_attention: 49 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 50 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 51 | attention_mask = encoder_attention_mask 52 | elif past_key_value is not None: 53 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 54 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 55 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 56 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 57 | else: 58 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 59 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 60 | 61 | query_layer = self.transpose_for_scores(mixed_query_layer) 62 | 63 | use_cache = past_key_value is not None 64 | if self.is_decoder: 65 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 66 | # Further calls to cross_attention layer can then reuse all cross-attention 67 | # key/value_states (first "if" case) 68 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 69 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 70 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 71 | # if encoder bi-directional self-attention `past_key_value` is always `None` 72 | past_key_value = (key_layer, value_layer) 73 | 74 | # Take the dot product between "query" and "key" to get the raw attention scores. 75 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 76 | 77 | if self.position_embedding_type in { 78 | "relative_key", 79 | "relative_key_query", 80 | "cyclic_relative_key", 81 | }: 82 | query_length, key_length = query_layer.shape[2], key_layer.shape[2] 83 | if use_cache: 84 | position_ids_l = torch.tensor( 85 | key_length - 1, dtype=torch.long, device=hidden_states.device 86 | ).view(-1, 1) 87 | else: 88 | position_ids_l = torch.arange( 89 | query_length, dtype=torch.long, device=hidden_states.device 90 | ).view(-1, 1) 91 | position_ids_r = torch.arange( 92 | key_length, dtype=torch.long, device=hidden_states.device 93 | ).view(1, -1) 94 | 95 | if self.position_embedding_type in { 96 | "relative_key", 97 | "relative_key_query", 98 | }: 99 | distance = position_ids_l - position_ids_r 100 | 101 | positional_embedding = self.distance_embedding( 102 | distance + self.max_position_embeddings - 1 103 | ) 104 | positional_embedding = positional_embedding.to( 105 | dtype=query_layer.dtype 106 | ) # fp16 compatibility 107 | 108 | if self.position_embedding_type == "relative_key": 109 | relative_position_scores = torch.einsum( 110 | "bhld,lrd->bhlr", query_layer, positional_embedding 111 | ) 112 | attention_scores = attention_scores + relative_position_scores 113 | elif self.position_embedding_type == "relative_key_query": 114 | relative_position_scores_query = torch.einsum( 115 | "bhld,lrd->bhlr", query_layer, positional_embedding 116 | ) 117 | relative_position_scores_key = torch.einsum( 118 | "bhrd,lrd->bhlr", key_layer, positional_embedding 119 | ) 120 | attention_scores = ( 121 | attention_scores 122 | + relative_position_scores_query 123 | + relative_position_scores_key 124 | ) 125 | elif self.position_embedding_type == "cyclic_relative_key": 126 | distance = position_ids_l - position_ids_r 127 | 128 | # Attention mask at this point is already expanded and has zeros in the locations that should be attended to 129 | seq_lengths = (attention_mask == 0).sum(-1) 130 | forward_distance = distance % seq_lengths 131 | reverse_distance = distance % -seq_lengths 132 | 133 | forward_positional_embedding = self.distance_embedding( 134 | forward_distance + self.max_position_embeddings - 1 135 | ) 136 | reverse_positional_embedding = self.distance_embedding( 137 | reverse_distance + self.max_position_embeddings - 1 138 | ) 139 | forward_positional_embedding = forward_positional_embedding.to( 140 | dtype=query_layer.dtype 141 | ) # fp16 compatibility 142 | reverse_positional_embedding = reverse_positional_embedding.to( 143 | dtype=query_layer.dtype 144 | ) # fp16 compatibility 145 | 146 | # Need batch dimension in einsum for positional embeddings because they are different for each sequence in the batch 147 | relative_position_scores_forward = torch.einsum( 148 | "bhld,blrd->bhlr", query_layer, forward_positional_embedding 149 | ) 150 | relative_position_scores_reverse = torch.einsum( 151 | "bhld,blrd->bhlr", query_layer, reverse_positional_embedding 152 | ) 153 | attention_scores = ( 154 | attention_scores 155 | + relative_position_scores_forward 156 | + relative_position_scores_reverse 157 | ) 158 | 159 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 160 | if attention_mask is not None: 161 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 162 | attention_scores = attention_scores + attention_mask 163 | 164 | # Normalize the attention scores to probabilities. 165 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 166 | 167 | # This is actually dropping out entire tokens to attend to, which might 168 | # seem a bit unusual, but is taken from the original Transformer paper. 169 | attention_probs = self.dropout(attention_probs) 170 | 171 | # Mask heads if we want to 172 | if head_mask is not None: 173 | attention_probs = attention_probs * head_mask 174 | 175 | context_layer = torch.matmul(attention_probs, value_layer) 176 | 177 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 178 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 179 | context_layer = context_layer.view(new_context_layer_shape) 180 | 181 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 182 | 183 | if self.is_decoder: 184 | outputs = outputs + (past_key_value,) 185 | return outputs 186 | 187 | modeling_bert.BertSelfAttention = BertSelfAttentionWithCyclicEncoding # Patch 188 | -------------------------------------------------------------------------------- /ringer/models/components/embeddings.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn 6 | from transformers.models.bert import configuration_bert 7 | 8 | 9 | class GaussianFourierProjection(nn.Module): 10 | """Gaussian random features for encoding time steps. Built primarily for score-based models. 11 | 12 | Source: 13 | https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=YyQtV7155Nht 14 | """ 15 | 16 | def __init__(self, embed_dim: int, scale: float = 2 * torch.pi) -> None: 17 | super().__init__() 18 | # Randomly sample weights during initialization. These weights are fixed 19 | # during optimization and are not trainable. 20 | w = torch.randn(embed_dim // 2) * scale 21 | assert not w.requires_grad 22 | self.register_buffer("W", w) 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | """ 26 | takes as input the time vector and returns the time encoding 27 | time (x): (batch_size, ) 28 | output : (batch_size, embed_dim) 29 | """ 30 | if x.ndim > 1: 31 | x = x.squeeze() 32 | elif x.ndim < 1: 33 | x = x.unsqueeze(0) 34 | x_proj = x[:, None] * self.W[None, :] * 2 * torch.pi 35 | embed = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 36 | return embed 37 | 38 | 39 | class SinusoidalPositionEmbeddings(nn.Module): 40 | """Positional embeddings.""" 41 | 42 | def __init__(self, dim: int) -> None: 43 | super().__init__() 44 | self.dim = dim 45 | 46 | def forward(self, time: torch.Tensor) -> torch.Tensor: 47 | device = time.device 48 | half_dim = self.dim // 2 49 | embeddings = math.log(10000) / (half_dim - 1) 50 | # half_dim shape 51 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 52 | # outer product (batch, 1) x (1, half_dim) -> (batch x half_dim) 53 | embeddings = time[:, None] * embeddings[None, :] 54 | # sin and cosine embeddings 55 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 56 | return embeddings 57 | 58 | 59 | class PositionalEncoding(nn.Module): 60 | """Positional embedding for BERT. 61 | 62 | Source: https://pytorch.org/tutorials/beginner/transformer_tutorial.html 63 | """ 64 | 65 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: 66 | super().__init__() 67 | self.dropout = nn.Dropout(p=dropout) 68 | 69 | position = torch.arange(max_len).unsqueeze(1) 70 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 71 | pe = torch.zeros(max_len, 1, d_model) 72 | pe[:, 0, 0::2] = torch.sin(position * div_term) 73 | pe[:, 0, 1::2] = torch.cos(position * div_term) 74 | self.register_buffer("pe", pe) 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | """ 78 | Args: 79 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 80 | """ 81 | assert len(x.shape) == 3 82 | orig_shape = x.shape 83 | # x is a tensor of shape (batch_size, seq_len, embedding_dim) 84 | # permute to be (seq_len, batch_size, embedding_dim) 85 | x = x.permute(1, 0, 2) 86 | x += self.pe[: x.size(0)] 87 | # permute back to (batch_size, seq_len, embedding_dim) 88 | x = x.permute(1, 0, 2) 89 | assert x.shape == orig_shape, f"{x.shape} != {orig_shape}" 90 | return self.dropout(x) 91 | 92 | 93 | class BertEmbeddings(nn.Module): 94 | """Adds in positional embeddings if using absolute embeddings, adds layer norm and dropout.""" 95 | 96 | def __init__( 97 | self, config: configuration_bert.BertConfig, use_atom_embeddings: bool = False 98 | ) -> None: 99 | super().__init__() 100 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 101 | if self.position_embedding_type == "absolute": 102 | self.position_embeddings = nn.Embedding( 103 | config.max_position_embeddings, config.hidden_size 104 | ) 105 | self.register_buffer( 106 | "position_ids", 107 | torch.arange(config.max_position_embeddings).expand((1, -1)), 108 | ) 109 | 110 | # Used to embed backbone atom information 111 | self.use_atom_embeddings = use_atom_embeddings 112 | if self.use_atom_embeddings: 113 | self.atom_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 114 | 115 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 116 | # any TensorFlow checkpoint file 117 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 118 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 119 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 120 | 121 | def forward( 122 | self, 123 | input_embeds: torch.Tensor, 124 | position_ids: torch.LongTensor, 125 | atom_ids: Optional[torch.LongTensor] = None, 126 | ) -> torch.Tensor: 127 | assert position_ids is not None, "`position_ids` must be defined" 128 | embeddings = input_embeds 129 | if self.position_embedding_type == "absolute": 130 | position_embeddings = self.position_embeddings(position_ids) 131 | embeddings += position_embeddings 132 | 133 | if self.use_atom_embeddings and atom_ids is not None: 134 | atom_embeddings = self.atom_embeddings(atom_ids) 135 | embeddings += atom_embeddings 136 | 137 | embeddings = self.LayerNorm(embeddings) 138 | embeddings = self.dropout(embeddings) 139 | return embeddings 140 | -------------------------------------------------------------------------------- /ringer/models/components/output.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.activations import get_activation 6 | 7 | 8 | class AnglesPredictor(nn.Module): 9 | """Predict angles from the embeddings. For BERT, the MLM task is done using an architecture 10 | like d_model -> dense -> d_model -> activation -> layernorm -> dense -> d_output. 11 | 12 | https://github.com/huggingface/transformers/blob/v4.21.1/src/transformers/models/bert/modeling_bert.py#L681 13 | 14 | activation should be given as nn.ReLU for example -- NOT nn.ReLU() 15 | """ 16 | 17 | def __init__( 18 | self, 19 | d_model: int, 20 | d_out: int = 4, 21 | activation: Union[str, nn.Module] = "gelu", 22 | eps: float = 1e-12, 23 | ) -> None: 24 | super().__init__() 25 | self.d_model = d_model 26 | self.d_out = d_out 27 | self.dense1 = nn.Linear(d_model, d_model) 28 | 29 | if isinstance(activation, str): 30 | self.dense1_act = get_activation(activation) 31 | else: 32 | self.dense1_act = activation() 33 | self.layer_norm = nn.LayerNorm(d_model, eps=eps) 34 | 35 | self.dense2 = nn.Linear(d_model, d_out) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | x = self.dense1(x) 39 | x = self.dense1_act(x) 40 | x = self.layer_norm(x) 41 | x = self.dense2(x) 42 | return x 43 | -------------------------------------------------------------------------------- /ringer/reconstruct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import logging 5 | import multiprocessing 6 | import pickle 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 9 | 10 | import pandas as pd 11 | import ray 12 | import typer 13 | from rdkit import Chem 14 | from tqdm import tqdm 15 | 16 | from ringer.utils import reconstruction 17 | 18 | 19 | def load_json(path: Union[str, Path]) -> Dict[str, Any]: 20 | with open(path) as f: 21 | return json.load(f) 22 | 23 | 24 | def load_pickle(path: Union[str, Path]) -> Any: 25 | with open(path, "rb") as f: 26 | return pickle.load(f) 27 | 28 | 29 | def save_pickle(path: Union[str, Path], data: Any) -> None: 30 | with open(path, "wb") as f: 31 | pickle.dump(data, f) 32 | 33 | 34 | @ray.remote 35 | def reconstruct_ring( 36 | fname: str, 37 | data: Dict[str, Any], 38 | bond_dist_dict: Dict[str, float], 39 | bond_angle_dict: Optional[Dict[str, float]] = None, 40 | bond_angle_dev_dict: Optional[Dict[str, float]] = None, 41 | opt_init: Literal["best_dists", "average"] = "best_dists", 42 | skip_opt: bool = False, 43 | max_conf: Optional[int] = None, 44 | return_unsuccessful: bool = False, 45 | mol_opt_dir: Optional[Union[str, Path]] = None, 46 | ) -> Tuple[str, Tuple[Chem.Mol, List[pd.DataFrame]]]: 47 | mol = data["mol"] 48 | structure = data["structure"] 49 | 50 | result = reconstruction.reconstruct_ring( 51 | mol=mol, 52 | structure=structure, 53 | bond_dist_dict=bond_dist_dict, 54 | bond_angle_dict=bond_angle_dict, 55 | bond_angle_dev_dict=bond_angle_dev_dict, 56 | opt_init=opt_init, 57 | skip_opt=skip_opt, 58 | max_conf=max_conf, 59 | return_unsuccessful=return_unsuccessful, 60 | ) 61 | 62 | if mol_opt_dir is not None: 63 | mol_opt = result[0] 64 | mol_opt_path = Path(mol_opt_dir) / Path(fname).name 65 | save_pickle(mol_opt_path, mol_opt) 66 | 67 | return fname, result 68 | 69 | 70 | def get_as_iterator(obj_ids): 71 | # Returns results as they're ready 72 | # Order is preserved within the IDs that are ready, 73 | # but not if this iterator is converted to a list 74 | while obj_ids: 75 | done, obj_ids = ray.wait(obj_ids) 76 | yield ray.get(done[0]) 77 | 78 | 79 | def reconstruct( 80 | mol_dir: str, 81 | structures_path: str, 82 | out_dir: str, 83 | mean_distances_path: str, 84 | mean_angles_path: Optional[str] = None, 85 | std_angles_path: Optional[str] = None, 86 | opt_init: str = "best_dists", 87 | skip_opt: bool = False, 88 | max_conf: Optional[int] = None, 89 | save_unsuccessful: bool = False, 90 | ncpu: int = multiprocessing.cpu_count(), 91 | ) -> None: 92 | output_dir = Path(out_dir) 93 | output_dir.mkdir(exist_ok=True) 94 | 95 | # Load data 96 | mols = {path.name: load_pickle(path) for path in Path(mol_dir).glob("*.pickle")} 97 | structures_dict = load_pickle(structures_path) 98 | mean_bond_distances = load_json(mean_distances_path) 99 | mean_bond_angles = None if mean_angles_path is None else load_json(mean_angles_path) 100 | std_bond_angles = None if std_angles_path is None else load_json(std_angles_path) 101 | 102 | # Get mols in correct order 103 | mols_and_structures = { 104 | fname: dict(mol=mols[Path(fname).name], structure=structure) 105 | for fname, structure in structures_dict.items() 106 | } 107 | 108 | mol_opt_dir = output_dir / "reconstructed_mols" 109 | mol_opt_dir.mkdir(exist_ok=True) 110 | 111 | # Reconstruct 112 | if skip_opt: 113 | logging.info("Skipping opt") 114 | ray.init(num_cpus=ncpu) 115 | result_ids = [ 116 | reconstruct_ring.remote( 117 | fname, 118 | mol_and_structure, 119 | mean_bond_distances, 120 | bond_angle_dict=mean_bond_angles, 121 | bond_angle_dev_dict=std_bond_angles, 122 | opt_init=opt_init, 123 | skip_opt=skip_opt, 124 | max_conf=max_conf, 125 | return_unsuccessful=save_unsuccessful, 126 | mol_opt_dir=mol_opt_dir, 127 | ) 128 | for fname, mol_and_structure in mols_and_structures.items() 129 | ] 130 | mols_and_coords_opt = dict(tqdm(get_as_iterator(result_ids), total=len(result_ids))) 131 | 132 | # Post-process and dump data 133 | def get_structure_from_coords(coords: List[pd.DataFrame], name: str) -> Dict[str, Any]: 134 | # Convert list of coords to structure 135 | # Concatenate making hierarchical index of sample_idx and atom_idx 136 | coords_stacked = pd.concat( 137 | coords, keys=range(len(coords)), names=["sample_idx", "atom_idx"] 138 | ) 139 | # Pivot so we can get all samples for each feature from the outermost column 140 | coords_pivoted = coords_stacked.unstack(level="atom_idx") 141 | structure = { 142 | feat_name: coords_pivoted[feat_name] for feat_name in coords_pivoted.columns.levels[0] 143 | } 144 | structure["atom_labels"] = structures_dict[name]["atom_labels"] 145 | return structure 146 | 147 | reconstructed_structures_dict = {} 148 | unsuccessful_results = {} 149 | for fname, result in mols_and_coords_opt.items(): 150 | coords_opt = result[1] 151 | 152 | structure = get_structure_from_coords(coords_opt, fname) 153 | reconstructed_structures_dict[fname] = structure 154 | 155 | if save_unsuccessful: 156 | result_objs = result[2] 157 | if result_objs: 158 | unsuccessful_results[fname] = result_objs 159 | 160 | save_pickle(output_dir / "samples_reconstructed.pickle", reconstructed_structures_dict) 161 | if unsuccessful_results: 162 | save_pickle(output_dir / "unsuccessful_opts.pickle", unsuccessful_results) 163 | 164 | 165 | def main() -> None: 166 | logging.basicConfig(level=logging.INFO) 167 | typer.run(reconstruct) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /ringer/sidechain_reconstruction/__init__.py: -------------------------------------------------------------------------------- 1 | from .reconstruction import Macrocycle, Reconstructor, set_rdkit_geometries 2 | -------------------------------------------------------------------------------- /ringer/sidechain_reconstruction/data/reconstruction_data.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/ringer/sidechain_reconstruction/data/reconstruction_data.pickle -------------------------------------------------------------------------------- /ringer/sidechain_reconstruction/transforms.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | NUM_DIMENSIONS = 3 9 | Triplet = collections.namedtuple("Triplet", "a, b, c") 10 | 11 | DEFAULT_OFFSETS = (1, 2, 3) 12 | RINGER_OFFSETS = (0, 1, 1) 13 | NUM_INTERNALS = 3 14 | 15 | offsets = RINGER_OFFSETS 16 | 17 | 18 | def extract_bd_theta_np(positions: np.ndarray): 19 | """Extracts the bond distance and angle in radians for a given tetrad structure in numpy array. 20 | 21 | Args: 22 | positions (np.ndarray): An array representing the positions of atoms in a tetrad structure. 23 | The structure is given by (a,b,c,d), where d is a branching atom. 24 | Dimensions: K x B x 4 x 3. 25 | 26 | Returns: 27 | Tuple[np.ndarray, np.ndarray]: The calculated bond distance and angle in radians. 28 | """ 29 | a = positions[:, :, 0, :] 30 | b = positions[:, :, 1, :] 31 | c = positions[:, :, 2, :] 32 | d = positions[:, :, 3, :] 33 | 34 | # calculate the main vectors 35 | ba = a - b 36 | bc = c - b 37 | bd = d - b 38 | 39 | cross_product = np.cross(bc, ba, axis=-1) 40 | 41 | norm = np.linalg.norm(cross_product, axis=-1, keepdims=False) # Shape: (B, 1) 42 | bd_norm = np.linalg.norm(bd, axis=-1, keepdims=False) 43 | 44 | dot_prod = (cross_product * bd).sum(axis=-1) 45 | cos_angle = dot_prod / (norm * bd_norm) 46 | angle_radians = np.arccos(cos_angle) 47 | 48 | return bd_norm, angle_radians 49 | 50 | 51 | def extract_bd_theta(positions: torch.FloatTensor): 52 | """Extracts the bond distance and angle in radians for a given tetrad structure in torch 53 | tensor. 54 | 55 | Args: 56 | positions (torch.FloatTensor): A tensor representing the positions of atoms in a tetrad structure. 57 | The structure is given by (a,b,c,d), where d is a branching atom. 58 | Dimensions: K x B x 4 x 3. 59 | 60 | Returns: 61 | Tuple[torch.Tensor, torch.Tensor]: The calculated bond distance and angle in radians. 62 | """ 63 | a = positions[:, :, 0, :] 64 | b = positions[:, :, 1, :] 65 | c = positions[:, :, 2, :] 66 | d = positions[:, :, 3, :] 67 | 68 | # calculate the main vectors 69 | ba = a - b 70 | bc = c - b 71 | bd = d - b 72 | 73 | cross_product = torch.cross(bc, ba, dim=-1) 74 | norm = torch.norm(cross_product, dim=-1, keepdim=False) 75 | bd_norm = torch.norm(bd, dim=-1, keepdim=False) 76 | 77 | # calculate the angle between the normal vector and bd vector 78 | dot_prod = (cross_product * bd).sum(dim=-1) 79 | cos_angle = dot_prod / (norm * bd_norm) 80 | angle_radians = torch.arccos(cos_angle) 81 | 82 | return bd_norm, angle_radians 83 | 84 | 85 | class NeRF(object): 86 | """Natural Extension Reference Frame (NeRF). 87 | 88 | Constructs cartesian coordinates from internal coordinates. NeRF requires a dependency matrix 89 | that sequentially constructs the molecule. As such, it creates a set of distances, angles, and 90 | dihedrals, (the internal coordinates) that correspond to specific indices. Hence, for 91 | (1,2,3,4), atom_id 4 will be placed with: a distance using (3,4), and an angle about (2,3,4), 92 | and the torsion with (1,2,3,4) that collectively define position 4. This is different from 93 | RINGER, which for atom 2 would use bond distance (2,3), angle (1,2,3), and dihedral (1,2,3,4), 94 | since these all are focused on atom_id 2. 95 | """ 96 | 97 | def __init__(self, float_type: torch.double): 98 | """Torch doubles as preferred for numerical stability, but can use floats for training if 99 | used in the loop.""" 100 | self.float_type = float_type 101 | 102 | if float_type == torch.double: 103 | self.np_float_type = np.float64 104 | elif float_type == torch.float: 105 | self.np_float_type = np.float32 106 | 107 | self.init_matrix = np.array( 108 | [[-np.sqrt(1.0 / 2.0), np.sqrt(3.0 / 2.0), 0], [-np.sqrt(2.0), 0, 0], [0, 0, 0]], 109 | dtype=self.np_float_type, 110 | ) 111 | 112 | @staticmethod 113 | def convert_to_points(bonds: torch.Tensor, angles: torch.Tensor, dihedrals: torch.Tensor): 114 | """ 115 | Convert bonds, angles, and dihedrals to points. 116 | ----------- 117 | bonds: torch.Tensor, 118 | bond lengths (distance in angstroms) 119 | angles: torch.Tensor, 120 | bond angles theta, in radians 121 | dihedrals: torch.Tensor, 122 | bond dihedrals, in radians. 123 | """ 124 | r_cos_theta = bonds * torch.cos(torch.pi - angles) 125 | r_sin_theta = bonds * torch.sin(torch.pi - angles) 126 | 127 | points_x = r_cos_theta 128 | points_y = r_sin_theta * torch.cos(dihedrals) 129 | points_z = r_sin_theta * torch.sin(dihedrals) 130 | 131 | points = torch.stack([points_x, points_y, points_z]) # 3 x B x L 132 | return points.permute(2, 1, 0) # L x B x 3 133 | 134 | @staticmethod 135 | def extend(point, last_three_points): 136 | """ 137 | point: B x 3, 138 | The coordinates for a single position. 139 | last_three_points: Triple(NamedTuple), 140 | NamedTuple container for holding the last three coordinates. 141 | """ 142 | bc = F.normalize(last_three_points.c - last_three_points.b, dim=-1) 143 | n = F.normalize(torch.cross(last_three_points.b - last_three_points.a, bc), dim=-1) 144 | m = torch.stack([bc, torch.cross(n, bc), n]).permute(1, 2, 0) 145 | 146 | point = point.unsqueeze(2) # Expand from B x 3 to B x 3 x 1 to enable bmm 147 | 148 | return torch.bmm(m, point).squeeze() + last_three_points.c 149 | 150 | @staticmethod 151 | def validate_index(index: torch.LongTensor): 152 | """Perform a simple validation that the index can correctly satisfy Cartesian generation in 153 | sequential order. 154 | 155 | If returns True, then we satisfy the correct traversal over the indices. 156 | """ 157 | index = index.clone() 158 | new_index = torch.zeros(index.max() + NUM_DIMENSIONS + 1).type(torch.LongTensor) 159 | 160 | # We choose an arbitrary starting point set the last three indices as the new_index 161 | for i in range(3): 162 | index[i][: 3 - i] = torch.arange(i - 3, 0) 163 | 164 | # Because these are virtual nodes, they don't matter and we set these to True 165 | new_index[-3:] = 1 166 | 167 | # iterate simply to make sure things are zero 168 | for i in index: 169 | # assume if the previous three are set, then we can set the fourth 170 | if torch.all(new_index[i[:3]]): 171 | new_index[i[3]] = 1 172 | else: 173 | print(f"index failed for {i}") 174 | return False 175 | return True 176 | 177 | @staticmethod 178 | def build_indices(length, size, offset=0): 179 | a, b = torch.meshgrid(torch.arange(size), torch.arange(length), indexing="xy") 180 | index = (a + b - offset) % length 181 | return index 182 | 183 | def nerf( 184 | self, 185 | r_theta_phi: torch.Tensor, 186 | quadruples: Optional[torch.LongTensor] = None, 187 | validate_index: bool = True, 188 | ) -> torch.Tensor: 189 | """ 190 | Apply the Natural extension Reference Frame Method to a set of internal coords. 191 | ---------- 192 | Args: 193 | r_theta_phi: torch.FloatTensor, 194 | B x L x 3 Tensor of distances (r), angles (theta), and dihedrals (phi). 195 | quandruples: torch.LongTensor, 196 | A tensor of indices corresponding to each set of internal_coordinates. 197 | Returns: 198 | xyzs: torch.Tensor, 199 | B x L x 3 Tensor of x,y,z coordinates 200 | """ 201 | if quadruples is None: 202 | quadruples = self.build_indices(r_theta_phi.size(1), size=4) # get quadruples 203 | # Clone the indices 204 | if validate_index: 205 | self.validate_index(quadruples) 206 | 207 | quadruples = quadruples.clone() 208 | 209 | atom_indices = quadruples[ 210 | :, -1 211 | ] # the last column of the B x 4 index corresponds to the atom_ids 212 | 213 | # This creates three virtual atoms that overwrite existing ones for the start index 214 | for i in range(3): 215 | quadruples[i][: 3 - i] = torch.arange(i - 3, 0) 216 | 217 | batch_size = r_theta_phi.shape[0] 218 | points = self.convert_to_points( 219 | r_theta_phi[:, :, 0], r_theta_phi[:, :, 1], r_theta_phi[:, :, 2] 220 | ) 221 | 222 | # we create a new index to index into 223 | new_index = torch.zeros( 224 | quadruples.max() + NUM_DIMENSIONS + 1, batch_size, NUM_DIMENSIONS 225 | ).type(torch.DoubleTensor) 226 | # index those points into the new_index to be looked up by the last col 227 | new_index[quadruples[:, -1]] = points 228 | 229 | init_tensor = torch.from_numpy(self.init_matrix) 230 | new_index[-3:, :, :] = torch.cat( 231 | [row.repeat(1).view(1, 3) for row in init_tensor] 232 | ).unsqueeze(1) 233 | 234 | coords_list = new_index.clone() 235 | for quadruple in quadruples: 236 | prev_three_coords = Triplet( 237 | coords_list[quadruple[0]], coords_list[quadruple[1]], coords_list[quadruple[2]] 238 | ) 239 | coords_list[quadruple[3]] = self.extend(coords_list[quadruple[3]], prev_three_coords) 240 | 241 | # return just the columns corresponding to the atom_indices 242 | return (coords_list.permute(1, 0, 2))[:, atom_indices] 243 | 244 | def __call__(self, r_theta_phi: torch.Tensor, quadruples: Optional[torch.Tensor] = None): 245 | """Call NeRF to perform the reconstruction.""" 246 | return self.nerf(r_theta_phi, quadruples) 247 | 248 | 249 | class TetraPlacer: 250 | def calculate_bd(self, positions: torch.Tensor, bd_norm: torch.Tensor, theta: torch.Tensor): 251 | """Calculates the branched atom distance for a given structure. 252 | 253 | Args: 254 | positions (torch.Tensor): A tensor representing the positions of atoms. 255 | Dimensions: K x B x 3 x 3, with K as the batch size, B the number of structures in the batch, and 3x3 as the position vectors. 256 | bd_norm (torch.Tensor): A tensor representing the norms corresponding to the branched atom distance. Vector based, size N x 1. 257 | theta (torch.Tensor): A tensor representing the angle in degrees. Vector based, size N x 1. 258 | 259 | Returns: 260 | torch.Tensor: The calculated branched atom distance tensor. 261 | """ 262 | a = positions[:, :, 0, :] 263 | b = positions[:, :, 1, :] 264 | c = positions[:, :, 2, :] 265 | 266 | # calculate the main vectors 267 | ba = a - b 268 | bc = c - b 269 | 270 | centroid = torch.mean(torch.stack([a, c]), dim=0) 271 | centroid_b = b - centroid # 272 | 273 | a_prime = a + centroid_b 274 | c_prime = c + centroid_b 275 | 276 | cross_product = torch.cross(bc, ba, dim=-1) 277 | norm = torch.norm(cross_product, dim=-1, keepdim=True) 278 | 279 | normalized_cross_product = (cross_product / norm) * bd_norm.unsqueeze(-1) + b 280 | 281 | new_bd = self.rotate_ac(a_prime, normalized_cross_product, c_prime, theta, b_only=True) 282 | 283 | return new_bd 284 | 285 | def rotate_ac(self, a, b, c, theta, b_only: bool = True): 286 | """Performs batch rotation for the tensor, withholding K x B x 3 tensors. 287 | 288 | Args: 289 | a, b, c (torch.Tensor): Input tensors of dimensions K x B x 3, representing initial position vectors for each batch. 290 | theta (torch.Tensor): The angle of rotation in radians. Vector based, size N x 1. 291 | b_only (bool): If True, only returns tensor b_rot, otherwise also returns a and c tensors. Default is True. 292 | 293 | Returns: 294 | torch.Tensor: The rotated tensor b_rot or (if b_only=False) a, b_rot, c. 295 | """ 296 | K, B, _ = a.shape 297 | 298 | # Expand theta for compatibility with K and B dimensions 299 | theta = theta.unsqueeze(-1) # Now shape (K, B, 1) for compatibility with broadcasting 300 | 301 | # Calculate normalized ac direction 302 | ac = c - a 303 | ac_norm = ac / torch.norm(ac, dim=2, keepdim=True) 304 | 305 | zeros = torch.zeros_like(theta) 306 | k = ac_norm 307 | 308 | K_matrix = torch.cat( 309 | [ 310 | zeros, 311 | -k[..., 2:3], 312 | k[..., 1:2], 313 | k[..., 2:3], 314 | zeros, 315 | -k[..., 0:1], 316 | -k[..., 1:2], 317 | k[..., 0:1], 318 | zeros, 319 | ], 320 | dim=2, 321 | ).reshape( 322 | K * B, 3, 3 323 | ) # Reshape for batch matrix multiplication 324 | 325 | # Calculate rotation matrix R and the three terms 326 | R_1 = torch.eye(3, device=k.device).repeat(K * B, 1, 1) # Repeat eye for each batch 327 | R_2 = torch.sin(theta).reshape(K * B, 1, 1) * K_matrix 328 | R_3 = (1 - torch.cos(theta).reshape(K * B, 1, 1)) * torch.bmm(K_matrix, K_matrix) 329 | R = R_1 + R_2 + R_3 330 | 331 | # Translate b to the origin (relative to a), then rotate, then translate back 332 | b_translated = (b - a).reshape(K * B, 3, 1) # Reshape for batch matrix multiplication 333 | b_rot = torch.bmm(R, b_translated).reshape(K, B, 3) + a 334 | if b_only: 335 | return b_rot 336 | return a, b_rot, c # returns a and as a sanity check 337 | 338 | def add_branched_points( 339 | self, 340 | xyzs, 341 | quad_indices: torch.LongTensor, 342 | quad_bond_distances: torch.FloatTensor, 343 | quad_bond_thetas: torch.FloatTensor, 344 | copy: bool = True, 345 | ): 346 | """Adds a new branched point for each structure in the batch. 347 | 348 | Args: 349 | xyzs (torch.Tensor): The current tensor of atom positions in each structure, with dimensions K x M x 3. 350 | quad_indices (torch.FloatTensor): The indices of quads in the structures, with dimensions N x 4. 351 | quad_bond_distances (torch.FloatTensor): The distances of quad bonds, with dimensions N x 1. 352 | quad_bond_thetas (torch.FloatTensor): Theta values for each quad bond, with dimensions N x 1. 353 | copy (bool): If True, creates a copy of `xyzs` to perform operations on, else operates in-place. Default is True. 354 | 355 | Returns: 356 | torch.Tensor: The tensor of atom positions after adding the branched points. 357 | """ 358 | if copy: 359 | xyzs = xyzs.clone() 360 | triples = quad_indices[:, :3] 361 | target_index = quad_indices[:, -1] 362 | 363 | positions = xyzs[:, triples, :] 364 | 365 | k = xyzs.shape[0] 366 | quad_bd = quad_bond_distances.repeat((k, 1)) 367 | quad_theta = quad_bond_thetas.repeat((k, 1)) 368 | 369 | # print(theta_mean.shape) 370 | target_xyz = self.calculate_bd(positions, quad_bd, quad_theta) 371 | xyzs[:, target_index, :] = target_xyz 372 | 373 | return xyzs 374 | 375 | 376 | class RigidTransform(object): 377 | """Implementation of Kabsch algorithm in PyTorch to handle batching. 378 | 379 | Does not handle reflections at the moment. 380 | """ 381 | 382 | def __init__(self): 383 | """R is the rotation matrix, t is the translation matrix.""" 384 | self.R = None 385 | self.t = None 386 | 387 | def fit(self, source: torch.Tensor, target: torch.Tensor) -> None: 388 | assert source.shape == target.shape 389 | 390 | # find mean row wise 391 | centroid_source = torch.mean(source, dim=1) 392 | centroid_target = torch.mean(target, dim=1) 393 | 394 | # Center the data along the centroid target 395 | source_m = source - centroid_target.unsqueeze(dim=1) 396 | target_m = target - centroid_target.unsqueeze(dim=1) 397 | 398 | H = torch.bmm(source_m.permute(0, 2, 1), target_m) 399 | U, S, Vt = torch.linalg.svd(H) 400 | 401 | R = torch.bmm(Vt.permute(0, 2, 1), U.permute(0, 2, 1)) 402 | t = -R @ centroid_source.unsqueeze(dim=2) + centroid_target.unsqueeze(dim=2) 403 | 404 | # Store these 405 | self.R = R 406 | self.t = t 407 | 408 | @staticmethod 409 | def get_reflections(R): 410 | """Not implemented at the moment.""" 411 | return torch.where(torch.linalg.det(R) < 0) 412 | 413 | def transform(self, source): 414 | return torch.bmm(source, self.R.permute(0, 2, 1)) + self.t.permute(0, 2, 1) 415 | 416 | def fit_transform(self, source: torch.Tensor, target: torch.Tensor): 417 | """Fit and transform the data.""" 418 | self.fit(source, target) 419 | return self.transform(source) 420 | 421 | @staticmethod 422 | def rmsd(source_tensor, target_tensor): 423 | """Calculate the RMSD between the source and the target tensors.""" 424 | assert source_tensor.shape[1] == target_tensor.shape[1] 425 | return ( 426 | ((source_tensor - target_tensor) ** 2).sum(-1).sum(-1) / source_tensor.shape[1] 427 | ) ** 0.5 428 | 429 | def __call__(self, source: torch.Tensor): 430 | transformed = self.transform(source) 431 | return transformed 432 | 433 | 434 | class InverseNeRF(object): 435 | """Construct the internal coordinates, r_theta_phi, matrix from xyz coordinates. 436 | 437 | Currently only works for backbones. 438 | """ 439 | 440 | def __init__( 441 | self, 442 | distances_offset: int = 1, # 0 1 443 | angles_offset: int = 2, # 1 1 444 | dihedrals_offset: int = 3, 445 | ): # 1 2 446 | """ 447 | distances_offset: int = 1, 448 | the distance from the previous atom_id to the current one. 449 | angles_offset: int = 2, 450 | the angle defining the atom, starting two atoms prior. 451 | dihedrals_offset: int = 3, 452 | get the dihedral corresponding to the atom two positions prior, 453 | which requires finding the start of the quadruple 3 points prior. 454 | """ 455 | self.distances_offset = distances_offset 456 | self.angles_offset = angles_offset 457 | self.dihedrals_offset = dihedrals_offset 458 | 459 | @staticmethod 460 | def build_indices(length, size, offset: int = 0) -> torch.Tensor: 461 | """Build wrapped indices for cycles only.""" 462 | a, b = torch.meshgrid(torch.arange(size), torch.arange(length), indexing="xy") 463 | index = (a + b - offset) % length 464 | return index 465 | 466 | def distances(self, positions, index: Optional[torch.Tensor] = None): 467 | """Calculate the distances with tuples, starting with the offset used.""" 468 | if index is None: 469 | index = self.build_indices(positions.size(1), 2, self.distances_offset) 470 | doubles = positions[:, index, :] 471 | 472 | return torch.norm(doubles[..., 1, :] - doubles[..., 0, :], dim=-1) 473 | 474 | def angles(self, positions, index: Optional[torch.Tensor] = None): 475 | """Calculate bond angles.""" 476 | if index is None: 477 | index = self.build_indices(positions.size(1), 3, self.angles_offset) 478 | triples = positions[:, index, :] 479 | 480 | a1 = triples[..., 1, :] - triples[..., 0, :] 481 | a2 = triples[..., 2, :] - triples[..., 1, :] 482 | 483 | a1 = F.normalize(a1, dim=-1) 484 | a2 = F.normalize(a2, dim=-1) 485 | 486 | rad = torch.pi - torch.arccos((a1 * a2).sum(dim=-1)) 487 | 488 | return rad 489 | 490 | def dihedrals(self, positions, index: Optional[torch.Tensor] = None): 491 | """Calculate dihedral angles from sets of quadruples.""" 492 | if index is None: 493 | index = self.build_indices( 494 | positions.size(1), 4, self.dihedrals_offset # dihedrals require 4 indices 495 | ) 496 | quadruples = positions[:, index, :] 497 | 498 | a1 = quadruples[..., 1, :] - quadruples[..., 0, :] 499 | a2 = quadruples[..., 2, :] - quadruples[..., 1, :] 500 | a3 = quadruples[..., 3, :] - quadruples[..., 2, :] 501 | 502 | v1 = torch.cross(a1, a2, dim=-1) 503 | v1 = F.normalize(v1, dim=-1) 504 | v2 = torch.cross(a2, a3, dim=-1) 505 | v2 = F.normalize(v2, dim=-1) 506 | 507 | sign = torch.sign((v1 * a3).sum(dim=-1)) 508 | rad = torch.arccos((v1 * v2).sum(-1) / ((v1**2).sum(-1) * (v2**2).sum(-1)) ** 0.5) 509 | 510 | rad = sign * rad 511 | return rad 512 | 513 | def inverse_nerf(self, positions): 514 | """Collect distances, angles, and diehdrals to generate r_theta_phi matrix. 515 | 516 | Params: 517 | positions: torch.FloatTensor, 518 | N x L x 3 matrix of xyz coordinates (positions). 519 | Returns: 520 | N x L x 3 matrix of r, theta, and phis. 521 | """ 522 | r = self.distances(positions) 523 | theta = self.angles(positions) 524 | phi = self.dihedrals(positions) 525 | r_theta_phi = torch.stack([r, theta, phi]).permute(1, 2, 0) 526 | return r_theta_phi 527 | 528 | def reindex(self, array, offset_differences, length: Optional[int] = None): 529 | """Reindex the array based on the differences in offsets.""" 530 | if length is None: 531 | length = array.size(1) 532 | new_index = self.build_reindex(offset_differences, length) 533 | return array[:, new_index, torch.arange(array.size(-1))[None, :]] 534 | 535 | def build_reindex(self, offsets, length): 536 | reindex = (torch.arange(length).reshape(-1, 1) + np.array(offsets)) % length 537 | return reindex 538 | 539 | @staticmethod 540 | def convert_offsets(source_offset, target_offset): 541 | """Convert the difference in offsets. 542 | 543 | b 544 | """ 545 | return tuple(source_offset[i] - target_offset[i] for i in range(NUM_INTERNALS)) 546 | 547 | def __call__(self, positions): 548 | return self.inverse_nerf(positions) 549 | -------------------------------------------------------------------------------- /ringer/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import logging 5 | import multiprocessing 6 | import shutil 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Literal, Optional, Sequence, Union 10 | 11 | import pytorch_lightning as pl 12 | import torch 13 | import typer 14 | from pytorch_lightning.strategies.ddp import DDPStrategy 15 | from torch.utils.data import DataLoader 16 | from transformers.models.bert import configuration_bert 17 | 18 | import ringer 19 | from ringer.models import bert_for_diffusion 20 | from ringer.utils import data_loading, utils, variance_schedules 21 | 22 | POSITION_EMBEDDING_TYPES = Literal[ 23 | "absolute", "relative_key", "relative_key_query", "cyclic_relative_key" 24 | ] 25 | 26 | assert torch.cuda.is_available(), "Requires CUDA to train" 27 | torch.manual_seed(6489) 28 | torch.backends.cudnn.benchmark = False 29 | 30 | 31 | @pl.utilities.rank_zero_only 32 | def record_args(func_args: Dict[str, Any], results_dir: Path, overwrite: bool = False) -> None: 33 | # Create results directory 34 | if results_dir.exists(): 35 | if overwrite: 36 | logging.warning(f"Removing old results directory: {results_dir}") 37 | shutil.rmtree(results_dir) 38 | else: 39 | raise IOError(f"'{results_dir}' already exists") 40 | results_dir.mkdir() 41 | 42 | func_args_serializable = func_args.copy() 43 | for k, v in func_args_serializable.items(): 44 | if isinstance(v, Path): 45 | func_args_serializable[k] = str(v) 46 | 47 | with open(results_dir / "training_args.json", "w") as sink: 48 | logging.info(f"Writing training args to {sink.name}") 49 | json.dump(func_args_serializable, sink, indent=4) 50 | for k, v in func_args.items(): 51 | logging.info(f"Training argument: {k}={v}") 52 | 53 | 54 | def build_callbacks( 55 | out_dir: Union[str, Path], early_stop_patience: Optional[int] = None, swa: bool = False 56 | ) -> List[pl.Callback]: 57 | # Create the logging dir 58 | out_dir = Path(out_dir) 59 | best_validation_dir = out_dir / "models/best_by_valid" 60 | best_train_dir = out_dir / "models/best_by_train" 61 | (out_dir / "logs/lightning_logs").mkdir(parents=True, exist_ok=True) 62 | best_validation_dir.mkdir(parents=True, exist_ok=True) 63 | best_train_dir.mkdir(parents=True, exist_ok=True) 64 | 65 | callbacks = [ 66 | pl.callbacks.ModelCheckpoint( 67 | monitor="val_loss", 68 | dirpath=best_validation_dir, 69 | save_top_k=5, 70 | save_weights_only=True, 71 | mode="min", 72 | ), 73 | pl.callbacks.ModelCheckpoint( 74 | monitor="train_loss", 75 | dirpath=best_train_dir, 76 | save_top_k=5, 77 | save_weights_only=True, 78 | mode="min", 79 | ), 80 | pl.callbacks.LearningRateMonitor(logging_interval="epoch"), 81 | ] 82 | 83 | if early_stop_patience is not None and early_stop_patience > 0: 84 | logging.info(f"Using early stopping with patience {early_stop_patience}") 85 | callbacks.append( 86 | pl.callbacks.early_stopping.EarlyStopping( 87 | monitor="val_loss", 88 | patience=early_stop_patience, 89 | verbose=True, 90 | mode="min", 91 | ) 92 | ) 93 | 94 | if swa: 95 | # Stochastic weight averaging 96 | callbacks.append(pl.callbacks.StochasticWeightAveraging()) 97 | logging.info(f"Model callbacks: {callbacks}") 98 | 99 | return callbacks 100 | 101 | 102 | def train( 103 | # Output 104 | out_dir: Union[str, Path] = "results", 105 | # Data loading and noising process 106 | data_dir: Union[ 107 | str, Path 108 | ] = "", # Directory containing pickle files, can be relative to ../data 109 | split_sizes: Sequence[float] = (0.8, 0.1, 0.1), 110 | internal_coordinates_definitions: data_loading.INTERNAL_COORDINATES_DEFINITIONS = "angles", 111 | use_atom_features: bool = True, # Condition model on atom sequence 112 | atom_feature_fingerprint_radius: int = 3, # Morgan fingerprint radius for atom side chains 113 | atom_feature_fingerprint_size: int = 32, # Morgan fingerprint size for atom side chains 114 | atom_feature_embed_size: int = 24, # Transform atom features to this size before concatenating with angle features 115 | max_conf: int = 30, 116 | timesteps: int = 50, 117 | variance_schedule: variance_schedules.SCHEDULES = "cosine", 118 | variance_scale: float = 1.0, 119 | use_feat_mask: bool = True, 120 | mask_noise: bool = False, 121 | mask_noise_for_features: Optional[List[str]] = None, 122 | # Model architecture 123 | restart_dir: Optional[Union[str, Path]] = None, # Restart from checkpoint 124 | time_encoding: bert_for_diffusion.TIME_ENCODING = "gaussian_fourier", 125 | num_hidden_layers: int = 3, 126 | hidden_size: int = 24, 127 | intermediate_size: int = 96, 128 | num_heads: int = 3, 129 | position_embedding_type: POSITION_EMBEDDING_TYPES = "cyclic_relative_key", 130 | dropout_p: float = 0.1, 131 | decoder: bert_for_diffusion.DECODER_HEAD = "mlp", 132 | # Training strategy 133 | batch_size: int = 64, 134 | loss: bert_for_diffusion.LOSS_KEYS = "smooth_l1", 135 | l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0 136 | l1_norm: float = 0.0, 137 | circle_reg: float = 0.0, 138 | gradient_clip: float = 1.0, # From BERT trainer 139 | lr: float = 5e-5, # Default lr for huggingface BERT trainer 140 | lr_scheduler: bert_for_diffusion.LR_SCHEDULE = "LinearWarmup", 141 | min_epochs: Optional[int] = None, 142 | max_epochs: int = 2000, 143 | warmup_epochs: int = 100, 144 | weights: Optional[Dict[str, float]] = None, 145 | early_stop_patience: int = 0, # Set to 0 to disable early stopping 146 | use_swa: bool = False, # Stochastic weight averaging can improve training genearlization 147 | # Miscellaneous 148 | exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test 149 | use_data_cache: bool = True, 150 | data_cache_dir: Optional[Union[str, Path]] = None, 151 | unsafe_cache: bool = False, 152 | ncpu: int = multiprocessing.cpu_count(), 153 | ngpu: int = -1, # -1 for all GPUs 154 | write_validation_preds: bool = False, # Write validation predictions to disk at each epoch 155 | profile: bool = False, 156 | overwrite: bool = False, # Overwrite results dir 157 | wandb_config: Optional[Dict[str, str]] = None, 158 | ) -> None: 159 | """Main training loop.""" 160 | # Record the args given to the function before we create more vars 161 | func_args = locals() 162 | 163 | assert data_dir 164 | data_dir = Path(data_dir) 165 | out_dir = Path(out_dir) 166 | record_args(func_args, out_dir, overwrite=overwrite) 167 | 168 | # Get datasets and wrap them in data_loaders 169 | dsets = data_loading.get_datasets( 170 | data_dir=data_dir, 171 | internal_coordinates_definitions=internal_coordinates_definitions, 172 | splits=["train", "validation"], 173 | split_sizes=split_sizes, 174 | use_atom_features=use_atom_features, 175 | atom_feature_fingerprint_radius=atom_feature_fingerprint_radius, 176 | atom_feature_fingerprint_size=atom_feature_fingerprint_size, 177 | max_conf=max_conf, 178 | timesteps=timesteps, 179 | weights=weights, 180 | variance_schedule=variance_schedule, 181 | variance_scale=variance_scale, 182 | mask_noise=mask_noise, 183 | mask_noise_for_features=mask_noise_for_features, 184 | exhaustive_t=exhaustive_validation_t, 185 | use_cache=use_data_cache, 186 | cache_dir=data_cache_dir, 187 | unsafe_cache=unsafe_cache, 188 | num_proc=ncpu, 189 | ) 190 | 191 | # Given total (effective) batch size, calculate batch size per GPU 192 | if torch.cuda.is_available(): 193 | device_count = torch.cuda.device_count() if ngpu == -1 else ngpu 194 | batch_size_per_device = max(int(batch_size / device_count), 1) 195 | pl.utilities.rank_zero_info( 196 | f"Given batch size: {batch_size} --> per-GPU batch size with {device_count} GPUs: {batch_size_per_device}" 197 | ) 198 | else: 199 | batch_size_per_device = batch_size 200 | 201 | data_loaders = { 202 | split: DataLoader( 203 | dataset=dset, 204 | batch_size=batch_size_per_device, 205 | shuffle=split == "train", 206 | num_workers=ncpu, 207 | pin_memory=True, 208 | ) 209 | for split, dset in dsets.items() 210 | } 211 | 212 | # Record the means in the output directory 213 | with open(out_dir / "training_mean_offset.json", "w") as sink: 214 | json.dump(dsets["train"].dset.means_dict, sink, indent=4) 215 | with open(out_dir / "training_mean_distances.json", "w") as sink: 216 | json.dump(dsets["train"].dset.atom_type_means["distance"], sink, indent=4) 217 | with open(out_dir / "training_mean_angles.json", "w") as sink: 218 | json.dump(dsets["train"].dset.atom_type_means["angle"], sink, indent=4) 219 | with open(out_dir / "training_std_angles.json", "w") as sink: 220 | json.dump(dsets["train"].dset.atom_type_stdevs["angle"], sink, indent=4) 221 | 222 | # Shape of the input is (batch_size, timesteps, features) 223 | sample_item = dsets["train"][0] 224 | sample_input = sample_item["corrupted"] 225 | model_n_inputs = sample_input.shape[-1] 226 | logging.info(f"Auto detected {model_n_inputs} inputs") 227 | 228 | if use_atom_features: 229 | sample_atom_features = sample_item["atom_features"] 230 | atom_feature_size = sample_atom_features.shape[-1] 231 | logging.info(f"Auto detected atom feature size: {atom_feature_size}") 232 | else: 233 | atom_feature_size = None 234 | 235 | logging.info(f"Using loss function: {loss}") 236 | config = configuration_bert.BertConfig( 237 | max_position_embeddings=dsets["train"].pad, 238 | num_attention_heads=num_heads, 239 | hidden_size=hidden_size, 240 | intermediate_size=intermediate_size, 241 | num_hidden_layers=num_hidden_layers, 242 | position_embedding_type=position_embedding_type, 243 | hidden_dropout_prob=dropout_p, 244 | attention_probs_dropout_prob=dropout_p, 245 | use_cache=False, 246 | ) 247 | bert_kwargs = dict( 248 | lr=lr, 249 | loss=loss, 250 | use_feat_mask=use_feat_mask, 251 | l2=l2_norm, 252 | l1=l1_norm, 253 | circle_reg=circle_reg, 254 | epochs=max_epochs, 255 | warmup_epochs=warmup_epochs, 256 | lr_scheduler=lr_scheduler, 257 | write_preds_to_dir=out_dir / "validation_preds" if write_validation_preds else None, 258 | ) 259 | if restart_dir is None: 260 | model = bert_for_diffusion.BertForDiffusion( 261 | config=config, 262 | ft_is_angular=dsets["train"].feature_is_angular, 263 | ft_names=dsets["train"].feature_names, 264 | time_encoding=time_encoding, 265 | decoder=decoder, 266 | atom_feature_size=atom_feature_size, 267 | atom_feature_embed_size=atom_feature_embed_size, 268 | **bert_kwargs, 269 | ) 270 | else: 271 | model = bert_for_diffusion.BertForDiffusion.from_dir( 272 | dir_name=restart_dir, 273 | **bert_kwargs, 274 | ) 275 | model.config.save_pretrained(out_dir) 276 | 277 | callbacks = build_callbacks( 278 | out_dir=out_dir, early_stop_patience=early_stop_patience, swa=use_swa 279 | ) 280 | 281 | # Get accelerator and distributed strategy 282 | accelerator = "cpu" 283 | strategy = None 284 | if torch.cuda.is_available(): 285 | accelerator = "cuda" 286 | if torch.cuda.device_count() > 1: 287 | # https://github.com/Lightning-AI/lightning/discussions/6761https://github.com/Lightning-AI/lightning/discussions/6761 288 | strategy = DDPStrategy(find_unused_parameters=False) 289 | 290 | logging.info(f"Using {accelerator} with strategy {strategy}") 291 | 292 | loggers = [pl.loggers.CSVLogger(save_dir=out_dir / "logs")] 293 | 294 | # Set up WandB logging 295 | if wandb_config is not None: 296 | wandb_logger = pl.loggers.WandbLogger(**wandb_config) 297 | if pl.utilities.rank_zero_only.rank == 0: 298 | wandb_logger.experiment.config.update(func_args) 299 | loggers.append(wandb_logger) 300 | 301 | trainer = pl.Trainer( 302 | default_root_dir=out_dir, 303 | gradient_clip_val=gradient_clip, 304 | min_epochs=min_epochs, 305 | max_epochs=max_epochs, 306 | check_val_every_n_epoch=1, 307 | callbacks=callbacks, 308 | logger=loggers, 309 | log_every_n_steps=min(50, len(data_loaders["train"])), # Log >= once per epoch 310 | accelerator=accelerator, 311 | strategy=strategy, 312 | gpus=ngpu, 313 | enable_progress_bar=False, 314 | move_metrics_to_cpu=False, # Saves memory 315 | profiler="simple" if profile else None, 316 | ) 317 | trainer.fit( 318 | model=model, 319 | train_dataloaders=data_loaders["train"], 320 | val_dataloaders=data_loaders["validation"], 321 | ) 322 | 323 | 324 | @utils.unwrap_typer_args 325 | def train_from_config( 326 | config: str = typer.Argument(..., help="JSON file containing training parameters"), 327 | out_dir: str = typer.Option("results", help="Directory to write model training outputs to"), 328 | restart_dir: str = typer.Option(None, help="Directory to restart from"), 329 | wandb_run: str = typer.Option(None, help="Run name for WandB logging"), 330 | ncpu: int = typer.Option(multiprocessing.cpu_count(), help="Number of workers"), 331 | ngpu: int = typer.Option(-1, help="Number of GPUs to use (-1 for all)"), 332 | unsafe_cache: bool = typer.Option( 333 | False, help="Don't check data filenames and cache hashes before loading data" 334 | ), 335 | profile: bool = False, 336 | overwrite: bool = typer.Option(False, help="Overwrite output directory"), 337 | ) -> None: 338 | curr_time = datetime.now().strftime("%y%m%d_%H%M%S") 339 | logging.basicConfig( 340 | level=logging.INFO, 341 | handlers=[ 342 | logging.FileHandler(f"training_{curr_time}.log"), 343 | logging.StreamHandler(), 344 | ], 345 | ) 346 | 347 | config_path = Path(config) 348 | if not config_path.exists(): 349 | # Assume it's a path relative to the configs folder in the top-level directory 350 | config_path = ringer.CONFIG_DIR / config_path 351 | if not config_path.exists(): 352 | raise ValueError(f"Config '{config_path}' doesn't exist") 353 | 354 | with open(config_path) as source: 355 | config_args = json.load(source) 356 | 357 | wandb_config = None 358 | if wandb_run is not None: 359 | wandb_config_path = ringer.CONFIG_DIR / "wandb/wandb.json" 360 | with open(wandb_config_path) as source: 361 | wandb_config = json.load(source) 362 | wandb_config["name"] = wandb_run 363 | 364 | config_args = utils.update_dict_nonnull( 365 | config_args, 366 | { 367 | "out_dir": out_dir, 368 | "restart_dir": restart_dir, 369 | "overwrite": overwrite, 370 | "ncpu": ncpu, 371 | "ngpu": ngpu, 372 | "unsafe_cache": unsafe_cache, 373 | "profile": profile, 374 | "wandb_config": wandb_config, 375 | }, 376 | ) 377 | 378 | train(**config_args) 379 | 380 | 381 | def main() -> None: 382 | typer.run(train_from_config) 383 | 384 | 385 | if __name__ == "__main__": 386 | main() 387 | -------------------------------------------------------------------------------- /ringer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/ringer/6ec3b5ef6b586a74fc7613e0c40ff887840dddd3/ringer/utils/__init__.py -------------------------------------------------------------------------------- /ringer/utils/chem.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from rdkit import Chem 6 | from rdkit.Chem import AllChem 7 | 8 | from . import utils 9 | 10 | # Carboxyl before nitrogen corresponds to N-to-C direction of peptide 11 | PEPTIDE_PATTERN = Chem.MolFromSmarts("[OX1]=[C;R][N;R]") 12 | 13 | 14 | def get_macrocycle_idxs( 15 | mol: Chem.Mol, min_size: int = 9, n_to_c: bool = True 16 | ) -> Optional[List[int]]: 17 | sssr = Chem.GetSymmSSSR(mol) 18 | if len(sssr) > 0: 19 | largest_ring = max(sssr, key=len) 20 | if len(largest_ring) >= min_size: 21 | idxs = list(largest_ring) 22 | if n_to_c: 23 | return macrocycle_idxs_in_n_to_c_direction(mol, idxs) 24 | return idxs 25 | return None 26 | 27 | 28 | def macrocycle_idxs_in_n_to_c_direction(mol: Chem.Mol, macrocycle_idxs: List[int]) -> List[int]: 29 | # Obtain carbon and nitrogen idxs in peptide bonds in the molecule 30 | matches = mol.GetSubstructMatches(PEPTIDE_PATTERN) 31 | if not matches: 32 | raise ValueError("Did not match any peptide bonds") 33 | 34 | # We match 3 atoms each time (O, C, N), just need C and N in the ring 35 | carbon_and_nitrogen_idxs = {match[1:] for match in matches} 36 | 37 | for atom_idx_pair in utils.get_overlapping_sublists(macrocycle_idxs, 2): 38 | # If the directionality of atom idxs is already in N to C direction, then pairs of these 39 | # atom indices should already be in the set of matched atoms, otherwise, we need to flip 40 | # the direction 41 | if tuple(atom_idx_pair) in carbon_and_nitrogen_idxs: 42 | break 43 | else: 44 | macrocycle_idxs = macrocycle_idxs[::-1] # Flip direction 45 | 46 | # Always start at a nitrogen 47 | nitrogen_idx = next(iter(carbon_and_nitrogen_idxs))[1] # Random nitrogen 48 | nitrogen_loc = macrocycle_idxs.index(nitrogen_idx) 49 | macrocycle_idxs = macrocycle_idxs[nitrogen_loc:] + macrocycle_idxs[:nitrogen_loc] 50 | 51 | return macrocycle_idxs 52 | 53 | 54 | def extract_macrocycle(mol: Chem.Mol) -> Chem.Mol: 55 | macrocycle_idxs = get_macrocycle_idxs(mol) 56 | if macrocycle_idxs is None: 57 | raise ValueError(f"No macrocycle detected in '{Chem.MolToSmiles(Chem.RemoveHs(mol))}'") 58 | 59 | macrocycle_idxs = set(macrocycle_idxs) 60 | to_remove = sorted( 61 | (atom.GetIdx() for atom in mol.GetAtoms() if atom.GetIdx() not in macrocycle_idxs), 62 | reverse=True, 63 | ) 64 | 65 | rwmol = Chem.RWMol(mol) 66 | for idx in to_remove: 67 | rwmol.RemoveAtom(idx) 68 | 69 | new_mol = rwmol.GetMol() 70 | return new_mol 71 | 72 | 73 | def combine_mols(mols: List[Chem.Mol]) -> Chem.Mol: 74 | """Combine multiple molecules with one conformer each into one molecule with multiple 75 | conformers. 76 | 77 | Args: 78 | mols: List of molecules. 79 | 80 | Returns: 81 | Combined molecule. 82 | """ 83 | new_mol = Chem.Mol(mols[0], quickCopy=True) 84 | for mol in mols: 85 | conf = Chem.Conformer(mol.GetConformer()) 86 | new_mol.AddConformer(conf, assignId=True) 87 | return new_mol 88 | 89 | 90 | def set_atom_positions( 91 | mol: Chem.Mol, 92 | xyzs: Union[np.ndarray, pd.DataFrame, List[np.ndarray], List[pd.DataFrame]], 93 | atom_idxs: Optional[List[int]] = None, 94 | ) -> Chem.Mol: 95 | """Set atom positions of a molecule. 96 | 97 | Args: 98 | mol: Molecule. 99 | xyzs: An array of coordinates; a dataframe with 'x', 'y', and 'z' columns (and optionally an index of atom indices); a list of arrays; or a list of dataframes. 100 | atom_idxs: Atom indices to set atom positions for. Not required if dataframe(s) contain(s) atom indices. 101 | 102 | Returns: 103 | A copy of the molecule with one conformer for each set of coordinates. 104 | """ 105 | # If multiple xyxs are provided, make one conformer for each one 106 | if isinstance(xyzs, (np.ndarray, pd.DataFrame)): 107 | xyzs = [xyzs] 108 | 109 | if atom_idxs is None: 110 | assert all(isinstance(xyz, pd.DataFrame) for xyz in xyzs) 111 | 112 | # The positions that don't get set will be the same as in the first conformer of the given mol 113 | dummy_conf = mol.GetConformer() 114 | mol = Chem.Mol(mol, quickCopy=True) # Don't copy conformers 115 | 116 | for xyz in xyzs: 117 | if isinstance(xyz, pd.DataFrame): 118 | atom_idxs = xyz.index.tolist() 119 | xyz = xyz[["x", "y", "z"]].to_numpy() 120 | 121 | xyz = xyz.reshape(-1, 3) 122 | 123 | # Set only the positions at the provided indices 124 | conf = Chem.Conformer(dummy_conf) 125 | for atom_idx, pos in zip(atom_idxs, xyz): 126 | conf.SetAtomPosition(atom_idx, [float(p) for p in pos]) 127 | 128 | mol.AddConformer(conf, assignId=True) 129 | 130 | return mol 131 | 132 | 133 | def dfs( 134 | root_atom_idx: int, 135 | mol: Chem.Mol, 136 | max_depth: int = float("inf"), 137 | blocked_idxs: Optional[List[int]] = None, 138 | include_hydrogens: bool = True, 139 | ) -> List[int]: 140 | """Traverse molecular graph with depth-first search from given root atom index. 141 | 142 | Args: 143 | root_atom_idx: Root atom index. 144 | mol: Molecule. 145 | max_depth: Only traverse to this maximum depth. 146 | blocked_idxs: Don't traverse across these indices. Defaults to None. 147 | include_hydrogens: Include hydrogen atom indices in returned list. 148 | 149 | Returns: 150 | List of traversed atom indices in DFS order. 151 | """ 152 | root_atom = mol.GetAtomWithIdx(root_atom_idx) 153 | if blocked_idxs is not None: 154 | blocked_idxs = set(blocked_idxs) 155 | return _dfs( 156 | root_atom, 157 | max_depth=max_depth, 158 | blocked_idxs=blocked_idxs, 159 | include_hydrogens=include_hydrogens, 160 | ) 161 | 162 | 163 | def _dfs( 164 | atom: Chem.Atom, # Start from atom so we don't have to get it from index each time 165 | depth: int = 0, 166 | max_depth: int = float("inf"), 167 | blocked_idxs: Optional[Set[int]] = None, 168 | include_hydrogens: bool = True, 169 | visited: Optional[Set[int]] = None, 170 | traversal: Optional[List[int]] = None, 171 | ) -> List[int]: 172 | if visited is None: 173 | visited = set() 174 | if traversal is None: 175 | traversal = [] 176 | 177 | if include_hydrogens or atom.GetAtomicNum() != 1: 178 | atom_idx = atom.GetIdx() 179 | visited.add(atom_idx) 180 | traversal.append(atom_idx) 181 | 182 | if depth < max_depth: 183 | for atom_nei in atom.GetNeighbors(): 184 | atom_nei_idx = atom_nei.GetIdx() 185 | if atom_nei_idx not in visited: 186 | if blocked_idxs is None or atom_nei_idx not in blocked_idxs: 187 | _dfs( 188 | atom_nei, 189 | depth=depth + 1, 190 | max_depth=max_depth, 191 | blocked_idxs=blocked_idxs, 192 | include_hydrogens=include_hydrogens, 193 | visited=visited, 194 | traversal=traversal, 195 | ) 196 | 197 | return traversal 198 | -------------------------------------------------------------------------------- /ringer/utils/data/amino_acids.csv: -------------------------------------------------------------------------------- 1 | aa,smiles,residue_smiles,alpha_carbon_stereo,n-methylation 2 | A,C[C@H](N)C(=O)O,C[C@H](N)C=O,L,False 3 | R,N=C(N)NCCC[C@H](N)C(=O)O,N=C(N)NCCC[C@H](N)C=O,L,False 4 | N,NC(=O)C[C@H](N)C(=O)O,NC(=O)C[C@H](N)C=O,L,False 5 | D,N[C@@H](CC(=O)O)C(=O)O,N[C@H](C=O)CC(=O)O,L,False 6 | C,N[C@@H](CS)C(=O)O,N[C@H](C=O)CS,L,False 7 | Q,NC(=O)CC[C@H](N)C(=O)O,NC(=O)CC[C@H](N)C=O,L,False 8 | E,N[C@@H](CCC(=O)O)C(=O)O,N[C@H](C=O)CCC(=O)O,L,False 9 | G,NCC(=O)O,NCC=O,,False 10 | H,N[C@@H](Cc1c[nH]cn1)C(=O)O,N[C@H](C=O)Cc1c[nH]cn1,L,False 11 | I,CC[C@H](C)[C@H](N)C(=O)O,CC[C@H](C)[C@H](N)C=O,L,False 12 | L,CC(C)C[C@H](N)C(=O)O,CC(C)C[C@H](N)C=O,L,False 13 | K,NCCCC[C@H](N)C(=O)O,NCCCC[C@H](N)C=O,L,False 14 | M,CSCC[C@H](N)C(=O)O,CSCC[C@H](N)C=O,L,False 15 | F,N[C@@H](Cc1ccccc1)C(=O)O,N[C@H](C=O)Cc1ccccc1,L,False 16 | P,O=C(O)[C@@H]1CCCN1,O=C[C@@H]1CCCN1,L,False 17 | S,N[C@@H](CO)C(=O)O,N[C@H](C=O)CO,L,False 18 | T,C[C@@H](O)[C@H](N)C(=O)O,C[C@@H](O)[C@H](N)C=O,L,False 19 | W,N[C@@H](Cc1c[nH]c2ccccc12)C(=O)O,N[C@H](C=O)Cc1c[nH]c2ccccc12,L,False 20 | Y,N[C@@H](Cc1ccc(O)cc1)C(=O)O,N[C@H](C=O)Cc1ccc(O)cc1,L,False 21 | V,CC(C)[C@H](N)C(=O)O,CC(C)[C@H](N)C=O,L,False 22 | a,C[C@@H](N)C(=O)O,C[C@@H](N)C=O,D,False 23 | r,N=C(N)NCCC[C@@H](N)C(=O)O,N=C(N)NCCC[C@@H](N)C=O,D,False 24 | n,NC(=O)C[C@@H](N)C(=O)O,NC(=O)C[C@@H](N)C=O,D,False 25 | d,N[C@H](CC(=O)O)C(=O)O,N[C@@H](C=O)CC(=O)O,D,False 26 | c,N[C@H](CS)C(=O)O,N[C@@H](C=O)CS,D,False 27 | q,NC(=O)CC[C@@H](N)C(=O)O,NC(=O)CC[C@@H](N)C=O,D,False 28 | e,N[C@H](CCC(=O)O)C(=O)O,N[C@@H](C=O)CCC(=O)O,D,False 29 | h,N[C@H](Cc1c[nH]cn1)C(=O)O,N[C@@H](C=O)Cc1c[nH]cn1,D,False 30 | i,CC[C@@H](C)[C@@H](N)C(=O)O,CC[C@@H](C)[C@@H](N)C=O,D,False 31 | l,CC(C)C[C@@H](N)C(=O)O,CC(C)C[C@@H](N)C=O,D,False 32 | k,NCCCC[C@@H](N)C(=O)O,NCCCC[C@@H](N)C=O,D,False 33 | m,CSCC[C@@H](N)C(=O)O,CSCC[C@@H](N)C=O,D,False 34 | f,N[C@H](Cc1ccccc1)C(=O)O,N[C@@H](C=O)Cc1ccccc1,D,False 35 | p,O=C(O)[C@H]1CCCN1,O=C[C@H]1CCCN1,D,False 36 | s,N[C@H](CO)C(=O)O,N[C@@H](C=O)CO,D,False 37 | t,C[C@H](O)[C@@H](N)C(=O)O,C[C@H](O)[C@@H](N)C=O,D,False 38 | w,N[C@H](Cc1c[nH]c2ccccc12)C(=O)O,N[C@@H](C=O)Cc1c[nH]c2ccccc12,D,False 39 | y,N[C@H](Cc1ccc(O)cc1)C(=O)O,N[C@@H](C=O)Cc1ccc(O)cc1,D,False 40 | v,CC(C)[C@@H](N)C(=O)O,CC(C)[C@@H](N)C=O,D,False 41 | MeA,CN[C@@H](C)C(=O)O,CN[C@@H](C)C=O,L,True 42 | MeR,CN[C@@H](CCCNC(=N)N)C(=O)O,CN[C@H](C=O)CCCNC(=N)N,L,True 43 | MeN,CN[C@@H](CC(N)=O)C(=O)O,CN[C@H](C=O)CC(N)=O,L,True 44 | MeD,CN[C@@H](CC(=O)O)C(=O)O,CN[C@H](C=O)CC(=O)O,L,True 45 | MeC,CN[C@@H](CS)C(=O)O,CN[C@H](C=O)CS,L,True 46 | MeQ,CN[C@@H](CCC(N)=O)C(=O)O,CN[C@H](C=O)CCC(N)=O,L,True 47 | MeE,CN[C@@H](CCC(=O)O)C(=O)O,CN[C@H](C=O)CCC(=O)O,L,True 48 | MeG,CNCC(=O)O,CNCC=O,,True 49 | MeH,CN[C@@H](Cc1c[nH]cn1)C(=O)O,CN[C@H](C=O)Cc1c[nH]cn1,L,True 50 | MeI,CC[C@H](C)[C@H](NC)C(=O)O,CC[C@H](C)[C@@H](C=O)NC,L,True 51 | MeL,CN[C@@H](CC(C)C)C(=O)O,CN[C@H](C=O)CC(C)C,L,True 52 | MeK,CN[C@@H](CCCCN)C(=O)O,CN[C@H](C=O)CCCCN,L,True 53 | MeM,CN[C@@H](CCSC)C(=O)O,CN[C@H](C=O)CCSC,L,True 54 | MeF,CN[C@@H](Cc1ccccc1)C(=O)O,CN[C@H](C=O)Cc1ccccc1,L,True 55 | MeS,CN[C@@H](CO)C(=O)O,CN[C@H](C=O)CO,L,True 56 | MeT,CN[C@H](C(=O)O)[C@@H](C)O,CN[C@H](C=O)[C@@H](C)O,L,True 57 | MeW,CN[C@@H](Cc1c[nH]c2ccccc12)C(=O)O,CN[C@H](C=O)Cc1c[nH]c2ccccc12,L,True 58 | MeY,CN[C@@H](Cc1ccc(O)cc1)C(=O)O,CN[C@H](C=O)Cc1ccc(O)cc1,L,True 59 | MeV,CN[C@H](C(=O)O)C(C)C,CN[C@H](C=O)C(C)C,L,True 60 | Mea,CN[C@H](C)C(=O)O,CN[C@H](C)C=O,D,True 61 | Mer,CN[C@H](CCCNC(=N)N)C(=O)O,CN[C@@H](C=O)CCCNC(=N)N,D,True 62 | Men,CN[C@H](CC(N)=O)C(=O)O,CN[C@@H](C=O)CC(N)=O,D,True 63 | Med,CN[C@H](CC(=O)O)C(=O)O,CN[C@@H](C=O)CC(=O)O,D,True 64 | Mec,CN[C@H](CS)C(=O)O,CN[C@@H](C=O)CS,D,True 65 | Meq,CN[C@H](CCC(N)=O)C(=O)O,CN[C@@H](C=O)CCC(N)=O,D,True 66 | Mee,CN[C@H](CCC(=O)O)C(=O)O,CN[C@@H](C=O)CCC(=O)O,D,True 67 | Meh,CN[C@H](Cc1c[nH]cn1)C(=O)O,CN[C@@H](C=O)Cc1c[nH]cn1,D,True 68 | Mei,CC[C@@H](C)[C@@H](NC)C(=O)O,CC[C@@H](C)[C@H](C=O)NC,D,True 69 | Mel,CN[C@H](CC(C)C)C(=O)O,CN[C@@H](C=O)CC(C)C,D,True 70 | Mek,CN[C@H](CCCCN)C(=O)O,CN[C@@H](C=O)CCCCN,D,True 71 | Mem,CN[C@H](CCSC)C(=O)O,CN[C@@H](C=O)CCSC,D,True 72 | Mef,CN[C@H](Cc1ccccc1)C(=O)O,CN[C@@H](C=O)Cc1ccccc1,D,True 73 | Mes,CN[C@H](CO)C(=O)O,CN[C@@H](C=O)CO,D,True 74 | Met,CN[C@@H](C(=O)O)[C@H](C)O,CN[C@@H](C=O)[C@H](C)O,D,True 75 | Mew,CN[C@H](Cc1c[nH]c2ccccc12)C(=O)O,CN[C@@H](C=O)Cc1c[nH]c2ccccc12,D,True 76 | Mey,CN[C@H](Cc1ccc(O)cc1)C(=O)O,CN[C@@H](C=O)Cc1ccc(O)cc1,D,True 77 | Mev,CN[C@@H](C(=O)O)C(C)C,CN[C@@H](C=O)C(C)C,D,True 78 | -------------------------------------------------------------------------------- /ringer/utils/data_loading.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | from pathlib import Path 4 | from typing import Dict, List, Literal, Optional, Sequence, Union 5 | 6 | import numpy as np 7 | 8 | import ringer 9 | 10 | from .. import data 11 | from ..data import noised 12 | from . import variance_schedules 13 | 14 | INTERNAL_COORDINATES_DEFINITIONS = Literal[ 15 | "distances-angles", "angles", "dihedrals", "angles-sidechains" 16 | ] 17 | 18 | 19 | def get_datasets( 20 | data_dir: Union[str, Path], 21 | internal_coordinates_definitions: INTERNAL_COORDINATES_DEFINITIONS = "angles", 22 | splits: Sequence[str] = ("train", "validation", "test"), 23 | split_sizes: Sequence[float] = (0.8, 0.1, 0.1), 24 | use_atom_features: bool = True, 25 | atom_feature_fingerprint_radius: int = 3, 26 | atom_feature_fingerprint_size: int = 32, 27 | max_conf: Union[int, str] = 30, 28 | timesteps: int = 50, 29 | weights: Optional[Dict[str, float]] = None, 30 | variance_schedule: variance_schedules.SCHEDULES = "cosine", 31 | variance_scale: float = np.pi, 32 | mask_noise: bool = False, 33 | mask_noise_for_features: Optional[List[str]] = None, 34 | exhaustive_t: bool = False, 35 | use_cache: bool = True, 36 | cache_dir: Optional[Union[str, Path]] = None, 37 | unsafe_cache: bool = False, 38 | num_proc: int = multiprocessing.cpu_count(), 39 | sample_seed: int = 42, 40 | ) -> Dict[str, noised.NoisedDataset]: 41 | """Get the dataset objects to use for train/valid/test. 42 | 43 | Note, these need to be wrapped in data loaders later 44 | """ 45 | data_dir = Path(data_dir) 46 | if not data_dir.exists(): 47 | # Assume it's a path relative to the data folder in the top-level directory 48 | data_dir = ringer.DATA_DIR / data_dir 49 | if not data_dir.exists(): 50 | raise ValueError(f"Data directory '{data_dir}' doesn't exist") 51 | 52 | clean_dset_class = data.DATASET_CLASSES[internal_coordinates_definitions] 53 | logging.info(f"Clean dataset class: {clean_dset_class}") 54 | 55 | logging.info(f"Creating data splits: {splits}") 56 | clean_dset_kwargs = dict( 57 | data_dir=data_dir, 58 | use_atom_features=use_atom_features, 59 | fingerprint_radius=atom_feature_fingerprint_radius, 60 | fingerprint_size=atom_feature_fingerprint_size, 61 | split_sizes=split_sizes, 62 | num_conf=max_conf, 63 | all_confs_in_test=True, 64 | weights=weights, 65 | zero_center=True, 66 | use_cache=use_cache, 67 | unsafe_cache=unsafe_cache, 68 | num_proc=num_proc, 69 | sample_seed=sample_seed, 70 | ) 71 | if cache_dir is not None: 72 | clean_dset_kwargs["cache_dir"] = cache_dir 73 | clean_dsets = {split: clean_dset_class(split=split, **clean_dset_kwargs) for split in splits} 74 | 75 | # Set the validation set means to the training set means 76 | if len(clean_dsets) > 1 and clean_dsets["train"].means_dict is not None: 77 | logging.info(f"Updating validation/test means to {clean_dsets['train'].means}") 78 | for split, dset in clean_dsets.items(): 79 | if split != "train": 80 | dset.means = clean_dsets["train"].means_dict 81 | 82 | logging.info(f"Using {noised.NoisedDataset} for noise") 83 | noised_dsets = { 84 | split: noised.NoisedDataset( 85 | dset=dset, 86 | dset_key="angles", 87 | timesteps=timesteps, 88 | exhaustive_t=(split != "train") and exhaustive_t, 89 | beta_schedule=variance_schedule, 90 | nonangular_variance=1.0, 91 | angular_variance=variance_scale, 92 | mask_noise=mask_noise, 93 | mask_noise_for_features=mask_noise_for_features, 94 | ) 95 | for split, dset in clean_dsets.items() 96 | } 97 | for split, dset in noised_dsets.items(): 98 | logging.info(f"{split}: {dset}") 99 | 100 | return noised_dsets 101 | -------------------------------------------------------------------------------- /ringer/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from collections import defaultdict 3 | from functools import partial 4 | from typing import Dict, List, Optional, Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | 11 | from . import chem, internal_coords 12 | 13 | 14 | def compute_cov_mat_metrics( 15 | confusion_mat: np.ndarray, 16 | thresholds: np.ndarray = np.arange(0.05, 1.55, 0.05), 17 | ) -> Dict[str, Union[Dict[str, float], Dict[str, np.ndarray]]]: 18 | # confusion_mat: num_ref x num_gen 19 | ref_min = confusion_mat.min(axis=1) 20 | gen_min = confusion_mat.min(axis=0) 21 | cov_thresh = ref_min.reshape(-1, 1) <= thresholds.reshape(1, -1) 22 | jnk_thresh = gen_min.reshape(-1, 1) <= thresholds.reshape(1, -1) 23 | 24 | cov_r = cov_thresh.mean(axis=0) 25 | mat_r = ref_min.mean() 26 | cov_p = jnk_thresh.mean(axis=0) 27 | mat_p = gen_min.mean() 28 | 29 | cov = {"threshold": thresholds, "cov-r": cov_r, "cov-p": cov_p} 30 | mat = {"mat-r": mat_r, "mat-p": mat_p} 31 | 32 | return {"cov": cov, "mat": mat} 33 | 34 | 35 | def get_atom_map(probe_mol: Chem.Mol, ref_mol: Chem.Mol) -> List[Dict[int, int]]: 36 | ref_mol_idxs = [a.GetIdx() for a in ref_mol.GetAtoms()] 37 | matches = probe_mol.GetSubstructMatches( 38 | ref_mol, uniquify=False 39 | ) # Don't uniquify to account for symmetry 40 | atom_map = [dict(zip(match, ref_mol_idxs)) for match in matches] 41 | return atom_map 42 | 43 | 44 | def compute_rmsd_matrix( 45 | probe_mol: Chem.Mol, 46 | ref_mol: Chem.Mol, 47 | conf_ids_probe: Optional[List[int]] = None, 48 | conf_ids_ref: Optional[List[int]] = None, 49 | ncpu: int = 1, 50 | ) -> np.ndarray: 51 | probe_mol = Chem.RemoveHs(probe_mol) 52 | ref_mol = Chem.RemoveHs(ref_mol) 53 | 54 | # Precompute atom map for faster RMSD calculation 55 | atom_map = get_atom_map(probe_mol, ref_mol) 56 | atom_map = [list(map_.items()) for map_ in atom_map] 57 | 58 | rmsd_mat = compute_rmsd_matrix_from_map( 59 | probe_mol, 60 | ref_mol, 61 | atom_map, 62 | conf_ids_probe=conf_ids_probe, 63 | conf_ids_ref=conf_ids_ref, 64 | ncpu=ncpu, 65 | ) 66 | return rmsd_mat # [num_ref x num_probe] 67 | 68 | 69 | def compute_ring_rmsd_matrix( 70 | probe_mol: Chem.Mol, 71 | ref_mol: Chem.Mol, 72 | conf_ids_probe: Optional[List[int]] = None, 73 | conf_ids_ref: Optional[List[int]] = None, 74 | ncpu: int = 1, 75 | ) -> np.ndarray: 76 | probe_mol = Chem.RemoveHs(probe_mol) 77 | ref_mol = Chem.RemoveHs(ref_mol) 78 | 79 | # Precompute atom map for faster RMSD calculation 80 | atom_map = get_atom_map(probe_mol, ref_mol) 81 | 82 | # Subset to macrocycle indices 83 | probe_macrocycle_idxs = chem.get_macrocycle_idxs(probe_mol, n_to_c=False) 84 | if probe_macrocycle_idxs is None: 85 | raise ValueError( 86 | f"Macrocycle indices could not be determined for '{Chem.MolToSmiles(probe_mol)}'" 87 | ) 88 | atom_map = list(set(tuple((k, map_[k]) for k in probe_macrocycle_idxs) for map_ in atom_map)) 89 | 90 | # Check to make sure we have all macrocycle indices in the ref mol 91 | ref_macrocycle_idxs = chem.get_macrocycle_idxs(ref_mol, n_to_c=False) 92 | if ref_macrocycle_idxs is None: 93 | raise ValueError( 94 | f"Macrocycle indices could not be determined for '{Chem.MolToSmiles(ref_mol)}'" 95 | ) 96 | ref_macrocycle_idxs = set(ref_macrocycle_idxs) 97 | for map_ in atom_map: 98 | if set(ref_idx for _, ref_idx in map_) != ref_macrocycle_idxs: 99 | raise ValueError("Inconsistent macrocycle indices") 100 | 101 | rmsd_mat = compute_rmsd_matrix_from_map( 102 | probe_mol, 103 | ref_mol, 104 | atom_map, 105 | conf_ids_probe=conf_ids_probe, 106 | conf_ids_ref=conf_ids_ref, 107 | ncpu=ncpu, 108 | ) 109 | return rmsd_mat # [num_ref x num_probe] 110 | 111 | 112 | def compute_rmsd_matrix_from_map( 113 | probe_mol: Chem.Mol, 114 | ref_mol: Chem.Mol, 115 | atom_map: List[Sequence[Tuple[int, int]]], 116 | conf_ids_probe: Optional[List[int]] = None, 117 | conf_ids_ref: Optional[List[int]] = None, 118 | ncpu: int = 1, 119 | ) -> np.ndarray: 120 | if conf_ids_ref is None: 121 | conf_ids_ref = [conf.GetId() for conf in ref_mol.GetConformers()] 122 | if conf_ids_probe is None: 123 | conf_ids_probe = [conf.GetId() for conf in probe_mol.GetConformers()] 124 | 125 | num_ref = len(conf_ids_ref) 126 | num_probe = len(conf_ids_probe) 127 | 128 | _get_best_rms_args = [] 129 | for i, ref_id in enumerate(conf_ids_ref): 130 | for j, probe_id in enumerate(conf_ids_probe): 131 | _get_best_rms_args.append((i, j, probe_id, ref_id)) 132 | 133 | pfunc = partial(_get_best_rms, probe_mol=probe_mol, ref_mol=ref_mol, atom_map=atom_map) 134 | with multiprocessing.Pool(ncpu) as pool: 135 | rmsds_with_idxs = pool.map(pfunc, _get_best_rms_args) 136 | 137 | rmsd_mat = np.empty((num_ref, num_probe)) 138 | for i, j, rmsd in rmsds_with_idxs: 139 | rmsd_mat[i, j] = rmsd 140 | 141 | return rmsd_mat # [num_ref x num_probe] 142 | 143 | 144 | def _get_best_rms( 145 | args: Tuple[int, int, int, int], 146 | probe_mol: Chem.Mol, 147 | ref_mol: Chem.Mol, 148 | atom_map: List[Sequence[Tuple[int, int]]], 149 | ) -> Tuple[int, int, float]: 150 | i, j, probe_id, ref_id = args 151 | rmsd = AllChem.GetBestRMS(probe_mol, ref_mol, prbId=probe_id, refId=ref_id, map=atom_map) 152 | return i, j, rmsd 153 | 154 | 155 | def compute_ring_tfd_matrix(probe_mol: Chem.Mol, ref_mol: Chem.Mol, ncpu: int = 1) -> np.ndarray: 156 | """Compute ring torsion fingerprint deviation as in 157 | https://doi.org/10.1021/acs.jcim.0c00025.""" 158 | probe_mol = Chem.RemoveHs(probe_mol) 159 | ref_mol = Chem.RemoveHs(ref_mol) 160 | 161 | # Precompute atom map for faster RMSD calculation 162 | ref_mol_idxs = [a.GetIdx() for a in ref_mol.GetAtoms()] 163 | matches = probe_mol.GetSubstructMatches( 164 | ref_mol, uniquify=False 165 | ) # Don't uniquify to account for symmetry 166 | atom_map = [dict(zip(match, ref_mol_idxs)) for match in matches] 167 | 168 | # Subset to macrocycle indices 169 | probe_macrocycle_idxs = chem.get_macrocycle_idxs(probe_mol, n_to_c=False) 170 | if probe_macrocycle_idxs is None: 171 | raise ValueError( 172 | f"Macrocycle indices could not be determined for '{Chem.MolToSmiles(probe_mol)}'" 173 | ) 174 | atom_map = list(set(tuple((k, map_[k]) for k in probe_macrocycle_idxs) for map_ in atom_map)) 175 | 176 | num_torsions = len(probe_macrocycle_idxs) 177 | num_conf_probe = probe_mol.GetNumConformers() 178 | num_conf_ref = ref_mol.GetNumConformers() 179 | 180 | probe_torsions = internal_coords.get_macrocycle_dihedrals(probe_mol, probe_macrocycle_idxs) 181 | probe_torsions = probe_torsions.to_numpy() 182 | probe_torsions_tiled = np.tile(probe_torsions[np.newaxis, ...], (num_conf_ref, 1, 1)) 183 | 184 | ref_macrocycle_idxs_check = chem.get_macrocycle_idxs(ref_mol, n_to_c=False) 185 | if ref_macrocycle_idxs_check is None: 186 | raise ValueError( 187 | f"Macrocycle indices could not be determined for '{Chem.MolToSmiles(ref_mol)}'" 188 | ) 189 | ref_macrocycle_idxs_check = set(ref_macrocycle_idxs_check) 190 | 191 | ref_macrocycle_idxs_list = [] 192 | for map_ in atom_map: 193 | ref_macrocycle_idxs = [ref_idx for _, ref_idx in map_] # Same order as in probe 194 | if set(ref_macrocycle_idxs) != ref_macrocycle_idxs_check: 195 | raise ValueError("Inconsistent macrocycle indices") 196 | ref_macrocycle_idxs_list.append(ref_macrocycle_idxs) 197 | 198 | # There could be multiple maps, so select the minimum TFD for each one 199 | pfunc = partial( 200 | _get_ring_tfd_for_ref_idxs, 201 | ref_mol=ref_mol, 202 | num_conf_probe=num_conf_probe, 203 | probe_torsions_tiled=probe_torsions_tiled, 204 | num_torsions=num_torsions, 205 | ) 206 | with multiprocessing.Pool(ncpu) as pool: 207 | tfds = pool.map(pfunc, ref_macrocycle_idxs_list) 208 | tfd = np.minimum.reduce(tfds) 209 | 210 | return tfd # [num_ref x num_probe] 211 | 212 | 213 | def _get_ring_tfd_for_ref_idxs( 214 | ref_macrocycle_idxs: List[int], 215 | ref_mol: Chem.Mol, 216 | num_conf_probe: int, 217 | probe_torsions_tiled: np.ndarray, 218 | num_torsions: int, 219 | ) -> np.ndarray: 220 | ref_torsions = internal_coords.get_macrocycle_dihedrals(ref_mol, ref_macrocycle_idxs) 221 | ref_torsions = ref_torsions.to_numpy() 222 | 223 | # Compute deviations between all pairs of conformers 224 | ref_torsions_tiled = np.tile(ref_torsions[:, np.newaxis, :], (1, num_conf_probe, 1)) 225 | torsion_deviation = probe_torsions_tiled - ref_torsions_tiled 226 | 227 | # Wrap deviation around [-pi, pi] range 228 | torsion_deviation = (torsion_deviation + np.pi) % (2 * np.pi) - np.pi 229 | 230 | # Scale by max deviation, sum across torsions, and normalize 231 | tfd = np.sum(np.abs(torsion_deviation) / np.pi, axis=-1) / num_torsions 232 | 233 | return tfd 234 | 235 | 236 | class CovMatEvaluator: 237 | confusion_mat_funcs = { 238 | "rmsd": compute_rmsd_matrix, # heavy atoms 239 | "ring-rmsd": compute_ring_rmsd_matrix, 240 | "ring-tfd": compute_ring_tfd_matrix, 241 | } 242 | thresholds = { 243 | "rmsd": np.arange(0, 2.51, 0.01), 244 | "ring-rmsd": np.arange(0, 1.26, 0.01), 245 | "ring-tfd": np.arange(0, 1.01, 0.01), # Can't be larger than 1 246 | } 247 | 248 | def __init__(self, metrics: Sequence[str] = ("ring-rmsd", "ring-tfd")) -> None: 249 | for name in metrics: 250 | if name not in self.confusion_mat_funcs: 251 | raise NotImplementedError(f"Metric '{name}' is not implemented") 252 | self.metric_names = metrics 253 | 254 | def __call__( 255 | self, probe_mol: Chem.Mol, ref_mol: Chem.Mol, ncpu: int = 1 256 | ) -> Dict[str, Dict[str, Union[Dict[str, float], Dict[str, np.ndarray]]]]: 257 | metrics = {} 258 | for metric_name in self.metric_names: 259 | confusion_mat_func = self.confusion_mat_funcs[metric_name] 260 | confusion_mat = confusion_mat_func(probe_mol, ref_mol, ncpu=ncpu) 261 | metric = compute_cov_mat_metrics( 262 | confusion_mat, thresholds=self.thresholds[metric_name] 263 | ) 264 | metrics[metric_name] = metric 265 | return metrics 266 | 267 | @staticmethod 268 | def stack_results( 269 | metrics_dict: Dict[ 270 | str, Dict[str, Dict[str, Union[Dict[str, float], Dict[str, np.ndarray]]]] 271 | ] 272 | ) -> Dict[str, Dict[str, pd.DataFrame]]: 273 | # Stack all COV/MAT for each metric 274 | cov = defaultdict(dict) 275 | mat = defaultdict(list) 276 | 277 | for fname, results in metrics_dict.items(): 278 | # results = {'metric1': {'cov': ..., 'mat': ...}, ...} 279 | for metric_name, result in results.items(): 280 | # result = {'cov': {'threshold': ..., 'cov-r': ..., 'cov-p': ...}, 'mat': {'mat-r': ..., 'mat-p': ...}} 281 | cov[metric_name][fname] = pd.DataFrame(result["cov"]).set_index("threshold") 282 | mat[metric_name].append(pd.DataFrame(result["mat"], index=[fname])) 283 | 284 | cov = {metric_name: pd.concat(results) for metric_name, results in cov.items()} 285 | mat = {metric_name: pd.concat(results) for metric_name, results in mat.items()} 286 | 287 | # {'metric1': {'cov': pd.DataFrame, 'mat': pd.DataFrame}, ...} 288 | results = { 289 | metric_name: {"cov": cov[metric_name], "mat": mat[metric_name]} 290 | for metric_name in cov.keys() 291 | } 292 | 293 | return results 294 | 295 | @staticmethod 296 | def aggregate_results_for_metric(results: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]: 297 | return { 298 | "cov": results["cov"].groupby(level="threshold").agg(func=["mean", "median"]), 299 | "mat": results["mat"].agg(func=["mean", "median"]), 300 | } 301 | 302 | @classmethod 303 | def aggregate_results( 304 | cls, results: Dict[str, Dict[str, pd.DataFrame]] 305 | ) -> Dict[str, Dict[str, pd.DataFrame]]: 306 | return { 307 | metric_name: cls.aggregate_results_for_metric(r) for metric_name, r in results.items() 308 | } 309 | -------------------------------------------------------------------------------- /ringer/utils/featurization.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | from typing import Any, List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from rdkit import Chem 8 | from rdkit.Chem import rdFingerprintGenerator 9 | from rdkit.Chem.rdchem import ChiralType, HybridizationType 10 | 11 | from . import chem, peptides 12 | 13 | ATOMIC_NUMS = list(range(1, 100)) 14 | PEPTIDE_CHIRAL_TAGS = { 15 | "L": 1, 16 | "D": -1, 17 | None: 0, 18 | } 19 | CHIRAL_TAGS = { 20 | ChiralType.CHI_TETRAHEDRAL_CW: -1, 21 | ChiralType.CHI_TETRAHEDRAL_CCW: 1, 22 | ChiralType.CHI_UNSPECIFIED: 0, 23 | ChiralType.CHI_OTHER: 0, 24 | } 25 | HYBRIDIZATION_TYPES = [ 26 | HybridizationType.SP, 27 | HybridizationType.SP2, 28 | HybridizationType.SP3, 29 | HybridizationType.SP3D, 30 | HybridizationType.SP3D2, 31 | ] 32 | DEGREES = [0, 1, 2, 3, 4, 5] 33 | VALENCES = [0, 1, 2, 3, 4, 5, 6] 34 | NUM_HYDROGENS = [0, 1, 2, 3, 4] 35 | FORMAL_CHARGES = [-2, -1, 0, 1, 2] 36 | RING_SIZES = [3, 4, 5, 6, 7, 8] 37 | NUM_RINGS = [0, 1, 2, 3] 38 | 39 | ATOMIC_NUM_FEATURE_NAMES = [f"anum{anum}" for anum in ATOMIC_NUMS] + ["anumUNK"] 40 | CHIRAL_TAG_FEATURE_NAME = "chiraltag" 41 | AROMATICITY_FEATURE_NAME = "aromatic" 42 | HYBRIDIZATION_TYPE_FEATURE_NAMES = [f"hybrid{ht}" for ht in HYBRIDIZATION_TYPES] + ["hybridUNK"] 43 | DEGREE_FEATURE_NAMES = [f"degree{d}" for d in DEGREES] + ["degreeUNK"] 44 | VALENCE_FEATURE_NAMES = [f"valence{v}" for v in VALENCES] + ["valenceUNK"] 45 | NUM_HYDROGEN_FEATURE_NAMES = [f"numh{nh}" for nh in NUM_HYDROGENS] + ["numhUNK"] 46 | FORMAL_CHARGE_FEATURE_NAMES = [f"charge{c}" for c in FORMAL_CHARGES] + ["chargeUNK"] 47 | RING_SIZE_FEATURE_NAMES = [f"ringsize{rs}" for rs in RING_SIZES] # Don't need "unknown" name 48 | NUM_RING_FEATURE_NAMES = [f"numring{nr}" for nr in NUM_RINGS] + ["numringUNK"] 49 | 50 | 51 | def one_k_encoding(value: Any, choices: List[Any], include_unknown: bool = True) -> List[int]: 52 | """Create a one-hot encoding with an extra category for uncommon values. 53 | 54 | Args: 55 | value: The value for which the encoding should be one. 56 | choices: A list of possible values. 57 | include_unknown: Add the extra category for uncommon values. 58 | 59 | Returns: 60 | A one-hot encoding of the `value` in a list of length len(`choices`) + 1. 61 | If `value` is not in `choices, then the final element in the encoding is 1 62 | (if `include_unknown` is True). 63 | """ 64 | encoding = [0] * (len(choices) + include_unknown) 65 | try: 66 | idx = choices.index(value) 67 | except ValueError: 68 | if include_unknown: 69 | idx = -1 70 | else: 71 | raise ValueError( 72 | f"Cannot encode '{value}' because it is not in the list of possible values {choices}" 73 | ) 74 | encoding[idx] = 1 75 | 76 | return encoding 77 | 78 | 79 | def featurize_macrocycle_atoms( 80 | mol: Chem.Mol, 81 | macrocycle_idxs: Optional[List[int]] = None, 82 | use_peptide_stereo: bool = True, 83 | residues_in_mol: Optional[List[str]] = None, 84 | include_side_chain_fingerprint: bool = True, 85 | radius: int = 3, 86 | size: int = 2048, 87 | ) -> pd.DataFrame: 88 | """Create a sequence of features for each atom in `macrocycle_idxs`. 89 | 90 | Args: 91 | mol: Macrocycle molecule. 92 | macrocycle_idxs: Atom indices for atoms in the macrocycle. 93 | use_peptide_stereo: Use L/D chiral tags instead of RDKit tags. 94 | residues_in_mol: Residues the mol is composed of. Speeds up determining L/D tags. 95 | include_side_chain_fingerprint: Add Morgan count fingerprints. 96 | radius: Morgan fingerprint radius. 97 | size: Morgan fingerprint size. 98 | 99 | Returns: 100 | DataFrame where each row is an atom in the macrocycle and each column is a feature. 101 | """ 102 | if macrocycle_idxs is None: 103 | macrocycle_idxs = chem.get_macrocycle_idxs(mol) 104 | if macrocycle_idxs is None: 105 | raise ValueError( 106 | f"Couldn't get macrocycle indices for '{Chem.MolToSmiles(Chem.RemoveHs(mol))}'" 107 | ) 108 | 109 | atom_features = {} 110 | ring_info = mol.GetRingInfo() 111 | morgan_fingerprint_generator = rdFingerprintGenerator.GetMorganGenerator( 112 | radius=radius, fpSize=size, includeChirality=True 113 | ) 114 | fingerprint_feature_names = [f"fp{i}" for i in range(size)] 115 | 116 | if use_peptide_stereo: 117 | residues = peptides.get_residues( 118 | mol, residues_in_mol=residues_in_mol, macrocycle_idxs=macrocycle_idxs 119 | ) 120 | atom_to_residue = { 121 | atom_idx: symbol for atom_idxs, symbol in residues.items() for atom_idx in atom_idxs 122 | } 123 | 124 | for atom_idx in macrocycle_idxs: 125 | atom_feature_dict = {} 126 | atom = mol.GetAtomWithIdx(atom_idx) 127 | 128 | atomic_num_onehot = one_k_encoding(atom.GetAtomicNum(), ATOMIC_NUMS) 129 | atom_feature_dict.update(dict(zip(ATOMIC_NUM_FEATURE_NAMES, atomic_num_onehot))) 130 | 131 | chiral_feature = CHIRAL_TAGS[atom.GetChiralTag()] 132 | if use_peptide_stereo: 133 | # Only label an atom with the residue L/D tag if the atom is a chiral center 134 | if chiral_feature != 0: 135 | chiral_feature = PEPTIDE_CHIRAL_TAGS[ 136 | peptides.get_amino_acid_stereo(atom_to_residue[atom_idx]) 137 | ] 138 | atom_feature_dict[CHIRAL_TAG_FEATURE_NAME] = chiral_feature 139 | 140 | atom_feature_dict[AROMATICITY_FEATURE_NAME] = 1 if atom.GetIsAromatic() else 0 141 | 142 | hybridization_onehot = one_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPES) 143 | atom_feature_dict.update(dict(zip(HYBRIDIZATION_TYPE_FEATURE_NAMES, hybridization_onehot))) 144 | 145 | degree_onehot = one_k_encoding(atom.GetTotalDegree(), DEGREES) 146 | atom_feature_dict.update(dict(zip(DEGREE_FEATURE_NAMES, degree_onehot))) 147 | 148 | valence_onehot = one_k_encoding(atom.GetTotalValence(), VALENCES) 149 | atom_feature_dict.update(dict(zip(VALENCE_FEATURE_NAMES, valence_onehot))) 150 | 151 | num_hydrogen_onehot = one_k_encoding( 152 | atom.GetTotalNumHs(includeNeighbors=True), NUM_HYDROGENS 153 | ) 154 | atom_feature_dict.update(dict(zip(NUM_HYDROGEN_FEATURE_NAMES, num_hydrogen_onehot))) 155 | 156 | charge_onehot = one_k_encoding(atom.GetFormalCharge(), FORMAL_CHARGES) 157 | atom_feature_dict.update(dict(zip(FORMAL_CHARGE_FEATURE_NAMES, charge_onehot))) 158 | 159 | in_ring_sizes = [int(ring_info.IsAtomInRingOfSize(atom_idx, size)) for size in RING_SIZES] 160 | atom_feature_dict.update(dict(zip(RING_SIZE_FEATURE_NAMES, in_ring_sizes))) 161 | 162 | num_rings_onehot = one_k_encoding(int(ring_info.NumAtomRings(atom_idx)), NUM_RINGS) 163 | atom_feature_dict.update(dict(zip(NUM_RING_FEATURE_NAMES, num_rings_onehot))) 164 | 165 | if include_side_chain_fingerprint: 166 | # Fingerprint includes atom in ring that side chain starts at 167 | side_chain_idxs = chem.dfs(atom_idx, mol, blocked_idxs=macrocycle_idxs) 168 | fingerprint = morgan_fingerprint_generator.GetCountFingerprintAsNumPy( 169 | mol, fromAtoms=side_chain_idxs 170 | ) 171 | fingerprint = np.asarray(fingerprint.astype(np.int64), dtype=int) 172 | atom_feature_dict.update(dict(zip(fingerprint_feature_names, fingerprint))) 173 | 174 | atom_features[atom_idx] = atom_feature_dict 175 | 176 | atom_features = pd.DataFrame(atom_features).T 177 | atom_features.index.name = "atom_idx" 178 | 179 | return atom_features 180 | 181 | 182 | def featurize_macrocycle_atoms_from_file( 183 | path: Union[str, Path], 184 | use_peptide_stereo: bool = True, 185 | residues_in_mol: Optional[List[str]] = None, 186 | include_side_chain_fingerprint: bool = True, 187 | radius: int = 3, 188 | size: int = 2048, 189 | return_mol: bool = False, 190 | ) -> Union[pd.DataFrame, Tuple[Chem.Mol, pd.DataFrame]]: 191 | with open(path, "rb") as f: 192 | ensemble_data = pickle.load(f) 193 | mol = ensemble_data["rd_mol"] 194 | 195 | features = featurize_macrocycle_atoms( 196 | mol, 197 | use_peptide_stereo=use_peptide_stereo, 198 | residues_in_mol=residues_in_mol, 199 | include_side_chain_fingerprint=include_side_chain_fingerprint, 200 | radius=radius, 201 | size=size, 202 | ) 203 | 204 | if return_mol: 205 | return mol, features 206 | return features 207 | -------------------------------------------------------------------------------- /ringer/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import utils 4 | 5 | 6 | def radian_l1_loss(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 7 | """Computes the loss between input and target. 8 | 9 | >>> radian_l1_loss(torch.tensor(0.1), 2 * torch.pi) 10 | tensor(0.1000) 11 | >>> radian_l1_loss(torch.tensor(0.1), torch.tensor(2 * np.pi - 0.1)) 12 | tensor(0.2000) 13 | """ 14 | # https://stackoverflow.com/questions/1878907/how-can-i-find-the-difference-between-two-angles 15 | target = target % (2 * torch.pi) 16 | input = input % (2 * torch.pi) 17 | d = target - input 18 | d = (d + torch.pi) % (2 * torch.pi) - torch.pi 19 | retval = torch.abs(d) 20 | return torch.mean(retval) 21 | 22 | 23 | def radian_smooth_l1_loss( 24 | input: torch.Tensor, 25 | target: torch.Tensor, 26 | beta: float = 1.0, 27 | circle_penalty: float = 0.0, 28 | ) -> torch.Tensor: 29 | """Smooth radian L1 loss. 30 | 31 | if the abs(delta) < beta --> 0.5 * delta^2 / beta 32 | else --> abs(delta) - 0.5 * beta 33 | 34 | See: 35 | https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#smooth_l1_loss 36 | >>> radian_smooth_l1_loss(torch.tensor(-17.0466), torch.tensor(-1.3888), beta=0.1) 37 | tensor(3.0414) 38 | """ 39 | assert target.shape == input.shape, f"Mismatched shapes: {input.shape} != {target.shape}" 40 | assert beta > 0 41 | d = target - input 42 | d = utils.modulo_with_wrapped_range(d, -torch.pi, torch.pi) 43 | 44 | abs_d = torch.abs(d) 45 | retval = torch.where(abs_d < beta, 0.5 * (d**2) / beta, abs_d - 0.5 * beta) 46 | assert torch.all(retval >= 0), f"Got negative loss terms: {torch.min(retval)}" 47 | retval = torch.mean(retval) 48 | 49 | # Regularize on "turns" around the circle 50 | if circle_penalty > 0: 51 | retval += circle_penalty * torch.mean( 52 | torch.div(torch.abs(input), torch.pi, rounding_mode="trunc") 53 | ) 54 | 55 | return retval 56 | -------------------------------------------------------------------------------- /ringer/utils/peptides.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, FrozenSet, List, Optional 3 | 4 | import pandas as pd 5 | from rdkit import Chem 6 | 7 | from . import chem 8 | 9 | AMINO_ACID_DATA_PATH = Path(__file__).resolve().parent / "data/amino_acids.csv" 10 | AMINO_ACID_DATA = pd.read_csv(AMINO_ACID_DATA_PATH, index_col="aa") 11 | AMINO_ACID_DATA["residue_mol"] = AMINO_ACID_DATA["residue_smiles"].map(Chem.MolFromSmiles) 12 | 13 | RING_PEPTIDE_BOND_PATTERN = Chem.MolFromSmarts("[C;R:0](=[OX1:1])[C;R:2][N;R:3]") 14 | 15 | GENERIC_AMINO_ACID_SMARTS = "[$([CX3](=[OX1]))][NX3,NX4+][$([CX4H]([CX3](=[OX1])[O,N]))][*]" 16 | 17 | # These don't match all atoms in the side chains, but only the ones we're interested in 18 | # extracting internal coordinates for 19 | SIDE_CHAIN_TORSIONS_SMARTS_DICT = { 20 | "alanine": "[CH3X4]", 21 | "asparagine": "[CH2X4][$([CX3](=[OX1])[NX3H2])][NX3H2]", 22 | "aspartic acid": "[CH2X4][$([CX3](=[OX1])[OH0-,OH])][OH0-,OH]", 23 | "cysteine": "[CH2X4][SX2H,SX1H0-]", 24 | "glutamic acid": "[CH2X4][CH2X4][$([CX3](=[OX1])[OH0-,OH])][OH0-,OH]", 25 | "glutamine": "[CH2X4][CH2X4][$([CX3](=[OX1])[NX3H2])][NX3H2]", 26 | "histidine": "[CH2X4][$([#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1)]:[#6X3H]", 27 | "isoleucine": "[$([CHX4]([CH3X4])[CH2X4][CH3X4])][CH2X4][CH3X4]", 28 | "leucine": "[CH2X4][$([CHX4]([CH3X4])[CH3X4])][CH3X4]", 29 | "lysine": "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", 30 | "phenylalanine": "[CH2X4][$([cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1)][cX3H]", 31 | "serine": "[CH2X4][OX2H]", 32 | "threonine": "[$([CHX4]([OX2H])[CH3X4])][CH3X4]", 33 | "tryptophan": "[CH2X4][$([cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12)][cX3H0]", 34 | "tyrosine": "[CH2X4][$([cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1)][cX3H]", 35 | "valine": "[$([CHX4]([CH3X4])[CH3X4])][CH3X4]", 36 | } 37 | AMINO_ACID_TORSIONS_SMARTS_DICT = { 38 | name: GENERIC_AMINO_ACID_SMARTS.replace("[*]", smarts) 39 | for name, smarts in SIDE_CHAIN_TORSIONS_SMARTS_DICT.items() 40 | } 41 | # Handle proline separately because it doesn't fit the generic amino acid template. 42 | # Make sure we only match three backbone atoms and the beta carbon because we only want 43 | # to model the torsion coming out of the ring 44 | AMINO_ACID_TORSIONS_SMARTS_DICT[ 45 | "proline" 46 | ] = "[$([CX3](=[OX1]))][$([$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N])][$([CX4H]1[CH2][CH2][CH2][$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1)][$([CX4H2]1[CH2][CH2][$([NX3H,NX4H2+]),$([NX3](C)(C)(C))][CX4H]1)]" 47 | AMINO_ACID_TORSIONS_PATTERNS = { 48 | name: Chem.MolFromSmarts(smarts) for name, smarts in AMINO_ACID_TORSIONS_SMARTS_DICT.items() 49 | } 50 | # These will be matched twice, we only want to keep one match 51 | AMINO_ACIDS_WITH_SYMMETRY = {"leucine", "phenylalanine", "tyrosine", "valine"} 52 | 53 | 54 | def get_amino_acid_stereo(symbol: str) -> Optional[str]: 55 | # None for glycine 56 | stereo = AMINO_ACID_DATA.loc[symbol]["alpha_carbon_stereo"] 57 | return stereo if isinstance(stereo, str) else None 58 | 59 | 60 | def get_residues( 61 | mol: Chem.Mol, 62 | residues_in_mol: Optional[List[str]] = None, 63 | macrocycle_idxs: Optional[List[int]] = None, 64 | ) -> Dict[FrozenSet[int], str]: 65 | """ 66 | Find the residues in a molecule by matching to a known dataset of amino acids. 67 | Note: This function isn't inherently restricted to macrocycles and would only require little 68 | tweaking to work for general peptides. 69 | 70 | Args: 71 | mol: Macrocycle molecule. 72 | residues_in_mol: If known, this list of residues speeds up the matching process. 73 | macrocycle_idxs: Atom indices for atoms in the macrocycle backbone. 74 | 75 | Returns: 76 | Mapping from atom indices in a residue to its residue label. 77 | """ 78 | if macrocycle_idxs is None: 79 | macrocycle_idxs = chem.get_macrocycle_idxs(mol) 80 | if macrocycle_idxs is None: 81 | raise ValueError( 82 | f"Couldn't get macrocycle indices for '{Chem.MolToSmiles(Chem.RemoveHs(mol))}'" 83 | ) 84 | 85 | # Note: An alternative to the below algorithm would be to first find all the atom indices in 86 | # each residue by running DFS from each backbone atom in the residue, extract the residues as 87 | # new mols, and then match them to the known residues (e.g., by SMILES). 88 | # If the residues in the molecule are not known a priori, such an approach would be much 89 | # faster, but because we generally assume that the residue information is available, we use the 90 | # algorithm below because extracting the residue as a new mol is a little more tedious. 91 | 92 | backbone_idxs = mol.GetSubstructMatches(RING_PEPTIDE_BOND_PATTERN) 93 | if residues_in_mol is None: 94 | potential_residues = AMINO_ACID_DATA.index 95 | else: 96 | potential_residues = residues_in_mol 97 | 98 | potential_residue_idxs = {} 99 | for residue in set(potential_residues): 100 | residue_data = AMINO_ACID_DATA.loc[residue] 101 | # This might match a partial side chain, e.g., glycine will match all side chains 102 | # Will match charged and uncharged side chains if the residue SMILES does not have charges 103 | # Using chirality for this match is very important to distinguish L- and D-amino acids 104 | residue_matches = mol.GetSubstructMatches(residue_data["residue_mol"], useChirality=True) 105 | potential_residue_idxs.update({frozenset(match): residue for match in residue_matches}) 106 | 107 | # Because we might have partial matches, we need to find all atom indices in each residue in 108 | # order to compare to the matched residues and get the residue label 109 | residue_idxs = [ 110 | frozenset( 111 | side_chain_idx 112 | for atom_idx in atom_idxs 113 | for side_chain_idx in chem.dfs( 114 | atom_idx, mol, blocked_idxs=macrocycle_idxs, include_hydrogens=False 115 | ) 116 | ) 117 | for atom_idxs in backbone_idxs 118 | ] 119 | 120 | residue_dict = {} 121 | for atom_idxs in residue_idxs: 122 | try: 123 | residue = potential_residue_idxs[atom_idxs] 124 | except KeyError: 125 | raise Exception( 126 | f"Cannot determine residue for backbone indices '{list(atom_idxs)}' of '{Chem.MolToSmiles(Chem.RemoveHs(mol))}'" 127 | ) 128 | else: 129 | residue_dict[atom_idxs] = residue 130 | 131 | return residue_dict 132 | 133 | 134 | def get_side_chain_torsion_idxs(mol: Chem.Mol) -> Dict[int, List[int]]: 135 | """Get the indices of atoms in the side chains that we want to calculate internal coordinates 136 | for. 137 | 138 | Args: 139 | mol: Molecule. 140 | 141 | Returns: 142 | Mapping from alpha-carbon atom index to its side-chain indices. 143 | """ 144 | side_chain_torsion_idxs = {} 145 | 146 | for amino_acid_name, pattern in AMINO_ACID_TORSIONS_PATTERNS.items(): 147 | matches = mol.GetSubstructMatches(pattern) 148 | if matches: 149 | if amino_acid_name in AMINO_ACIDS_WITH_SYMMETRY: 150 | # Take every 2nd match 151 | assert len(matches) % 2 == 0 152 | matches = matches[::2] 153 | 154 | for match in matches: 155 | # Alpha carbon is 3rd matched atom 156 | alpha_carbon = match[2] 157 | assert alpha_carbon not in side_chain_torsion_idxs 158 | side_chain_torsion_idxs[alpha_carbon] = list(match) 159 | 160 | return side_chain_torsion_idxs 161 | -------------------------------------------------------------------------------- /ringer/utils/plotting.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union 3 | 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import seaborn as sns 9 | from astropy.visualization import LogStretch 10 | from astropy.visualization.mpl_normalize import ImageNormalize 11 | from matplotlib.offsetbox import AnchoredText 12 | from scipy import interpolate 13 | 14 | from . import utils 15 | 16 | 17 | def plot_ramachandran( 18 | data: pd.DataFrame, 19 | x: str = "phi", 20 | y: str = "psi", 21 | col: Optional[str] = None, 22 | col_wrap: Optional[int] = None, 23 | col_order: Optional[List[str]] = None, 24 | row: Optional[str] = None, 25 | row_order: Optional[List[str]] = None, 26 | height: int = 4, 27 | plot_type: Literal["density_scatter", "hexbin", "sns.scatterplot"] = "density_scatter", 28 | remove_axis_text: bool = False, 29 | path: Optional[Union[str, Path]] = None, 30 | **kwargs, 31 | ) -> sns.FacetGrid: 32 | g = sns.FacetGrid( 33 | data, 34 | col=col, 35 | col_wrap=col_wrap, 36 | col_order=col_order, 37 | row=row, 38 | row_order=row_order, 39 | sharex=True, 40 | sharey=True, 41 | height=height, 42 | aspect=1, 43 | despine=False, 44 | margin_titles=row is not None, 45 | gridspec_kws=dict(wspace=0.05, hspace=0.05), 46 | ) 47 | name, *attrs = plot_type.split(".") 48 | plot_func = globals()[name] 49 | for attr in attrs: 50 | plot_func = getattr(plot_func, attr) 51 | g.map(plot_func, x, y, **kwargs) 52 | format_facet_grid_dihedral_axis(g, which="x") 53 | format_facet_grid_dihedral_axis(g, which="y") 54 | g.set_xlabels(r"$\phi$") 55 | g.set_ylabels(r"$\psi$") 56 | g.tick_params(axis="both", labelsize=8) 57 | if row is None: 58 | g.set_titles(template="{col_name}") 59 | else: 60 | g.set_titles(col_template="{col_name}", row_template="{row_name}") 61 | for ax in g.axes.flat: 62 | ax.set_aspect("equal") 63 | if remove_axis_text: 64 | ax.axes.get_xaxis().set_visible(False) 65 | ax.axes.get_yaxis().set_visible(False) 66 | if path is not None: 67 | g.figure.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0.01) 68 | return g 69 | 70 | 71 | def plot_distributions( 72 | data: pd.DataFrame, 73 | x: str, 74 | col: Optional[str] = None, 75 | col_wrap: Optional[int] = None, 76 | row: Optional[str] = None, 77 | hue: Optional[str] = None, 78 | binwidth: float = 2 * np.pi / 100, 79 | sharex: bool = True, 80 | sharey: bool = False, 81 | height: int = 3, 82 | format_as_dihedral_axis: bool = True, 83 | xlabel: Optional[str] = None, 84 | latex: bool = True, 85 | path: Optional[Union[str, Path]] = None, 86 | add_kl_div: bool = True, 87 | **kwargs, 88 | ) -> sns.FacetGrid: 89 | g = sns.FacetGrid( 90 | data, 91 | hue=hue, 92 | col=col, 93 | col_wrap=col_wrap, 94 | row=row, 95 | sharex=sharex, 96 | sharey=sharey, 97 | height=height, 98 | margin_titles=row is not None, 99 | **kwargs, 100 | ) 101 | g.map(sns.histplot, x, stat="density", binwidth=binwidth, linewidth=0.2) 102 | if row is None: 103 | g.set_titles("") 104 | else: 105 | g.set_titles(col_template="", row_template="{row_name}") 106 | if format_as_dihedral_axis: 107 | format_facet_grid_dihedral_axis(g) 108 | if xlabel is None: 109 | for label, ax in g.axes_dict.items(): 110 | if row is None: 111 | if latex: 112 | ax.set_xlabel(rf"$\{label}$") 113 | else: 114 | ax.set_xlabel(label) 115 | else: 116 | if latex: 117 | ax.set_xlabel(rf"$\{label[1]}$") 118 | else: 119 | ax.set_xlabel(label[1]) 120 | else: 121 | g.set_xlabels(label=xlabel) 122 | g.set_ylabels(label="Density") 123 | g.tick_params(axis="y", labelsize=8) 124 | g.add_legend(title="", loc="upper left") 125 | if add_kl_div: 126 | 127 | def add_kldiv_to_plot(ax, kl_div): 128 | at = AnchoredText( 129 | f"KL = {kl_div:.4f}", prop=dict(size=10), frameon=True, loc="upper right" 130 | ) 131 | at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2") 132 | at.patch.set(alpha=0.8, edgecolor=(0.8, 0.8, 0.8, 0.8)) 133 | ax.add_artist(at) 134 | 135 | if col is None and row is None: 136 | kl_div = utils.compute_kl_divergence_from_dataframe(data, x) 137 | add_kldiv_to_plot(g.ax, kl_div[x]) 138 | else: 139 | by = list(filter(None, [row, col])) 140 | kl_divs = data.groupby(by).apply( 141 | lambda d: utils.compute_kl_divergence_from_dataframe(d, x) 142 | ) 143 | for label, ax in g.axes_dict.items(): 144 | if kl_divs is not None: 145 | kl_div = kl_divs.loc[label, x] 146 | add_kldiv_to_plot(ax, kl_div) 147 | g.figure.tight_layout(pad=0.01, h_pad=0.1, w_pad=0.1) 148 | if path is not None: 149 | g.figure.savefig(path, dpi=600, bbox_inches="tight", pad_inches=0.01) 150 | return g 151 | 152 | 153 | def plot_coverage( 154 | data: pd.DataFrame, 155 | x: str = "threshold", 156 | y: str = "cov", 157 | hue: str = "cov-type", 158 | col: str = "src", 159 | col_order: Sequence[str] = ("RMSD", "TFD"), 160 | path: Optional[Union[str, Path]] = None, 161 | **kwargs, 162 | ) -> sns.FacetGrid: 163 | g = sns.FacetGrid( 164 | data, 165 | hue=hue, 166 | col=col, 167 | col_order=col_order, 168 | sharex=False, 169 | height=4, 170 | despine=False, 171 | xlim=(0, data[x].max()), 172 | ylim=(0, 100), 173 | legend_out=False, 174 | **kwargs, 175 | ) 176 | g.map(sns.lineplot, x, y) 177 | xlabels = {"RMSD": "Threshold (Å)", "TFD": "Threshold"} 178 | for label, ax in g.axes_dict.items(): 179 | ax.set_xlabel(xlabels[label]) 180 | if label == "TFD": 181 | ax.set_xlim(0, 1) 182 | g.set_ylabels("Coverage (%)") 183 | g.set_titles(template="{col_name}") 184 | g.add_legend(title="") 185 | if path is not None: 186 | g.figure.savefig(path, dpi=600, bbox_inches="tight", pad_inches=0.01) 187 | return g 188 | 189 | 190 | def plot_ramachandran_plots( 191 | phi_psi_data: pd.DataFrame, 192 | plot_dir: Optional[Union[str, Path]] = None, 193 | name: str = "ramachandran", 194 | ext: str = ".png", 195 | col_order: Sequence[str] = ("Test", "Sampled"), 196 | as_rows: bool = False, 197 | residues: bool = False, 198 | ) -> None: 199 | if plot_dir is not None: 200 | plot_dir = Path(plot_dir) 201 | fname = f"{name}_residues{ext}" if residues else f"{name}{ext}" 202 | ramachandran_path = plot_dir / fname 203 | else: 204 | ramachandran_path = None 205 | 206 | col = "src" 207 | col_order = list(col_order) 208 | row = "num_residues" if residues else None 209 | row_order = ["4 residues", "5 residues", "6 residues"] if residues else None 210 | if as_rows: 211 | row, col = col, row 212 | row_order, col_order = col_order, row_order 213 | 214 | plot_ramachandran( 215 | phi_psi_data, 216 | col=col, 217 | col_order=col_order, 218 | row=row, 219 | row_order=row_order, 220 | height=3, 221 | s=1, 222 | edgecolors="none", 223 | cmap="magma", 224 | path=ramachandran_path, 225 | ) 226 | 227 | 228 | def plot_angle_and_dihedral_distributions( 229 | data: pd.DataFrame, 230 | plot_dir: Optional[Union[str, Path]] = None, 231 | ext: str = ".png", 232 | residues: bool = False, 233 | ) -> None: 234 | def get_path(name: str) -> Optional[Path]: 235 | if plot_dir is None: 236 | return None 237 | if residues: 238 | name += "_residues" 239 | fname = f"{name}{ext}" 240 | return Path(plot_dir) / fname 241 | 242 | if "angle" in data.columns: 243 | plot_distributions( 244 | data=data, 245 | x="angle", 246 | row="num_residues" if residues else None, 247 | row_order=["4 residues", "5 residues", "6 residues"] if residues else None, 248 | hue="src", 249 | binwidth=np.pi / 200, 250 | format_as_dihedral_axis=False, 251 | xlim=(1.5, 2.5), 252 | xlabel="Bond angle", 253 | height=2 if residues else 3, 254 | aspect=1.5, 255 | legend_out=False, 256 | despine=False, 257 | path=get_path("angle_dist"), 258 | ) 259 | plot_distributions( 260 | data=data, 261 | x="angle", 262 | row="num_residues" if residues else None, 263 | row_order=["4 residues", "5 residues", "6 residues"] if residues else None, 264 | col="angle_label", 265 | hue="src", 266 | height=2 if residues else 3, 267 | aspect=1.4 if residues else 1, 268 | binwidth=np.pi / 200, 269 | format_as_dihedral_axis=False, 270 | xlim=(1.5, 2.5), 271 | legend_out=False, 272 | despine=False, 273 | path=get_path("angles_dists"), 274 | ) 275 | 276 | if "dihedral" in data.columns: 277 | plot_distributions( 278 | data=data, 279 | x="dihedral", 280 | row="num_residues" if residues else None, 281 | row_order=["4 residues", "5 residues", "6 residues"] if residues else None, 282 | hue="src", 283 | binwidth=2 * np.pi / 60, 284 | xlabel="Dihedral angle", 285 | height=2 if residues else 3, 286 | aspect=1.5, 287 | legend_out=False, 288 | despine=False, 289 | path=get_path("dihedral_dist"), 290 | ) 291 | plot_distributions( 292 | data=data, 293 | x="dihedral", 294 | row="num_residues" if residues else None, 295 | row_order=["4 residues", "5 residues", "6 residues"] if residues else None, 296 | col="dihedral_label", 297 | hue="src", 298 | binwidth=2 * np.pi / 60, 299 | height=2 if residues else 3, 300 | aspect=1.4 if residues else 1, 301 | legend_out=False, 302 | despine=False, 303 | path=get_path("dihedrals_dists"), 304 | ) 305 | 306 | 307 | def plot_side_chain_distributions( 308 | data: pd.DataFrame, 309 | plot_dir: Optional[Union[str, Path]] = None, 310 | ext: str = ".png", 311 | ) -> None: 312 | def get_path(name: str) -> Optional[Path]: 313 | if plot_dir is None: 314 | return None 315 | fname = f"{name}{ext}" 316 | return Path(plot_dir) / fname 317 | 318 | plot_distributions( 319 | data=data[data["feature"].str.startswith("sc_a")], 320 | x="value", 321 | col="feature", 322 | col_wrap=5, 323 | hue="src", 324 | height=2, 325 | aspect=1, 326 | binwidth=np.pi / 200, 327 | format_as_dihedral_axis=False, 328 | xlabel="value", 329 | latex=False, 330 | xlim=(1.5, 2.5), 331 | legend_out=False, 332 | despine=False, 333 | path=get_path("sidechain_angle_dists"), 334 | ) 335 | plot_distributions( 336 | data=data[data["feature"].str.startswith("sc_chi")], 337 | x="value", 338 | col="feature", 339 | col_wrap=5, 340 | hue="src", 341 | height=2, 342 | aspect=1, 343 | binwidth=2 * np.pi / 60, 344 | format_as_dihedral_axis=True, 345 | latex=False, 346 | legend_out=False, 347 | despine=False, 348 | path=get_path("sidechain_dihedral_dists"), 349 | ) 350 | 351 | 352 | def format_facet_grid_dihedral_axis(g: sns.FacetGrid, which: Literal["x", "y"] = "x") -> None: 353 | lim = (-np.pi, np.pi) 354 | ticks = [-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi] 355 | ticklabels = [r"-$\pi$", r"-$\pi/2$", 0, r"$\pi/2$", r"$\pi$"] 356 | g.set(**{f"{which}lim": lim, f"{which}ticks": ticks, f"{which}ticklabels": ticklabels}) 357 | 358 | 359 | def hexbin( 360 | x: np.ndarray, 361 | y: np.ndarray, 362 | color: Optional[Union[str, Tuple[float, float, float]]] = None, 363 | gridsize: int = 50, 364 | bins: Union[Literal["log"], int, Sequence[float]] = "log", 365 | **kwargs, 366 | ): 367 | cmap = kwargs.pop("cmap", None) 368 | if cmap is None and color is not None: 369 | cmap = sns.light_palette(color, as_cmap=True) 370 | plt.hexbin(x, y, gridsize=gridsize, bins=bins, cmap=cmap, **kwargs) 371 | 372 | 373 | def density_scatter( 374 | x: Union[np.ndarray, pd.Series], 375 | y: Union[np.ndarray, pd.Series], 376 | ax: Optional[mpl.axes.Axes] = None, 377 | bins: Tuple[int, int] = (500, 500), 378 | color: Optional[Union[str, Tuple[float, float, float]]] = None, 379 | norm: ImageNormalize = ImageNormalize(vmin=0, vmax=1, stretch=LogStretch()), 380 | **kwargs, 381 | ) -> mpl.axes.Axes: 382 | # https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib 383 | if ax is None: 384 | ax = plt.gca() 385 | 386 | x = np.asarray(x) 387 | y = np.asarray(y) 388 | 389 | hist, xedges, yedges = np.histogram2d(x, y, bins=bins, density=True) 390 | points = (0.5 * (xedges[1:] + xedges[:-1]), 0.5 * (yedges[1:] + yedges[:-1])) 391 | points, hist = augment_with_periodic_bc(points, hist, domain=(2 * np.pi, 2 * np.pi)) 392 | z = interpolate.interpn( 393 | points, 394 | hist, 395 | np.vstack([x, y]).T, 396 | method="splinef2d", 397 | bounds_error=False, 398 | fill_value=0, 399 | ) 400 | 401 | # Sort the points by density, so that the densest points are plotted last 402 | idx = z.argsort() 403 | x, y, z = x[idx], y[idx], z[idx] 404 | 405 | cmap = kwargs.pop("cmap", None) 406 | if cmap is None and color is not None: 407 | cmap = sns.light_palette(color, as_cmap=True) 408 | 409 | ax.scatter(x, y, c=z, cmap=cmap, norm=norm, **kwargs) 410 | return ax 411 | 412 | 413 | def augment_with_periodic_bc( 414 | points: Tuple[np.ndarray, ...], 415 | values: np.ndarray, 416 | domain: Optional[Union[float, Sequence[float]]] = None, 417 | ) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]: 418 | """Augment the data to create periodic boundary conditions. 419 | 420 | Parameters 421 | ---------- 422 | points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, ) 423 | The points defining the regular grid in n dimensions. 424 | values : array_like, shape (m1, ..., mn, ...) 425 | The data on the regular grid in n dimensions. 426 | domain : float or None or array_like of shape (n, ) 427 | The size of the domain along each of the n dimensions 428 | or a uniform domain size along all dimensions if a 429 | scalar. Using None specifies aperiodic boundary conditions. 430 | 431 | Returns 432 | ------- 433 | points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, ) 434 | The points defining the regular grid in n dimensions with 435 | periodic boundary conditions. 436 | values : array_like, shape (m1, ..., mn, ...) 437 | The data on the regular grid in n dimensions with periodic 438 | boundary conditions. 439 | """ 440 | # https://stackoverflow.com/questions/25087111/2d-interpolation-with-periodic-boundary-conditions 441 | # Validate the domain argument 442 | n = len(points) 443 | if np.ndim(domain) == 0: 444 | domain = [domain] * n 445 | if np.shape(domain) != (n,): 446 | raise ValueError("`domain` must be a scalar or have the same " "length as `points`") 447 | 448 | # Pre- and append repeated points 449 | points = [ 450 | x if d is None else np.concatenate([x - d, x, x + d]) for x, d in zip(points, domain) 451 | ] 452 | 453 | # Tile the values as necessary 454 | reps = [1 if d is None else 3 for d in domain] 455 | values = np.tile(values, reps) 456 | 457 | return points, values 458 | -------------------------------------------------------------------------------- /ringer/utils/reconstruction.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union 3 | 4 | import pandas as pd 5 | from rdkit import Chem 6 | from scipy.optimize import OptimizeResult 7 | from tqdm.contrib.concurrent import process_map 8 | 9 | from . import chem, internal_coords 10 | 11 | 12 | def reconstruct_ring( 13 | mol: Chem.Mol, 14 | structure: Dict[str, Any], 15 | bond_dist_dict: Dict[str, float], 16 | bond_angle_dict: Optional[Dict[str, float]] = None, 17 | bond_angle_dev_dict: Optional[Dict[str, float]] = None, 18 | angles_as_constraints: bool = False, 19 | opt_init: Literal["best_dists", "average"] = "best_dists", 20 | skip_opt: bool = False, 21 | max_conf: Optional[int] = None, 22 | return_unsuccessful: bool = False, 23 | ncpu: int = 1, 24 | ) -> Union[ 25 | Tuple[Chem.Mol, List[pd.DataFrame]], 26 | Tuple[Chem.Mol, List[pd.DataFrame], Dict[int, OptimizeResult]], 27 | ]: 28 | """Reconstruct Cartesian coordinates and recover consistent sets of redundant internal 29 | coordinates. 30 | 31 | Args: 32 | mol: Molecule containing connectivity and at least one conformer. 33 | structure: Inconsistent internal coordinates and atom labels for several samples/conformers. 34 | bond_dist_dict: Bond distances that match the atom labels in structure, usually from the training data. 35 | bond_angle_dict: If structure doesn't contain bond angles, use these as constraints. 36 | bond_angle_dev_dict: If structure doesn't contain bond angles, use these as maximum deviations for the bond angle constraints. 37 | angles_as_constraints: Use the bond angles as constraints instead of targets (default if structure does not contain bond angles). 38 | opt_init: Initialization method for the optimization. 39 | skip_opt: Just set the coordinates in sequence and don't run the optimization. 40 | max_conf: Reconstruct at most this many conformers. 41 | return_unsuccessful: Return the results objects of unsuccessful optimizations. 42 | ncpu: Number of processes to use. 43 | 44 | Returns: 45 | Reconstructed mol with one conformer for each row in the structure dataframes containing 46 | new Cartesian coordinates of ring atoms and list of reconstructed coordinates for each 47 | conformer. 48 | """ 49 | angle_df = structure.get("angle") 50 | if angle_df is None: 51 | angles_as_constraints = True 52 | if angles_as_constraints and bond_angle_dict is None: 53 | raise ValueError("Must provide bond angles") 54 | dihedral_df = structure["dihedral"] 55 | 56 | # Set up the optimization class 57 | ring_idxs = dihedral_df.columns.tolist() 58 | ring_internal_coords = internal_coords.RingInternalCoordinates(ring_idxs) 59 | 60 | # When reconstructing mol, use mean bond distances from training data 61 | bond_dists = pd.Series( 62 | data=(bond_dist_dict[label] for label in structure["atom_labels"]), index=ring_idxs 63 | ) 64 | 65 | bond_angle_devs = None 66 | if angles_as_constraints: 67 | bond_angles = pd.Series( 68 | data=(bond_angle_dict[label] for label in structure["atom_labels"]), index=ring_idxs 69 | ) 70 | if bond_angle_dev_dict is not None: 71 | bond_angle_devs = pd.Series( 72 | data=(bond_angle_dev_dict[label] for label in structure["atom_labels"]), 73 | index=ring_idxs, 74 | ) 75 | 76 | # Obtain Cartesian coordinates and consistent angles 77 | # Each row in the dataframes contains the internal coordinates of a sample/conformer 78 | pfunc = partial( 79 | _to_cartesian_helper, 80 | mol=mol, 81 | ring_internal_coords=ring_internal_coords, 82 | distance_vals=bond_dists, 83 | angles_as_constraints=angles_as_constraints, 84 | angle_vals_max_devs=bond_angle_devs, 85 | opt_init=opt_init, 86 | skip_opt=skip_opt, 87 | print_warning=False, 88 | return_result_obj=True, 89 | ) 90 | inputs_list = [] 91 | for conf_idx, dihedrals in dihedral_df.iterrows(): 92 | if not angles_as_constraints: 93 | bond_angles = angle_df.loc[conf_idx] 94 | inputs_list.append((bond_angles, dihedrals)) 95 | if max_conf is not None: 96 | inputs_list = inputs_list[:max_conf] 97 | 98 | chunksize, extra = divmod(len(inputs_list), ncpu * 4) 99 | if extra: 100 | chunksize += 1 101 | results = process_map(pfunc, inputs_list, max_workers=ncpu, chunksize=chunksize) 102 | 103 | coords_opt = [result[0] for result in results] 104 | unsuccessful_results = { 105 | conf_idx: result[1] 106 | for conf_idx, result in zip(dihedral_df.index, results) 107 | if not result[1].success 108 | } 109 | 110 | # Make a new mol where the ring atoms contain the new coordinates 111 | mol_opt = chem.set_atom_positions(mol, coords_opt) 112 | 113 | if return_unsuccessful: 114 | return mol_opt, coords_opt, unsuccessful_results 115 | return mol_opt, coords_opt 116 | 117 | 118 | def _to_cartesian_helper( 119 | inputs: Tuple[pd.Series, pd.Series], 120 | ring_internal_coords: internal_coords.RingInternalCoordinates, 121 | **kwargs: Any, 122 | ) -> Union[pd.DataFrame, Tuple[pd.DataFrame, OptimizeResult]]: 123 | """Helper function for to_cartesian.""" 124 | bond_angles, dihedrals = inputs 125 | return ring_internal_coords.to_cartesian( 126 | angle_vals_target=bond_angles, 127 | dihedral_vals_target=dihedrals, 128 | **kwargs, 129 | ) 130 | -------------------------------------------------------------------------------- /ringer/utils/sampling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Optional, Sequence, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch import nn 8 | from tqdm.auto import tqdm 9 | 10 | from ..data import noised 11 | from . import internal_coords, utils, variance_schedules 12 | 13 | 14 | @torch.no_grad() 15 | def p_sample( 16 | model: nn.Module, 17 | x: torch.Tensor, 18 | t: torch.Tensor, 19 | seq_lengths: Sequence[int], 20 | t_index: Union[int, torch.Tensor], 21 | betas: torch.Tensor, 22 | atom_ids: Optional[torch.Tensor] = None, 23 | atom_features: Optional[torch.Tensor] = None, 24 | ) -> torch.Tensor: 25 | """Sample the given timestep. 26 | 27 | Note that this _may_ fall off the manifold if we just feed the output back into itself 28 | repeatedly, so we need to perform modulo on it (see p_sample_loop) 29 | """ 30 | # Calculate alphas and betas 31 | alpha_beta_values = variance_schedules.compute_alphas(betas) 32 | sqrt_recip_alphas = 1.0 / torch.sqrt(alpha_beta_values["alphas"]) 33 | 34 | # Select based on time 35 | t_unique = torch.unique(t) 36 | assert len(t_unique) == 1, f"Got multiple values for t: {t_unique}" 37 | t_index = t_unique.item() 38 | sqrt_recip_alphas_t = sqrt_recip_alphas[t_index] 39 | betas_t = betas[t_index] 40 | sqrt_one_minus_alphas_cumprod_t = alpha_beta_values["sqrt_one_minus_alphas_cumprod"][t_index] 41 | 42 | # Create the attention mask 43 | attn_mask = torch.zeros(x.shape[:2], device=x.device) 44 | for i, seq_length in enumerate(seq_lengths): 45 | attn_mask[i, :seq_length] = 1.0 46 | 47 | # Use model to predict the mean 48 | model_mean = sqrt_recip_alphas_t * ( 49 | x 50 | - betas_t 51 | * model(x, t, attention_mask=attn_mask, atom_ids=atom_ids, atom_features=atom_features) 52 | / sqrt_one_minus_alphas_cumprod_t 53 | ) 54 | 55 | if t_index == 0: 56 | return model_mean 57 | else: 58 | posterior_variance_t = alpha_beta_values["posterior_variance"][t_index] 59 | noise = torch.randn_like(x) 60 | return model_mean + torch.sqrt(posterior_variance_t) * noise 61 | 62 | 63 | @torch.no_grad() 64 | def p_sample_loop( 65 | model: nn.Module, 66 | seq_lengths: Sequence[int], 67 | noise: torch.Tensor, 68 | timesteps: int, 69 | betas: torch.Tensor, 70 | atom_ids: Optional[torch.Tensor] = None, 71 | atom_features: Optional[torch.Tensor] = None, 72 | is_angle: Union[bool, Sequence[bool]] = (False, True, True), 73 | disable_pbar: bool = False, 74 | ) -> torch.Tensor: 75 | """Returns a tensor of shape [timesteps x batch_size x seq_len x num_feat]""" 76 | device = next(model.parameters()).device 77 | b = noise.shape[0] 78 | x = noise.to(device) 79 | # Report metrics on starting noise 80 | # amin and amax support reducing on multiple dimensions 81 | logging.info( 82 | f"Starting from noise {noise.shape} with angularity {is_angle} and range {torch.amin(x, dim=(0, 1))} - {torch.amax(x, dim=(0, 1))} using {device}" 83 | ) 84 | 85 | outputs = [] 86 | for i in tqdm( 87 | reversed(range(0, timesteps)), desc="Time step", total=timesteps, disable=disable_pbar 88 | ): 89 | # Shape is (batch, seq_len, num_output) 90 | x = p_sample( 91 | model=model, 92 | x=x, 93 | t=torch.full((b,), i, device=device, dtype=torch.long), # Time vector 94 | seq_lengths=seq_lengths, 95 | t_index=i, 96 | betas=betas, 97 | atom_ids=None if atom_ids is None else atom_ids.to(device), 98 | atom_features=None if atom_features is None else atom_features.to(device), 99 | ) 100 | 101 | # Wrap if angular 102 | if isinstance(is_angle, bool): 103 | if is_angle: 104 | x = utils.modulo_with_wrapped_range(x, range_min=-torch.pi, range_max=torch.pi) 105 | else: 106 | assert len(is_angle) == x.shape[-1] 107 | x[..., is_angle] = utils.modulo_with_wrapped_range( 108 | x[..., is_angle], range_min=-torch.pi, range_max=torch.pi 109 | ) 110 | 111 | outputs.append(x.cpu()) 112 | 113 | return torch.stack(outputs) 114 | 115 | 116 | def sample_batch( 117 | model: nn.Module, 118 | dset: noised.NoisedDataset, 119 | seq_lengths: Sequence[int], 120 | atom_ids: Optional[torch.Tensor] = None, 121 | atom_features: Optional[torch.Tensor] = None, 122 | uniform: bool = False, 123 | final_timepoint_only: bool = True, 124 | disable_pbar: bool = False, 125 | ) -> List[np.ndarray]: 126 | noise = dset.sample_noise( 127 | torch.zeros((len(seq_lengths), dset.pad, model.n_inputs), dtype=torch.float32), 128 | uniform=uniform, 129 | ) 130 | 131 | samples = p_sample_loop( 132 | model=model, 133 | seq_lengths=seq_lengths, 134 | noise=noise, 135 | timesteps=dset.timesteps, 136 | betas=dset.alpha_beta_terms["betas"], 137 | atom_ids=atom_ids, 138 | atom_features=atom_features, 139 | is_angle=dset.feature_is_angular, 140 | disable_pbar=disable_pbar, 141 | ) # [timesteps x batch_size x seq_len x num_feat] 142 | 143 | if final_timepoint_only: 144 | samples = samples[-1] 145 | 146 | # Assumes dset.means contains the training data means 147 | means = dset.means 148 | if means is not None: 149 | logging.info(f"Shifting predicted values by original offset: {means}") 150 | samples += means 151 | # Wrap because shifting could have gone beyond boundary 152 | samples[..., dset.feature_is_angular] = utils.modulo_with_wrapped_range( 153 | samples[..., dset.feature_is_angular], range_min=-torch.pi, range_max=torch.pi 154 | ) 155 | 156 | # Trim each element in the batch to its sequence length 157 | trimmed_samples = [ 158 | samples[..., i, :seq_len, :].numpy() for i, seq_len in enumerate(seq_lengths) 159 | ] 160 | 161 | return trimmed_samples 162 | 163 | 164 | def sample_unconditional_from_lengths( 165 | model: nn.Module, 166 | dset: noised.NoisedDataset, 167 | seq_lengths: Sequence[int], 168 | uniform: bool = False, 169 | batch_size: int = 65536, 170 | final_timepoint_only: bool = True, 171 | disable_pbar: bool = False, 172 | ) -> Union[List[np.ndarray], List[pd.DataFrame]]: 173 | """Run reverse diffusion for unconditional macrocycle backbone generation. 174 | 175 | Args: 176 | model: Model. 177 | dset: Only needed for its means, sample_noise, timesteps, alpha_beta_terms, feature_is_angular, and pad attributes. 178 | seq_lengths: Generate one sample for each sequence length provided. 179 | uniform: Sample uniformly instead of from a wrapped normal. 180 | batch_size: Batch size. 181 | final_timepoint_only: Only return the sample at the final (non-noisy) timepoint. 182 | disable_pbar: Don't display a progress bar. 183 | """ 184 | samples = [] 185 | atom_id_lists = [] 186 | chunks = [(i, i + batch_size) for i in range(0, len(seq_lengths), batch_size)] 187 | 188 | logging.info(f"Sampling {len(seq_lengths)} items in batches of size {batch_size}") 189 | for idx_start, idx_end in chunks: 190 | seq_lengths_batch = seq_lengths[idx_start:idx_end] 191 | 192 | # Need backbone atom labels 193 | # Technically, we don't need to do this separately for each sequence because the attention 194 | # mask will take care of the extraneous ones 195 | atom_ids_batch = torch.zeros(len(seq_lengths_batch), dset.pad, dtype=torch.long) 196 | for i, seq_length in enumerate(seq_lengths_batch): 197 | assert seq_length % 3 == 0 198 | atom_id_list = internal_coords.BACKBONE_ATOM_IDS * (seq_length // 3) 199 | atom_ids_batch[i, :seq_length] = torch.tensor(atom_id_list, dtype=torch.long) 200 | atom_id_lists.append(atom_id_list) 201 | 202 | samples_batch = sample_batch( 203 | model=model, 204 | dset=dset, 205 | seq_lengths=seq_lengths_batch, 206 | atom_ids=atom_ids_batch, 207 | uniform=uniform, 208 | final_timepoint_only=final_timepoint_only, 209 | disable_pbar=disable_pbar, 210 | ) 211 | samples.extend(samples_batch) 212 | 213 | # Label predictions 214 | if final_timepoint_only: 215 | samples = [pd.DataFrame(data=sample, columns=dset.feature_names) for sample in samples] 216 | for sample, atom_ids in zip(samples, atom_id_lists): 217 | sample["atom_label"] = [ 218 | internal_coords.BACKBONE_ATOM_ID_TO_LABEL[atom_id] for atom_id in atom_ids 219 | ] 220 | 221 | return samples 222 | 223 | 224 | def sample_unconditional( 225 | model: nn.Module, 226 | dset: noised.NoisedDataset, 227 | num_samples: int = 1, 228 | uniform: bool = False, 229 | batch_size: int = 65536, 230 | final_timepoint_only: bool = True, 231 | disable_pbar: bool = False, 232 | ) -> Union[List[np.ndarray], List[pd.DataFrame]]: 233 | """Sample num_samples samples by first sampling num_samples lengths using dset and then passing 234 | these to sample_unconditional_from_lengths.""" 235 | seq_lengths = dset.sample_length(n=num_samples) 236 | if isinstance(seq_lengths, int): 237 | seq_lengths = [seq_lengths] 238 | return sample_unconditional_from_lengths( 239 | model, 240 | dset, 241 | seq_lengths, 242 | uniform=uniform, 243 | batch_size=batch_size, 244 | final_timepoint_only=final_timepoint_only, 245 | disable_pbar=disable_pbar, 246 | ) 247 | 248 | 249 | def sample_conditional( 250 | model: nn.Module, 251 | dset: noised.NoisedDataset, 252 | samples_multiplier: int = 2, 253 | samples_per_mol: Optional[int] = None, 254 | uniform: bool = False, 255 | batch_size: int = 65536, 256 | final_timepoint_only: bool = True, 257 | disable_pbar: bool = False, 258 | ) -> Union[ 259 | Dict[str, Dict[str, pd.DataFrame]], Dict[str, Dict[str, Union[List[int], List[np.ndarray]]]] 260 | ]: 261 | """Run reverse diffusion on a set of macrocycles conditioned on atom sequence. 262 | 263 | Args: 264 | model: Model. 265 | dset: Dataset to generate samples for (must contain atom features). 266 | samples_multiplier: For each molecule in the dataset, generate sample_multiplier * num_conformers samples. 267 | samples_per_mol: Override samples_multiplier to generate exactly this many samples per molecule. 268 | uniform: Sample uniformly instead of from a wrapped normal. 269 | batch_size: Batch size. 270 | final_timepoint_only: Only return the sample at the final (non-noisy) timepoint. 271 | disable_pbar: Don't display a progress bar. 272 | """ 273 | if dset.atom_features is None: 274 | raise ValueError("Dataset must have atom features") 275 | 276 | num_samples_per_mol = [] # Per mol 277 | all_atom_idxs = [] # Per mol 278 | all_fnames = [] # Per mol 279 | seq_lengths = [] # Per conformer 280 | atom_features = [] # Per conformer 281 | 282 | for fname, structure in dset.structures.items(): 283 | if samples_per_mol is None: 284 | num_conf = len(structure[dset.feature_names[0]]) 285 | num_to_sample = samples_multiplier * num_conf 286 | else: 287 | num_to_sample = samples_per_mol 288 | 289 | atom_features_padded, atom_idxs = dset.get_atom_features(fname, pad=True, return_idxs=True) 290 | atom_features_padded_repeated = atom_features_padded.expand(num_to_sample, -1, -1) 291 | seq_length = len(atom_idxs) 292 | 293 | num_samples_per_mol.append(num_to_sample) 294 | all_atom_idxs.append(atom_idxs) 295 | all_fnames.append(fname) 296 | seq_lengths.extend(num_to_sample * [seq_length]) 297 | atom_features.append(atom_features_padded_repeated) 298 | 299 | atom_features = torch.cat(atom_features) 300 | 301 | samples = [] 302 | chunks = [(i, i + batch_size) for i in range(0, len(seq_lengths), batch_size)] 303 | 304 | logging.info(f"Sampling {len(seq_lengths)} items in batches of size {batch_size}") 305 | for idx_start, idx_end in chunks: 306 | samples_batch = sample_batch( 307 | model=model, 308 | dset=dset, 309 | seq_lengths=seq_lengths[idx_start:idx_end], 310 | atom_features=atom_features[idx_start:idx_end], 311 | uniform=uniform, 312 | final_timepoint_only=final_timepoint_only, 313 | disable_pbar=disable_pbar, 314 | ) 315 | samples.extend(samples_batch) 316 | 317 | # samples is a flat list, need to map it back to mols 318 | mol_chunks = [0] + np.cumsum(num_samples_per_mol).tolist() 319 | samples_dict = {} 320 | 321 | # Aggregate samples for each molecule 322 | for mol_idx, (idx_start, idx_end) in enumerate(zip(mol_chunks, mol_chunks[1:])): 323 | # samples_mol is num_samples * [timesteps x seq_len x num_feat] 324 | # or num_samples * [seq_len x num_feat] 325 | samples_mol = samples[idx_start:idx_end] 326 | samples_mol = np.stack(samples_mol) # [num_samples x ...] 327 | 328 | fname = all_fnames[mol_idx] 329 | structure = dset.structures[fname] 330 | atom_idxs = all_atom_idxs[mol_idx] 331 | 332 | if final_timepoint_only: # Return as dataframes 333 | samples_mol_dict = {"atom_labels": structure["atom_labels"]} 334 | for feat_idx, feature_name in enumerate(dset.feature_names): 335 | df = pd.DataFrame(data=samples_mol[..., feat_idx], columns=atom_idxs) 336 | df.index.name = "sample_idx" 337 | feat_missing = structure[feature_name].iloc[0].isna() 338 | feat_missing_cols = feat_missing[feat_missing].index 339 | df[feat_missing_cols] = np.nan 340 | samples_mol_dict[feature_name] = df 341 | samples_dict[fname] = samples_mol_dict 342 | else: # Return as arrays 343 | samples_mol_dict = {"atom_idxs": atom_idxs, "atom_labels": structure["atom_labels"]} 344 | for feat_idx, feature_name in enumerate(dset.feature_names): 345 | # [num_samples x timesteps x seq_len] 346 | samples_mol_dict[feature_name] = samples_mol[..., feat_idx] 347 | samples_dict[fname] = samples_mol_dict 348 | 349 | return samples_dict 350 | -------------------------------------------------------------------------------- /ringer/utils/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import logging 4 | from pathlib import Path 5 | from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from scipy import stats 10 | from typer.models import ParameterInfo 11 | 12 | 13 | def get_wrapped_overlapping_sublists(list_: List[Any], size: int) -> Iterator[List[Any]]: 14 | for idx in range(len(list_)): 15 | idxs = [idx] 16 | for offset in range(1, size): 17 | # Wrap past end of list 18 | idxs.append((idx + offset) % len(list_)) 19 | yield [list_[i] for i in idxs] 20 | 21 | 22 | def get_overlapping_sublists( 23 | list_: List[Any], size: int, wrap: bool = True 24 | ) -> Iterator[List[Any]]: 25 | if wrap: 26 | for item in get_wrapped_overlapping_sublists(list_, size): 27 | yield item 28 | else: 29 | for i in range(len(list_) - size + 1): 30 | yield list_[i : i + size] 31 | 32 | 33 | def compute_kl_divergence(p: np.ndarray, q: np.ndarray, nbins: int = 100) -> float: 34 | min_val = min(np.min(p), np.min(q)) 35 | max_val = min(np.max(p), np.max(q)) 36 | bins = np.linspace(min_val, max_val, nbins + 1) 37 | p_hist, _ = np.histogram(p, bins=bins) 38 | q_hist, _ = np.histogram(q, bins=bins) 39 | # Handle zero-counts 40 | p_hist[p_hist == 0] = 1 41 | q_hist[q_hist == 0] = 1 42 | return stats.entropy(p_hist, q_hist) 43 | 44 | 45 | def compute_kl_divergence_from_dataframe( 46 | df: pd.DataFrame, 47 | *data_cols: str, 48 | key_col: str = "src", 49 | pkey: str = "Test", 50 | qkey: str = "Sampled", 51 | nbins: int = 100, 52 | ) -> pd.Series: 53 | dfp = df[df[key_col] == pkey] 54 | dfq = df[df[key_col] == qkey] 55 | return pd.Series( 56 | {col: compute_kl_divergence(dfp[col], dfq[col], nbins=nbins) for col in data_cols} 57 | ) 58 | 59 | 60 | def tolerant_comparison_check(values, cmp: Literal[">=", "<="], v): 61 | """Compares values in a way that is tolerant of numerical precision. 62 | 63 | >>> tolerant_comparison_check(-3.1415927410125732, ">=", -np.pi) 64 | True 65 | """ 66 | if cmp == ">=": # v is a lower bound 67 | minval = np.nanmin(values) 68 | diff = minval - v 69 | if np.isclose(diff, 0, atol=1e-5): 70 | return True # Passes 71 | return diff > 0 72 | elif cmp == "<=": 73 | maxval = np.nanmax(values) 74 | diff = maxval - v 75 | if np.isclose(diff, 0, atol=1e-5): 76 | return True 77 | return diff < 0 78 | else: 79 | raise ValueError(f"Illegal comparator: {cmp}") 80 | 81 | 82 | def modulo_with_wrapped_range(vals, range_min: float = -np.pi, range_max: float = np.pi): 83 | """Modulo with wrapped range -- capable of handing a range with a negative min. 84 | 85 | >>> modulo_with_wrapped_range(3, -2, 2) 86 | -1 87 | """ 88 | assert range_min <= 0.0 89 | assert range_min < range_max 90 | 91 | # Modulo after we shift values 92 | top_end = range_max - range_min 93 | # Shift the values to be in the range [0, top_end) 94 | vals_shifted = vals - range_min 95 | # Perform modulo 96 | vals_shifted_mod = vals_shifted % top_end 97 | # Shift back down 98 | retval = vals_shifted_mod + range_min 99 | 100 | return retval 101 | 102 | 103 | def wrapped_mean(x: np.ndarray, axis: Optional[int] = None) -> Union[float, np.ndarray]: 104 | """Wrap the mean function about [-pi, pi]""" 105 | # https://rosettacode.org/wiki/Averages/Mean_angle 106 | sin_x = np.sin(x) 107 | cos_x = np.cos(x) 108 | 109 | retval = np.arctan2(np.nanmean(sin_x, axis=axis), np.nanmean(cos_x, axis=axis)) 110 | return retval 111 | 112 | 113 | def update_dict_nonnull(d: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]: 114 | """Update a dictionary with values from another dictionary. 115 | 116 | >>> update_dict_nonnull({'a': 1, 'b': 2}, {'b': 3, 'c': 4}) 117 | {'a': 1, 'b': 3, 'c': 4} 118 | """ 119 | for k, v in vals.items(): 120 | if k in d: 121 | if d[k] != v and v is not None: 122 | logging.info(f"Replacing key {k} original value {d[k]} with {v}") 123 | d[k] = v 124 | else: 125 | d[k] = v 126 | return d 127 | 128 | 129 | def md5_all_py_files(dir_name: Union[str, Path]) -> str: 130 | """Create a single md5 sum for all given files.""" 131 | # https://stackoverflow.com/questions/36099331/how-to-grab-all-files-in-a-folder-and-get-their-md5-hash-in-python 132 | dir_name = Path(dir_name) 133 | fnames = dir_name.glob("*.py") 134 | hash_md5 = hashlib.md5() 135 | for fname in sorted(fnames): 136 | with open(fname, "rb") as f: 137 | for chunk in iter(lambda: f.read(2**20), b""): 138 | hash_md5.update(chunk) 139 | return hash_md5.hexdigest() 140 | 141 | 142 | def unwrap_typer_args(func: Callable): 143 | # https://github.com/tiangolo/typer/issues/279#issuecomment-841875218 144 | @functools.wraps(func) 145 | def wrapper(*args, **kwargs): 146 | default_values = func.__defaults__ 147 | patched_defaults = tuple( 148 | value.default if isinstance(value, ParameterInfo) else value 149 | for value in default_values 150 | ) 151 | func.__defaults__ = patched_defaults 152 | 153 | return func(*args, **kwargs) 154 | 155 | return wrapper 156 | -------------------------------------------------------------------------------- /ringer/utils/variance_schedules.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Literal 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | SCHEDULES = Literal["linear", "cosine", "quadratic"] 8 | 9 | 10 | def cosine_beta_schedule(timesteps: int, s: float = 8e-3) -> torch.Tensor: 11 | """Cosine scheduling https://arxiv.org/pdf/2102.09672.pdf.""" 12 | steps = timesteps + 1 13 | x = torch.linspace(0, timesteps, steps) 14 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 15 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 16 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 17 | return torch.clip(betas, 0.0001, 0.9999) 18 | 19 | 20 | def linear_beta_schedule(timesteps: int, beta_start=1e-4, beta_end=0.02) -> torch.Tensor: 21 | return torch.linspace(beta_start, beta_end, timesteps) 22 | 23 | 24 | def quadratic_beta_schedule(timesteps: int, beta_start=1e-4, beta_end=0.02) -> torch.Tensor: 25 | betas = torch.linspace(-6, 6, timesteps) 26 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 27 | 28 | 29 | def compute_alphas(betas: torch.Tensor) -> Dict[str, torch.Tensor]: 30 | """Compute the alphas from the betas.""" 31 | alphas = 1.0 - betas 32 | alphas_cumprod = torch.cumprod(alphas, dim=0) 33 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) 34 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 35 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 36 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) 37 | return { 38 | "betas": betas, 39 | "alphas": alphas, 40 | "alphas_cumprod": alphas_cumprod, 41 | "sqrt_alphas_cumprod": sqrt_alphas_cumprod, 42 | "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod, 43 | "posterior_variance": posterior_variance, 44 | } 45 | 46 | 47 | def get_variance_schedule(keyword: SCHEDULES, timesteps: int, **kwargs) -> torch.Tensor: 48 | """Easy interface for getting a variance schedule based on keyword and number of timesteps.""" 49 | logging.info(f"Getting {keyword} variance schedule with {timesteps} timesteps") 50 | if keyword == "cosine": 51 | return cosine_beta_schedule(timesteps, **kwargs) 52 | elif keyword == "linear": 53 | return linear_beta_schedule(timesteps, **kwargs) 54 | elif keyword == "quadratic": 55 | return quadratic_beta_schedule(timesteps, **kwargs) 56 | else: 57 | raise ValueError(f"Unrecognized variance schedule: {keyword}") 58 | -------------------------------------------------------------------------------- /scripts/aggregate_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import pickle 5 | from pathlib import Path 6 | from typing import Any, Union 7 | 8 | import typer 9 | 10 | from ringer.utils import evaluation 11 | 12 | 13 | def load_pickle(path: Union[str, Path]) -> Any: 14 | with open(path, "rb") as f: 15 | return pickle.load(f) 16 | 17 | 18 | def save_pickle(path: Union[str, Path], data: Any) -> None: 19 | with open(path, "wb") as f: 20 | pickle.dump(data, f) 21 | 22 | 23 | def aggregate_metrics( 24 | mol_metrics_dir: str = "metrics", 25 | out_dir: str = ".", 26 | ) -> None: 27 | metrics_dir = Path(mol_metrics_dir) 28 | output_dir = Path(out_dir) 29 | assert output_dir.exists() 30 | 31 | metrics = {path.name: load_pickle(path) for path in metrics_dir.glob("*.pickle")} 32 | 33 | # Simplify and aggregate results 34 | metrics = evaluation.CovMatEvaluator.stack_results(metrics) 35 | metrics_aggregated = evaluation.CovMatEvaluator.aggregate_results(metrics) 36 | 37 | metrics_path = output_dir / "metrics.pickle" 38 | metrics_aggregated_path = output_dir / "metrics_aggregated.pickle" 39 | save_pickle(metrics_path, metrics) 40 | save_pickle(metrics_aggregated_path, metrics_aggregated) 41 | 42 | 43 | def main() -> None: 44 | logging.basicConfig(level=logging.INFO) 45 | typer.run(aggregate_metrics) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /scripts/compute_metrics_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import pickle 5 | from pathlib import Path 6 | from typing import Any, Optional, Union 7 | 8 | import typer 9 | from rdkit import Chem 10 | 11 | from ringer.utils import evaluation 12 | 13 | 14 | def load_mol(path: Union[str, Path]) -> Chem.Mol: 15 | with open(path, "rb") as f: 16 | mol = pickle.load(f) 17 | if isinstance(mol, dict): 18 | mol = mol["rd_mol"] 19 | return mol 20 | 21 | 22 | def save_pickle(path: Union[str, Path], data: Any) -> None: 23 | with open(path, "wb") as f: 24 | pickle.dump(data, f) 25 | 26 | 27 | def remove_confs(mol: Chem.Mol, confs_to_keep: int = 1) -> Chem.Mol: 28 | new_mol = Chem.Mol(mol) 29 | conf_ids = [conf.GetId() for conf in new_mol.GetConformers()] 30 | conf_ids_to_remove = conf_ids[confs_to_keep:] 31 | for conf_id in conf_ids_to_remove: 32 | new_mol.RemoveConformer(conf_id) 33 | return new_mol 34 | 35 | 36 | def compute_metrics( 37 | mol_true_path: str, 38 | mol_reconstructed_path: str, 39 | max_true_confs: Optional[int] = None, 40 | out_dir: str = "metrics", 41 | include_all_atom: bool = True, 42 | ncpu: int = 1, 43 | ) -> None: 44 | output_dir = Path(out_dir) 45 | output_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | mol_path = Path(mol_true_path) 48 | mol_opt_path = Path(mol_reconstructed_path) 49 | assert mol_path.name == mol_opt_path.name 50 | 51 | if not mol_opt_path.exists(): 52 | raise IOError(f"'{mol_opt_path}' is missing") 53 | 54 | # Load data 55 | mol = load_mol(mol_path) 56 | mol_opt = load_mol(mol_opt_path) 57 | 58 | if max_true_confs is not None: 59 | mol = remove_confs(mol, max_true_confs) 60 | 61 | # Evaluate 62 | metric_names = ["ring-rmsd", "ring-tfd"] 63 | if include_all_atom: 64 | metric_names.append("rmsd") 65 | cov_mat_evaluator = evaluation.CovMatEvaluator(metric_names) 66 | metrics = cov_mat_evaluator(mol_opt, mol, ncpu=ncpu) 67 | 68 | metrics_path = output_dir / mol_path.name 69 | save_pickle(metrics_path, metrics) 70 | 71 | 72 | def main() -> None: 73 | logging.basicConfig(level=logging.INFO) 74 | typer.run(compute_metrics) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /scripts/reconstruct_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import logging 5 | import pickle 6 | import time 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional, Union 9 | 10 | import pandas as pd 11 | import typer 12 | from rdkit import Chem 13 | 14 | import ringer 15 | from ringer.sidechain_reconstruction import ( 16 | Macrocycle, 17 | Reconstructor, 18 | set_rdkit_geometries, 19 | ) 20 | from ringer.utils import reconstruction 21 | 22 | RECONSTRUCTION_DATA_PATH = ( 23 | Path(ringer.__file__).resolve().parent 24 | / "sidechain_reconstruction/data/reconstruction_data.pickle" 25 | ) 26 | 27 | 28 | def load_json(path: Union[str, Path]) -> Dict[str, Any]: 29 | with open(path) as f: 30 | return json.load(f) 31 | 32 | 33 | def load_pickle(path: Union[str, Path]) -> Any: 34 | with open(path, "rb") as f: 35 | return pickle.load(f) 36 | 37 | 38 | def save_pickle(path: Union[str, Path], data: Any) -> None: 39 | with open(path, "wb") as f: 40 | pickle.dump(data, f) 41 | 42 | 43 | def merge_samples( 44 | sample_backbone: Dict[str, Any], sample_sidechain: Dict[str, Any] 45 | ) -> Dict[str, Any]: 46 | """Merge samples from two sources corresponding to reconstructed backbones and the original 47 | sidechain predictions that do not get modified during reconstruction.""" 48 | merged_dict = {} 49 | merged_dict["angle"] = sample_backbone["angle"] 50 | merged_dict["dihedral"] = sample_backbone["dihedral"] 51 | 52 | sidechain_internals = [ 53 | "sc_a0", 54 | "sc_a1", 55 | "sc_a2", 56 | "sc_a3", 57 | "sc_a4", 58 | "sc_chi0", 59 | "sc_chi1", 60 | "sc_chi2", 61 | "sc_chi3", 62 | "sc_chi4", 63 | ] 64 | 65 | for sc_key in sidechain_internals: 66 | merged_dict[sc_key] = sample_sidechain[sc_key] 67 | return merged_dict 68 | 69 | 70 | def reconstruct( 71 | idx: int, 72 | mol_dir: str, 73 | structures_path: str, 74 | out_dir: str, 75 | mean_distances_path: str, 76 | reconstruct_sidechains: bool = True, 77 | mean_angles_path: Optional[str] = None, 78 | std_angles_path: Optional[str] = None, 79 | angles_as_constraints: bool = False, 80 | opt_init: str = "average", 81 | skip_opt: bool = False, 82 | max_conf: Optional[int] = None, 83 | ncpu: int = 1, 84 | ) -> None: 85 | mol_opt_dir = Path(out_dir) 86 | mol_opt_dir.mkdir(parents=True, exist_ok=True) 87 | 88 | # Load data 89 | structures_dict = load_pickle(structures_path) 90 | mean_bond_distances = load_json(mean_distances_path) 91 | mean_bond_angles = None if mean_angles_path is None else load_json(mean_angles_path) 92 | std_bond_angles = None if std_angles_path is None else load_json(std_angles_path) 93 | 94 | # Load mol 95 | fname = list(structures_dict.keys())[idx] 96 | mol_name = Path(fname).name 97 | mol_path = Path(mol_dir) / mol_name 98 | mol = load_pickle(mol_path) 99 | if isinstance(mol, dict): 100 | mol = mol["rd_mol"] 101 | structure = structures_dict[fname] 102 | 103 | # Reconstruct 104 | logging.info(f"Reconstructing {mol_name}") 105 | if skip_opt: 106 | logging.info("Skipping opt") 107 | 108 | start_time = time.time() 109 | result = reconstruction.reconstruct_ring( 110 | mol=mol, 111 | structure=structure, 112 | bond_dist_dict=mean_bond_distances, 113 | bond_angle_dict=mean_bond_angles, 114 | bond_angle_dev_dict=std_bond_angles, 115 | angles_as_constraints=angles_as_constraints, 116 | opt_init=opt_init, 117 | skip_opt=skip_opt, 118 | max_conf=max_conf, 119 | return_unsuccessful=False, # Don't save unsuccessful optimizations for now 120 | ncpu=ncpu, 121 | ) 122 | end_time = time.time() 123 | logging.info(f"Reconstruction took {end_time - start_time:.2f} seconds") 124 | 125 | mol_opt = result[0] 126 | 127 | # Post-process and dump data 128 | def get_structure_from_coords(coords: List[pd.DataFrame]) -> Dict[str, Any]: 129 | # Convert list of coords to structure 130 | # Concatenate making hierarchical index of sample_idx and atom_idx 131 | coords_stacked = pd.concat( 132 | coords, keys=range(len(coords)), names=["sample_idx", "atom_idx"] 133 | ) 134 | # Pivot so we can get all samples for each feature from the outermost column 135 | coords_pivoted = coords_stacked.unstack(level="atom_idx") 136 | new_structure = { 137 | feat_name: coords_pivoted[feat_name] for feat_name in coords_pivoted.columns.levels[0] 138 | } 139 | new_structure["atom_labels"] = structure["atom_labels"] 140 | return new_structure 141 | 142 | if reconstruct_sidechains: 143 | logging.info("Reconstructing sidechains") 144 | 145 | with open(RECONSTRUCTION_DATA_PATH, "rb") as f: 146 | reconstruction_config = pickle.load(f) 147 | 148 | coords_opt = result[1] 149 | structure_opt = get_structure_from_coords(coords_opt) 150 | 151 | # Merge with original sidechain predictions 152 | sample = merge_samples(structure_opt, structure) 153 | 154 | mol_opt_no_h = Chem.RemoveHs(mol_opt) 155 | mc = Macrocycle( 156 | mol_opt_no_h, 157 | reconstruction_config, 158 | coords=False, 159 | copy=True, 160 | verify=True, 161 | ) 162 | reconstructor = Reconstructor(mc) 163 | 164 | internals_tensor = reconstructor.parse_internals(sample) 165 | index_tensor = reconstructor.stacked_tuples 166 | 167 | positions = reconstructor.reconstruct(internals_tensor, index_tensor) 168 | mol_opt = set_rdkit_geometries(mol_opt_no_h, positions, copy=True) 169 | 170 | mol_opt_path = mol_opt_dir / mol_name 171 | save_pickle(mol_opt_path, mol_opt) 172 | 173 | 174 | def main() -> None: 175 | logging.basicConfig(level=logging.INFO) 176 | typer.run(reconstruct) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open("README.md") as source: 6 | long_description = source.read() 7 | 8 | setup( 9 | name="ringer", 10 | version="1.1.0", 11 | description="Rapid conformer generation for macrocycles with internal coordinate diffusion", 12 | author="Colin Grambow, Hayley Weir, Kangway Chuang", 13 | author_email="grambow.colin@gene.com", 14 | url="https://github.com/Genentech/ringer", 15 | install_requires=[], 16 | packages=find_packages(), 17 | entry_points={ 18 | "console_scripts": [ 19 | "train = ringer.train:main", 20 | "evaluate = ringer.eval:main", 21 | ] 22 | }, 23 | ) 24 | --------------------------------------------------------------------------------