├── .env.example ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── LICENSE ├── Makefile ├── README.md ├── checkpoints └── .gitkeep ├── configs ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── ema.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ └── pdb_na.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── experiment │ └── pdb_prot_na_gen_se3.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── pdb_prot_gen_se3_module_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ ├── aim.yaml │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── diffusion_cfg │ │ └── pdb_prot_na_gen_se3_ddpm.yaml │ ├── model_cfg │ │ └── pdb_prot_na_gen_se3_model.yaml │ └── pdb_prot_na_gen_se3_module.yaml ├── paths │ ├── default.yaml │ └── pdb_metadata.yaml ├── sample.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── data └── .gitkeep ├── environment.yaml ├── eval_results ├── analyze_eval_results.py ├── analyze_grouped_eval_results.py └── collect_eval_results.py ├── forks ├── ProteinMPNN │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── ca_model_weights │ │ ├── v_48_002.pt │ │ ├── v_48_010.pt │ │ └── v_48_020.pt │ ├── colab_notebooks │ │ ├── README.md │ │ ├── ca_only_quickdemo.ipynb │ │ ├── quickdemo.ipynb │ │ └── quickdemo_wAF2.ipynb │ ├── examples │ │ ├── submit_example_1.sh │ │ ├── submit_example_2.sh │ │ ├── submit_example_3.sh │ │ ├── submit_example_3_score_only.sh │ │ ├── submit_example_3_score_only_from_fasta.sh │ │ ├── submit_example_4.sh │ │ ├── submit_example_4_non_fixed.sh │ │ ├── submit_example_5.sh │ │ ├── submit_example_6.sh │ │ ├── submit_example_7.sh │ │ ├── submit_example_8.sh │ │ └── submit_example_pssm.sh │ ├── helper_scripts │ │ ├── assign_fixed_chains.py │ │ ├── make_bias_AA.py │ │ ├── make_bias_per_res_dict.py │ │ ├── make_fixed_positions_dict.py │ │ ├── make_pos_neg_tied_positions_dict.py │ │ ├── make_pssm_input_dict.py │ │ ├── make_tied_positions_dict.py │ │ ├── other_tools │ │ │ ├── make_omit_AA.py │ │ │ └── make_pssm_dict.py │ │ ├── parse_multiple_chains.out │ │ ├── parse_multiple_chains.py │ │ └── parse_multiple_chains.sh │ ├── inputs │ │ └── PSSM_inputs │ │ │ ├── 3HTN.npz │ │ │ └── 4YOW.npz │ ├── outputs │ │ └── training_test_output │ │ │ └── seqs │ │ │ └── 5L33.fa │ ├── protein_mpnn_run.py │ ├── protein_mpnn_utils.py │ ├── soluble_model_weights │ │ ├── v_48_002.pt │ │ ├── v_48_010.pt │ │ ├── v_48_020.pt │ │ └── v_48_030.pt │ ├── training │ │ ├── LICENSE │ │ ├── README.md │ │ ├── colab_training_example.ipynb │ │ ├── exp_020 │ │ │ ├── log.txt │ │ │ └── model_weights │ │ │ │ └── epoch_last.pt │ │ ├── model_utils.py │ │ ├── parse_cif_noX.py │ │ ├── plot_training_results.ipynb │ │ ├── submit_exp_020.sh │ │ ├── test_inference.sh │ │ ├── training.py │ │ └── utils.py │ └── vanilla_model_weights │ │ ├── v_48_002.pt │ │ ├── v_48_010.pt │ │ ├── v_48_020.pt │ │ └── v_48_030.pt └── RoseTTAFold2NA │ ├── LICENSE │ ├── README.md │ ├── RF2na-linux.yml │ ├── SE3Transformer │ ├── Dockerfile │ ├── LICENSE │ ├── NOTICE │ ├── README.md │ ├── images │ │ └── se3-transformer.png │ ├── requirements.txt │ ├── scripts │ │ ├── benchmark_inference.sh │ │ ├── benchmark_train.sh │ │ ├── benchmark_train_multi_gpu.sh │ │ ├── predict.sh │ │ ├── train.sh │ │ └── train_multi_gpu.sh │ ├── se3_transformer │ │ ├── __init__.py │ │ ├── data_loading │ │ │ ├── __init__.py │ │ │ ├── data_module.py │ │ │ └── qm9.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── basis.py │ │ │ ├── fiber.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── convolution.py │ │ │ │ ├── linear.py │ │ │ │ ├── norm.py │ │ │ │ └── pooling.py │ │ │ └── transformer.py │ │ └── runtime │ │ │ ├── __init__.py │ │ │ ├── arguments.py │ │ │ ├── callbacks.py │ │ │ ├── gpu_affinity.py │ │ │ ├── inference.py │ │ │ ├── loggers.py │ │ │ ├── metrics.py │ │ │ ├── training.py │ │ │ └── utils.py │ ├── setup.py │ └── tests │ │ ├── __init__.py │ │ ├── test_equivariance.py │ │ └── utils.py │ ├── example │ ├── RNA.fa │ ├── dna_binding_protein.fa │ └── rna_binding_protein.fa │ ├── input_prep │ ├── make_pMSAs_prot_RNA.py │ ├── make_protein_msa.sh │ ├── make_rna_msa.sh │ └── reprocess_rnac.pl │ ├── network │ ├── Attention_module.py │ ├── AuxiliaryPredictor.py │ ├── Embeddings.py │ ├── RoseTTAFoldModel.py │ ├── SE3_network.py │ ├── Track_module.py │ ├── arguments.py │ ├── chemical.py │ ├── coords6d.py │ ├── data_loader.py │ ├── ffindex.py │ ├── kinematics.py │ ├── loss.py │ ├── models.json │ ├── parsers.py │ ├── predict.py │ ├── resnet.py │ ├── scheduler.py │ ├── scoring.py │ ├── util.py │ └── util_module.py │ ├── pdb100_2021Mar03 │ ├── pdb100_2021Mar03_pdb.ffdata │ └── pdb100_2021Mar03_pdb.ffindex │ └── run_RF2NA.sh ├── img └── Nucleic_Acid_Diffusion.gif ├── logs └── .gitkeep ├── metadata └── PDB_NA_Dataset.csv ├── notebooks ├── .gitkeep ├── analyze_dataset_characteristics.ipynb ├── analyze_dataset_characteristics.py ├── creating_protein_na_datasets_from_the_pdb.ipynb └── creating_protein_na_datasets_from_the_pdb.py ├── pyproject.toml ├── scripts └── schedule.sh ├── setup.py ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── pdb │ │ │ ├── all_atom.py │ │ │ ├── chemical.py │ │ │ ├── complex.py │ │ │ ├── complex_constants.py │ │ │ ├── data_transforms.py │ │ │ ├── data_utils.py │ │ │ ├── errors.py │ │ │ ├── join_pdb_metadata.py │ │ │ ├── mmcif_parsing.py │ │ │ ├── nucleotide_constants.py │ │ │ ├── parsers.py │ │ │ ├── parsing.py │ │ │ ├── pdb_na_dataset.py │ │ │ ├── process_pdb_mmcif_files.py │ │ │ ├── process_pdb_na_files.py │ │ │ ├── protein.py │ │ │ ├── protein_constants.py │ │ │ ├── relax │ │ │ ├── amber_minimize.py │ │ │ ├── cleanup.py │ │ │ └── relax_utils.py │ │ │ ├── rigid_utils.py │ │ │ ├── so3_utils.py │ │ │ ├── stereo_chemical_props.txt │ │ │ ├── validate_na_frames.py │ │ │ ├── validate_protein_frames.py │ │ │ └── vocabulary.py │ └── pdb_na_datamodule.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── pdb │ │ │ ├── analysis_utils.py │ │ │ ├── embedders.py │ │ │ ├── framediff.py │ │ │ ├── loss.py │ │ │ ├── metrics.py │ │ │ ├── se3_diffusion.py │ │ │ └── sequence_diffusion.py │ └── pdb_prot_na_gen_se3_module.py ├── sample.py ├── train.py └── utils │ ├── __init__.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── helpers ├── __init__.py ├── package_available.py ├── run_if.py └── run_sh_command.py ├── test_configs.py ├── test_datamodules.py ├── test_eval.py ├── test_sweeps.py └── test_train.py /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # MMDiff 148 | configs/local/default.yaml 149 | .env 150 | *.pdb 151 | *.csv 152 | eval_results*/ 153 | data/ 154 | generations/ 155 | *_outputs*/ 156 | logs/ 157 | scripts/ 158 | checkpoints/* 159 | !checkpoints/.gitkeep 160 | !metadata/ 161 | 162 | # RoseTTAFold2NA 163 | forks/RoseTTAFold2NA/RF2NA/ 164 | forks/RoseTTAFold2NA/network/RF2NA_apr23.tgz 165 | forks/RoseTTAFold2NA/network/weights/ 166 | forks/RoseTTAFold2NA/UniRef30_2020_06* 167 | forks/RoseTTAFold2NA/bfd* 168 | forks/RoseTTAFold2NA/pdb100* 169 | forks/RoseTTAFold2NA/RNA* 170 | forks/RoseTTAFold2NA/**/*.stderr 171 | forks/RoseTTAFold2NA/**/*.stdout 172 | forks/RoseTTAFold2NA/**/*.a3m 173 | forks/RoseTTAFold2NA/**/*.*tab 174 | forks/RoseTTAFold2NA/**/*.hhr 175 | forks/RoseTTAFold2NA/**/*.*fa 176 | forks/RoseTTAFold2NA/**/*.npz 177 | forks/RoseTTAFold2NA/**/*db* 178 | 179 | # Aim logging 180 | .aim 181 | 182 | # MLFlow logging 183 | mlruns/ 184 | 185 | # Conda/Mamba 186 | MMDiff*/ 187 | 188 | # Hydra 189 | .hydra 190 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | args: ["--maxkb=200000"] 20 | 21 | # python code formatting 22 | - repo: https://github.com/psf/black 23 | rev: 23.3.0 24 | hooks: 25 | - id: black 26 | args: [--line-length, "99"] 27 | 28 | # python import sorting 29 | - repo: https://github.com/PyCQA/isort 30 | rev: 5.12.0 31 | hooks: 32 | - id: isort 33 | args: ["--profile", "black", "--filter-files"] 34 | 35 | # python upgrading syntax to newer version 36 | - repo: https://github.com/asottile/pyupgrade 37 | rev: v3.4.0 38 | hooks: 39 | - id: pyupgrade 40 | args: [--py38-plus] 41 | 42 | # python docstring formatting 43 | - repo: https://github.com/myint/docformatter 44 | rev: v1.7.0 45 | hooks: 46 | - id: docformatter 47 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 48 | 49 | # python check (PEP8), programming errors and code complexity 50 | - repo: https://github.com/PyCQA/flake8 51 | rev: 6.0.0 52 | hooks: 53 | - id: flake8 54 | args: 55 | [ 56 | "--extend-ignore", 57 | "E203,E402,E501,F401,F841", 58 | "--exclude", 59 | "logs/*,data/*", 60 | ] 61 | 62 | # python security linter 63 | - repo: https://github.com/PyCQA/bandit 64 | rev: "1.7.5" 65 | hooks: 66 | - id: bandit 67 | args: ["-s", "B101"] 68 | 69 | # yaml formatting 70 | - repo: https://github.com/pre-commit/mirrors-prettier 71 | rev: v3.0.0-alpha.9-for-vscode 72 | hooks: 73 | - id: prettier 74 | types: [yaml] 75 | exclude: "environment.yaml" 76 | 77 | # shell scripts linter 78 | - repo: https://github.com/shellcheck-py/shellcheck-py 79 | rev: v0.9.0.2 80 | hooks: 81 | - id: shellcheck 82 | 83 | # md formatting 84 | - repo: https://github.com/executablebooks/mdformat 85 | rev: 0.7.16 86 | hooks: 87 | - id: mdformat 88 | args: ["--number"] 89 | additional_dependencies: 90 | - mdformat-gfm 91 | - mdformat-tables 92 | - mdformat_frontmatter 93 | # - mdformat-toc 94 | # - mdformat-black 95 | 96 | # word spelling linter 97 | - repo: https://github.com/codespell-project/codespell 98 | rev: v2.2.4 99 | hooks: 100 | - id: codespell 101 | args: 102 | - --skip=logs/**,data/**,*.ipynb,*.csv 103 | - --ignore-words-list=ser,coo,nd 104 | 105 | # jupyter notebook cell output clearing 106 | - repo: https://github.com/kynan/nbstripout 107 | rev: 0.6.1 108 | hooks: 109 | - id: nbstripout 110 | 111 | # jupyter notebook linting 112 | - repo: https://github.com/nbQA-dev/nbQA 113 | rev: 1.7.0 114 | hooks: 115 | - id: nbqa-black 116 | args: ["--line-length=99"] 117 | - id: nbqa-isort 118 | args: ["--profile=black"] 119 | - id: nbqa-flake8 120 | args: 121 | [ 122 | "--extend-ignore=E203,E402,E501,F401,F841", 123 | "--exclude=logs/*,data/*", 124 | ] 125 | 126 | exclude: "forks/|checkpoints/" 127 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Profluent Bio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/checkpoints/.gitkeep -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | 7 | model_checkpoint: 8 | dirpath: ${paths.output_dir}/checkpoints 9 | filename: "epoch_{epoch:03d}" 10 | monitor: "val/na_num_c4_prime_steric_clashes" 11 | mode: "min" 12 | save_last: True 13 | auto_insert_metric_name: False 14 | 15 | model_summary: 16 | max_depth: -1 17 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/callbacks/ema.yaml: -------------------------------------------------------------------------------- 1 | # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py 2 | 3 | # Maintains an exponential moving average (EMA) of model weights. 4 | # Look at the above link for more detailed information regarding the original implementation. 5 | ema: 6 | _target_: src.models.EMA 7 | decay: 0.9999 8 | apply_ema_every_n_steps: 1 9 | start_step: 0 10 | save_ema_weights_in_callback_state: true 11 | evaluate_ema_weights_instead: true 12 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: null # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/pdb_na.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.pdb_na_datamodule.PDBNADataModule 2 | data_cfg: 3 | # CSV for path and metadata to training examples 4 | csv_path: ${paths.data_dir}/PDB-NA/processed/metadata.csv 5 | annot_path: ${paths.root_dir}/metadata/PDB_NA_Dataset.csv 6 | cluster_path: ${paths.data_dir}/PDB-NA/processed/clusters-by-entity-30.txt 7 | cluster_examples_by_structure: true # Note: Corresponds to using qTMclust structure-based clustering to select training examples 8 | qtmclust_exec_path: ${paths.qtmclust_exec_path} 9 | filtering: 10 | max_len: 768 11 | min_len: 10 12 | # Select a subset of examples, which could be useful for debugging 13 | subset: null 14 | mmcif_allowed_oligomer: [monomeric] # Note: Corresponds to filtering complexes originating from mmCIF files 15 | pdb_allowed_oligomer: null # Note: Corresponds to filtering complexes originating from PDB files 16 | max_helix_percent: 1.0 17 | max_loop_percent: 0.5 18 | min_beta_percent: -1.0 19 | rog_quantile: 0.96 20 | # Specify which types of molecules to keep in the dataset 21 | allowed_molecule_types: [protein, na] # Note: Value must be in `[[protein, na], [protein], [na], null]` 22 | # As a cross-validation holdout, remove examples containing proteins belonging to a subset of Pfam protein families 23 | holdout: null 24 | min_t: ${model.diffusion_cfg.min_timestep} 25 | samples_per_eval_length: 4 26 | num_eval_lengths: 10 27 | num_t: ${model.diffusion_cfg.num_timesteps} 28 | max_squared_res: 600000 29 | batch_size: 256 30 | eval_batch_size: ${.samples_per_eval_length} 31 | sample_mode: time_batch # Note: Must be in [`time_batch`, `length_batch`, `cluster_time_batch`, `cluster_length_batch`] 32 | num_workers: 5 33 | prefetch_factor: 100 34 | # Sequence diffusion arguments 35 | diffuse_sequence: ${model.diffusion_cfg.diffuse_sequence} 36 | num_sequence_t: ${model.diffusion_cfg.num_sequence_timesteps} 37 | sequence_noise_schedule: ${model.diffusion_cfg.sequence_noise_schedule} # note: must be a value in [`linear`, `cosine`, `sqrt`] 38 | sequence_sample_distribution: ${model.diffusion_cfg.sequence_sample_distribution} # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function 39 | sequence_sample_distribution_gmm_means: ${model.diffusion_cfg.sequence_sample_distribution_gmm_means} 40 | sequence_sample_distribution_gmm_variances: ${model.diffusion_cfg.sequence_sample_distribution_gmm_variances} 41 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/experiment/pdb_prot_na_gen_se3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=pdb_prot_na_gen_se3 5 | 6 | defaults: 7 | - override /data: pdb_na.yaml 8 | - override /model: pdb_prot_na_gen_se3_module.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["pdb", "prot_na_gen", "se3"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 100 21 | max_epochs: 1000 22 | 23 | callbacks: 24 | model_checkpoint: 25 | monitor: "val/na_num_c4_prime_steric_clashes" 26 | mode: "min" 27 | 28 | model: 29 | optimizer: 30 | lr: 0.0001 31 | 32 | data: 33 | data_cfg: 34 | batch_size: 256 35 | 36 | logger: 37 | # mlflow: 38 | # _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 39 | # experiment_name: MMDiff 40 | # run_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}_PDBProtGenSE3 41 | # # tags: ${tags} 42 | # # save_dir: "./mlruns" 43 | # # log_model: false 44 | # prefix: "" 45 | # artifact_location: null 46 | # # run_id: "" 47 | wandb: 48 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 49 | # name: "" # name of the run (normally generated by wandb) 50 | save_dir: "${paths.output_dir}" 51 | offline: False 52 | id: null # pass correct id to resume experiment! 53 | anonymous: null # enable anonymous logging 54 | project: "MMDiff" 55 | log_model: False # upload lightning ckpts 56 | prefix: "" # a string to put at the beginning of metric keys 57 | # entity: "" # set to name of your wandb team 58 | group: "" 59 | tags: [] 60 | job_type: "" 61 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/pdb_prot_gen_se3_module_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/na_num_c4_prime_steric_clashes" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: minimize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(8, 16, 32, 64) 50 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "MMDiff" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | experiment_name: MMDiff 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | # log_model: false 11 | prefix: "" 12 | artifact_location: null 13 | # run_id: "" 14 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/MMDiff 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "MMDiff" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/diffusion_cfg/pdb_prot_na_gen_se3_ddpm.yaml: -------------------------------------------------------------------------------- 1 | # general diffusion arguments 2 | ddpm_mode: se3_unconditional # [se3_unconditional] 3 | diffusion_network: se3_score # [se3_score] 4 | dynamics_network: framediff # [framediff] 5 | eval_epochs: 10 6 | 7 | # SE(3) diffusion arguments 8 | diffuse_rotations: true 9 | diffuse_translations: true 10 | min_timestep: 0.01 11 | num_timesteps: 100 12 | 13 | # R(3) diffusion arguments 14 | r3: 15 | min_b: 0.1 16 | max_b: 20.0 17 | coordinate_scaling: 0.1 18 | 19 | # SO(3) diffusion arguments 20 | so3: 21 | num_omega: 1000 22 | num_sigma: 1000 23 | min_sigma: 0.1 24 | max_sigma: 1.5 25 | schedule: logarithmic 26 | cache_dir: .cache/ 27 | use_cached_score: false 28 | 29 | # sequence diffusion arguments 30 | diffuse_sequence: true 31 | num_sequence_timesteps: ${.num_timesteps} 32 | sequence_noise_schedule: sqrt # note: must be a value in [`linear`, `cosine`, `sqrt`] 33 | sequence_sample_distribution: normal # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function 34 | sequence_sample_distribution_gmm_means: [-1.0, 1.0] 35 | sequence_sample_distribution_gmm_variances: [1.0, 1.0] 36 | 37 | # sampling arguments 38 | sampling: 39 | sequence_noise_scale: 1.0 40 | structure_noise_scale: 1.0 41 | apply_na_consensus_sampling: false 42 | -------------------------------------------------------------------------------- /configs/model/model_cfg/pdb_prot_na_gen_se3_model.yaml: -------------------------------------------------------------------------------- 1 | node_input_dim: 1 2 | edge_input_dim: 2 3 | existing_edge_embedding_dim: 0 4 | 5 | node_hidden_dim: 256 6 | edge_hidden_dim: 128 7 | num_layers: 5 8 | dropout: 0.0 9 | 10 | c_skip: 64 11 | num_angles: ${subtract:${resolve_variable:src.data.components.pdb.complex_constants.NUM_PROT_NA_TORSIONS},2} # note: must be `1` for protein design tasks and `8` for either nucleic acid or protein-nucleic acid design tasks 12 | 13 | clip_gradients: false 14 | log_grad_flow_steps: 3000 # after how many steps to log gradient flow 15 | 16 | embedding: 17 | index_embedding_size: 128 18 | max_relative_idx: 32 19 | max_relative_chain: 2 20 | use_chain_relative: true 21 | molecule_type_embedding_size: 4 22 | molecule_type_embedded_size: 128 23 | embed_molecule_type_conditioning: True 24 | embed_self_conditioning: True 25 | num_bins: 22 26 | min_bin: 1e-5 27 | max_bin: 20.0 28 | 29 | ipa: 30 | c_s: ${..node_hidden_dim} 31 | c_z: ${..edge_hidden_dim} 32 | c_hidden: 256 33 | c_resnet: 128 34 | num_resnet_blocks: 2 35 | num_heads: 8 36 | num_qk_points: 8 37 | num_v_points: 12 38 | coordinate_scaling: ${...diffusion_cfg.r3.coordinate_scaling} 39 | inf: 1e5 40 | epsilon: 1e-8 41 | 42 | tfmr: 43 | num_heads: 4 44 | num_layers: 2 45 | 46 | loss_cfg: 47 | trans_loss_weight: 1.0 48 | rot_loss_weight: 0.5 49 | rot_loss_t_threshold: 0.2 50 | separate_rot_loss: False 51 | trans_x0_threshold: 1.0 52 | bb_atom_loss_weight: 1.0 53 | bb_atom_loss_t_filter: 0.25 54 | dist_mat_loss_weight: 1.0 55 | dist_mat_loss_t_filter: 0.25 56 | interface_dist_mat_loss_weight: 1.0 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_interfaces}` is `true` 57 | interface_dist_mat_loss_t_filter: 0.25 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_interfaces}` is `true` 58 | aux_loss_weight: 0.25 59 | torsion_loss_weight: 1.0 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true` 60 | torsion_loss_t_filter: 0.25 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true` 61 | torsion_norm_loss_weight: 0.02 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true` 62 | cce_seq_loss_weight: 1.0 # note: will only be used if `${model.diffusion_cfg.diffuse_sequence}` is `true` 63 | kl_seq_loss_weight: 0.25 # note: will only be used if `${model.diffusion_cfg.diffuse_sequence}` is `true` 64 | supervise_n1_atom_positions: true # note: if `true`, for pyrimidine residues will instead supervise predicted N9 atom positions using ground-truth N1 atom positions 65 | supervise_interfaces: true # note: if `true`, will supervise the model's predicted pairwise distances specifically for interfacing residues 66 | supervise_torsion_angles: false # note: if `false`, will supervise the model's predicted torsion angles indirectly using predicted coordinates 67 | coordinate_scaling: ${...diffusion_cfg.r3.coordinate_scaling} 68 | -------------------------------------------------------------------------------- /configs/model/pdb_prot_na_gen_se3_module.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.pdb_prot_na_gen_se3_module.PDBProtNAGenSE3LitModule 2 | 3 | optimizer: 4 | _target_: torch.optim.Adam 5 | _partial_: true 6 | lr: 0.0001 7 | 8 | scheduler: 9 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 10 | # _partial_: true 11 | # mode: min 12 | # factor: 0.1 13 | # patience: 10 14 | 15 | defaults: 16 | - model_cfg: pdb_prot_na_gen_se3_model.yaml 17 | - diffusion_cfg: pdb_prot_na_gen_se3_ddpm.yaml 18 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | 20 | # paths to local executables 21 | rf2na_exec_path: ${paths.root_dir}/forks/RoseTTAFold2NA/run_RF2NA.sh 22 | proteinmpnn_dir: ${paths.root_dir}/forks/ProteinMPNN/ 23 | usalign_exec_path: ~/Programs/USalign/USalign # note: must be an absolute path during runtime 24 | qtmclust_exec_path: ~/Programs/USalign/qTMclust # note: must be an absolute path during runtime 25 | -------------------------------------------------------------------------------- /configs/paths/pdb_metadata.yaml: -------------------------------------------------------------------------------- 1 | na_metadata_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/na_metadata.csv 2 | protein_metadata_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/protein_metadata.csv 3 | metadata_output_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/metadata.csv 4 | -------------------------------------------------------------------------------- /configs/sample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: pdb_na.yaml # choose datamodule with `predict_dataloader()` for inference 6 | - model: pdb_prot_na_gen_se3_module.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | task_name: "sample" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for sampling 18 | ckpt_path: checkpoints/protein_na_sequence_structure_g42jpyug_rotations_epoch_286.ckpt 19 | 20 | # establishing inference arguments 21 | inference: 22 | name: protein_na_sequence_structure_se3_discrete_diffusion_sampling_${now:%Y-%m-%d}_${now:%H:%M:%S} 23 | seed: 123 24 | run_statified_eval: false # note: if `true`, will instead use the `validation` dataset's range of examples for sampling evaluation; if `false`, (large) multi-state PDB trajectories will be recorded 25 | filter_eval_split: false # note: if `true`, will use `samples.min/max_length` and `samples.min/max_num_chains` to filter out examples from the evaluation dataset 26 | # whether to compute evaluation metrics for generated samples e.g., using RoseTTAFold2NA 27 | # warning: `self_consistency` and `novelty` will be time and memory-intensive metrics to compute 28 | run_self_consistency_eval: false 29 | run_diversity_eval: false 30 | run_novelty_eval: false 31 | use_rf2na_single_sequence_mode: true # note: trades prediction accuracy for time complexity 32 | generate_protein_sequences_using_pmpnn: false # note: should only be `true` if generating protein-only samples and instead using ProteinMPNN for backbone sequence design 33 | measure_auxiliary_na_metrics: false # note: should only be `true` if generating nucleic acid-only samples 34 | 35 | # output directory for samples 36 | output_dir: ./inference_outputs/ 37 | 38 | diffusion: 39 | # number of diffusion steps for sampling 40 | num_t: 500 41 | # note: analogous to sampling temperature 42 | sequence_noise_scale: 1.0 43 | structure_noise_scale: 0.1 44 | # final diffusion step `t` for sampling 45 | min_t: 0.01 46 | # whether to apply a 50% majority rule that transforms all generated nucleotide residue types to be exclusively of DNA or RNA types 47 | apply_na_consensus_sampling: true 48 | # whether to employ a random diffusion baseline for sequence-structure generation 49 | employ_random_baseline: false 50 | # note: the following are for overriding sequence diffusion arguments 51 | diffuse_sequence: true 52 | num_sequence_timesteps: ${.num_t} 53 | sequence_noise_schedule: sqrt # note: must be a value in [`linear`, `cosine`, `sqrt`], where `linear` typically yields the best looking structures; `cosine` often yields the most designable structures; and `sqrt` is what is used during training 54 | sequence_sample_distribution: normal # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function 55 | sequence_sample_distribution_gmm_means: [-1.0, 1.0] 56 | sequence_sample_distribution_gmm_variances: [1.0, 1.0] 57 | 58 | samples: 59 | # number of backbone structures and sequences to sample and score per sequence length 60 | samples_per_length: 30 61 | # minimum sequence length to sample 62 | min_length: 10 63 | # maximum sequence length to sample 64 | max_length: 50 65 | # note: `num_length_steps` will only be used if `inference.run_statified_eval` is `true` 66 | num_length_steps: 10 67 | # note: `min_num_chains` will only be used if `inference.run_statified_eval` is `true` 68 | min_num_chains: 1 69 | # note: `max_num_chains` will only be used if `inference.run_statified_eval` is `true` 70 | max_num_chains: 4 71 | # gap between lengths to sample 72 | # (note: this script will sample all lengths 73 | # in range(min_length, max_length, length_step)) 74 | length_step: 10 75 | # a syntactic specification mapping the standardized molecule types 76 | # (`A`: amino acid, `D`: deoxyribonucleic acid, `R`: ribonucleic acid) to each residue index; 77 | # note: the sum of each string's residue index annotations must sum to the 78 | # current length specified by the interval [`min_length`, `max_length`, `length_step`]: 79 | # e.g., `residue_molecule_type_mappings: ['R:100', 'A:75,D:75']` 80 | residue_molecule_type_mappings: 81 | ["R:10", "D:20", "R:10,D:20", "R:20,R:20", "A:40,R:10"] 82 | # residue_molecule_type_mappings: ["R:90", "R:40,A:60"] # e.g., for protein-nucleic acid generation 83 | # a syntactic specification mapping to each residue index 84 | # one of the PDB's 62 alphanumeric chain identifiers: 85 | # (i.e., one of `ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789`); 86 | # note: the sum of each string's residue index annotations must sum to the 87 | # current length specified by the interval [`min_length`, `max_length`, `length_step`]: 88 | # e.g., `residue_chain_mappings: ['a:50,b:40', 'a:50,b:50']` 89 | residue_chain_mappings: 90 | ["a:10", "a:20", "a:10,b:20", "a:20,b:20", "a:40,b:10"] 91 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: pdb_na.yaml 8 | - model: pdb_prot_na_gen_se3_module.yaml 9 | - callbacks: default.yaml 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | tags: ["dev"] 37 | 38 | # set False to skip model training 39 | train: True 40 | 41 | # evaluate on test set, using best model weights achieved during training 42 | # lightning chooses best weights based on the metric specified in checkpoint callback 43 | test: True 44 | 45 | # compile model for faster training with pytorch 2.0 46 | compile: False 47 | 48 | # simply provide checkpoint path to resume training 49 | ckpt_path: null 50 | 51 | # seed for random number generators in pytorch, numpy and python.random 52 | seed: null 53 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: 4 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: gpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # number of sanity-check validation forward passes to run prior to model training 15 | num_sanity_val_steps: 0 16 | 17 | # perform a validation loop every N training epochs 18 | check_val_every_n_epoch: null 19 | val_check_interval: 10000 20 | # note: when `check_val_every_n_epoch` is `null`, 21 | # Lightning will require you to provide an integer 22 | # value for `val_check_interval` to instead perform a 23 | # validation epoch every `val_check_interval` steps 24 | 25 | # gradient accumulation to simulate larger-than-GPU-memory batch sizes 26 | accumulate_grad_batches: 1 27 | 28 | # set True to to ensure deterministic results 29 | # makes training slower but gives more reproducibility than just setting seeds 30 | deterministic: False 31 | 32 | # track and log the vector norm of each gradient 33 | # track_grad_norm: 2.0 34 | 35 | # profile code comprehensively 36 | profiler: 37 | # _target_: pytorch_lightning.profilers.PyTorchProfiler 38 | 39 | # inform Lightning that we will be supplying a custom `Sampler` in our `DataModule` class 40 | use_distributed_sampler: False 41 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/data/.gitkeep -------------------------------------------------------------------------------- /eval_results/collect_eval_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | 6 | def main(eval_outputs_dir: str, eval_results_dir: str): 7 | for item in os.listdir(eval_outputs_dir): 8 | eval_outputs_subdir = os.path.join(eval_outputs_dir, item) 9 | if os.path.isdir(eval_outputs_subdir): 10 | all_rank_predictions_csv_path = os.path.join( 11 | eval_outputs_subdir, "all_rank_predictions.csv" 12 | ) 13 | if os.path.exists(all_rank_predictions_csv_path): 14 | eval_results_run_name = Path(eval_outputs_subdir).stem 15 | eval_results_run_type = "_".join( 16 | [ 17 | s.replace("naive", "rb") 18 | for s in eval_results_run_name.split("_se3_discrete_diffusion_stratified")[ 19 | 0 20 | ].split("_") 21 | ] 22 | ) 23 | eval_results_run_category = "_".join( 24 | eval_results_run_name.split("_se3_discrete_diffusion_stratified")[1] 25 | .split("eval")[0] 26 | .split("_") 27 | ) 28 | eval_results_csv_name = os.path.join( 29 | eval_results_dir, 30 | f"{eval_results_run_type}{eval_results_run_category}all_rank_predictions.csv", 31 | ) 32 | shutil.copyfile(all_rank_predictions_csv_path, eval_results_csv_name) 33 | 34 | 35 | if __name__ == "__main__": 36 | main(eval_outputs_dir="inference_eval_outputs/", eval_results_dir="eval_results/") 37 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Justas Dauparas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/ca_model_weights/v_48_002.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_002.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/ca_model_weights/v_48_010.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_010.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/ca_model_weights/v_48_020.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_020.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/colab_notebooks/README.md: -------------------------------------------------------------------------------- 1 | Open In Colab 2 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 2 6 | #SBATCH --output=example_1.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/" 11 | 12 | output_dir="../outputs/example_1_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 19 | 20 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 21 | 22 | python ../protein_mpnn_run.py \ 23 | --jsonl_path $path_for_parsed_chains \ 24 | --out_folder $output_dir \ 25 | --num_seq_per_target 2 \ 26 | --sampling_temp "0.1" \ 27 | --seed 37 \ 28 | --batch_size 1 29 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 2 6 | #SBATCH --output=example_2.out 7 | 8 | source activate mlfold 9 | 10 | 11 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/" 12 | 13 | output_dir="../outputs/example_2_outputs" 14 | if [ ! -d $output_dir ] 15 | then 16 | mkdir -p $output_dir 17 | fi 18 | 19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 21 | chains_to_design="A B" 22 | 23 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 24 | 25 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 26 | 27 | python ../protein_mpnn_run.py \ 28 | --jsonl_path $path_for_parsed_chains \ 29 | --chain_id_jsonl $path_for_assigned_chains \ 30 | --out_folder $output_dir \ 31 | --num_seq_per_target 2 \ 32 | --sampling_temp "0.1" \ 33 | --seed 37 \ 34 | --batch_size 1 35 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_3.out 7 | 8 | source activate mlfold 9 | 10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb" 11 | 12 | output_dir="../outputs/example_3_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | chains_to_design="A B" 19 | 20 | python ../protein_mpnn_run.py \ 21 | --pdb_path $path_to_PDB \ 22 | --pdb_path_chains "$chains_to_design" \ 23 | --out_folder $output_dir \ 24 | --num_seq_per_target 2 \ 25 | --sampling_temp "0.1" \ 26 | --seed 37 \ 27 | --batch_size 1 28 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_3_score_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_3.out 7 | 8 | source activate mlfold 9 | 10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb" 11 | 12 | output_dir="../outputs/example_3_score_only_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | chains_to_design="A B" 19 | 20 | python ../protein_mpnn_run.py \ 21 | --pdb_path $path_to_PDB \ 22 | --pdb_path_chains "$chains_to_design" \ 23 | --out_folder $output_dir \ 24 | --num_seq_per_target 10 \ 25 | --sampling_temp "0.1" \ 26 | --score_only 1 \ 27 | --seed 37 \ 28 | --batch_size 1 29 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_3_score_only_from_fasta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_3_from_fasta.out 7 | 8 | source activate mlfold 9 | 10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb" 11 | path_to_fasta="/home/justas/projects/github/ProteinMPNN/outputs/example_3_outputs/seqs/3HTN.fa" 12 | 13 | output_dir="../outputs/example_3_score_only_from_fasta_outputs" 14 | if [ ! -d $output_dir ] 15 | then 16 | mkdir -p $output_dir 17 | fi 18 | 19 | chains_to_design="A B" 20 | 21 | python ../protein_mpnn_run.py \ 22 | --path_to_fasta $path_to_fasta \ 23 | --pdb_path $path_to_PDB \ 24 | --pdb_path_chains "$chains_to_design" \ 25 | --out_folder $output_dir \ 26 | --num_seq_per_target 5 \ 27 | --sampling_temp "0.1" \ 28 | --score_only 1 \ 29 | --seed 13 \ 30 | --batch_size 1 31 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_4.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/" 11 | 12 | output_dir="../outputs/example_4_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | 19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl" 22 | chains_to_design="A C" 23 | #The first amino acid in the chain corresponds to 1 and not PDB residues index for now. 24 | fixed_positions="1 2 3 4 5 6 7 8 23 25, 10 11 12 13 14 15 16 17 18 19 20 40" #fixing/not designing residues 1 2 3...25 in chain A and residues 10 11 12...40 in chain C 25 | 26 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 27 | 28 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 29 | 30 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$fixed_positions" 31 | 32 | python ../protein_mpnn_run.py \ 33 | --jsonl_path $path_for_parsed_chains \ 34 | --chain_id_jsonl $path_for_assigned_chains \ 35 | --fixed_positions_jsonl $path_for_fixed_positions \ 36 | --out_folder $output_dir \ 37 | --num_seq_per_target 2 \ 38 | --sampling_temp "0.1" \ 39 | --seed 37 \ 40 | --batch_size 1 41 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_4_non_fixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_4_non_fixed.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/" 11 | 12 | output_dir="../outputs/example_4_non_fixed_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | 19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl" 22 | chains_to_design="A C" 23 | #The first amino acid in the chain corresponds to 1 and not PDB residues index for now. 24 | design_only_positions="1 2 3 4 5 6 7 8 9 10, 3 4 5 6 7 8" #design only these residues; use flag --specify_non_fixed 25 | 26 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 27 | 28 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 29 | 30 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$design_only_positions" --specify_non_fixed 31 | 32 | python ../protein_mpnn_run.py \ 33 | --jsonl_path $path_for_parsed_chains \ 34 | --chain_id_jsonl $path_for_assigned_chains \ 35 | --fixed_positions_jsonl $path_for_fixed_positions \ 36 | --out_folder $output_dir \ 37 | --num_seq_per_target 2 \ 38 | --sampling_temp "0.1" \ 39 | --seed 37 \ 40 | --batch_size 1 41 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_5.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/" 11 | 12 | output_dir="../outputs/example_5_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | 19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl" 22 | path_for_tied_positions=$output_dir"/tied_pdbs.jsonl" 23 | chains_to_design="A C" 24 | fixed_positions="9 10 11 12 13 14 15 16 17 18 19 20 21 22 23, 10 11 18 19 20 22" 25 | tied_positions="1 2 3 4 5 6 7 8, 1 2 3 4 5 6 7 8" #two list must match in length; residue 1 in chain A and C will be sampled togther; 26 | 27 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 28 | 29 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 30 | 31 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$fixed_positions" 32 | 33 | python ../helper_scripts/make_tied_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_tied_positions --chain_list "$chains_to_design" --position_list "$tied_positions" 34 | 35 | python ../protein_mpnn_run.py \ 36 | --jsonl_path $path_for_parsed_chains \ 37 | --chain_id_jsonl $path_for_assigned_chains \ 38 | --fixed_positions_jsonl $path_for_fixed_positions \ 39 | --tied_positions_jsonl $path_for_tied_positions \ 40 | --out_folder $output_dir \ 41 | --num_seq_per_target 2 \ 42 | --sampling_temp "0.1" \ 43 | --seed 37 \ 44 | --batch_size 1 45 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_6.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_homooligomers/pdbs/" 11 | 12 | output_dir="../outputs/example_6_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | 19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 20 | path_for_tied_positions=$output_dir"/tied_pdbs.jsonl" 21 | path_for_designed_sequences=$output_dir"/temp_0.1" 22 | 23 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 24 | 25 | python ../helper_scripts/make_tied_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_tied_positions --homooligomer 1 26 | 27 | python ../protein_mpnn_run.py \ 28 | --jsonl_path $path_for_parsed_chains \ 29 | --tied_positions_jsonl $path_for_tied_positions \ 30 | --out_folder $output_dir \ 31 | --num_seq_per_target 2 \ 32 | --sampling_temp "0.2" \ 33 | --seed 37 \ 34 | --batch_size 1 35 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 2 6 | #SBATCH --output=example_7.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/" 11 | 12 | output_dir="../outputs/example_7_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 19 | 20 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 21 | 22 | python ../protein_mpnn_run.py \ 23 | --jsonl_path $path_for_parsed_chains \ 24 | --out_folder $output_dir \ 25 | --num_seq_per_target 1 \ 26 | --sampling_temp "0.1" \ 27 | --unconditional_probs_only 1 \ 28 | --seed 37 \ 29 | --batch_size 1 30 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 2 6 | #SBATCH --output=example_8.out 7 | 8 | source activate mlfold 9 | 10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/" 11 | 12 | output_dir="../outputs/example_8_outputs" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | path_for_bias=$output_dir"/bias_pdbs.jsonl" 19 | #Adding global polar amino acid bias (Doug Tischer) 20 | AA_list="D E H K N Q R S T W Y" 21 | bias_list="1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39" 22 | python ../helper_scripts/make_bias_AA.py --output_path=$path_for_bias --AA_list="$AA_list" --bias_list="$bias_list" 23 | 24 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 25 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 26 | 27 | python ../protein_mpnn_run.py \ 28 | --jsonl_path $path_for_parsed_chains \ 29 | --out_folder $output_dir \ 30 | --bias_AA_jsonl $path_for_bias \ 31 | --num_seq_per_target 2 \ 32 | --sampling_temp "0.1" \ 33 | --seed 37 \ 34 | --batch_size 1 35 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/examples/submit_example_pssm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 2 6 | #SBATCH --output=example_2.out 7 | 8 | source activate mlfold 9 | 10 | 11 | #new_probabilities_using_PSSM = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered 12 | #probs - predictions from MPNN 13 | #pssm_bias_gathered - input PSSM bias (needs to be a probability distribution) 14 | #pssm_multi - a number between 0.0 (no bias) and 1.0 (no MPNN) inputed via flag --pssm_multi; this is a global number equally applied to all the residues 15 | #pssm_coef_gathered - a number between 0.0 (no bias) and 1.0 (no MPNN) inputed via ../helper_scripts/make_pssm_input_dict.py can be adjusted per residue level; i.e only apply PSSM bias to specific residues; or chains 16 | 17 | 18 | 19 | pssm_input_path="../inputs/PSSM_inputs" 20 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/" 21 | 22 | output_dir="../outputs/example_pssm_outputs" 23 | if [ ! -d $output_dir ] 24 | then 25 | mkdir -p $output_dir 26 | fi 27 | 28 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl" 29 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl" 30 | pssm=$output_dir"/pssm.jsonl" 31 | chains_to_design="A B" 32 | 33 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains 34 | 35 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design" 36 | 37 | python ../helper_scripts/make_pssm_input_dict.py --jsonl_input_path=$path_for_parsed_chains --PSSM_input_path=$pssm_input_path --output_path=$pssm 38 | 39 | python ../protein_mpnn_run.py \ 40 | --jsonl_path $path_for_parsed_chains \ 41 | --chain_id_jsonl $path_for_assigned_chains \ 42 | --out_folder $output_dir \ 43 | --num_seq_per_target 2 \ 44 | --sampling_temp "0.1" \ 45 | --seed 37 \ 46 | --batch_size 1 \ 47 | --pssm_jsonl $pssm \ 48 | --pssm_multi 0.3 \ 49 | --pssm_bias_flag 1 50 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/assign_fixed_chains.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import json 5 | 6 | with open(args.input_path, 'r') as json_file: 7 | json_list = list(json_file) 8 | 9 | global_designed_chain_list = [] 10 | if args.chain_list != '': 11 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 12 | my_dict = {} 13 | for json_str in json_list: 14 | result = json.loads(json_str) 15 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] #['A','B', 'C',...] 16 | if len(global_designed_chain_list) > 0: 17 | designed_chain_list = global_designed_chain_list 18 | else: 19 | #manually specify, e.g. 20 | designed_chain_list = ["A"] 21 | fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list] #fix/do not redesign these chains 22 | my_dict[result['name']]= (designed_chain_list, fixed_chain_list) 23 | 24 | with open(args.output_path, 'w') as f: 25 | f.write(json.dumps(my_dict) + '\n') 26 | 27 | 28 | if __name__ == "__main__": 29 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 31 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 32 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be designed") 33 | 34 | args = argparser.parse_args() 35 | main(args) 36 | 37 | # Output looks like this: 38 | # {"5TTA": [["A"], ["B"]], "3LIS": [["A"], ["B"]]} 39 | 40 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/make_bias_AA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import numpy as np 6 | import json 7 | 8 | bias_list = [float(item) for item in args.bias_list.split()] 9 | AA_list = [str(item) for item in args.AA_list.split()] 10 | 11 | my_dict = dict(zip(AA_list, bias_list)) 12 | 13 | with open(args.output_path, 'w') as f: 14 | f.write(json.dumps(my_dict) + '\n') 15 | 16 | 17 | if __name__ == "__main__": 18 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 20 | argparser.add_argument("--AA_list", type=str, default='', help="List of AAs to be biased") 21 | argparser.add_argument("--bias_list", type=str, default='', help="AA bias strengths") 22 | 23 | args = argparser.parse_args() 24 | main(args) 25 | 26 | #e.g. output 27 | #{"A": -0.01, "G": 0.02} 28 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/make_bias_per_res_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | 9 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' 10 | 11 | mpnn_alphabet_dict = {'A': 0,'C': 1,'D': 2,'E': 3,'F': 4,'G': 5,'H': 6,'I': 7,'K': 8,'L': 9,'M': 10,'N': 11,'P': 12,'Q': 13,'R': 14,'S': 15,'T': 16,'V': 17,'W': 18,'Y': 19,'X': 20} 12 | 13 | with open(args.input_path, 'r') as json_file: 14 | json_list = list(json_file) 15 | 16 | my_dict = {} 17 | for json_str in json_list: 18 | result = json.loads(json_str) 19 | all_chain_list = [item[-1:] for item in list(result) if item[:10]=='seq_chain_'] 20 | bias_by_res_dict = {} 21 | for chain in all_chain_list: 22 | chain_length = len(result[f'seq_chain_{chain}']) 23 | bias_per_residue = np.zeros([chain_length, 21]) 24 | 25 | 26 | if chain == 'A': 27 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15] 28 | amino_acids = [5, 9] #[G, L] 29 | for res in residues: 30 | for aa in amino_acids: 31 | bias_per_residue[res, aa] = 100.5 32 | 33 | if chain == 'C': 34 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15] 35 | amino_acids = range(21)[1:] #[G, L] 36 | for res in residues: 37 | for aa in amino_acids: 38 | bias_per_residue[res, aa] = -100.5 39 | 40 | bias_by_res_dict[chain] = bias_per_residue.tolist() 41 | my_dict[result['name']] = bias_by_res_dict 42 | 43 | with open(args.output_path, 'w') as f: 44 | f.write(json.dumps(my_dict) + '\n') 45 | 46 | 47 | if __name__ == "__main__": 48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 51 | 52 | args = argparser.parse_args() 53 | main(args) 54 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/make_fixed_positions_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | import itertools 9 | 10 | with open(args.input_path, 'r') as json_file: 11 | json_list = list(json_file) 12 | 13 | fixed_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")] 14 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 15 | my_dict = {} 16 | 17 | if not args.specify_non_fixed: 18 | for json_str in json_list: 19 | result = json.loads(json_str) 20 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 21 | fixed_position_dict = {} 22 | for i, chain in enumerate(global_designed_chain_list): 23 | fixed_position_dict[chain] = fixed_list[i] 24 | for chain in all_chain_list: 25 | if chain not in global_designed_chain_list: 26 | fixed_position_dict[chain] = [] 27 | my_dict[result['name']] = fixed_position_dict 28 | else: 29 | for json_str in json_list: 30 | result = json.loads(json_str) 31 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 32 | fixed_position_dict = {} 33 | for chain in all_chain_list: 34 | seq_length = len(result[f'seq_chain_{chain}']) 35 | all_residue_list = (np.arange(seq_length)+1).tolist() 36 | if chain not in global_designed_chain_list: 37 | fixed_position_dict[chain] = all_residue_list 38 | else: 39 | idx = np.argwhere(np.array(global_designed_chain_list) == chain)[0][0] 40 | fixed_position_dict[chain] = list(set(all_residue_list)-set(fixed_list[idx])) 41 | my_dict[result['name']] = fixed_position_dict 42 | 43 | with open(args.output_path, 'w') as f: 44 | f.write(json.dumps(my_dict) + '\n') 45 | 46 | #e.g. output 47 | #{"5TTA": {"A": [1, 2, 3, 7, 8, 9, 22, 25, 33], "B": []}, "3LIS": {"A": [], "B": []}} 48 | 49 | if __name__ == "__main__": 50 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 52 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 53 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed") 54 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain") 55 | argparser.add_argument("--specify_non_fixed", action="store_true", default=False, help="Allows specifying just residues that need to be designed (default: false)") 56 | 57 | args = argparser.parse_args() 58 | main(args) 59 | 60 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/make_pssm_input_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | import json 5 | import numpy as np 6 | with open(args.jsonl_input_path, 'r') as json_file: 7 | json_list = list(json_file) 8 | 9 | my_dict = {} 10 | for json_str in json_list: 11 | result = json.loads(json_str) 12 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 13 | path_to_PSSM = args.PSSM_input_path+"/"+result['name'] + ".npz" 14 | print(path_to_PSSM) 15 | pssm_input = np.load(path_to_PSSM) 16 | pssm_dict = {} 17 | for chain in all_chain_list: 18 | pssm_dict[chain] = {} 19 | pssm_dict[chain]['pssm_coef'] = pssm_input[chain+'_coef'].tolist() #[L] per position coefficient to trust PSSM; 0.0 - do not use it; 1.0 - just use PSSM only 20 | pssm_dict[chain]['pssm_bias'] = pssm_input[chain+'_bias'].tolist() #[L,21] probability (sums up to 1.0 over alphabet of size 21) from PSSM 21 | pssm_dict[chain]['pssm_log_odds'] = pssm_input[chain+'_odds'].tolist() #[L,21] log_odds ratios coming from PSSM; optional/not needed 22 | my_dict[result['name']] = pssm_dict 23 | 24 | #Write output to: 25 | with open(args.output_path, 'w') as f: 26 | f.write(json.dumps(my_dict) + '\n') 27 | 28 | if __name__ == "__main__": 29 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | 31 | argparser.add_argument("--PSSM_input_path", type=str, help="Path to PSSMs saved as npz files.") 32 | argparser.add_argument("--jsonl_input_path", type=str, help="Path where to load .jsonl dictionary of parsed pdbs.") 33 | argparser.add_argument("--output_path", type=str, help="Path where to save .jsonl dictionary with PSSM bias.") 34 | 35 | args = argparser.parse_args() 36 | main(args) 37 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/make_tied_positions_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def main(args): 4 | 5 | import glob 6 | import random 7 | import numpy as np 8 | import json 9 | import itertools 10 | 11 | with open(args.input_path, 'r') as json_file: 12 | json_list = list(json_file) 13 | 14 | homooligomeric_state = args.homooligomer 15 | 16 | if homooligomeric_state == 0: 17 | tied_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")] 18 | global_designed_chain_list = [str(item) for item in args.chain_list.split()] 19 | my_dict = {} 20 | for json_str in json_list: 21 | result = json.loads(json_str) 22 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 23 | tied_positions_list = [] 24 | for i, pos in enumerate(tied_list[0]): 25 | temp_dict = {} 26 | for j, chain in enumerate(global_designed_chain_list): 27 | temp_dict[chain] = [tied_list[j][i]] #needs to be a list 28 | tied_positions_list.append(temp_dict) 29 | my_dict[result['name']] = tied_positions_list 30 | else: 31 | my_dict = {} 32 | for json_str in json_list: 33 | result = json.loads(json_str) 34 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ... 35 | tied_positions_list = [] 36 | chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) 37 | for i in range(1,chain_length+1): 38 | temp_dict = {} 39 | for j, chain in enumerate(all_chain_list): 40 | temp_dict[chain] = [i] #needs to be a list 41 | tied_positions_list.append(temp_dict) 42 | my_dict[result['name']] = tied_positions_list 43 | 44 | with open(args.output_path, 'w') as f: 45 | f.write(json.dumps(my_dict) + '\n') 46 | 47 | if __name__ == "__main__": 48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs") 50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary") 51 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed") 52 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain") 53 | argparser.add_argument("--homooligomer", type=int, default=0, help="If 0 do not use, if 1 then design homooligomer") 54 | 55 | args = argparser.parse_args() 56 | main(args) 57 | 58 | 59 | #e.g. output 60 | #{"5TTA": [], "3LIS": [{"A": [1], "B": [1]}, {"A": [2], "B": [2]}, {"A": [3], "B": [3]}, {"A": [4], "B": [4]}, {"A": [5], "B": [5]}, {"A": [6], "B": [6]}, {"A": [7], "B": [7]}, {"A": [8], "B": [8]}, {"A": [9], "B": [9]}, {"A": [10], "B": [10]}, {"A": [11], "B": [11]}, {"A": [12], "B": [12]}, {"A": [13], "B": [13]}, {"A": [14], "B": [14]}, {"A": [15], "B": [15]}, {"A": [16], "B": [16]}, {"A": [17], "B": [17]}, {"A": [18], "B": [18]}, {"A": [19], "B": [19]}, {"A": [20], "B": [20]}, {"A": [21], "B": [21]}, {"A": [22], "B": [22]}, {"A": [23], "B": [23]}, {"A": [24], "B": [24]}, {"A": [25], "B": [25]}, {"A": [26], "B": [26]}, {"A": [27], "B": [27]}, {"A": [28], "B": [28]}, {"A": [29], "B": [29]}, {"A": [30], "B": [30]}, {"A": [31], "B": [31]}, {"A": [32], "B": [32]}, {"A": [33], "B": [33]}, {"A": [34], "B": [34]}, {"A": [35], "B": [35]}, {"A": [36], "B": [36]}, {"A": [37], "B": [37]}, {"A": [38], "B": [38]}, {"A": [39], "B": [39]}, {"A": [40], "B": [40]}, {"A": [41], "B": [41]}, {"A": [42], "B": [42]}, {"A": [43], "B": [43]}, {"A": [44], "B": [44]}, {"A": [45], "B": [45]}, {"A": [46], "B": [46]}, {"A": [47], "B": [47]}, {"A": [48], "B": [48]}, {"A": [49], "B": [49]}, {"A": [50], "B": [50]}, {"A": [51], "B": [51]}, {"A": [52], "B": [52]}, {"A": [53], "B": [53]}, {"A": [54], "B": [54]}, {"A": [55], "B": [55]}, {"A": [56], "B": [56]}, {"A": [57], "B": [57]}, {"A": [58], "B": [58]}, {"A": [59], "B": [59]}, {"A": [60], "B": [60]}, {"A": [61], "B": [61]}, {"A": [62], "B": [62]}, {"A": [63], "B": [63]}, {"A": [64], "B": [64]}, {"A": [65], "B": [65]}, {"A": [66], "B": [66]}, {"A": [67], "B": [67]}, {"A": [68], "B": [68]}, {"A": [69], "B": [69]}, {"A": [70], "B": [70]}, {"A": [71], "B": [71]}, {"A": [72], "B": [72]}, {"A": [73], "B": [73]}, {"A": [74], "B": [74]}, {"A": [75], "B": [75]}, {"A": [76], "B": [76]}, {"A": [77], "B": [77]}, {"A": [78], "B": [78]}, {"A": [79], "B": [79]}, {"A": [80], "B": [80]}, {"A": [81], "B": [81]}, {"A": [82], "B": [82]}, {"A": [83], "B": [83]}, {"A": [84], "B": [84]}, {"A": [85], "B": [85]}, {"A": [86], "B": [86]}, {"A": [87], "B": [87]}, {"A": [88], "B": [88]}, {"A": [89], "B": [89]}, {"A": [90], "B": [90]}, {"A": [91], "B": [91]}, {"A": [92], "B": [92]}, {"A": [93], "B": [93]}, {"A": [94], "B": [94]}, {"A": [95], "B": [95]}, {"A": [96], "B": [96]}]} 61 | 62 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/other_tools/make_omit_AA.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import numpy as np 4 | import json 5 | import itertools 6 | 7 | #MODIFY this path 8 | with open('/home/justas/projects/lab_github/mpnn/data/pdbs.jsonl', 'r') as json_file: 9 | json_list = list(json_file) 10 | 11 | my_dict = {} 12 | for json_str in json_list: 13 | result = json.loads(json_str) 14 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 15 | fixed_position_dict = {} 16 | print(result['name']) 17 | if result['name'] == '5TTA': 18 | for chain in all_chain_list: 19 | if chain == 'A': 20 | fixed_position_dict[chain] = [ 21 | [[int(item) for item in list(itertools.chain(list(np.arange(1,4)), list(np.arange(7,10)), [22, 25, 33]))], 'GPL'], 22 | [[int(item) for item in list(itertools.chain([40, 41, 42, 43]))], 'WC'], 23 | [[int(item) for item in list(itertools.chain(list(np.arange(50,150))))], 'ACEFGHIKLMNRSTVWYX'], 24 | [[int(item) for item in list(itertools.chain(list(np.arange(160,200))))], 'FGHIKLPQDMNRSTVWYX']] 25 | else: 26 | fixed_position_dict[chain] = [] 27 | else: 28 | for chain in all_chain_list: 29 | fixed_position_dict[chain] = [] 30 | my_dict[result['name']] = fixed_position_dict 31 | 32 | #MODIFY this path 33 | with open('/home/justas/projects/lab_github/mpnn/data/omit_AA.jsonl', 'w') as f: 34 | f.write(json.dumps(my_dict) + '\n') 35 | 36 | 37 | print('Finished') 38 | #e.g. output 39 | #{"5TTA": {"A": [[[1, 2, 3, 7, 8, 9, 22, 25, 33], "GPL"], [[40, 41, 42, 43], "WC"], [[50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149], "ACEFGHIKLMNRSTVWYX"], [[160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199], "FGHIKLPQDMNRSTVWYX"]], "B": []}, "3LIS": {"A": [], "B": []}} 40 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/other_tools/make_pssm_dict.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import glob 5 | import random 6 | import numpy as np 7 | import json 8 | 9 | 10 | def softmax(x, T): 11 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True) 12 | 13 | def parse_pssm(path): 14 | data = pd.read_csv(path, skiprows=2) 15 | floats_list_list = [] 16 | for i in range(data.values.shape[0]): 17 | str1 = data.values[i][0][4:] 18 | floats_list = [] 19 | for item in str1.split(): 20 | floats_list.append(float(item)) 21 | floats_list_list.append(floats_list) 22 | np_lines = np.array(floats_list_list) 23 | return np_lines 24 | 25 | np_lines = parse_pssm('/home/swang523/RLcage/capsid/monomersfordesign/8-16-21/pssm_rainity_final_8-16-21_int/build_0.2089_0.98_0.4653_19_2.00_0.005745.pssm') 26 | 27 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX' 28 | input_alphabet = 'ARNDCQEGHILKMFPSTWYV' 29 | 30 | permutation_matrix = np.zeros([20,21]) 31 | for i in range(20): 32 | letter1 = input_alphabet[i] 33 | for j in range(21): 34 | letter2 = mpnn_alphabet[j] 35 | if letter1 == letter2: 36 | permutation_matrix[i,j]=1. 37 | 38 | pssm_log_odds = np_lines[:,:20] @ permutation_matrix 39 | pssm_probs = np_lines[:,20:40] @ permutation_matrix 40 | 41 | X_mask = np.concatenate([np.zeros([1,20]), np.ones([1,1])], -1) 42 | 43 | def softmax(x, T): 44 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True) 45 | 46 | #Load parsed PDBs: 47 | with open('/home/justas/projects/cages/parsed/test.jsonl', 'r') as json_file: 48 | json_list = list(json_file) 49 | 50 | my_dict = {} 51 | for json_str in json_list: 52 | result = json.loads(json_str) 53 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] 54 | pssm_dict = {} 55 | for chain in all_chain_list: 56 | pssm_dict[chain] = {} 57 | pssm_dict[chain]['pssm_coef'] = (np.ones(len(result['seq_chain_A']))).tolist() #a number between 0.0 and 1.0 specifying how much attention put to PSSM, can be adjusted later as a flag 58 | pssm_dict[chain]['pssm_bias'] = (softmax(pssm_log_odds-X_mask*1e8, 1.0)).tolist() #PSSM like, [length, 21] such that sum over the last dimension adds up to 1.0 59 | pssm_dict[chain]['pssm_log_odds'] = (pssm_log_odds).tolist() 60 | my_dict[result['name']] = pssm_dict 61 | 62 | #Write output to: 63 | with open('/home/justas/projects/lab_github/mpnn/data/pssm_dict.jsonl', 'w') as f: 64 | f.write(json.dumps(my_dict) + '\n') 65 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/parse_multiple_chains.out: -------------------------------------------------------------------------------- 1 | Successfully finished: 2 pdbs 2 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/helper_scripts/parse_multiple_chains.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem=32g 3 | #SBATCH -c 2 4 | #SBATCH --output=parse_multiple_chains.out 5 | 6 | source activate mlfold 7 | python parse_multiple_chains.py --input_path='../PDB_complexes/pdbs/' --output_path='../PDB_complexes/parsed_pdbs.jsonl' 8 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/inputs/PSSM_inputs/3HTN.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/inputs/PSSM_inputs/3HTN.npz -------------------------------------------------------------------------------- /forks/ProteinMPNN/inputs/PSSM_inputs/4YOW.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/inputs/PSSM_inputs/4YOW.npz -------------------------------------------------------------------------------- /forks/ProteinMPNN/outputs/training_test_output/seqs/5L33.fa: -------------------------------------------------------------------------------- 1 | >5L33, score=1.9496, global_score=1.9496, fixed_chains=[], designed_chains=['A'], model_name=epoch_last, git_hash=78fdc70e6da22ba7a535f8913f62310b738f0b33, seed=37 2 | HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDNNGQTEQWEVQIEVRNGRIKRVTITHV 3 | >T=0.1, sample=1, score=1.3913, global_score=1.3913, seq_recovery=0.3396 4 | ALPAAEKVALALLDALATGDPELLKAVLTADSKFTDNGKEFKGEDLVDFVEKLKKEGKKFKPTSGSVTGDSFTLTLTVSSNGKTETATLTVKVENGKLSSLTITKN 5 | >T=0.1, sample=2, score=1.3315, global_score=1.3315, seq_recovery=0.3868 6 | KLPEEEKVAKELLKALENGDPELAKKVLTPDAKFTLNGEEFSGEELVKFVKELKEKGEKFKPLSSKKDGDSYTFTLKLSKNGKTETATLKVKVKNGKVEEIELSSD 7 | >T=0.1, sample=3, score=1.3452, global_score=1.3452, seq_recovery=0.4151 8 | KLPEEEKVAKKLLEALEKKDPELLKKVLTPDSEFTINGKKFKGEELVKLVKELKKKGEKFELVSSSKTGDSFTFTLKVSKNGKTKEATLTVTVKNGKLDKLTLSFK 9 | >T=0.1, sample=4, score=1.3763, global_score=1.3763, seq_recovery=0.3208 10 | KLDEKLKVAEKLIKALENGDPALLKEVLTKDSKFTINGKEFTGEDAVDFVKKLKKAGEKFKKLSGELKGDEFTFKLELEKDGEKKTAELTVKVENGKLTSLTIKDK 11 | >T=0.1, sample=5, score=1.3108, global_score=1.3108, seq_recovery=0.3208 12 | KLPEEEKVAKKLLDALENKDPELAKKVLTKDAKFKINGKEFSGEDLVDFVKKLKEDGKEFKLLSGKKVGDKYVFTLKISKDGEEKEAKLEVKVKNGKVEEIKIESK 13 | >T=0.1, sample=6, score=1.3122, global_score=1.3122, seq_recovery=0.3302 14 | KLPEEEKVAKKLLEALEKGDPELLKEVLTKDAEFTKNGEKFKGPDLVKFVEKLKAAGEKFKLLSSKKEGDKYTLTLELEKNGEKKKATLTVDVKNGKVESLELSDK 15 | >T=0.1, sample=7, score=1.3217, global_score=1.3217, seq_recovery=0.3774 16 | KKPEKEKVADKLVDALEKGDPELLKKVLTKDAKFEKNGKKFKGEDLVKYVEELKAKGEKFEPVGGEKKGDSYKFKLKISKNGKTKTATLEVKVENGKVKELKISEK 17 | >T=0.1, sample=8, score=1.2603, global_score=1.2603, seq_recovery=0.3491 18 | KLPEKEKVAKKLLDALEKGDPELLKEVLNKDAKFTINGKKFKGDDLVKFVEELKKKGEKFEPLSGEKKGDEFVFKLKIEKNGEKKEVELKVKVEDGKVKDIEIKDK 19 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/soluble_model_weights/v_48_002.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_002.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/soluble_model_weights/v_48_010.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_010.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/soluble_model_weights/v_48_020.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_020.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/soluble_model_weights/v_48_030.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_030.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/training/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Justas Dauparas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/training/exp_020/log.txt: -------------------------------------------------------------------------------- 1 | Epoch Train Validation 2 | epoch: 1, step: 74, time: 45.7, train: 23.565, valid: 17.468, train_acc: 0.072, valid_acc: 0.113 3 | epoch: 2, step: 148, time: 36.1, train: 17.790, valid: 15.872, train_acc: 0.108, valid_acc: 0.133 4 | epoch: 3, step: 221, time: 41.2, train: 16.398, valid: 14.980, train_acc: 0.128, valid_acc: 0.154 5 | epoch: 4, step: 294, time: 39.3, train: 15.417, valid: 15.041, train_acc: 0.148, valid_acc: 0.153 6 | epoch: 5, step: 369, time: 40.8, train: 14.736, valid: 13.124, train_acc: 0.167, valid_acc: 0.202 7 | epoch: 6, step: 444, time: 39.7, train: 13.954, valid: 12.683, train_acc: 0.187, valid_acc: 0.209 8 | epoch: 7, step: 515, time: 41.2, train: 13.665, valid: 12.079, train_acc: 0.193, valid_acc: 0.229 9 | epoch: 8, step: 586, time: 39.6, train: 13.105, valid: 11.938, train_acc: 0.208, valid_acc: 0.237 10 | epoch: 9, step: 656, time: 41.0, train: 12.714, valid: 11.232, train_acc: 0.219, valid_acc: 0.248 11 | epoch: 10, step: 726, time: 37.7, train: 12.391, valid: 11.386, train_acc: 0.225, valid_acc: 0.249 12 | epoch: 11, step: 796, time: 39.2, train: 12.098, valid: 10.990, train_acc: 0.231, valid_acc: 0.261 13 | epoch: 12, step: 866, time: 37.0, train: 11.845, valid: 10.554, train_acc: 0.238, valid_acc: 0.267 14 | epoch: 13, step: 940, time: 42.0, train: 11.742, valid: 10.673, train_acc: 0.240, valid_acc: 0.270 15 | epoch: 14, step: 1014, time: 38.7, train: 11.503, valid: 10.455, train_acc: 0.245, valid_acc: 0.271 16 | epoch: 15, step: 1089, time: 42.3, train: 11.284, valid: 10.303, train_acc: 0.251, valid_acc: 0.278 17 | epoch: 16, step: 1164, time: 40.1, train: 11.335, valid: 9.982, train_acc: 0.249, valid_acc: 0.285 18 | epoch: 17, step: 1239, time: 43.7, train: 10.959, valid: 9.796, train_acc: 0.260, valid_acc: 0.292 19 | epoch: 18, step: 1314, time: 36.3, train: 10.726, valid: 9.472, train_acc: 0.265, valid_acc: 0.301 20 | epoch: 19, step: 1383, time: 44.5, train: 10.730, valid: 9.604, train_acc: 0.267, valid_acc: 0.295 21 | epoch: 20, step: 1452, time: 34.0, train: 10.583, valid: 9.378, train_acc: 0.270, valid_acc: 0.305 22 | epoch: 21, step: 1522, time: 41.2, train: 10.396, valid: 9.458, train_acc: 0.275, valid_acc: 0.304 23 | epoch: 22, step: 1592, time: 37.0, train: 10.324, valid: 9.444, train_acc: 0.277, valid_acc: 0.301 24 | epoch: 23, step: 1662, time: 39.9, train: 10.250, valid: 9.518, train_acc: 0.278, valid_acc: 0.302 25 | epoch: 24, step: 1732, time: 38.0, train: 10.092, valid: 9.048, train_acc: 0.284, valid_acc: 0.315 26 | epoch: 25, step: 1808, time: 46.0, train: 10.221, valid: 8.969, train_acc: 0.279, valid_acc: 0.322 27 | epoch: 26, step: 1884, time: 36.7, train: 10.010, valid: 8.876, train_acc: 0.285, valid_acc: 0.320 28 | epoch: 27, step: 1959, time: 43.3, train: 10.013, valid: 8.652, train_acc: 0.286, valid_acc: 0.328 29 | epoch: 28, step: 2034, time: 40.4, train: 9.694, valid: 8.779, train_acc: 0.296, valid_acc: 0.324 30 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/training/exp_020/model_weights/epoch_last.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/training/exp_020/model_weights/epoch_last.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/training/submit_exp_020.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p gpu 4 | #SBATCH --mem=128g 5 | #SBATCH --gres=gpu:a100:1 6 | #SBATCH -c 12 7 | #SBATCH -t 7-00:00:00 8 | #SBATCH --output=exp_020.out 9 | 10 | source activate mlfold-test 11 | python ./training.py \ 12 | --path_for_outputs "./exp_020" \ 13 | --path_for_training_data "path_to/pdb_2021aug02" \ 14 | --num_examples_per_epoch 1000 \ 15 | --save_model_every_n_epochs 50 16 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/training/test_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu 3 | #SBATCH --mem=32g 4 | #SBATCH --gres=gpu:rtx2080:1 5 | #SBATCH -c 3 6 | #SBATCH --output=example_3_model_w_test.out 7 | 8 | source activate mlfold 9 | 10 | path_to_PDB="../inputs/PDB_monomers/pdbs/5L33.pdb" 11 | 12 | output_dir="../outputs/training_test_output" 13 | if [ ! -d $output_dir ] 14 | then 15 | mkdir -p $output_dir 16 | fi 17 | 18 | chains_to_design="A" 19 | 20 | 21 | python ../protein_mpnn_run.py \ 22 | --path_to_model_weights "../training/exp_020/model_weights" \ 23 | --model_name "epoch_last" \ 24 | --pdb_path $path_to_PDB \ 25 | --pdb_path_chains "$chains_to_design" \ 26 | --out_folder $output_dir \ 27 | --num_seq_per_target 8 \ 28 | --sampling_temp "0.1" \ 29 | --seed 37 \ 30 | --batch_size 1 31 | -------------------------------------------------------------------------------- /forks/ProteinMPNN/vanilla_model_weights/v_48_002.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_002.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/vanilla_model_weights/v_48_010.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_010.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/vanilla_model_weights/v_48_020.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_020.pt -------------------------------------------------------------------------------- /forks/ProteinMPNN/vanilla_model_weights/v_48_030.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_030.pt -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Institute for Protein Design 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/README.md: -------------------------------------------------------------------------------- 1 | # RF2NA 2 | GitHub repo for RoseTTAFold2 with nucleic acids 3 | 4 | **New: April 13, 2023 v0.2** 5 | * Updated weights (https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz) for better prediction of homodimer:DNA interactions and better DNA-specific sequence recognition 6 | * Bugfixes in MSA generation pipeline 7 | * Support for paired protein/RNA MSAs 8 | 9 | ## Installation 10 | 11 | 1. Clone the package 12 | ``` 13 | git clone https://github.com/uw-ipd/RoseTTAFold2NA.git 14 | cd RoseTTAFold2NA 15 | ``` 16 | 17 | 2. Create conda environment 18 | All external dependencies are contained in `RF2na-linux.yml` 19 | ``` 20 | # create conda environment for RoseTTAFold2NA 21 | conda env create -f RF2na-linux.yml 22 | ``` 23 | You also need to install NVIDIA's SE(3)-Transformer (**please use SE3Transformer in this repo to install**). 24 | ``` 25 | conda activate RF2NA 26 | cd SE3Transformer 27 | pip install --no-cache-dir -r requirements.txt 28 | python setup.py install 29 | ``` 30 | 31 | 3. Download pre-trained weights under network directory 32 | ``` 33 | cd network 34 | wget https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz 35 | tar xvfz RF2NA_apr23.tgz 36 | ls weights/ # it should contain a 1.1GB weights file 37 | cd .. 38 | ``` 39 | 40 | 4. Download sequence and structure databases 41 | ``` 42 | # uniref30 [46G] 43 | wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz 44 | mkdir -p UniRef30_2020_06 45 | tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06 46 | 47 | # BFD [272G] 48 | wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz 49 | mkdir -p bfd 50 | tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd 51 | 52 | # structure templates (including *_a3m.ffdata, *_a3m.ffindex) 53 | wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz 54 | tar xfz pdb100_2021Mar03.tar.gz 55 | 56 | # RNA databases 57 | mkdir -p RNA 58 | cd RNA 59 | 60 | # Rfam [300M] 61 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.full_region.gz 62 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.cm.gz 63 | gunzip Rfam.cm.gz 64 | cmpress Rfam.cm 65 | 66 | # RNAcentral [12G] 67 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/rfam/rfam_annotations.tsv.gz 68 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/id_mapping/id_mapping.tsv.gz 69 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/sequences/rnacentral_species_specific_ids.fasta.gz 70 | ../input_prep/reprocess_rnac.pl id_mapping.tsv.gz rfam_annotations.tsv.gz # ~8 minutes 71 | gunzip -c rnacentral_species_specific_ids.fasta.gz | makeblastdb -in - -dbtype nucl -parse_seqids -out rnacentral.fasta -title "RNACentral" 72 | 73 | # nt [151G] 74 | update_blastdb.pl --decompress nt 75 | cd .. 76 | ``` 77 | 78 | ## Usage 79 | ``` 80 | conda activate RF2NA 81 | cd example 82 | # run Protein/RNA prediction 83 | ../run_RF2NA.sh rna_pred 0 rna_binding_protein.fa R:RNA.fa 84 | # run Protein/DNA prediction 85 | ../run_RF2NA.sh dna_pred 0 dna_binding_protein.fa D:DNA.fa 86 | ``` 87 | ### Inputs 88 | * The first argument to the script is the output folder 89 | * The second argument to the script is an integer indicating whether to use single-sequence mode (`1`) or not (`0`) 90 | * The remaining arguments are fasta files for individual chains in the structure. Use the tags `P:xxx.fa` `R:xxx.fa` `D:xxx.fa` `S:xxx.fa` to specify protein, RNA, double-stranded DNA, and single-stranded DNA, respectively. Use the tag `PR:xxx.fa` to specify paired protein/RNA. Each chain is a separate file; 'D' will automatically generate a complementary DNA strand to the input strand. 91 | 92 | ### Expected outputs 93 | * Outputs are written to the folder provided as the first argument (`dna_pred` and `rna_pred`). 94 | * Model outputs are placed in a subfolder, `models` (e.g., `dna_pred.models`) 95 | * You will get a predicted structre with estimated per-residue LDDT in the B-factor column (`models/model_00.pdb`) 96 | * You will get a numpy `.npz` file (`models/model_00.npz`). This can be read with `numpy.load` and contains three tables (L=complex length): 97 | - dist (L x L x 37) - the predicted distogram 98 | - lddt (L) - the per-residue predicted lddt 99 | - pae (L x L) - the per-residue pair predicted error 100 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/RF2na-linux.yml: -------------------------------------------------------------------------------- 1 | name: RF2NA 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - python=3.10 9 | - pip 10 | - pytorch 11 | - requests 12 | - pytorch-cuda=11.7 13 | - dglteam/label/cu117::dgl 14 | - pyg::pyg 15 | - bioconda::mafft 16 | - bioconda::hhsuite 17 | - bioconda::blast 18 | - bioconda::hmmer>=3.3 19 | - bioconda::infernal 20 | - bioconda::cd-hit 21 | - bioconda::csblast 22 | - pip: 23 | - psutil 24 | - tqdm -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | # run docker daemon with --default-runtime=nvidia for GPU detection during build 25 | # multistage build for DGL with CUDA and FP16 26 | 27 | ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3 28 | 29 | FROM ${FROM_IMAGE_NAME} AS dgl_builder 30 | 31 | ENV DEBIAN_FRONTEND=noninteractive 32 | RUN apt-get update \ 33 | && apt-get install -y git build-essential python3-dev make cmake \ 34 | && rm -rf /var/lib/apt/lists/* 35 | WORKDIR /dgl 36 | RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git . 37 | RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake 38 | WORKDIR build 39 | RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON .. 40 | RUN make -j8 41 | 42 | 43 | FROM ${FROM_IMAGE_NAME} 44 | 45 | RUN rm -rf /workspace/* 46 | WORKDIR /workspace/se3-transformer 47 | 48 | # copy built DGL and install it 49 | COPY --from=dgl_builder /dgl ./dgl 50 | RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl 51 | 52 | ADD requirements.txt . 53 | RUN pip install --no-cache-dir --upgrade --pre pip 54 | RUN pip install --no-cache-dir -r requirements.txt 55 | ADD . . 56 | 57 | ENV DGLBACKEND=pytorch 58 | ENV OMP_NUM_THREADS=1 59 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 NVIDIA CORPORATION & AFFILIATES 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/NOTICE: -------------------------------------------------------------------------------- 1 | SE(3)-Transformer PyTorch 2 | 3 | This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public 4 | licensed under the MIT License. 5 | 6 | This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch 7 | licensed under the MIT License. 8 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/images/se3-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/images/se3-transformer.png -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/requirements.txt: -------------------------------------------------------------------------------- 1 | e3nn==0.3.3 2 | wandb==0.12.0 3 | pynvml==11.0.0 4 | git+https://github.com/NVIDIA/dllogger#egg=dllogger -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark inference performance, without bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \ 9 | --amp "$AMP" \ 10 | --batch_size "$BATCH_SIZE" \ 11 | --use_layer_norm \ 12 | --norm \ 13 | --task homo \ 14 | --seed 42 \ 15 | --benchmark 16 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark single-GPU training performance, with bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \ 9 | --amp "$AMP" \ 10 | --batch_size "$BATCH_SIZE" \ 11 | --epochs 6 \ 12 | --use_layer_norm \ 13 | --norm \ 14 | --save_ckpt_path model_qm9.pth \ 15 | --task homo \ 16 | --precompute_bases \ 17 | --seed 42 \ 18 | --benchmark 19 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to benchmark multi-GPU training performance, with bases precomputation 3 | 4 | # CLI args with defaults 5 | BATCH_SIZE=${1:-240} 6 | AMP=${2:-true} 7 | 8 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 9 | se3_transformer.runtime.training \ 10 | --amp "$AMP" \ 11 | --batch_size "$BATCH_SIZE" \ 12 | --epochs 6 \ 13 | --use_layer_norm \ 14 | --norm \ 15 | --save_ckpt_path model_qm9.pth \ 16 | --task homo \ 17 | --precompute_bases \ 18 | --seed 42 \ 19 | --benchmark 20 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/predict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | 7 | 8 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 9 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 10 | TASK=homo 11 | 12 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 13 | se3_transformer.runtime.inference \ 14 | --amp "$AMP" \ 15 | --batch_size "$BATCH_SIZE" \ 16 | --use_layer_norm \ 17 | --norm \ 18 | --load_ckpt_path model_qm9.pth \ 19 | --task "$TASK" 20 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | NUM_EPOCHS=${3:-100} 7 | LEARNING_RATE=${4:-0.002} 8 | WEIGHT_DECAY=${5:-0.1} 9 | 10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 12 | TASK=homo 13 | 14 | python -m se3_transformer.runtime.training \ 15 | --amp "$AMP" \ 16 | --batch_size "$BATCH_SIZE" \ 17 | --epochs "$NUM_EPOCHS" \ 18 | --lr "$LEARNING_RATE" \ 19 | --weight_decay "$WEIGHT_DECAY" \ 20 | --use_layer_norm \ 21 | --norm \ 22 | --save_ckpt_path model_qm9.pth \ 23 | --precompute_bases \ 24 | --seed 42 \ 25 | --task "$TASK" -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/scripts/train_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CLI args with defaults 4 | BATCH_SIZE=${1:-240} 5 | AMP=${2:-true} 6 | NUM_EPOCHS=${3:-130} 7 | LEARNING_RATE=${4:-0.01} 8 | WEIGHT_DECAY=${5:-0.1} 9 | 10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C' 12 | TASK=homo 13 | 14 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ 15 | se3_transformer.runtime.training \ 16 | --amp "$AMP" \ 17 | --batch_size "$BATCH_SIZE" \ 18 | --epochs "$NUM_EPOCHS" \ 19 | --lr "$LEARNING_RATE" \ 20 | --min_lr 0.00001 \ 21 | --weight_decay "$WEIGHT_DECAY" \ 22 | --use_layer_norm \ 23 | --norm \ 24 | --save_ckpt_path model_qm9.pth \ 25 | --precompute_bases \ 26 | --seed 42 \ 27 | --task "$TASK" -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/__init__.py -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/data_loading/__init__.py: -------------------------------------------------------------------------------- 1 | from .qm9 import QM9DataModule 2 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/data_loading/data_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import torch.distributed as dist 25 | from abc import ABC 26 | from torch.utils.data import DataLoader, DistributedSampler, Dataset 27 | 28 | from se3_transformer.runtime.utils import get_local_rank 29 | 30 | 31 | def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader: 32 | # Classic or distributed dataloader depending on the context 33 | sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None 34 | return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs) 35 | 36 | 37 | class DataModule(ABC): 38 | """ Abstract DataModule. Children must define self.ds_{train | val | test}. """ 39 | 40 | def __init__(self, **dataloader_kwargs): 41 | super().__init__() 42 | if get_local_rank() == 0: 43 | self.prepare_data() 44 | 45 | # Wait until rank zero has prepared the data (download, preprocessing, ...) 46 | if dist.is_initialized(): 47 | dist.barrier(device_ids=[get_local_rank()]) 48 | 49 | self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs} 50 | self.ds_train, self.ds_val, self.ds_test = None, None, None 51 | 52 | def prepare_data(self): 53 | """ Method called only once per node. Put here any downloading or preprocessing """ 54 | pass 55 | 56 | def train_dataloader(self) -> DataLoader: 57 | return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs) 58 | 59 | def val_dataloader(self) -> DataLoader: 60 | return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs) 61 | 62 | def test_dataloader(self) -> DataLoader: 63 | return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs) 64 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import SE3Transformer, SE3TransformerPooled 2 | from .fiber import Fiber 3 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import LinearSE3 2 | from .norm import NormSE3 3 | from .pooling import GPooling 4 | from .convolution import ConvSE3 5 | from .attention import AttentionBlockSE3 -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from typing import Dict 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | from torch import Tensor 31 | 32 | from se3_transformer.model.fiber import Fiber 33 | 34 | 35 | class LinearSE3(nn.Module): 36 | """ 37 | Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. 38 | Maps a fiber to a fiber with the same degrees (channels may be different). 39 | No interaction between degrees, but interaction between channels. 40 | 41 | type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels) 42 | type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels) 43 | : 44 | type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels) 45 | """ 46 | 47 | def __init__(self, fiber_in: Fiber, fiber_out: Fiber): 48 | super().__init__() 49 | self.weights = nn.ParameterDict({ 50 | str(degree_out): nn.Parameter( 51 | torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) 52 | for degree_out, channels_out in fiber_out 53 | }) 54 | 55 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: 56 | return { 57 | degree: self.weights[degree] @ features[degree] 58 | for degree, weight in self.weights.items() 59 | } 60 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | 25 | from typing import Dict 26 | 27 | import torch 28 | import torch.nn as nn 29 | from torch import Tensor 30 | from torch.cuda.nvtx import range as nvtx_range 31 | 32 | from se3_transformer.model.fiber import Fiber 33 | 34 | 35 | class NormSE3(nn.Module): 36 | """ 37 | Norm-based SE(3)-equivariant nonlinearity. 38 | 39 | ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐ 40 | feature_in ──┤ * ──> feature_out 41 | └──> feature_phase ────────────────────────────┘ 42 | """ 43 | 44 | NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 45 | 46 | def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): 47 | super().__init__() 48 | self.fiber = fiber 49 | self.nonlinearity = nonlinearity 50 | 51 | if len(set(fiber.channels)) == 1: 52 | # Fuse all the layer normalizations into a group normalization 53 | self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) 54 | else: 55 | # Use multiple layer normalizations 56 | self.layer_norms = nn.ModuleDict({ 57 | str(degree): nn.LayerNorm(channels) 58 | for degree, channels in fiber 59 | }) 60 | 61 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: 62 | with nvtx_range('NormSE3'): 63 | output = {} 64 | #print ('NormSE3 features',[torch.sum(torch.isnan(v)) for v in features.values()]) 65 | if hasattr(self, 'group_norm'): 66 | # Compute per-degree norms of features 67 | norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) 68 | for d in self.fiber.degrees] 69 | fused_norms = torch.cat(norms, dim=-2) 70 | 71 | # Transform the norms only 72 | new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) 73 | new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) 74 | 75 | # Scale features to the new norms 76 | for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): 77 | output[str(d)] = features[str(d)] / norm * new_norm 78 | else: 79 | for degree, feat in features.items(): 80 | norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) 81 | new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) 82 | output[degree] = new_norm * feat / norm 83 | #print ('NormSE3 output',[torch.sum(torch.isnan(v)) for v in output.values()]) 84 | 85 | return output 86 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/pooling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | from typing import Dict, Literal 25 | 26 | import torch.nn as nn 27 | from dgl import DGLGraph 28 | from dgl.nn.pytorch import AvgPooling, MaxPooling 29 | from torch import Tensor 30 | 31 | 32 | class GPooling(nn.Module): 33 | """ 34 | Graph max/average pooling on a given feature type. 35 | The average can be taken for any feature type, and equivariance will be maintained. 36 | The maximum can only be taken for invariant features (type 0). 37 | If you want max-pooling for type > 0 features, look into Vector Neurons. 38 | """ 39 | 40 | def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): 41 | """ 42 | :param feat_type: Feature type to pool 43 | :param pool: Type of pooling: max or avg 44 | """ 45 | super().__init__() 46 | assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' 47 | assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' 48 | self.feat_type = feat_type 49 | self.pool = MaxPooling() if pool == 'max' else AvgPooling() 50 | 51 | def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: 52 | pooled = self.pool(graph, features[str(self.feat_type)]) 53 | return pooled.squeeze(dim=-1) 54 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/__init__.py -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import argparse 25 | import pathlib 26 | 27 | from se3_transformer.data_loading import QM9DataModule 28 | from se3_transformer.model import SE3TransformerPooled 29 | from se3_transformer.runtime.utils import str2bool 30 | 31 | PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') 32 | 33 | paths = PARSER.add_argument_group('Paths') 34 | paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), 35 | help='Directory where the data is located or should be downloaded') 36 | paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), 37 | help='Directory where the results logs should be saved') 38 | paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', 39 | help='Name for the resulting DLLogger JSON file') 40 | paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, 41 | help='File where the checkpoint should be saved') 42 | paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, 43 | help='File of the checkpoint to be loaded') 44 | 45 | optimizer = PARSER.add_argument_group('Optimizer') 46 | optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') 47 | optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) 48 | optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) 49 | optimizer.add_argument('--momentum', type=float, default=0.9) 50 | optimizer.add_argument('--weight_decay', type=float, default=0.1) 51 | 52 | PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') 53 | PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') 54 | PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') 55 | PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') 56 | 57 | PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') 58 | PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') 59 | PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') 60 | PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') 61 | PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1, 62 | help='Do an evaluation round every N epochs') 63 | PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, 64 | help='Minimize stdout output') 65 | 66 | PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, 67 | help='Benchmark mode') 68 | 69 | QM9DataModule.add_argparse_args(PARSER) 70 | SE3TransformerPooled.add_argparse_args(PARSER) 71 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | from abc import ABC, abstractmethod 25 | 26 | import torch 27 | import torch.distributed as dist 28 | from torch import Tensor 29 | 30 | 31 | class Metric(ABC): 32 | """ Metric class with synchronization capabilities similar to TorchMetrics """ 33 | 34 | def __init__(self): 35 | self.states = {} 36 | 37 | def add_state(self, name: str, default: Tensor): 38 | assert name not in self.states 39 | self.states[name] = default.clone() 40 | setattr(self, name, default) 41 | 42 | def synchronize(self): 43 | if dist.is_initialized(): 44 | for state in self.states: 45 | dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD) 46 | 47 | def __call__(self, *args, **kwargs): 48 | self.update(*args, **kwargs) 49 | 50 | def reset(self): 51 | for name, default in self.states.items(): 52 | setattr(self, name, default.clone()) 53 | 54 | def compute(self): 55 | self.synchronize() 56 | value = self._compute().item() 57 | self.reset() 58 | return value 59 | 60 | @abstractmethod 61 | def _compute(self): 62 | pass 63 | 64 | @abstractmethod 65 | def update(self, preds: Tensor, targets: Tensor): 66 | pass 67 | 68 | 69 | class MeanAbsoluteError(Metric): 70 | def __init__(self): 71 | super().__init__() 72 | self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda')) 73 | self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda')) 74 | 75 | def update(self, preds: Tensor, targets: Tensor): 76 | preds = preds.detach() 77 | n = preds.shape[0] 78 | error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum() 79 | self.total += n 80 | self.error += error 81 | 82 | def _compute(self): 83 | return self.error / self.total 84 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='se3-transformer', 5 | packages=find_packages(), 6 | include_package_data=True, 7 | version='1.0.0', 8 | description='PyTorch + DGL implementation of SE(3)-Transformers', 9 | author='Alexandre Milesi', 10 | author_email='alexandrem@nvidia.com', 11 | ) 12 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/tests/__init__.py -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/tests/test_equivariance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import torch 25 | 26 | from se3_transformer.model import SE3Transformer 27 | from se3_transformer.model.fiber import Fiber 28 | from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot 29 | 30 | # Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) 31 | TOL = 1e-3 32 | CHANNELS, NODES = 32, 512 33 | 34 | 35 | def _get_outputs(model, R): 36 | feats0 = torch.randn(NODES, CHANNELS, 1) 37 | feats1 = torch.randn(NODES, CHANNELS, 3) 38 | 39 | coords = torch.randn(NODES, 3) 40 | graph = get_random_graph(NODES) 41 | if torch.cuda.is_available(): 42 | feats0 = feats0.cuda() 43 | feats1 = feats1.cuda() 44 | R = R.cuda() 45 | coords = coords.cuda() 46 | graph = graph.to('cuda') 47 | model.cuda() 48 | 49 | graph1 = assign_relative_pos(graph, coords) 50 | out1 = model(graph1, {'0': feats0, '1': feats1}, {}) 51 | graph2 = assign_relative_pos(graph, coords @ R) 52 | out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {}) 53 | 54 | return out1, out2 55 | 56 | 57 | def _get_model(**kwargs): 58 | return SE3Transformer( 59 | num_layers=4, 60 | fiber_in=Fiber.create(2, CHANNELS), 61 | fiber_hidden=Fiber.create(3, CHANNELS), 62 | fiber_out=Fiber.create(2, CHANNELS), 63 | fiber_edge=Fiber({}), 64 | num_heads=8, 65 | channels_div=2, 66 | **kwargs 67 | ) 68 | 69 | 70 | def test_equivariance(): 71 | model = _get_model() 72 | R = rot(*torch.rand(3)) 73 | if torch.cuda.is_available(): 74 | R = R.cuda() 75 | out1, out2 = _get_outputs(model, R) 76 | 77 | assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ 78 | f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' 79 | assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ 80 | f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}' 81 | 82 | 83 | def test_equivariance_pooled(): 84 | model = _get_model(pooling='avg', return_type=1) 85 | R = rot(*torch.rand(3)) 86 | if torch.cuda.is_available(): 87 | R = R.cuda() 88 | out1, out2 = _get_outputs(model, R) 89 | 90 | assert torch.allclose(out2, (out1 @ R), atol=TOL), \ 91 | f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' 92 | 93 | 94 | def test_invariance_pooled(): 95 | model = _get_model(pooling='avg', return_type=0) 96 | R = rot(*torch.rand(3)) 97 | if torch.cuda.is_available(): 98 | R = R.cuda() 99 | out1, out2 = _get_outputs(model, R) 100 | 101 | assert torch.allclose(out2, out1, atol=TOL), \ 102 | f'type-0 features should be invariant {get_max_diff(out1, out2)}' 103 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/SE3Transformer/tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a 4 | # copy of this software and associated documentation files (the "Software"), 5 | # to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | # and/or sell copies of the Software, and to permit persons to whom the 8 | # Software is furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | # DEALINGS IN THE SOFTWARE. 20 | # 21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES 22 | # SPDX-License-Identifier: MIT 23 | 24 | import dgl 25 | import torch 26 | 27 | 28 | def get_random_graph(N, num_edges_factor=18): 29 | graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) 30 | return graph 31 | 32 | 33 | def assign_relative_pos(graph, coords): 34 | src, dst = graph.edges() 35 | graph.edata['rel_pos'] = coords[src] - coords[dst] 36 | return graph 37 | 38 | 39 | def get_max_diff(a, b): 40 | return (a - b).abs().max().item() 41 | 42 | 43 | def rot_z(gamma): 44 | return torch.tensor([ 45 | [torch.cos(gamma), -torch.sin(gamma), 0], 46 | [torch.sin(gamma), torch.cos(gamma), 0], 47 | [0, 0, 1] 48 | ], dtype=gamma.dtype) 49 | 50 | 51 | def rot_y(beta): 52 | return torch.tensor([ 53 | [torch.cos(beta), 0, torch.sin(beta)], 54 | [0, 1, 0], 55 | [-torch.sin(beta), 0, torch.cos(beta)] 56 | ], dtype=beta.dtype) 57 | 58 | 59 | def rot(alpha, beta, gamma): 60 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 61 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/example/RNA.fa: -------------------------------------------------------------------------------- 1 | > RNA 2 | GAGAGAGAAGTCAACCAGAGAAACACACCAACCCATTGCACTCCGGGTTGGTGGTATATTACCTGGTACGGGGGAAACTTCGTGGTGGCCGGCCACCTGACA 3 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/example/dna_binding_protein.fa: -------------------------------------------------------------------------------- 1 | > ANTENNAPEDIA HOMEODOMAIN|Drosophila melanogaster (7227) 2 | MERKRGRQTYTRYQTLELEKEFHFNRYLTRRRRIEIAHALSLTERQIKIWFQNRRMKWKKEN 3 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/example/rna_binding_protein.fa: -------------------------------------------------------------------------------- 1 | > prot 2 | TRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKM 3 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/input_prep/make_protein_msa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # inputs 4 | in_fasta="$1" 5 | out_dir="$2" 6 | tag="$3" 7 | 8 | # resources 9 | CPU="$4" 10 | MEM="$5" 11 | 12 | # single-sequence mode 13 | SINGLE_SEQ_MODE="$6" 14 | 15 | # validate if the single-sequence mode argument is a valid integer 16 | re='^[0-1]+$' 17 | if ! [[ $SINGLE_SEQ_MODE =~ $re ]]; then 18 | echo "Error: The single-sequence mode argument must be an integer ('1' meaning true and '0' otherwise)." 19 | exit 1 20 | fi 21 | 22 | # sequence databases 23 | DB_UR30="$PIPEDIR/UniRef30_2020_06/UniRef30_2020_06" 24 | DB_BFD="$PIPEDIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt" 25 | 26 | # setup hhblits command 27 | HHBLITS_UR30="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_UR30" 28 | HHBLITS_BFD="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_BFD" 29 | 30 | mkdir -p $out_dir/hhblits 31 | tmp_dir="$out_dir/hhblits" 32 | out_prefix="$out_dir/$tag" 33 | 34 | echo out_prefix $out_prefix 35 | 36 | # perform iterative searches against UniRef30 37 | prev_a3m="$in_fasta" 38 | 39 | # check if single-sequence mode was requested; if so, skip the remainder of the script 40 | if [ "$SINGLE_SEQ_MODE" -eq 1 ]; then 41 | cp $prev_a3m ${out_prefix}.msa0.a3m 42 | exit 0 43 | fi 44 | 45 | for e in 1e-10 1e-6 1e-3 46 | do 47 | echo "Running HHblits against UniRef30 with E-value cutoff $e" 48 | $HHBLITS_UR30 -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0 49 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m 50 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m 51 | prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m" 52 | n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m` 53 | n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m` 54 | 55 | if ((n75>2000)) 56 | then 57 | if [ ! -s ${out_prefix}.msa0.a3m ] 58 | then 59 | cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m 60 | break 61 | fi 62 | elif ((n50>4000)) 63 | then 64 | if [ ! -s ${out_prefix}.msa0.a3m ] 65 | then 66 | cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m 67 | break 68 | fi 69 | else 70 | continue 71 | fi 72 | done 73 | 74 | # perform iterative searches against BFD if it failes to get enough sequences 75 | if [ ! -s ${out_prefix}.msa0.a3m ] 76 | then 77 | e=1e-3 78 | echo "Running HHblits against BFD with E-value cutoff $e" 79 | $HHBLITS_BFD -i $prev_a3m -oa3m $tmp_dir/t000_.$e.bfd.a3m -e $e -v 0 80 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov75.a3m 81 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov50.a3m 82 | prev_a3m="$tmp_dir/t000_.$e.bfd.id90cov50.a3m" 83 | n75=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov75.a3m` 84 | n50=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov50.a3m` 85 | 86 | if ((n75>2000)) 87 | then 88 | if [ ! -s ${out_prefix}.msa0.a3m ] 89 | then 90 | cp $tmp_dir/t000_.$e.bfd.id90cov75.a3m ${out_prefix}.msa0.a3m 91 | fi 92 | elif ((n50>4000)) 93 | then 94 | if [ ! -s ${out_prefix}.msa0.a3m ] 95 | then 96 | cp $tmp_dir/t000_.$e.bfd.id90cov50.a3m ${out_prefix}.msa0.a3m 97 | fi 98 | fi 99 | fi 100 | 101 | if [ ! -s ${out_prefix}.msa0.a3m ] 102 | then 103 | cp $prev_a3m ${out_prefix}.msa0.a3m 104 | fi -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/input_prep/reprocess_rnac.pl: -------------------------------------------------------------------------------- 1 | #! /usr/bin/perl 2 | use strict; 3 | 4 | my $taxids = shift @ARGV; 5 | my $idfile = shift @ARGV; 6 | 7 | my %ids; 8 | open(GZIN, "gunzip -c $taxids |") or die("gunzip $taxids: $!"); 9 | foreach my $line () { 10 | my ($id,$taxid); 11 | ($id,$_,$_,$taxid,$_,$_) = split ' ',$line; 12 | if (not defined $ids{$id}) { 13 | $ids{$id} = [] 14 | } 15 | if (not $taxid ~~ @{$ids{$id}}) { 16 | push (@{$ids{$id}}, $taxid) 17 | } 18 | } 19 | close(GZIN); 20 | 21 | system ("mv $idfile $idfile.bak"); 22 | open (GZOUT, "| gzip -c > $idfile") or die("gzip $idfile: $!"); 23 | open(GZIN, "gunzip -c $idfile.bak |") or die("gunzip $idfile: $!"); 24 | foreach my $line () { 25 | #URS0000000001 RF00177 109.4 3.3e-33 2 200 29 230 Bacterial small subunit ribosomal RNA 26 | my @fields = split /\t/,$line; 27 | my $id = $fields[0]; 28 | 29 | foreach my $taxid (@{$ids{$id}}) { 30 | print GZOUT $id."_".$taxid."\t".join("\t",@fields[1..$#fields]); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/AuxiliaryPredictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from chemical import NAATOKENS 4 | 5 | class DistanceNetwork(nn.Module): 6 | def __init__(self, n_feat, p_drop=0.1): 7 | super(DistanceNetwork, self).__init__() 8 | # 9 | self.proj_symm = nn.Linear(n_feat, 37*2) 10 | self.proj_asymm = nn.Linear(n_feat, 37+19) 11 | 12 | self.reset_parameter() 13 | 14 | def reset_parameter(self): 15 | # initialize linear layer for final logit prediction 16 | nn.init.zeros_(self.proj_symm.weight) 17 | nn.init.zeros_(self.proj_asymm.weight) 18 | nn.init.zeros_(self.proj_symm.bias) 19 | nn.init.zeros_(self.proj_asymm.bias) 20 | 21 | def forward(self, x): 22 | # input: pair info (B, L, L, C) 23 | 24 | # predict theta, phi (non-symmetric) 25 | logits_asymm = self.proj_asymm(x) 26 | logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2) 27 | logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2) 28 | 29 | # predict dist, omega 30 | logits_symm = self.proj_symm(x) 31 | logits_symm = logits_symm + logits_symm.permute(0,2,1,3) 32 | logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2) 33 | logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2) 34 | 35 | return logits_dist, logits_omega, logits_theta, logits_phi 36 | 37 | class MaskedTokenNetwork(nn.Module): 38 | def __init__(self, n_feat, p_drop=0.1): 39 | super(MaskedTokenNetwork, self).__init__() 40 | self.proj = nn.Linear(n_feat, NAATOKENS) 41 | 42 | self.reset_parameter() 43 | 44 | def reset_parameter(self): 45 | nn.init.zeros_(self.proj.weight) 46 | nn.init.zeros_(self.proj.bias) 47 | 48 | def forward(self, x): 49 | B, N, L = x.shape[:3] 50 | logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L) 51 | 52 | return logits 53 | 54 | class LDDTNetwork(nn.Module): 55 | def __init__(self, n_feat, n_bin_lddt=50): 56 | super(LDDTNetwork, self).__init__() 57 | self.proj = nn.Linear(n_feat, n_bin_lddt) 58 | 59 | self.reset_parameter() 60 | 61 | def reset_parameter(self): 62 | nn.init.zeros_(self.proj.weight) 63 | nn.init.zeros_(self.proj.bias) 64 | 65 | def forward(self, x): 66 | logits = self.proj(x) # (B, L, 50) 67 | 68 | return logits.permute(0,2,1) 69 | 70 | class PAENetwork(nn.Module): 71 | def __init__(self, n_feat, n_bin_pae=64): 72 | super(PAENetwork, self).__init__() 73 | self.proj = nn.Linear(n_feat, n_bin_pae) 74 | self.reset_parameter() 75 | def reset_parameter(self): 76 | nn.init.zeros_(self.proj.weight) 77 | nn.init.zeros_(self.proj.bias) 78 | 79 | def forward(self, pair, state): 80 | L = pair.shape[1] 81 | left = state.unsqueeze(2).expand(-1,-1,L,-1) 82 | right = state.unsqueeze(1).expand(-1,L,-1,-1) 83 | 84 | logits = self.proj( torch.cat((pair, left, right), dim=-1) ) # (B, L, L, 64) 85 | 86 | return logits.permute(0,3,1,2) 87 | 88 | class BinderNetwork(nn.Module): 89 | def __init__(self, n_hidden=64, n_bin_pae=64): 90 | super(BinderNetwork, self).__init__() 91 | #self.proj = nn.Linear(n_bin_pae, n_hidden) 92 | #self.classify = torch.nn.Linear(2*n_hidden, 1) 93 | self.classify = torch.nn.Linear(n_bin_pae, 1) 94 | self.reset_parameter() 95 | 96 | def reset_parameter(self): 97 | #nn.init.zeros_(self.proj.weight) 98 | #nn.init.zeros_(self.proj.bias) 99 | nn.init.zeros_(self.classify.weight) 100 | nn.init.zeros_(self.classify.bias) 101 | 102 | def forward(self, pae, same_chain): 103 | #logits = self.proj( pae.permute(0,2,3,1) ) 104 | logits = pae.permute(0,2,3,1) 105 | #logits_intra = torch.mean( logits[same_chain==1], dim=0 ) 106 | logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain 107 | #prob = torch.sigmoid( self.classify( torch.cat((logits_intra,logits_inter)) ) ) 108 | prob = torch.sigmoid( self.classify( logits_inter ) ) 109 | return prob 110 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/SE3_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | #from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias 5 | #from equivariant_attention.modules import GConvSE3, GNormSE3 6 | #from equivariant_attention.fibers import Fiber 7 | 8 | from util_module import init_lecun_normal_param 9 | from se3_transformer.model import SE3Transformer 10 | from se3_transformer.model.fiber import Fiber 11 | 12 | class SE3TransformerWrapper(nn.Module): 13 | """SE(3) equivariant GCN with attention""" 14 | def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4, 15 | l0_in_features=32, l0_out_features=32, 16 | l1_in_features=3, l1_out_features=2, 17 | num_edge_features=32): 18 | super().__init__() 19 | # Build the network 20 | self.l1_in = l1_in_features 21 | self.l1_out = l1_out_features 22 | # 23 | fiber_edge = Fiber({0: num_edge_features}) 24 | if l1_out_features > 0: 25 | if l1_in_features > 0: 26 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) 27 | fiber_hidden = Fiber.create(num_degrees, num_channels) 28 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) 29 | else: 30 | fiber_in = Fiber({0: l0_in_features}) 31 | fiber_hidden = Fiber.create(num_degrees, num_channels) 32 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) 33 | else: 34 | if l1_in_features > 0: 35 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) 36 | fiber_hidden = Fiber.create(num_degrees, num_channels) 37 | fiber_out = Fiber({0: l0_out_features}) 38 | else: 39 | fiber_in = Fiber({0: l0_in_features}) 40 | fiber_hidden = Fiber.create(num_degrees, num_channels) 41 | fiber_out = Fiber({0: l0_out_features}) 42 | 43 | self.se3 = SE3Transformer(num_layers=num_layers, 44 | fiber_in=fiber_in, 45 | fiber_hidden=fiber_hidden, 46 | fiber_out = fiber_out, 47 | num_heads=n_heads, 48 | channels_div=div, 49 | fiber_edge=fiber_edge, 50 | #populate_edges=False, 51 | #sum_over_edge=False, 52 | use_layer_norm=True, 53 | tensor_cores=True, 54 | low_memory=True)#, 55 | #populate_edge='log') 56 | 57 | self.reset_parameter() 58 | 59 | def reset_parameter(self): 60 | 61 | # make sure linear layer before ReLu are initialized with kaiming_normal_ 62 | for n, p in self.se3.named_parameters(): 63 | if "bias" in n: 64 | nn.init.zeros_(p) 65 | elif len(p.shape) == 1: 66 | continue 67 | else: 68 | if "radial_func" not in n: 69 | p = init_lecun_normal_param(p) 70 | else: 71 | if "net.6" in n: 72 | nn.init.zeros_(p) 73 | else: 74 | nn.init.kaiming_normal_(p, nonlinearity='relu') 75 | 76 | # make last layers to be zero-initialized 77 | self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0']) 78 | self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1']) 79 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0']) 80 | if self.l1_out > 0: 81 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1']) 82 | 83 | def forward(self, G, type_0_features, type_1_features=None, edge_features=None): 84 | if self.l1_in > 0: 85 | node_features = {'0': type_0_features, '1': type_1_features} 86 | else: 87 | node_features = {'0': type_0_features} 88 | edge_features = {'0': edge_features} 89 | return self.se3(G, node_features, edge_features) 90 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/coords6d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.spatial 4 | from util import generate_Cbeta 5 | 6 | # calculate dihedral angles defined by 4 sets of points 7 | def get_dihedrals(a, b, c, d): 8 | 9 | b0 = -1.0*(b - a) 10 | b1 = c - b 11 | b2 = d - c 12 | 13 | b1 /= np.linalg.norm(b1, axis=-1)[:,None] 14 | 15 | v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1 16 | w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1 17 | 18 | x = np.sum(v*w, axis=-1) 19 | y = np.sum(np.cross(b1, v)*w, axis=-1) 20 | 21 | return np.arctan2(y, x) 22 | 23 | # calculate planar angles defined by 3 sets of points 24 | def get_angles(a, b, c): 25 | 26 | v = a - b 27 | v /= np.linalg.norm(v, axis=-1)[:,None] 28 | 29 | w = c - b 30 | w /= np.linalg.norm(w, axis=-1)[:,None] 31 | 32 | x = np.sum(v*w, axis=1) 33 | 34 | #return np.arccos(x) 35 | return np.arccos(np.clip(x, -1.0, 1.0)) 36 | 37 | # get 6d coordinates from x,y,z coords of N,Ca,C atoms 38 | def get_coords6d(xyz, dmax): 39 | 40 | nres = xyz.shape[1] 41 | 42 | # three anchor atoms 43 | N = xyz[0] 44 | Ca = xyz[1] 45 | C = xyz[2] 46 | 47 | # recreate Cb given N,Ca,C 48 | Cb = generate_Cbeta(N,Ca,C) 49 | 50 | # fast neighbors search to collect all 51 | # Cb-Cb pairs within dmax 52 | kdCb = scipy.spatial.cKDTree(Cb) 53 | indices = kdCb.query_ball_tree(kdCb, dmax) 54 | 55 | # indices of contacting residues 56 | idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T 57 | idx0 = idx[0] 58 | idx1 = idx[1] 59 | 60 | # Cb-Cb distance matrix 61 | dist6d = np.full((nres, nres),999.9, dtype=np.float32) 62 | dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1) 63 | 64 | # matrix of Ca-Cb-Cb-Ca dihedrals 65 | omega6d = np.zeros((nres, nres), dtype=np.float32) 66 | omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) 67 | 68 | # matrix of polar coord theta 69 | theta6d = np.zeros((nres, nres), dtype=np.float32) 70 | theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) 71 | 72 | # matrix of polar coord phi 73 | phi6d = np.zeros((nres, nres), dtype=np.float32) 74 | phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1]) 75 | 76 | mask = np.zeros((nres, nres), dtype=np.float32) 77 | mask[idx0, idx1] = 1.0 78 | return dist6d, omega6d, theta6d, phi6d, mask 79 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/ffindex.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py 3 | 4 | ''' 5 | Created on Apr 30, 2014 6 | 7 | @author: meiermark 8 | ''' 9 | 10 | 11 | import os 12 | import mmap 13 | from collections import namedtuple 14 | 15 | FFindexEntry = namedtuple("FFindexEntry", "name, offset, length") 16 | 17 | 18 | def read_index(ffindex_filename): 19 | if os.path.getsize(ffindex_filename) == 0: 20 | return None 21 | 22 | entries = [] 23 | 24 | fh = open(ffindex_filename) 25 | for line in fh: 26 | tokens = line.split("\t") 27 | entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2]))) 28 | fh.close() 29 | 30 | return entries 31 | 32 | 33 | def read_data(ffdata_filename): 34 | if os.path.getsize(ffdata_filename) == 0: 35 | return None 36 | 37 | fh = open(ffdata_filename, "rb") 38 | data = mmap.mmap(fh.fileno(), 0, prot=mmap.PROT_READ) 39 | fh.close() 40 | 41 | return data 42 | 43 | 44 | def get_entry_by_name(name, index): 45 | #TODO: bsearch 46 | if index is None: 47 | return None 48 | for entry in index: 49 | if(name == entry.name): 50 | return entry 51 | return None 52 | 53 | 54 | def read_entry_lines(entry, data): 55 | if data is None: 56 | return [] 57 | lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n") 58 | return lines 59 | 60 | 61 | def read_entry_data(entry, data): 62 | return data[entry.offset:entry.offset + entry.length - 1] 63 | 64 | 65 | def write_entry(entries, data_fh, entry_name, offset, data): 66 | data_fh.write(data[:-1]) 67 | data_fh.write(bytearray(1)) 68 | 69 | entry = FFindexEntry(entry_name, offset, len(data)) 70 | entries.append(entry) 71 | 72 | return offset + len(data) 73 | 74 | 75 | def write_entry_with_file(entries, data_fh, entry_name, offset, file_name): 76 | with open(file_name, "rb") as fh: 77 | data = bytearray(fh.read()) 78 | return write_entry(entries, data_fh, entry_name, offset, data) 79 | 80 | 81 | def finish_db(entries, ffindex_filename, data_fh): 82 | data_fh.close() 83 | write_entries_to_db(entries, ffindex_filename) 84 | 85 | 86 | def write_entries_to_db(entries, ffindex_filename): 87 | sorted(entries, key=lambda x: x.name) 88 | index_fh = open(ffindex_filename, "w") 89 | 90 | for entry in entries: 91 | index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length)) 92 | 93 | index_fh.close() 94 | 95 | 96 | def write_entry_to_file(entry, data, file): 97 | lines = read_lines(entry, data) 98 | 99 | fh = open(file, "w") 100 | for line in lines: 101 | fh.write(line+"\n") 102 | fh.close() 103 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/models.json: -------------------------------------------------------------------------------- 1 | { 2 | "full_bigSE3": 3 | { 4 | "description": "deep architecture w/ big SE(3)-Transformer on fully connected graph. Trained on biounit", 5 | "model_param":{ 6 | "n_extra_block" : 4, 7 | "n_main_block" : 32, 8 | "n_ref_block" : 4, 9 | "d_msa" : 256 , 10 | "d_pair" : 128, 11 | "d_templ" : 64, 12 | "n_head_msa" : 8, 13 | "n_head_pair" : 4, 14 | "n_head_templ" : 4, 15 | "d_hidden" : 32, 16 | "d_hidden_templ" : 64, 17 | "p_drop" : 0.0, 18 | "lj_lin" : 0.75, 19 | "SE3_param": { 20 | "num_layers" : 1, 21 | "num_channels" : 32, 22 | "num_degrees" : 2, 23 | "l0_in_features": 64, 24 | "l0_out_features": 64, 25 | "l1_in_features": 3, 26 | "l1_out_features": 2, 27 | "num_edge_features": 64, 28 | "div": 4, 29 | "n_heads": 4 30 | } 31 | }, 32 | "weight_fn": ["full_bigSE3_model1.pt", "full_bigSE3_model2.pt", "full_bigSE3_model3.pt"] 33 | }, 34 | "full_smallSE3": 35 | { 36 | "description": "deep architecture w/ small SE(3)-Transformer on fully connected graph. Trained on biounit", 37 | "model_param":{ 38 | "n_extra_block" : 4, 39 | "n_main_block" : 32, 40 | "n_ref_block" : 4, 41 | "d_msa" : 256 , 42 | "d_pair" : 128, 43 | "d_templ" : 64, 44 | "n_head_msa" : 8, 45 | "n_head_pair" : 4, 46 | "n_head_templ" : 4, 47 | "d_hidden" : 32, 48 | "d_hidden_templ" : 32, 49 | "p_drop" : 0.0, 50 | "SE3_param_full": { 51 | "num_layers" : 1, 52 | "num_channels" : 32, 53 | "num_degrees" : 2, 54 | "l0_in_features": 8, 55 | "l0_out_features": 8, 56 | "l1_in_features": 3, 57 | "l1_out_features": 2, 58 | "num_edge_features": 32, 59 | "div": 4, 60 | "n_heads": 4 61 | }, 62 | "SE3_param_topk": { 63 | "num_layers" : 1, 64 | "num_channels" : 32, 65 | "num_degrees" : 2, 66 | "l0_in_features": 64, 67 | "l0_out_features": 64, 68 | "l1_in_features": 3, 69 | "l1_out_features": 2, 70 | "num_edge_features": 64, 71 | "div": 4, 72 | "n_heads": 4 73 | } 74 | }, 75 | "weight_fn": ["full_smallSE3_model1.pt", "full_smallSE3_model2.pt"] 76 | } 77 | } 78 | 79 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | 5 | # pre-activation bottleneck resblock 6 | class ResBlock2D_bottleneck(nn.Module): 7 | def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15): 8 | super(ResBlock2D_bottleneck, self).__init__() 9 | padding = self._get_same_padding(kernel, dilation) 10 | 11 | n_b = n_c // 2 # bottleneck channel 12 | 13 | layer_s = list() 14 | # pre-activation 15 | layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) 16 | layer_s.append(nn.ELU(inplace=True)) 17 | # project down to n_b 18 | layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False)) 19 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) 20 | layer_s.append(nn.ELU(inplace=True)) 21 | # convolution 22 | layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False)) 23 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) 24 | layer_s.append(nn.ELU(inplace=True)) 25 | # dropout 26 | layer_s.append(nn.Dropout(p_drop)) 27 | # project up 28 | layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False)) 29 | 30 | # make final layer initialize with zeros 31 | #nn.init.zeros_(layer_s[-1].weight) 32 | 33 | self.layer = nn.Sequential(*layer_s) 34 | 35 | self.reset_parameter() 36 | 37 | def reset_parameter(self): 38 | # zero-initialize final layer right before residual connection 39 | nn.init.zeros_(self.layer[-1].weight) 40 | 41 | def _get_same_padding(self, kernel, dilation): 42 | return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2 43 | 44 | def forward(self, x): 45 | out = self.layer(x) 46 | return x + out 47 | 48 | class ResidualNetwork(nn.Module): 49 | def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out, 50 | dilation=[1,2,4,8], p_drop=0.15): 51 | super(ResidualNetwork, self).__init__() 52 | 53 | 54 | layer_s = list() 55 | # project to n_feat_block 56 | if n_feat_in != n_feat_block: 57 | layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False)) 58 | 59 | # add resblocks 60 | for i_block in range(n_block): 61 | d = dilation[i_block%len(dilation)] 62 | res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop) 63 | layer_s.append(res_block) 64 | 65 | if n_feat_out != n_feat_block: 66 | # project to n_feat_out 67 | layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1)) 68 | 69 | self.layer = nn.Sequential(*layer_s) 70 | 71 | def forward(self, x): 72 | return self.layer(x) 73 | -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffdata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffdata -------------------------------------------------------------------------------- /forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffindex: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffindex -------------------------------------------------------------------------------- /img/Nucleic_Acid_Diffusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/img/Nucleic_Acid_Diffusion.gif -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/logs/.gitkeep -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="MMDiff", 7 | version="0.0.1", 8 | description="Official PyTorch implementation of 'Joint Sequence-Structure Generation of Nucleic Acid and Protein Complexes with SE(3)-Discrete Diffusion'.", 9 | author="Alex Morehead", 10 | author_email="alex.morehead@gmail.com", 11 | url="https://github.com/Profluent-Internships/MMDiff", 12 | install_requires=["lightning", "hydra-core"], 13 | packages=find_packages(), 14 | # use this to customize global commands available in the terminal after installing the package 15 | entry_points={ 16 | "console_scripts": [ 17 | "train_command = src.train:main", 18 | "sample_command = src.sample:main", 19 | ] 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from beartype import beartype 4 | from beartype.typing import Any 5 | from omegaconf import OmegaConf 6 | 7 | 8 | @beartype 9 | def resolve_omegaconf_variable(variable_path: str) -> Any: 10 | # split the string into parts using the dot separator 11 | parts = variable_path.rsplit(".", 1) 12 | 13 | # get the module name from the first part of the path 14 | module_name = parts[0] 15 | 16 | # dynamically import the module using the module name 17 | module = importlib.import_module(module_name) 18 | 19 | # use the imported module to get the requested attribute value 20 | attribute = getattr(module, parts[1]) 21 | 22 | return attribute 23 | 24 | 25 | def register_custom_omegaconf_resolvers(): 26 | OmegaConf.register_new_resolver("add", lambda x, y: x + y) 27 | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) 28 | OmegaConf.register_new_resolver( 29 | "resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path) 30 | ) 31 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/data/components/__init__.py -------------------------------------------------------------------------------- /src/data/components/pdb/complex_constants.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for (https://github.com/Profluent-Internships/MMDiff): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | NUM_PROT_TORSIONS = 1 6 | NUM_NA_TORSIONS = 8 7 | NUM_PROT_NA_TORSIONS = 10 8 | PYRIMIDINE_RESIDUE_TOKENS = [22, 24, 26, 28] 9 | 10 | # This is the standard residue order when coding protein-nucleic acid residue types as a number (for PDB visualization purposes). 11 | restypes = [ 12 | "A", 13 | "C", 14 | "D", 15 | "E", 16 | "F", 17 | "G", 18 | "H", 19 | "I", 20 | "K", 21 | "L", 22 | "M", 23 | "N", 24 | "P", 25 | "Q", 26 | "R", 27 | "S", 28 | "T", 29 | "V", 30 | "W", 31 | "Y", 32 | "X", 33 | "da", 34 | "dc", 35 | "dg", 36 | "dt", 37 | "a", 38 | "c", 39 | "g", 40 | "u", 41 | "x", 42 | "-", 43 | "_", 44 | "1", 45 | "2", 46 | "3", 47 | "4", 48 | "5", 49 | ] 50 | restype_order = {restype: i for i, restype in enumerate(restypes)} 51 | restype_num = len(restypes) # := 37. 52 | 53 | restype_1to3 = { 54 | "A": "ALA", 55 | "C": "CYS", 56 | "D": "ASP", 57 | "E": "GLU", 58 | "F": "PHE", 59 | "G": "GLY", 60 | "H": "HIS", 61 | "I": "ILE", 62 | "K": "LYS", 63 | "L": "LEU", 64 | "M": "MET", 65 | "N": "ASN", 66 | "P": "PRO", 67 | "Q": "GLN", 68 | "R": "ARG", 69 | "S": "SER", 70 | "T": "THR", 71 | "V": "VAL", 72 | "W": "TRP", 73 | "Y": "TYR", 74 | "X": "UNK", 75 | "da": "DA", 76 | "dc": "DC", 77 | "dg": "DG", 78 | "dt": "DT", 79 | "a": "A", 80 | "c": "C", 81 | "g": "G", 82 | "u": "U", 83 | "x": "unk", 84 | "-": "GAP", 85 | "_": "PAD", 86 | "1": "SP1", 87 | "2": "SP2", 88 | "3": "SP3", 89 | "4": "SP4", 90 | "5": "SP5", 91 | } 92 | restype_3to1 = {v: k for k, v in restype_1to3.items()} 93 | restypes_3 = list(restype_3to1.keys()) 94 | restypes_1 = list(restype_1to3.keys()) 95 | 96 | protein_restypes = [ 97 | "A", 98 | "R", 99 | "N", 100 | "D", 101 | "C", 102 | "Q", 103 | "E", 104 | "G", 105 | "H", 106 | "I", 107 | "L", 108 | "K", 109 | "M", 110 | "F", 111 | "P", 112 | "S", 113 | "T", 114 | "W", 115 | "Y", 116 | "V", 117 | ] 118 | nucleic_restypes = ["DA", "DC", "DG", "DT", "A", "C", "G", "U"] 119 | special_restypes = ["-", "_", "1", "2", "3", "4", "5"] 120 | unknown_protein_restype = "X" 121 | unknown_nucleic_restype = "x" 122 | gap_token = restypes.index("-") 123 | pad_token = restypes.index("_") 124 | unknown_protein_token = restypes.index(unknown_protein_restype) # := 20 125 | unknown_nucleic_token = restypes.index(unknown_nucleic_restype) # := 29 126 | 127 | protein_restype_num = len(protein_restypes + [unknown_protein_restype]) # := 21 128 | na_restype_num = len(nucleic_restypes + [unknown_nucleic_restype]) # := 9 129 | protein_na_restype_num = len( 130 | protein_restypes + [unknown_protein_restype] + nucleic_restypes + [unknown_nucleic_restype] 131 | ) # := 30 132 | 133 | default_protein_restype = restypes.index("A") # := 0 134 | default_na_restype = restypes.index("da") # := 21 135 | 136 | alternative_restypes_map = { 137 | # Protein 138 | "MSE": "MET", 139 | } 140 | allowable_restypes = set(restypes + list(alternative_restypes_map.keys())) 141 | -------------------------------------------------------------------------------- /src/data/components/pdb/errors.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from se3_diffusion (https://github.com/jasonkyuyim/se3_diffusion): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | """Error class for handled errors.""" 5 | 6 | 7 | class DataError(Exception): 8 | """Data exception.""" 9 | 10 | pass 11 | 12 | 13 | class FileExistsError(DataError): 14 | """Raised when file already exists.""" 15 | 16 | pass 17 | 18 | 19 | class MmcifParsingError(DataError): 20 | """Raised when mmcif parsing fails.""" 21 | 22 | pass 23 | 24 | 25 | class ResolutionError(DataError): 26 | """Raised when resolution isn't acceptable.""" 27 | 28 | pass 29 | 30 | 31 | class LengthError(DataError): 32 | """Raised when length isn't acceptable.""" 33 | 34 | pass 35 | -------------------------------------------------------------------------------- /src/data/components/pdb/join_pdb_metadata.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for (https://github.com/Profluent-Internships/MMDiff): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import hydra 6 | import pandas as pd 7 | from omegaconf import DictConfig 8 | 9 | 10 | @hydra.main( 11 | version_base="1.3", config_path="../../../../configs/paths", config_name="pdb_metadata.yaml" 12 | ) 13 | def main(cfg: DictConfig): 14 | na_df = pd.read_csv(cfg.na_metadata_csv_path) 15 | protein_df = pd.read_csv(cfg.protein_metadata_csv_path) 16 | # impute missing columns for the nucleic acid DataFrame 17 | na_df["oligomeric_count"] = na_df["num_chains"].astype(str) 18 | na_df["oligomeric_detail"] = na_df["num_chains"].apply( 19 | lambda x: "heteromeric" if x > 1 else "monomeric" 20 | ) 21 | # note: we can reasonably assume the following two column values due to our initial filtering for nucleic acid molecules 22 | na_df["resolution"] = 0.0 23 | na_df["structure_method"] = "x-ray diffraction" 24 | output_df = pd.concat([na_df, protein_df]).drop_duplicates( 25 | subset=["pdb_name"], keep="first", ignore_index=True 26 | ) 27 | output_df.to_csv(cfg.metadata_output_csv_path, index=False) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /src/data/components/pdb/relax/relax_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code adapted from openfold (https://github.com/aqlaboratory/openfold): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | # Copyright 2021 AlQuraishi Laboratory 6 | # Copyright 2021 DeepMind Technologies Limited 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | """Utils for minimization.""" 20 | import io 21 | 22 | import numpy as np 23 | from Bio import PDB 24 | 25 | try: 26 | # openmm >= 7.6 27 | from openmm import app as openmm_app 28 | from openmm.app.internal.pdbstructure import PdbStructure 29 | except ImportError: 30 | # openmm < 7.6 (requires DeepMind patch) 31 | from simtk.openmm import app as openmm_app 32 | from simtk.openmm.app.internal.pdbstructure import PdbStructure 33 | 34 | from src.data.components.pdb import protein_constants 35 | 36 | 37 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 38 | pdb_file = io.StringIO(pdb_str) 39 | structure = PdbStructure(pdb_file) 40 | topology = openmm_app.PDBFile(structure).getTopology() 41 | with io.StringIO() as f: 42 | openmm_app.PDBFile.writeFile(topology, pos, f) 43 | return f.getvalue() 44 | 45 | 46 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 47 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 48 | 49 | Args: 50 | pdb_str: An input PDB string. 51 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 52 | B-factors are per residue; i.e. that the nonzero entries are identical in 53 | [0, i, :]. 54 | 55 | Returns: 56 | A new PDB string with the B-factors replaced. 57 | """ 58 | if bfactors.shape[-1] != protein_constants.atom_type_num: 59 | raise ValueError(f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}.") 60 | 61 | parser = PDB.PDBParser(QUIET=True) 62 | handle = io.StringIO(pdb_str) 63 | structure = parser.get_structure("", handle) 64 | 65 | curr_resid = ("", "", "") 66 | idx = -1 67 | for atom in structure.get_atoms(): 68 | atom_resid = atom.parent.get_id() 69 | if atom_resid != curr_resid: 70 | idx += 1 71 | if idx >= bfactors.shape[0]: 72 | raise ValueError( 73 | "Index into bfactors exceeds number of residues. " 74 | "B-factors shape: {shape}, idx: {idx}." 75 | ) 76 | curr_resid = atom_resid 77 | atom.bfactor = bfactors[idx, protein_constants.atom_order["CA"]] 78 | 79 | new_pdb = io.StringIO() 80 | pdb_io = PDB.PDBIO() 81 | pdb_io.set_structure(structure) 82 | pdb_io.save(new_pdb) 83 | return new_pdb.getvalue() 84 | 85 | 86 | def assert_equal_nonterminal_atom_types(atom_mask: np.ndarray, ref_atom_mask: np.ndarray): 87 | """Checks that pre- and post-minimized proteins have same atom set.""" 88 | # Ignore any terminal OXT atoms which may have been added by minimization. 89 | oxt = protein_constants.atom_order["OXT"] 90 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 91 | no_oxt_mask[..., oxt] = False 92 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]) 93 | -------------------------------------------------------------------------------- /src/data/components/pdb/vocabulary.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for MMDiff (https://github.com/Profluent-Internships/MMDiff): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | # This is the standard residue order when coding residues type as a number. 6 | restypes = [ 7 | "A", 8 | "C", 9 | "D", 10 | "E", 11 | "F", 12 | "G", 13 | "H", 14 | "I", 15 | "K", 16 | "L", 17 | "M", 18 | "N", 19 | "P", 20 | "Q", 21 | "R", 22 | "S", 23 | "T", 24 | "V", 25 | "W", 26 | "Y", 27 | "X", 28 | "a", 29 | "c", 30 | "g", 31 | "t", 32 | "u", 33 | "x", 34 | "-", 35 | "_", 36 | "1", 37 | "2", 38 | "3", 39 | "4", 40 | "5", 41 | ] 42 | restype_order = {restype: i for i, restype in enumerate(restypes)} 43 | 44 | restype_1to3 = { 45 | "A": "ALA", 46 | "C": "CYS", 47 | "D": "ASP", 48 | "E": "GLU", 49 | "F": "PHE", 50 | "G": "GLY", 51 | "H": "HIS", 52 | "I": "ILE", 53 | "K": "LYS", 54 | "L": "LEU", 55 | "M": "MET", 56 | "N": "ASN", 57 | "P": "PRO", 58 | "Q": "GLN", 59 | "R": "ARG", 60 | "S": "SER", 61 | "T": "THR", 62 | "V": "VAL", 63 | "W": "TRP", 64 | "Y": "TYR", 65 | "X": "UNK", 66 | "a": "A", 67 | "c": "C", 68 | "g": "G", 69 | "t": "T", 70 | "u": "U", 71 | "x": "unk", 72 | "-": "GAP", 73 | "_": "PAD", 74 | "1": "SP1", 75 | "2": "SP2", 76 | "3": "SP3", 77 | "4": "SP4", 78 | "5": "SP5", 79 | } 80 | restype_3to1 = {v: k for k, v in restype_1to3.items()} 81 | restype_3to1.update( 82 | { 83 | "DA": "a", 84 | "DC": "c", 85 | "DG": "g", 86 | "DT": "t", 87 | } 88 | ) 89 | 90 | protein_restypes = [ 91 | "A", 92 | "C", 93 | "D", 94 | "E", 95 | "F", 96 | "G", 97 | "H", 98 | "I", 99 | "K", 100 | "L", 101 | "M", 102 | "N", 103 | "P", 104 | "Q", 105 | "R", 106 | "S", 107 | "T", 108 | "V", 109 | "W", 110 | "Y", 111 | ] 112 | protein_restype_order = {restype: i for i, restype in enumerate(protein_restypes)} 113 | 114 | protein_restype_1to3 = { 115 | "A": "ALA", 116 | "C": "CYS", 117 | "D": "ASP", 118 | "E": "GLU", 119 | "F": "PHE", 120 | "G": "GLY", 121 | "H": "HIS", 122 | "I": "ILE", 123 | "K": "LYS", 124 | "L": "LEU", 125 | "M": "MET", 126 | "N": "ASN", 127 | "P": "PRO", 128 | "Q": "GLN", 129 | "R": "ARG", 130 | "S": "SER", 131 | "T": "THR", 132 | "V": "VAL", 133 | "W": "TRP", 134 | "Y": "TYR", 135 | } 136 | protein_restype_3to1 = {v: k for k, v in protein_restype_1to3.items()} 137 | 138 | protein_restypes_with_x = protein_restypes + ["X"] 139 | protein_restype_order_with_x = {restype: i for i, restype in enumerate(protein_restypes_with_x)} 140 | nucleic_restypes = ["a", "c", "g", "t", "u"] 141 | special_restypes = ["-", "_", "1", "2", "3", "4", "5"] 142 | unknown_protein_restype = "X" 143 | unknown_nucleic_restype = "x" 144 | gap_token = restypes.index("-") 145 | pad_token = restypes.index("_") 146 | 147 | restype_num = len(restypes) # := 34. 148 | protein_restype_num = len(protein_restypes) # := 20 149 | protein_restype_num_with_x = len(protein_restypes_with_x) # := 21 150 | 151 | protein_resnames = [restype_1to3[r] for r in protein_restypes] 152 | protein_resname_to_idx = {resname: i for i, resname in enumerate(protein_resnames)} 153 | 154 | alternative_restypes_map = { 155 | # Protein 156 | "MSE": "MET", 157 | } 158 | allowable_restypes = set(restypes + list(alternative_restypes_map.keys())) 159 | 160 | 161 | def is_protein_sequence(sequence: str) -> bool: 162 | """Check if a sequence is a protein sequence.""" 163 | 164 | return all([s in protein_restypes for s in sequence]) 165 | 166 | 167 | def is_nucleic_sequence(sequence: str) -> bool: 168 | """Check if a sequence is a nucleic acid sequence.""" 169 | 170 | return all([s in nucleic_restypes for s in sequence]) 171 | -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/models/components/__init__.py -------------------------------------------------------------------------------- /src/models/components/pdb/embedders.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------------------------- 2 | # Following code curated for MMDiff (https://github.com/Profluent-Internships/MMDiff): 3 | # ------------------------------------------------------------------------------------------------------------------------------------- 4 | 5 | import torch 6 | import torch.nn as nn 7 | from beartype import beartype 8 | from beartype.typing import Optional 9 | 10 | from src.data.components.pdb.data_transforms import make_one_hot 11 | from src.models.components.pdb.framediff import Linear 12 | 13 | 14 | class RelPosEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | embedding_size: int, 18 | max_relative_idx: int, 19 | max_relative_chain: int = 0, 20 | use_chain_relative: bool = False, 21 | ): 22 | super().__init__() 23 | self.max_relative_idx = max_relative_idx 24 | self.max_relative_chain = max_relative_chain 25 | self.use_chain_relative = use_chain_relative 26 | self.num_bins = 2 * max_relative_idx + 2 27 | if max_relative_chain > 0: 28 | self.num_bins += 2 * max_relative_chain + 2 29 | if use_chain_relative: 30 | self.num_bins += 1 31 | 32 | self.linear_relpos = Linear(self.num_bins, embedding_size) 33 | 34 | @beartype 35 | def forward( 36 | self, 37 | residue_index: torch.Tensor, 38 | asym_id: Optional[torch.Tensor] = None, 39 | sym_id: Optional[torch.Tensor] = None, 40 | entity_id: Optional[torch.Tensor] = None, 41 | ) -> torch.Tensor: 42 | d = residue_index[..., None] - residue_index[..., None, :] 43 | 44 | if asym_id is None: 45 | # compute relative position encoding according to AlphaFold's `relpos` algorithm 46 | boundaries = torch.arange( 47 | start=-self.max_relative_idx, end=self.max_relative_idx + 1, device=d.device 48 | ) 49 | reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),)) 50 | d = d[..., None] - reshaped_bins 51 | d = torch.abs(d) 52 | d = torch.argmin(d, dim=-1) 53 | d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() 54 | d = d.to(residue_index.dtype) 55 | rel_feat = d 56 | else: 57 | # compute relative position encoding according to AlphaFold-Multimer's `relpos` algorithm 58 | rel_feats = [] 59 | asym_id_same = torch.eq(asym_id[..., None], asym_id[..., None, :]) 60 | offset = residue_index[..., None] - residue_index[..., None, :] 61 | 62 | clipped_offset = torch.clamp( 63 | input=offset + self.max_relative_idx, min=0, max=(2 * self.max_relative_idx) 64 | ) 65 | 66 | final_offset = torch.where( 67 | condition=asym_id_same, 68 | input=clipped_offset, 69 | other=((2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset)), 70 | ) 71 | 72 | rel_pos = make_one_hot(x=final_offset, num_classes=(2 * self.max_relative_idx + 2)) 73 | rel_feats.append(rel_pos) 74 | 75 | if self.use_chain_relative: 76 | entity_id_same = torch.eq(entity_id[..., None], entity_id[..., None, :]) 77 | rel_feats.append(entity_id_same.type(rel_pos.dtype)[..., None]) 78 | 79 | if self.max_relative_chain > 0: 80 | rel_sym_id = sym_id[..., None] - sym_id[..., None, :] 81 | max_rel_chain = self.max_relative_chain 82 | 83 | clipped_rel_chain = torch.clamp( 84 | input=rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain 85 | ) 86 | 87 | if not self.use_chain_relative: 88 | # ensure `entity_id_same` is constructed for `rel_chain` 89 | entity_id_same = torch.eq(entity_id[..., None], entity_id[..., None, :]) 90 | 91 | final_rel_chain = torch.where( 92 | condition=entity_id_same, 93 | input=clipped_rel_chain, 94 | other=(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain), 95 | ) 96 | rel_chain = make_one_hot( 97 | x=final_rel_chain, num_classes=2 * self.max_relative_chain + 2 98 | ) 99 | rel_feats.append(rel_chain) 100 | 101 | rel_feat = torch.cat(rel_feats, dim=-1) 102 | 103 | return self.linear_relpos(rel_feat) 104 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import log_hyperparameters 3 | from src.utils.pylogger import get_pylogger 4 | from src.utils.rich_utils import enforce_tags, print_config_tree 5 | from src.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from beartype.typing import List 3 | from lightning import Callback 4 | from lightning.pytorch.loggers import Logger 5 | from omegaconf import DictConfig 6 | 7 | from src.utils import pylogger 8 | 9 | log = pylogger.get_pylogger(__name__) 10 | 11 | 12 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 13 | """Instantiates callbacks from config.""" 14 | 15 | callbacks: List[Callback] = [] 16 | 17 | if not callbacks_cfg: 18 | log.warning("No callback configs found! Skipping..") 19 | return callbacks 20 | 21 | if not isinstance(callbacks_cfg, DictConfig): 22 | raise TypeError("Callbacks config must be a DictConfig!") 23 | 24 | for _, cb_conf in callbacks_cfg.items(): 25 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 26 | log.info(f"Instantiating callback <{cb_conf._target_}>") 27 | callbacks.append(hydra.utils.instantiate(cb_conf)) 28 | 29 | return callbacks 30 | 31 | 32 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 33 | """Instantiates loggers from config.""" 34 | 35 | logger: List[Logger] = [] 36 | 37 | if not logger_cfg: 38 | log.warning("No logger configs found! Skipping...") 39 | return logger 40 | 41 | if not isinstance(logger_cfg, DictConfig): 42 | raise TypeError("Logger config must be a DictConfig!") 43 | 44 | for _, lg_conf in logger_cfg.items(): 45 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 46 | log.info(f"Instantiating logger <{lg_conf._target_}>") 47 | logger.append(hydra.utils.instantiate(lg_conf)) 48 | 49 | return logger 50 | -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from src.utils import pylogger 4 | 5 | log = pylogger.get_pylogger(__name__) 6 | 7 | 8 | @rank_zero_only 9 | def log_hyperparameters(object_dict: dict) -> None: 10 | """Controls which config parts are saved by lightning loggers. 11 | 12 | Additionally saves: 13 | - Number of model parameters 14 | """ 15 | 16 | hparams = {} 17 | 18 | cfg = object_dict["cfg"] 19 | model = object_dict["model"] 20 | trainer = object_dict["trainer"] 21 | 22 | if not trainer.logger: 23 | log.warning("Logger not found! Skipping hyperparameter logging...") 24 | return 25 | 26 | hparams["model"] = cfg["model"] 27 | 28 | # save number of model parameters 29 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 30 | hparams["model/params/trainable"] = sum( 31 | p.numel() for p in model.parameters() if p.requires_grad 32 | ) 33 | hparams["model/params/non_trainable"] = sum( 34 | p.numel() for p in model.parameters() if not p.requires_grad 35 | ) 36 | 37 | hparams["data"] = cfg["data"] 38 | hparams["trainer"] = cfg["trainer"] 39 | 40 | hparams["callbacks"] = cfg.get("callbacks") 41 | hparams["extras"] = cfg.get("extras") 42 | 43 | hparams["task_name"] = cfg.get("task_name") 44 | hparams["tags"] = cfg.get("tags") 45 | hparams["ckpt_path"] = cfg.get("ckpt_path") 46 | hparams["seed"] = cfg.get("seed") 47 | 48 | # send hparams to all loggers 49 | for logger in trainer.loggers: 50 | logger.log_hyperparams(hparams) 51 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import rich 4 | import rich.syntax 5 | import rich.tree 6 | from beartype.typing import Sequence 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str], optional): Determines in what order config components are printed. 37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) 98 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """This file prepares config fixtures for other tests.""" 2 | 3 | import pyrootutils 4 | import pytest 5 | from hydra import compose, initialize 6 | from hydra.core.global_hydra import GlobalHydra 7 | from omegaconf import DictConfig, open_dict 8 | 9 | 10 | @pytest.fixture(scope="package") 11 | def cfg_train_global() -> DictConfig: 12 | with initialize(version_base="1.3", config_path="../configs"): 13 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 14 | 15 | # set defaults for all tests 16 | with open_dict(cfg): 17 | cfg.paths.root_dir = str(pyrootutils.find_root(indicator=".project-root")) 18 | cfg.trainer.max_epochs = 1 19 | cfg.trainer.limit_train_batches = 0.01 20 | cfg.trainer.limit_val_batches = 0.1 21 | cfg.trainer.limit_test_batches = 0.1 22 | cfg.trainer.accelerator = "cpu" 23 | cfg.trainer.devices = 1 24 | cfg.data.num_workers = 0 25 | cfg.data.pin_memory = False 26 | cfg.extras.print_config = False 27 | cfg.extras.enforce_tags = False 28 | cfg.logger = None 29 | 30 | return cfg 31 | 32 | 33 | @pytest.fixture(scope="package") 34 | def cfg_sample_global() -> DictConfig: 35 | with initialize(version_base="1.3", config_path="../configs"): 36 | cfg = compose( 37 | config_name="sample.yaml", return_hydra_config=True, overrides=["ckpt_path=."] 38 | ) 39 | 40 | # set defaults for all tests 41 | with open_dict(cfg): 42 | cfg.paths.root_dir = str(pyrootutils.find_root(indicator=".project-root")) 43 | cfg.trainer.max_epochs = 1 44 | cfg.trainer.limit_test_batches = 0.1 45 | cfg.trainer.accelerator = "cpu" 46 | cfg.trainer.devices = 1 47 | cfg.data.num_workers = 0 48 | cfg.data.pin_memory = False 49 | cfg.extras.print_config = False 50 | cfg.extras.enforce_tags = False 51 | cfg.logger = None 52 | 53 | return cfg 54 | 55 | 56 | # this is called by each test which uses `cfg_train` arg 57 | # each test generates its own temporary logging path 58 | @pytest.fixture(scope="function") 59 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig: 60 | cfg = cfg_train_global.copy() 61 | 62 | with open_dict(cfg): 63 | cfg.paths.output_dir = str(tmp_path) 64 | cfg.paths.log_dir = str(tmp_path) 65 | 66 | yield cfg 67 | 68 | GlobalHydra.instance().clear() 69 | 70 | 71 | # this is called by each test which uses `cfg_sample` arg 72 | # each test generates its own temporary logging path 73 | @pytest.fixture(scope="function") 74 | def cfg_sample(cfg_sample_global, tmp_path) -> DictConfig: 75 | cfg = cfg_sample_global.copy() 76 | 77 | with open_dict(cfg): 78 | cfg.paths.output_dir = str(tmp_path) 79 | cfg.paths.log_dir = str(tmp_path) 80 | 81 | yield cfg 82 | 83 | GlobalHydra.instance().clear() 84 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from lightning.fabric.accelerators import TPUAccelerator 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment.""" 9 | try: 10 | return pkg_resources.require(package_name) is not None 11 | except pkg_resources.DistributionNotFound: 12 | return False 13 | 14 | 15 | _TPU_AVAILABLE = TPUAccelerator.is_available() 16 | 17 | _IS_WINDOWS = platform.system() == "Windows" 18 | 19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 20 | 21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 23 | 24 | _WANDB_AVAILABLE = _package_available("wandb") 25 | _NEPTUNE_AVAILABLE = _package_available("neptune") 26 | _COMET_AVAILABLE = _package_available("comet_ml") 27 | _MLFLOW_AVAILABLE = _package_available("mlflow") 28 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | 8 | import pytest 9 | import torch 10 | from beartype.typing import Optional 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | 14 | from tests.helpers.package_available import ( 15 | _COMET_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _MLFLOW_AVAILABLE, 20 | _NEPTUNE_AVAILABLE, 21 | _SH_AVAILABLE, 22 | _TPU_AVAILABLE, 23 | _WANDB_AVAILABLE, 24 | ) 25 | 26 | 27 | class RunIf: 28 | """RunIf wrapper for conditional skipping of tests. 29 | 30 | Fully compatible with `@pytest.mark`. 31 | 32 | Example: 33 | 34 | @RunIf(min_torch="1.8") 35 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 36 | def test_wrapper(arg1): 37 | assert arg1 > 0 38 | """ 39 | 40 | def __new__( 41 | self, 42 | min_gpus: int = 0, 43 | min_torch: Optional[str] = None, 44 | max_torch: Optional[str] = None, 45 | min_python: Optional[str] = None, 46 | skip_windows: bool = False, 47 | sh: bool = False, 48 | tpu: bool = False, 49 | fairscale: bool = False, 50 | deepspeed: bool = False, 51 | wandb: bool = False, 52 | neptune: bool = False, 53 | comet: bool = False, 54 | mlflow: bool = False, 55 | **kwargs, 56 | ): 57 | """ 58 | Args: 59 | min_gpus: min number of GPUs required to run test 60 | min_torch: minimum pytorch version to run test 61 | max_torch: maximum pytorch version to run test 62 | min_python: minimum python version required to run test 63 | skip_windows: skip test for Windows platform 64 | tpu: if TPU is available 65 | sh: if `sh` module is required to run the test 66 | fairscale: if `fairscale` module is required to run the test 67 | deepspeed: if `deepspeed` module is required to run the test 68 | wandb: if `wandb` module is required to run the test 69 | neptune: if `neptune` module is required to run the test 70 | comet: if `comet` module is required to run the test 71 | mlflow: if `mlflow` module is required to run the test 72 | kwargs: native pytest.mark.skipif keyword arguments 73 | """ 74 | conditions = [] 75 | reasons = [] 76 | 77 | if min_gpus: 78 | conditions.append(torch.cuda.device_count() < min_gpus) 79 | reasons.append(f"GPUs>={min_gpus}") 80 | 81 | if min_torch: 82 | torch_version = get_distribution("torch").version 83 | conditions.append(Version(torch_version) < Version(min_torch)) 84 | reasons.append(f"torch>={min_torch}") 85 | 86 | if max_torch: 87 | torch_version = get_distribution("torch").version 88 | conditions.append(Version(torch_version) >= Version(max_torch)) 89 | reasons.append(f"torch<{max_torch}") 90 | 91 | if min_python: 92 | py_version = ( 93 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 94 | ) 95 | conditions.append(Version(py_version) < Version(min_python)) 96 | reasons.append(f"python>={min_python}") 97 | 98 | if skip_windows: 99 | conditions.append(_IS_WINDOWS) 100 | reasons.append("does not run on Windows") 101 | 102 | if tpu: 103 | conditions.append(not _TPU_AVAILABLE) 104 | reasons.append("TPU") 105 | 106 | if sh: 107 | conditions.append(not _SH_AVAILABLE) 108 | reasons.append("sh") 109 | 110 | if fairscale: 111 | conditions.append(not _FAIRSCALE_AVAILABLE) 112 | reasons.append("fairscale") 113 | 114 | if deepspeed: 115 | conditions.append(not _DEEPSPEED_AVAILABLE) 116 | reasons.append("deepspeed") 117 | 118 | if wandb: 119 | conditions.append(not _WANDB_AVAILABLE) 120 | reasons.append("wandb") 121 | 122 | if neptune: 123 | conditions.append(not _NEPTUNE_AVAILABLE) 124 | reasons.append("neptune") 125 | 126 | if comet: 127 | conditions.append(not _COMET_AVAILABLE) 128 | reasons.append("comet") 129 | 130 | if mlflow: 131 | conditions.append(not _MLFLOW_AVAILABLE) 132 | reasons.append("mlflow") 133 | 134 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 135 | return pytest.mark.skipif( 136 | condition=any(conditions), 137 | reason=f"Requires: [{' + '.join(reasons)}]", 138 | **kwargs, 139 | ) 140 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from beartype.typing import List 3 | 4 | from tests.helpers.package_available import _SH_AVAILABLE 5 | 6 | if _SH_AVAILABLE: 7 | import sh 8 | 9 | 10 | def run_sh_command(command: List[str]): 11 | """Default method for executing shell commands with pytest and sh package.""" 12 | msg = None 13 | try: 14 | sh.python(command) 15 | except sh.ErrorReturnCode as e: 16 | msg = e.stderr.decode() 17 | if msg: 18 | pytest.fail(msg=msg) 19 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig): 7 | assert cfg_train 8 | assert cfg_train.data 9 | assert cfg_train.model 10 | assert cfg_train.trainer 11 | 12 | HydraConfig().set_config(cfg_train) 13 | 14 | hydra.utils.instantiate(cfg_train.data) 15 | hydra.utils.instantiate(cfg_train.model) 16 | hydra.utils.instantiate(cfg_train.trainer) 17 | 18 | 19 | def cfg_sample_config(cfg_sample: DictConfig): 20 | assert cfg_sample 21 | assert cfg_sample.data 22 | assert cfg_sample.model 23 | assert cfg_sample.trainer 24 | 25 | HydraConfig().set_config(cfg_sample) 26 | 27 | hydra.utils.instantiate(cfg_sample.data) 28 | hydra.utils.instantiate(cfg_sample.model) 29 | hydra.utils.instantiate(cfg_sample.trainer) 30 | -------------------------------------------------------------------------------- /tests/test_datamodules.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.data.mnist_datamodule import MNISTDataModule 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [32, 128]) 10 | def test_mnist_datamodule(batch_size): 11 | data_dir = "data/" 12 | 13 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) 14 | dm.prepare_data() 15 | 16 | assert not dm.data_train and not dm.data_val and not dm.data_test 17 | assert Path(data_dir, "MNIST").exists() 18 | assert Path(data_dir, "MNIST", "raw").exists() 19 | 20 | dm.setup() 21 | assert dm.data_train and dm.data_val and dm.data_test 22 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 23 | 24 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) 25 | assert num_datapoints == 70_000 26 | 27 | batch = next(iter(dm.train_dataloader())) 28 | x, y = batch 29 | assert len(x) == batch_size 30 | assert len(y) == batch_size 31 | assert x.dtype == torch.float32 32 | assert y.dtype == torch.int64 33 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.sample import sample 8 | from src.train import train 9 | 10 | 11 | @pytest.mark.slow 12 | def test_train_sample(tmp_path, cfg_train, cfg_sample): 13 | """Train for 1 epoch with `train.py` and run inference with `sample.py`""" 14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_sample.paths.output_dir 15 | 16 | with open_dict(cfg_train): 17 | cfg_train.trainer.max_epochs = 1 18 | cfg_train.test = True 19 | 20 | HydraConfig().set_config(cfg_train) 21 | train_metric_dict, _ = train(cfg_train) 22 | 23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 24 | 25 | with open_dict(cfg_sample): 26 | cfg_sample.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 27 | 28 | HydraConfig().set_config(cfg_sample) 29 | test_metric_dict, _ = sample(cfg_sample) 30 | 31 | assert test_metric_dict["test/acc"] > 0.0 32 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 33 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_if import RunIf 4 | from tests.helpers.run_sh_command import run_sh_command 5 | 6 | startfile = "src/train.py" 7 | overrides = ["logger=[]"] 8 | 9 | 10 | @RunIf(sh=True) 11 | @pytest.mark.slow 12 | def test_experiments(tmp_path): 13 | """Test running all available experiment configs with fast_dev_run=True.""" 14 | command = [ 15 | startfile, 16 | "-m", 17 | "experiment=glob(*)", 18 | "hydra.sweep.dir=" + str(tmp_path), 19 | "++trainer.fast_dev_run=true", 20 | ] + overrides 21 | run_sh_command(command) 22 | 23 | 24 | @RunIf(sh=True) 25 | @pytest.mark.slow 26 | def test_hydra_sweep(tmp_path): 27 | """Test default hydra sweep.""" 28 | command = [ 29 | startfile, 30 | "-m", 31 | "hydra.sweep.dir=" + str(tmp_path), 32 | "model.optimizer.lr=0.005,0.01", 33 | "++trainer.fast_dev_run=true", 34 | ] + overrides 35 | 36 | run_sh_command(command) 37 | 38 | 39 | @RunIf(sh=True) 40 | @pytest.mark.slow 41 | def test_hydra_sweep_ddp_sim(tmp_path): 42 | """Test default hydra sweep with ddp sim.""" 43 | command = [ 44 | startfile, 45 | "-m", 46 | "hydra.sweep.dir=" + str(tmp_path), 47 | "trainer=ddp_sim", 48 | "trainer.max_epochs=3", 49 | "+trainer.limit_train_batches=0.01", 50 | "+trainer.limit_val_batches=0.1", 51 | "+trainer.limit_test_batches=0.1", 52 | "model.optimizer.lr=0.005,0.01,0.02", 53 | ] + overrides 54 | run_sh_command(command) 55 | 56 | 57 | @RunIf(sh=True) 58 | @pytest.mark.slow 59 | def test_optuna_sweep(tmp_path): 60 | """Test optuna sweep.""" 61 | command = [ 62 | startfile, 63 | "-m", 64 | "hparams_search=mnist_optuna", 65 | "hydra.sweep.dir=" + str(tmp_path), 66 | "hydra.sweeper.n_trials=10", 67 | "hydra.sweeper.sampler.n_startup_trials=5", 68 | "++trainer.fast_dev_run=true", 69 | ] + overrides 70 | run_sh_command(command) 71 | 72 | 73 | @RunIf(wandb=True, sh=True) 74 | @pytest.mark.slow 75 | def test_optuna_sweep_ddp_sim_wandb(tmp_path): 76 | """Test optuna sweep with wandb and ddp sim.""" 77 | command = [ 78 | startfile, 79 | "-m", 80 | "hparams_search=mnist_optuna", 81 | "hydra.sweep.dir=" + str(tmp_path), 82 | "hydra.sweeper.n_trials=5", 83 | "trainer=ddp_sim", 84 | "trainer.max_epochs=3", 85 | "+trainer.limit_train_batches=0.01", 86 | "+trainer.limit_val_batches=0.1", 87 | "+trainer.limit_test_batches=0.1", 88 | "logger=wandb", 89 | ] 90 | run_sh_command(command) 91 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.train import train 8 | from tests.helpers.run_if import RunIf 9 | 10 | 11 | def test_train_fast_dev_run(cfg_train): 12 | """Run for 1 train, val and test step.""" 13 | HydraConfig().set_config(cfg_train) 14 | with open_dict(cfg_train): 15 | cfg_train.trainer.fast_dev_run = True 16 | cfg_train.trainer.accelerator = "cpu" 17 | train(cfg_train) 18 | 19 | 20 | @RunIf(min_gpus=1) 21 | def test_train_fast_dev_run_gpu(cfg_train): 22 | """Run for 1 train, val and test step on GPU.""" 23 | HydraConfig().set_config(cfg_train) 24 | with open_dict(cfg_train): 25 | cfg_train.trainer.fast_dev_run = True 26 | cfg_train.trainer.accelerator = "gpu" 27 | train(cfg_train) 28 | 29 | 30 | @RunIf(min_gpus=1) 31 | @pytest.mark.slow 32 | def test_train_epoch_gpu_amp(cfg_train): 33 | """Train 1 epoch on GPU with mixed-precision.""" 34 | HydraConfig().set_config(cfg_train) 35 | with open_dict(cfg_train): 36 | cfg_train.trainer.max_epochs = 1 37 | cfg_train.trainer.accelerator = "cpu" 38 | cfg_train.trainer.precision = 16 39 | train(cfg_train) 40 | 41 | 42 | @pytest.mark.slow 43 | def test_train_epoch_double_val_loop(cfg_train): 44 | """Train 1 epoch with validation loop twice per epoch.""" 45 | HydraConfig().set_config(cfg_train) 46 | with open_dict(cfg_train): 47 | cfg_train.trainer.max_epochs = 1 48 | cfg_train.trainer.val_check_interval = 0.5 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | def test_train_ddp_sim(cfg_train): 54 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" 55 | HydraConfig().set_config(cfg_train) 56 | with open_dict(cfg_train): 57 | cfg_train.trainer.max_epochs = 2 58 | cfg_train.trainer.accelerator = "cpu" 59 | cfg_train.trainer.devices = 2 60 | cfg_train.trainer.strategy = "ddp_spawn" 61 | train(cfg_train) 62 | 63 | 64 | @pytest.mark.slow 65 | def test_train_resume(tmp_path, cfg_train): 66 | """Run 1 epoch, finish, and resume for another epoch.""" 67 | with open_dict(cfg_train): 68 | cfg_train.trainer.max_epochs = 1 69 | 70 | HydraConfig().set_config(cfg_train) 71 | metric_dict_1, _ = train(cfg_train) 72 | 73 | files = os.listdir(tmp_path / "checkpoints") 74 | assert "last.ckpt" in files 75 | assert "epoch_000.ckpt" in files 76 | 77 | with open_dict(cfg_train): 78 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 79 | cfg_train.trainer.max_epochs = 2 80 | 81 | metric_dict_2, _ = train(cfg_train) 82 | 83 | files = os.listdir(tmp_path / "checkpoints") 84 | assert "epoch_001.ckpt" in files 85 | assert "epoch_002.ckpt" not in files 86 | 87 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] 88 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"] 89 | --------------------------------------------------------------------------------