├── .dockerignore ├── .env.example ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── code-quality-main.yaml │ ├── code-quality-pr.yaml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── citation.bib ├── configs ├── analysis │ ├── bust_analysis.yaml │ ├── inference_analysis.yaml │ └── molecule_analysis.yaml ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── ema.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── datamodule │ ├── dataloader_cfg │ │ ├── edm_geom_dataloader.yaml │ │ └── edm_qm9_dataloader.yaml │ ├── edm_geom.yaml │ └── edm_qm9.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── experiment │ ├── geom_mol_gen_ddpm.yaml │ ├── geom_mol_gen_ddpm_grid_search.yaml │ ├── qm9_mol_gen_conditional_ddpm.yaml │ ├── qm9_mol_gen_conditional_ddpm_grid_search.yaml │ ├── qm9_mol_gen_ddpm.yaml │ └── qm9_mol_gen_ddpm_grid_search.yaml ├── extras │ └── default.yaml ├── hparams_search │ ├── geom_optuna.yaml │ └── qm9_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── diffusion_cfg │ │ ├── geom_mol_gen_ddpm.yaml │ │ └── qm9_mol_gen_ddpm.yaml │ ├── geom_mol_gen_ddpm.yaml │ ├── layer_cfg │ │ ├── geom_mol_gen_ddpm_gcp_interaction_layer.yaml │ │ ├── mp_cfg │ │ │ ├── geom_mol_gen_ddpm_gcp_mp.yaml │ │ │ └── qm9_mol_gen_ddpm_gcp_mp.yaml │ │ └── qm9_mol_gen_ddpm_gcp_interaction_layer.yaml │ ├── model_cfg │ │ ├── geom_mol_gen_ddpm_gcp_model.yaml │ │ └── qm9_mol_gen_ddpm_gcp_model.yaml │ ├── module_cfg │ │ ├── geom_mol_gen_ddpm_gcp_module.yaml │ │ └── qm9_mol_gen_ddpm_gcp_module.yaml │ └── qm9_mol_gen_ddpm.yaml ├── mol_gen_eval.yaml ├── mol_gen_eval_conditional_qm9.yaml ├── mol_gen_eval_optimization_qm9.yaml ├── mol_gen_sample.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── environment.yaml ├── img ├── Bio-Diffusion.png ├── GCDM.png ├── GCDM_Alpha_Conditional_Sampling.gif └── GCDM_Sampled_Molecule_Trajectory.gif ├── notebooks └── .gitkeep ├── pyproject.toml ├── scripts ├── generate_geom_mol_gen_ddpm_grid_search_runs.py ├── generate_qm9_mol_gen_ddpm_grid_search_runs.py ├── geom_mol_gen_ddpm_grid_search_scripts │ └── launch_all_geom_mol_gen_ddpm_grid_search_jobs.bash ├── local │ └── .gitkeep ├── nautilus │ ├── data_transfer_pod_pvc_template.yaml │ ├── generate_data_transfer_pod_pvc_yaml.py │ ├── generate_geom_mol_gen_ddpm_grid_search_jobs.py │ ├── generate_gpu_job_yaml.py │ ├── generate_hm_gpu_job_yaml.py │ ├── generate_persistent_storage_yaml.py │ ├── generate_qm9_mol_gen_ddpm_grid_search_jobs.py │ ├── gpu_job_template.yaml │ ├── hm_gpu_job_template.yaml │ └── persistent_storage_template.yaml └── qm9_mol_gen_ddpm_grid_search_scripts │ └── launch_all_qm9_mol_gen_ddpm_grid_search_jobs.bash ├── setup.py ├── src ├── __init__.py ├── analysis │ ├── bust_analysis.py │ ├── inference_analysis.py │ ├── molecule_analysis.py │ ├── optimization_analysis.py │ └── qm_analysis.py ├── datamodules │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── edm │ │ │ ├── __init__.py │ │ │ ├── bond_analysis.py │ │ │ ├── build_geom_dataset.py │ │ │ ├── collate.py │ │ │ ├── constants.py │ │ │ ├── dataset.py │ │ │ ├── datasets_config.py │ │ │ ├── download.py │ │ │ ├── md17.py │ │ │ ├── process.py │ │ │ ├── qm9.py │ │ │ ├── rdkit_functions.py │ │ │ └── utils.py │ │ ├── edm_dataset.py │ │ ├── helper.py │ │ ├── protein_graph_dataset.py │ │ └── sampler.py │ └── edm_datamodule.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── egnn.py │ │ ├── gcpnet.py │ │ └── variational_diffusion.py │ ├── geom_mol_gen_ddpm.py │ └── qm9_mol_gen_ddpm.py ├── mol_gen_eval.py ├── mol_gen_eval_conditional_qm9.py ├── mol_gen_eval_optimization_qm9.py ├── mol_gen_sample.py ├── train.py └── utils │ ├── __init__.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_configs.py ├── test_eval.py ├── test_sweeps.py └── test_train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # ignore the checkpoints directory 2 | checkpoints/ 3 | 4 | # ignore the data directory 5 | data/ 6 | 7 | # ignore the logs directory 8 | logs/ 9 | 10 | # ignore the git directory 11 | .git 12 | 13 | # ignore generated YAML files 14 | scripts/nautilus/data_transfer_pod_pvc.yaml 15 | scripts/nautilus/gpu_job.yaml 16 | scripts/nautilus/hm_gpu_job.yaml 17 | scripts/nautilus/persistent_storage.yaml -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | 18 | ## Did you have fun? 19 | 20 | Make sure you had fun coding 🙃 21 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-main.yaml: -------------------------------------------------------------------------------- 1 | # Same as `code-quality-pr.yaml` but triggered on commit to main branch 2 | # and runs on all files (instead of only the changed ones) 3 | 4 | name: Code Quality Main 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | 10 | jobs: 11 | code-quality: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | 21 | - name: Run pre-commits 22 | uses: pre-commit/action@v2.0.3 23 | -------------------------------------------------------------------------------- /.github/workflows/code-quality-pr.yaml: -------------------------------------------------------------------------------- 1 | # This workflow finds which files were changed, prints them, 2 | # and runs `pre-commit` on those files. 3 | 4 | # Inspired by the sktime library: 5 | # https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml 6 | 7 | name: Code Quality PR 8 | 9 | on: 10 | pull_request: 11 | branches: [main, "release/*"] 12 | 13 | jobs: 14 | code-quality: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | 24 | - name: Find modified files 25 | id: file_changes 26 | uses: trilom/file-changes-action@v1.2.4 27 | with: 28 | output: " " 29 | 30 | - name: List modified files 31 | run: echo '${{ steps.file_changes.outputs.files}}' 32 | 33 | - name: Run pre-commits 34 | uses: pre-commit/action@v2.0.3 35 | with: 36 | extra_args: --files ${{ steps.file_changes.outputs.files}} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | jobs: 10 | run_tests: 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: ["ubuntu-latest", "macos-latest"] 17 | python-version: ["3.7", "3.8", "3.9", "3.10"] 18 | 19 | timeout-minutes: 10 20 | 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v3 24 | 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install dependencies 31 | run: | 32 | conda install -y -c conda-forge mamba 33 | mamba env create -f environment.yaml 34 | conda activate bio-diffusion 35 | python -m pip install --upgrade pip 36 | pip install -e . 37 | pip install pytest 38 | pip install sh 39 | 40 | - name: List dependencies 41 | run: | 42 | python -m pip list 43 | 44 | - name: Run pytest 45 | run: | 46 | pytest -v 47 | 48 | run_tests_windows: 49 | runs-on: ${{ matrix.os }} 50 | 51 | strategy: 52 | fail-fast: false 53 | matrix: 54 | os: ["windows-latest"] 55 | python-version: ["3.7", "3.8", "3.9", "3.10"] 56 | 57 | timeout-minutes: 10 58 | 59 | steps: 60 | - name: Checkout 61 | uses: actions/checkout@v3 62 | 63 | - name: Set up Python ${{ matrix.python-version }} 64 | uses: actions/setup-python@v3 65 | with: 66 | python-version: ${{ matrix.python-version }} 67 | 68 | - name: Install dependencies 69 | run: | 70 | conda install -y -c conda-forge mamba 71 | mamba env create -f environment.yaml 72 | conda activate bio-diffusion 73 | python -m pip install --upgrade pip 74 | pip install -e . 75 | pip install pytest 76 | 77 | - name: List dependencies 78 | run: | 79 | python -m pip list 80 | 81 | - name: Run pytest 82 | run: | 83 | pytest -v 84 | 85 | # upload code coverage report 86 | code-coverage: 87 | runs-on: ubuntu-latest 88 | 89 | steps: 90 | - name: Checkout 91 | uses: actions/checkout@v2 92 | 93 | - name: Set up Python 3.10 94 | uses: actions/setup-python@v2 95 | with: 96 | python-version: "3.10" 97 | 98 | - name: Install dependencies 99 | run: | 100 | conda install -y -c conda-forge mamba 101 | mamba env create -f environment.yaml 102 | conda activate bio-diffusion 103 | python -m pip install --upgrade pip 104 | pip install -e . 105 | pip install pytest 106 | pip install pytest-cov[toml] 107 | pip install sh 108 | 109 | - name: Run tests and collect coverage 110 | run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER 111 | 112 | - name: Upload coverage to Codecov 113 | uses: codecov/codecov-action@v3 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Bio-Diffusion 148 | .env 149 | .autoenv 150 | .hydra 151 | *.dot 152 | *.pdf 153 | fit-perf_logs.txt 154 | *.sdf 155 | checkpoints/GEOM/* 156 | checkpoints/QM9/* 157 | configs/local/*.yaml 158 | scripts/*_grid_search_runs.json 159 | scripts/*_grid_search_scripts*/gpu_job_*.yaml 160 | scripts/*_grid_search_scripts*/train_*.bash 161 | scripts/*_grid_search_scripts*/*.out 162 | scripts/*_grid_search_scripts*/*.done 163 | scripts/local/*.bash 164 | scripts/nautilus/data_transfer_pod_pvc.yaml 165 | scripts/nautilus/gpu_job.yaml 166 | scripts/nautilus/hm_gpu_job.yaml 167 | scripts/nautilus/persistent_storage.yaml 168 | 169 | bio-diffusion/ 170 | logs/ 171 | *optim_mols/ 172 | outputs/ 173 | epoch_*/ 174 | .cache/ 175 | 176 | data/EDM/GEOM* 177 | data/EDM/QM9* 178 | 179 | # NFS 180 | .nfs* 181 | 182 | # Git 183 | .git-credentials -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 22.6.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "99"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.10.1 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v2.32.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.4 44 | hooks: 45 | - id: docformatter 46 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 47 | 48 | # python check (PEP8), programming errors and code complexity 49 | - repo: https://github.com/PyCQA/flake8 50 | rev: 4.0.1 51 | hooks: 52 | - id: flake8 53 | args: 54 | [ 55 | "--extend-ignore", 56 | "E203,E402,E501,F401,F841", 57 | "--exclude", 58 | "logs/*,data/*", 59 | ] 60 | 61 | # python security linter 62 | - repo: https://github.com/PyCQA/bandit 63 | rev: "1.7.1" 64 | hooks: 65 | - id: bandit 66 | args: ["-s", "B101"] 67 | 68 | # yaml formatting 69 | - repo: https://github.com/pre-commit/mirrors-prettier 70 | rev: v2.7.1 71 | hooks: 72 | - id: prettier 73 | types: [yaml] 74 | 75 | # shell scripts linter 76 | - repo: https://github.com/shellcheck-py/shellcheck-py 77 | rev: v0.8.0.4 78 | hooks: 79 | - id: shellcheck 80 | 81 | # md formatting 82 | - repo: https://github.com/executablebooks/mdformat 83 | rev: 0.7.14 84 | hooks: 85 | - id: mdformat 86 | args: ["--number"] 87 | additional_dependencies: 88 | - mdformat-gfm 89 | - mdformat-tables 90 | - mdformat_frontmatter 91 | # - mdformat-toc 92 | # - mdformat-black 93 | 94 | # word spelling linter 95 | - repo: https://github.com/codespell-project/codespell 96 | rev: v2.1.0 97 | hooks: 98 | - id: codespell 99 | args: 100 | - --skip=logs/**,data/**,*.ipynb 101 | # - --ignore-words-list=abc,def 102 | 103 | # jupyter notebook cell output clearing 104 | - repo: https://github.com/kynan/nbstripout 105 | rev: 0.5.0 106 | hooks: 107 | - id: nbstripout 108 | 109 | # jupyter notebook linting 110 | - repo: https://github.com/nbQA-dev/nbQA 111 | rev: 1.4.0 112 | hooks: 113 | - id: nbqa-black 114 | args: ["--line-length=99"] 115 | - id: nbqa-isort 116 | args: ["--profile=black"] 117 | - id: nbqa-flake8 118 | args: 119 | [ 120 | "--extend-ignore=E203,E402,E501,F401,F841", 121 | "--exclude=logs/*,data/*", 122 | ] 123 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM nvcr.io/nvidia/pytorch:21.06-py3 2 | FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime 3 | 4 | LABEL authors="Colby T. Ford " 5 | 6 | ## Set environment variables 7 | ENV MPLCONFIGDIR /data/MPL_Config 8 | ENV TORCH_HOME /data/Torch_Home 9 | ENV TORCH_EXTENSIONS_DIR /data/Torch_Extensions 10 | ENV DEBIAN_FRONTEND noninteractive 11 | 12 | ## Install system requirements 13 | RUN apt update && \ 14 | apt-get install -y --reinstall \ 15 | ca-certificates && \ 16 | apt install -y \ 17 | git \ 18 | vim \ 19 | wget \ 20 | libxml2 \ 21 | libgl-dev \ 22 | libgl1 23 | 24 | ## Make directories 25 | RUN mkdir -p /software/ 26 | WORKDIR /software/ 27 | 28 | ## Install dependencies from Conda/Mamba 29 | COPY environment.yaml /software/environment.yaml 30 | RUN conda env create -f environment.yaml 31 | RUN conda init bash && \ 32 | echo "conda activate bio-diffusion" >> ~/.bashrc 33 | SHELL ["/bin/bash", "--login", "-c"] 34 | 35 | ## Install bio-diffusion 36 | RUN git clone https://github.com/BioinfoMachineLearning/bio-diffusion && \ 37 | cd bio-diffusion && \ 38 | pip install -e . 39 | WORKDIR /software/bio-diffusion/ 40 | 41 | CMD /bin/bash 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 BioinfoMachineLearning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- 24 | The QM9 dataset is accompanied by a 25 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 26 | 27 | -------------------------------------------------------------------------------- 28 | The GEOM-Drugs dataset is accompanied by a 29 | Creative Commons Attribution 4.0 International License. 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | 32 | debug: ## Enter debugging mode with pdb 33 | # 34 | # tips: 35 | # - use "import pdb; pdb.set_trace()" to set breakpoint 36 | # - use "h" to print all commands 37 | # - use "n" to execute the next line 38 | # - use "c" to run until the breakpoint is hit 39 | # - use "l" to print src code around current line, "ll" for full function code 40 | # - docs: https://docs.python.org/3/library/pdb.html 41 | # 42 | python -m pdb src/train.py debug=default 43 | -------------------------------------------------------------------------------- /citation.bib: -------------------------------------------------------------------------------- 1 | @article{morehead2024geometry, 2 | title={Geometry-complete diffusion for 3D molecule generation and optimization}, 3 | author={Morehead, Alex and Cheng, Jianlin}, 4 | journal={Communications Chemistry}, 5 | volume={7}, 6 | number={1}, 7 | pages={150}, 8 | year={2024}, 9 | publisher={Nature Publishing Group UK London} 10 | } 11 | -------------------------------------------------------------------------------- /configs/analysis/bust_analysis.yaml: -------------------------------------------------------------------------------- 1 | dataset: GEOM # NOTE: must be one of (`QM9`, `GEOM`) 2 | model_type: Unconditional # NOTE: must be one of (`Unconditional`, `Conditional`) 3 | sampling_index: 0 # NOTE: must be one of (`0`, `1`, `2`, `3`, `4`) 4 | method_1: gcdm # NOTE: must be one of (`gcdm`, `geoldm`) 5 | method_2: geoldm # NOTE: must be one of (`gcdm`, `geoldm`) 6 | property: '' # NOTE: if `model_type` is `Conditional`, must be one of (`_alpha`, `_gap`, `_homo`, `_lumo`, `_mu`, `_Cv`) 7 | bust_column_name: energy_ratio # column name in the bust results file 8 | method_1_bust_results_filepath: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}_analysis/${method_1}${property}_molecule_bust_results_${sampling_index}.csv # filepath to which bust results were saved 9 | method_2_bust_results_filepath: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}_analysis/${method_2}${property}_molecule_bust_results_${sampling_index}.csv # filepath to which bust results were saved 10 | bust_analysis_plot_filepath: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}_analysis/${method_1}_${method_2}${property}_bust_analysis_${sampling_index}.png # filepath to which bust analysis plot will be saved 11 | verbose: true # whether to print additional information 12 | -------------------------------------------------------------------------------- /configs/analysis/inference_analysis.yaml: -------------------------------------------------------------------------------- 1 | dataset: QM9 # NOTE: must be one of (`QM9`, `GEOM`) 2 | model_type: Unconditional # NOTE: must be one of (`Unconditional`, `Conditional`) 3 | method: ??? # NOTE: must be one of (`gcdm`, `geoldm`) 4 | property: '' # NOTE: if `model_type` is `Conditional`, must be one of (`_alpha`, `_gap`, `_homo`, `_lumo`, `_mu`, `_Cv`) 5 | bust_results_filepath: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}_analysis/${method}${property}_molecule_bust_results.csv # filepath to which bust results were saved 6 | -------------------------------------------------------------------------------- /configs/analysis/molecule_analysis.yaml: -------------------------------------------------------------------------------- 1 | dataset: QM9 # NOTE: must be one of (`QM9`, `GEOM`) 2 | model_type: Unconditional # NOTE: must be one of (`Unconditional`, `Conditional`) 3 | model_index: 1 # NOTE: must be one of (`1`, `2`, `3`) 4 | sampling_index: ??? # NOTE: must be one of (`0`, `1`, `2`, `3`, `4`) 5 | method: ??? # NOTE: must be one of (`gcdm`, `geoldm`) 6 | property: '' # NOTE: if `model_type` is `Conditional`, must be one of (`_alpha`, `_gap`, `_homo`, `_lumo`, `_mu`, `_Cv`) 7 | input_molecule_dir: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}/${method}_model_${model_index}${property}/ # directory containing input molecules 8 | bust_results_filepath: ${oc.env:PROJECT_ROOT}/output/${dataset}/${model_type}_analysis/${method}${property}_molecule_bust_results.csv # filepath to which to save bust results 9 | full_report: true # whether to generate a full report or not 10 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - ema.yaml 5 | - model_summary.yaml 6 | - rich_progress_bar.yaml 7 | - _self_ 8 | 9 | model_checkpoint: 10 | dirpath: ${paths.output_dir}/checkpoints 11 | filename: "epoch_{epoch:03d}" 12 | monitor: "val/loss" 13 | save_top_k: 3 14 | mode: "min" 15 | save_last: True 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | monitor: "val/loss" 20 | patience: 5 21 | mode: "min" 22 | 23 | model_summary: 24 | max_depth: -1 25 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: pytorch_lightning.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 5 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /configs/callbacks/ema.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py 2 | 3 | # Maintains an exponential moving average (EMA) of model weights. 4 | # Look at the above link for more detailed information regarding the original implementation. 5 | ema: 6 | _target_: src.utils.EMA 7 | decay: 0.9999 8 | apply_ema_every_n_steps: 1 9 | start_step: 0 10 | save_ema_weights_in_callback_state: true 11 | evaluate_ema_weights_instead: true 12 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html 2 | # Adapted from: https://github.com/NVIDIA/NeMo/blob/be0804f61e82dd0f63da7f9fe8a4d8388e330b18/nemo/utils/exp_manager.py#L744 3 | 4 | # Save the model periodically by monitoring a quantity. 5 | # Look at the above links for more detailed information. 6 | model_checkpoint: 7 | _target_: src.utils.EMAModelCheckpoint 8 | dirpath: null # directory to save the model file 9 | filename: null # checkpoint filename 10 | monitor: null # name of the logged metric which determines when model is improving 11 | verbose: False # verbosity mode 12 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 13 | save_top_k: 3 # save k best models (determined by above metric) 14 | mode: "min" # "max" means higher metric value is better, can be also "min" 15 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 16 | save_weights_only: False # if True, then only the model’s weights will be saved 17 | every_n_train_steps: null # number of training steps between checkpoints 18 | train_time_interval: null # checkpoints are monitored at the specified time interval 19 | every_n_epochs: null # number of epochs between checkpoints 20 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 21 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: pytorch_lightning.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /configs/datamodule/dataloader_cfg/edm_geom_dataloader.yaml: -------------------------------------------------------------------------------- 1 | dataset: GEOM 2 | data_dir: ${paths.data_dir}EDM 3 | smiles_filepath: ${.data_dir}/GEOM/GEOM_drugs_smiles.txt 4 | num_atom_types: 16 5 | num_x_dims: 3 6 | remove_h: false 7 | create_pyg_graphs: true 8 | num_train: -1 9 | num_valid: -1 10 | num_test: -1 11 | subtract_thermo: true 12 | filter_n_atoms: null 13 | include_charges: false 14 | filter_molecule_size: null 15 | sequential: false 16 | device: cpu 17 | force_download: false 18 | num_radials: 1 19 | batch_size: 64 20 | num_workers: 4 21 | shuffle: true 22 | drop_last: true 23 | pin_memory: false 24 | -------------------------------------------------------------------------------- /configs/datamodule/dataloader_cfg/edm_qm9_dataloader.yaml: -------------------------------------------------------------------------------- 1 | dataset: QM9 2 | data_dir: ${paths.data_dir}EDM 3 | smiles_filepath: ${.data_dir}/QM9/QM9_smiles.pickle 4 | num_atom_types: 5 5 | num_x_dims: 3 6 | remove_h: false 7 | create_pyg_graphs: true 8 | num_train: -1 9 | num_valid: -1 10 | num_test: -1 11 | subtract_thermo: true 12 | filter_n_atoms: null 13 | include_charges: true 14 | filter_molecule_size: null 15 | sequential: false 16 | device: cpu 17 | force_download: false 18 | num_radials: 1 19 | batch_size: 64 20 | num_workers: 4 21 | shuffle: true 22 | drop_last: true 23 | pin_memory: false 24 | -------------------------------------------------------------------------------- /configs/datamodule/edm_geom.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.edm_datamodule.EDMDataModule 2 | 3 | defaults: 4 | - _self_ 5 | - dataloader_cfg: edm_geom_dataloader.yaml 6 | -------------------------------------------------------------------------------- /configs/datamodule/edm_qm9.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.edm_datamodule.EDMDataModule 2 | 3 | defaults: 4 | - _self_ 5 | - dataloader_cfg: edm_qm9_dataloader.yaml 6 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/experiment/geom_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=geom_mol_gen_ddpm 5 | 6 | defaults: 7 | - override /datamodule: edm_geom.yaml 8 | - override /model: geom_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "geom"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 20 26 | 27 | trainer: 28 | min_epochs: 50 29 | max_epochs: 3000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: 1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | model: 37 | optimizer: 38 | lr: 1e-4 39 | weight_decay: 1e-12 40 | 41 | model_cfg: 42 | h_hidden_dim: 256 43 | chi_hidden_dim: 32 44 | e_hidden_dim: 16 45 | xi_hidden_dim: 8 46 | 47 | num_encoder_layers: 4 48 | num_decoder_layers: 3 49 | dropout: 0.0 50 | 51 | module_cfg: 52 | selected_GCP: 53 | # which version of the GCP module to use (e.g., GCP or GCP2) 54 | _target_: src.models.components.gcpnet.GCP2 55 | _partial_: true 56 | 57 | norm_x_diff: true 58 | 59 | scalar_gate: 0 60 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 61 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 62 | vector_frame_residual: false # note: for GCP2, this parameter is unused 63 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 64 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 65 | 66 | scalar_nonlinearity: silu 67 | vector_nonlinearity: silu 68 | 69 | nonlinearities: 70 | - ${..scalar_nonlinearity} 71 | - ${..vector_nonlinearity} 72 | 73 | bottleneck: 4 74 | 75 | vector_linear: true 76 | vector_identity: true 77 | 78 | default_vector_residual: false 79 | default_bottleneck: 4 80 | 81 | node_positions_weight: 1.0 82 | update_positions_with_vector_sum: false 83 | 84 | ablate_frame_updates: false 85 | ablate_scalars: false 86 | ablate_vectors: false 87 | 88 | clip_gradients: true 89 | 90 | layer_cfg: 91 | mp_cfg: 92 | edge_encoder: false 93 | edge_gate: false 94 | num_message_layers: 4 95 | message_residual: 0 96 | message_ff_multiplier: 1 97 | self_message: true 98 | use_residual_message_gcp: true 99 | 100 | pre_norm: false 101 | use_gcp_norm: false 102 | use_gcp_dropout: false 103 | use_scalar_message_attention: true 104 | num_feedforward_layers: 1 105 | dropout: 0.0 106 | 107 | nonlinearity_slope: 1e-2 108 | 109 | diffusion_cfg: 110 | ddpm_mode: unconditional 111 | dynamics_network: gcpnet 112 | num_timesteps: 1000 113 | norm_training_by_max_nodes: false 114 | 115 | datamodule: 116 | dataloader_cfg: 117 | num_train: -1 118 | num_val: -1 119 | num_test: -1 120 | batch_size: 64 121 | num_workers: 4 122 | 123 | logger: 124 | wandb: 125 | name: 04302023_GEOMMoleculeGenerationDDPM 126 | group: "GEOM" 127 | tags: ${tags} 128 | -------------------------------------------------------------------------------- /configs/experiment/geom_mol_gen_ddpm_grid_search.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=geom_mol_gen_ddpm_grid_search 5 | 6 | defaults: 7 | - override /datamodule: edm_geom.yaml 8 | - override /model: geom_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "geom", "grid_search"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 20 26 | 27 | trainer: 28 | min_epochs: 50 29 | max_epochs: 3000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: -1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | paths: 37 | grid_search_script_dir: scripts/geom_mol_gen_ddpm_grid_search_scripts 38 | 39 | model: 40 | optimizer: 41 | lr: 1e-4 42 | weight_decay: 1e-12 43 | 44 | model_cfg: 45 | h_hidden_dim: 256 46 | chi_hidden_dim: 32 47 | e_hidden_dim: 16 48 | xi_hidden_dim: 8 49 | 50 | num_encoder_layers: 4 51 | num_decoder_layers: 3 52 | dropout: 0.0 53 | 54 | module_cfg: 55 | selected_GCP: 56 | # which version of the GCP module to use (e.g., GCP or GCP2) 57 | _target_: src.models.components.gcpnet.GCP2 58 | _partial_: true 59 | 60 | norm_x_diff: true 61 | 62 | scalar_gate: 0 63 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 64 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 65 | vector_frame_residual: false # note: for GCP2, this parameter is unused 66 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 67 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 68 | 69 | scalar_nonlinearity: silu 70 | vector_nonlinearity: silu 71 | 72 | nonlinearities: 73 | - ${..scalar_nonlinearity} 74 | - ${..vector_nonlinearity} 75 | 76 | bottleneck: 4 77 | 78 | vector_linear: true 79 | vector_identity: true 80 | 81 | default_vector_residual: false 82 | default_bottleneck: 4 83 | 84 | node_positions_weight: 1.0 85 | update_positions_with_vector_sum: false 86 | 87 | ablate_frame_updates: false 88 | ablate_scalars: false 89 | ablate_vectors: false 90 | 91 | clip_gradients: true 92 | 93 | layer_cfg: 94 | mp_cfg: 95 | edge_encoder: false 96 | edge_gate: false 97 | num_message_layers: 4 98 | message_residual: 0 99 | message_ff_multiplier: 1 100 | self_message: true 101 | use_residual_message_gcp: true 102 | 103 | pre_norm: false 104 | use_gcp_norm: false 105 | use_gcp_dropout: false 106 | use_scalar_message_attention: true 107 | num_feedforward_layers: 1 108 | dropout: 0.0 109 | 110 | nonlinearity_slope: 1e-2 111 | 112 | diffusion_cfg: 113 | ddpm_mode: unconditional 114 | dynamics_network: gcpnet 115 | num_timesteps: 1000 116 | norm_training_by_max_nodes: false 117 | 118 | datamodule: 119 | dataloader_cfg: 120 | num_train: -1 121 | num_val: -1 122 | num_test: -1 123 | batch_size: 64 124 | num_workers: 4 125 | 126 | logger: 127 | wandb: 128 | name: 04302023_GEOMMoleculeGenerationDDPM 129 | group: "GEOM" 130 | tags: ${tags} 131 | 132 | train: true 133 | test: false -------------------------------------------------------------------------------- /configs/experiment/qm9_mol_gen_conditional_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=qm9_mol_gen_ddpm 5 | 6 | defaults: 7 | - override /datamodule: edm_qm9.yaml 8 | - override /model: qm9_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "qm9", "conditioning"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 12 26 | 27 | trainer: 28 | min_epochs: 1600 29 | max_epochs: 5000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: 1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | model: 37 | optimizer: 38 | lr: 1e-4 39 | weight_decay: 1e-12 40 | 41 | model_cfg: 42 | h_hidden_dim: 256 43 | chi_hidden_dim: 32 44 | e_hidden_dim: 64 45 | xi_hidden_dim: 16 46 | 47 | num_encoder_layers: 9 48 | num_decoder_layers: 3 49 | dropout: 0.0 50 | 51 | module_cfg: 52 | selected_GCP: 53 | # which version of the GCP module to use (e.g., GCP or GCP2) 54 | _target_: src.models.components.gcpnet.GCP2 55 | _partial_: true 56 | 57 | norm_x_diff: true 58 | 59 | scalar_gate: 0 60 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 61 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 62 | vector_frame_residual: false # note: for GCP2, this parameter is unused 63 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 64 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 65 | 66 | scalar_nonlinearity: silu 67 | vector_nonlinearity: silu 68 | 69 | nonlinearities: 70 | - ${..scalar_nonlinearity} 71 | - ${..vector_nonlinearity} 72 | 73 | bottleneck: 4 74 | 75 | vector_linear: true 76 | vector_identity: true 77 | 78 | default_vector_residual: false 79 | default_bottleneck: 4 80 | 81 | node_positions_weight: 1.0 82 | update_positions_with_vector_sum: false 83 | 84 | ablate_frame_updates: false 85 | ablate_scalars: false 86 | ablate_vectors: false 87 | 88 | conditioning: [alpha] 89 | 90 | clip_gradients: true 91 | 92 | layer_cfg: 93 | mp_cfg: 94 | edge_encoder: false 95 | edge_gate: false 96 | num_message_layers: 4 97 | message_residual: 0 98 | message_ff_multiplier: 1 99 | self_message: true 100 | use_residual_message_gcp: true 101 | 102 | pre_norm: false 103 | use_gcp_norm: false 104 | use_gcp_dropout: false 105 | use_scalar_message_attention: true 106 | num_feedforward_layers: 1 107 | dropout: 0.0 108 | 109 | nonlinearity_slope: 1e-2 110 | 111 | diffusion_cfg: 112 | ddpm_mode: unconditional 113 | dynamics_network: gcpnet 114 | num_timesteps: 1000 115 | norm_training_by_max_nodes: false 116 | norm_values: [1.0, 8.0, 1.0] 117 | 118 | datamodule: 119 | dataloader_cfg: 120 | num_train: -1 121 | num_val: -1 122 | num_test: -1 123 | batch_size: 64 124 | num_workers: 4 125 | include_charges: false 126 | dataset: QM9_second_half 127 | 128 | logger: 129 | wandb: 130 | name: 02092023_QM9MoleculeGenerationDDPM 131 | group: "QM9" 132 | tags: ${tags} 133 | -------------------------------------------------------------------------------- /configs/experiment/qm9_mol_gen_conditional_ddpm_grid_search.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=qm9_mol_gen_ddpm 5 | 6 | defaults: 7 | - override /datamodule: edm_qm9.yaml 8 | - override /model: qm9_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "qm9", "conditioning", "grid_search"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 12 26 | 27 | trainer: 28 | min_epochs: 1600 29 | max_epochs: 5000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: -1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | paths: 37 | grid_search_script_dir: scripts/qm9_mol_gen_ddpm_grid_search_scripts 38 | 39 | model: 40 | optimizer: 41 | lr: 1e-4 42 | weight_decay: 1e-12 43 | 44 | model_cfg: 45 | h_hidden_dim: 256 46 | chi_hidden_dim: 32 47 | e_hidden_dim: 64 48 | xi_hidden_dim: 16 49 | 50 | num_encoder_layers: 9 51 | num_decoder_layers: 3 52 | dropout: 0.0 53 | 54 | module_cfg: 55 | selected_GCP: 56 | # which version of the GCP module to use (e.g., GCP or GCP2) 57 | _target_: src.models.components.gcpnet.GCP2 58 | _partial_: true 59 | 60 | norm_x_diff: true 61 | 62 | scalar_gate: 0 63 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 64 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 65 | vector_frame_residual: false # note: for GCP2, this parameter is unused 66 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 67 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 68 | 69 | scalar_nonlinearity: silu 70 | vector_nonlinearity: silu 71 | 72 | nonlinearities: 73 | - ${..scalar_nonlinearity} 74 | - ${..vector_nonlinearity} 75 | 76 | bottleneck: 4 77 | 78 | vector_linear: true 79 | vector_identity: true 80 | 81 | default_vector_residual: false 82 | default_bottleneck: 4 83 | 84 | node_positions_weight: 1.0 85 | update_positions_with_vector_sum: false 86 | 87 | ablate_frame_updates: false 88 | ablate_scalars: false 89 | ablate_vectors: false 90 | 91 | conditioning: [alpha] 92 | 93 | clip_gradients: true 94 | 95 | layer_cfg: 96 | mp_cfg: 97 | edge_encoder: false 98 | edge_gate: false 99 | num_message_layers: 4 100 | message_residual: 0 101 | message_ff_multiplier: 1 102 | self_message: true 103 | use_residual_message_gcp: true 104 | 105 | pre_norm: false 106 | use_gcp_norm: false 107 | use_gcp_dropout: false 108 | use_scalar_message_attention: true 109 | num_feedforward_layers: 1 110 | dropout: 0.0 111 | 112 | nonlinearity_slope: 1e-2 113 | 114 | diffusion_cfg: 115 | ddpm_mode: unconditional 116 | dynamics_network: gcpnet 117 | num_timesteps: 1000 118 | norm_training_by_max_nodes: false 119 | norm_values: [1.0, 8.0, 1.0] 120 | 121 | datamodule: 122 | dataloader_cfg: 123 | num_train: -1 124 | num_val: -1 125 | num_test: -1 126 | batch_size: 64 127 | num_workers: 4 128 | include_charges: false 129 | dataset: QM9_second_half 130 | 131 | logger: 132 | wandb: 133 | name: 02092023_QM9MoleculeGenerationDDPM 134 | group: "QM9" 135 | tags: ${tags} 136 | 137 | train: true 138 | test: false -------------------------------------------------------------------------------- /configs/experiment/qm9_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=qm9_mol_gen_ddpm 5 | 6 | defaults: 7 | - override /datamodule: edm_qm9.yaml 8 | - override /model: qm9_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "qm9"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 10 26 | 27 | trainer: 28 | min_epochs: 1000 29 | max_epochs: 5000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: 1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | model: 37 | optimizer: 38 | lr: 1e-4 39 | weight_decay: 1e-12 40 | 41 | model_cfg: 42 | h_hidden_dim: 256 43 | chi_hidden_dim: 32 44 | e_hidden_dim: 64 45 | xi_hidden_dim: 16 46 | 47 | num_encoder_layers: 9 48 | num_decoder_layers: 3 49 | dropout: 0.0 50 | 51 | module_cfg: 52 | selected_GCP: 53 | # which version of the GCP module to use (e.g., GCP or GCP2) 54 | _target_: src.models.components.gcpnet.GCP2 55 | _partial_: true 56 | 57 | norm_x_diff: true 58 | 59 | scalar_gate: 0 60 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 61 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 62 | vector_frame_residual: false # note: for GCP2, this parameter is unused 63 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 64 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 65 | 66 | scalar_nonlinearity: silu 67 | vector_nonlinearity: silu 68 | 69 | nonlinearities: 70 | - ${..scalar_nonlinearity} 71 | - ${..vector_nonlinearity} 72 | 73 | bottleneck: 4 74 | 75 | vector_linear: true 76 | vector_identity: true 77 | 78 | default_vector_residual: false 79 | default_bottleneck: 4 80 | 81 | node_positions_weight: 1.0 82 | update_positions_with_vector_sum: false 83 | 84 | ablate_frame_updates: false 85 | ablate_scalars: false 86 | ablate_vectors: false 87 | 88 | clip_gradients: true 89 | 90 | layer_cfg: 91 | mp_cfg: 92 | edge_encoder: false 93 | edge_gate: false 94 | num_message_layers: 4 95 | message_residual: 0 96 | message_ff_multiplier: 1 97 | self_message: true 98 | use_residual_message_gcp: true 99 | 100 | pre_norm: false 101 | use_gcp_norm: false 102 | use_gcp_dropout: false 103 | use_scalar_message_attention: true 104 | num_feedforward_layers: 1 105 | dropout: 0.0 106 | 107 | nonlinearity_slope: 1e-2 108 | 109 | diffusion_cfg: 110 | ddpm_mode: unconditional 111 | dynamics_network: gcpnet 112 | num_timesteps: 1000 113 | norm_training_by_max_nodes: false 114 | 115 | datamodule: 116 | dataloader_cfg: 117 | num_train: -1 118 | num_val: -1 119 | num_test: -1 120 | batch_size: 64 121 | num_workers: 4 122 | 123 | logger: 124 | wandb: 125 | name: 02092023_QM9MoleculeGenerationDDPM 126 | group: "QM9" 127 | tags: ${tags} 128 | -------------------------------------------------------------------------------- /configs/experiment/qm9_mol_gen_ddpm_grid_search.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=qm9_mol_gen_ddpm_grid_search 5 | 6 | defaults: 7 | - override /datamodule: edm_qm9.yaml 8 | - override /model: qm9_mol_gen_ddpm.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["molecule_generation", "qm9", "grid_search"] 16 | 17 | seed: 42 18 | 19 | callbacks: 20 | model_checkpoint: 21 | monitor: "val/loss" 22 | save_top_k: 3 23 | early_stopping: 24 | monitor: "val/loss" 25 | patience: 10 26 | 27 | trainer: 28 | min_epochs: 1000 29 | max_epochs: 5000 30 | strategy: ddp_find_unused_parameters_false 31 | accelerator: gpu 32 | devices: -1 33 | num_nodes: 1 34 | accumulate_grad_batches: 1 35 | 36 | paths: 37 | grid_search_script_dir: scripts/qm9_mol_gen_ddpm_grid_search_scripts 38 | 39 | model: 40 | optimizer: 41 | lr: 1e-4 42 | weight_decay: 1e-12 43 | 44 | model_cfg: 45 | h_hidden_dim: 256 46 | chi_hidden_dim: 32 47 | e_hidden_dim: 64 48 | xi_hidden_dim: 16 49 | 50 | num_encoder_layers: 9 51 | num_decoder_layers: 3 52 | dropout: 0.0 53 | 54 | module_cfg: 55 | selected_GCP: 56 | # which version of the GCP module to use (e.g., GCP or GCP2) 57 | _target_: src.models.components.gcpnet.GCP2 58 | _partial_: true 59 | 60 | norm_x_diff: true 61 | 62 | scalar_gate: 0 63 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 64 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 65 | vector_frame_residual: false # note: for GCP2, this parameter is unused 66 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 67 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 68 | 69 | scalar_nonlinearity: silu 70 | vector_nonlinearity: silu 71 | 72 | nonlinearities: 73 | - ${..scalar_nonlinearity} 74 | - ${..vector_nonlinearity} 75 | 76 | bottleneck: 4 77 | 78 | vector_linear: true 79 | vector_identity: true 80 | 81 | default_vector_residual: false 82 | default_bottleneck: 4 83 | 84 | node_positions_weight: 1.0 85 | update_positions_with_vector_sum: false 86 | 87 | ablate_frame_updates: false 88 | ablate_scalars: false 89 | ablate_vectors: false 90 | 91 | clip_gradients: true 92 | 93 | layer_cfg: 94 | mp_cfg: 95 | edge_encoder: false 96 | edge_gate: false 97 | num_message_layers: 4 98 | message_residual: 0 99 | message_ff_multiplier: 1 100 | self_message: true 101 | use_residual_message_gcp: true 102 | 103 | pre_norm: false 104 | use_gcp_norm: false 105 | use_gcp_dropout: false 106 | use_scalar_message_attention: true 107 | num_feedforward_layers: 1 108 | dropout: 0.0 109 | 110 | nonlinearity_slope: 1e-2 111 | 112 | diffusion_cfg: 113 | ddpm_mode: unconditional 114 | dynamics_network: gcpnet 115 | num_timesteps: 1000 116 | norm_training_by_max_nodes: false 117 | 118 | datamodule: 119 | dataloader_cfg: 120 | num_train: -1 121 | num_val: -1 122 | num_test: -1 123 | batch_size: 64 124 | num_workers: 4 125 | 126 | logger: 127 | wandb: 128 | name: 02092023_QM9MoleculeGenerationDDPM 129 | group: "QM9" 130 | tags: ${tags} 131 | 132 | train: true 133 | test: false -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/geom_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=geom_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/loss_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: minimize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 42 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hparams_search/qm9_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=qm9_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/loss_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: minimize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 42 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "Bio-Diffusion" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: Bio-Diffusion 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: false 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "Bio-Diffusion" 11 | log_model: true # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | entity: "" # set to name of your wandb team 14 | # group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/diffusion_cfg/geom_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | ddpm_mode: unconditional # [unconditional, conditional, simple_conditional] 2 | dynamics_network: gcpnet # [gcpnet, egnn] 3 | diffusion_target: "atom_types_and_coords" # [atom_types_and_coords] 4 | num_timesteps: 1000 5 | parametrization: "eps" 6 | noise_schedule: "polynomial_2" # [cosine, polynomial_n, learned] 7 | noise_precision: 1e-5 8 | loss_type: "l2" # [l2, vlb] 9 | norm_values: [1.0, 4.0, 10.0] # [normalization_value_for_x, normalization_value_for_h_categorical, normalization_value_for_h_integer] 10 | norm_biases: [null, 0.0, 0.0] 11 | condition_on_time: true 12 | self_condition: false 13 | norm_training_by_max_nodes: false 14 | sample_during_training: true 15 | eval_epochs: 1 16 | visualize_sample_epochs: ${.eval_epochs} 17 | visualize_chain_epochs: ${.eval_epochs} 18 | num_eval_samples: 500 19 | eval_batch_size: 100 20 | num_visualization_samples: 5 21 | keep_frames: 100 -------------------------------------------------------------------------------- /configs/model/diffusion_cfg/qm9_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | ddpm_mode: unconditional # [unconditional, conditional, simple_conditional] 2 | dynamics_network: gcpnet # [gcpnet, egnn] 3 | diffusion_target: "atom_types_and_coords" # [atom_types_and_coords] 4 | num_timesteps: 1000 5 | parametrization: "eps" # [eps] 6 | noise_schedule: "polynomial_2" # [cosine, polynomial_n, learned] 7 | noise_precision: 1e-5 8 | loss_type: "l2" # [l2, vlb] 9 | norm_values: [1.0, 4.0, 10.0] # [normalization_value_for_x, normalization_value_for_h_categorical, normalization_value_for_h_integer] 10 | norm_biases: [null, 0.0, 0.0] 11 | condition_on_time: true 12 | self_condition: false 13 | norm_training_by_max_nodes: false 14 | sample_during_training: true 15 | eval_epochs: 20 16 | visualize_sample_epochs: ${.eval_epochs} 17 | visualize_chain_epochs: ${.eval_epochs} 18 | num_eval_samples: 1000 19 | eval_batch_size: 100 20 | num_visualization_samples: 5 21 | keep_frames: 100 -------------------------------------------------------------------------------- /configs/model/geom_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.geom_mol_gen_ddpm.GEOMMoleculeGenerationDDPM 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 1e-4 7 | weight_decay: 1e-12 8 | amsgrad: true 9 | 10 | scheduler: # note: leaving `scheduler` empty will result in a learning-rate scheduler not being used 11 | # _target_: torch.optim.lr_scheduler.StepLR 12 | # _partial_: true 13 | # step_size: ${...trainer.min_epochs} // 8 # note: using literal evalution manually until Hydra natively supports this functionality 14 | # gamma: 0.9 15 | # last_epoch: -1 16 | 17 | defaults: 18 | - model_cfg: geom_mol_gen_ddpm_gcp_model.yaml 19 | - module_cfg: geom_mol_gen_ddpm_gcp_module.yaml 20 | - layer_cfg: geom_mol_gen_ddpm_gcp_interaction_layer.yaml 21 | - diffusion_cfg: geom_mol_gen_ddpm.yaml 22 | 23 | seed: ${..seed} -------------------------------------------------------------------------------- /configs/model/layer_cfg/geom_mol_gen_ddpm_gcp_interaction_layer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mp_cfg: geom_mol_gen_ddpm_gcp_mp.yaml 3 | 4 | pre_norm: false 5 | use_gcp_norm: false 6 | use_gcp_dropout: false 7 | use_scalar_message_attention: true 8 | num_feedforward_layers: 1 9 | dropout: 0.0 10 | 11 | nonlinearity_slope: 1e-2 -------------------------------------------------------------------------------- /configs/model/layer_cfg/mp_cfg/geom_mol_gen_ddpm_gcp_mp.yaml: -------------------------------------------------------------------------------- 1 | edge_encoder: false 2 | edge_gate: false 3 | num_message_layers: 4 4 | message_residual: 0 5 | message_ff_multiplier: 1 6 | self_message: true 7 | use_residual_message_gcp: true -------------------------------------------------------------------------------- /configs/model/layer_cfg/mp_cfg/qm9_mol_gen_ddpm_gcp_mp.yaml: -------------------------------------------------------------------------------- 1 | edge_encoder: false 2 | edge_gate: false 3 | num_message_layers: 4 4 | message_residual: 0 5 | message_ff_multiplier: 1 6 | self_message: true 7 | use_residual_message_gcp: true -------------------------------------------------------------------------------- /configs/model/layer_cfg/qm9_mol_gen_ddpm_gcp_interaction_layer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - mp_cfg: qm9_mol_gen_ddpm_gcp_mp.yaml 3 | 4 | pre_norm: false 5 | use_gcp_norm: false 6 | use_gcp_dropout: false 7 | use_scalar_message_attention: true 8 | num_feedforward_layers: 1 9 | dropout: 0.0 10 | 11 | nonlinearity_slope: 1e-2 -------------------------------------------------------------------------------- /configs/model/model_cfg/geom_mol_gen_ddpm_gcp_model.yaml: -------------------------------------------------------------------------------- 1 | chi_input_dim: 2 2 | e_input_dim: 1 3 | xi_input_dim: 1 4 | 5 | h_hidden_dim: 256 6 | chi_hidden_dim: 32 7 | e_hidden_dim: 16 8 | xi_hidden_dim: 8 9 | 10 | num_encoder_layers: 4 11 | num_decoder_layers: 3 12 | dropout: 0.0 -------------------------------------------------------------------------------- /configs/model/model_cfg/qm9_mol_gen_ddpm_gcp_model.yaml: -------------------------------------------------------------------------------- 1 | chi_input_dim: 2 2 | e_input_dim: 1 3 | xi_input_dim: 1 4 | 5 | h_hidden_dim: 256 6 | chi_hidden_dim: 32 7 | e_hidden_dim: 64 8 | xi_hidden_dim: 16 9 | 10 | num_encoder_layers: 9 11 | num_decoder_layers: 3 12 | dropout: 0.0 -------------------------------------------------------------------------------- /configs/model/module_cfg/geom_mol_gen_ddpm_gcp_module.yaml: -------------------------------------------------------------------------------- 1 | selected_GCP: 2 | # which version of the GCP module to use (e.g., GCP or GCP2) 3 | _target_: src.models.components.gcpnet.GCP2 4 | _partial_: true 5 | 6 | norm_x_diff: true 7 | 8 | scalar_gate: 0 9 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 10 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 11 | vector_frame_residual: false # note: for GCP2, this parameter is unused 12 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 13 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 14 | 15 | scalar_nonlinearity: silu 16 | vector_nonlinearity: silu 17 | 18 | nonlinearities: 19 | - ${..scalar_nonlinearity} 20 | - ${..vector_nonlinearity} 21 | 22 | bottleneck: 4 23 | 24 | vector_linear: true 25 | vector_identity: true 26 | 27 | default_vector_residual: false 28 | default_bottleneck: 4 29 | 30 | node_positions_weight: 1.0 31 | update_positions_with_vector_sum: false 32 | 33 | ablate_frame_updates: false 34 | ablate_scalars: false 35 | ablate_vectors: false 36 | 37 | conditioning: [] # note: the GEOM-Drugs dataset currently does not support property conditioning 38 | 39 | clip_gradients: true 40 | log_grad_flow_steps: 500 # after how many steps to log gradient flow -------------------------------------------------------------------------------- /configs/model/module_cfg/qm9_mol_gen_ddpm_gcp_module.yaml: -------------------------------------------------------------------------------- 1 | selected_GCP: 2 | # which version of the GCP module to use (e.g., GCP or GCP2) 3 | _target_: src.models.components.gcpnet.GCP2 4 | _partial_: true 5 | 6 | norm_x_diff: true 7 | 8 | scalar_gate: 0 9 | vector_gate: true # note: For both GCP and GCP2, this parameter is used; For GCP2, this mimics updating vector features without directly using frame vectors 10 | vector_residual: false # note: For both GCP and GCP2, this parameter is used 11 | vector_frame_residual: false # note: for GCP2, this parameter is unused 12 | frame_gate: false # note: for GCP2, if this parameter and `vector_gate` are both set to `false`, row-wise vector self-gating is applied instead 13 | sigma_frame_gate: false # note: For GCP, this parameter overrides `frame_gate`; For GCP2, this parameter is unused and is replaced in functionality by `vector_gate` 14 | 15 | scalar_nonlinearity: silu 16 | vector_nonlinearity: silu 17 | 18 | nonlinearities: 19 | - ${..scalar_nonlinearity} 20 | - ${..vector_nonlinearity} 21 | 22 | bottleneck: 4 23 | 24 | vector_linear: true 25 | vector_identity: true 26 | 27 | default_vector_residual: false 28 | default_bottleneck: 4 29 | 30 | node_positions_weight: 1.0 31 | update_positions_with_vector_sum: false 32 | 33 | ablate_frame_updates: false 34 | ablate_scalars: false 35 | ablate_vectors: false 36 | 37 | conditioning: [] # note: multiple arguments can be passed here including: homo | onehot | lumo | num_atoms | etc. 38 | # usage: `conditioning: [H_thermo, homo]`) 39 | 40 | clip_gradients: true 41 | log_grad_flow_steps: 500 # after how many steps to log gradient flow -------------------------------------------------------------------------------- /configs/model/qm9_mol_gen_ddpm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.qm9_mol_gen_ddpm.QM9MoleculeGenerationDDPM 2 | 3 | optimizer: 4 | _target_: torch.optim.AdamW 5 | _partial_: true 6 | lr: 1e-4 7 | weight_decay: 1e-12 8 | amsgrad: true 9 | 10 | scheduler: # note: leaving `scheduler` empty will result in a learning-rate scheduler not being used 11 | # _target_: torch.optim.lr_scheduler.StepLR 12 | # _partial_: true 13 | # step_size: ${...trainer.min_epochs} // 8 # note: using literal evalution manually until Hydra natively supports this functionality 14 | # gamma: 0.9 15 | # last_epoch: -1 16 | 17 | defaults: 18 | - model_cfg: qm9_mol_gen_ddpm_gcp_model.yaml 19 | - module_cfg: qm9_mol_gen_ddpm_gcp_module.yaml 20 | - layer_cfg: qm9_mol_gen_ddpm_gcp_interaction_layer.yaml 21 | - diffusion_cfg: qm9_mol_gen_ddpm.yaml 22 | -------------------------------------------------------------------------------- /configs/mol_gen_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: edm_qm9.yaml # choose datamodule with `val_dataloader()` and `test_dataloader()` for evaluation 6 | - model: qm9_mol_gen_ddpm.yaml 7 | - callbacks: default.yaml 8 | - logger: null 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - extras: default.yaml 12 | - hydra: default.yaml 13 | 14 | task_name: "mol_gen_eval" 15 | 16 | tags: ["dev"] 17 | 18 | # passing checkpoint path is necessary for sampling and evaluation 19 | ckpt_path: ??? 20 | 21 | # inference (i.e., sampling) and evaluation arguments 22 | seed: 42 23 | num_samples: 10000 24 | sampling_batch_size: 100 25 | num_timesteps: 1000 26 | num_test_passes: 5 27 | check_val_nll: true 28 | save_molecules: false 29 | output_dir: null 30 | -------------------------------------------------------------------------------- /configs/mol_gen_eval_conditional_qm9.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: edm_qm9.yaml # choose datamodule with `val_dataloader()` and `test_dataloader()` for evaluation 6 | - model: qm9_mol_gen_ddpm.yaml 7 | - callbacks: default.yaml 8 | - logger: null 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - extras: default.yaml 12 | - hydra: default.yaml 13 | 14 | task_name: "mol_gen_eval_conditional_qm9" 15 | 16 | tags: ["dev"] 17 | 18 | # inference (i.e., sampling) and evaluation arguments 19 | seed: 42 20 | generator_model_filepath: ??? 21 | classifier_model_dir: "" 22 | property: "alpha" 23 | iterations: 100 24 | batch_size: 100 25 | debug_break: false 26 | sweep_property_values: false 27 | save_molecules: true 28 | num_sweeps: 10 29 | experiment_name: ${.property}-conditioning-${.seed} 30 | output_dir: "" 31 | -------------------------------------------------------------------------------- /configs/mol_gen_eval_optimization_qm9.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: edm_qm9.yaml # choose datamodule with `val_dataloader()` and `test_dataloader()` for evaluation 6 | - model: qm9_mol_gen_ddpm.yaml 7 | - callbacks: default.yaml 8 | - logger: null 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - extras: default.yaml 12 | - hydra: default.yaml 13 | 14 | task_name: "mol_gen_eval_optimization_qm9" 15 | 16 | tags: ["dev"] 17 | 18 | # inference (i.e., sampling) and evaluation arguments 19 | seed: 42 20 | unconditional_generator_model_filepath: ??? 21 | conditional_generator_model_filepath: ??? 22 | classifier_model_dir: "" 23 | sampling_output_dir: "" 24 | num_samples: 1000 25 | num_timesteps: 10 26 | property: "alpha" 27 | iterations: 1 28 | num_optimization_timesteps: 100 29 | return_frames: 1 # note: set `return_frames > 1` to save sample 0's optimization GIF for each iteration 30 | debug_break: false 31 | save_molecules: false 32 | experiment_name: ${.property}-optimizing-${.seed} 33 | generate_molecules_only: false 34 | use_pregenerated_molecules: false 35 | -------------------------------------------------------------------------------- /configs/mol_gen_sample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: edm_qm9.yaml 6 | - model: qm9_mol_gen_ddpm.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | task_name: "mol_gen_sample" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for sampling 18 | ckpt_path: ??? 19 | 20 | # inference (i.e., sampling) arguments 21 | seed: ??? 22 | output_dir: "" 23 | num_samples: ??? 24 | num_nodes: 19 25 | all_frags: true 26 | sanitize: false 27 | sample_chain: false 28 | relax: false 29 | num_resamplings: 1 30 | jump_length: 1 31 | num_timesteps: ??? 32 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py`, `eval.py`, `mol_gen_sample.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | 20 | # path to grid search script directory 21 | grid_search_script_dir: 22 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - datamodule: edm_qm9.yaml 8 | - model: qm9_mol_gen_ddpm.yaml 9 | - callbacks: default.yaml 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | # appending lists from command line is currently not supported :( 37 | # https://github.com/facebookresearch/hydra/issues/1547 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: False 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 42 52 | 53 | # whether to log gradients, parameters, or model topology to a compatible external logger (i.e., WandB) 54 | watch_model: false -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | strategy: ddp_find_unused_parameters_false 5 | 6 | accelerator: gpu 7 | devices: 3 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 50 # prevents early stopping 6 | max_epochs: 3000 7 | 8 | strategy: ddp_find_unused_parameters_false 9 | 10 | accelerator: gpu 11 | devices: 1 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | 15 | # mixed precision for extra speed-up 16 | # precision: 16 17 | 18 | # number of sanity-check validation forward passes to run prior to model training 19 | num_sanity_val_steps: 0 20 | 21 | # perform a validation loop every N training epochs 22 | check_val_every_n_epoch: ${model.diffusion_cfg.eval_epochs} 23 | 24 | # gradient accumulation to simulate larger-than-GPU-memory batch sizes 25 | accumulate_grad_batches: 1 26 | 27 | # set True to ensure deterministic results 28 | # makes training slower but gives more reproducibility than just setting seeds 29 | deterministic: False 30 | 31 | # track and log the vector norm of each gradient 32 | # track_grad_norm: 2.0 33 | 34 | # profile code comprehensively 35 | profiler: 36 | # _target_: pytorch_lightning.profilers.PyTorchProfiler -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /img/Bio-Diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/img/Bio-Diffusion.png -------------------------------------------------------------------------------- /img/GCDM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/img/GCDM.png -------------------------------------------------------------------------------- /img/GCDM_Alpha_Conditional_Sampling.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/img/GCDM_Alpha_Conditional_Sampling.gif -------------------------------------------------------------------------------- /img/GCDM_Sampled_Molecule_Trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/img/GCDM_Sampled_Molecule_Trajectory.gif -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /scripts/generate_geom_mol_gen_ddpm_grid_search_runs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | import itertools 7 | import json 8 | 9 | 10 | # define constants # 11 | TASK = "geom_mol_gen_ddpm" # TODO: Ensure Is Correct Before Each Grid Search! 12 | SCRIPT_DIR = os.path.join("scripts") 13 | SEARCH_SPACE_FILEPATH = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_runs.json") 14 | 15 | 16 | def main(): 17 | # TODO: Ensure Is Correct Before Each Grid Search! 18 | search_space_dict = { 19 | "gcp_version": [2], 20 | "key_names": ["NEL NML LR WD DO CHD NT"], 21 | "model.model_cfg.num_encoder_layers": [4], 22 | "model.layer_cfg.mp_cfg.num_message_layers": [4], 23 | "model.optimizer.lr": [1e-4], 24 | "model.optimizer.weight_decay": [1e-12], 25 | "model.model_cfg.dropout": [0.0], 26 | "model.model_cfg.chi_hidden_dim": [16], 27 | "model.diffusion_cfg.num_timesteps": [1000] 28 | } 29 | 30 | # gather all combinations of hyperparameters while retaining field names for each chosen hyperparameter 31 | keys, values = zip(*search_space_dict.items()) 32 | hyperparameter_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)] 33 | 34 | # save search space to storage as JSON file 35 | with open(SEARCH_SPACE_FILEPATH, "w") as f: 36 | f.write(json.dumps(hyperparameter_dicts)) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /scripts/generate_qm9_mol_gen_ddpm_grid_search_runs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | import itertools 7 | import json 8 | 9 | 10 | # define constants # 11 | TASK = "qm9_mol_gen_ddpm" # TODO: Ensure Is Correct Before Each Grid Search! 12 | SCRIPT_DIR = os.path.join("scripts") 13 | SEARCH_SPACE_FILEPATH = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_runs.json") 14 | 15 | 16 | def main(): 17 | # TODO: Ensure Is Correct Before Each Grid Search! 18 | search_space_dict = { 19 | "gcp_version": [2], 20 | "key_names": ["NEL NML LR WD DO CHD NT C"], 21 | "model.model_cfg.num_encoder_layers": [9], 22 | "model.layer_cfg.mp_cfg.num_message_layers": [4], 23 | "model.optimizer.lr": [1e-4], 24 | "model.optimizer.weight_decay": [1e-12], 25 | "model.model_cfg.dropout": [0.0], 26 | "model.model_cfg.chi_hidden_dim": [32], 27 | "model.diffusion_cfg.num_timesteps": [1000], 28 | "model.module_cfg.conditioning": ["[]"] 29 | } 30 | 31 | # gather all combinations of hyperparameters while retaining field names for each chosen hyperparameter 32 | keys, values = zip(*search_space_dict.items()) 33 | hyperparameter_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)] 34 | 35 | # save search space to storage as JSON file 36 | with open(SEARCH_SPACE_FILEPATH, "w") as f: 37 | f.write(json.dumps(hyperparameter_dicts)) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /scripts/geom_mol_gen_ddpm_grid_search_scripts/launch_all_geom_mol_gen_ddpm_grid_search_jobs.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for file in gpu_job_*.yaml; do 4 | kubectl apply -f "$file" 5 | sleep 5 6 | done -------------------------------------------------------------------------------- /scripts/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/scripts/local/.gitkeep -------------------------------------------------------------------------------- /scripts/nautilus/data_transfer_pod_pvc_template.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | name: $USER-data-transfer-pod-pvc # REPLACE $USER with your Nautilus username 5 | spec: 6 | containers: 7 | - name: $USER-data-transfer-pod-pvc # REPLACE $USER with your Nautilus username 8 | image: ubuntu:20.04 9 | command: ["sh", "-c", "echo 'I am a new pod for data transfers to one of my PVCs' && sleep infinity"] 10 | resources: 11 | limits: 12 | memory: 12Gi 13 | cpu: 2 14 | requests: 15 | memory: 10Gi 16 | cpu: 2 17 | volumeMounts: 18 | - mountPath: /data 19 | name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 20 | volumes: 21 | - name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 22 | persistentVolumeClaim: 23 | claimName: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 24 | affinity: 25 | nodeAffinity: 26 | requiredDuringSchedulingIgnoredDuringExecution: 27 | nodeSelectorTerms: 28 | - matchExpressions: 29 | - key: topology.kubernetes.io/region 30 | operator: In 31 | values: 32 | - us-central -------------------------------------------------------------------------------- /scripts/nautilus/generate_data_transfer_pod_pvc_yaml.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import getpass 6 | import os 7 | import yaml 8 | 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | from src.utils.utils import replace_dict_str_values_unconditionally 15 | 16 | 17 | # define constants # 18 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 19 | SCRIPT_DIR = os.path.join("scripts", "nautilus") 20 | DATA_TRANSFER_POD_PVC_TEMPLATE_FILEPATH = os.path.join(SCRIPT_DIR, "data_transfer_pod_pvc_template.yaml") 21 | DATA_TRANSFER_POD_PVC_OUTPUT_FILEPATH = os.path.join(SCRIPT_DIR, "data_transfer_pod_pvc.yaml") 22 | 23 | 24 | def main(): 25 | with open(DATA_TRANSFER_POD_PVC_TEMPLATE_FILEPATH, "r") as f: 26 | yaml_dict = yaml.load(f, Loader) 27 | 28 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 29 | yaml_dict, 30 | unconditional_key_value_replacements={"$USER": USERNAME} 31 | ) 32 | 33 | with open(DATA_TRANSFER_POD_PVC_OUTPUT_FILEPATH, "w") as f: 34 | yaml.dump(unconditional_yaml_dict, f, Dumper) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /scripts/nautilus/generate_geom_mol_gen_ddpm_grid_search_jobs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import ast 6 | import getpass 7 | import json 8 | import os 9 | import wandb 10 | import yaml 11 | from datetime import datetime 12 | from typing import Any, Dict, List, Tuple 13 | 14 | try: 15 | from yaml import CLoader as Loader, CDumper as Dumper 16 | except ImportError: 17 | from yaml import Loader, Dumper 18 | 19 | from src.utils.utils import replace_dict_str_values_conditionally, replace_dict_str_values_unconditionally 20 | 21 | from torchtyping import patch_typeguard 22 | from typeguard import typechecked 23 | 24 | patch_typeguard() # use before @typechecked 25 | 26 | 27 | # define constants # 28 | HIGH_MEMORY = True # whether to use high-memory (HM) mode 29 | HALT_FILE_EXTENSION = "done" # TODO: Update `src.models.HALT_FILE_EXTENSION` As Well Upon Making Changes Here! 30 | IMAGE_TAG = "bb558b48" # TODO: Ensure Is Correct! 31 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 32 | TIMESTAMP = datetime.now().strftime("%m%d%Y_%H_%M") 33 | 34 | # choose a base experiment to run 35 | TASK = "geom_mol_gen_ddpm" # TODO: Ensure Is Correct Before Each Grid Search! 36 | EXPERIMENT = f"{TASK.lower()}_grid_search" # TODO: Ensure Is Correct Before Each Grid Search! 37 | TEMPLATE_RUN_NAME = f"{TIMESTAMP}_{TASK.lower()}" 38 | TEMPLATE_COMMAND_STR = f"cd /data/Repositories/Lab_Repositories/bio-diffusion && git pull origin main && /data/Repositories/Lab_Repositories/bio-diffusion/bio-diffusion/bin/python src/train.py experiment={EXPERIMENT}" 39 | TEMPLATE_COMMAND_STR += " trainer.devices=auto" # TODO: Remove Once No Longer Needed 40 | NUM_RUNS_PER_EXPERIMENT = {"qm9_mol_gen_ddpm": 3, "geom_mol_gen_ddpm": 1} 41 | 42 | # establish paths 43 | OUTPUT_SCRIPT_FILENAME_PREFIX = "gpu_job" 44 | SCRIPT_DIR = os.path.join("scripts") 45 | OUTPUT_SCRIPT_DIR = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_scripts") 46 | TEMPLATE_SCRIPT_FILEPATH = os.path.join( 47 | SCRIPT_DIR, 48 | "nautilus", 49 | "hm_gpu_job_template.yaml" 50 | if HIGH_MEMORY 51 | else "gpu_job_template.yaml" 52 | ) 53 | 54 | assert TASK in NUM_RUNS_PER_EXPERIMENT.keys(), f"The task {TASK} is not currently available." 55 | 56 | 57 | @typechecked 58 | def build_command_string( 59 | run: Dict[str, Any], 60 | items_to_show: List[Tuple[str, Any]], 61 | command_str: str = TEMPLATE_COMMAND_STR, 62 | run_name: str = TEMPLATE_RUN_NAME, 63 | run_id: str = wandb.util.generate_id() 64 | ) -> str: 65 | # substitute latest grid search parameter values into command string of latest script 66 | command_str += f" tags='[bio-diffusion, geom_mol_gen_ddpm, grid_search, nautilus]' logger=wandb logger.wandb.id={run_id} logger.wandb.name='{run_name}_GCPv{run['gcp_version']}" 67 | 68 | # install a unique WandB run name 69 | for s, (key, value) in zip(run["key_names"].split(), items_to_show): 70 | if s in ["C", "NV", "NB"]: 71 | # parse individual contexts to use for conditioning 72 | contexts = ast.literal_eval(value) 73 | command_str += f"_{s.strip()}:" 74 | for contextIndex, context in enumerate(contexts): 75 | command_str += f"{context}" 76 | command_str = ( 77 | command_str 78 | if contextIndex == len(contexts) - 1 79 | else command_str + "-" 80 | ) 81 | elif s == "N": 82 | # bypass listing combined nonlinearities due to their redundancy 83 | pass 84 | else: 85 | command_str += f"_{s.strip()}:{value}" 86 | command_str += "'" # ensure the WandB name ends in a single quote to avoid Hydra list parsing 87 | 88 | # establish directory in which to store and find checkpoints and other artifacts for run 89 | run_dir = os.path.join("logs", "train", "runs", run_id) 90 | ckpt_dir = os.path.join(run_dir, "checkpoints") 91 | ckpt_path = os.path.join(ckpt_dir, "last.ckpt") 92 | command_str += f" hydra.run.dir={run_dir}" 93 | command_str += f" ckpt_path={ckpt_path}" # define name of latest checkpoint for resuming model 94 | 95 | # manually specify version of GCP module to use 96 | command_str += f" model.module_cfg.selected_GCP._target_=src.models.components.gcpnet.GCP{run['gcp_version']}" 97 | 98 | # add each custom grid search argument 99 | for key, value in items_to_show: 100 | if key in ["model.module_cfg.conditioning", "model.diffusion_cfg.norm_values", "model.diffusion_cfg.norm_biases"]: 101 | # ensure that Hydra will be able to parse list of contexts to use for conditioning 102 | command_str += f" {key}='{value}'" 103 | elif key == "model.module_cfg.nonlinearities": 104 | # ensure that Hydra will be able to parse list of nonlinearities to use for training 105 | parsed_nonlinearities = [ 106 | nonlinearity 107 | if nonlinearity is not None and len(nonlinearity) > 0 108 | else "null" for nonlinearity in value 109 | ] 110 | command_str += f" {key}='{parsed_nonlinearities}'" 111 | else: 112 | command_str += f" {key}={value}" 113 | 114 | return command_str 115 | 116 | 117 | def main(): 118 | # load search space from storage as JSON file 119 | search_space_filepath = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_runs.json") 120 | assert os.path.exists( 121 | search_space_filepath 122 | ), "JSON file describing grid search runs must be generated beforehand using `generate_grid_search_runs.py`" 123 | with open(search_space_filepath, "r") as f: 124 | grid_search_runs = json.load(f) 125 | 126 | # curate each grid search run 127 | grid_search_runs = [run for run in grid_search_runs for _ in range(NUM_RUNS_PER_EXPERIMENT[TASK])] 128 | for run_index, run in enumerate(grid_search_runs): 129 | # distinguish items to show in arguments list 130 | items_to_show = [(key, value) for (key, value) in run.items() if key not in ["gcp_version", "key_names"]] 131 | 132 | # build list of input arguments 133 | run_id = wandb.util.generate_id() 134 | cur_script_filename = f"{OUTPUT_SCRIPT_FILENAME_PREFIX}_{run_index}.yaml" 135 | command_str = build_command_string(run, items_to_show, run_id=run_id) 136 | 137 | # write out latest script as copy of template launcher script 138 | output_script_filepath = os.path.join( 139 | OUTPUT_SCRIPT_DIR, cur_script_filename 140 | ) 141 | with open(TEMPLATE_SCRIPT_FILEPATH, "r") as f: 142 | yaml_dict = yaml.load(f, Loader) 143 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 144 | yaml_dict, 145 | unconditional_key_value_replacements={ 146 | "$JOB_INDEX": f"-{run_index}", 147 | "$IMAGE_TAG": IMAGE_TAG, 148 | "$USER": USERNAME, 149 | "$EXPERIMENT": EXPERIMENT 150 | } 151 | ) 152 | conditional_yaml_dict = replace_dict_str_values_conditionally( 153 | unconditional_yaml_dict, 154 | conditional_key_value_replacements={"command": ["bash", "-c", command_str]} 155 | ) 156 | with open(output_script_filepath, "w") as f: 157 | yaml.dump(conditional_yaml_dict, f, Dumper) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /scripts/nautilus/generate_gpu_job_yaml.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import getpass 6 | import os 7 | import yaml 8 | 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | from src.utils.utils import replace_dict_str_values_unconditionally 15 | 16 | 17 | # define constants # 18 | JOB_INDEX = "" # TODO: Ensure Is Correct! 19 | IMAGE_TAG = "bb558b48" # TODO: Ensure Is Correct! 20 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 21 | EXPERIMENT = "qm9_mol_gen_ddpm" # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 22 | SCRIPT_DIR = os.path.join("scripts", "nautilus") 23 | GPU_JOB_TEMPLATE_FILEPATH = os.path.join(SCRIPT_DIR, "gpu_job_template.yaml") 24 | GPU_JOB_OUTPUT_FILEPATH = os.path.join(SCRIPT_DIR, "gpu_job.yaml") 25 | 26 | 27 | def main(): 28 | with open(GPU_JOB_TEMPLATE_FILEPATH, "r") as f: 29 | yaml_dict = yaml.load(f, Loader) 30 | 31 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 32 | yaml_dict, 33 | unconditional_key_value_replacements={ 34 | "$JOB_INDEX": JOB_INDEX, 35 | "$IMAGE_TAG": IMAGE_TAG, 36 | "$USER": USERNAME, 37 | "$EXPERIMENT": EXPERIMENT 38 | } 39 | ) 40 | 41 | with open(GPU_JOB_OUTPUT_FILEPATH, "w") as f: 42 | yaml.dump(unconditional_yaml_dict, f, Dumper) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/nautilus/generate_hm_gpu_job_yaml.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import getpass 6 | import os 7 | import yaml 8 | 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | from src.utils.utils import replace_dict_str_values_unconditionally 15 | 16 | 17 | # define constants # 18 | JOB_INDEX = "" # TODO: Ensure Is Correct! 19 | IMAGE_TAG = "bb558b48" # TODO: Ensure Is Correct! 20 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 21 | EXPERIMENT = "qm9_mol_gen_ddpm" # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 22 | SCRIPT_DIR = os.path.join("scripts", "nautilus") 23 | HM_GPU_JOB_TEMPLATE_FILEPATH = os.path.join(SCRIPT_DIR, "hm_gpu_job_template.yaml") 24 | HM_GPU_JOB_OUTPUT_FILEPATH = os.path.join(SCRIPT_DIR, "hm_gpu_job.yaml") 25 | 26 | 27 | def main(): 28 | with open(HM_GPU_JOB_TEMPLATE_FILEPATH, "r") as f: 29 | yaml_dict = yaml.load(f, Loader) 30 | 31 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 32 | yaml_dict, 33 | unconditional_key_value_replacements={ 34 | "$JOB_INDEX": JOB_INDEX, 35 | "$IMAGE_TAG": IMAGE_TAG, 36 | "$USER": USERNAME, 37 | "$EXPERIMENT": EXPERIMENT 38 | } 39 | ) 40 | 41 | with open(HM_GPU_JOB_OUTPUT_FILEPATH, "w") as f: 42 | yaml.dump(unconditional_yaml_dict, f, Dumper) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/nautilus/generate_persistent_storage_yaml.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import getpass 6 | import os 7 | import yaml 8 | 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | from src.utils.utils import replace_dict_str_values_conditionally, replace_dict_str_values_unconditionally 15 | 16 | 17 | # define constants # 18 | STORAGE_SIZE = "1000Gi" # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 19 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 20 | SCRIPT_DIR = os.path.join("scripts", "nautilus") 21 | PERSISTENT_STORAGE_TEMPLATE_FILEPATH = os.path.join(SCRIPT_DIR, "persistent_storage_template.yaml") 22 | PERSISTENT_STORAGE_OUTPUT_FILEPATH = os.path.join(SCRIPT_DIR, "persistent_storage.yaml") 23 | 24 | 25 | def main(): 26 | with open(PERSISTENT_STORAGE_TEMPLATE_FILEPATH, "r") as f: 27 | yaml_dict = yaml.load(f, Loader) 28 | 29 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 30 | yaml_dict, 31 | unconditional_key_value_replacements={"$USER": USERNAME} 32 | ) 33 | conditional_yaml_dict = replace_dict_str_values_conditionally( 34 | unconditional_yaml_dict, 35 | conditional_key_value_replacements={"storage": STORAGE_SIZE} 36 | ) 37 | 38 | with open(PERSISTENT_STORAGE_OUTPUT_FILEPATH, "w") as f: 39 | yaml.dump(conditional_yaml_dict, f, Dumper) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /scripts/nautilus/generate_qm9_mol_gen_ddpm_grid_search_jobs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import ast 6 | import getpass 7 | import json 8 | import os 9 | import wandb 10 | import yaml 11 | from datetime import datetime 12 | from typing import Any, Dict, List, Tuple 13 | 14 | try: 15 | from yaml import CLoader as Loader, CDumper as Dumper 16 | except ImportError: 17 | from yaml import Loader, Dumper 18 | 19 | from src.utils.utils import replace_dict_str_values_conditionally, replace_dict_str_values_unconditionally 20 | 21 | from torchtyping import patch_typeguard 22 | from typeguard import typechecked 23 | 24 | patch_typeguard() # use before @typechecked 25 | 26 | 27 | # define constants # 28 | HIGH_MEMORY = False # whether to use high-memory (HM) mode 29 | HALT_FILE_EXTENSION = "done" # TODO: Update `src.models.HALT_FILE_EXTENSION` As Well Upon Making Changes Here! 30 | IMAGE_TAG = "bb558b48" # TODO: Ensure Is Correct! 31 | USERNAME = getpass.getuser() # TODO: Ensure Is Correct for Nautilus Before Each Grid Search! 32 | TIMESTAMP = datetime.now().strftime("%m%d%Y_%H_%M") 33 | 34 | # choose a base experiment to run 35 | TASK = "qm9_mol_gen_ddpm" # TODO: Ensure Is Correct Before Each Grid Search! 36 | EXPERIMENT = f"{TASK.lower()}_grid_search" # TODO: Ensure Is Correct Before Each Grid Search! 37 | TEMPLATE_RUN_NAME = f"{TIMESTAMP}_{TASK.lower()}" 38 | TEMPLATE_COMMAND_STR = f"cd /data/Repositories/Lab_Repositories/bio-diffusion && git pull origin main && /data/Repositories/Lab_Repositories/bio-diffusion/bio-diffusion/bin/python src/train.py experiment={EXPERIMENT}" 39 | TEMPLATE_COMMAND_STR += " trainer.devices=auto" # TODO: Remove Once No Longer Needed 40 | NUM_RUNS_PER_EXPERIMENT = {"qm9_mol_gen_ddpm": 3, "geom_mol_gen_ddpm": 1} 41 | 42 | # establish paths 43 | OUTPUT_SCRIPT_FILENAME_PREFIX = "gpu_job" 44 | SCRIPT_DIR = os.path.join("scripts") 45 | OUTPUT_SCRIPT_DIR = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_scripts") 46 | TEMPLATE_SCRIPT_FILEPATH = os.path.join( 47 | SCRIPT_DIR, 48 | "nautilus", 49 | "hm_gpu_job_template.yaml" 50 | if HIGH_MEMORY 51 | else "gpu_job_template.yaml" 52 | ) 53 | 54 | assert TASK in NUM_RUNS_PER_EXPERIMENT.keys(), f"The task {TASK} is not currently available." 55 | 56 | 57 | @typechecked 58 | def build_command_string( 59 | run: Dict[str, Any], 60 | items_to_show: List[Tuple[str, Any]], 61 | command_str: str = TEMPLATE_COMMAND_STR, 62 | run_name: str = TEMPLATE_RUN_NAME, 63 | run_id: str = wandb.util.generate_id() 64 | ) -> str: 65 | # substitute latest grid search parameter values into command string of latest script 66 | command_str += f" tags='[bio-diffusion, qm9_mol_gen_ddpm, grid_search, nautilus]' logger=wandb logger.wandb.id={run_id} logger.wandb.name='{run_name}_GCPv{run['gcp_version']}" 67 | 68 | # install a unique WandB run name 69 | for s, (key, value) in zip(run["key_names"].split(), items_to_show): 70 | if s in ["C", "NV", "NB"]: 71 | # parse individual contexts to use for conditioning 72 | contexts = ast.literal_eval(value) 73 | command_str += f"_{s.strip()}:" 74 | for contextIndex, context in enumerate(contexts): 75 | command_str += f"{context}" 76 | command_str = ( 77 | command_str 78 | if contextIndex == len(contexts) - 1 79 | else command_str + "-" 80 | ) 81 | elif s == "N": 82 | # bypass listing combined nonlinearities due to their redundancy 83 | pass 84 | else: 85 | command_str += f"_{s.strip()}:{value}" 86 | command_str += "'" # ensure the WandB name ends in a single quote to avoid Hydra list parsing 87 | 88 | # establish directory in which to store and find checkpoints and other artifacts for run 89 | run_dir = os.path.join("logs", "train", "runs", run_id) 90 | ckpt_dir = os.path.join(run_dir, "checkpoints") 91 | ckpt_path = os.path.join(ckpt_dir, "last.ckpt") 92 | command_str += f" hydra.run.dir={run_dir}" 93 | command_str += f" ckpt_path={ckpt_path}" # define name of latest checkpoint for resuming model 94 | 95 | # manually specify version of GCP module to use 96 | command_str += f" model.module_cfg.selected_GCP._target_=src.models.components.gcpnet.GCP{run['gcp_version']}" 97 | 98 | # add each custom grid search argument 99 | for key, value in items_to_show: 100 | if key in ["model.module_cfg.conditioning", "model.diffusion_cfg.norm_values", "model.diffusion_cfg.norm_biases"]: 101 | # ensure that Hydra will be able to parse list of contexts to use for conditioning 102 | command_str += f" {key}='{value}'" 103 | elif key == "model.module_cfg.nonlinearities": 104 | # ensure that Hydra will be able to parse list of nonlinearities to use for training 105 | parsed_nonlinearities = [ 106 | nonlinearity 107 | if nonlinearity is not None and len(nonlinearity) > 0 108 | else "null" for nonlinearity in value 109 | ] 110 | command_str += f" {key}='{parsed_nonlinearities}'" 111 | else: 112 | command_str += f" {key}={value}" 113 | 114 | return command_str 115 | 116 | 117 | def main(): 118 | # load search space from storage as JSON file 119 | search_space_filepath = os.path.join(SCRIPT_DIR, f"{TASK}_grid_search_runs.json") 120 | assert os.path.exists( 121 | search_space_filepath 122 | ), "JSON file describing grid search runs must be generated beforehand using `generate_grid_search_runs.py`" 123 | with open(search_space_filepath, "r") as f: 124 | grid_search_runs = json.load(f) 125 | 126 | # curate each grid search run 127 | grid_search_runs = [run for run in grid_search_runs for _ in range(NUM_RUNS_PER_EXPERIMENT[TASK])] 128 | for run_index, run in enumerate(grid_search_runs): 129 | # distinguish items to show in arguments list 130 | items_to_show = [(key, value) for (key, value) in run.items() if key not in ["gcp_version", "key_names"]] 131 | 132 | # build list of input arguments 133 | run_id = wandb.util.generate_id() 134 | cur_script_filename = f"{OUTPUT_SCRIPT_FILENAME_PREFIX}_{run_index}.yaml" 135 | command_str = build_command_string(run, items_to_show, run_id=run_id) 136 | 137 | # write out latest script as copy of template launcher script 138 | output_script_filepath = os.path.join( 139 | OUTPUT_SCRIPT_DIR, cur_script_filename 140 | ) 141 | with open(TEMPLATE_SCRIPT_FILEPATH, "r") as f: 142 | yaml_dict = yaml.load(f, Loader) 143 | unconditional_yaml_dict = replace_dict_str_values_unconditionally( 144 | yaml_dict, 145 | unconditional_key_value_replacements={ 146 | "$JOB_INDEX": f"-{run_index}", 147 | "$IMAGE_TAG": IMAGE_TAG, 148 | "$USER": USERNAME, 149 | "$EXPERIMENT": EXPERIMENT 150 | } 151 | ) 152 | conditional_yaml_dict = replace_dict_str_values_conditionally( 153 | unconditional_yaml_dict, 154 | conditional_key_value_replacements={"command": ["bash", "-c", command_str]} 155 | ) 156 | with open(output_script_filepath, "w") as f: 157 | yaml.dump(conditional_yaml_dict, f, Dumper) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /scripts/nautilus/gpu_job_template.yaml: -------------------------------------------------------------------------------- 1 | # batch/v1 tells it to use the JOB API 2 | apiVersion: batch/v1 3 | # we are running a Job, not a Pod 4 | kind: Job 5 | 6 | # set the name of the job 7 | metadata: 8 | name: train-test-bio-diffusion$JOB_INDEX 9 | 10 | spec: 11 | # how many times should the system 12 | # retry before calling it a failure 13 | backoffLimit: 0 14 | template: 15 | spec: 16 | # should we restart on failure 17 | restartPolicy: Never 18 | # what containers will we need 19 | containers: 20 | # the name of the container 21 | - name: bio-diffusion 22 | # the image: can be from any public facing registry such as your GitLab repository's container registry 23 | image: gitlab-registry.nrp-nautilus.io/bioinfomachinelearning/bio-diffusion:$IMAGE_TAG # replace `IMAGE_TAG` with tag for container of interest 24 | # the working dir when the container starts 25 | workingDir: /data/Repositories/Lab_Repositories/bio-diffusion 26 | # whether Kube should pull it 27 | imagePullPolicy: IfNotPresent 28 | # we need to expose the port 29 | # that will be used for DDP 30 | ports: 31 | - containerPort: 8880 32 | # setting of env variables 33 | env: 34 | # which interface to use 35 | - name: NCCL_SOCKET_IFNAME 36 | value: eth0 37 | # note: prints some INFO level 38 | # NCCL logs 39 | - name: NCCL_DEBUG 40 | value: INFO 41 | # the command to run when the container starts 42 | command: 43 | [ 44 | "bash", 45 | "-c", 46 | "cd /data/Repositories/Lab_Repositories/bio-diffusion && git pull origin main && /data/Repositories/Lab_Repositories/bio-diffusion/bio-diffusion/bin/python src/train.py logger=wandb experiment=$EXPERIMENT", 47 | ] 48 | # define the resources for this container 49 | resources: 50 | # limits - the max given to the container 51 | limits: 52 | # RAM 53 | memory: 14Gi 54 | # cores 55 | cpu: 2 56 | # NVIDIA GPUs 57 | nvidia.com/gpu: 1 58 | # requests - what we'd like 59 | requests: 60 | # RAM 61 | memory: 12Gi 62 | # CPU Cores 63 | cpu: 2 64 | # GPUs 65 | nvidia.com/gpu: 1 66 | # what volumes we should mount 67 | volumeMounts: 68 | # note: my datasets PVC should mount to /data 69 | - mountPath: /data 70 | name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 71 | # IMPORTANT: we need SHM for DDP 72 | - mountPath: /dev/shm 73 | name: dshm 74 | # tell Kube where to find credentials with which to pull GitLab Docker containers 75 | imagePullSecrets: 76 | - name: regcred-bio-diffusion 77 | # tell Kube where to find the volumes we want to use 78 | volumes: 79 | # which PVC is my data 80 | - name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 81 | persistentVolumeClaim: 82 | claimName: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 83 | # setup shared memory as a RAM volume 84 | - name: dshm 85 | emptyDir: 86 | medium: Memory 87 | # tell Kube what type of GPUs we want 88 | affinity: 89 | nodeAffinity: 90 | requiredDuringSchedulingIgnoredDuringExecution: 91 | nodeSelectorTerms: 92 | - matchExpressions: 93 | - key: nvidia.com/gpu.product 94 | operator: In 95 | values: 96 | # note: here, we are asking for 24GB GPUs only 97 | - NVIDIA-A10 98 | - NVIDIA-GeForce-RTX-3090 99 | - NVIDIA-TITAN-RTX 100 | - NVIDIA-RTX-A5000 101 | - Quadro-RTX-6000 102 | -------------------------------------------------------------------------------- /scripts/nautilus/hm_gpu_job_template.yaml: -------------------------------------------------------------------------------- 1 | # batch/v1 tells it to use the JOB API 2 | apiVersion: batch/v1 3 | # we are running a Job, not a Pod 4 | kind: Job 5 | 6 | # set the name of the job 7 | metadata: 8 | name: train-test-bio-diffusion-hm$JOB_INDEX 9 | 10 | spec: 11 | # how many times should the system 12 | # retry before calling it a failure 13 | backoffLimit: 0 14 | template: 15 | spec: 16 | # should we restart on failure 17 | restartPolicy: Never 18 | # what containers will we need 19 | containers: 20 | # the name of the container 21 | - name: bio-diffusion 22 | # the image: can be from any public facing registry such as your GitLab repository's container registry 23 | image: gitlab-registry.nrp-nautilus.io/bioinfomachinelearning/bio-diffusion:$IMAGE_TAG # replace `IMAGE_TAG` with tag for container of interest 24 | # the working dir when the container starts 25 | workingDir: /data/Repositories/Lab_Repositories/bio-diffusion 26 | # whether Kube should pull it 27 | imagePullPolicy: IfNotPresent 28 | # we need to expose the port 29 | # that will be used for DDP 30 | ports: 31 | - containerPort: 8880 32 | # setting of env variables 33 | env: 34 | # which interface to use 35 | - name: NCCL_SOCKET_IFNAME 36 | value: eth0 37 | # note: prints some INFO level 38 | # NCCL logs 39 | - name: NCCL_DEBUG 40 | value: INFO 41 | # the command to run when the container starts 42 | command: 43 | [ 44 | "bash", 45 | "-c", 46 | "cd /data/Repositories/Lab_Repositories/bio-diffusion && git pull origin main && /data/Repositories/Lab_Repositories/bio-diffusion/bio-diffusion/bin/python src/train.py logger=wandb experiment=$EXPERIMENT", 47 | ] 48 | # define the resources for this container 49 | resources: 50 | # limits - the max given to the container 51 | limits: 52 | # RAM 53 | memory: 20Gi 54 | # cores 55 | cpu: 2 56 | # NVIDIA GPUs 57 | nvidia.com/gpu: 1 58 | # requests - what we'd like 59 | requests: 60 | # RAM 61 | memory: 18Gi 62 | # CPU Cores 63 | cpu: 2 64 | # GPUs 65 | nvidia.com/gpu: 1 66 | # what volumes we should mount 67 | volumeMounts: 68 | # note: my datasets PVC should mount to /data 69 | - mountPath: /data 70 | name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 71 | # IMPORTANT: we need SHM for DDP 72 | - mountPath: /dev/shm 73 | name: dshm 74 | # tell Kube where to find credentials with which to pull GitLab Docker containers 75 | imagePullSecrets: 76 | - name: regcred-bio-diffusion 77 | # tell Kube where to find the volumes we want to use 78 | volumes: 79 | # which PVC is my data 80 | - name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 81 | persistentVolumeClaim: 82 | claimName: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 83 | # setup shared memory as a RAM volume 84 | - name: dshm 85 | emptyDir: 86 | medium: Memory 87 | # tell Kube what type of GPUs we want 88 | affinity: 89 | nodeAffinity: 90 | requiredDuringSchedulingIgnoredDuringExecution: 91 | nodeSelectorTerms: 92 | - matchExpressions: 93 | - key: nvidia.com/gpu.product 94 | operator: In 95 | values: 96 | # note: here, we are asking for 48GB GPUs only 97 | - NVIDIA-A40 98 | - NVIDIA-RTX-A6000 99 | - Quadro-RTX-8000 100 | -------------------------------------------------------------------------------- /scripts/nautilus/persistent_storage_template.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | name: $USER-bio-diffusion-pvc # REPLACE $USER with your Nautilus username 5 | spec: 6 | storageClassName: rook-cephfs-central 7 | accessModes: 8 | - ReadWriteMany 9 | resources: 10 | requests: 11 | storage: 1000Gi -------------------------------------------------------------------------------- /scripts/qm9_mol_gen_ddpm_grid_search_scripts/launch_all_qm9_mol_gen_ddpm_grid_search_jobs.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for file in gpu_job_*.yaml; do 4 | kubectl apply -f "$file" 5 | sleep 5 6 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="Bio-Diffusion", 7 | version="0.0.1", 8 | description="A hub for deep diffusion networks designed to generate novel biological data", 9 | author="Alex Morehead", 10 | author_email="acmwhb@umsystem.edu", 11 | url="https://github.com/BioinfoMachineLearning/bio-diffusion", 12 | install_requires=["pytorch-lightning", "hydra-core"], 13 | packages=find_packages(), 14 | ) 15 | -------------------------------------------------------------------------------- /src/analysis/bust_analysis.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import hydra 6 | import pyrootutils 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import seaborn as sns 11 | import matplotlib.pyplot as plt 12 | 13 | from omegaconf import DictConfig 14 | 15 | root = pyrootutils.setup_root( 16 | search_from=__file__, 17 | indicator=[".git", "pyproject.toml"], 18 | pythonpath=True, 19 | dotenv=True, 20 | ) 21 | 22 | from src.analysis.inference_analysis import calculate_mean_and_conf_int 23 | from src.utils.pylogger import get_pylogger 24 | log = get_pylogger(__name__) 25 | 26 | 27 | @hydra.main( 28 | version_base="1.3", 29 | config_path="../../configs/analysis", 30 | config_name="bust_analysis.yaml", 31 | ) 32 | def main(cfg: DictConfig): 33 | """Compare the bust results of generated molecules from two separate generative model checkpoints. 34 | 35 | :param cfg: Configuration dictionary from the hydra YAML file. 36 | """ 37 | method_1_bust_results = pd.read_csv(cfg.method_1_bust_results_filepath) 38 | method_2_bust_results = pd.read_csv(cfg.method_2_bust_results_filepath) 39 | 40 | assert cfg.bust_column_name in method_1_bust_results.columns, f"{cfg.bust_column_name} not found in {cfg.method_1_bust_results_filepath}" 41 | assert cfg.bust_column_name in method_2_bust_results.columns, f"{cfg.bust_column_name} not found in {cfg.method_2_bust_results_filepath}" 42 | 43 | # Add a source column to distinguish between datasets 44 | method_1_bust_results["source"] = cfg.method_1 45 | method_2_bust_results["source"] = cfg.method_2 46 | 47 | # Select only the requested column as well as the source column 48 | method_1_data = method_1_bust_results[[cfg.bust_column_name, "source"]] 49 | method_2_data = method_2_bust_results[[cfg.bust_column_name, "source"]] 50 | 51 | if cfg.verbose: 52 | method_1_column_mean, method_1_column_conf_int = calculate_mean_and_conf_int(method_1_data[cfg.bust_column_name][~np.isnan(method_1_data[cfg.bust_column_name])]) 53 | method_2_column_mean, method_2_column_conf_int = calculate_mean_and_conf_int(method_2_data[cfg.bust_column_name][~np.isnan(method_2_data[cfg.bust_column_name])]) 54 | log.info(f"Mean of {cfg.bust_column_name} for {cfg.method_1}: {method_1_column_mean} ± {(method_1_column_conf_int[1] - method_1_column_mean)}") 55 | log.info(f"Mean of {cfg.bust_column_name} for {cfg.method_2}: {method_2_column_mean} ± {(method_2_column_conf_int[1] - method_2_column_mean)}") 56 | 57 | # Combine the data 58 | combined_data = pd.concat([method_1_data, method_2_data], ignore_index=True) 59 | 60 | # Plotting 61 | ax = sns.boxplot(x="source", y=cfg.bust_column_name, data=combined_data) 62 | ax.set_ylim(0, 10) 63 | plt.xlabel("Method") 64 | plt.ylabel(f"{cfg.bust_column_name.title()}") 65 | plt.savefig(cfg.bust_analysis_plot_filepath, dpi=300) 66 | 67 | if cfg.verbose: 68 | log.info("Bust analysis completed") 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /src/analysis/molecule_analysis.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import copy 6 | import glob 7 | import hydra 8 | import os 9 | import pyrootutils 10 | import subprocess 11 | 12 | import pandas as pd 13 | 14 | from omegaconf import DictConfig, open_dict 15 | from pathlib import Path 16 | from posebusters import PoseBusters 17 | from tqdm import tqdm 18 | from typing import Optional 19 | 20 | root = pyrootutils.setup_root( 21 | search_from=__file__, 22 | indicator=[".git", "pyproject.toml"], 23 | pythonpath=True, 24 | dotenv=True, 25 | ) 26 | 27 | from src.utils.pylogger import get_pylogger 28 | log = get_pylogger(__name__) 29 | 30 | 31 | def convert_xyz_to_sdf(input_xyz_filepath: str) -> Optional[str]: 32 | """Convert an XYZ file to an SDF file using OpenBabel. 33 | 34 | :param input_xyz_filepath: Input XYZ file path. 35 | :return: Output SDF file path. 36 | """ 37 | output_sdf_filepath = input_xyz_filepath.replace(".xyz", ".sdf") 38 | if not os.path.exists(output_sdf_filepath): 39 | subprocess.run( 40 | [ 41 | "obabel", 42 | input_xyz_filepath, 43 | "-O", 44 | output_sdf_filepath, 45 | ], 46 | check=True, 47 | ) 48 | return output_sdf_filepath if os.path.exists(output_sdf_filepath) else None 49 | 50 | 51 | def create_molecule_table(input_molecule_dir: str) -> pd.DataFrame: 52 | """Create a molecule table from the inference results of a trained model checkpoint. 53 | 54 | :param input_molecule_dir: Directory containing the generated molecules of a trained model checkpoint. 55 | :return: Molecule table as a Pandas DataFrame. 56 | """ 57 | inference_xyz_results = [str(item) for item in Path(input_molecule_dir).rglob("*.xyz")] 58 | inference_sdf_results = [str(item) for item in Path(input_molecule_dir).rglob("*.sdf")] 59 | if not inference_sdf_results or len(inference_sdf_results) != len(inference_xyz_results): 60 | inference_sdf_results = [ 61 | convert_xyz_to_sdf(item) for item in tqdm( 62 | inference_xyz_results, desc="Converting XYZ input files to SDF files" 63 | ) 64 | ] 65 | mol_table = pd.DataFrame( 66 | { 67 | "mol_pred": [item for item in inference_sdf_results if item is not None], 68 | "mol_true": None, 69 | "mol_cond": None, 70 | } 71 | ) 72 | return mol_table 73 | 74 | 75 | def run_unconditional_molecule_analysis(cfg: DictConfig): 76 | """ 77 | Run molecule analysis for an unconditional method. 78 | 79 | :param cfg: Configuration dictionary from the hydra YAML file. 80 | """ 81 | input_molecule_subdirs = os.listdir(cfg.input_molecule_dir) 82 | assert cfg.sampling_index is None or (cfg.sampling_index is not None and cfg.sampling_index < len(input_molecule_subdirs)), "The given sampling index is out of range." 83 | 84 | sampling_index = 0 85 | for item in input_molecule_subdirs: 86 | sampling_dir = os.path.join(cfg.input_molecule_dir, item) 87 | if os.path.isdir(sampling_dir): 88 | if cfg.sampling_index is not None and sampling_index != cfg.sampling_index: 89 | sampling_index += 1 90 | continue 91 | log.info(f"Processing sampling {sampling_index} corresponding to {sampling_dir}...") 92 | bust_results_filepath = cfg.bust_results_filepath.replace(".csv", f"_{sampling_index}.csv") 93 | mol_table = create_molecule_table(sampling_dir) 94 | buster = PoseBusters(config="mol", top_n=None) 95 | bust_results = buster.bust_table(mol_table, full_report=cfg.full_report) 96 | bust_results.to_csv(bust_results_filepath, index=False) 97 | log.info(f"PoseBusters results for sampling {sampling_index} saved to {bust_results_filepath}.") 98 | sampling_index += 1 99 | 100 | 101 | def run_conditional_molecule_analysis(cfg: DictConfig): 102 | """ 103 | Run molecule analysis for a property-conditional method. 104 | 105 | :param cfg: Configuration dictionary from the hydra YAML file. 106 | """ 107 | log.info(f"Processing sampling directory {cfg.input_molecule_dir}...") 108 | for subdir in glob.glob(os.path.join(cfg.input_molecule_dir, "*")): 109 | if not os.path.isdir(subdir): 110 | continue 111 | log.info(f"Processing sampling directory {subdir}...") 112 | seed = subdir.split("_")[-1] 113 | mol_table = create_molecule_table(subdir) 114 | buster = PoseBusters(config="mol", top_n=None) 115 | bust_results = buster.bust_table(mol_table, full_report=cfg.full_report) 116 | bust_results.to_csv(cfg.bust_results_filepath.replace(".csv", f"_seed_{seed}.csv"), index=False) 117 | log.info(f"PoseBusters results for sampling directory {cfg.input_molecule_dir} saved to {cfg.bust_results_filepath}.") 118 | 119 | 120 | @hydra.main( 121 | version_base="1.3", 122 | config_path="../../configs/analysis", 123 | config_name="molecule_analysis.yaml", 124 | ) 125 | def main(cfg: DictConfig): 126 | """Analyze the generated molecules from a trained model checkpoint. 127 | 128 | :param cfg: Configuration dictionary from the hydra YAML file. 129 | """ 130 | os.makedirs(Path(cfg.bust_results_filepath).parent, exist_ok=True) 131 | 132 | if cfg.model_type == "Unconditional": 133 | run_unconditional_molecule_analysis(cfg) 134 | elif cfg.model_type == "Conditional": 135 | with open_dict(cfg): 136 | input_molecule_dir = copy.deepcopy(cfg.input_molecule_dir) 137 | bust_results_filepath = copy.deepcopy(cfg.bust_results_filepath) 138 | cfg.property = cfg.property.replace("_", "") 139 | cfg.input_molecule_dir = input_molecule_dir 140 | cfg.bust_results_filepath = bust_results_filepath 141 | assert cfg.property in ["alpha", "gap", "homo", "lumo", "mu", "Cv"], "The given property is not supported." 142 | run_conditional_molecule_analysis(cfg) 143 | else: 144 | raise ValueError(f"Unsupported model type: {cfg.model_type}") 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /src/analysis/optimization_analysis.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from typing import Any 9 | 10 | 11 | def parse_ms_value(ms: Any) -> tuple[float, float]: 12 | """Parse the MS value and its error, return as a tuple.""" 13 | if isinstance(ms, float) or ms == "N/A": 14 | return (ms, 0) # No error for single float values or N/A 15 | ms_parts = ms.split("±") 16 | value = float(ms_parts[0].strip()) 17 | error = float(ms_parts[1].strip()) if len(ms_parts) > 1 else 0 18 | return (value, error) 19 | 20 | 21 | def format_ms_annotation(value: Any, error: float) -> str: 22 | """Format the MS annotation based on value and error.""" 23 | if value == "N/A": 24 | return "N/A" 25 | lower = value - error 26 | upper = value + error 27 | return r"$MS \in " + f"[{lower:.1f}\%, {upper:.1f}\%]" + "$" 28 | 29 | 30 | def main(): 31 | # Define plot data 32 | data = { 33 | "Initial Samples (Moderately Stable)": { 34 | r"$\alpha\ (Bohr^{3})$": {"value": "4.61 ± 0.2", "MS": 61.7}, 35 | r"$\Delta_{\epsilon}\ (meV)$": {"value": "1.26 ± 0.1", "MS": 61.7}, 36 | r"$\epsilon_{HOMO}\ (meV)$": {"value": "0.53 ± 0.0", "MS": 61.7}, 37 | r"$\epsilon_{LUMO}\ (meV)$": {"value": "1.25 ± 0.0", "MS": 61.7}, 38 | r"$\mu\ (D)$": {"value": "1.35 ± 0.1", "MS": 61.7}, 39 | r"$C_{v}\ (\frac{cal}{mol} K)$": {"value": "2.93 ± 0.1", "MS": 61.7}, 40 | }, 41 | "EDM-Opt (100 steps)": { 42 | r"$\alpha\ (Bohr^{3})$": {"value": "4.45 ± 0.6", "MS": "77.6 ± 2.1"}, 43 | r"$\Delta_{\epsilon}\ (meV)$": {"value": "0.98 ± 0.1", "MS": "80.0 ± 2.0"}, 44 | r"$\epsilon_{HOMO}\ (meV)$": {"value": "0.45 ± 0.0", "MS": "78.8 ± 1.0"}, 45 | r"$\epsilon_{LUMO}\ (meV)$": {"value": "0.91 ± 0.0", "MS": "83.4 ± 4.6"}, 46 | r"$\mu\ (D)$": {"value": "6e5 ± 6e5", "MS": "78.3 ± 2.9"}, 47 | r"$C_{v}\ (\frac{cal}{mol} K)$": {"value": "2.72 ± 2.6", "MS": "51.0 ± 109.7"}, 48 | }, 49 | "EDM-Opt (250 steps)": { 50 | r"$\alpha\ (Bohr^{3})$": {"value": "1e2 ± 5e2", "MS": "80.1 ± 2.1"}, 51 | r"$\Delta_{\epsilon}\ (meV)$": {"value": "1e3 ± 6e3", "MS": "83.7 ± 3.8"}, 52 | r"$\epsilon_{HOMO}\ (meV)$": {"value": "0.44 ± 0.0", "MS": "82.5 ± 1.3"}, 53 | r"$\epsilon_{LUMO}\ (meV)$": {"value": "0.91 ± 0.1", "MS": "84.7 ± 1.6"}, 54 | r"$\mu\ (D)$": {"value": "2e5 ± 8e5", "MS": "81.0 ± 5.8"}, 55 | r"$C_{v}\ (\frac{cal}{mol} K)$": {"value": "2.15 ± 0.1", "MS": "78.5 ± 3.4"}, 56 | }, 57 | "GCDM-Opt (100 steps)": { 58 | r"$\alpha\ (Bohr^{3})$": {"value": "3.29 ± 0.1", "MS": "86.2 ± 1.3"}, 59 | r"$\Delta_{\epsilon}\ (meV)$": {"value": "0.93 ± 0.0", "MS": "89.0 ± 1.9"}, 60 | r"$\epsilon_{HOMO}\ (meV)$": {"value": "0.43 ± 0.0", "MS": "91.6 ± 3.5"}, 61 | r"$\epsilon_{LUMO}\ (meV)$": {"value": "0.86 ± 0.0", "MS": "87.0 ± 1.7"}, 62 | r"$\mu\ (D)$": {"value": "1.08 ± 0.1", "MS": "89.9 ± 4.2"}, 63 | r"$C_{v}\ (\frac{cal}{mol} K)$": {"value": "1.81 ± 0.0", "MS": "87.6 ± 1.1"}, 64 | }, 65 | "GCDM-Opt (250 steps)": { 66 | r"$\alpha\ (Bohr^{3})$": {"value": "3.24 ± 0.2", "MS": "86.6 ± 1.9"}, 67 | r"$\Delta_{\epsilon}\ (meV)$": {"value": "0.93 ± 0.0", "MS": "89.7 ± 2.2"}, 68 | r"$\epsilon_{HOMO}\ (meV)$": {"value": "0.43 ± 0.0", "MS": "90.7 ± 0.0"}, 69 | r"$\epsilon_{LUMO}\ (meV)$": {"value": "0.85 ± 0.0", "MS": "88.6 ± 3.8"}, 70 | r"$\mu\ (D)$": {"value": "1.04 ± 0.0", "MS": "89.5 ± 2.6"}, 71 | r"$C_{v}\ (\frac{cal}{mol} K)$": {"value": "1.82 ± 0.1", "MS": "87.6 ± 2.3"}, 72 | }, 73 | } 74 | 75 | # Prepare data for plotting 76 | data_groups = {} 77 | for k, v in data.items(): 78 | values, errors, ms_values = {}, {}, {} 79 | for prop in v: 80 | raw_value = float(v[prop]["value"].split("±")[0].strip()) 81 | ms_value, ms_error = parse_ms_value(v[prop]["MS"]) 82 | if raw_value > 50: 83 | values[prop], errors[prop], ms_values[prop] = np.nan, 0, ("N/A", "N/A") 84 | else: 85 | values[prop] = raw_value 86 | errors[prop] = float(v[prop]["value"].split("±")[1].strip()) 87 | ms_values[prop] = (ms_value, ms_error) 88 | data_groups[k] = {"values": values, "errors": errors, "MS": ms_values} 89 | 90 | x_labels = list(next(iter(data_groups.values()))["values"].keys()) 91 | 92 | fig, ax = plt.subplots(figsize=(10, 8)) 93 | 94 | # Adjustments for improved readability 95 | width = 0.15 96 | group_gap = 0.5 # Increased for clearer separation 97 | n_groups = len(data_groups) 98 | total_width = n_groups * width + (n_groups - 1) * group_gap 99 | positions = np.arange(len(x_labels)) * (total_width + group_gap) # Adjusted calculation 100 | 101 | for i, (group, group_data) in enumerate(data_groups.items()): 102 | values = list(group_data["values"].values()) 103 | errors = list(group_data["errors"].values()) 104 | ms_values = list(group_data["MS"].values()) 105 | bar_positions = [pos + i * (width + group_gap) for pos in positions] 106 | bars = ax.barh(bar_positions, values, width, label=group, xerr=errors, capsize=2, alpha=0.8, edgecolor="black") 107 | 108 | for j, value in enumerate(values): 109 | if np.isnan(value): 110 | # Correctly place an 'x' symbol for missing values 111 | # Ensure the 'x' marker is at the correct y-axis position corresponding to the missing value's location 112 | ax.text(0, bar_positions[j], "x", color="red", va="center", ha="center", fontsize=12, weight="bold") 113 | 114 | for bar, (ms, error) in zip(bars, ms_values): 115 | if not isinstance(ms, str) or ms != "N/A": 116 | ms_annotation = format_ms_annotation(ms, error) 117 | ax.annotate(ms_annotation, (bar.get_width(), bar.get_y() + bar.get_height() / 2 + 0.35), 118 | textcoords="offset points", xytext=(5, 0), ha="left", 119 | fontsize=8, color="darkblue", weight="black") # Adjusted for readability 120 | 121 | ax.set_ylabel("Task") 122 | ax.set_xlabel("Property MAE / Molecule Stability (MS) %") 123 | ax.set_yticks([pos + total_width / 2 - width / 2 for pos in positions]) 124 | ax.set_yticklabels(x_labels, rotation=45, va="center") 125 | ax.grid(True, which='both', axis='x', linestyle='-.', linewidth=0.5) 126 | 127 | for pos in positions[1:]: 128 | ax.axhline(y=pos - group_gap / 2, color="black", linewidth=2) # Make separation lines clearer 129 | 130 | ax.legend(loc="best") 131 | ax.invert_yaxis() 132 | 133 | plt.tight_layout() 134 | plt.savefig("qm9_property_optimization_results.png", dpi=300) 135 | plt.show() 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /src/analysis/qm_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def main(xyz_filepath: str, dataset: str, memory: str, num_threads: int, verbose: bool = True): 5 | if dataset == "qm9": 6 | import psi4 7 | 8 | # Set memory and number of threads 9 | psi4.set_memory(memory) 10 | psi4.set_num_threads(num_threads) 11 | 12 | # Set computation options 13 | psi4.set_options({ 14 | "basis": "6-31G(2df,p)", 15 | "scf_type": "pk", 16 | "e_convergence": 1e-8, 17 | "d_convergence": 1e-8, 18 | }) 19 | 20 | # Create Psi4 geometry from the XYZ contents 21 | with open(xyz_filepath, "r") as file: 22 | xyz_contents = file.read() 23 | molecule = psi4.geometry(xyz_contents) 24 | 25 | # Calculate polarizability 26 | energy = psi4.properties("B3LYP", properties=["dipole_polarizabilities"], molecule=molecule) 27 | 28 | # Print the final energy value 29 | if verbose: 30 | print(f"Final energy of molecule: {energy} (a.u.)") 31 | 32 | elif dataset == "drugs": 33 | import subprocess 34 | 35 | subprocess.run(["crest", xyz_filepath, "--single-point", "GFN2-xTB", "-T", str(num_threads), "-quick"], check=True) 36 | else: 37 | raise ValueError(f"Dataset '{dataset}' not recognized.") 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser(description="Perform quantum mechanical analysis on a molecule.") 42 | parser.add_argument("xyz_filepath", type=str, help="Path to the XYZ file containing the molecule.") 43 | parser.add_argument("--dataset", type=str, default="qm9", choices=["qm9", "drugs"], help="Name of the dataset for which to run QM calculations.") 44 | parser.add_argument("--memory", type=str, default="32 GB", help="Amount of memory to use.") 45 | parser.add_argument("--num_threads", type=int, default=4, help="Number of threads to use.") 46 | args = parser.parse_args() 47 | main(args.xyz_filepath, args.dataset, args.memory, args.num_threads) 48 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/src/datamodules/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/src/datamodules/components/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/edm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://raw.githubusercontent.com/arneschneuing/DiffSBDD/ 3 | """ 4 | 5 | import torch 6 | 7 | import numpy as np 8 | 9 | import src.datamodules.components.edm.constants as edm_constants 10 | 11 | from typing import Any, Dict, List, Tuple, Union 12 | 13 | from torchtyping import TensorType, patch_typeguard 14 | from typeguard import typechecked 15 | 16 | from src.utils.pylogger import get_pylogger 17 | 18 | patch_typeguard() # use before @typechecked 19 | 20 | 21 | log = get_pylogger(__name__) 22 | 23 | 24 | @typechecked 25 | def get_bond_length_arrays(atom_mapping: Dict[str, int]) -> List[np.ndarray]: 26 | bond_arrays = [] 27 | for i in range(3): 28 | bond_dict = getattr(edm_constants, f'bonds{i + 1}') 29 | bond_array = np.zeros((len(atom_mapping), len(atom_mapping))) 30 | for a1 in atom_mapping.keys(): 31 | for a2 in atom_mapping.keys(): 32 | if a1 in bond_dict and a2 in bond_dict[a1]: 33 | bond_len = bond_dict[a1][a2] 34 | else: 35 | bond_len = 0 36 | bond_array[atom_mapping[a1], atom_mapping[a2]] = bond_len 37 | 38 | assert np.all(bond_array == bond_array.T) 39 | bond_arrays.append(bond_array) 40 | 41 | return bond_arrays 42 | 43 | 44 | @typechecked 45 | def get_bond_order( 46 | atom1: str, 47 | atom2: str, 48 | distance: Union[float, np.float32] 49 | ) -> int: 50 | distance = 100 * distance # We change the metric 51 | if atom1 in edm_constants.bonds3 and atom2 in edm_constants.bonds3[atom1] and distance < edm_constants.bonds3[atom1][atom2] + edm_constants.margin3: 52 | return 3 # triple bond 53 | if atom1 in edm_constants.bonds2 and atom2 in edm_constants.bonds2[atom1] and distance < edm_constants.bonds2[atom1][atom2] + edm_constants.margin2: 54 | return 2 # double bond 55 | if atom1 in edm_constants.bonds1 and atom2 in edm_constants.bonds1[atom1] and distance < edm_constants.bonds1[atom1][atom2] + edm_constants.margin1: 56 | return 1 # single bond 57 | return 0 # no bond 58 | 59 | 60 | @typechecked 61 | def get_bond_order_batch( 62 | atoms1: TensorType["num_pairwise_atoms"], 63 | atoms2: TensorType["num_pairwise_atoms"], 64 | distances: TensorType["num_pairwise_atoms"], 65 | dataset_info: Dict[str, Any], 66 | limit_bonds_to_one: bool = False 67 | ) -> TensorType["num_pairwise_atoms"]: 68 | distances = 100 * distances # note: we change the metric 69 | 70 | bonds1 = torch.tensor(dataset_info['bonds1']) 71 | bonds2 = torch.tensor(dataset_info['bonds2']) 72 | bonds3 = torch.tensor(dataset_info['bonds3']) 73 | 74 | bond_types = torch.zeros_like(atoms1) # note: `0` indicates no bond 75 | 76 | # single bond 77 | bond_types[distances < (bonds1[atoms1, atoms2] + edm_constants.margin1)] = 1 78 | # double bond (note: already assigned single bonds will be overwritten) 79 | bond_types[distances < (bonds2[atoms1, atoms2] + edm_constants.margin2)] = 2 80 | # triple bond 81 | bond_types[distances < (bonds3[atoms1, atoms2] + edm_constants.margin3)] = 3 82 | 83 | if limit_bonds_to_one: 84 | # e.g., for datasets such as GEOM-Drugs 85 | bond_types[bond_types > 1] = 1 86 | 87 | return bond_types 88 | 89 | 90 | @typechecked 91 | def check_molecular_stability( 92 | positions: TensorType["num_nodes", 3], 93 | atom_types: TensorType["num_nodes"], 94 | dataset_info: Dict[str, Any], 95 | verbose: bool = False 96 | ) -> Tuple[bool, int, int]: 97 | assert len(positions.shape) == 2 98 | assert positions.shape[1] == 3 99 | 100 | atom_decoder = dataset_info['atom_decoder'] 101 | n = len(positions) 102 | 103 | dists = torch.cdist(positions, positions, p=2.0).reshape(-1) 104 | atoms1, atoms2 = torch.meshgrid(atom_types, atom_types, indexing="xy") 105 | atoms1, atoms2 = atoms1.reshape(-1), atoms2.reshape(-1) 106 | order = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).numpy().reshape(n, n) 107 | np.fill_diagonal(order, 0) # mask out diagonal (i.e., self) bonds 108 | nr_bonds = np.sum(order, axis=1) 109 | 110 | nr_stable_bonds = 0 111 | for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds): 112 | possible_bonds = edm_constants.allowed_bonds[atom_decoder[atom_type_i]] 113 | if type(possible_bonds) == int: 114 | is_stable = possible_bonds == nr_bonds_i 115 | else: 116 | is_stable = nr_bonds_i in possible_bonds 117 | if not is_stable and verbose: 118 | log.info("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type_i], nr_bonds_i)) 119 | nr_stable_bonds += int(is_stable) 120 | 121 | molecule_stable = nr_stable_bonds == n 122 | return molecule_stable, nr_stable_bonds, n 123 | -------------------------------------------------------------------------------- /src/datamodules/components/edm/collate.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ehoogeboom/e3_diffusion_for_molecules/ 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def batch_stack(props): 9 | """ 10 | Stack a list of torch.tensors so they are padded to the size of the 11 | largest tensor along each axis. 12 | 13 | Parameters 14 | ---------- 15 | props : list of Pytorch Tensors 16 | Pytorch tensors to stack 17 | 18 | Returns 19 | ------- 20 | props : Pytorch tensor 21 | Stacked pytorch tensor. 22 | 23 | Notes 24 | ----- 25 | TODO : Review whether the behavior when elements are not tensors is safe. 26 | """ 27 | if not torch.is_tensor(props[0]): 28 | return torch.tensor(props) 29 | elif props[0].dim() == 0: 30 | return torch.stack(props) 31 | else: 32 | return torch.nn.utils.rnn.pad_sequence(props, batch_first=True, padding_value=0) 33 | 34 | 35 | def drop_zeros(props, to_keep): 36 | """ 37 | Function to drop zeros from batches when the entire dataset is padded to the largest molecule size. 38 | 39 | Parameters 40 | ---------- 41 | props : Pytorch tensor 42 | Full Dataset 43 | 44 | 45 | Returns 46 | ------- 47 | props : Pytorch tensor 48 | The dataset with only the retained information. 49 | 50 | Notes 51 | ----- 52 | TODO : Review whether the behavior when elements are not tensors is safe. 53 | """ 54 | if not torch.is_tensor(props[0]): 55 | return props 56 | elif props[0].dim() == 0: 57 | return props 58 | else: 59 | return props[:, to_keep, ...] 60 | 61 | 62 | class PreprocessQM9: 63 | def __init__(self, load_charges=True): 64 | self.load_charges = load_charges 65 | 66 | def add_trick(self, trick): 67 | self.tricks.append(trick) 68 | 69 | def collate_fn(self, batch): 70 | """ 71 | Collation function that collates datapoints into the batch format for cormorant 72 | 73 | Parameters 74 | ---------- 75 | batch : list of datapoints 76 | The data to be collated. 77 | 78 | Returns 79 | ------- 80 | batch : dict of Pytorch tensors 81 | The collated data. 82 | """ 83 | batch = {prop: batch_stack([mol[prop] for mol in batch]) for prop in batch[0].keys()} 84 | 85 | to_keep = (batch["charges"].sum(0) > 0) 86 | 87 | batch = {key: drop_zeros(prop, to_keep) for key, prop in batch.items()} 88 | 89 | atom_mask = batch["charges"] > 0 90 | batch["atom_mask"] = atom_mask 91 | 92 | # Obtain edges 93 | batch_size, n_nodes = atom_mask.size() 94 | edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2) 95 | 96 | # mask diagonal 97 | diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0) 98 | edge_mask *= diag_mask 99 | 100 | #edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2) 101 | batch["edge_mask"] = edge_mask.view(batch_size * n_nodes * n_nodes, 1) 102 | 103 | if self.load_charges: 104 | batch["charges"] = batch["charges"].unsqueeze(2) 105 | else: 106 | batch["charges"] = torch.zeros(0) 107 | return batch 108 | -------------------------------------------------------------------------------- /src/datamodules/components/edm/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ehoogeboom/e3_diffusion_for_molecules/ 3 | """ 4 | 5 | import torch 6 | import os 7 | 8 | import src.datamodules.components.edm.build_geom_dataset as build_geom_dataset 9 | 10 | from omegaconf import DictConfig 11 | from functools import partial 12 | 13 | from torch.utils.data.dataloader import DataLoader as TorchDataLoader 14 | from torch_geometric.loader.dataloader import DataLoader as PyGDataLoader 15 | 16 | from src.datamodules.components.edm.collate import PreprocessQM9 17 | from src.datamodules.components.edm.datasets_config import get_dataset_info 18 | from src.datamodules.components.edm.utils import initialize_datasets 19 | from src.utils.pylogger import get_pylogger 20 | 21 | 22 | log = get_pylogger(__name__) 23 | 24 | 25 | SHARING_STRATEGY = "file_system" 26 | torch.multiprocessing.set_sharing_strategy(SHARING_STRATEGY) 27 | 28 | 29 | def set_worker_sharing_strategy(worker_id: int): 30 | torch.multiprocessing.set_sharing_strategy(SHARING_STRATEGY) 31 | 32 | 33 | def retrieve_dataloaders(dataloader_cfg: DictConfig): 34 | if "QM9" in dataloader_cfg.dataset: 35 | batch_size = dataloader_cfg.batch_size 36 | num_workers = dataloader_cfg.num_workers 37 | filter_n_atoms = dataloader_cfg.filter_n_atoms 38 | # Initialize dataloader 39 | cfg_, datasets, _, charge_scale = initialize_datasets(dataloader_cfg, 40 | dataloader_cfg.data_dir, 41 | dataloader_cfg.dataset, 42 | subtract_thermo=dataloader_cfg.subtract_thermo, 43 | force_download=dataloader_cfg.force_download, 44 | remove_h=dataloader_cfg.remove_h, 45 | create_pyg_graphs=dataloader_cfg.create_pyg_graphs, 46 | num_radials=dataloader_cfg.num_radials, 47 | device=dataloader_cfg.device) 48 | qm9_to_eV = { 49 | "U0": 27.2114, "U": 27.2114, "G": 27.2114, "H": 27.2114, 50 | "zpve": 27211.4, "gap": 27.2114, "homo": 27.2114, "lumo": 27.2114 51 | } 52 | 53 | for dataset in datasets.values(): 54 | dataset.convert_units(qm9_to_eV) 55 | 56 | if filter_n_atoms is not None: 57 | log.info("Retrieving molecules with only %d atoms" % filter_n_atoms) 58 | datasets = filter_atoms(datasets, filter_n_atoms) 59 | 60 | # Construct PyTorch dataloaders from datasets 61 | preprocess = PreprocessQM9(load_charges=dataloader_cfg.include_charges) 62 | dataloader_class = ( 63 | partial(PyGDataLoader, prefetch_factor=100, worker_init_fn=set_worker_sharing_strategy) 64 | if dataloader_cfg.create_pyg_graphs 65 | else partial(TorchDataLoader, collate_fn=preprocess.collate_fn) 66 | ) 67 | dataloaders = {split: dataloader_class(dataset, 68 | num_workers=num_workers, 69 | batch_size=batch_size, 70 | shuffle=cfg_.shuffle if (split == "train") else False, 71 | drop_last=cfg_.drop_last if (split == "train") else False) 72 | for split, dataset in datasets.items()} 73 | elif "GEOM" in dataloader_cfg.dataset: 74 | data_file = os.path.join("data", "EDM", "GEOM", "GEOM_drugs_30.npy") 75 | dataset_info = get_dataset_info(dataloader_cfg.dataset, dataloader_cfg.remove_h) 76 | 77 | # Retrieve QM9 dataloaders 78 | split_data = build_geom_dataset.load_split_data(data_file, 79 | val_proportion=0.1, 80 | test_proportion=0.1, 81 | filter_size=dataloader_cfg.filter_molecule_size) 82 | transform = build_geom_dataset.GeomDrugsTransform(dataset_info, 83 | dataloader_cfg.include_charges, 84 | dataloader_cfg.device, 85 | dataloader_cfg.sequential) 86 | dataloaders = {} 87 | for split, data_list in zip(["train", "valid", "test"], split_data): 88 | dataset = build_geom_dataset.GeomDrugsDataset(data_list, 89 | transform=transform, 90 | create_pyg_graphs=dataloader_cfg.create_pyg_graphs, 91 | num_radials=dataloader_cfg.num_radials, 92 | device=dataloader_cfg.device) 93 | shuffle = (split == "train") and not dataloader_cfg.sequential 94 | 95 | # Sequential dataloading disabled for now. 96 | dataloader_class = ( 97 | partial(build_geom_dataset.GeomDrugsPyGDataLoader, sequential=dataloader_cfg.sequential) 98 | if dataloader_cfg.create_pyg_graphs 99 | else partial(build_geom_dataset.GeomDrugsTorchDataLoader, sequential=dataloader_cfg.sequential) 100 | ) 101 | dataloaders[split] = dataloader_class( 102 | dataset=dataset, 103 | batch_size=dataloader_cfg.batch_size, 104 | shuffle=shuffle 105 | ) 106 | del split_data 107 | charge_scale = None 108 | else: 109 | raise ValueError(f"Unknown dataset {dataloader_cfg.dataset}") 110 | 111 | return dataloaders, charge_scale 112 | 113 | 114 | def filter_atoms(datasets, n_nodes): 115 | for key in datasets: 116 | dataset = datasets[key] 117 | idxs = dataset.data["num_atoms"] == n_nodes 118 | for key2 in dataset.data: 119 | dataset.data[key2] = dataset.data[key2][idxs] 120 | 121 | datasets[key].num_pts = dataset.data["one_hot"].size(0) 122 | datasets[key].perm = None 123 | return datasets 124 | -------------------------------------------------------------------------------- /src/datamodules/components/edm/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ehoogeboom/e3_diffusion_for_molecules/ 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | from src.datamodules.components.edm.md17 import download_dataset_md17 9 | from src.datamodules.components.edm.qm9 import download_dataset_qm9 10 | 11 | 12 | def prepare_dataset(data_dir, dataset, subset=None, splits=None, cleanup=True, force_download=False): 13 | """ 14 | Download and process dataset. 15 | 16 | Parameters 17 | ---------- 18 | data_dir : str 19 | Path to the directory where the data and calculations and is, or will be, stored. 20 | dataset : str 21 | String specification of the dataset. If it is not already downloaded, must currently by "QM9" or "MD17". 22 | subset : str, optional 23 | Which subset of a dataset to use. Action is dependent on the dataset given. 24 | Must be specified if the dataset has subsets (i.e. MD17). Otherwise ignored (i.e. GDB9). 25 | splits : dict, optional 26 | Dataset splits to use. 27 | cleanup : bool, optional 28 | Clean up files created while preparing the data. 29 | force_download : bool, optional 30 | If true, forces a fresh download of the dataset. 31 | 32 | Returns 33 | ------- 34 | datafiles : dict of strings 35 | Dictionary of strings pointing to the files containing the data. 36 | 37 | Notes 38 | ----- 39 | TODO: Delete the splits argument? 40 | """ 41 | 42 | # If datasets have subsets, 43 | if subset: 44 | dataset_dir = [data_dir, dataset, subset] 45 | else: 46 | dataset_dir = [data_dir, dataset] 47 | 48 | # Names of splits, based upon keys if split dictionary exists, elsewise default to train/valid/test. 49 | split_names = splits.keys() if splits is not None else [ 50 | "train", "valid", "test" 51 | ] 52 | 53 | # Assume one data file for each split 54 | datafiles = {split: os.path.join( 55 | *(dataset_dir + [split + ".npz"])) for split in split_names} 56 | 57 | # Check datafiles exist 58 | datafiles_checks = [os.path.exists(datafile) 59 | for datafile in datafiles.values()] 60 | 61 | # Check if prepared dataset exists, and if not set flag to download below. 62 | # Probably should add more consistency checks, such as number of datapoints, etc... 63 | new_download = False 64 | if all(datafiles_checks): 65 | logging.info("Dataset exists and is processed.") 66 | elif all([not x for x in datafiles_checks]): 67 | # If checks are failed. 68 | new_download = True 69 | else: 70 | raise ValueError( 71 | "Dataset only partially processed. Try deleting {} and running again to download/process.".format(os.path.join(dataset_dir))) 72 | 73 | # If need to download dataset, pass to appropriate downloader 74 | if new_download or force_download: 75 | logging.info("Dataset does not exist. Downloading!") 76 | if dataset.lower().startswith("qm9"): 77 | download_dataset_qm9(data_dir, dataset, splits, cleanup=cleanup) 78 | elif dataset.lower().startswith("md17"): 79 | download_dataset_md17(data_dir, dataset, subset, 80 | splits, cleanup=cleanup) 81 | else: 82 | raise ValueError( 83 | "Incorrect choice of dataset! Must chose QM9/MD17!") 84 | 85 | return datafiles 86 | -------------------------------------------------------------------------------- /src/datamodules/components/edm/md17.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ehoogeboom/e3_diffusion_for_molecules/ 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | import logging 9 | import os 10 | 11 | from urllib.request import urlopen 12 | 13 | md17_base_url = "http://quantum-machine.org/gdml/data/npz/" 14 | 15 | md17_subsets = {"benzene": "benzene_old_dft", 16 | "uracil": "uracil_dft", 17 | "naphthalene": "naphthalene_dft", 18 | "aspirin": "aspirin_dft", 19 | "salicylic_acid": "salicylic_dft", 20 | "malonaldehyde": "malonaldehyde_dft", 21 | "ethanol": "ethanol_dft", 22 | "toluene": "toluene_dft", 23 | "paracetamol": "paracetamol_dft", 24 | "azobenzene": "azobenzene_dft" 25 | } 26 | 27 | 28 | def download_data(url, outfile="", binary=False): 29 | """ 30 | Downloads data from a URL and returns raw data. 31 | 32 | Parameters 33 | ---------- 34 | url : str 35 | URL to get the data from 36 | outfile : str, optional 37 | Where to save the data. 38 | binary : bool, optional 39 | If true, writes data in binary. 40 | """ 41 | # Try statement to catch downloads. 42 | try: 43 | # Download url using urlopen 44 | with urlopen(url) as f: 45 | data = f.read() 46 | logging.info("Data download success!") 47 | success = True 48 | except: 49 | logging.info("Data download failed!") 50 | success = False 51 | 52 | if binary: 53 | # If data is binary, use "wb" if outputting to file 54 | writeflag = "wb" 55 | else: 56 | # If data is string, convert to string and use "w" if outputting to file 57 | writeflag = "w" 58 | data = data.decode("utf-8") 59 | 60 | if outfile: 61 | logging.info("Saving downloaded data to file: {}".format(outfile)) 62 | 63 | with open(outfile, writeflag) as f: 64 | f.write(data) 65 | 66 | return data, success 67 | 68 | 69 | def cleanup_file(file, cleanup=True): 70 | if cleanup: 71 | try: 72 | os.remove(file) 73 | except OSError: 74 | pass 75 | 76 | 77 | def download_dataset_md17(datadir, dataname, subset, splits=None, cleanup=True): 78 | """ 79 | Downloads the MD17 dataset. 80 | """ 81 | if subset not in md17_subsets: 82 | logging.info( 83 | "Molecule {} not included in list of downloadable MD17 datasets! Attempting to download based directly upon input key.".format(subset)) 84 | md17_molecule = subset 85 | else: 86 | md17_molecule = md17_subsets[subset] 87 | 88 | # Define directory for which data will be output. 89 | md17dir = os.path.join(*[datadir, dataname, subset]) 90 | 91 | # Important to avoid a race condition 92 | os.makedirs(md17dir, exist_ok=True) 93 | 94 | logging.info("Downloading and processing molecule {} from MD17 dataset. Output will be in directory: {}.".format(subset, md17dir)) 95 | 96 | md17_data_url = md17_base_url + md17_molecule + ".npz" 97 | md17_data_npz = os.path.join(md17dir, md17_molecule + ".npz") 98 | 99 | download_data(md17_data_url, outfile=md17_data_npz, binary=True) 100 | 101 | # Convert raw MD17 data to torch tensors. 102 | md17_raw_data = np.load(md17_data_npz) 103 | 104 | # Number of molecules in dataset: 105 | num_tot_mols = len(md17_raw_data["E"]) 106 | 107 | # Dictionary to convert keys in MD17 database to those used in this code. 108 | md17_keys = {"E": "energies", "R": "positions", "F": "forces"} 109 | 110 | # Convert numpy arrays to torch.Tensors 111 | md17_data = {new_key: md17_raw_data[old_key] for old_key, new_key in md17_keys.items()} 112 | 113 | # Reshape energies to remove final singleton dimension 114 | md17_data["energies"] = md17_data["energies"].squeeze(1) 115 | 116 | # Add charges to md17_data 117 | md17_data["charges"] = np.tile(md17_raw_data["z"], (num_tot_mols, 1)) 118 | 119 | # If splits are not specified, automatically generate them. 120 | if splits is None: 121 | splits = gen_splits_md17(num_tot_mols) 122 | 123 | # Process GDB9 dataset, and return dictionary of splits 124 | md17_data_split = {} 125 | for split, split_idx in splits.items(): 126 | md17_data_split[split] = {key: val[split_idx] if type( 127 | val) is np.ndarray else val for key, val in md17_data.items()} 128 | 129 | # Save processed GDB9 data into train/validation/test splits 130 | logging.info("Saving processed data:") 131 | for split, data_split in md17_data_split.items(): 132 | savefile = os.path.join(md17dir, split + ".npz") 133 | np.savez_compressed(savefile, **data_split) 134 | 135 | cleanup_file(md17_data_npz, cleanup) 136 | 137 | 138 | def gen_splits_md17(num_pts): 139 | """ 140 | Generate the splits used to train/evaluate the network in the original Cormorant paper. 141 | """ 142 | # deterministically generate random split based upon random permutation 143 | np.random.seed(0) 144 | data_perm = np.random.permutation(num_pts) 145 | 146 | # Create masks for which splits to invoke 147 | mask_train = np.zeros(num_pts, dtype=np.bool) 148 | mask_valid = np.zeros(num_pts, dtype=np.bool) 149 | mask_test = np.zeros(num_pts, dtype=np.bool) 150 | 151 | # For historical reasons, this is the indexing on the 152 | # 50k/10k/10k train/valid/test splits used in the paper. 153 | mask_train[:10000] = True 154 | mask_valid[10000:20000] = True 155 | mask_test[20000:30000] = True 156 | mask_train[30000:70000] = True 157 | 158 | # COnvert masks to splits 159 | splits = {} 160 | splits["train"] = torch.tensor(data_perm[mask_train]) 161 | splits["valid"] = torch.tensor(data_perm[mask_valid]) 162 | splits["test"] = torch.tensor(data_perm[mask_test]) 163 | 164 | return splits 165 | -------------------------------------------------------------------------------- /src/datamodules/components/edm/process.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ehoogeboom/e3_diffusion_for_molecules/ 3 | """ 4 | 5 | import logging 6 | import os 7 | import torch 8 | import tarfile 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | charge_dict = {"H": 1, "C": 6, "N": 7, "O": 8, "F": 9} 12 | 13 | 14 | def split_dataset(data, split_idxs): 15 | """ 16 | Splits a dataset according to the indices given. 17 | 18 | Parameters 19 | ---------- 20 | data : dict 21 | Dictionary to split. 22 | split_idxs : dict 23 | Dictionary defining the split. Keys are the name of the split, and 24 | values are the keys for the items in data that go into the split. 25 | 26 | Returns 27 | ------- 28 | split_dataset : dict 29 | The split dataset. 30 | """ 31 | split_data = {} 32 | for set, split in split_idxs.items(): 33 | split_data[set] = {key: val[split] for key, val in data.items()} 34 | 35 | return split_data 36 | 37 | 38 | def process_xyz_files(data, process_file_fn, file_ext=None, file_idx_list=None, stack=True): 39 | """ 40 | Take a set of datafiles and apply a predefined data processing script to each 41 | one. Data can be stored in a directory, tarfile, or zipfile. An optional 42 | file extension can be added. 43 | 44 | Parameters 45 | ---------- 46 | data : str 47 | Complete path to datafiles. Files must be in a directory, tarball, or zip archive. 48 | process_file_fn : callable 49 | Function to process files. Can be defined externally. 50 | Must input a file, and output a dictionary of properties, each of which 51 | is a torch.tensor. Dictionary must contain at least three properties: 52 | {"num_elements", "charges", "positions"} 53 | file_ext : str, optional 54 | Optionally add a file extension if multiple types of files exist. 55 | file_idx_list : ?????, optional 56 | Optionally add a file filter to check a file index is in a 57 | predefined list, for example, when constructing a train/valid/test split. 58 | stack : bool, optional 59 | ????? 60 | """ 61 | logging.info("Processing data file: {}".format(data)) 62 | if tarfile.is_tarfile(data): 63 | tardata = tarfile.open(data, "r") 64 | files = tardata.getmembers() 65 | 66 | def readfile(data_pt): return tardata.extractfile(data_pt) 67 | 68 | elif os.is_dir(data): 69 | files = os.listdir(data) 70 | files = [os.path.join(data, file) for file in files] 71 | 72 | def readfile(data_pt): return open(data_pt, "r") 73 | 74 | else: 75 | raise ValueError("Can only read from directory or tarball archive!") 76 | 77 | # Use only files that end with specified extension. 78 | if file_ext is not None: 79 | files = [file for file in files if file.endswith(file_ext)] 80 | 81 | # Use only files that match desired filter. 82 | if file_idx_list is not None: 83 | files = [file for idx, file in enumerate(files) if idx in file_idx_list] 84 | 85 | # Now loop over files using readfile function defined above 86 | # Process each file accordingly using process_file_fn 87 | 88 | molecules = [] 89 | 90 | for file in files: 91 | with readfile(file) as openfile: 92 | molecules.append(process_file_fn(openfile)) 93 | 94 | # Check that all molecules have the same set of items in their dictionary: 95 | props = molecules[0].keys() 96 | assert all(props == mol.keys() for mol in molecules), "All molecules must have same set of properties/keys!" 97 | 98 | # Convert list-of-dicts to dict-of-lists 99 | molecules = {prop: [mol[prop] for mol in molecules] for prop in props} 100 | 101 | # If stacking is desireable, pad and then stack. 102 | if stack: 103 | molecules = {key: pad_sequence(val, batch_first=True) if val[0].dim( 104 | ) > 0 else torch.stack(val) for key, val in molecules.items()} 105 | 106 | return molecules 107 | 108 | 109 | def process_xyz_md17(datafile): 110 | """ 111 | Read xyz file and return a molecular dict with number of atoms, energy, forces, coordinates and atom-type for the MD-17 dataset. 112 | 113 | Parameters 114 | ---------- 115 | datafile : python file object 116 | File object containing the molecular data in the MD17 dataset. 117 | 118 | Returns 119 | ------- 120 | molecule : dict 121 | Dictionary containing the molecular properties of the associated file object. 122 | """ 123 | xyz_lines = [line.decode("UTF-8") for line in datafile.readlines()] 124 | 125 | line_counter = 0 126 | atom_positions = [] 127 | atom_types = [] 128 | for line in xyz_lines: 129 | if line[0] is "#": 130 | continue 131 | if line_counter is 0: 132 | num_atoms = int(line) 133 | elif line_counter is 1: 134 | split = line.split(";") 135 | assert (len(split) == 1 or len(split) == 2), "Improperly formatted energy/force line." 136 | if (len(split) == 1): 137 | e = split[0] 138 | f = None 139 | elif (len(split) == 2): 140 | e, f = split 141 | f = f.split("],[") 142 | atom_energy = float(e) 143 | atom_forces = [[float(x.strip("[]\n")) for x in force.split(",")] for force in f] 144 | else: 145 | split = line.split() 146 | if len(split) is 4: 147 | type, x, y, z = split 148 | atom_types.append(split[0]) 149 | atom_positions.append([float(x) for x in split[1:]]) 150 | else: 151 | logging.debug(line) 152 | line_counter += 1 153 | 154 | atom_charges = [charge_dict[type] for type in atom_types] 155 | 156 | molecule = {"num_atoms": num_atoms, "energy": atom_energy, "charges": atom_charges, 157 | "forces": atom_forces, "positions": atom_positions} 158 | 159 | molecule = {key: torch.tensor(val) for key, val in molecule.items()} 160 | 161 | return molecule 162 | 163 | 164 | def process_xyz_gdb9(datafile): 165 | """ 166 | Read xyz file and return a molecular dict with number of atoms, energy, forces, coordinates and atom-type for the gdb9 dataset. 167 | 168 | Parameters 169 | ---------- 170 | datafile : python file object 171 | File object containing the molecular data in the MD17 dataset. 172 | 173 | Returns 174 | ------- 175 | molecule : dict 176 | Dictionary containing the molecular properties of the associated file object. 177 | """ 178 | xyz_lines = [line.decode("UTF-8") for line in datafile.readlines()] 179 | 180 | num_atoms = int(xyz_lines[0]) 181 | mol_props = xyz_lines[1].split() 182 | mol_xyz = xyz_lines[2:num_atoms+2] 183 | mol_freq = xyz_lines[num_atoms+2] 184 | 185 | atom_charges, atom_positions = [], [] 186 | for line in mol_xyz: 187 | atom, posx, posy, posz, _ = line.replace("*^", "e").split() 188 | atom_charges.append(charge_dict[atom]) 189 | atom_positions.append([float(posx), float(posy), float(posz)]) 190 | 191 | prop_strings = ["tag", "index", "A", "B", "C", "mu", "alpha", 192 | "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] 193 | prop_strings = prop_strings[1:] 194 | mol_props = [int(mol_props[1])] + [float(x) for x in mol_props[2:]] 195 | mol_props = dict(zip(prop_strings, mol_props)) 196 | mol_props["omega1"] = max(float(omega) for omega in mol_freq.split()) 197 | 198 | molecule = {"num_atoms": num_atoms, "charges": atom_charges, "positions": atom_positions} 199 | molecule.update(mol_props) 200 | molecule = {key: torch.tensor(val) for key, val in molecule.items()} 201 | 202 | return molecule 203 | -------------------------------------------------------------------------------- /src/datamodules/components/helper.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from GCPNet (https://github.com/BioinfoMachineLearning/GCPNet): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import torch 6 | from typing import Union 7 | 8 | from torchtyping import patch_typeguard 9 | from typeguard import typechecked 10 | 11 | patch_typeguard() # use before @typechecked 12 | 13 | 14 | @typechecked 15 | def _normalize( 16 | tensor: torch.Tensor, 17 | dim: int = -1 18 | ) -> torch.Tensor: 19 | """ 20 | From https://github.com/drorlab/gvp-pytorch 21 | """ 22 | return torch.nan_to_num( 23 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)) 24 | ) 25 | 26 | @typechecked 27 | def _rbf( 28 | D: torch.Tensor, 29 | D_min: float = 0.0, 30 | D_max: float = 20.0, 31 | D_count: int = 16, 32 | device: Union[torch.device, str] = "cpu" 33 | ) -> torch.Tensor: 34 | """ 35 | From https://github.com/jingraham/neurips19-graph-protein-design 36 | 37 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. 38 | That is, if `D` has shape [...dims], then the returned tensor will have 39 | shape [...dims, D_count]. 40 | """ 41 | D_mu = torch.linspace(D_min, D_max, D_count, device=device) 42 | D_mu = D_mu.view([1, -1]) 43 | D_sigma = (D_max - D_min) / D_count 44 | D_expand = torch.unsqueeze(D, -1) 45 | 46 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) 47 | return RBF 48 | -------------------------------------------------------------------------------- /src/datamodules/components/protein_graph_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from GCPNet (https://github.com/BioinfoMachineLearning/GCPNet): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | from __future__ import print_function, absolute_import, division 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch_cluster 12 | from omegaconf import DictConfig 13 | from torch.nn import functional as F 14 | from torch.utils import data as data 15 | from torch_geometric.data import Data 16 | from typing import Any, Dict, List, Optional, Union 17 | 18 | from src.datamodules.components.helper import _normalize, _rbf 19 | 20 | from torchtyping import TensorType, patch_typeguard 21 | from typeguard import typechecked 22 | 23 | patch_typeguard() # use before @typechecked 24 | 25 | 26 | class ProteinGraphDataset(data.Dataset): 27 | """ 28 | From https://github.com/drorlab/gvp-pytorch 29 | """ 30 | 31 | def __init__(self, 32 | data_list: List[Dict[str, Any]], 33 | features_cfg: DictConfig, 34 | num_positional_embeddings: int = 16, 35 | top_k: int = 30, 36 | num_rbf: int = 16, 37 | device: Union[torch.device, str] = "cpu"): 38 | 39 | super().__init__() 40 | 41 | self.features_cfg = features_cfg 42 | self.data_list = data_list 43 | self.top_k = top_k 44 | self.num_rbf = num_rbf 45 | self.num_positional_embeddings = num_positional_embeddings 46 | self.device = device 47 | self.node_counts = [len(datum["seq"]) for datum in data_list] 48 | self.edge_counts = [len(datum["seq"]) * top_k for datum in data_list] 49 | 50 | self.letter_to_num = { 51 | "C": 4, 52 | "D": 3, 53 | "S": 15, 54 | "Q": 5, 55 | "K": 11, 56 | "I": 9, 57 | "P": 14, 58 | "T": 16, 59 | "F": 13, 60 | "A": 0, 61 | "G": 7, 62 | "H": 8, 63 | "E": 6, 64 | "L": 10, 65 | "R": 1, 66 | "W": 17, 67 | "V": 19, 68 | "N": 2, 69 | "Y": 18, 70 | "M": 12, 71 | } 72 | self.num_to_letter = {v: k for k, v in self.letter_to_num.items()} 73 | self.num_to_letter_list = [None] * 20 74 | for k in self.letter_to_num: 75 | self.num_to_letter_list[self.letter_to_num[k]] = k 76 | 77 | @typechecked 78 | def num_to_letter(self) -> List[str]: 79 | letter_to_num = { 80 | "C": 4, 81 | "D": 3, 82 | "S": 15, 83 | "Q": 5, 84 | "K": 11, 85 | "I": 9, 86 | "P": 14, 87 | "T": 16, 88 | "F": 13, 89 | "A": 0, 90 | "G": 7, 91 | "H": 8, 92 | "E": 6, 93 | "L": 10, 94 | "R": 1, 95 | "W": 17, 96 | "V": 19, 97 | "N": 2, 98 | "Y": 18, 99 | "M": 12, 100 | } 101 | num_to_letter_list = [None] * 20 102 | for k in letter_to_num: 103 | num_to_letter_list[letter_to_num[k]] = k 104 | return num_to_letter_list 105 | 106 | def __len__(self): 107 | return len(self.data_list) 108 | 109 | def __getitem__(self, idx: int): 110 | return self._featurize_as_graph(self.data_list[idx]) 111 | 112 | @typechecked 113 | def _featurize_as_graph(self, protein: Dict[str, Any]) -> Data: 114 | if "name" not in protein: 115 | name = protein["id"] 116 | else: 117 | name = protein["name"] 118 | with torch.no_grad(): 119 | coords = torch.as_tensor(protein["coords"], device=self.device, dtype=torch.float32) 120 | seq = torch.as_tensor([self.letter_to_num[a] for a in protein["seq"]], device=self.device, dtype=torch.long) 121 | 122 | mask = torch.isfinite(coords.sum(dim=(1, 2))) 123 | coords[~mask] = np.inf # ensure missing nodes are assigned no edges 124 | 125 | X_ca = coords[:, 1] 126 | edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k) 127 | 128 | pos_embeddings = self._positional_embeddings(edge_index) 129 | E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] 130 | rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device) 131 | 132 | dihedrals = self._dihedrals(coords) 133 | if not self.features_cfg.dihedral: 134 | dihedrals = torch.zeros_like(dihedrals) 135 | orientations = self._orientations(X_ca) 136 | if not self.features_cfg.orientations: 137 | orientations = torch.zeros_like(orientations) 138 | sidechains = self._sidechains(coords) 139 | if not self.features_cfg.sidechain: 140 | sidechains = torch.zeros_like(sidechains) 141 | 142 | if not self.features_cfg.relative_distance: 143 | rbf = torch.zeros_like(rbf) 144 | if not self.features_cfg.relative_position: 145 | pos_embeddings = torch.zeros_like(pos_embeddings) 146 | if not self.features_cfg.direction_unit: 147 | E_vectors = torch.zeros_like(E_vectors) 148 | 149 | node_s = dihedrals 150 | node_v = torch.cat((orientations, sidechains.unsqueeze(-2)), dim=-2) 151 | edge_s = torch.cat((rbf, pos_embeddings), dim=-1) 152 | edge_v = _normalize(E_vectors).unsqueeze(-2) 153 | 154 | node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, (node_s, node_v, edge_s, edge_v)) 155 | 156 | data = Data( 157 | x=X_ca, 158 | seq=seq, 159 | name=name, 160 | h=node_s, 161 | chi=node_v, 162 | e=edge_s, 163 | xi=edge_v, 164 | edge_index=edge_index, 165 | mask=mask 166 | ) 167 | return data 168 | 169 | @staticmethod 170 | def _dihedrals( 171 | X: TensorType["num_residues", "num_atoms_per_residue", 3], 172 | eps: float = 1e-7 173 | ) -> TensorType["num_residues", 6]: 174 | # From https://github.com/jingraham/neurips19-graph-protein-design 175 | X = torch.reshape(X[:, :3], [3 * X.shape[0], 3]) 176 | dX = X[1:] - X[:-1] 177 | U = _normalize(dX, dim=-1) 178 | u_2 = U[:-2] 179 | u_1 = U[1:-1] 180 | u_0 = U[2:] 181 | 182 | # Backbone normals 183 | n_2 = _normalize(torch.cross(u_2, u_1), dim=-1) 184 | n_1 = _normalize(torch.cross(u_1, u_0), dim=-1) 185 | 186 | # Angle between normals 187 | cosD = torch.sum(n_2 * n_1, -1) 188 | cosD = torch.clamp(cosD, -1 + eps, 1 - eps) 189 | D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) 190 | 191 | # This scheme will remove phi[0], psi[-1], omega[-1] 192 | D = F.pad(D, [1, 2]) 193 | D = torch.reshape(D, [-1, 3]) 194 | 195 | # Lift angle representations to the circle 196 | D_features = torch.cat((torch.cos(D), torch.sin(D)), dim=1) 197 | return D_features 198 | 199 | @typechecked 200 | def _positional_embeddings( 201 | self, 202 | edge_index: TensorType[2, "num_edges"], 203 | num_embeddings: Optional[int] = None 204 | ) -> TensorType["num_edges", "num_embeddings_per_edge"]: 205 | # From https://github.com/jingraham/neurips19-graph-protein-design 206 | num_embeddings = num_embeddings or self.num_positional_embeddings 207 | d = edge_index[0] - edge_index[1] 208 | 209 | frequency = torch.exp( 210 | torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device) 211 | * -(np.log(10000.0) / num_embeddings) 212 | ) 213 | angles = d.unsqueeze(-1) * frequency 214 | E = torch.cat((torch.cos(angles), torch.sin(angles)), dim=-1) 215 | return E 216 | 217 | @staticmethod 218 | def _orientations( 219 | X: TensorType["num_nodes", 3] 220 | ) -> TensorType["num_nodes", 2, 3]: 221 | forward = _normalize(X[1:] - X[:-1]) 222 | backward = _normalize(X[:-1] - X[1:]) 223 | forward = F.pad(forward, [0, 0, 0, 1]) 224 | backward = F.pad(backward, [0, 0, 1, 0]) 225 | return torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2) 226 | 227 | @staticmethod 228 | def _sidechains( 229 | X: TensorType["num_residues", "num_atoms_per_residue", 3] 230 | ) -> TensorType["num_residues", 3]: 231 | n, origin, c = X[:, 0], X[:, 1], X[:, 2] 232 | c, n = _normalize(c - origin), _normalize(n - origin) 233 | bisector = _normalize(c + n) 234 | perp = _normalize(torch.cross(c, n)) 235 | vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) 236 | return vec 237 | -------------------------------------------------------------------------------- /src/datamodules/components/sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from GCPNet (https://github.com/BioinfoMachineLearning/GCPNet): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import random 6 | from operator import itemgetter 7 | from typing import List, Optional, Iterator 8 | 9 | import numpy as np 10 | from torch.utils import data as data 11 | from torch.utils.data import Dataset, Sampler, DistributedSampler 12 | 13 | 14 | class BatchSampler(data.Sampler): 15 | """ 16 | From https://github.com/jingraham/neurips19-graph-protein-design. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | unit_counts: List[int], 22 | max_units: int = 3000, 23 | shuffle: bool = True, 24 | hard_shuffle: bool = False, 25 | **kwargs 26 | ): 27 | self.hard_shuffle = hard_shuffle 28 | self.unit_counts = unit_counts 29 | self.idx = [i for i in range(len(unit_counts)) if unit_counts[i] <= max_units] 30 | self.shuffle = shuffle 31 | self.max_units = max_units 32 | self._form_batches() 33 | 34 | def _form_batches(self): 35 | self.batches = [] 36 | if self.shuffle: 37 | random.shuffle(self.idx) 38 | idx = self.idx 39 | while idx: 40 | batch = [] 41 | n_nodes = 0 42 | while idx and n_nodes + self.unit_counts[idx[0]] <= self.max_units: 43 | next_idx, idx = idx[0], idx[1:] 44 | n_nodes += self.unit_counts[next_idx] 45 | batch.append(next_idx) 46 | self.batches.append(batch) 47 | 48 | def __len__(self) -> int: 49 | if not self.batches: 50 | self._form_batches() 51 | return len(self.batches) 52 | 53 | def __iter__(self): 54 | if not self.batches or (self.shuffle and self.hard_shuffle): 55 | self._form_batches() 56 | elif self.shuffle: 57 | np.random.shuffle(self.batches) 58 | for batch in self.batches: 59 | yield batch 60 | 61 | 62 | class DatasetFromSampler(Dataset): 63 | def __init__(self, sampler: Sampler): 64 | self.sampler = sampler 65 | self.sampler_list = None 66 | 67 | def __getitem__(self, index: int): 68 | if self.sampler_list is None: 69 | self.sampler_list = list(self.sampler) 70 | return self.sampler_list[index] 71 | 72 | def __len__(self) -> int: 73 | return len(self.sampler) 74 | 75 | 76 | class DistributedSamplerWrapper(DistributedSampler): 77 | """ 78 | From https://github.com/catalyst-team/catalyst 79 | 80 | Wrapper over `Sampler` for distributed training. 81 | Allows you to use any sampler in distributed mode. 82 | 83 | It is especially useful in conjunction with 84 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 85 | process can pass a DistributedSamplerWrapper instance as a DataLoader 86 | sampler, and load a subset of subsampled data of the original dataset 87 | that is exclusive to it. 88 | 89 | .. note:: 90 | Sampler is assumed to be of constant size. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | sampler: Sampler, 96 | num_replicas: Optional[int] = None, 97 | rank: Optional[int] = None, 98 | shuffle: bool = True, 99 | **kwargs 100 | ): 101 | """ 102 | 103 | Args: 104 | sampler: Sampler used for subsampling 105 | num_replicas (int, optional): Number of processes participating in 106 | distributed training 107 | rank (int, optional): Rank of the current process 108 | within ``num_replicas`` 109 | shuffle (bool, optional): If true (default), 110 | sampler will shuffle the indices 111 | """ 112 | super().__init__( 113 | DatasetFromSampler(sampler), num_replicas=num_replicas, rank=rank, shuffle=shuffle 114 | ) 115 | self.sampler = sampler 116 | 117 | def __iter__(self) -> Iterator[int]: 118 | """Iterate over sampler. 119 | 120 | Returns: 121 | python iterator 122 | """ 123 | self.dataset = DatasetFromSampler(self.sampler) 124 | indexes_of_indexes = super().__iter__() 125 | subsampler_indexes = self.dataset 126 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 127 | -------------------------------------------------------------------------------- /src/datamodules/edm_datamodule.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | 7 | from typing import Any, Dict, Optional 8 | from omegaconf import DictConfig 9 | 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import DataLoader 12 | 13 | from src.datamodules.components.edm.dataset import retrieve_dataloaders 14 | 15 | 16 | class EDMDataModule(LightningDataModule): 17 | """ 18 | A data wrapper for the EDM datasets. It downloads any missing 19 | data files from Springer Nature or Zenodo. 20 | 21 | :param dataloader_cfg: configuration arguments for EDM dataloaders. 22 | """ 23 | 24 | def __init__(self, dataloader_cfg: DictConfig): 25 | super().__init__() 26 | 27 | # this line allows to access init params with `self.hparams` attribute 28 | # also ensures init params will be stored in ckpt 29 | self.save_hyperparameters(logger=False) 30 | 31 | self.dataloader_train: Optional[DataLoader] = None 32 | self.dataloader_val: Optional[DataLoader] = None 33 | self.dataloader_test: Optional[DataLoader] = None 34 | 35 | def prepare_data(self): 36 | """Download data if needed. 37 | 38 | Do not use it to assign state (e.g., self.x = y). 39 | """ 40 | data_path = os.path.join(self.hparams.dataloader_cfg.data_dir, self.hparams.dataloader_cfg.dataset) 41 | 42 | if "QM9" in self.hparams.dataloader_cfg.dataset and not all([ 43 | os.path.exists(os.path.join(data_path, "train.npz")), 44 | os.path.exists(os.path.join(data_path, "valid.npz")), 45 | os.path.exists(os.path.join(data_path, "test.npz")) 46 | ]): 47 | retrieve_dataloaders(self.hparams.dataloader_cfg) 48 | 49 | elif "GEOM" in self.hparams.dataloader_cfg.dataset and not all([ 50 | os.path.exists(os.path.join(data_path, "GEOM_drugs_30.npy")), 51 | os.path.exists(os.path.join(data_path, "GEOM_drugs_n_30.npy")), 52 | os.path.exists(os.path.join(data_path, "GEOM_drugs_smiles.txt")) 53 | ]): 54 | retrieve_dataloaders(self.hparams.dataloader_cfg) 55 | 56 | def setup(self, stage: Optional[str] = None): 57 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 58 | 59 | Note: This method is called by Lightning with both `trainer.fit()` and `trainer.test()`. 60 | """ 61 | # load dataloaders only if not loaded already 62 | if not self.dataloader_train and not self.dataloader_val and not self.dataloader_test: 63 | self.dataloaders, self.charge_scale = retrieve_dataloaders(self.hparams.dataloader_cfg) 64 | self.dataloader_train, self.dataloader_val, self.dataloader_test = ( 65 | self.dataloaders["train"], self.dataloaders["valid"], self.dataloaders["test"] 66 | ) 67 | 68 | def train_dataloader(self): 69 | return self.dataloader_train 70 | 71 | def val_dataloader(self): 72 | return self.dataloader_val 73 | 74 | def test_dataloader(self): 75 | return self.dataloader_test 76 | 77 | def teardown(self, stage: Optional[str] = None): 78 | """Clean up after fit or test.""" 79 | pass 80 | 81 | def state_dict(self): 82 | """Extra things to save to checkpoint.""" 83 | return {} 84 | 85 | def load_state_dict(self, state_dict: Dict[str, Any]): 86 | """Things to do when loading checkpoint.""" 87 | pass 88 | 89 | 90 | if __name__ == "__main__": 91 | import hydra 92 | import omegaconf 93 | import pyrootutils 94 | 95 | root = pyrootutils.setup_root(__file__, pythonpath=True) 96 | 97 | cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "edm_qm9.yaml") 98 | cfg.data_dir = str(root / "data" / "EDM") 99 | _ = hydra.utils.instantiate(cfg) 100 | 101 | cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "edm_geom.yaml") 102 | cfg.data_dir = str(root / "data" / "EDM") 103 | _ = hydra.utils.instantiate(cfg) 104 | 105 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import logging 6 | 7 | from pytorch_lightning.utilities import rank_zero_only 8 | 9 | 10 | def get_pylogger(name=__name__) -> logging.Logger: 11 | """Initializes multi-GPU-friendly python command line logger.""" 12 | 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | from pathlib import Path 6 | from typing import Sequence 7 | 8 | import rich 9 | import rich.syntax 10 | import rich.tree 11 | from hydra.core.hydra_config import HydraConfig 12 | from omegaconf import DictConfig, OmegaConf, open_dict 13 | from pytorch_lightning.utilities import rank_zero_only 14 | from rich.prompt import Prompt 15 | 16 | from src.utils import pylogger 17 | 18 | log = pylogger.get_pylogger(__name__) 19 | 20 | 21 | @rank_zero_only 22 | def print_config_tree( 23 | cfg: DictConfig, 24 | print_order: Sequence[str] = ( 25 | "datamodule", 26 | "model", 27 | "callbacks", 28 | "logger", 29 | "trainer", 30 | "paths", 31 | "extras", 32 | ), 33 | resolve: bool = False, 34 | save_to_file: bool = False, 35 | ) -> None: 36 | """Prints content of DictConfig using Rich library and its tree structure. 37 | 38 | Args: 39 | cfg (DictConfig): Configuration composed by Hydra. 40 | print_order (Sequence[str], optional): Determines in what order config components are printed. 41 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 42 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 43 | """ 44 | 45 | style = "dim" 46 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 47 | 48 | queue = [] 49 | 50 | # add fields from `print_order` to queue 51 | for field in print_order: 52 | queue.append(field) if field in cfg else log.warning( 53 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 54 | ) 55 | 56 | # add all the other fields to queue (not specified in `print_order`) 57 | for field in cfg: 58 | if field not in queue: 59 | queue.append(field) 60 | 61 | # generate config tree from queue 62 | for field in queue: 63 | branch = tree.add(field, style=style, guide_style=style) 64 | 65 | config_group = cfg[field] 66 | if isinstance(config_group, DictConfig): 67 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 68 | else: 69 | branch_content = str(config_group) 70 | 71 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 72 | 73 | # print config tree 74 | rich.print(tree) 75 | 76 | # save config tree to file 77 | if save_to_file: 78 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 79 | rich.print(tree, file=file) 80 | 81 | 82 | @rank_zero_only 83 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 84 | """Prompts user to input tags from command line if no tags are provided in config.""" 85 | 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | 103 | 104 | if __name__ == "__main__": 105 | from hydra import compose, initialize 106 | 107 | with initialize(version_base="1.2", config_path="../../configs"): 108 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 109 | print_config_tree(cfg, resolve=False, save_to_file=False) 110 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import pyrootutils 6 | import pytest 7 | from hydra import compose, initialize 8 | from hydra.core.global_hydra import GlobalHydra 9 | from omegaconf import DictConfig, open_dict 10 | 11 | 12 | @pytest.fixture(scope="package") 13 | def cfg_train_global() -> DictConfig: 14 | with initialize(version_base="1.2", config_path="../configs"): 15 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 16 | 17 | # set defaults for all tests 18 | with open_dict(cfg): 19 | cfg.paths.root_dir = str(pyrootutils.find_root()) 20 | cfg.trainer.max_epochs = 1 21 | cfg.trainer.limit_train_batches = 0.01 22 | cfg.trainer.limit_val_batches = 0.1 23 | cfg.trainer.limit_test_batches = 0.1 24 | cfg.trainer.accelerator = "cpu" 25 | cfg.trainer.devices = 1 26 | cfg.datamodule.num_workers = 0 27 | cfg.datamodule.pin_memory = False 28 | cfg.extras.print_config = False 29 | cfg.extras.enforce_tags = False 30 | cfg.logger = None 31 | 32 | return cfg 33 | 34 | 35 | @pytest.fixture(scope="package") 36 | def cfg_eval_global() -> DictConfig: 37 | with initialize(version_base="1.2", config_path="../configs"): 38 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 39 | 40 | # set defaults for all tests 41 | with open_dict(cfg): 42 | cfg.paths.root_dir = str(pyrootutils.find_root()) 43 | cfg.trainer.max_epochs = 1 44 | cfg.trainer.limit_test_batches = 0.1 45 | cfg.trainer.accelerator = "cpu" 46 | cfg.trainer.devices = 1 47 | cfg.datamodule.num_workers = 0 48 | cfg.datamodule.pin_memory = False 49 | cfg.extras.print_config = False 50 | cfg.extras.enforce_tags = False 51 | cfg.logger = None 52 | 53 | return cfg 54 | 55 | 56 | # this is called by each test which uses `cfg_train` arg 57 | # each test generates its own temporary logging path 58 | @pytest.fixture(scope="function") 59 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig: 60 | cfg = cfg_train_global.copy() 61 | 62 | with open_dict(cfg): 63 | cfg.paths.output_dir = str(tmp_path) 64 | cfg.paths.log_dir = str(tmp_path) 65 | 66 | yield cfg 67 | 68 | GlobalHydra.instance().clear() 69 | 70 | 71 | # this is called by each test which uses `cfg_eval` arg 72 | # each test generates its own temporary logging path 73 | @pytest.fixture(scope="function") 74 | def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: 75 | cfg = cfg_eval_global.copy() 76 | 77 | with open_dict(cfg): 78 | cfg.paths.output_dir = str(tmp_path) 79 | cfg.paths.log_dir = str(tmp_path) 80 | 81 | yield cfg 82 | 83 | GlobalHydra.instance().clear() 84 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BioinfoMachineLearning/bio-diffusion/a328950c5d23ed4333df9a10830913450d9d71a9/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import platform 6 | 7 | import pkg_resources 8 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils 9 | 10 | 11 | def _package_available(package_name: str) -> bool: 12 | """Check if a package is available in your environment.""" 13 | try: 14 | return pkg_resources.require(package_name) is not None 15 | except pkg_resources.DistributionNotFound: 16 | return False 17 | 18 | 19 | _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() 20 | 21 | _IS_WINDOWS = platform.system() == "Windows" 22 | 23 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 24 | 25 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 26 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 27 | 28 | _WANDB_AVAILABLE = _package_available("wandb") 29 | _NEPTUNE_AVAILABLE = _package_available("neptune") 30 | _COMET_AVAILABLE = _package_available("comet_ml") 31 | _MLFLOW_AVAILABLE = _package_available("mlflow") 32 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | """Adapted from: 6 | 7 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 8 | """ 9 | 10 | import sys 11 | from typing import Optional 12 | 13 | import pytest 14 | import torch 15 | from packaging.version import Version 16 | from pkg_resources import get_distribution 17 | 18 | from tests.helpers.package_available import ( 19 | _COMET_AVAILABLE, 20 | _DEEPSPEED_AVAILABLE, 21 | _FAIRSCALE_AVAILABLE, 22 | _IS_WINDOWS, 23 | _MLFLOW_AVAILABLE, 24 | _NEPTUNE_AVAILABLE, 25 | _SH_AVAILABLE, 26 | _TPU_AVAILABLE, 27 | _WANDB_AVAILABLE, 28 | ) 29 | 30 | 31 | class RunIf: 32 | """RunIf wrapper for conditional skipping of tests. 33 | 34 | Fully compatible with `@pytest.mark`. 35 | 36 | Example: 37 | 38 | @RunIf(min_torch="1.8") 39 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 40 | def test_wrapper(arg1): 41 | assert arg1 > 0 42 | """ 43 | 44 | def __new__( 45 | self, 46 | min_gpus: int = 0, 47 | min_torch: Optional[str] = None, 48 | max_torch: Optional[str] = None, 49 | min_python: Optional[str] = None, 50 | skip_windows: bool = False, 51 | sh: bool = False, 52 | tpu: bool = False, 53 | fairscale: bool = False, 54 | deepspeed: bool = False, 55 | wandb: bool = False, 56 | neptune: bool = False, 57 | comet: bool = False, 58 | mlflow: bool = False, 59 | **kwargs, 60 | ): 61 | """ 62 | Args: 63 | min_gpus: min number of GPUs required to run test 64 | min_torch: minimum pytorch version to run test 65 | max_torch: maximum pytorch version to run test 66 | min_python: minimum python version required to run test 67 | skip_windows: skip test for Windows platform 68 | tpu: if TPU is available 69 | sh: if `sh` module is required to run the test 70 | fairscale: if `fairscale` module is required to run the test 71 | deepspeed: if `deepspeed` module is required to run the test 72 | wandb: if `wandb` module is required to run the test 73 | neptune: if `neptune` module is required to run the test 74 | comet: if `comet` module is required to run the test 75 | mlflow: if `mlflow` module is required to run the test 76 | kwargs: native pytest.mark.skipif keyword arguments 77 | """ 78 | conditions = [] 79 | reasons = [] 80 | 81 | if min_gpus: 82 | conditions.append(torch.cuda.device_count() < min_gpus) 83 | reasons.append(f"GPUs>={min_gpus}") 84 | 85 | if min_torch: 86 | torch_version = get_distribution("torch").version 87 | conditions.append(Version(torch_version) < Version(min_torch)) 88 | reasons.append(f"torch>={min_torch}") 89 | 90 | if max_torch: 91 | torch_version = get_distribution("torch").version 92 | conditions.append(Version(torch_version) >= Version(max_torch)) 93 | reasons.append(f"torch<{max_torch}") 94 | 95 | if min_python: 96 | py_version = ( 97 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 98 | ) 99 | conditions.append(Version(py_version) < Version(min_python)) 100 | reasons.append(f"python>={min_python}") 101 | 102 | if skip_windows: 103 | conditions.append(_IS_WINDOWS) 104 | reasons.append("does not run on Windows") 105 | 106 | if tpu: 107 | conditions.append(not _TPU_AVAILABLE) 108 | reasons.append("TPU") 109 | 110 | if sh: 111 | conditions.append(not _SH_AVAILABLE) 112 | reasons.append("sh") 113 | 114 | if fairscale: 115 | conditions.append(not _FAIRSCALE_AVAILABLE) 116 | reasons.append("fairscale") 117 | 118 | if deepspeed: 119 | conditions.append(not _DEEPSPEED_AVAILABLE) 120 | reasons.append("deepspeed") 121 | 122 | if wandb: 123 | conditions.append(not _WANDB_AVAILABLE) 124 | reasons.append("wandb") 125 | 126 | if neptune: 127 | conditions.append(not _NEPTUNE_AVAILABLE) 128 | reasons.append("neptune") 129 | 130 | if comet: 131 | conditions.append(not _COMET_AVAILABLE) 132 | reasons.append("comet") 133 | 134 | if mlflow: 135 | conditions.append(not _MLFLOW_AVAILABLE) 136 | reasons.append("mlflow") 137 | 138 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 139 | return pytest.mark.skipif( 140 | condition=any(conditions), 141 | reason=f"Requires: [{' + '.join(reasons)}]", 142 | **kwargs, 143 | ) 144 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | from typing import List 6 | 7 | import pytest 8 | 9 | from tests.helpers.package_available import _SH_AVAILABLE 10 | 11 | if _SH_AVAILABLE: 12 | import sh 13 | 14 | 15 | def run_sh_command(command: List[str]): 16 | """Default method for executing shell commands with pytest and sh package.""" 17 | msg = None 18 | try: 19 | sh.python(command) 20 | except sh.ErrorReturnCode as e: 21 | msg = e.stderr.decode() 22 | if msg: 23 | pytest.fail(msg=msg) 24 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import hydra 6 | from hydra.core.hydra_config import HydraConfig 7 | from omegaconf import DictConfig 8 | 9 | 10 | def test_train_config(cfg_train: DictConfig): 11 | assert cfg_train 12 | assert cfg_train.datamodule 13 | assert cfg_train.model 14 | assert cfg_train.trainer 15 | 16 | HydraConfig().set_config(cfg_train) 17 | 18 | hydra.utils.instantiate(cfg_train.datamodule) 19 | hydra.utils.instantiate(cfg_train.model) 20 | hydra.utils.instantiate(cfg_train.trainer) 21 | 22 | 23 | def test_eval_config(cfg_eval: DictConfig): 24 | assert cfg_eval 25 | assert cfg_eval.datamodule 26 | assert cfg_eval.model 27 | assert cfg_eval.trainer 28 | 29 | HydraConfig().set_config(cfg_eval) 30 | 31 | hydra.utils.instantiate(cfg_eval.datamodule) 32 | hydra.utils.instantiate(cfg_eval.model) 33 | hydra.utils.instantiate(cfg_eval.trainer) 34 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | 7 | import pytest 8 | from hydra.core.hydra_config import HydraConfig 9 | from omegaconf import open_dict 10 | 11 | from src.mol_gen_eval import evaluate 12 | from src.train import train 13 | 14 | 15 | @pytest.mark.slow 16 | def test_train_eval(tmp_path, cfg_train, cfg_eval): 17 | """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" 18 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 19 | 20 | with open_dict(cfg_train): 21 | cfg_train.trainer.max_epochs = 1 22 | cfg_train.test = True 23 | 24 | HydraConfig().set_config(cfg_train) 25 | train_metric_dict, _ = train(cfg_train) 26 | 27 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 28 | 29 | with open_dict(cfg_eval): 30 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 31 | 32 | HydraConfig().set_config(cfg_eval) 33 | test_metric_dict, _ = evaluate(cfg_eval) 34 | 35 | assert test_metric_dict["test/loss"] < 1e8 36 | assert abs(train_metric_dict["test/loss"].item() - test_metric_dict["test/loss"].item()) < 0.001 37 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import pytest 6 | 7 | from tests.helpers.run_if import RunIf 8 | from tests.helpers.run_sh_command import run_sh_command 9 | 10 | startfile = "src/train.py" 11 | overrides = ["logger=[]"] 12 | 13 | 14 | @RunIf(sh=True) 15 | @pytest.mark.slow 16 | def test_experiments(tmp_path): 17 | """Test running all available experiment configs with fast_dev_run=True.""" 18 | command = [ 19 | startfile, 20 | "-m", 21 | "experiment=glob(*)", 22 | "hydra.sweep.dir=" + str(tmp_path), 23 | "++trainer.fast_dev_run=true", 24 | ] + overrides 25 | run_sh_command(command) 26 | 27 | 28 | @RunIf(sh=True) 29 | @pytest.mark.slow 30 | def test_hydra_sweep(tmp_path): 31 | """Test default hydra sweep.""" 32 | command = [ 33 | startfile, 34 | "-m", 35 | "hydra.sweep.dir=" + str(tmp_path), 36 | "model.optimizer.lr=0.005,0.01", 37 | "++trainer.fast_dev_run=true", 38 | ] + overrides 39 | 40 | run_sh_command(command) 41 | 42 | 43 | @RunIf(sh=True) 44 | @pytest.mark.slow 45 | def test_hydra_sweep_ddp_sim(tmp_path): 46 | """Test default hydra sweep with ddp sim.""" 47 | command = [ 48 | startfile, 49 | "-m", 50 | "hydra.sweep.dir=" + str(tmp_path), 51 | "trainer=ddp_sim", 52 | "trainer.max_epochs=3", 53 | "+trainer.limit_train_batches=0.01", 54 | "+trainer.limit_val_batches=0.1", 55 | "+trainer.limit_test_batches=0.1", 56 | "model.optimizer.lr=0.005,0.01,0.02", 57 | ] + overrides 58 | run_sh_command(command) 59 | 60 | 61 | @RunIf(sh=True) 62 | @pytest.mark.slow 63 | def test_optuna_sweep(tmp_path): 64 | """Test optuna sweep.""" 65 | command = [ 66 | startfile, 67 | "-m", 68 | "hparams_search=qm9_optuna", 69 | "hydra.sweep.dir=" + str(tmp_path), 70 | "hydra.sweeper.n_trials=10", 71 | "hydra.sweeper.sampler.n_startup_trials=5", 72 | "++trainer.fast_dev_run=true", 73 | ] + overrides 74 | run_sh_command(command) 75 | 76 | 77 | @RunIf(wandb=True, sh=True) 78 | @pytest.mark.slow 79 | def test_optuna_sweep_ddp_sim_wandb(tmp_path): 80 | """Test optuna sweep with wandb and ddp sim.""" 81 | command = [ 82 | startfile, 83 | "-m", 84 | "hparams_search=qm9_optuna", 85 | "hydra.sweep.dir=" + str(tmp_path), 86 | "hydra.sweeper.n_trials=5", 87 | "trainer=ddp_sim", 88 | "trainer.max_epochs=3", 89 | "+trainer.limit_train_batches=0.01", 90 | "+trainer.limit_val_batches=0.1", 91 | "+trainer.limit_test_batches=0.1", 92 | "logger=wandb", 93 | ] 94 | run_sh_command(command) 95 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | 7 | import pytest 8 | from hydra.core.hydra_config import HydraConfig 9 | from omegaconf import open_dict 10 | 11 | from src.train import train 12 | from tests.helpers.run_if import RunIf 13 | 14 | 15 | def test_train_fast_dev_run(cfg_train): 16 | """Run for 1 train, val and test step.""" 17 | HydraConfig().set_config(cfg_train) 18 | with open_dict(cfg_train): 19 | cfg_train.trainer.fast_dev_run = True 20 | cfg_train.trainer.accelerator = "cpu" 21 | train(cfg_train) 22 | 23 | 24 | @RunIf(min_gpus=1) 25 | def test_train_fast_dev_run_gpu(cfg_train): 26 | """Run for 1 train, val and test step on GPU.""" 27 | HydraConfig().set_config(cfg_train) 28 | with open_dict(cfg_train): 29 | cfg_train.trainer.fast_dev_run = True 30 | cfg_train.trainer.accelerator = "gpu" 31 | train(cfg_train) 32 | 33 | 34 | @RunIf(min_gpus=1) 35 | @pytest.mark.slow 36 | def test_train_epoch_gpu_amp(cfg_train): 37 | """Train 1 epoch on GPU with mixed-precision.""" 38 | HydraConfig().set_config(cfg_train) 39 | with open_dict(cfg_train): 40 | cfg_train.trainer.max_epochs = 1 41 | cfg_train.trainer.accelerator = "cpu" 42 | cfg_train.trainer.precision = 16 43 | train(cfg_train) 44 | 45 | 46 | @pytest.mark.slow 47 | def test_train_epoch_double_val_loop(cfg_train): 48 | """Train 1 epoch with validation loop twice per epoch.""" 49 | HydraConfig().set_config(cfg_train) 50 | with open_dict(cfg_train): 51 | cfg_train.trainer.max_epochs = 1 52 | cfg_train.trainer.val_check_interval = 0.5 53 | train(cfg_train) 54 | 55 | 56 | @pytest.mark.slow 57 | def test_train_ddp_sim(cfg_train): 58 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" 59 | HydraConfig().set_config(cfg_train) 60 | with open_dict(cfg_train): 61 | cfg_train.trainer.max_epochs = 2 62 | cfg_train.trainer.accelerator = "cpu" 63 | cfg_train.trainer.devices = 2 64 | cfg_train.trainer.strategy = "ddp_spawn" 65 | train(cfg_train) 66 | 67 | 68 | @pytest.mark.slow 69 | def test_train_resume(tmp_path, cfg_train): 70 | """Run 1 epoch, finish, and resume for another epoch.""" 71 | with open_dict(cfg_train): 72 | cfg_train.trainer.max_epochs = 1 73 | 74 | HydraConfig().set_config(cfg_train) 75 | metric_dict_1, _ = train(cfg_train) 76 | 77 | files = os.listdir(tmp_path / "checkpoints") 78 | assert "last.ckpt" in files 79 | assert "epoch_000.ckpt" in files 80 | 81 | with open_dict(cfg_train): 82 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 83 | cfg_train.trainer.max_epochs = 2 84 | 85 | metric_dict_2, _ = train(cfg_train) 86 | 87 | files = os.listdir(tmp_path / "checkpoints") 88 | assert "epoch_001.ckpt" in files 89 | assert "epoch_002.ckpt" not in files 90 | 91 | assert metric_dict_1["train/loss"] < metric_dict_2["train/loss"] 92 | assert metric_dict_1["val/loss"] < metric_dict_2["val/loss"] 93 | --------------------------------------------------------------------------------