├── .env.example
├── .gitignore
├── .pre-commit-config.yaml
├── .project-root
├── LICENSE
├── Makefile
├── README.md
├── checkpoints
└── .gitkeep
├── configs
├── __init__.py
├── callbacks
│ ├── default.yaml
│ ├── early_stopping.yaml
│ ├── ema.yaml
│ ├── model_checkpoint.yaml
│ ├── model_summary.yaml
│ ├── none.yaml
│ └── rich_progress_bar.yaml
├── data
│ └── pdb_na.yaml
├── debug
│ ├── default.yaml
│ ├── fdr.yaml
│ ├── limit.yaml
│ ├── overfit.yaml
│ └── profiler.yaml
├── experiment
│ └── pdb_prot_na_gen_se3.yaml
├── extras
│ └── default.yaml
├── hparams_search
│ └── pdb_prot_gen_se3_module_optuna.yaml
├── hydra
│ └── default.yaml
├── local
│ └── .gitkeep
├── logger
│ ├── aim.yaml
│ ├── comet.yaml
│ ├── csv.yaml
│ ├── many_loggers.yaml
│ ├── mlflow.yaml
│ ├── neptune.yaml
│ ├── tensorboard.yaml
│ └── wandb.yaml
├── model
│ ├── diffusion_cfg
│ │ └── pdb_prot_na_gen_se3_ddpm.yaml
│ ├── model_cfg
│ │ └── pdb_prot_na_gen_se3_model.yaml
│ └── pdb_prot_na_gen_se3_module.yaml
├── paths
│ ├── default.yaml
│ └── pdb_metadata.yaml
├── sample.yaml
├── train.yaml
└── trainer
│ ├── cpu.yaml
│ ├── ddp.yaml
│ ├── ddp_sim.yaml
│ ├── default.yaml
│ ├── gpu.yaml
│ └── mps.yaml
├── data
└── .gitkeep
├── environment.yaml
├── eval_results
├── analyze_eval_results.py
├── analyze_grouped_eval_results.py
└── collect_eval_results.py
├── forks
├── ProteinMPNN
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── ca_model_weights
│ │ ├── v_48_002.pt
│ │ ├── v_48_010.pt
│ │ └── v_48_020.pt
│ ├── colab_notebooks
│ │ ├── README.md
│ │ ├── ca_only_quickdemo.ipynb
│ │ ├── quickdemo.ipynb
│ │ └── quickdemo_wAF2.ipynb
│ ├── examples
│ │ ├── submit_example_1.sh
│ │ ├── submit_example_2.sh
│ │ ├── submit_example_3.sh
│ │ ├── submit_example_3_score_only.sh
│ │ ├── submit_example_3_score_only_from_fasta.sh
│ │ ├── submit_example_4.sh
│ │ ├── submit_example_4_non_fixed.sh
│ │ ├── submit_example_5.sh
│ │ ├── submit_example_6.sh
│ │ ├── submit_example_7.sh
│ │ ├── submit_example_8.sh
│ │ └── submit_example_pssm.sh
│ ├── helper_scripts
│ │ ├── assign_fixed_chains.py
│ │ ├── make_bias_AA.py
│ │ ├── make_bias_per_res_dict.py
│ │ ├── make_fixed_positions_dict.py
│ │ ├── make_pos_neg_tied_positions_dict.py
│ │ ├── make_pssm_input_dict.py
│ │ ├── make_tied_positions_dict.py
│ │ ├── other_tools
│ │ │ ├── make_omit_AA.py
│ │ │ └── make_pssm_dict.py
│ │ ├── parse_multiple_chains.out
│ │ ├── parse_multiple_chains.py
│ │ └── parse_multiple_chains.sh
│ ├── inputs
│ │ └── PSSM_inputs
│ │ │ ├── 3HTN.npz
│ │ │ └── 4YOW.npz
│ ├── outputs
│ │ └── training_test_output
│ │ │ └── seqs
│ │ │ └── 5L33.fa
│ ├── protein_mpnn_run.py
│ ├── protein_mpnn_utils.py
│ ├── soluble_model_weights
│ │ ├── v_48_002.pt
│ │ ├── v_48_010.pt
│ │ ├── v_48_020.pt
│ │ └── v_48_030.pt
│ ├── training
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── colab_training_example.ipynb
│ │ ├── exp_020
│ │ │ ├── log.txt
│ │ │ └── model_weights
│ │ │ │ └── epoch_last.pt
│ │ ├── model_utils.py
│ │ ├── parse_cif_noX.py
│ │ ├── plot_training_results.ipynb
│ │ ├── submit_exp_020.sh
│ │ ├── test_inference.sh
│ │ ├── training.py
│ │ └── utils.py
│ └── vanilla_model_weights
│ │ ├── v_48_002.pt
│ │ ├── v_48_010.pt
│ │ ├── v_48_020.pt
│ │ └── v_48_030.pt
└── RoseTTAFold2NA
│ ├── LICENSE
│ ├── README.md
│ ├── RF2na-linux.yml
│ ├── SE3Transformer
│ ├── Dockerfile
│ ├── LICENSE
│ ├── NOTICE
│ ├── README.md
│ ├── images
│ │ └── se3-transformer.png
│ ├── requirements.txt
│ ├── scripts
│ │ ├── benchmark_inference.sh
│ │ ├── benchmark_train.sh
│ │ ├── benchmark_train_multi_gpu.sh
│ │ ├── predict.sh
│ │ ├── train.sh
│ │ └── train_multi_gpu.sh
│ ├── se3_transformer
│ │ ├── __init__.py
│ │ ├── data_loading
│ │ │ ├── __init__.py
│ │ │ ├── data_module.py
│ │ │ └── qm9.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── basis.py
│ │ │ ├── fiber.py
│ │ │ ├── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── convolution.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── norm.py
│ │ │ │ └── pooling.py
│ │ │ └── transformer.py
│ │ └── runtime
│ │ │ ├── __init__.py
│ │ │ ├── arguments.py
│ │ │ ├── callbacks.py
│ │ │ ├── gpu_affinity.py
│ │ │ ├── inference.py
│ │ │ ├── loggers.py
│ │ │ ├── metrics.py
│ │ │ ├── training.py
│ │ │ └── utils.py
│ ├── setup.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_equivariance.py
│ │ └── utils.py
│ ├── example
│ ├── RNA.fa
│ ├── dna_binding_protein.fa
│ └── rna_binding_protein.fa
│ ├── input_prep
│ ├── make_pMSAs_prot_RNA.py
│ ├── make_protein_msa.sh
│ ├── make_rna_msa.sh
│ └── reprocess_rnac.pl
│ ├── network
│ ├── Attention_module.py
│ ├── AuxiliaryPredictor.py
│ ├── Embeddings.py
│ ├── RoseTTAFoldModel.py
│ ├── SE3_network.py
│ ├── Track_module.py
│ ├── arguments.py
│ ├── chemical.py
│ ├── coords6d.py
│ ├── data_loader.py
│ ├── ffindex.py
│ ├── kinematics.py
│ ├── loss.py
│ ├── models.json
│ ├── parsers.py
│ ├── predict.py
│ ├── resnet.py
│ ├── scheduler.py
│ ├── scoring.py
│ ├── util.py
│ └── util_module.py
│ ├── pdb100_2021Mar03
│ ├── pdb100_2021Mar03_pdb.ffdata
│ └── pdb100_2021Mar03_pdb.ffindex
│ └── run_RF2NA.sh
├── img
└── Nucleic_Acid_Diffusion.gif
├── logs
└── .gitkeep
├── metadata
└── PDB_NA_Dataset.csv
├── notebooks
├── .gitkeep
├── analyze_dataset_characteristics.ipynb
├── analyze_dataset_characteristics.py
├── creating_protein_na_datasets_from_the_pdb.ipynb
└── creating_protein_na_datasets_from_the_pdb.py
├── pyproject.toml
├── scripts
└── schedule.sh
├── setup.py
├── src
├── __init__.py
├── data
│ ├── __init__.py
│ ├── components
│ │ ├── __init__.py
│ │ └── pdb
│ │ │ ├── all_atom.py
│ │ │ ├── chemical.py
│ │ │ ├── complex.py
│ │ │ ├── complex_constants.py
│ │ │ ├── data_transforms.py
│ │ │ ├── data_utils.py
│ │ │ ├── errors.py
│ │ │ ├── join_pdb_metadata.py
│ │ │ ├── mmcif_parsing.py
│ │ │ ├── nucleotide_constants.py
│ │ │ ├── parsers.py
│ │ │ ├── parsing.py
│ │ │ ├── pdb_na_dataset.py
│ │ │ ├── process_pdb_mmcif_files.py
│ │ │ ├── process_pdb_na_files.py
│ │ │ ├── protein.py
│ │ │ ├── protein_constants.py
│ │ │ ├── relax
│ │ │ ├── amber_minimize.py
│ │ │ ├── cleanup.py
│ │ │ └── relax_utils.py
│ │ │ ├── rigid_utils.py
│ │ │ ├── so3_utils.py
│ │ │ ├── stereo_chemical_props.txt
│ │ │ ├── validate_na_frames.py
│ │ │ ├── validate_protein_frames.py
│ │ │ └── vocabulary.py
│ └── pdb_na_datamodule.py
├── models
│ ├── __init__.py
│ ├── components
│ │ ├── __init__.py
│ │ └── pdb
│ │ │ ├── analysis_utils.py
│ │ │ ├── embedders.py
│ │ │ ├── framediff.py
│ │ │ ├── loss.py
│ │ │ ├── metrics.py
│ │ │ ├── se3_diffusion.py
│ │ │ └── sequence_diffusion.py
│ └── pdb_prot_na_gen_se3_module.py
├── sample.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── instantiators.py
│ ├── logging_utils.py
│ ├── pylogger.py
│ ├── rich_utils.py
│ └── utils.py
└── tests
├── __init__.py
├── conftest.py
├── helpers
├── __init__.py
├── package_available.py
├── run_if.py
└── run_sh_command.py
├── test_configs.py
├── test_datamodules.py
├── test_eval.py
├── test_sweeps.py
└── test_train.py
/.env.example:
--------------------------------------------------------------------------------
1 | # example of file for storing private and user specific environment variables, like keys or system paths
2 | # rename it to ".env" (excluded from version control by default)
3 | # .env is loaded by train.py automatically
4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
5 |
6 | MY_VAR="/home/user/my/system/path"
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | ### VisualStudioCode
131 | .vscode/*
132 | !.vscode/settings.json
133 | !.vscode/tasks.json
134 | !.vscode/launch.json
135 | !.vscode/extensions.json
136 | *.code-workspace
137 | **/.vscode
138 |
139 | # JetBrains
140 | .idea/
141 |
142 | # Data & Models
143 | *.h5
144 | *.tar
145 | *.tar.gz
146 |
147 | # MMDiff
148 | configs/local/default.yaml
149 | .env
150 | *.pdb
151 | *.csv
152 | eval_results*/
153 | data/
154 | generations/
155 | *_outputs*/
156 | logs/
157 | scripts/
158 | checkpoints/*
159 | !checkpoints/.gitkeep
160 | !metadata/
161 |
162 | # RoseTTAFold2NA
163 | forks/RoseTTAFold2NA/RF2NA/
164 | forks/RoseTTAFold2NA/network/RF2NA_apr23.tgz
165 | forks/RoseTTAFold2NA/network/weights/
166 | forks/RoseTTAFold2NA/UniRef30_2020_06*
167 | forks/RoseTTAFold2NA/bfd*
168 | forks/RoseTTAFold2NA/pdb100*
169 | forks/RoseTTAFold2NA/RNA*
170 | forks/RoseTTAFold2NA/**/*.stderr
171 | forks/RoseTTAFold2NA/**/*.stdout
172 | forks/RoseTTAFold2NA/**/*.a3m
173 | forks/RoseTTAFold2NA/**/*.*tab
174 | forks/RoseTTAFold2NA/**/*.hhr
175 | forks/RoseTTAFold2NA/**/*.*fa
176 | forks/RoseTTAFold2NA/**/*.npz
177 | forks/RoseTTAFold2NA/**/*db*
178 |
179 | # Aim logging
180 | .aim
181 |
182 | # MLFlow logging
183 | mlruns/
184 |
185 | # Conda/Mamba
186 | MMDiff*/
187 |
188 | # Hydra
189 | .hydra
190 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.4.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-docstring-first
12 | - id: check-yaml
13 | - id: debug-statements
14 | - id: detect-private-key
15 | - id: check-executables-have-shebangs
16 | - id: check-toml
17 | - id: check-case-conflict
18 | - id: check-added-large-files
19 | args: ["--maxkb=200000"]
20 |
21 | # python code formatting
22 | - repo: https://github.com/psf/black
23 | rev: 23.3.0
24 | hooks:
25 | - id: black
26 | args: [--line-length, "99"]
27 |
28 | # python import sorting
29 | - repo: https://github.com/PyCQA/isort
30 | rev: 5.12.0
31 | hooks:
32 | - id: isort
33 | args: ["--profile", "black", "--filter-files"]
34 |
35 | # python upgrading syntax to newer version
36 | - repo: https://github.com/asottile/pyupgrade
37 | rev: v3.4.0
38 | hooks:
39 | - id: pyupgrade
40 | args: [--py38-plus]
41 |
42 | # python docstring formatting
43 | - repo: https://github.com/myint/docformatter
44 | rev: v1.7.0
45 | hooks:
46 | - id: docformatter
47 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
48 |
49 | # python check (PEP8), programming errors and code complexity
50 | - repo: https://github.com/PyCQA/flake8
51 | rev: 6.0.0
52 | hooks:
53 | - id: flake8
54 | args:
55 | [
56 | "--extend-ignore",
57 | "E203,E402,E501,F401,F841",
58 | "--exclude",
59 | "logs/*,data/*",
60 | ]
61 |
62 | # python security linter
63 | - repo: https://github.com/PyCQA/bandit
64 | rev: "1.7.5"
65 | hooks:
66 | - id: bandit
67 | args: ["-s", "B101"]
68 |
69 | # yaml formatting
70 | - repo: https://github.com/pre-commit/mirrors-prettier
71 | rev: v3.0.0-alpha.9-for-vscode
72 | hooks:
73 | - id: prettier
74 | types: [yaml]
75 | exclude: "environment.yaml"
76 |
77 | # shell scripts linter
78 | - repo: https://github.com/shellcheck-py/shellcheck-py
79 | rev: v0.9.0.2
80 | hooks:
81 | - id: shellcheck
82 |
83 | # md formatting
84 | - repo: https://github.com/executablebooks/mdformat
85 | rev: 0.7.16
86 | hooks:
87 | - id: mdformat
88 | args: ["--number"]
89 | additional_dependencies:
90 | - mdformat-gfm
91 | - mdformat-tables
92 | - mdformat_frontmatter
93 | # - mdformat-toc
94 | # - mdformat-black
95 |
96 | # word spelling linter
97 | - repo: https://github.com/codespell-project/codespell
98 | rev: v2.2.4
99 | hooks:
100 | - id: codespell
101 | args:
102 | - --skip=logs/**,data/**,*.ipynb,*.csv
103 | - --ignore-words-list=ser,coo,nd
104 |
105 | # jupyter notebook cell output clearing
106 | - repo: https://github.com/kynan/nbstripout
107 | rev: 0.6.1
108 | hooks:
109 | - id: nbstripout
110 |
111 | # jupyter notebook linting
112 | - repo: https://github.com/nbQA-dev/nbQA
113 | rev: 1.7.0
114 | hooks:
115 | - id: nbqa-black
116 | args: ["--line-length=99"]
117 | - id: nbqa-isort
118 | args: ["--profile=black"]
119 | - id: nbqa-flake8
120 | args:
121 | [
122 | "--extend-ignore=E203,E402,E501,F401,F841",
123 | "--exclude=logs/*,data/*",
124 | ]
125 |
126 | exclude: "forks/|checkpoints/"
127 |
--------------------------------------------------------------------------------
/.project-root:
--------------------------------------------------------------------------------
1 | # this file is required for inferring the project root directory
2 | # do not delete
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Profluent Bio
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | help: ## Show help
3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4 |
5 | clean: ## Clean autogenerated files
6 | rm -rf dist
7 | find . -type f -name "*.DS_Store" -ls -delete
8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9 | find . | grep -E ".pytest_cache" | xargs rm -rf
10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11 | rm -f .coverage
12 |
13 | clean-logs: ## Clean logs
14 | rm -rf logs/**
15 |
16 | format: ## Run pre-commit hooks
17 | pre-commit run -a
18 |
19 | sync: ## Merge changes from main branch to your current branch
20 | git pull
21 | git pull origin main
22 |
23 | test: ## Run not slow tests
24 | pytest -k "not slow"
25 |
26 | test-full: ## Run all tests
27 | pytest
28 |
29 | train: ## Train the model
30 | python src/train.py
31 |
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # this file is needed here to include configs when building project as a package
2 |
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model_checkpoint.yaml
3 | - model_summary.yaml
4 | - rich_progress_bar.yaml
5 | - _self_
6 |
7 | model_checkpoint:
8 | dirpath: ${paths.output_dir}/checkpoints
9 | filename: "epoch_{epoch:03d}"
10 | monitor: "val/na_num_c4_prime_steric_clashes"
11 | mode: "min"
12 | save_last: True
13 | auto_insert_metric_name: False
14 |
15 | model_summary:
16 | max_depth: -1
17 |
--------------------------------------------------------------------------------
/configs/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2 |
3 | early_stopping:
4 | _target_: lightning.pytorch.callbacks.EarlyStopping
5 | monitor: ??? # quantity to be monitored, must be specified !!!
6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7 | patience: 3 # number of checks with no improvement after which training will be stopped
8 | verbose: False # verbosity mode
9 | mode: "min" # "max" means higher metric value is better, can be also "min"
10 | strict: True # whether to crash the training if monitor is not found in the validation metrics
11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version
16 |
--------------------------------------------------------------------------------
/configs/callbacks/ema.yaml:
--------------------------------------------------------------------------------
1 | # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
2 |
3 | # Maintains an exponential moving average (EMA) of model weights.
4 | # Look at the above link for more detailed information regarding the original implementation.
5 | ema:
6 | _target_: src.models.EMA
7 | decay: 0.9999
8 | apply_ema_every_n_steps: 1
9 | start_step: 0
10 | save_ema_weights_in_callback_state: true
11 | evaluate_ema_weights_instead: true
12 |
--------------------------------------------------------------------------------
/configs/callbacks/model_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2 |
3 | model_checkpoint:
4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
5 | dirpath: null # directory to save the model file
6 | filename: null # checkpoint filename
7 | monitor: null # name of the logged metric which determines when model is improving
8 | verbose: False # verbosity mode
9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10 | save_top_k: 1 # save k best models (determined by above metric)
11 | mode: "min" # "max" means higher metric value is better, can be also "min"
12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13 | save_weights_only: False # if True, then only the model’s weights will be saved
14 | every_n_train_steps: null # number of training steps between checkpoints
15 | train_time_interval: null # checkpoints are monitored at the specified time interval
16 | every_n_epochs: null # number of epochs between checkpoints
17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
18 |
--------------------------------------------------------------------------------
/configs/callbacks/model_summary.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2 |
3 | model_summary:
4 | _target_: lightning.pytorch.callbacks.RichModelSummary
5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include
6 |
--------------------------------------------------------------------------------
/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/configs/callbacks/rich_progress_bar.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2 |
3 | rich_progress_bar:
4 | _target_: lightning.pytorch.callbacks.RichProgressBar
5 |
--------------------------------------------------------------------------------
/configs/data/pdb_na.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.pdb_na_datamodule.PDBNADataModule
2 | data_cfg:
3 | # CSV for path and metadata to training examples
4 | csv_path: ${paths.data_dir}/PDB-NA/processed/metadata.csv
5 | annot_path: ${paths.root_dir}/metadata/PDB_NA_Dataset.csv
6 | cluster_path: ${paths.data_dir}/PDB-NA/processed/clusters-by-entity-30.txt
7 | cluster_examples_by_structure: true # Note: Corresponds to using qTMclust structure-based clustering to select training examples
8 | qtmclust_exec_path: ${paths.qtmclust_exec_path}
9 | filtering:
10 | max_len: 768
11 | min_len: 10
12 | # Select a subset of examples, which could be useful for debugging
13 | subset: null
14 | mmcif_allowed_oligomer: [monomeric] # Note: Corresponds to filtering complexes originating from mmCIF files
15 | pdb_allowed_oligomer: null # Note: Corresponds to filtering complexes originating from PDB files
16 | max_helix_percent: 1.0
17 | max_loop_percent: 0.5
18 | min_beta_percent: -1.0
19 | rog_quantile: 0.96
20 | # Specify which types of molecules to keep in the dataset
21 | allowed_molecule_types: [protein, na] # Note: Value must be in `[[protein, na], [protein], [na], null]`
22 | # As a cross-validation holdout, remove examples containing proteins belonging to a subset of Pfam protein families
23 | holdout: null
24 | min_t: ${model.diffusion_cfg.min_timestep}
25 | samples_per_eval_length: 4
26 | num_eval_lengths: 10
27 | num_t: ${model.diffusion_cfg.num_timesteps}
28 | max_squared_res: 600000
29 | batch_size: 256
30 | eval_batch_size: ${.samples_per_eval_length}
31 | sample_mode: time_batch # Note: Must be in [`time_batch`, `length_batch`, `cluster_time_batch`, `cluster_length_batch`]
32 | num_workers: 5
33 | prefetch_factor: 100
34 | # Sequence diffusion arguments
35 | diffuse_sequence: ${model.diffusion_cfg.diffuse_sequence}
36 | num_sequence_t: ${model.diffusion_cfg.num_sequence_timesteps}
37 | sequence_noise_schedule: ${model.diffusion_cfg.sequence_noise_schedule} # note: must be a value in [`linear`, `cosine`, `sqrt`]
38 | sequence_sample_distribution: ${model.diffusion_cfg.sequence_sample_distribution} # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function
39 | sequence_sample_distribution_gmm_means: ${model.diffusion_cfg.sequence_sample_distribution_gmm_means}
40 | sequence_sample_distribution_gmm_variances: ${model.diffusion_cfg.sequence_sample_distribution_gmm_variances}
41 |
--------------------------------------------------------------------------------
/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | # disable callbacks and loggers during debugging
10 | callbacks: null
11 | logger: null
12 |
13 | extras:
14 | ignore_warnings: False
15 | enforce_tags: False
16 |
17 | # sets level of all command line loggers to 'DEBUG'
18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19 | hydra:
20 | job_logging:
21 | root:
22 | level: DEBUG
23 |
24 | # use this to also set hydra loggers to 'DEBUG'
25 | # verbose: True
26 |
27 | trainer:
28 | max_epochs: 1
29 | accelerator: cpu # debuggers don't like gpus
30 | devices: 1 # debuggers don't like multiprocessing
31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32 |
33 | data:
34 | num_workers: 0 # debuggers don't like multiprocessing
35 | pin_memory: False # disable gpu memory pin
36 |
--------------------------------------------------------------------------------
/configs/debug/fdr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs 1 train, 1 validation and 1 test step
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | fast_dev_run: true
10 |
--------------------------------------------------------------------------------
/configs/debug/limit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # uses only 1% of the training data and 5% of validation/test data
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 3
10 | limit_train_batches: 0.01
11 | limit_val_batches: 0.05
12 | limit_test_batches: 0.05
13 |
--------------------------------------------------------------------------------
/configs/debug/overfit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # overfits to 3 batches
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 20
10 | overfit_batches: 3
11 |
12 | # model ckpt and early stopping need to be disabled during overfitting
13 | callbacks: null
14 |
--------------------------------------------------------------------------------
/configs/debug/profiler.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs with execution time profiling
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 1
10 | profiler: "simple"
11 | # profiler: "advanced"
12 | # profiler: "pytorch"
13 |
--------------------------------------------------------------------------------
/configs/experiment/pdb_prot_na_gen_se3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=pdb_prot_na_gen_se3
5 |
6 | defaults:
7 | - override /data: pdb_na.yaml
8 | - override /model: pdb_prot_na_gen_se3_module.yaml
9 | - override /callbacks: default.yaml
10 | - override /trainer: default.yaml
11 |
12 | # all parameters below will be merged with parameters from default configurations set above
13 | # this allows you to overwrite only specified parameters
14 |
15 | tags: ["pdb", "prot_na_gen", "se3"]
16 |
17 | seed: 12345
18 |
19 | trainer:
20 | min_epochs: 100
21 | max_epochs: 1000
22 |
23 | callbacks:
24 | model_checkpoint:
25 | monitor: "val/na_num_c4_prime_steric_clashes"
26 | mode: "min"
27 |
28 | model:
29 | optimizer:
30 | lr: 0.0001
31 |
32 | data:
33 | data_cfg:
34 | batch_size: 256
35 |
36 | logger:
37 | # mlflow:
38 | # _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
39 | # experiment_name: MMDiff
40 | # run_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}_PDBProtGenSE3
41 | # # tags: ${tags}
42 | # # save_dir: "./mlruns"
43 | # # log_model: false
44 | # prefix: ""
45 | # artifact_location: null
46 | # # run_id: ""
47 | wandb:
48 | _target_: lightning.pytorch.loggers.wandb.WandbLogger
49 | # name: "" # name of the run (normally generated by wandb)
50 | save_dir: "${paths.output_dir}"
51 | offline: False
52 | id: null # pass correct id to resume experiment!
53 | anonymous: null # enable anonymous logging
54 | project: "MMDiff"
55 | log_model: False # upload lightning ckpts
56 | prefix: "" # a string to put at the beginning of metric keys
57 | # entity: "" # set to name of your wandb team
58 | group: ""
59 | tags: []
60 | job_type: ""
61 |
--------------------------------------------------------------------------------
/configs/extras/default.yaml:
--------------------------------------------------------------------------------
1 | # disable python warnings if they annoy you
2 | ignore_warnings: False
3 |
4 | # ask user for tags if none are provided in the config
5 | enforce_tags: True
6 |
7 | # pretty print config tree at the start of the run using Rich library
8 | print_config: True
9 |
--------------------------------------------------------------------------------
/configs/hparams_search/pdb_prot_gen_se3_module_optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python train.py -m hparams_search=mnist_optuna experiment=example
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/na_num_c4_prime_steric_clashes"
12 |
13 | # here we define Optuna hyperparameter search
14 | # it optimizes for value returned from function with @hydra.main decorator
15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16 | hydra:
17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18 |
19 | sweeper:
20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21 |
22 | # storage URL to persist optimization results
23 | # for example, you can use SQLite if you set 'sqlite:///example.db'
24 | storage: null
25 |
26 | # name of the study to persist optimization results
27 | study_name: null
28 |
29 | # number of parallel workers
30 | n_jobs: 1
31 |
32 | # 'minimize' or 'maximize' the objective
33 | direction: minimize
34 |
35 | # total number of runs that will be executed
36 | n_trials: 20
37 |
38 | # choose Optuna hyperparameter sampler
39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41 | sampler:
42 | _target_: optuna.samplers.TPESampler
43 | seed: 1234
44 | n_startup_trials: 10 # number of random sampling runs before optimization starts
45 |
46 | # define hyperparameter search space
47 | params:
48 | model.optimizer.lr: interval(0.0001, 0.1)
49 | data.batch_size: choice(8, 16, 32, 64)
50 |
--------------------------------------------------------------------------------
/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11 | sweep:
12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13 | subdir: ${hydra.job.num}
14 |
--------------------------------------------------------------------------------
/configs/local/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/configs/local/.gitkeep
--------------------------------------------------------------------------------
/configs/logger/aim.yaml:
--------------------------------------------------------------------------------
1 | # https://aimstack.io/
2 |
3 | # example usage in lightning module:
4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5 |
6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7 | # `aim up`
8 |
9 | aim:
10 | _target_: aim.pytorch_lightning.AimLogger
11 | repo: ${paths.root_dir} # .aim folder will be created here
12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13 |
14 | # aim allows to group runs under experiment name
15 | experiment: null # any string, set to "default" if not specified
16 |
17 | train_metric_prefix: "train/"
18 | val_metric_prefix: "val/"
19 | test_metric_prefix: "test/"
20 |
21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22 | system_tracking_interval: 10 # set to null to disable system metrics tracking
23 |
24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25 | log_system_params: true
26 |
27 | # enable/disable tracking console logs (default value is true)
28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
29 |
--------------------------------------------------------------------------------
/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: lightning.pytorch.loggers.comet.CometLogger
5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6 | save_dir: "${paths.output_dir}"
7 | project_name: "MMDiff"
8 | rest_api_key: null
9 | # experiment_name: ""
10 | experiment_key: null # set to resume experiment
11 | offline: False
12 | prefix: ""
13 |
--------------------------------------------------------------------------------
/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5 | save_dir: "${paths.output_dir}"
6 | name: "csv/"
7 | prefix: ""
8 |
--------------------------------------------------------------------------------
/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | # - comet.yaml
5 | - csv.yaml
6 | # - mlflow.yaml
7 | # - neptune.yaml
8 | - tensorboard.yaml
9 | - wandb.yaml
10 |
--------------------------------------------------------------------------------
/configs/logger/mlflow.yaml:
--------------------------------------------------------------------------------
1 | # https://mlflow.org
2 |
3 | mlflow:
4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5 | experiment_name: MMDiff
6 | # run_name: ""
7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8 | tags: null
9 | # save_dir: "./mlruns"
10 | # log_model: false
11 | prefix: ""
12 | artifact_location: null
13 | # run_id: ""
14 |
--------------------------------------------------------------------------------
/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6 | project: username/MMDiff
7 | # name: ""
8 | log_model_checkpoints: True
9 | prefix: ""
10 |
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # https://www.tensorflow.org/tensorboard/
2 |
3 | tensorboard:
4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "${paths.output_dir}/tensorboard/"
6 | name: null
7 | log_graph: False
8 | default_hp_metric: True
9 | prefix: ""
10 | # version: ""
11 |
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger
5 | # name: "" # name of the run (normally generated by wandb)
6 | save_dir: "${paths.output_dir}"
7 | offline: False
8 | id: null # pass correct id to resume experiment!
9 | anonymous: null # enable anonymous logging
10 | project: "MMDiff"
11 | log_model: False # upload lightning ckpts
12 | prefix: "" # a string to put at the beginning of metric keys
13 | # entity: "" # set to name of your wandb team
14 | group: ""
15 | tags: []
16 | job_type: ""
17 |
--------------------------------------------------------------------------------
/configs/model/diffusion_cfg/pdb_prot_na_gen_se3_ddpm.yaml:
--------------------------------------------------------------------------------
1 | # general diffusion arguments
2 | ddpm_mode: se3_unconditional # [se3_unconditional]
3 | diffusion_network: se3_score # [se3_score]
4 | dynamics_network: framediff # [framediff]
5 | eval_epochs: 10
6 |
7 | # SE(3) diffusion arguments
8 | diffuse_rotations: true
9 | diffuse_translations: true
10 | min_timestep: 0.01
11 | num_timesteps: 100
12 |
13 | # R(3) diffusion arguments
14 | r3:
15 | min_b: 0.1
16 | max_b: 20.0
17 | coordinate_scaling: 0.1
18 |
19 | # SO(3) diffusion arguments
20 | so3:
21 | num_omega: 1000
22 | num_sigma: 1000
23 | min_sigma: 0.1
24 | max_sigma: 1.5
25 | schedule: logarithmic
26 | cache_dir: .cache/
27 | use_cached_score: false
28 |
29 | # sequence diffusion arguments
30 | diffuse_sequence: true
31 | num_sequence_timesteps: ${.num_timesteps}
32 | sequence_noise_schedule: sqrt # note: must be a value in [`linear`, `cosine`, `sqrt`]
33 | sequence_sample_distribution: normal # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function
34 | sequence_sample_distribution_gmm_means: [-1.0, 1.0]
35 | sequence_sample_distribution_gmm_variances: [1.0, 1.0]
36 |
37 | # sampling arguments
38 | sampling:
39 | sequence_noise_scale: 1.0
40 | structure_noise_scale: 1.0
41 | apply_na_consensus_sampling: false
42 |
--------------------------------------------------------------------------------
/configs/model/model_cfg/pdb_prot_na_gen_se3_model.yaml:
--------------------------------------------------------------------------------
1 | node_input_dim: 1
2 | edge_input_dim: 2
3 | existing_edge_embedding_dim: 0
4 |
5 | node_hidden_dim: 256
6 | edge_hidden_dim: 128
7 | num_layers: 5
8 | dropout: 0.0
9 |
10 | c_skip: 64
11 | num_angles: ${subtract:${resolve_variable:src.data.components.pdb.complex_constants.NUM_PROT_NA_TORSIONS},2} # note: must be `1` for protein design tasks and `8` for either nucleic acid or protein-nucleic acid design tasks
12 |
13 | clip_gradients: false
14 | log_grad_flow_steps: 3000 # after how many steps to log gradient flow
15 |
16 | embedding:
17 | index_embedding_size: 128
18 | max_relative_idx: 32
19 | max_relative_chain: 2
20 | use_chain_relative: true
21 | molecule_type_embedding_size: 4
22 | molecule_type_embedded_size: 128
23 | embed_molecule_type_conditioning: True
24 | embed_self_conditioning: True
25 | num_bins: 22
26 | min_bin: 1e-5
27 | max_bin: 20.0
28 |
29 | ipa:
30 | c_s: ${..node_hidden_dim}
31 | c_z: ${..edge_hidden_dim}
32 | c_hidden: 256
33 | c_resnet: 128
34 | num_resnet_blocks: 2
35 | num_heads: 8
36 | num_qk_points: 8
37 | num_v_points: 12
38 | coordinate_scaling: ${...diffusion_cfg.r3.coordinate_scaling}
39 | inf: 1e5
40 | epsilon: 1e-8
41 |
42 | tfmr:
43 | num_heads: 4
44 | num_layers: 2
45 |
46 | loss_cfg:
47 | trans_loss_weight: 1.0
48 | rot_loss_weight: 0.5
49 | rot_loss_t_threshold: 0.2
50 | separate_rot_loss: False
51 | trans_x0_threshold: 1.0
52 | bb_atom_loss_weight: 1.0
53 | bb_atom_loss_t_filter: 0.25
54 | dist_mat_loss_weight: 1.0
55 | dist_mat_loss_t_filter: 0.25
56 | interface_dist_mat_loss_weight: 1.0 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_interfaces}` is `true`
57 | interface_dist_mat_loss_t_filter: 0.25 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_interfaces}` is `true`
58 | aux_loss_weight: 0.25
59 | torsion_loss_weight: 1.0 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true`
60 | torsion_loss_t_filter: 0.25 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true`
61 | torsion_norm_loss_weight: 0.02 # note: will only be used if `${model.model_cfg.loss_cfg.supervise_torsion_angles}` is `true`
62 | cce_seq_loss_weight: 1.0 # note: will only be used if `${model.diffusion_cfg.diffuse_sequence}` is `true`
63 | kl_seq_loss_weight: 0.25 # note: will only be used if `${model.diffusion_cfg.diffuse_sequence}` is `true`
64 | supervise_n1_atom_positions: true # note: if `true`, for pyrimidine residues will instead supervise predicted N9 atom positions using ground-truth N1 atom positions
65 | supervise_interfaces: true # note: if `true`, will supervise the model's predicted pairwise distances specifically for interfacing residues
66 | supervise_torsion_angles: false # note: if `false`, will supervise the model's predicted torsion angles indirectly using predicted coordinates
67 | coordinate_scaling: ${...diffusion_cfg.r3.coordinate_scaling}
68 |
--------------------------------------------------------------------------------
/configs/model/pdb_prot_na_gen_se3_module.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.pdb_prot_na_gen_se3_module.PDBProtNAGenSE3LitModule
2 |
3 | optimizer:
4 | _target_: torch.optim.Adam
5 | _partial_: true
6 | lr: 0.0001
7 |
8 | scheduler:
9 | # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
10 | # _partial_: true
11 | # mode: min
12 | # factor: 0.1
13 | # patience: 10
14 |
15 | defaults:
16 | - model_cfg: pdb_prot_na_gen_se3_model.yaml
17 | - diffusion_cfg: pdb_prot_na_gen_se3_ddpm.yaml
18 |
--------------------------------------------------------------------------------
/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | # path to root directory
2 | # this requires PROJECT_ROOT environment variable to exist
3 | # you can replace it with "." if you want the root to be the current working directory
4 | root_dir: ${oc.env:PROJECT_ROOT}
5 |
6 | # path to data directory
7 | data_dir: ${paths.root_dir}/data/
8 |
9 | # path to logging directory
10 | log_dir: ${paths.root_dir}/logs/
11 |
12 | # path to output directory, created dynamically by hydra
13 | # path generation pattern is specified in `configs/hydra/default.yaml`
14 | # use it to store all files generated during the run, like ckpts and metrics
15 | output_dir: ${hydra:runtime.output_dir}
16 |
17 | # path to working directory
18 | work_dir: ${hydra:runtime.cwd}
19 |
20 | # paths to local executables
21 | rf2na_exec_path: ${paths.root_dir}/forks/RoseTTAFold2NA/run_RF2NA.sh
22 | proteinmpnn_dir: ${paths.root_dir}/forks/ProteinMPNN/
23 | usalign_exec_path: ~/Programs/USalign/USalign # note: must be an absolute path during runtime
24 | qtmclust_exec_path: ~/Programs/USalign/qTMclust # note: must be an absolute path during runtime
25 |
--------------------------------------------------------------------------------
/configs/paths/pdb_metadata.yaml:
--------------------------------------------------------------------------------
1 | na_metadata_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/na_metadata.csv
2 | protein_metadata_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/protein_metadata.csv
3 | metadata_output_csv_path: ${paths.data_dir}/PDB-NA/processed_pdb/metadata.csv
4 |
--------------------------------------------------------------------------------
/configs/sample.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - data: pdb_na.yaml # choose datamodule with `predict_dataloader()` for inference
6 | - model: pdb_prot_na_gen_se3_module.yaml
7 | - logger: null
8 | - trainer: default.yaml
9 | - paths: default.yaml
10 | - extras: default.yaml
11 | - hydra: default.yaml
12 |
13 | task_name: "sample"
14 |
15 | tags: ["dev"]
16 |
17 | # passing checkpoint path is necessary for sampling
18 | ckpt_path: checkpoints/protein_na_sequence_structure_g42jpyug_rotations_epoch_286.ckpt
19 |
20 | # establishing inference arguments
21 | inference:
22 | name: protein_na_sequence_structure_se3_discrete_diffusion_sampling_${now:%Y-%m-%d}_${now:%H:%M:%S}
23 | seed: 123
24 | run_statified_eval: false # note: if `true`, will instead use the `validation` dataset's range of examples for sampling evaluation; if `false`, (large) multi-state PDB trajectories will be recorded
25 | filter_eval_split: false # note: if `true`, will use `samples.min/max_length` and `samples.min/max_num_chains` to filter out examples from the evaluation dataset
26 | # whether to compute evaluation metrics for generated samples e.g., using RoseTTAFold2NA
27 | # warning: `self_consistency` and `novelty` will be time and memory-intensive metrics to compute
28 | run_self_consistency_eval: false
29 | run_diversity_eval: false
30 | run_novelty_eval: false
31 | use_rf2na_single_sequence_mode: true # note: trades prediction accuracy for time complexity
32 | generate_protein_sequences_using_pmpnn: false # note: should only be `true` if generating protein-only samples and instead using ProteinMPNN for backbone sequence design
33 | measure_auxiliary_na_metrics: false # note: should only be `true` if generating nucleic acid-only samples
34 |
35 | # output directory for samples
36 | output_dir: ./inference_outputs/
37 |
38 | diffusion:
39 | # number of diffusion steps for sampling
40 | num_t: 500
41 | # note: analogous to sampling temperature
42 | sequence_noise_scale: 1.0
43 | structure_noise_scale: 0.1
44 | # final diffusion step `t` for sampling
45 | min_t: 0.01
46 | # whether to apply a 50% majority rule that transforms all generated nucleotide residue types to be exclusively of DNA or RNA types
47 | apply_na_consensus_sampling: true
48 | # whether to employ a random diffusion baseline for sequence-structure generation
49 | employ_random_baseline: false
50 | # note: the following are for overriding sequence diffusion arguments
51 | diffuse_sequence: true
52 | num_sequence_timesteps: ${.num_t}
53 | sequence_noise_schedule: sqrt # note: must be a value in [`linear`, `cosine`, `sqrt`], where `linear` typically yields the best looking structures; `cosine` often yields the most designable structures; and `sqrt` is what is used during training
54 | sequence_sample_distribution: normal # note: if value is not `normal`, then, instead of a Gaussian distribution, a Gaussian mixture model (GMM) is used as the sequence noising function
55 | sequence_sample_distribution_gmm_means: [-1.0, 1.0]
56 | sequence_sample_distribution_gmm_variances: [1.0, 1.0]
57 |
58 | samples:
59 | # number of backbone structures and sequences to sample and score per sequence length
60 | samples_per_length: 30
61 | # minimum sequence length to sample
62 | min_length: 10
63 | # maximum sequence length to sample
64 | max_length: 50
65 | # note: `num_length_steps` will only be used if `inference.run_statified_eval` is `true`
66 | num_length_steps: 10
67 | # note: `min_num_chains` will only be used if `inference.run_statified_eval` is `true`
68 | min_num_chains: 1
69 | # note: `max_num_chains` will only be used if `inference.run_statified_eval` is `true`
70 | max_num_chains: 4
71 | # gap between lengths to sample
72 | # (note: this script will sample all lengths
73 | # in range(min_length, max_length, length_step))
74 | length_step: 10
75 | # a syntactic specification mapping the standardized molecule types
76 | # (`A`: amino acid, `D`: deoxyribonucleic acid, `R`: ribonucleic acid) to each residue index;
77 | # note: the sum of each string's residue index annotations must sum to the
78 | # current length specified by the interval [`min_length`, `max_length`, `length_step`]:
79 | # e.g., `residue_molecule_type_mappings: ['R:100', 'A:75,D:75']`
80 | residue_molecule_type_mappings:
81 | ["R:10", "D:20", "R:10,D:20", "R:20,R:20", "A:40,R:10"]
82 | # residue_molecule_type_mappings: ["R:90", "R:40,A:60"] # e.g., for protein-nucleic acid generation
83 | # a syntactic specification mapping to each residue index
84 | # one of the PDB's 62 alphanumeric chain identifiers:
85 | # (i.e., one of `ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789`);
86 | # note: the sum of each string's residue index annotations must sum to the
87 | # current length specified by the interval [`min_length`, `max_length`, `length_step`]:
88 | # e.g., `residue_chain_mappings: ['a:50,b:40', 'a:50,b:50']`
89 | residue_chain_mappings:
90 | ["a:10", "a:20", "a:10,b:20", "a:20,b:20", "a:40,b:10"]
91 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default configuration
4 | # order of defaults determines the order in which configs override each other
5 | defaults:
6 | - _self_
7 | - data: pdb_na.yaml
8 | - model: pdb_prot_na_gen_se3_module.yaml
9 | - callbacks: default.yaml
10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11 | - trainer: default.yaml
12 | - paths: default.yaml
13 | - extras: default.yaml
14 | - hydra: default.yaml
15 |
16 | # experiment configs allow for version control of specific hyperparameters
17 | # e.g. best hyperparameters for given model and datamodule
18 | - experiment: null
19 |
20 | # config for hyperparameter optimization
21 | - hparams_search: null
22 |
23 | # optional local config for machine/user specific settings
24 | # it's optional since it doesn't need to exist and is excluded from version control
25 | - optional local: default.yaml
26 |
27 | # debugging config (enable through command line, e.g. `python train.py debug=default)
28 | - debug: null
29 |
30 | # task name, determines output directory path
31 | task_name: "train"
32 |
33 | # tags to help you identify your experiments
34 | # you can overwrite this in experiment configs
35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36 | tags: ["dev"]
37 |
38 | # set False to skip model training
39 | train: True
40 |
41 | # evaluate on test set, using best model weights achieved during training
42 | # lightning chooses best weights based on the metric specified in checkpoint callback
43 | test: True
44 |
45 | # compile model for faster training with pytorch 2.0
46 | compile: False
47 |
48 | # simply provide checkpoint path to resume training
49 | ckpt_path: null
50 |
51 | # seed for random number generators in pytorch, numpy and python.random
52 | seed: null
53 |
--------------------------------------------------------------------------------
/configs/trainer/cpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: cpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | strategy: ddp
5 |
6 | accelerator: gpu
7 | devices: 4
8 | num_nodes: 1
9 | sync_batchnorm: True
10 |
--------------------------------------------------------------------------------
/configs/trainer/ddp_sim.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # simulate DDP on CPU, useful for debugging
5 | accelerator: cpu
6 | devices: 2
7 | strategy: ddp_spawn
8 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: lightning.pytorch.trainer.Trainer
2 |
3 | default_root_dir: ${paths.output_dir}
4 |
5 | min_epochs: 1 # prevents early stopping
6 | max_epochs: 10
7 |
8 | accelerator: gpu
9 | devices: 1
10 |
11 | # mixed precision for extra speed-up
12 | # precision: 16
13 |
14 | # number of sanity-check validation forward passes to run prior to model training
15 | num_sanity_val_steps: 0
16 |
17 | # perform a validation loop every N training epochs
18 | check_val_every_n_epoch: null
19 | val_check_interval: 10000
20 | # note: when `check_val_every_n_epoch` is `null`,
21 | # Lightning will require you to provide an integer
22 | # value for `val_check_interval` to instead perform a
23 | # validation epoch every `val_check_interval` steps
24 |
25 | # gradient accumulation to simulate larger-than-GPU-memory batch sizes
26 | accumulate_grad_batches: 1
27 |
28 | # set True to to ensure deterministic results
29 | # makes training slower but gives more reproducibility than just setting seeds
30 | deterministic: False
31 |
32 | # track and log the vector norm of each gradient
33 | # track_grad_norm: 2.0
34 |
35 | # profile code comprehensively
36 | profiler:
37 | # _target_: pytorch_lightning.profilers.PyTorchProfiler
38 |
39 | # inform Lightning that we will be supplying a custom `Sampler` in our `DataModule` class
40 | use_distributed_sampler: False
41 |
--------------------------------------------------------------------------------
/configs/trainer/gpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: gpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/mps.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: mps
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/data/.gitkeep
--------------------------------------------------------------------------------
/eval_results/collect_eval_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 |
5 |
6 | def main(eval_outputs_dir: str, eval_results_dir: str):
7 | for item in os.listdir(eval_outputs_dir):
8 | eval_outputs_subdir = os.path.join(eval_outputs_dir, item)
9 | if os.path.isdir(eval_outputs_subdir):
10 | all_rank_predictions_csv_path = os.path.join(
11 | eval_outputs_subdir, "all_rank_predictions.csv"
12 | )
13 | if os.path.exists(all_rank_predictions_csv_path):
14 | eval_results_run_name = Path(eval_outputs_subdir).stem
15 | eval_results_run_type = "_".join(
16 | [
17 | s.replace("naive", "rb")
18 | for s in eval_results_run_name.split("_se3_discrete_diffusion_stratified")[
19 | 0
20 | ].split("_")
21 | ]
22 | )
23 | eval_results_run_category = "_".join(
24 | eval_results_run_name.split("_se3_discrete_diffusion_stratified")[1]
25 | .split("eval")[0]
26 | .split("_")
27 | )
28 | eval_results_csv_name = os.path.join(
29 | eval_results_dir,
30 | f"{eval_results_run_type}{eval_results_run_category}all_rank_predictions.csv",
31 | )
32 | shutil.copyfile(all_rank_predictions_csv_path, eval_results_csv_name)
33 |
34 |
35 | if __name__ == "__main__":
36 | main(eval_outputs_dir="inference_eval_outputs/", eval_results_dir="eval_results/")
37 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Justas Dauparas
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/ca_model_weights/v_48_002.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_002.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/ca_model_weights/v_48_010.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_010.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/ca_model_weights/v_48_020.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/ca_model_weights/v_48_020.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/colab_notebooks/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 2
6 | #SBATCH --output=example_1.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/"
11 |
12 | output_dir="../outputs/example_1_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
19 |
20 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
21 |
22 | python ../protein_mpnn_run.py \
23 | --jsonl_path $path_for_parsed_chains \
24 | --out_folder $output_dir \
25 | --num_seq_per_target 2 \
26 | --sampling_temp "0.1" \
27 | --seed 37 \
28 | --batch_size 1
29 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 2
6 | #SBATCH --output=example_2.out
7 |
8 | source activate mlfold
9 |
10 |
11 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/"
12 |
13 | output_dir="../outputs/example_2_outputs"
14 | if [ ! -d $output_dir ]
15 | then
16 | mkdir -p $output_dir
17 | fi
18 |
19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl"
21 | chains_to_design="A B"
22 |
23 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
24 |
25 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design"
26 |
27 | python ../protein_mpnn_run.py \
28 | --jsonl_path $path_for_parsed_chains \
29 | --chain_id_jsonl $path_for_assigned_chains \
30 | --out_folder $output_dir \
31 | --num_seq_per_target 2 \
32 | --sampling_temp "0.1" \
33 | --seed 37 \
34 | --batch_size 1
35 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_3.out
7 |
8 | source activate mlfold
9 |
10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb"
11 |
12 | output_dir="../outputs/example_3_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | chains_to_design="A B"
19 |
20 | python ../protein_mpnn_run.py \
21 | --pdb_path $path_to_PDB \
22 | --pdb_path_chains "$chains_to_design" \
23 | --out_folder $output_dir \
24 | --num_seq_per_target 2 \
25 | --sampling_temp "0.1" \
26 | --seed 37 \
27 | --batch_size 1
28 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_3_score_only.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_3.out
7 |
8 | source activate mlfold
9 |
10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb"
11 |
12 | output_dir="../outputs/example_3_score_only_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | chains_to_design="A B"
19 |
20 | python ../protein_mpnn_run.py \
21 | --pdb_path $path_to_PDB \
22 | --pdb_path_chains "$chains_to_design" \
23 | --out_folder $output_dir \
24 | --num_seq_per_target 10 \
25 | --sampling_temp "0.1" \
26 | --score_only 1 \
27 | --seed 37 \
28 | --batch_size 1
29 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_3_score_only_from_fasta.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_3_from_fasta.out
7 |
8 | source activate mlfold
9 |
10 | path_to_PDB="../inputs/PDB_complexes/pdbs/3HTN.pdb"
11 | path_to_fasta="/home/justas/projects/github/ProteinMPNN/outputs/example_3_outputs/seqs/3HTN.fa"
12 |
13 | output_dir="../outputs/example_3_score_only_from_fasta_outputs"
14 | if [ ! -d $output_dir ]
15 | then
16 | mkdir -p $output_dir
17 | fi
18 |
19 | chains_to_design="A B"
20 |
21 | python ../protein_mpnn_run.py \
22 | --path_to_fasta $path_to_fasta \
23 | --pdb_path $path_to_PDB \
24 | --pdb_path_chains "$chains_to_design" \
25 | --out_folder $output_dir \
26 | --num_seq_per_target 5 \
27 | --sampling_temp "0.1" \
28 | --score_only 1 \
29 | --seed 13 \
30 | --batch_size 1
31 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_4.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/"
11 |
12 | output_dir="../outputs/example_4_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 |
19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl"
21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl"
22 | chains_to_design="A C"
23 | #The first amino acid in the chain corresponds to 1 and not PDB residues index for now.
24 | fixed_positions="1 2 3 4 5 6 7 8 23 25, 10 11 12 13 14 15 16 17 18 19 20 40" #fixing/not designing residues 1 2 3...25 in chain A and residues 10 11 12...40 in chain C
25 |
26 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
27 |
28 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design"
29 |
30 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$fixed_positions"
31 |
32 | python ../protein_mpnn_run.py \
33 | --jsonl_path $path_for_parsed_chains \
34 | --chain_id_jsonl $path_for_assigned_chains \
35 | --fixed_positions_jsonl $path_for_fixed_positions \
36 | --out_folder $output_dir \
37 | --num_seq_per_target 2 \
38 | --sampling_temp "0.1" \
39 | --seed 37 \
40 | --batch_size 1
41 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_4_non_fixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_4_non_fixed.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/"
11 |
12 | output_dir="../outputs/example_4_non_fixed_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 |
19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl"
21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl"
22 | chains_to_design="A C"
23 | #The first amino acid in the chain corresponds to 1 and not PDB residues index for now.
24 | design_only_positions="1 2 3 4 5 6 7 8 9 10, 3 4 5 6 7 8" #design only these residues; use flag --specify_non_fixed
25 |
26 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
27 |
28 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design"
29 |
30 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$design_only_positions" --specify_non_fixed
31 |
32 | python ../protein_mpnn_run.py \
33 | --jsonl_path $path_for_parsed_chains \
34 | --chain_id_jsonl $path_for_assigned_chains \
35 | --fixed_positions_jsonl $path_for_fixed_positions \
36 | --out_folder $output_dir \
37 | --num_seq_per_target 2 \
38 | --sampling_temp "0.1" \
39 | --seed 37 \
40 | --batch_size 1
41 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_5.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/"
11 |
12 | output_dir="../outputs/example_5_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 |
19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
20 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl"
21 | path_for_fixed_positions=$output_dir"/fixed_pdbs.jsonl"
22 | path_for_tied_positions=$output_dir"/tied_pdbs.jsonl"
23 | chains_to_design="A C"
24 | fixed_positions="9 10 11 12 13 14 15 16 17 18 19 20 21 22 23, 10 11 18 19 20 22"
25 | tied_positions="1 2 3 4 5 6 7 8, 1 2 3 4 5 6 7 8" #two list must match in length; residue 1 in chain A and C will be sampled togther;
26 |
27 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
28 |
29 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design"
30 |
31 | python ../helper_scripts/make_fixed_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_fixed_positions --chain_list "$chains_to_design" --position_list "$fixed_positions"
32 |
33 | python ../helper_scripts/make_tied_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_tied_positions --chain_list "$chains_to_design" --position_list "$tied_positions"
34 |
35 | python ../protein_mpnn_run.py \
36 | --jsonl_path $path_for_parsed_chains \
37 | --chain_id_jsonl $path_for_assigned_chains \
38 | --fixed_positions_jsonl $path_for_fixed_positions \
39 | --tied_positions_jsonl $path_for_tied_positions \
40 | --out_folder $output_dir \
41 | --num_seq_per_target 2 \
42 | --sampling_temp "0.1" \
43 | --seed 37 \
44 | --batch_size 1
45 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_6.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_6.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_homooligomers/pdbs/"
11 |
12 | output_dir="../outputs/example_6_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 |
19 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
20 | path_for_tied_positions=$output_dir"/tied_pdbs.jsonl"
21 | path_for_designed_sequences=$output_dir"/temp_0.1"
22 |
23 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
24 |
25 | python ../helper_scripts/make_tied_positions_dict.py --input_path=$path_for_parsed_chains --output_path=$path_for_tied_positions --homooligomer 1
26 |
27 | python ../protein_mpnn_run.py \
28 | --jsonl_path $path_for_parsed_chains \
29 | --tied_positions_jsonl $path_for_tied_positions \
30 | --out_folder $output_dir \
31 | --num_seq_per_target 2 \
32 | --sampling_temp "0.2" \
33 | --seed 37 \
34 | --batch_size 1
35 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_7.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 2
6 | #SBATCH --output=example_7.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/"
11 |
12 | output_dir="../outputs/example_7_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
19 |
20 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
21 |
22 | python ../protein_mpnn_run.py \
23 | --jsonl_path $path_for_parsed_chains \
24 | --out_folder $output_dir \
25 | --num_seq_per_target 1 \
26 | --sampling_temp "0.1" \
27 | --unconditional_probs_only 1 \
28 | --seed 37 \
29 | --batch_size 1
30 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 2
6 | #SBATCH --output=example_8.out
7 |
8 | source activate mlfold
9 |
10 | folder_with_pdbs="../inputs/PDB_monomers/pdbs/"
11 |
12 | output_dir="../outputs/example_8_outputs"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | path_for_bias=$output_dir"/bias_pdbs.jsonl"
19 | #Adding global polar amino acid bias (Doug Tischer)
20 | AA_list="D E H K N Q R S T W Y"
21 | bias_list="1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39 1.39"
22 | python ../helper_scripts/make_bias_AA.py --output_path=$path_for_bias --AA_list="$AA_list" --bias_list="$bias_list"
23 |
24 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
25 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
26 |
27 | python ../protein_mpnn_run.py \
28 | --jsonl_path $path_for_parsed_chains \
29 | --out_folder $output_dir \
30 | --bias_AA_jsonl $path_for_bias \
31 | --num_seq_per_target 2 \
32 | --sampling_temp "0.1" \
33 | --seed 37 \
34 | --batch_size 1
35 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/examples/submit_example_pssm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 2
6 | #SBATCH --output=example_2.out
7 |
8 | source activate mlfold
9 |
10 |
11 | #new_probabilities_using_PSSM = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
12 | #probs - predictions from MPNN
13 | #pssm_bias_gathered - input PSSM bias (needs to be a probability distribution)
14 | #pssm_multi - a number between 0.0 (no bias) and 1.0 (no MPNN) inputed via flag --pssm_multi; this is a global number equally applied to all the residues
15 | #pssm_coef_gathered - a number between 0.0 (no bias) and 1.0 (no MPNN) inputed via ../helper_scripts/make_pssm_input_dict.py can be adjusted per residue level; i.e only apply PSSM bias to specific residues; or chains
16 |
17 |
18 |
19 | pssm_input_path="../inputs/PSSM_inputs"
20 | folder_with_pdbs="../inputs/PDB_complexes/pdbs/"
21 |
22 | output_dir="../outputs/example_pssm_outputs"
23 | if [ ! -d $output_dir ]
24 | then
25 | mkdir -p $output_dir
26 | fi
27 |
28 | path_for_parsed_chains=$output_dir"/parsed_pdbs.jsonl"
29 | path_for_assigned_chains=$output_dir"/assigned_pdbs.jsonl"
30 | pssm=$output_dir"/pssm.jsonl"
31 | chains_to_design="A B"
32 |
33 | python ../helper_scripts/parse_multiple_chains.py --input_path=$folder_with_pdbs --output_path=$path_for_parsed_chains
34 |
35 | python ../helper_scripts/assign_fixed_chains.py --input_path=$path_for_parsed_chains --output_path=$path_for_assigned_chains --chain_list "$chains_to_design"
36 |
37 | python ../helper_scripts/make_pssm_input_dict.py --jsonl_input_path=$path_for_parsed_chains --PSSM_input_path=$pssm_input_path --output_path=$pssm
38 |
39 | python ../protein_mpnn_run.py \
40 | --jsonl_path $path_for_parsed_chains \
41 | --chain_id_jsonl $path_for_assigned_chains \
42 | --out_folder $output_dir \
43 | --num_seq_per_target 2 \
44 | --sampling_temp "0.1" \
45 | --seed 37 \
46 | --batch_size 1 \
47 | --pssm_jsonl $pssm \
48 | --pssm_multi 0.3 \
49 | --pssm_bias_flag 1
50 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/assign_fixed_chains.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 | import json
5 |
6 | with open(args.input_path, 'r') as json_file:
7 | json_list = list(json_file)
8 |
9 | global_designed_chain_list = []
10 | if args.chain_list != '':
11 | global_designed_chain_list = [str(item) for item in args.chain_list.split()]
12 | my_dict = {}
13 | for json_str in json_list:
14 | result = json.loads(json_str)
15 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain'] #['A','B', 'C',...]
16 | if len(global_designed_chain_list) > 0:
17 | designed_chain_list = global_designed_chain_list
18 | else:
19 | #manually specify, e.g.
20 | designed_chain_list = ["A"]
21 | fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list] #fix/do not redesign these chains
22 | my_dict[result['name']]= (designed_chain_list, fixed_chain_list)
23 |
24 | with open(args.output_path, 'w') as f:
25 | f.write(json.dumps(my_dict) + '\n')
26 |
27 |
28 | if __name__ == "__main__":
29 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs")
31 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary")
32 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be designed")
33 |
34 | args = argparser.parse_args()
35 | main(args)
36 |
37 | # Output looks like this:
38 | # {"5TTA": [["A"], ["B"]], "3LIS": [["A"], ["B"]]}
39 |
40 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/make_bias_AA.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 |
5 | import numpy as np
6 | import json
7 |
8 | bias_list = [float(item) for item in args.bias_list.split()]
9 | AA_list = [str(item) for item in args.AA_list.split()]
10 |
11 | my_dict = dict(zip(AA_list, bias_list))
12 |
13 | with open(args.output_path, 'w') as f:
14 | f.write(json.dumps(my_dict) + '\n')
15 |
16 |
17 | if __name__ == "__main__":
18 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary")
20 | argparser.add_argument("--AA_list", type=str, default='', help="List of AAs to be biased")
21 | argparser.add_argument("--bias_list", type=str, default='', help="AA bias strengths")
22 |
23 | args = argparser.parse_args()
24 | main(args)
25 |
26 | #e.g. output
27 | #{"A": -0.01, "G": 0.02}
28 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/make_bias_per_res_dict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 | import glob
5 | import random
6 | import numpy as np
7 | import json
8 |
9 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
10 |
11 | mpnn_alphabet_dict = {'A': 0,'C': 1,'D': 2,'E': 3,'F': 4,'G': 5,'H': 6,'I': 7,'K': 8,'L': 9,'M': 10,'N': 11,'P': 12,'Q': 13,'R': 14,'S': 15,'T': 16,'V': 17,'W': 18,'Y': 19,'X': 20}
12 |
13 | with open(args.input_path, 'r') as json_file:
14 | json_list = list(json_file)
15 |
16 | my_dict = {}
17 | for json_str in json_list:
18 | result = json.loads(json_str)
19 | all_chain_list = [item[-1:] for item in list(result) if item[:10]=='seq_chain_']
20 | bias_by_res_dict = {}
21 | for chain in all_chain_list:
22 | chain_length = len(result[f'seq_chain_{chain}'])
23 | bias_per_residue = np.zeros([chain_length, 21])
24 |
25 |
26 | if chain == 'A':
27 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15]
28 | amino_acids = [5, 9] #[G, L]
29 | for res in residues:
30 | for aa in amino_acids:
31 | bias_per_residue[res, aa] = 100.5
32 |
33 | if chain == 'C':
34 | residues = [0, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15]
35 | amino_acids = range(21)[1:] #[G, L]
36 | for res in residues:
37 | for aa in amino_acids:
38 | bias_per_residue[res, aa] = -100.5
39 |
40 | bias_by_res_dict[chain] = bias_per_residue.tolist()
41 | my_dict[result['name']] = bias_by_res_dict
42 |
43 | with open(args.output_path, 'w') as f:
44 | f.write(json.dumps(my_dict) + '\n')
45 |
46 |
47 | if __name__ == "__main__":
48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs")
50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary")
51 |
52 | args = argparser.parse_args()
53 | main(args)
54 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/make_fixed_positions_dict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 | import glob
5 | import random
6 | import numpy as np
7 | import json
8 | import itertools
9 |
10 | with open(args.input_path, 'r') as json_file:
11 | json_list = list(json_file)
12 |
13 | fixed_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")]
14 | global_designed_chain_list = [str(item) for item in args.chain_list.split()]
15 | my_dict = {}
16 |
17 | if not args.specify_non_fixed:
18 | for json_str in json_list:
19 | result = json.loads(json_str)
20 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
21 | fixed_position_dict = {}
22 | for i, chain in enumerate(global_designed_chain_list):
23 | fixed_position_dict[chain] = fixed_list[i]
24 | for chain in all_chain_list:
25 | if chain not in global_designed_chain_list:
26 | fixed_position_dict[chain] = []
27 | my_dict[result['name']] = fixed_position_dict
28 | else:
29 | for json_str in json_list:
30 | result = json.loads(json_str)
31 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
32 | fixed_position_dict = {}
33 | for chain in all_chain_list:
34 | seq_length = len(result[f'seq_chain_{chain}'])
35 | all_residue_list = (np.arange(seq_length)+1).tolist()
36 | if chain not in global_designed_chain_list:
37 | fixed_position_dict[chain] = all_residue_list
38 | else:
39 | idx = np.argwhere(np.array(global_designed_chain_list) == chain)[0][0]
40 | fixed_position_dict[chain] = list(set(all_residue_list)-set(fixed_list[idx]))
41 | my_dict[result['name']] = fixed_position_dict
42 |
43 | with open(args.output_path, 'w') as f:
44 | f.write(json.dumps(my_dict) + '\n')
45 |
46 | #e.g. output
47 | #{"5TTA": {"A": [1, 2, 3, 7, 8, 9, 22, 25, 33], "B": []}, "3LIS": {"A": [], "B": []}}
48 |
49 | if __name__ == "__main__":
50 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
51 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs")
52 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary")
53 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed")
54 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain")
55 | argparser.add_argument("--specify_non_fixed", action="store_true", default=False, help="Allows specifying just residues that need to be designed (default: false)")
56 |
57 | args = argparser.parse_args()
58 | main(args)
59 |
60 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/make_pssm_input_dict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 | import json
5 | import numpy as np
6 | with open(args.jsonl_input_path, 'r') as json_file:
7 | json_list = list(json_file)
8 |
9 | my_dict = {}
10 | for json_str in json_list:
11 | result = json.loads(json_str)
12 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
13 | path_to_PSSM = args.PSSM_input_path+"/"+result['name'] + ".npz"
14 | print(path_to_PSSM)
15 | pssm_input = np.load(path_to_PSSM)
16 | pssm_dict = {}
17 | for chain in all_chain_list:
18 | pssm_dict[chain] = {}
19 | pssm_dict[chain]['pssm_coef'] = pssm_input[chain+'_coef'].tolist() #[L] per position coefficient to trust PSSM; 0.0 - do not use it; 1.0 - just use PSSM only
20 | pssm_dict[chain]['pssm_bias'] = pssm_input[chain+'_bias'].tolist() #[L,21] probability (sums up to 1.0 over alphabet of size 21) from PSSM
21 | pssm_dict[chain]['pssm_log_odds'] = pssm_input[chain+'_odds'].tolist() #[L,21] log_odds ratios coming from PSSM; optional/not needed
22 | my_dict[result['name']] = pssm_dict
23 |
24 | #Write output to:
25 | with open(args.output_path, 'w') as f:
26 | f.write(json.dumps(my_dict) + '\n')
27 |
28 | if __name__ == "__main__":
29 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 |
31 | argparser.add_argument("--PSSM_input_path", type=str, help="Path to PSSMs saved as npz files.")
32 | argparser.add_argument("--jsonl_input_path", type=str, help="Path where to load .jsonl dictionary of parsed pdbs.")
33 | argparser.add_argument("--output_path", type=str, help="Path where to save .jsonl dictionary with PSSM bias.")
34 |
35 | args = argparser.parse_args()
36 | main(args)
37 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/make_tied_positions_dict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def main(args):
4 |
5 | import glob
6 | import random
7 | import numpy as np
8 | import json
9 | import itertools
10 |
11 | with open(args.input_path, 'r') as json_file:
12 | json_list = list(json_file)
13 |
14 | homooligomeric_state = args.homooligomer
15 |
16 | if homooligomeric_state == 0:
17 | tied_list = [[int(item) for item in one.split()] for one in args.position_list.split(",")]
18 | global_designed_chain_list = [str(item) for item in args.chain_list.split()]
19 | my_dict = {}
20 | for json_str in json_list:
21 | result = json.loads(json_str)
22 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
23 | tied_positions_list = []
24 | for i, pos in enumerate(tied_list[0]):
25 | temp_dict = {}
26 | for j, chain in enumerate(global_designed_chain_list):
27 | temp_dict[chain] = [tied_list[j][i]] #needs to be a list
28 | tied_positions_list.append(temp_dict)
29 | my_dict[result['name']] = tied_positions_list
30 | else:
31 | my_dict = {}
32 | for json_str in json_list:
33 | result = json.loads(json_str)
34 | all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
35 | tied_positions_list = []
36 | chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
37 | for i in range(1,chain_length+1):
38 | temp_dict = {}
39 | for j, chain in enumerate(all_chain_list):
40 | temp_dict[chain] = [i] #needs to be a list
41 | tied_positions_list.append(temp_dict)
42 | my_dict[result['name']] = tied_positions_list
43 |
44 | with open(args.output_path, 'w') as f:
45 | f.write(json.dumps(my_dict) + '\n')
46 |
47 | if __name__ == "__main__":
48 | argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | argparser.add_argument("--input_path", type=str, help="Path to the parsed PDBs")
50 | argparser.add_argument("--output_path", type=str, help="Path to the output dictionary")
51 | argparser.add_argument("--chain_list", type=str, default='', help="List of the chains that need to be fixed")
52 | argparser.add_argument("--position_list", type=str, default='', help="Position lists, e.g. 11 12 14 18, 1 2 3 4 for first chain and the second chain")
53 | argparser.add_argument("--homooligomer", type=int, default=0, help="If 0 do not use, if 1 then design homooligomer")
54 |
55 | args = argparser.parse_args()
56 | main(args)
57 |
58 |
59 | #e.g. output
60 | #{"5TTA": [], "3LIS": [{"A": [1], "B": [1]}, {"A": [2], "B": [2]}, {"A": [3], "B": [3]}, {"A": [4], "B": [4]}, {"A": [5], "B": [5]}, {"A": [6], "B": [6]}, {"A": [7], "B": [7]}, {"A": [8], "B": [8]}, {"A": [9], "B": [9]}, {"A": [10], "B": [10]}, {"A": [11], "B": [11]}, {"A": [12], "B": [12]}, {"A": [13], "B": [13]}, {"A": [14], "B": [14]}, {"A": [15], "B": [15]}, {"A": [16], "B": [16]}, {"A": [17], "B": [17]}, {"A": [18], "B": [18]}, {"A": [19], "B": [19]}, {"A": [20], "B": [20]}, {"A": [21], "B": [21]}, {"A": [22], "B": [22]}, {"A": [23], "B": [23]}, {"A": [24], "B": [24]}, {"A": [25], "B": [25]}, {"A": [26], "B": [26]}, {"A": [27], "B": [27]}, {"A": [28], "B": [28]}, {"A": [29], "B": [29]}, {"A": [30], "B": [30]}, {"A": [31], "B": [31]}, {"A": [32], "B": [32]}, {"A": [33], "B": [33]}, {"A": [34], "B": [34]}, {"A": [35], "B": [35]}, {"A": [36], "B": [36]}, {"A": [37], "B": [37]}, {"A": [38], "B": [38]}, {"A": [39], "B": [39]}, {"A": [40], "B": [40]}, {"A": [41], "B": [41]}, {"A": [42], "B": [42]}, {"A": [43], "B": [43]}, {"A": [44], "B": [44]}, {"A": [45], "B": [45]}, {"A": [46], "B": [46]}, {"A": [47], "B": [47]}, {"A": [48], "B": [48]}, {"A": [49], "B": [49]}, {"A": [50], "B": [50]}, {"A": [51], "B": [51]}, {"A": [52], "B": [52]}, {"A": [53], "B": [53]}, {"A": [54], "B": [54]}, {"A": [55], "B": [55]}, {"A": [56], "B": [56]}, {"A": [57], "B": [57]}, {"A": [58], "B": [58]}, {"A": [59], "B": [59]}, {"A": [60], "B": [60]}, {"A": [61], "B": [61]}, {"A": [62], "B": [62]}, {"A": [63], "B": [63]}, {"A": [64], "B": [64]}, {"A": [65], "B": [65]}, {"A": [66], "B": [66]}, {"A": [67], "B": [67]}, {"A": [68], "B": [68]}, {"A": [69], "B": [69]}, {"A": [70], "B": [70]}, {"A": [71], "B": [71]}, {"A": [72], "B": [72]}, {"A": [73], "B": [73]}, {"A": [74], "B": [74]}, {"A": [75], "B": [75]}, {"A": [76], "B": [76]}, {"A": [77], "B": [77]}, {"A": [78], "B": [78]}, {"A": [79], "B": [79]}, {"A": [80], "B": [80]}, {"A": [81], "B": [81]}, {"A": [82], "B": [82]}, {"A": [83], "B": [83]}, {"A": [84], "B": [84]}, {"A": [85], "B": [85]}, {"A": [86], "B": [86]}, {"A": [87], "B": [87]}, {"A": [88], "B": [88]}, {"A": [89], "B": [89]}, {"A": [90], "B": [90]}, {"A": [91], "B": [91]}, {"A": [92], "B": [92]}, {"A": [93], "B": [93]}, {"A": [94], "B": [94]}, {"A": [95], "B": [95]}, {"A": [96], "B": [96]}]}
61 |
62 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/other_tools/make_omit_AA.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | import numpy as np
4 | import json
5 | import itertools
6 |
7 | #MODIFY this path
8 | with open('/home/justas/projects/lab_github/mpnn/data/pdbs.jsonl', 'r') as json_file:
9 | json_list = list(json_file)
10 |
11 | my_dict = {}
12 | for json_str in json_list:
13 | result = json.loads(json_str)
14 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
15 | fixed_position_dict = {}
16 | print(result['name'])
17 | if result['name'] == '5TTA':
18 | for chain in all_chain_list:
19 | if chain == 'A':
20 | fixed_position_dict[chain] = [
21 | [[int(item) for item in list(itertools.chain(list(np.arange(1,4)), list(np.arange(7,10)), [22, 25, 33]))], 'GPL'],
22 | [[int(item) for item in list(itertools.chain([40, 41, 42, 43]))], 'WC'],
23 | [[int(item) for item in list(itertools.chain(list(np.arange(50,150))))], 'ACEFGHIKLMNRSTVWYX'],
24 | [[int(item) for item in list(itertools.chain(list(np.arange(160,200))))], 'FGHIKLPQDMNRSTVWYX']]
25 | else:
26 | fixed_position_dict[chain] = []
27 | else:
28 | for chain in all_chain_list:
29 | fixed_position_dict[chain] = []
30 | my_dict[result['name']] = fixed_position_dict
31 |
32 | #MODIFY this path
33 | with open('/home/justas/projects/lab_github/mpnn/data/omit_AA.jsonl', 'w') as f:
34 | f.write(json.dumps(my_dict) + '\n')
35 |
36 |
37 | print('Finished')
38 | #e.g. output
39 | #{"5TTA": {"A": [[[1, 2, 3, 7, 8, 9, 22, 25, 33], "GPL"], [[40, 41, 42, 43], "WC"], [[50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149], "ACEFGHIKLMNRSTVWYX"], [[160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199], "FGHIKLPQDMNRSTVWYX"]], "B": []}, "3LIS": {"A": [], "B": []}}
40 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/other_tools/make_pssm_dict.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | import glob
5 | import random
6 | import numpy as np
7 | import json
8 |
9 |
10 | def softmax(x, T):
11 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True)
12 |
13 | def parse_pssm(path):
14 | data = pd.read_csv(path, skiprows=2)
15 | floats_list_list = []
16 | for i in range(data.values.shape[0]):
17 | str1 = data.values[i][0][4:]
18 | floats_list = []
19 | for item in str1.split():
20 | floats_list.append(float(item))
21 | floats_list_list.append(floats_list)
22 | np_lines = np.array(floats_list_list)
23 | return np_lines
24 |
25 | np_lines = parse_pssm('/home/swang523/RLcage/capsid/monomersfordesign/8-16-21/pssm_rainity_final_8-16-21_int/build_0.2089_0.98_0.4653_19_2.00_0.005745.pssm')
26 |
27 | mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
28 | input_alphabet = 'ARNDCQEGHILKMFPSTWYV'
29 |
30 | permutation_matrix = np.zeros([20,21])
31 | for i in range(20):
32 | letter1 = input_alphabet[i]
33 | for j in range(21):
34 | letter2 = mpnn_alphabet[j]
35 | if letter1 == letter2:
36 | permutation_matrix[i,j]=1.
37 |
38 | pssm_log_odds = np_lines[:,:20] @ permutation_matrix
39 | pssm_probs = np_lines[:,20:40] @ permutation_matrix
40 |
41 | X_mask = np.concatenate([np.zeros([1,20]), np.ones([1,1])], -1)
42 |
43 | def softmax(x, T):
44 | return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True)
45 |
46 | #Load parsed PDBs:
47 | with open('/home/justas/projects/cages/parsed/test.jsonl', 'r') as json_file:
48 | json_list = list(json_file)
49 |
50 | my_dict = {}
51 | for json_str in json_list:
52 | result = json.loads(json_str)
53 | all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
54 | pssm_dict = {}
55 | for chain in all_chain_list:
56 | pssm_dict[chain] = {}
57 | pssm_dict[chain]['pssm_coef'] = (np.ones(len(result['seq_chain_A']))).tolist() #a number between 0.0 and 1.0 specifying how much attention put to PSSM, can be adjusted later as a flag
58 | pssm_dict[chain]['pssm_bias'] = (softmax(pssm_log_odds-X_mask*1e8, 1.0)).tolist() #PSSM like, [length, 21] such that sum over the last dimension adds up to 1.0
59 | pssm_dict[chain]['pssm_log_odds'] = (pssm_log_odds).tolist()
60 | my_dict[result['name']] = pssm_dict
61 |
62 | #Write output to:
63 | with open('/home/justas/projects/lab_github/mpnn/data/pssm_dict.jsonl', 'w') as f:
64 | f.write(json.dumps(my_dict) + '\n')
65 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/parse_multiple_chains.out:
--------------------------------------------------------------------------------
1 | Successfully finished: 2 pdbs
2 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/helper_scripts/parse_multiple_chains.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --mem=32g
3 | #SBATCH -c 2
4 | #SBATCH --output=parse_multiple_chains.out
5 |
6 | source activate mlfold
7 | python parse_multiple_chains.py --input_path='../PDB_complexes/pdbs/' --output_path='../PDB_complexes/parsed_pdbs.jsonl'
8 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/inputs/PSSM_inputs/3HTN.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/inputs/PSSM_inputs/3HTN.npz
--------------------------------------------------------------------------------
/forks/ProteinMPNN/inputs/PSSM_inputs/4YOW.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/inputs/PSSM_inputs/4YOW.npz
--------------------------------------------------------------------------------
/forks/ProteinMPNN/outputs/training_test_output/seqs/5L33.fa:
--------------------------------------------------------------------------------
1 | >5L33, score=1.9496, global_score=1.9496, fixed_chains=[], designed_chains=['A'], model_name=epoch_last, git_hash=78fdc70e6da22ba7a535f8913f62310b738f0b33, seed=37
2 | HMPEEEKAARLFIEALEKGDPELMRKVISPDTRMEDNGREFTGDEVVEYVKEIQKRGEQWHLRRYTKEGNSWRFEVQVDNNGQTEQWEVQIEVRNGRIKRVTITHV
3 | >T=0.1, sample=1, score=1.3913, global_score=1.3913, seq_recovery=0.3396
4 | ALPAAEKVALALLDALATGDPELLKAVLTADSKFTDNGKEFKGEDLVDFVEKLKKEGKKFKPTSGSVTGDSFTLTLTVSSNGKTETATLTVKVENGKLSSLTITKN
5 | >T=0.1, sample=2, score=1.3315, global_score=1.3315, seq_recovery=0.3868
6 | KLPEEEKVAKELLKALENGDPELAKKVLTPDAKFTLNGEEFSGEELVKFVKELKEKGEKFKPLSSKKDGDSYTFTLKLSKNGKTETATLKVKVKNGKVEEIELSSD
7 | >T=0.1, sample=3, score=1.3452, global_score=1.3452, seq_recovery=0.4151
8 | KLPEEEKVAKKLLEALEKKDPELLKKVLTPDSEFTINGKKFKGEELVKLVKELKKKGEKFELVSSSKTGDSFTFTLKVSKNGKTKEATLTVTVKNGKLDKLTLSFK
9 | >T=0.1, sample=4, score=1.3763, global_score=1.3763, seq_recovery=0.3208
10 | KLDEKLKVAEKLIKALENGDPALLKEVLTKDSKFTINGKEFTGEDAVDFVKKLKKAGEKFKKLSGELKGDEFTFKLELEKDGEKKTAELTVKVENGKLTSLTIKDK
11 | >T=0.1, sample=5, score=1.3108, global_score=1.3108, seq_recovery=0.3208
12 | KLPEEEKVAKKLLDALENKDPELAKKVLTKDAKFKINGKEFSGEDLVDFVKKLKEDGKEFKLLSGKKVGDKYVFTLKISKDGEEKEAKLEVKVKNGKVEEIKIESK
13 | >T=0.1, sample=6, score=1.3122, global_score=1.3122, seq_recovery=0.3302
14 | KLPEEEKVAKKLLEALEKGDPELLKEVLTKDAEFTKNGEKFKGPDLVKFVEKLKAAGEKFKLLSSKKEGDKYTLTLELEKNGEKKKATLTVDVKNGKVESLELSDK
15 | >T=0.1, sample=7, score=1.3217, global_score=1.3217, seq_recovery=0.3774
16 | KKPEKEKVADKLVDALEKGDPELLKKVLTKDAKFEKNGKKFKGEDLVKYVEELKAKGEKFEPVGGEKKGDSYKFKLKISKNGKTKTATLEVKVENGKVKELKISEK
17 | >T=0.1, sample=8, score=1.2603, global_score=1.2603, seq_recovery=0.3491
18 | KLPEKEKVAKKLLDALEKGDPELLKEVLNKDAKFTINGKKFKGDDLVKFVEELKKKGEKFEPLSGEKKGDEFVFKLKIEKNGEKKEVELKVKVEDGKVKDIEIKDK
19 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/soluble_model_weights/v_48_002.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_002.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/soluble_model_weights/v_48_010.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_010.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/soluble_model_weights/v_48_020.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_020.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/soluble_model_weights/v_48_030.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/soluble_model_weights/v_48_030.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/training/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Justas Dauparas
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/training/exp_020/log.txt:
--------------------------------------------------------------------------------
1 | Epoch Train Validation
2 | epoch: 1, step: 74, time: 45.7, train: 23.565, valid: 17.468, train_acc: 0.072, valid_acc: 0.113
3 | epoch: 2, step: 148, time: 36.1, train: 17.790, valid: 15.872, train_acc: 0.108, valid_acc: 0.133
4 | epoch: 3, step: 221, time: 41.2, train: 16.398, valid: 14.980, train_acc: 0.128, valid_acc: 0.154
5 | epoch: 4, step: 294, time: 39.3, train: 15.417, valid: 15.041, train_acc: 0.148, valid_acc: 0.153
6 | epoch: 5, step: 369, time: 40.8, train: 14.736, valid: 13.124, train_acc: 0.167, valid_acc: 0.202
7 | epoch: 6, step: 444, time: 39.7, train: 13.954, valid: 12.683, train_acc: 0.187, valid_acc: 0.209
8 | epoch: 7, step: 515, time: 41.2, train: 13.665, valid: 12.079, train_acc: 0.193, valid_acc: 0.229
9 | epoch: 8, step: 586, time: 39.6, train: 13.105, valid: 11.938, train_acc: 0.208, valid_acc: 0.237
10 | epoch: 9, step: 656, time: 41.0, train: 12.714, valid: 11.232, train_acc: 0.219, valid_acc: 0.248
11 | epoch: 10, step: 726, time: 37.7, train: 12.391, valid: 11.386, train_acc: 0.225, valid_acc: 0.249
12 | epoch: 11, step: 796, time: 39.2, train: 12.098, valid: 10.990, train_acc: 0.231, valid_acc: 0.261
13 | epoch: 12, step: 866, time: 37.0, train: 11.845, valid: 10.554, train_acc: 0.238, valid_acc: 0.267
14 | epoch: 13, step: 940, time: 42.0, train: 11.742, valid: 10.673, train_acc: 0.240, valid_acc: 0.270
15 | epoch: 14, step: 1014, time: 38.7, train: 11.503, valid: 10.455, train_acc: 0.245, valid_acc: 0.271
16 | epoch: 15, step: 1089, time: 42.3, train: 11.284, valid: 10.303, train_acc: 0.251, valid_acc: 0.278
17 | epoch: 16, step: 1164, time: 40.1, train: 11.335, valid: 9.982, train_acc: 0.249, valid_acc: 0.285
18 | epoch: 17, step: 1239, time: 43.7, train: 10.959, valid: 9.796, train_acc: 0.260, valid_acc: 0.292
19 | epoch: 18, step: 1314, time: 36.3, train: 10.726, valid: 9.472, train_acc: 0.265, valid_acc: 0.301
20 | epoch: 19, step: 1383, time: 44.5, train: 10.730, valid: 9.604, train_acc: 0.267, valid_acc: 0.295
21 | epoch: 20, step: 1452, time: 34.0, train: 10.583, valid: 9.378, train_acc: 0.270, valid_acc: 0.305
22 | epoch: 21, step: 1522, time: 41.2, train: 10.396, valid: 9.458, train_acc: 0.275, valid_acc: 0.304
23 | epoch: 22, step: 1592, time: 37.0, train: 10.324, valid: 9.444, train_acc: 0.277, valid_acc: 0.301
24 | epoch: 23, step: 1662, time: 39.9, train: 10.250, valid: 9.518, train_acc: 0.278, valid_acc: 0.302
25 | epoch: 24, step: 1732, time: 38.0, train: 10.092, valid: 9.048, train_acc: 0.284, valid_acc: 0.315
26 | epoch: 25, step: 1808, time: 46.0, train: 10.221, valid: 8.969, train_acc: 0.279, valid_acc: 0.322
27 | epoch: 26, step: 1884, time: 36.7, train: 10.010, valid: 8.876, train_acc: 0.285, valid_acc: 0.320
28 | epoch: 27, step: 1959, time: 43.3, train: 10.013, valid: 8.652, train_acc: 0.286, valid_acc: 0.328
29 | epoch: 28, step: 2034, time: 40.4, train: 9.694, valid: 8.779, train_acc: 0.296, valid_acc: 0.324
30 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/training/exp_020/model_weights/epoch_last.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/training/exp_020/model_weights/epoch_last.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/training/submit_exp_020.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH -p gpu
4 | #SBATCH --mem=128g
5 | #SBATCH --gres=gpu:a100:1
6 | #SBATCH -c 12
7 | #SBATCH -t 7-00:00:00
8 | #SBATCH --output=exp_020.out
9 |
10 | source activate mlfold-test
11 | python ./training.py \
12 | --path_for_outputs "./exp_020" \
13 | --path_for_training_data "path_to/pdb_2021aug02" \
14 | --num_examples_per_epoch 1000 \
15 | --save_model_every_n_epochs 50
16 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/training/test_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu
3 | #SBATCH --mem=32g
4 | #SBATCH --gres=gpu:rtx2080:1
5 | #SBATCH -c 3
6 | #SBATCH --output=example_3_model_w_test.out
7 |
8 | source activate mlfold
9 |
10 | path_to_PDB="../inputs/PDB_monomers/pdbs/5L33.pdb"
11 |
12 | output_dir="../outputs/training_test_output"
13 | if [ ! -d $output_dir ]
14 | then
15 | mkdir -p $output_dir
16 | fi
17 |
18 | chains_to_design="A"
19 |
20 |
21 | python ../protein_mpnn_run.py \
22 | --path_to_model_weights "../training/exp_020/model_weights" \
23 | --model_name "epoch_last" \
24 | --pdb_path $path_to_PDB \
25 | --pdb_path_chains "$chains_to_design" \
26 | --out_folder $output_dir \
27 | --num_seq_per_target 8 \
28 | --sampling_temp "0.1" \
29 | --seed 37 \
30 | --batch_size 1
31 |
--------------------------------------------------------------------------------
/forks/ProteinMPNN/vanilla_model_weights/v_48_002.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_002.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/vanilla_model_weights/v_48_010.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_010.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/vanilla_model_weights/v_48_020.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_020.pt
--------------------------------------------------------------------------------
/forks/ProteinMPNN/vanilla_model_weights/v_48_030.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/ProteinMPNN/vanilla_model_weights/v_48_030.pt
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Institute for Protein Design
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/README.md:
--------------------------------------------------------------------------------
1 | # RF2NA
2 | GitHub repo for RoseTTAFold2 with nucleic acids
3 |
4 | **New: April 13, 2023 v0.2**
5 | * Updated weights (https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz) for better prediction of homodimer:DNA interactions and better DNA-specific sequence recognition
6 | * Bugfixes in MSA generation pipeline
7 | * Support for paired protein/RNA MSAs
8 |
9 | ## Installation
10 |
11 | 1. Clone the package
12 | ```
13 | git clone https://github.com/uw-ipd/RoseTTAFold2NA.git
14 | cd RoseTTAFold2NA
15 | ```
16 |
17 | 2. Create conda environment
18 | All external dependencies are contained in `RF2na-linux.yml`
19 | ```
20 | # create conda environment for RoseTTAFold2NA
21 | conda env create -f RF2na-linux.yml
22 | ```
23 | You also need to install NVIDIA's SE(3)-Transformer (**please use SE3Transformer in this repo to install**).
24 | ```
25 | conda activate RF2NA
26 | cd SE3Transformer
27 | pip install --no-cache-dir -r requirements.txt
28 | python setup.py install
29 | ```
30 |
31 | 3. Download pre-trained weights under network directory
32 | ```
33 | cd network
34 | wget https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz
35 | tar xvfz RF2NA_apr23.tgz
36 | ls weights/ # it should contain a 1.1GB weights file
37 | cd ..
38 | ```
39 |
40 | 4. Download sequence and structure databases
41 | ```
42 | # uniref30 [46G]
43 | wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz
44 | mkdir -p UniRef30_2020_06
45 | tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06
46 |
47 | # BFD [272G]
48 | wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
49 | mkdir -p bfd
50 | tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
51 |
52 | # structure templates (including *_a3m.ffdata, *_a3m.ffindex)
53 | wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz
54 | tar xfz pdb100_2021Mar03.tar.gz
55 |
56 | # RNA databases
57 | mkdir -p RNA
58 | cd RNA
59 |
60 | # Rfam [300M]
61 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.full_region.gz
62 | wget ftp://ftp.ebi.ac.uk/pub/databases/Rfam/CURRENT/Rfam.cm.gz
63 | gunzip Rfam.cm.gz
64 | cmpress Rfam.cm
65 |
66 | # RNAcentral [12G]
67 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/rfam/rfam_annotations.tsv.gz
68 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/id_mapping/id_mapping.tsv.gz
69 | wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/sequences/rnacentral_species_specific_ids.fasta.gz
70 | ../input_prep/reprocess_rnac.pl id_mapping.tsv.gz rfam_annotations.tsv.gz # ~8 minutes
71 | gunzip -c rnacentral_species_specific_ids.fasta.gz | makeblastdb -in - -dbtype nucl -parse_seqids -out rnacentral.fasta -title "RNACentral"
72 |
73 | # nt [151G]
74 | update_blastdb.pl --decompress nt
75 | cd ..
76 | ```
77 |
78 | ## Usage
79 | ```
80 | conda activate RF2NA
81 | cd example
82 | # run Protein/RNA prediction
83 | ../run_RF2NA.sh rna_pred 0 rna_binding_protein.fa R:RNA.fa
84 | # run Protein/DNA prediction
85 | ../run_RF2NA.sh dna_pred 0 dna_binding_protein.fa D:DNA.fa
86 | ```
87 | ### Inputs
88 | * The first argument to the script is the output folder
89 | * The second argument to the script is an integer indicating whether to use single-sequence mode (`1`) or not (`0`)
90 | * The remaining arguments are fasta files for individual chains in the structure. Use the tags `P:xxx.fa` `R:xxx.fa` `D:xxx.fa` `S:xxx.fa` to specify protein, RNA, double-stranded DNA, and single-stranded DNA, respectively. Use the tag `PR:xxx.fa` to specify paired protein/RNA. Each chain is a separate file; 'D' will automatically generate a complementary DNA strand to the input strand.
91 |
92 | ### Expected outputs
93 | * Outputs are written to the folder provided as the first argument (`dna_pred` and `rna_pred`).
94 | * Model outputs are placed in a subfolder, `models` (e.g., `dna_pred.models`)
95 | * You will get a predicted structre with estimated per-residue LDDT in the B-factor column (`models/model_00.pdb`)
96 | * You will get a numpy `.npz` file (`models/model_00.npz`). This can be read with `numpy.load` and contains three tables (L=complex length):
97 | - dist (L x L x 37) - the predicted distogram
98 | - lddt (L) - the per-residue predicted lddt
99 | - pae (L x L) - the per-residue pair predicted error
100 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/RF2na-linux.yml:
--------------------------------------------------------------------------------
1 | name: RF2NA
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - python=3.10
9 | - pip
10 | - pytorch
11 | - requests
12 | - pytorch-cuda=11.7
13 | - dglteam/label/cu117::dgl
14 | - pyg::pyg
15 | - bioconda::mafft
16 | - bioconda::hhsuite
17 | - bioconda::blast
18 | - bioconda::hmmer>=3.3
19 | - bioconda::infernal
20 | - bioconda::cd-hit
21 | - bioconda::csblast
22 | - pip:
23 | - psutil
24 | - tqdm
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | # run docker daemon with --default-runtime=nvidia for GPU detection during build
25 | # multistage build for DGL with CUDA and FP16
26 |
27 | ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
28 |
29 | FROM ${FROM_IMAGE_NAME} AS dgl_builder
30 |
31 | ENV DEBIAN_FRONTEND=noninteractive
32 | RUN apt-get update \
33 | && apt-get install -y git build-essential python3-dev make cmake \
34 | && rm -rf /var/lib/apt/lists/*
35 | WORKDIR /dgl
36 | RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
37 | RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
38 | WORKDIR build
39 | RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
40 | RUN make -j8
41 |
42 |
43 | FROM ${FROM_IMAGE_NAME}
44 |
45 | RUN rm -rf /workspace/*
46 | WORKDIR /workspace/se3-transformer
47 |
48 | # copy built DGL and install it
49 | COPY --from=dgl_builder /dgl ./dgl
50 | RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
51 |
52 | ADD requirements.txt .
53 | RUN pip install --no-cache-dir --upgrade --pre pip
54 | RUN pip install --no-cache-dir -r requirements.txt
55 | ADD . .
56 |
57 | ENV DGLBACKEND=pytorch
58 | ENV OMP_NUM_THREADS=1
59 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2021 NVIDIA CORPORATION & AFFILIATES
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/NOTICE:
--------------------------------------------------------------------------------
1 | SE(3)-Transformer PyTorch
2 |
3 | This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
4 | licensed under the MIT License.
5 |
6 | This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
7 | licensed under the MIT License.
8 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/images/se3-transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/images/se3-transformer.png
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/requirements.txt:
--------------------------------------------------------------------------------
1 | e3nn==0.3.3
2 | wandb==0.12.0
3 | pynvml==11.0.0
4 | git+https://github.com/NVIDIA/dllogger#egg=dllogger
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Script to benchmark inference performance, without bases precomputation
3 |
4 | # CLI args with defaults
5 | BATCH_SIZE=${1:-240}
6 | AMP=${2:-true}
7 |
8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \
9 | --amp "$AMP" \
10 | --batch_size "$BATCH_SIZE" \
11 | --use_layer_norm \
12 | --norm \
13 | --task homo \
14 | --seed 42 \
15 | --benchmark
16 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Script to benchmark single-GPU training performance, with bases precomputation
3 |
4 | # CLI args with defaults
5 | BATCH_SIZE=${1:-240}
6 | AMP=${2:-true}
7 |
8 | CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \
9 | --amp "$AMP" \
10 | --batch_size "$BATCH_SIZE" \
11 | --epochs 6 \
12 | --use_layer_norm \
13 | --norm \
14 | --save_ckpt_path model_qm9.pth \
15 | --task homo \
16 | --precompute_bases \
17 | --seed 42 \
18 | --benchmark
19 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/benchmark_train_multi_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Script to benchmark multi-GPU training performance, with bases precomputation
3 |
4 | # CLI args with defaults
5 | BATCH_SIZE=${1:-240}
6 | AMP=${2:-true}
7 |
8 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
9 | se3_transformer.runtime.training \
10 | --amp "$AMP" \
11 | --batch_size "$BATCH_SIZE" \
12 | --epochs 6 \
13 | --use_layer_norm \
14 | --norm \
15 | --save_ckpt_path model_qm9.pth \
16 | --task homo \
17 | --precompute_bases \
18 | --seed 42 \
19 | --benchmark
20 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/predict.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # CLI args with defaults
4 | BATCH_SIZE=${1:-240}
5 | AMP=${2:-true}
6 |
7 |
8 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
9 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
10 | TASK=homo
11 |
12 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
13 | se3_transformer.runtime.inference \
14 | --amp "$AMP" \
15 | --batch_size "$BATCH_SIZE" \
16 | --use_layer_norm \
17 | --norm \
18 | --load_ckpt_path model_qm9.pth \
19 | --task "$TASK"
20 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # CLI args with defaults
4 | BATCH_SIZE=${1:-240}
5 | AMP=${2:-true}
6 | NUM_EPOCHS=${3:-100}
7 | LEARNING_RATE=${4:-0.002}
8 | WEIGHT_DECAY=${5:-0.1}
9 |
10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
12 | TASK=homo
13 |
14 | python -m se3_transformer.runtime.training \
15 | --amp "$AMP" \
16 | --batch_size "$BATCH_SIZE" \
17 | --epochs "$NUM_EPOCHS" \
18 | --lr "$LEARNING_RATE" \
19 | --weight_decay "$WEIGHT_DECAY" \
20 | --use_layer_norm \
21 | --norm \
22 | --save_ckpt_path model_qm9.pth \
23 | --precompute_bases \
24 | --seed 42 \
25 | --task "$TASK"
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/scripts/train_multi_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # CLI args with defaults
4 | BATCH_SIZE=${1:-240}
5 | AMP=${2:-true}
6 | NUM_EPOCHS=${3:-130}
7 | LEARNING_RATE=${4:-0.01}
8 | WEIGHT_DECAY=${5:-0.1}
9 |
10 | # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
11 | # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
12 | TASK=homo
13 |
14 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
15 | se3_transformer.runtime.training \
16 | --amp "$AMP" \
17 | --batch_size "$BATCH_SIZE" \
18 | --epochs "$NUM_EPOCHS" \
19 | --lr "$LEARNING_RATE" \
20 | --min_lr 0.00001 \
21 | --weight_decay "$WEIGHT_DECAY" \
22 | --use_layer_norm \
23 | --norm \
24 | --save_ckpt_path model_qm9.pth \
25 | --precompute_bases \
26 | --seed 42 \
27 | --task "$TASK"
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/__init__.py
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/data_loading/__init__.py:
--------------------------------------------------------------------------------
1 | from .qm9 import QM9DataModule
2 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/data_loading/data_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | import torch.distributed as dist
25 | from abc import ABC
26 | from torch.utils.data import DataLoader, DistributedSampler, Dataset
27 |
28 | from se3_transformer.runtime.utils import get_local_rank
29 |
30 |
31 | def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
32 | # Classic or distributed dataloader depending on the context
33 | sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
34 | return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
35 |
36 |
37 | class DataModule(ABC):
38 | """ Abstract DataModule. Children must define self.ds_{train | val | test}. """
39 |
40 | def __init__(self, **dataloader_kwargs):
41 | super().__init__()
42 | if get_local_rank() == 0:
43 | self.prepare_data()
44 |
45 | # Wait until rank zero has prepared the data (download, preprocessing, ...)
46 | if dist.is_initialized():
47 | dist.barrier(device_ids=[get_local_rank()])
48 |
49 | self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
50 | self.ds_train, self.ds_val, self.ds_test = None, None, None
51 |
52 | def prepare_data(self):
53 | """ Method called only once per node. Put here any downloading or preprocessing """
54 | pass
55 |
56 | def train_dataloader(self) -> DataLoader:
57 | return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
58 |
59 | def val_dataloader(self) -> DataLoader:
60 | return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
61 |
62 | def test_dataloader(self) -> DataLoader:
63 | return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
64 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import SE3Transformer, SE3TransformerPooled
2 | from .fiber import Fiber
3 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import LinearSE3
2 | from .norm import NormSE3
3 | from .pooling import GPooling
4 | from .convolution import ConvSE3
5 | from .attention import AttentionBlockSE3
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 |
25 | from typing import Dict
26 |
27 | import numpy as np
28 | import torch
29 | import torch.nn as nn
30 | from torch import Tensor
31 |
32 | from se3_transformer.model.fiber import Fiber
33 |
34 |
35 | class LinearSE3(nn.Module):
36 | """
37 | Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
38 | Maps a fiber to a fiber with the same degrees (channels may be different).
39 | No interaction between degrees, but interaction between channels.
40 |
41 | type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
42 | type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
43 | :
44 | type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
45 | """
46 |
47 | def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
48 | super().__init__()
49 | self.weights = nn.ParameterDict({
50 | str(degree_out): nn.Parameter(
51 | torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
52 | for degree_out, channels_out in fiber_out
53 | })
54 |
55 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
56 | return {
57 | degree: self.weights[degree] @ features[degree]
58 | for degree, weight in self.weights.items()
59 | }
60 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 |
25 | from typing import Dict
26 |
27 | import torch
28 | import torch.nn as nn
29 | from torch import Tensor
30 | from torch.cuda.nvtx import range as nvtx_range
31 |
32 | from se3_transformer.model.fiber import Fiber
33 |
34 |
35 | class NormSE3(nn.Module):
36 | """
37 | Norm-based SE(3)-equivariant nonlinearity.
38 |
39 | ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
40 | feature_in ──┤ * ──> feature_out
41 | └──> feature_phase ────────────────────────────┘
42 | """
43 |
44 | NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
45 |
46 | def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
47 | super().__init__()
48 | self.fiber = fiber
49 | self.nonlinearity = nonlinearity
50 |
51 | if len(set(fiber.channels)) == 1:
52 | # Fuse all the layer normalizations into a group normalization
53 | self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
54 | else:
55 | # Use multiple layer normalizations
56 | self.layer_norms = nn.ModuleDict({
57 | str(degree): nn.LayerNorm(channels)
58 | for degree, channels in fiber
59 | })
60 |
61 | def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
62 | with nvtx_range('NormSE3'):
63 | output = {}
64 | #print ('NormSE3 features',[torch.sum(torch.isnan(v)) for v in features.values()])
65 | if hasattr(self, 'group_norm'):
66 | # Compute per-degree norms of features
67 | norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
68 | for d in self.fiber.degrees]
69 | fused_norms = torch.cat(norms, dim=-2)
70 |
71 | # Transform the norms only
72 | new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
73 | new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
74 |
75 | # Scale features to the new norms
76 | for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
77 | output[str(d)] = features[str(d)] / norm * new_norm
78 | else:
79 | for degree, feat in features.items():
80 | norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
81 | new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
82 | output[degree] = new_norm * feat / norm
83 | #print ('NormSE3 output',[torch.sum(torch.isnan(v)) for v in output.values()])
84 |
85 | return output
86 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/model/layers/pooling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | from typing import Dict, Literal
25 |
26 | import torch.nn as nn
27 | from dgl import DGLGraph
28 | from dgl.nn.pytorch import AvgPooling, MaxPooling
29 | from torch import Tensor
30 |
31 |
32 | class GPooling(nn.Module):
33 | """
34 | Graph max/average pooling on a given feature type.
35 | The average can be taken for any feature type, and equivariance will be maintained.
36 | The maximum can only be taken for invariant features (type 0).
37 | If you want max-pooling for type > 0 features, look into Vector Neurons.
38 | """
39 |
40 | def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
41 | """
42 | :param feat_type: Feature type to pool
43 | :param pool: Type of pooling: max or avg
44 | """
45 | super().__init__()
46 | assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
47 | assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
48 | self.feat_type = feat_type
49 | self.pool = MaxPooling() if pool == 'max' else AvgPooling()
50 |
51 | def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
52 | pooled = self.pool(graph, features[str(self.feat_type)])
53 | return pooled.squeeze(dim=-1)
54 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/__init__.py
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | import argparse
25 | import pathlib
26 |
27 | from se3_transformer.data_loading import QM9DataModule
28 | from se3_transformer.model import SE3TransformerPooled
29 | from se3_transformer.runtime.utils import str2bool
30 |
31 | PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
32 |
33 | paths = PARSER.add_argument_group('Paths')
34 | paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
35 | help='Directory where the data is located or should be downloaded')
36 | paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
37 | help='Directory where the results logs should be saved')
38 | paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
39 | help='Name for the resulting DLLogger JSON file')
40 | paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
41 | help='File where the checkpoint should be saved')
42 | paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
43 | help='File of the checkpoint to be loaded')
44 |
45 | optimizer = PARSER.add_argument_group('Optimizer')
46 | optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
47 | optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
48 | optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
49 | optimizer.add_argument('--momentum', type=float, default=0.9)
50 | optimizer.add_argument('--weight_decay', type=float, default=0.1)
51 |
52 | PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
53 | PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
54 | PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
55 | PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
56 |
57 | PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
58 | PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
59 | PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
60 | PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
61 | PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
62 | help='Do an evaluation round every N epochs')
63 | PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
64 | help='Minimize stdout output')
65 |
66 | PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
67 | help='Benchmark mode')
68 |
69 | QM9DataModule.add_argparse_args(PARSER)
70 | SE3TransformerPooled.add_argparse_args(PARSER)
71 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/se3_transformer/runtime/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | from abc import ABC, abstractmethod
25 |
26 | import torch
27 | import torch.distributed as dist
28 | from torch import Tensor
29 |
30 |
31 | class Metric(ABC):
32 | """ Metric class with synchronization capabilities similar to TorchMetrics """
33 |
34 | def __init__(self):
35 | self.states = {}
36 |
37 | def add_state(self, name: str, default: Tensor):
38 | assert name not in self.states
39 | self.states[name] = default.clone()
40 | setattr(self, name, default)
41 |
42 | def synchronize(self):
43 | if dist.is_initialized():
44 | for state in self.states:
45 | dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
46 |
47 | def __call__(self, *args, **kwargs):
48 | self.update(*args, **kwargs)
49 |
50 | def reset(self):
51 | for name, default in self.states.items():
52 | setattr(self, name, default.clone())
53 |
54 | def compute(self):
55 | self.synchronize()
56 | value = self._compute().item()
57 | self.reset()
58 | return value
59 |
60 | @abstractmethod
61 | def _compute(self):
62 | pass
63 |
64 | @abstractmethod
65 | def update(self, preds: Tensor, targets: Tensor):
66 | pass
67 |
68 |
69 | class MeanAbsoluteError(Metric):
70 | def __init__(self):
71 | super().__init__()
72 | self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
73 | self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
74 |
75 | def update(self, preds: Tensor, targets: Tensor):
76 | preds = preds.detach()
77 | n = preds.shape[0]
78 | error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
79 | self.total += n
80 | self.error += error
81 |
82 | def _compute(self):
83 | return self.error / self.total
84 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='se3-transformer',
5 | packages=find_packages(),
6 | include_package_data=True,
7 | version='1.0.0',
8 | description='PyTorch + DGL implementation of SE(3)-Transformers',
9 | author='Alexandre Milesi',
10 | author_email='alexandrem@nvidia.com',
11 | )
12 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/SE3Transformer/tests/__init__.py
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/tests/test_equivariance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | import torch
25 |
26 | from se3_transformer.model import SE3Transformer
27 | from se3_transformer.model.fiber import Fiber
28 | from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
29 |
30 | # Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
31 | TOL = 1e-3
32 | CHANNELS, NODES = 32, 512
33 |
34 |
35 | def _get_outputs(model, R):
36 | feats0 = torch.randn(NODES, CHANNELS, 1)
37 | feats1 = torch.randn(NODES, CHANNELS, 3)
38 |
39 | coords = torch.randn(NODES, 3)
40 | graph = get_random_graph(NODES)
41 | if torch.cuda.is_available():
42 | feats0 = feats0.cuda()
43 | feats1 = feats1.cuda()
44 | R = R.cuda()
45 | coords = coords.cuda()
46 | graph = graph.to('cuda')
47 | model.cuda()
48 |
49 | graph1 = assign_relative_pos(graph, coords)
50 | out1 = model(graph1, {'0': feats0, '1': feats1}, {})
51 | graph2 = assign_relative_pos(graph, coords @ R)
52 | out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
53 |
54 | return out1, out2
55 |
56 |
57 | def _get_model(**kwargs):
58 | return SE3Transformer(
59 | num_layers=4,
60 | fiber_in=Fiber.create(2, CHANNELS),
61 | fiber_hidden=Fiber.create(3, CHANNELS),
62 | fiber_out=Fiber.create(2, CHANNELS),
63 | fiber_edge=Fiber({}),
64 | num_heads=8,
65 | channels_div=2,
66 | **kwargs
67 | )
68 |
69 |
70 | def test_equivariance():
71 | model = _get_model()
72 | R = rot(*torch.rand(3))
73 | if torch.cuda.is_available():
74 | R = R.cuda()
75 | out1, out2 = _get_outputs(model, R)
76 |
77 | assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
78 | f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
79 | assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
80 | f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
81 |
82 |
83 | def test_equivariance_pooled():
84 | model = _get_model(pooling='avg', return_type=1)
85 | R = rot(*torch.rand(3))
86 | if torch.cuda.is_available():
87 | R = R.cuda()
88 | out1, out2 = _get_outputs(model, R)
89 |
90 | assert torch.allclose(out2, (out1 @ R), atol=TOL), \
91 | f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
92 |
93 |
94 | def test_invariance_pooled():
95 | model = _get_model(pooling='avg', return_type=0)
96 | R = rot(*torch.rand(3))
97 | if torch.cuda.is_available():
98 | R = R.cuda()
99 | out1, out2 = _get_outputs(model, R)
100 |
101 | assert torch.allclose(out2, out1, atol=TOL), \
102 | f'type-0 features should be invariant {get_max_diff(out1, out2)}'
103 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/SE3Transformer/tests/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a
4 | # copy of this software and associated documentation files (the "Software"),
5 | # to deal in the Software without restriction, including without limitation
6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | # and/or sell copies of the Software, and to permit persons to whom the
8 | # Software is furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | # DEALINGS IN THE SOFTWARE.
20 | #
21 | # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22 | # SPDX-License-Identifier: MIT
23 |
24 | import dgl
25 | import torch
26 |
27 |
28 | def get_random_graph(N, num_edges_factor=18):
29 | graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
30 | return graph
31 |
32 |
33 | def assign_relative_pos(graph, coords):
34 | src, dst = graph.edges()
35 | graph.edata['rel_pos'] = coords[src] - coords[dst]
36 | return graph
37 |
38 |
39 | def get_max_diff(a, b):
40 | return (a - b).abs().max().item()
41 |
42 |
43 | def rot_z(gamma):
44 | return torch.tensor([
45 | [torch.cos(gamma), -torch.sin(gamma), 0],
46 | [torch.sin(gamma), torch.cos(gamma), 0],
47 | [0, 0, 1]
48 | ], dtype=gamma.dtype)
49 |
50 |
51 | def rot_y(beta):
52 | return torch.tensor([
53 | [torch.cos(beta), 0, torch.sin(beta)],
54 | [0, 1, 0],
55 | [-torch.sin(beta), 0, torch.cos(beta)]
56 | ], dtype=beta.dtype)
57 |
58 |
59 | def rot(alpha, beta, gamma):
60 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
61 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/example/RNA.fa:
--------------------------------------------------------------------------------
1 | > RNA
2 | GAGAGAGAAGTCAACCAGAGAAACACACCAACCCATTGCACTCCGGGTTGGTGGTATATTACCTGGTACGGGGGAAACTTCGTGGTGGCCGGCCACCTGACA
3 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/example/dna_binding_protein.fa:
--------------------------------------------------------------------------------
1 | > ANTENNAPEDIA HOMEODOMAIN|Drosophila melanogaster (7227)
2 | MERKRGRQTYTRYQTLELEKEFHFNRYLTRRRRIEIAHALSLTERQIKIWFQNRRMKWKKEN
3 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/example/rna_binding_protein.fa:
--------------------------------------------------------------------------------
1 | > prot
2 | TRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKM
3 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/input_prep/make_protein_msa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # inputs
4 | in_fasta="$1"
5 | out_dir="$2"
6 | tag="$3"
7 |
8 | # resources
9 | CPU="$4"
10 | MEM="$5"
11 |
12 | # single-sequence mode
13 | SINGLE_SEQ_MODE="$6"
14 |
15 | # validate if the single-sequence mode argument is a valid integer
16 | re='^[0-1]+$'
17 | if ! [[ $SINGLE_SEQ_MODE =~ $re ]]; then
18 | echo "Error: The single-sequence mode argument must be an integer ('1' meaning true and '0' otherwise)."
19 | exit 1
20 | fi
21 |
22 | # sequence databases
23 | DB_UR30="$PIPEDIR/UniRef30_2020_06/UniRef30_2020_06"
24 | DB_BFD="$PIPEDIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
25 |
26 | # setup hhblits command
27 | HHBLITS_UR30="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_UR30"
28 | HHBLITS_BFD="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_BFD"
29 |
30 | mkdir -p $out_dir/hhblits
31 | tmp_dir="$out_dir/hhblits"
32 | out_prefix="$out_dir/$tag"
33 |
34 | echo out_prefix $out_prefix
35 |
36 | # perform iterative searches against UniRef30
37 | prev_a3m="$in_fasta"
38 |
39 | # check if single-sequence mode was requested; if so, skip the remainder of the script
40 | if [ "$SINGLE_SEQ_MODE" -eq 1 ]; then
41 | cp $prev_a3m ${out_prefix}.msa0.a3m
42 | exit 0
43 | fi
44 |
45 | for e in 1e-10 1e-6 1e-3
46 | do
47 | echo "Running HHblits against UniRef30 with E-value cutoff $e"
48 | $HHBLITS_UR30 -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0
49 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m
50 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m
51 | prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m"
52 | n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m`
53 | n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m`
54 |
55 | if ((n75>2000))
56 | then
57 | if [ ! -s ${out_prefix}.msa0.a3m ]
58 | then
59 | cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m
60 | break
61 | fi
62 | elif ((n50>4000))
63 | then
64 | if [ ! -s ${out_prefix}.msa0.a3m ]
65 | then
66 | cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m
67 | break
68 | fi
69 | else
70 | continue
71 | fi
72 | done
73 |
74 | # perform iterative searches against BFD if it failes to get enough sequences
75 | if [ ! -s ${out_prefix}.msa0.a3m ]
76 | then
77 | e=1e-3
78 | echo "Running HHblits against BFD with E-value cutoff $e"
79 | $HHBLITS_BFD -i $prev_a3m -oa3m $tmp_dir/t000_.$e.bfd.a3m -e $e -v 0
80 | hhfilter -id 90 -cov 75 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov75.a3m
81 | hhfilter -id 90 -cov 50 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov50.a3m
82 | prev_a3m="$tmp_dir/t000_.$e.bfd.id90cov50.a3m"
83 | n75=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov75.a3m`
84 | n50=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov50.a3m`
85 |
86 | if ((n75>2000))
87 | then
88 | if [ ! -s ${out_prefix}.msa0.a3m ]
89 | then
90 | cp $tmp_dir/t000_.$e.bfd.id90cov75.a3m ${out_prefix}.msa0.a3m
91 | fi
92 | elif ((n50>4000))
93 | then
94 | if [ ! -s ${out_prefix}.msa0.a3m ]
95 | then
96 | cp $tmp_dir/t000_.$e.bfd.id90cov50.a3m ${out_prefix}.msa0.a3m
97 | fi
98 | fi
99 | fi
100 |
101 | if [ ! -s ${out_prefix}.msa0.a3m ]
102 | then
103 | cp $prev_a3m ${out_prefix}.msa0.a3m
104 | fi
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/input_prep/reprocess_rnac.pl:
--------------------------------------------------------------------------------
1 | #! /usr/bin/perl
2 | use strict;
3 |
4 | my $taxids = shift @ARGV;
5 | my $idfile = shift @ARGV;
6 |
7 | my %ids;
8 | open(GZIN, "gunzip -c $taxids |") or die("gunzip $taxids: $!");
9 | foreach my $line () {
10 | my ($id,$taxid);
11 | ($id,$_,$_,$taxid,$_,$_) = split ' ',$line;
12 | if (not defined $ids{$id}) {
13 | $ids{$id} = []
14 | }
15 | if (not $taxid ~~ @{$ids{$id}}) {
16 | push (@{$ids{$id}}, $taxid)
17 | }
18 | }
19 | close(GZIN);
20 |
21 | system ("mv $idfile $idfile.bak");
22 | open (GZOUT, "| gzip -c > $idfile") or die("gzip $idfile: $!");
23 | open(GZIN, "gunzip -c $idfile.bak |") or die("gunzip $idfile: $!");
24 | foreach my $line () {
25 | #URS0000000001 RF00177 109.4 3.3e-33 2 200 29 230 Bacterial small subunit ribosomal RNA
26 | my @fields = split /\t/,$line;
27 | my $id = $fields[0];
28 |
29 | foreach my $taxid (@{$ids{$id}}) {
30 | print GZOUT $id."_".$taxid."\t".join("\t",@fields[1..$#fields]);
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/AuxiliaryPredictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from chemical import NAATOKENS
4 |
5 | class DistanceNetwork(nn.Module):
6 | def __init__(self, n_feat, p_drop=0.1):
7 | super(DistanceNetwork, self).__init__()
8 | #
9 | self.proj_symm = nn.Linear(n_feat, 37*2)
10 | self.proj_asymm = nn.Linear(n_feat, 37+19)
11 |
12 | self.reset_parameter()
13 |
14 | def reset_parameter(self):
15 | # initialize linear layer for final logit prediction
16 | nn.init.zeros_(self.proj_symm.weight)
17 | nn.init.zeros_(self.proj_asymm.weight)
18 | nn.init.zeros_(self.proj_symm.bias)
19 | nn.init.zeros_(self.proj_asymm.bias)
20 |
21 | def forward(self, x):
22 | # input: pair info (B, L, L, C)
23 |
24 | # predict theta, phi (non-symmetric)
25 | logits_asymm = self.proj_asymm(x)
26 | logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
27 | logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
28 |
29 | # predict dist, omega
30 | logits_symm = self.proj_symm(x)
31 | logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
32 | logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
33 | logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
34 |
35 | return logits_dist, logits_omega, logits_theta, logits_phi
36 |
37 | class MaskedTokenNetwork(nn.Module):
38 | def __init__(self, n_feat, p_drop=0.1):
39 | super(MaskedTokenNetwork, self).__init__()
40 | self.proj = nn.Linear(n_feat, NAATOKENS)
41 |
42 | self.reset_parameter()
43 |
44 | def reset_parameter(self):
45 | nn.init.zeros_(self.proj.weight)
46 | nn.init.zeros_(self.proj.bias)
47 |
48 | def forward(self, x):
49 | B, N, L = x.shape[:3]
50 | logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
51 |
52 | return logits
53 |
54 | class LDDTNetwork(nn.Module):
55 | def __init__(self, n_feat, n_bin_lddt=50):
56 | super(LDDTNetwork, self).__init__()
57 | self.proj = nn.Linear(n_feat, n_bin_lddt)
58 |
59 | self.reset_parameter()
60 |
61 | def reset_parameter(self):
62 | nn.init.zeros_(self.proj.weight)
63 | nn.init.zeros_(self.proj.bias)
64 |
65 | def forward(self, x):
66 | logits = self.proj(x) # (B, L, 50)
67 |
68 | return logits.permute(0,2,1)
69 |
70 | class PAENetwork(nn.Module):
71 | def __init__(self, n_feat, n_bin_pae=64):
72 | super(PAENetwork, self).__init__()
73 | self.proj = nn.Linear(n_feat, n_bin_pae)
74 | self.reset_parameter()
75 | def reset_parameter(self):
76 | nn.init.zeros_(self.proj.weight)
77 | nn.init.zeros_(self.proj.bias)
78 |
79 | def forward(self, pair, state):
80 | L = pair.shape[1]
81 | left = state.unsqueeze(2).expand(-1,-1,L,-1)
82 | right = state.unsqueeze(1).expand(-1,L,-1,-1)
83 |
84 | logits = self.proj( torch.cat((pair, left, right), dim=-1) ) # (B, L, L, 64)
85 |
86 | return logits.permute(0,3,1,2)
87 |
88 | class BinderNetwork(nn.Module):
89 | def __init__(self, n_hidden=64, n_bin_pae=64):
90 | super(BinderNetwork, self).__init__()
91 | #self.proj = nn.Linear(n_bin_pae, n_hidden)
92 | #self.classify = torch.nn.Linear(2*n_hidden, 1)
93 | self.classify = torch.nn.Linear(n_bin_pae, 1)
94 | self.reset_parameter()
95 |
96 | def reset_parameter(self):
97 | #nn.init.zeros_(self.proj.weight)
98 | #nn.init.zeros_(self.proj.bias)
99 | nn.init.zeros_(self.classify.weight)
100 | nn.init.zeros_(self.classify.bias)
101 |
102 | def forward(self, pae, same_chain):
103 | #logits = self.proj( pae.permute(0,2,3,1) )
104 | logits = pae.permute(0,2,3,1)
105 | #logits_intra = torch.mean( logits[same_chain==1], dim=0 )
106 | logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain
107 | #prob = torch.sigmoid( self.classify( torch.cat((logits_intra,logits_inter)) ) )
108 | prob = torch.sigmoid( self.classify( logits_inter ) )
109 | return prob
110 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/SE3_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | #from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
5 | #from equivariant_attention.modules import GConvSE3, GNormSE3
6 | #from equivariant_attention.fibers import Fiber
7 |
8 | from util_module import init_lecun_normal_param
9 | from se3_transformer.model import SE3Transformer
10 | from se3_transformer.model.fiber import Fiber
11 |
12 | class SE3TransformerWrapper(nn.Module):
13 | """SE(3) equivariant GCN with attention"""
14 | def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
15 | l0_in_features=32, l0_out_features=32,
16 | l1_in_features=3, l1_out_features=2,
17 | num_edge_features=32):
18 | super().__init__()
19 | # Build the network
20 | self.l1_in = l1_in_features
21 | self.l1_out = l1_out_features
22 | #
23 | fiber_edge = Fiber({0: num_edge_features})
24 | if l1_out_features > 0:
25 | if l1_in_features > 0:
26 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
27 | fiber_hidden = Fiber.create(num_degrees, num_channels)
28 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
29 | else:
30 | fiber_in = Fiber({0: l0_in_features})
31 | fiber_hidden = Fiber.create(num_degrees, num_channels)
32 | fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
33 | else:
34 | if l1_in_features > 0:
35 | fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
36 | fiber_hidden = Fiber.create(num_degrees, num_channels)
37 | fiber_out = Fiber({0: l0_out_features})
38 | else:
39 | fiber_in = Fiber({0: l0_in_features})
40 | fiber_hidden = Fiber.create(num_degrees, num_channels)
41 | fiber_out = Fiber({0: l0_out_features})
42 |
43 | self.se3 = SE3Transformer(num_layers=num_layers,
44 | fiber_in=fiber_in,
45 | fiber_hidden=fiber_hidden,
46 | fiber_out = fiber_out,
47 | num_heads=n_heads,
48 | channels_div=div,
49 | fiber_edge=fiber_edge,
50 | #populate_edges=False,
51 | #sum_over_edge=False,
52 | use_layer_norm=True,
53 | tensor_cores=True,
54 | low_memory=True)#,
55 | #populate_edge='log')
56 |
57 | self.reset_parameter()
58 |
59 | def reset_parameter(self):
60 |
61 | # make sure linear layer before ReLu are initialized with kaiming_normal_
62 | for n, p in self.se3.named_parameters():
63 | if "bias" in n:
64 | nn.init.zeros_(p)
65 | elif len(p.shape) == 1:
66 | continue
67 | else:
68 | if "radial_func" not in n:
69 | p = init_lecun_normal_param(p)
70 | else:
71 | if "net.6" in n:
72 | nn.init.zeros_(p)
73 | else:
74 | nn.init.kaiming_normal_(p, nonlinearity='relu')
75 |
76 | # make last layers to be zero-initialized
77 | self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
78 | self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
79 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
80 | if self.l1_out > 0:
81 | nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
82 |
83 | def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
84 | if self.l1_in > 0:
85 | node_features = {'0': type_0_features, '1': type_1_features}
86 | else:
87 | node_features = {'0': type_0_features}
88 | edge_features = {'0': edge_features}
89 | return self.se3(G, node_features, edge_features)
90 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/coords6d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import scipy.spatial
4 | from util import generate_Cbeta
5 |
6 | # calculate dihedral angles defined by 4 sets of points
7 | def get_dihedrals(a, b, c, d):
8 |
9 | b0 = -1.0*(b - a)
10 | b1 = c - b
11 | b2 = d - c
12 |
13 | b1 /= np.linalg.norm(b1, axis=-1)[:,None]
14 |
15 | v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
16 | w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
17 |
18 | x = np.sum(v*w, axis=-1)
19 | y = np.sum(np.cross(b1, v)*w, axis=-1)
20 |
21 | return np.arctan2(y, x)
22 |
23 | # calculate planar angles defined by 3 sets of points
24 | def get_angles(a, b, c):
25 |
26 | v = a - b
27 | v /= np.linalg.norm(v, axis=-1)[:,None]
28 |
29 | w = c - b
30 | w /= np.linalg.norm(w, axis=-1)[:,None]
31 |
32 | x = np.sum(v*w, axis=1)
33 |
34 | #return np.arccos(x)
35 | return np.arccos(np.clip(x, -1.0, 1.0))
36 |
37 | # get 6d coordinates from x,y,z coords of N,Ca,C atoms
38 | def get_coords6d(xyz, dmax):
39 |
40 | nres = xyz.shape[1]
41 |
42 | # three anchor atoms
43 | N = xyz[0]
44 | Ca = xyz[1]
45 | C = xyz[2]
46 |
47 | # recreate Cb given N,Ca,C
48 | Cb = generate_Cbeta(N,Ca,C)
49 |
50 | # fast neighbors search to collect all
51 | # Cb-Cb pairs within dmax
52 | kdCb = scipy.spatial.cKDTree(Cb)
53 | indices = kdCb.query_ball_tree(kdCb, dmax)
54 |
55 | # indices of contacting residues
56 | idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T
57 | idx0 = idx[0]
58 | idx1 = idx[1]
59 |
60 | # Cb-Cb distance matrix
61 | dist6d = np.full((nres, nres),999.9, dtype=np.float32)
62 | dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1)
63 |
64 | # matrix of Ca-Cb-Cb-Ca dihedrals
65 | omega6d = np.zeros((nres, nres), dtype=np.float32)
66 | omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
67 |
68 | # matrix of polar coord theta
69 | theta6d = np.zeros((nres, nres), dtype=np.float32)
70 | theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
71 |
72 | # matrix of polar coord phi
73 | phi6d = np.zeros((nres, nres), dtype=np.float32)
74 | phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1])
75 |
76 | mask = np.zeros((nres, nres), dtype=np.float32)
77 | mask[idx0, idx1] = 1.0
78 | return dist6d, omega6d, theta6d, phi6d, mask
79 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/ffindex.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
3 |
4 | '''
5 | Created on Apr 30, 2014
6 |
7 | @author: meiermark
8 | '''
9 |
10 |
11 | import os
12 | import mmap
13 | from collections import namedtuple
14 |
15 | FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
16 |
17 |
18 | def read_index(ffindex_filename):
19 | if os.path.getsize(ffindex_filename) == 0:
20 | return None
21 |
22 | entries = []
23 |
24 | fh = open(ffindex_filename)
25 | for line in fh:
26 | tokens = line.split("\t")
27 | entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
28 | fh.close()
29 |
30 | return entries
31 |
32 |
33 | def read_data(ffdata_filename):
34 | if os.path.getsize(ffdata_filename) == 0:
35 | return None
36 |
37 | fh = open(ffdata_filename, "rb")
38 | data = mmap.mmap(fh.fileno(), 0, prot=mmap.PROT_READ)
39 | fh.close()
40 |
41 | return data
42 |
43 |
44 | def get_entry_by_name(name, index):
45 | #TODO: bsearch
46 | if index is None:
47 | return None
48 | for entry in index:
49 | if(name == entry.name):
50 | return entry
51 | return None
52 |
53 |
54 | def read_entry_lines(entry, data):
55 | if data is None:
56 | return []
57 | lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
58 | return lines
59 |
60 |
61 | def read_entry_data(entry, data):
62 | return data[entry.offset:entry.offset + entry.length - 1]
63 |
64 |
65 | def write_entry(entries, data_fh, entry_name, offset, data):
66 | data_fh.write(data[:-1])
67 | data_fh.write(bytearray(1))
68 |
69 | entry = FFindexEntry(entry_name, offset, len(data))
70 | entries.append(entry)
71 |
72 | return offset + len(data)
73 |
74 |
75 | def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
76 | with open(file_name, "rb") as fh:
77 | data = bytearray(fh.read())
78 | return write_entry(entries, data_fh, entry_name, offset, data)
79 |
80 |
81 | def finish_db(entries, ffindex_filename, data_fh):
82 | data_fh.close()
83 | write_entries_to_db(entries, ffindex_filename)
84 |
85 |
86 | def write_entries_to_db(entries, ffindex_filename):
87 | sorted(entries, key=lambda x: x.name)
88 | index_fh = open(ffindex_filename, "w")
89 |
90 | for entry in entries:
91 | index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
92 |
93 | index_fh.close()
94 |
95 |
96 | def write_entry_to_file(entry, data, file):
97 | lines = read_lines(entry, data)
98 |
99 | fh = open(file, "w")
100 | for line in lines:
101 | fh.write(line+"\n")
102 | fh.close()
103 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/models.json:
--------------------------------------------------------------------------------
1 | {
2 | "full_bigSE3":
3 | {
4 | "description": "deep architecture w/ big SE(3)-Transformer on fully connected graph. Trained on biounit",
5 | "model_param":{
6 | "n_extra_block" : 4,
7 | "n_main_block" : 32,
8 | "n_ref_block" : 4,
9 | "d_msa" : 256 ,
10 | "d_pair" : 128,
11 | "d_templ" : 64,
12 | "n_head_msa" : 8,
13 | "n_head_pair" : 4,
14 | "n_head_templ" : 4,
15 | "d_hidden" : 32,
16 | "d_hidden_templ" : 64,
17 | "p_drop" : 0.0,
18 | "lj_lin" : 0.75,
19 | "SE3_param": {
20 | "num_layers" : 1,
21 | "num_channels" : 32,
22 | "num_degrees" : 2,
23 | "l0_in_features": 64,
24 | "l0_out_features": 64,
25 | "l1_in_features": 3,
26 | "l1_out_features": 2,
27 | "num_edge_features": 64,
28 | "div": 4,
29 | "n_heads": 4
30 | }
31 | },
32 | "weight_fn": ["full_bigSE3_model1.pt", "full_bigSE3_model2.pt", "full_bigSE3_model3.pt"]
33 | },
34 | "full_smallSE3":
35 | {
36 | "description": "deep architecture w/ small SE(3)-Transformer on fully connected graph. Trained on biounit",
37 | "model_param":{
38 | "n_extra_block" : 4,
39 | "n_main_block" : 32,
40 | "n_ref_block" : 4,
41 | "d_msa" : 256 ,
42 | "d_pair" : 128,
43 | "d_templ" : 64,
44 | "n_head_msa" : 8,
45 | "n_head_pair" : 4,
46 | "n_head_templ" : 4,
47 | "d_hidden" : 32,
48 | "d_hidden_templ" : 32,
49 | "p_drop" : 0.0,
50 | "SE3_param_full": {
51 | "num_layers" : 1,
52 | "num_channels" : 32,
53 | "num_degrees" : 2,
54 | "l0_in_features": 8,
55 | "l0_out_features": 8,
56 | "l1_in_features": 3,
57 | "l1_out_features": 2,
58 | "num_edge_features": 32,
59 | "div": 4,
60 | "n_heads": 4
61 | },
62 | "SE3_param_topk": {
63 | "num_layers" : 1,
64 | "num_channels" : 32,
65 | "num_degrees" : 2,
66 | "l0_in_features": 64,
67 | "l0_out_features": 64,
68 | "l1_in_features": 3,
69 | "l1_out_features": 2,
70 | "num_edge_features": 64,
71 | "div": 4,
72 | "n_heads": 4
73 | }
74 | },
75 | "weight_fn": ["full_smallSE3_model1.pt", "full_smallSE3_model2.pt"]
76 | }
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/network/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 |
5 | # pre-activation bottleneck resblock
6 | class ResBlock2D_bottleneck(nn.Module):
7 | def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
8 | super(ResBlock2D_bottleneck, self).__init__()
9 | padding = self._get_same_padding(kernel, dilation)
10 |
11 | n_b = n_c // 2 # bottleneck channel
12 |
13 | layer_s = list()
14 | # pre-activation
15 | layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
16 | layer_s.append(nn.ELU(inplace=True))
17 | # project down to n_b
18 | layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
19 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
20 | layer_s.append(nn.ELU(inplace=True))
21 | # convolution
22 | layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
23 | layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
24 | layer_s.append(nn.ELU(inplace=True))
25 | # dropout
26 | layer_s.append(nn.Dropout(p_drop))
27 | # project up
28 | layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
29 |
30 | # make final layer initialize with zeros
31 | #nn.init.zeros_(layer_s[-1].weight)
32 |
33 | self.layer = nn.Sequential(*layer_s)
34 |
35 | self.reset_parameter()
36 |
37 | def reset_parameter(self):
38 | # zero-initialize final layer right before residual connection
39 | nn.init.zeros_(self.layer[-1].weight)
40 |
41 | def _get_same_padding(self, kernel, dilation):
42 | return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
43 |
44 | def forward(self, x):
45 | out = self.layer(x)
46 | return x + out
47 |
48 | class ResidualNetwork(nn.Module):
49 | def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
50 | dilation=[1,2,4,8], p_drop=0.15):
51 | super(ResidualNetwork, self).__init__()
52 |
53 |
54 | layer_s = list()
55 | # project to n_feat_block
56 | if n_feat_in != n_feat_block:
57 | layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
58 |
59 | # add resblocks
60 | for i_block in range(n_block):
61 | d = dilation[i_block%len(dilation)]
62 | res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
63 | layer_s.append(res_block)
64 |
65 | if n_feat_out != n_feat_block:
66 | # project to n_feat_out
67 | layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
68 |
69 | self.layer = nn.Sequential(*layer_s)
70 |
71 | def forward(self, x):
72 | return self.layer(x)
73 |
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffdata:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffdata
--------------------------------------------------------------------------------
/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffindex:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/forks/RoseTTAFold2NA/pdb100_2021Mar03/pdb100_2021Mar03_pdb.ffindex
--------------------------------------------------------------------------------
/img/Nucleic_Acid_Diffusion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/img/Nucleic_Acid_Diffusion.gif
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/logs/.gitkeep
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/notebooks/.gitkeep
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | addopts = [
3 | "--color=yes",
4 | "--durations=0",
5 | "--strict-markers",
6 | "--doctest-modules",
7 | ]
8 | filterwarnings = [
9 | "ignore::DeprecationWarning",
10 | "ignore::UserWarning",
11 | ]
12 | log_cli = "True"
13 | markers = [
14 | "slow: slow tests",
15 | ]
16 | minversion = "6.0"
17 | testpaths = "tests/"
18 |
19 | [tool.coverage.report]
20 | exclude_lines = [
21 | "pragma: nocover",
22 | "raise NotImplementedError",
23 | "raise NotImplementedError()",
24 | "if __name__ == .__main__.:",
25 | ]
26 |
--------------------------------------------------------------------------------
/scripts/schedule.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Schedule execution of many runs
3 | # Run from root folder with: bash scripts/schedule.sh
4 |
5 | python src/train.py trainer.max_epochs=5 logger=csv
6 |
7 | python src/train.py trainer.max_epochs=10 logger=csv
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | setup(
6 | name="MMDiff",
7 | version="0.0.1",
8 | description="Official PyTorch implementation of 'Joint Sequence-Structure Generation of Nucleic Acid and Protein Complexes with SE(3)-Discrete Diffusion'.",
9 | author="Alex Morehead",
10 | author_email="alex.morehead@gmail.com",
11 | url="https://github.com/Profluent-Internships/MMDiff",
12 | install_requires=["lightning", "hydra-core"],
13 | packages=find_packages(),
14 | # use this to customize global commands available in the terminal after installing the package
15 | entry_points={
16 | "console_scripts": [
17 | "train_command = src.train:main",
18 | "sample_command = src.sample:main",
19 | ]
20 | },
21 | )
22 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | from beartype import beartype
4 | from beartype.typing import Any
5 | from omegaconf import OmegaConf
6 |
7 |
8 | @beartype
9 | def resolve_omegaconf_variable(variable_path: str) -> Any:
10 | # split the string into parts using the dot separator
11 | parts = variable_path.rsplit(".", 1)
12 |
13 | # get the module name from the first part of the path
14 | module_name = parts[0]
15 |
16 | # dynamically import the module using the module name
17 | module = importlib.import_module(module_name)
18 |
19 | # use the imported module to get the requested attribute value
20 | attribute = getattr(module, parts[1])
21 |
22 | return attribute
23 |
24 |
25 | def register_custom_omegaconf_resolvers():
26 | OmegaConf.register_new_resolver("add", lambda x, y: x + y)
27 | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
28 | OmegaConf.register_new_resolver(
29 | "resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path)
30 | )
31 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/data/components/__init__.py
--------------------------------------------------------------------------------
/src/data/components/pdb/complex_constants.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code curated for (https://github.com/Profluent-Internships/MMDiff):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 |
5 | NUM_PROT_TORSIONS = 1
6 | NUM_NA_TORSIONS = 8
7 | NUM_PROT_NA_TORSIONS = 10
8 | PYRIMIDINE_RESIDUE_TOKENS = [22, 24, 26, 28]
9 |
10 | # This is the standard residue order when coding protein-nucleic acid residue types as a number (for PDB visualization purposes).
11 | restypes = [
12 | "A",
13 | "C",
14 | "D",
15 | "E",
16 | "F",
17 | "G",
18 | "H",
19 | "I",
20 | "K",
21 | "L",
22 | "M",
23 | "N",
24 | "P",
25 | "Q",
26 | "R",
27 | "S",
28 | "T",
29 | "V",
30 | "W",
31 | "Y",
32 | "X",
33 | "da",
34 | "dc",
35 | "dg",
36 | "dt",
37 | "a",
38 | "c",
39 | "g",
40 | "u",
41 | "x",
42 | "-",
43 | "_",
44 | "1",
45 | "2",
46 | "3",
47 | "4",
48 | "5",
49 | ]
50 | restype_order = {restype: i for i, restype in enumerate(restypes)}
51 | restype_num = len(restypes) # := 37.
52 |
53 | restype_1to3 = {
54 | "A": "ALA",
55 | "C": "CYS",
56 | "D": "ASP",
57 | "E": "GLU",
58 | "F": "PHE",
59 | "G": "GLY",
60 | "H": "HIS",
61 | "I": "ILE",
62 | "K": "LYS",
63 | "L": "LEU",
64 | "M": "MET",
65 | "N": "ASN",
66 | "P": "PRO",
67 | "Q": "GLN",
68 | "R": "ARG",
69 | "S": "SER",
70 | "T": "THR",
71 | "V": "VAL",
72 | "W": "TRP",
73 | "Y": "TYR",
74 | "X": "UNK",
75 | "da": "DA",
76 | "dc": "DC",
77 | "dg": "DG",
78 | "dt": "DT",
79 | "a": "A",
80 | "c": "C",
81 | "g": "G",
82 | "u": "U",
83 | "x": "unk",
84 | "-": "GAP",
85 | "_": "PAD",
86 | "1": "SP1",
87 | "2": "SP2",
88 | "3": "SP3",
89 | "4": "SP4",
90 | "5": "SP5",
91 | }
92 | restype_3to1 = {v: k for k, v in restype_1to3.items()}
93 | restypes_3 = list(restype_3to1.keys())
94 | restypes_1 = list(restype_1to3.keys())
95 |
96 | protein_restypes = [
97 | "A",
98 | "R",
99 | "N",
100 | "D",
101 | "C",
102 | "Q",
103 | "E",
104 | "G",
105 | "H",
106 | "I",
107 | "L",
108 | "K",
109 | "M",
110 | "F",
111 | "P",
112 | "S",
113 | "T",
114 | "W",
115 | "Y",
116 | "V",
117 | ]
118 | nucleic_restypes = ["DA", "DC", "DG", "DT", "A", "C", "G", "U"]
119 | special_restypes = ["-", "_", "1", "2", "3", "4", "5"]
120 | unknown_protein_restype = "X"
121 | unknown_nucleic_restype = "x"
122 | gap_token = restypes.index("-")
123 | pad_token = restypes.index("_")
124 | unknown_protein_token = restypes.index(unknown_protein_restype) # := 20
125 | unknown_nucleic_token = restypes.index(unknown_nucleic_restype) # := 29
126 |
127 | protein_restype_num = len(protein_restypes + [unknown_protein_restype]) # := 21
128 | na_restype_num = len(nucleic_restypes + [unknown_nucleic_restype]) # := 9
129 | protein_na_restype_num = len(
130 | protein_restypes + [unknown_protein_restype] + nucleic_restypes + [unknown_nucleic_restype]
131 | ) # := 30
132 |
133 | default_protein_restype = restypes.index("A") # := 0
134 | default_na_restype = restypes.index("da") # := 21
135 |
136 | alternative_restypes_map = {
137 | # Protein
138 | "MSE": "MET",
139 | }
140 | allowable_restypes = set(restypes + list(alternative_restypes_map.keys()))
141 |
--------------------------------------------------------------------------------
/src/data/components/pdb/errors.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code adapted from se3_diffusion (https://github.com/jasonkyuyim/se3_diffusion):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 | """Error class for handled errors."""
5 |
6 |
7 | class DataError(Exception):
8 | """Data exception."""
9 |
10 | pass
11 |
12 |
13 | class FileExistsError(DataError):
14 | """Raised when file already exists."""
15 |
16 | pass
17 |
18 |
19 | class MmcifParsingError(DataError):
20 | """Raised when mmcif parsing fails."""
21 |
22 | pass
23 |
24 |
25 | class ResolutionError(DataError):
26 | """Raised when resolution isn't acceptable."""
27 |
28 | pass
29 |
30 |
31 | class LengthError(DataError):
32 | """Raised when length isn't acceptable."""
33 |
34 | pass
35 |
--------------------------------------------------------------------------------
/src/data/components/pdb/join_pdb_metadata.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code curated for (https://github.com/Profluent-Internships/MMDiff):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 |
5 | import hydra
6 | import pandas as pd
7 | from omegaconf import DictConfig
8 |
9 |
10 | @hydra.main(
11 | version_base="1.3", config_path="../../../../configs/paths", config_name="pdb_metadata.yaml"
12 | )
13 | def main(cfg: DictConfig):
14 | na_df = pd.read_csv(cfg.na_metadata_csv_path)
15 | protein_df = pd.read_csv(cfg.protein_metadata_csv_path)
16 | # impute missing columns for the nucleic acid DataFrame
17 | na_df["oligomeric_count"] = na_df["num_chains"].astype(str)
18 | na_df["oligomeric_detail"] = na_df["num_chains"].apply(
19 | lambda x: "heteromeric" if x > 1 else "monomeric"
20 | )
21 | # note: we can reasonably assume the following two column values due to our initial filtering for nucleic acid molecules
22 | na_df["resolution"] = 0.0
23 | na_df["structure_method"] = "x-ray diffraction"
24 | output_df = pd.concat([na_df, protein_df]).drop_duplicates(
25 | subset=["pdb_name"], keep="first", ignore_index=True
26 | )
27 | output_df.to_csv(cfg.metadata_output_csv_path, index=False)
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/src/data/components/pdb/relax/relax_utils.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code adapted from openfold (https://github.com/aqlaboratory/openfold):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 |
5 | # Copyright 2021 AlQuraishi Laboratory
6 | # Copyright 2021 DeepMind Technologies Limited
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | """Utils for minimization."""
20 | import io
21 |
22 | import numpy as np
23 | from Bio import PDB
24 |
25 | try:
26 | # openmm >= 7.6
27 | from openmm import app as openmm_app
28 | from openmm.app.internal.pdbstructure import PdbStructure
29 | except ImportError:
30 | # openmm < 7.6 (requires DeepMind patch)
31 | from simtk.openmm import app as openmm_app
32 | from simtk.openmm.app.internal.pdbstructure import PdbStructure
33 |
34 | from src.data.components.pdb import protein_constants
35 |
36 |
37 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
38 | pdb_file = io.StringIO(pdb_str)
39 | structure = PdbStructure(pdb_file)
40 | topology = openmm_app.PDBFile(structure).getTopology()
41 | with io.StringIO() as f:
42 | openmm_app.PDBFile.writeFile(topology, pos, f)
43 | return f.getvalue()
44 |
45 |
46 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
47 | """Overwrites the B-factors in pdb_str with contents of bfactors array.
48 |
49 | Args:
50 | pdb_str: An input PDB string.
51 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
52 | B-factors are per residue; i.e. that the nonzero entries are identical in
53 | [0, i, :].
54 |
55 | Returns:
56 | A new PDB string with the B-factors replaced.
57 | """
58 | if bfactors.shape[-1] != protein_constants.atom_type_num:
59 | raise ValueError(f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}.")
60 |
61 | parser = PDB.PDBParser(QUIET=True)
62 | handle = io.StringIO(pdb_str)
63 | structure = parser.get_structure("", handle)
64 |
65 | curr_resid = ("", "", "")
66 | idx = -1
67 | for atom in structure.get_atoms():
68 | atom_resid = atom.parent.get_id()
69 | if atom_resid != curr_resid:
70 | idx += 1
71 | if idx >= bfactors.shape[0]:
72 | raise ValueError(
73 | "Index into bfactors exceeds number of residues. "
74 | "B-factors shape: {shape}, idx: {idx}."
75 | )
76 | curr_resid = atom_resid
77 | atom.bfactor = bfactors[idx, protein_constants.atom_order["CA"]]
78 |
79 | new_pdb = io.StringIO()
80 | pdb_io = PDB.PDBIO()
81 | pdb_io.set_structure(structure)
82 | pdb_io.save(new_pdb)
83 | return new_pdb.getvalue()
84 |
85 |
86 | def assert_equal_nonterminal_atom_types(atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
87 | """Checks that pre- and post-minimized proteins have same atom set."""
88 | # Ignore any terminal OXT atoms which may have been added by minimization.
89 | oxt = protein_constants.atom_order["OXT"]
90 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
91 | no_oxt_mask[..., oxt] = False
92 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask])
93 |
--------------------------------------------------------------------------------
/src/data/components/pdb/vocabulary.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code curated for MMDiff (https://github.com/Profluent-Internships/MMDiff):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 |
5 | # This is the standard residue order when coding residues type as a number.
6 | restypes = [
7 | "A",
8 | "C",
9 | "D",
10 | "E",
11 | "F",
12 | "G",
13 | "H",
14 | "I",
15 | "K",
16 | "L",
17 | "M",
18 | "N",
19 | "P",
20 | "Q",
21 | "R",
22 | "S",
23 | "T",
24 | "V",
25 | "W",
26 | "Y",
27 | "X",
28 | "a",
29 | "c",
30 | "g",
31 | "t",
32 | "u",
33 | "x",
34 | "-",
35 | "_",
36 | "1",
37 | "2",
38 | "3",
39 | "4",
40 | "5",
41 | ]
42 | restype_order = {restype: i for i, restype in enumerate(restypes)}
43 |
44 | restype_1to3 = {
45 | "A": "ALA",
46 | "C": "CYS",
47 | "D": "ASP",
48 | "E": "GLU",
49 | "F": "PHE",
50 | "G": "GLY",
51 | "H": "HIS",
52 | "I": "ILE",
53 | "K": "LYS",
54 | "L": "LEU",
55 | "M": "MET",
56 | "N": "ASN",
57 | "P": "PRO",
58 | "Q": "GLN",
59 | "R": "ARG",
60 | "S": "SER",
61 | "T": "THR",
62 | "V": "VAL",
63 | "W": "TRP",
64 | "Y": "TYR",
65 | "X": "UNK",
66 | "a": "A",
67 | "c": "C",
68 | "g": "G",
69 | "t": "T",
70 | "u": "U",
71 | "x": "unk",
72 | "-": "GAP",
73 | "_": "PAD",
74 | "1": "SP1",
75 | "2": "SP2",
76 | "3": "SP3",
77 | "4": "SP4",
78 | "5": "SP5",
79 | }
80 | restype_3to1 = {v: k for k, v in restype_1to3.items()}
81 | restype_3to1.update(
82 | {
83 | "DA": "a",
84 | "DC": "c",
85 | "DG": "g",
86 | "DT": "t",
87 | }
88 | )
89 |
90 | protein_restypes = [
91 | "A",
92 | "C",
93 | "D",
94 | "E",
95 | "F",
96 | "G",
97 | "H",
98 | "I",
99 | "K",
100 | "L",
101 | "M",
102 | "N",
103 | "P",
104 | "Q",
105 | "R",
106 | "S",
107 | "T",
108 | "V",
109 | "W",
110 | "Y",
111 | ]
112 | protein_restype_order = {restype: i for i, restype in enumerate(protein_restypes)}
113 |
114 | protein_restype_1to3 = {
115 | "A": "ALA",
116 | "C": "CYS",
117 | "D": "ASP",
118 | "E": "GLU",
119 | "F": "PHE",
120 | "G": "GLY",
121 | "H": "HIS",
122 | "I": "ILE",
123 | "K": "LYS",
124 | "L": "LEU",
125 | "M": "MET",
126 | "N": "ASN",
127 | "P": "PRO",
128 | "Q": "GLN",
129 | "R": "ARG",
130 | "S": "SER",
131 | "T": "THR",
132 | "V": "VAL",
133 | "W": "TRP",
134 | "Y": "TYR",
135 | }
136 | protein_restype_3to1 = {v: k for k, v in protein_restype_1to3.items()}
137 |
138 | protein_restypes_with_x = protein_restypes + ["X"]
139 | protein_restype_order_with_x = {restype: i for i, restype in enumerate(protein_restypes_with_x)}
140 | nucleic_restypes = ["a", "c", "g", "t", "u"]
141 | special_restypes = ["-", "_", "1", "2", "3", "4", "5"]
142 | unknown_protein_restype = "X"
143 | unknown_nucleic_restype = "x"
144 | gap_token = restypes.index("-")
145 | pad_token = restypes.index("_")
146 |
147 | restype_num = len(restypes) # := 34.
148 | protein_restype_num = len(protein_restypes) # := 20
149 | protein_restype_num_with_x = len(protein_restypes_with_x) # := 21
150 |
151 | protein_resnames = [restype_1to3[r] for r in protein_restypes]
152 | protein_resname_to_idx = {resname: i for i, resname in enumerate(protein_resnames)}
153 |
154 | alternative_restypes_map = {
155 | # Protein
156 | "MSE": "MET",
157 | }
158 | allowable_restypes = set(restypes + list(alternative_restypes_map.keys()))
159 |
160 |
161 | def is_protein_sequence(sequence: str) -> bool:
162 | """Check if a sequence is a protein sequence."""
163 |
164 | return all([s in protein_restypes for s in sequence])
165 |
166 |
167 | def is_nucleic_sequence(sequence: str) -> bool:
168 | """Check if a sequence is a nucleic acid sequence."""
169 |
170 | return all([s in nucleic_restypes for s in sequence])
171 |
--------------------------------------------------------------------------------
/src/models/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/src/models/components/__init__.py
--------------------------------------------------------------------------------
/src/models/components/pdb/embedders.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------------------------------------------------------------------
2 | # Following code curated for MMDiff (https://github.com/Profluent-Internships/MMDiff):
3 | # -------------------------------------------------------------------------------------------------------------------------------------
4 |
5 | import torch
6 | import torch.nn as nn
7 | from beartype import beartype
8 | from beartype.typing import Optional
9 |
10 | from src.data.components.pdb.data_transforms import make_one_hot
11 | from src.models.components.pdb.framediff import Linear
12 |
13 |
14 | class RelPosEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | embedding_size: int,
18 | max_relative_idx: int,
19 | max_relative_chain: int = 0,
20 | use_chain_relative: bool = False,
21 | ):
22 | super().__init__()
23 | self.max_relative_idx = max_relative_idx
24 | self.max_relative_chain = max_relative_chain
25 | self.use_chain_relative = use_chain_relative
26 | self.num_bins = 2 * max_relative_idx + 2
27 | if max_relative_chain > 0:
28 | self.num_bins += 2 * max_relative_chain + 2
29 | if use_chain_relative:
30 | self.num_bins += 1
31 |
32 | self.linear_relpos = Linear(self.num_bins, embedding_size)
33 |
34 | @beartype
35 | def forward(
36 | self,
37 | residue_index: torch.Tensor,
38 | asym_id: Optional[torch.Tensor] = None,
39 | sym_id: Optional[torch.Tensor] = None,
40 | entity_id: Optional[torch.Tensor] = None,
41 | ) -> torch.Tensor:
42 | d = residue_index[..., None] - residue_index[..., None, :]
43 |
44 | if asym_id is None:
45 | # compute relative position encoding according to AlphaFold's `relpos` algorithm
46 | boundaries = torch.arange(
47 | start=-self.max_relative_idx, end=self.max_relative_idx + 1, device=d.device
48 | )
49 | reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
50 | d = d[..., None] - reshaped_bins
51 | d = torch.abs(d)
52 | d = torch.argmin(d, dim=-1)
53 | d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
54 | d = d.to(residue_index.dtype)
55 | rel_feat = d
56 | else:
57 | # compute relative position encoding according to AlphaFold-Multimer's `relpos` algorithm
58 | rel_feats = []
59 | asym_id_same = torch.eq(asym_id[..., None], asym_id[..., None, :])
60 | offset = residue_index[..., None] - residue_index[..., None, :]
61 |
62 | clipped_offset = torch.clamp(
63 | input=offset + self.max_relative_idx, min=0, max=(2 * self.max_relative_idx)
64 | )
65 |
66 | final_offset = torch.where(
67 | condition=asym_id_same,
68 | input=clipped_offset,
69 | other=((2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset)),
70 | )
71 |
72 | rel_pos = make_one_hot(x=final_offset, num_classes=(2 * self.max_relative_idx + 2))
73 | rel_feats.append(rel_pos)
74 |
75 | if self.use_chain_relative:
76 | entity_id_same = torch.eq(entity_id[..., None], entity_id[..., None, :])
77 | rel_feats.append(entity_id_same.type(rel_pos.dtype)[..., None])
78 |
79 | if self.max_relative_chain > 0:
80 | rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
81 | max_rel_chain = self.max_relative_chain
82 |
83 | clipped_rel_chain = torch.clamp(
84 | input=rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain
85 | )
86 |
87 | if not self.use_chain_relative:
88 | # ensure `entity_id_same` is constructed for `rel_chain`
89 | entity_id_same = torch.eq(entity_id[..., None], entity_id[..., None, :])
90 |
91 | final_rel_chain = torch.where(
92 | condition=entity_id_same,
93 | input=clipped_rel_chain,
94 | other=(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
95 | )
96 | rel_chain = make_one_hot(
97 | x=final_rel_chain, num_classes=2 * self.max_relative_chain + 2
98 | )
99 | rel_feats.append(rel_chain)
100 |
101 | rel_feat = torch.cat(rel_feats, dim=-1)
102 |
103 | return self.linear_relpos(rel_feat)
104 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
2 | from src.utils.logging_utils import log_hyperparameters
3 | from src.utils.pylogger import get_pylogger
4 | from src.utils.rich_utils import enforce_tags, print_config_tree
5 | from src.utils.utils import extras, get_metric_value, task_wrapper
6 |
--------------------------------------------------------------------------------
/src/utils/instantiators.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from beartype.typing import List
3 | from lightning import Callback
4 | from lightning.pytorch.loggers import Logger
5 | from omegaconf import DictConfig
6 |
7 | from src.utils import pylogger
8 |
9 | log = pylogger.get_pylogger(__name__)
10 |
11 |
12 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
13 | """Instantiates callbacks from config."""
14 |
15 | callbacks: List[Callback] = []
16 |
17 | if not callbacks_cfg:
18 | log.warning("No callback configs found! Skipping..")
19 | return callbacks
20 |
21 | if not isinstance(callbacks_cfg, DictConfig):
22 | raise TypeError("Callbacks config must be a DictConfig!")
23 |
24 | for _, cb_conf in callbacks_cfg.items():
25 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
26 | log.info(f"Instantiating callback <{cb_conf._target_}>")
27 | callbacks.append(hydra.utils.instantiate(cb_conf))
28 |
29 | return callbacks
30 |
31 |
32 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
33 | """Instantiates loggers from config."""
34 |
35 | logger: List[Logger] = []
36 |
37 | if not logger_cfg:
38 | log.warning("No logger configs found! Skipping...")
39 | return logger
40 |
41 | if not isinstance(logger_cfg, DictConfig):
42 | raise TypeError("Logger config must be a DictConfig!")
43 |
44 | for _, lg_conf in logger_cfg.items():
45 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
46 | log.info(f"Instantiating logger <{lg_conf._target_}>")
47 | logger.append(hydra.utils.instantiate(lg_conf))
48 |
49 | return logger
50 |
--------------------------------------------------------------------------------
/src/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from lightning.pytorch.utilities import rank_zero_only
2 |
3 | from src.utils import pylogger
4 |
5 | log = pylogger.get_pylogger(__name__)
6 |
7 |
8 | @rank_zero_only
9 | def log_hyperparameters(object_dict: dict) -> None:
10 | """Controls which config parts are saved by lightning loggers.
11 |
12 | Additionally saves:
13 | - Number of model parameters
14 | """
15 |
16 | hparams = {}
17 |
18 | cfg = object_dict["cfg"]
19 | model = object_dict["model"]
20 | trainer = object_dict["trainer"]
21 |
22 | if not trainer.logger:
23 | log.warning("Logger not found! Skipping hyperparameter logging...")
24 | return
25 |
26 | hparams["model"] = cfg["model"]
27 |
28 | # save number of model parameters
29 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
30 | hparams["model/params/trainable"] = sum(
31 | p.numel() for p in model.parameters() if p.requires_grad
32 | )
33 | hparams["model/params/non_trainable"] = sum(
34 | p.numel() for p in model.parameters() if not p.requires_grad
35 | )
36 |
37 | hparams["data"] = cfg["data"]
38 | hparams["trainer"] = cfg["trainer"]
39 |
40 | hparams["callbacks"] = cfg.get("callbacks")
41 | hparams["extras"] = cfg.get("extras")
42 |
43 | hparams["task_name"] = cfg.get("task_name")
44 | hparams["tags"] = cfg.get("tags")
45 | hparams["ckpt_path"] = cfg.get("ckpt_path")
46 | hparams["seed"] = cfg.get("seed")
47 |
48 | # send hparams to all loggers
49 | for logger in trainer.loggers:
50 | logger.log_hyperparams(hparams)
51 |
--------------------------------------------------------------------------------
/src/utils/pylogger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from lightning.pytorch.utilities import rank_zero_only
4 |
5 |
6 | def get_pylogger(name=__name__) -> logging.Logger:
7 | """Initializes multi-GPU-friendly python command line logger."""
8 |
9 | logger = logging.getLogger(name)
10 |
11 | # this ensures all logging levels get marked with the rank zero decorator
12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
14 | for level in logging_levels:
15 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
16 |
17 | return logger
18 |
--------------------------------------------------------------------------------
/src/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import rich
4 | import rich.syntax
5 | import rich.tree
6 | from beartype.typing import Sequence
7 | from hydra.core.hydra_config import HydraConfig
8 | from lightning.pytorch.utilities import rank_zero_only
9 | from omegaconf import DictConfig, OmegaConf, open_dict
10 | from rich.prompt import Prompt
11 |
12 | from src.utils import pylogger
13 |
14 | log = pylogger.get_pylogger(__name__)
15 |
16 |
17 | @rank_zero_only
18 | def print_config_tree(
19 | cfg: DictConfig,
20 | print_order: Sequence[str] = (
21 | "data",
22 | "model",
23 | "callbacks",
24 | "logger",
25 | "trainer",
26 | "paths",
27 | "extras",
28 | ),
29 | resolve: bool = False,
30 | save_to_file: bool = False,
31 | ) -> None:
32 | """Prints content of DictConfig using Rich library and its tree structure.
33 |
34 | Args:
35 | cfg (DictConfig): Configuration composed by Hydra.
36 | print_order (Sequence[str], optional): Determines in what order config components are printed.
37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
38 | save_to_file (bool, optional): Whether to export config to the hydra output folder.
39 | """
40 |
41 | style = "dim"
42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
43 |
44 | queue = []
45 |
46 | # add fields from `print_order` to queue
47 | for field in print_order:
48 | queue.append(field) if field in cfg else log.warning(
49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..."
50 | )
51 |
52 | # add all the other fields to queue (not specified in `print_order`)
53 | for field in cfg:
54 | if field not in queue:
55 | queue.append(field)
56 |
57 | # generate config tree from queue
58 | for field in queue:
59 | branch = tree.add(field, style=style, guide_style=style)
60 |
61 | config_group = cfg[field]
62 | if isinstance(config_group, DictConfig):
63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
64 | else:
65 | branch_content = str(config_group)
66 |
67 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
68 |
69 | # print config tree
70 | rich.print(tree)
71 |
72 | # save config tree to file
73 | if save_to_file:
74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
75 | rich.print(tree, file=file)
76 |
77 |
78 | @rank_zero_only
79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
80 | """Prompts user to input tags from command line if no tags are provided in config."""
81 |
82 | if not cfg.get("tags"):
83 | if "id" in HydraConfig().cfg.hydra.job:
84 | raise ValueError("Specify tags before launching a multirun!")
85 |
86 | log.warning("No tags provided in config. Prompting user to input tags...")
87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
88 | tags = [t.strip() for t in tags.split(",") if t != ""]
89 |
90 | with open_dict(cfg):
91 | cfg.tags = tags
92 |
93 | log.info(f"Tags: {cfg.tags}")
94 |
95 | if save_to_file:
96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
97 | rich.print(cfg.tags, file=file)
98 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """This file prepares config fixtures for other tests."""
2 |
3 | import pyrootutils
4 | import pytest
5 | from hydra import compose, initialize
6 | from hydra.core.global_hydra import GlobalHydra
7 | from omegaconf import DictConfig, open_dict
8 |
9 |
10 | @pytest.fixture(scope="package")
11 | def cfg_train_global() -> DictConfig:
12 | with initialize(version_base="1.3", config_path="../configs"):
13 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
14 |
15 | # set defaults for all tests
16 | with open_dict(cfg):
17 | cfg.paths.root_dir = str(pyrootutils.find_root(indicator=".project-root"))
18 | cfg.trainer.max_epochs = 1
19 | cfg.trainer.limit_train_batches = 0.01
20 | cfg.trainer.limit_val_batches = 0.1
21 | cfg.trainer.limit_test_batches = 0.1
22 | cfg.trainer.accelerator = "cpu"
23 | cfg.trainer.devices = 1
24 | cfg.data.num_workers = 0
25 | cfg.data.pin_memory = False
26 | cfg.extras.print_config = False
27 | cfg.extras.enforce_tags = False
28 | cfg.logger = None
29 |
30 | return cfg
31 |
32 |
33 | @pytest.fixture(scope="package")
34 | def cfg_sample_global() -> DictConfig:
35 | with initialize(version_base="1.3", config_path="../configs"):
36 | cfg = compose(
37 | config_name="sample.yaml", return_hydra_config=True, overrides=["ckpt_path=."]
38 | )
39 |
40 | # set defaults for all tests
41 | with open_dict(cfg):
42 | cfg.paths.root_dir = str(pyrootutils.find_root(indicator=".project-root"))
43 | cfg.trainer.max_epochs = 1
44 | cfg.trainer.limit_test_batches = 0.1
45 | cfg.trainer.accelerator = "cpu"
46 | cfg.trainer.devices = 1
47 | cfg.data.num_workers = 0
48 | cfg.data.pin_memory = False
49 | cfg.extras.print_config = False
50 | cfg.extras.enforce_tags = False
51 | cfg.logger = None
52 |
53 | return cfg
54 |
55 |
56 | # this is called by each test which uses `cfg_train` arg
57 | # each test generates its own temporary logging path
58 | @pytest.fixture(scope="function")
59 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig:
60 | cfg = cfg_train_global.copy()
61 |
62 | with open_dict(cfg):
63 | cfg.paths.output_dir = str(tmp_path)
64 | cfg.paths.log_dir = str(tmp_path)
65 |
66 | yield cfg
67 |
68 | GlobalHydra.instance().clear()
69 |
70 |
71 | # this is called by each test which uses `cfg_sample` arg
72 | # each test generates its own temporary logging path
73 | @pytest.fixture(scope="function")
74 | def cfg_sample(cfg_sample_global, tmp_path) -> DictConfig:
75 | cfg = cfg_sample_global.copy()
76 |
77 | with open_dict(cfg):
78 | cfg.paths.output_dir = str(tmp_path)
79 | cfg.paths.log_dir = str(tmp_path)
80 |
81 | yield cfg
82 |
83 | GlobalHydra.instance().clear()
84 |
--------------------------------------------------------------------------------
/tests/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Profluent-Internships/MMDiff/e21192bb8e815c765eaa18ee0f7bacdcc6af4044/tests/helpers/__init__.py
--------------------------------------------------------------------------------
/tests/helpers/package_available.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | import pkg_resources
4 | from lightning.fabric.accelerators import TPUAccelerator
5 |
6 |
7 | def _package_available(package_name: str) -> bool:
8 | """Check if a package is available in your environment."""
9 | try:
10 | return pkg_resources.require(package_name) is not None
11 | except pkg_resources.DistributionNotFound:
12 | return False
13 |
14 |
15 | _TPU_AVAILABLE = TPUAccelerator.is_available()
16 |
17 | _IS_WINDOWS = platform.system() == "Windows"
18 |
19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh")
20 |
21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed")
22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale")
23 |
24 | _WANDB_AVAILABLE = _package_available("wandb")
25 | _NEPTUNE_AVAILABLE = _package_available("neptune")
26 | _COMET_AVAILABLE = _package_available("comet_ml")
27 | _MLFLOW_AVAILABLE = _package_available("mlflow")
28 |
--------------------------------------------------------------------------------
/tests/helpers/run_if.py:
--------------------------------------------------------------------------------
1 | """Adapted from:
2 |
3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
4 | """
5 |
6 | import sys
7 |
8 | import pytest
9 | import torch
10 | from beartype.typing import Optional
11 | from packaging.version import Version
12 | from pkg_resources import get_distribution
13 |
14 | from tests.helpers.package_available import (
15 | _COMET_AVAILABLE,
16 | _DEEPSPEED_AVAILABLE,
17 | _FAIRSCALE_AVAILABLE,
18 | _IS_WINDOWS,
19 | _MLFLOW_AVAILABLE,
20 | _NEPTUNE_AVAILABLE,
21 | _SH_AVAILABLE,
22 | _TPU_AVAILABLE,
23 | _WANDB_AVAILABLE,
24 | )
25 |
26 |
27 | class RunIf:
28 | """RunIf wrapper for conditional skipping of tests.
29 |
30 | Fully compatible with `@pytest.mark`.
31 |
32 | Example:
33 |
34 | @RunIf(min_torch="1.8")
35 | @pytest.mark.parametrize("arg1", [1.0, 2.0])
36 | def test_wrapper(arg1):
37 | assert arg1 > 0
38 | """
39 |
40 | def __new__(
41 | self,
42 | min_gpus: int = 0,
43 | min_torch: Optional[str] = None,
44 | max_torch: Optional[str] = None,
45 | min_python: Optional[str] = None,
46 | skip_windows: bool = False,
47 | sh: bool = False,
48 | tpu: bool = False,
49 | fairscale: bool = False,
50 | deepspeed: bool = False,
51 | wandb: bool = False,
52 | neptune: bool = False,
53 | comet: bool = False,
54 | mlflow: bool = False,
55 | **kwargs,
56 | ):
57 | """
58 | Args:
59 | min_gpus: min number of GPUs required to run test
60 | min_torch: minimum pytorch version to run test
61 | max_torch: maximum pytorch version to run test
62 | min_python: minimum python version required to run test
63 | skip_windows: skip test for Windows platform
64 | tpu: if TPU is available
65 | sh: if `sh` module is required to run the test
66 | fairscale: if `fairscale` module is required to run the test
67 | deepspeed: if `deepspeed` module is required to run the test
68 | wandb: if `wandb` module is required to run the test
69 | neptune: if `neptune` module is required to run the test
70 | comet: if `comet` module is required to run the test
71 | mlflow: if `mlflow` module is required to run the test
72 | kwargs: native pytest.mark.skipif keyword arguments
73 | """
74 | conditions = []
75 | reasons = []
76 |
77 | if min_gpus:
78 | conditions.append(torch.cuda.device_count() < min_gpus)
79 | reasons.append(f"GPUs>={min_gpus}")
80 |
81 | if min_torch:
82 | torch_version = get_distribution("torch").version
83 | conditions.append(Version(torch_version) < Version(min_torch))
84 | reasons.append(f"torch>={min_torch}")
85 |
86 | if max_torch:
87 | torch_version = get_distribution("torch").version
88 | conditions.append(Version(torch_version) >= Version(max_torch))
89 | reasons.append(f"torch<{max_torch}")
90 |
91 | if min_python:
92 | py_version = (
93 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
94 | )
95 | conditions.append(Version(py_version) < Version(min_python))
96 | reasons.append(f"python>={min_python}")
97 |
98 | if skip_windows:
99 | conditions.append(_IS_WINDOWS)
100 | reasons.append("does not run on Windows")
101 |
102 | if tpu:
103 | conditions.append(not _TPU_AVAILABLE)
104 | reasons.append("TPU")
105 |
106 | if sh:
107 | conditions.append(not _SH_AVAILABLE)
108 | reasons.append("sh")
109 |
110 | if fairscale:
111 | conditions.append(not _FAIRSCALE_AVAILABLE)
112 | reasons.append("fairscale")
113 |
114 | if deepspeed:
115 | conditions.append(not _DEEPSPEED_AVAILABLE)
116 | reasons.append("deepspeed")
117 |
118 | if wandb:
119 | conditions.append(not _WANDB_AVAILABLE)
120 | reasons.append("wandb")
121 |
122 | if neptune:
123 | conditions.append(not _NEPTUNE_AVAILABLE)
124 | reasons.append("neptune")
125 |
126 | if comet:
127 | conditions.append(not _COMET_AVAILABLE)
128 | reasons.append("comet")
129 |
130 | if mlflow:
131 | conditions.append(not _MLFLOW_AVAILABLE)
132 | reasons.append("mlflow")
133 |
134 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
135 | return pytest.mark.skipif(
136 | condition=any(conditions),
137 | reason=f"Requires: [{' + '.join(reasons)}]",
138 | **kwargs,
139 | )
140 |
--------------------------------------------------------------------------------
/tests/helpers/run_sh_command.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from beartype.typing import List
3 |
4 | from tests.helpers.package_available import _SH_AVAILABLE
5 |
6 | if _SH_AVAILABLE:
7 | import sh
8 |
9 |
10 | def run_sh_command(command: List[str]):
11 | """Default method for executing shell commands with pytest and sh package."""
12 | msg = None
13 | try:
14 | sh.python(command)
15 | except sh.ErrorReturnCode as e:
16 | msg = e.stderr.decode()
17 | if msg:
18 | pytest.fail(msg=msg)
19 |
--------------------------------------------------------------------------------
/tests/test_configs.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from hydra.core.hydra_config import HydraConfig
3 | from omegaconf import DictConfig
4 |
5 |
6 | def test_train_config(cfg_train: DictConfig):
7 | assert cfg_train
8 | assert cfg_train.data
9 | assert cfg_train.model
10 | assert cfg_train.trainer
11 |
12 | HydraConfig().set_config(cfg_train)
13 |
14 | hydra.utils.instantiate(cfg_train.data)
15 | hydra.utils.instantiate(cfg_train.model)
16 | hydra.utils.instantiate(cfg_train.trainer)
17 |
18 |
19 | def cfg_sample_config(cfg_sample: DictConfig):
20 | assert cfg_sample
21 | assert cfg_sample.data
22 | assert cfg_sample.model
23 | assert cfg_sample.trainer
24 |
25 | HydraConfig().set_config(cfg_sample)
26 |
27 | hydra.utils.instantiate(cfg_sample.data)
28 | hydra.utils.instantiate(cfg_sample.model)
29 | hydra.utils.instantiate(cfg_sample.trainer)
30 |
--------------------------------------------------------------------------------
/tests/test_datamodules.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | import torch
5 |
6 | from src.data.mnist_datamodule import MNISTDataModule
7 |
8 |
9 | @pytest.mark.parametrize("batch_size", [32, 128])
10 | def test_mnist_datamodule(batch_size):
11 | data_dir = "data/"
12 |
13 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size)
14 | dm.prepare_data()
15 |
16 | assert not dm.data_train and not dm.data_val and not dm.data_test
17 | assert Path(data_dir, "MNIST").exists()
18 | assert Path(data_dir, "MNIST", "raw").exists()
19 |
20 | dm.setup()
21 | assert dm.data_train and dm.data_val and dm.data_test
22 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()
23 |
24 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
25 | assert num_datapoints == 70_000
26 |
27 | batch = next(iter(dm.train_dataloader()))
28 | x, y = batch
29 | assert len(x) == batch_size
30 | assert len(y) == batch_size
31 | assert x.dtype == torch.float32
32 | assert y.dtype == torch.int64
33 |
--------------------------------------------------------------------------------
/tests/test_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from hydra.core.hydra_config import HydraConfig
5 | from omegaconf import open_dict
6 |
7 | from src.sample import sample
8 | from src.train import train
9 |
10 |
11 | @pytest.mark.slow
12 | def test_train_sample(tmp_path, cfg_train, cfg_sample):
13 | """Train for 1 epoch with `train.py` and run inference with `sample.py`"""
14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_sample.paths.output_dir
15 |
16 | with open_dict(cfg_train):
17 | cfg_train.trainer.max_epochs = 1
18 | cfg_train.test = True
19 |
20 | HydraConfig().set_config(cfg_train)
21 | train_metric_dict, _ = train(cfg_train)
22 |
23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")
24 |
25 | with open_dict(cfg_sample):
26 | cfg_sample.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
27 |
28 | HydraConfig().set_config(cfg_sample)
29 | test_metric_dict, _ = sample(cfg_sample)
30 |
31 | assert test_metric_dict["test/acc"] > 0.0
32 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001
33 |
--------------------------------------------------------------------------------
/tests/test_sweeps.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.run_if import RunIf
4 | from tests.helpers.run_sh_command import run_sh_command
5 |
6 | startfile = "src/train.py"
7 | overrides = ["logger=[]"]
8 |
9 |
10 | @RunIf(sh=True)
11 | @pytest.mark.slow
12 | def test_experiments(tmp_path):
13 | """Test running all available experiment configs with fast_dev_run=True."""
14 | command = [
15 | startfile,
16 | "-m",
17 | "experiment=glob(*)",
18 | "hydra.sweep.dir=" + str(tmp_path),
19 | "++trainer.fast_dev_run=true",
20 | ] + overrides
21 | run_sh_command(command)
22 |
23 |
24 | @RunIf(sh=True)
25 | @pytest.mark.slow
26 | def test_hydra_sweep(tmp_path):
27 | """Test default hydra sweep."""
28 | command = [
29 | startfile,
30 | "-m",
31 | "hydra.sweep.dir=" + str(tmp_path),
32 | "model.optimizer.lr=0.005,0.01",
33 | "++trainer.fast_dev_run=true",
34 | ] + overrides
35 |
36 | run_sh_command(command)
37 |
38 |
39 | @RunIf(sh=True)
40 | @pytest.mark.slow
41 | def test_hydra_sweep_ddp_sim(tmp_path):
42 | """Test default hydra sweep with ddp sim."""
43 | command = [
44 | startfile,
45 | "-m",
46 | "hydra.sweep.dir=" + str(tmp_path),
47 | "trainer=ddp_sim",
48 | "trainer.max_epochs=3",
49 | "+trainer.limit_train_batches=0.01",
50 | "+trainer.limit_val_batches=0.1",
51 | "+trainer.limit_test_batches=0.1",
52 | "model.optimizer.lr=0.005,0.01,0.02",
53 | ] + overrides
54 | run_sh_command(command)
55 |
56 |
57 | @RunIf(sh=True)
58 | @pytest.mark.slow
59 | def test_optuna_sweep(tmp_path):
60 | """Test optuna sweep."""
61 | command = [
62 | startfile,
63 | "-m",
64 | "hparams_search=mnist_optuna",
65 | "hydra.sweep.dir=" + str(tmp_path),
66 | "hydra.sweeper.n_trials=10",
67 | "hydra.sweeper.sampler.n_startup_trials=5",
68 | "++trainer.fast_dev_run=true",
69 | ] + overrides
70 | run_sh_command(command)
71 |
72 |
73 | @RunIf(wandb=True, sh=True)
74 | @pytest.mark.slow
75 | def test_optuna_sweep_ddp_sim_wandb(tmp_path):
76 | """Test optuna sweep with wandb and ddp sim."""
77 | command = [
78 | startfile,
79 | "-m",
80 | "hparams_search=mnist_optuna",
81 | "hydra.sweep.dir=" + str(tmp_path),
82 | "hydra.sweeper.n_trials=5",
83 | "trainer=ddp_sim",
84 | "trainer.max_epochs=3",
85 | "+trainer.limit_train_batches=0.01",
86 | "+trainer.limit_val_batches=0.1",
87 | "+trainer.limit_test_batches=0.1",
88 | "logger=wandb",
89 | ]
90 | run_sh_command(command)
91 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from hydra.core.hydra_config import HydraConfig
5 | from omegaconf import open_dict
6 |
7 | from src.train import train
8 | from tests.helpers.run_if import RunIf
9 |
10 |
11 | def test_train_fast_dev_run(cfg_train):
12 | """Run for 1 train, val and test step."""
13 | HydraConfig().set_config(cfg_train)
14 | with open_dict(cfg_train):
15 | cfg_train.trainer.fast_dev_run = True
16 | cfg_train.trainer.accelerator = "cpu"
17 | train(cfg_train)
18 |
19 |
20 | @RunIf(min_gpus=1)
21 | def test_train_fast_dev_run_gpu(cfg_train):
22 | """Run for 1 train, val and test step on GPU."""
23 | HydraConfig().set_config(cfg_train)
24 | with open_dict(cfg_train):
25 | cfg_train.trainer.fast_dev_run = True
26 | cfg_train.trainer.accelerator = "gpu"
27 | train(cfg_train)
28 |
29 |
30 | @RunIf(min_gpus=1)
31 | @pytest.mark.slow
32 | def test_train_epoch_gpu_amp(cfg_train):
33 | """Train 1 epoch on GPU with mixed-precision."""
34 | HydraConfig().set_config(cfg_train)
35 | with open_dict(cfg_train):
36 | cfg_train.trainer.max_epochs = 1
37 | cfg_train.trainer.accelerator = "cpu"
38 | cfg_train.trainer.precision = 16
39 | train(cfg_train)
40 |
41 |
42 | @pytest.mark.slow
43 | def test_train_epoch_double_val_loop(cfg_train):
44 | """Train 1 epoch with validation loop twice per epoch."""
45 | HydraConfig().set_config(cfg_train)
46 | with open_dict(cfg_train):
47 | cfg_train.trainer.max_epochs = 1
48 | cfg_train.trainer.val_check_interval = 0.5
49 | train(cfg_train)
50 |
51 |
52 | @pytest.mark.slow
53 | def test_train_ddp_sim(cfg_train):
54 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes."""
55 | HydraConfig().set_config(cfg_train)
56 | with open_dict(cfg_train):
57 | cfg_train.trainer.max_epochs = 2
58 | cfg_train.trainer.accelerator = "cpu"
59 | cfg_train.trainer.devices = 2
60 | cfg_train.trainer.strategy = "ddp_spawn"
61 | train(cfg_train)
62 |
63 |
64 | @pytest.mark.slow
65 | def test_train_resume(tmp_path, cfg_train):
66 | """Run 1 epoch, finish, and resume for another epoch."""
67 | with open_dict(cfg_train):
68 | cfg_train.trainer.max_epochs = 1
69 |
70 | HydraConfig().set_config(cfg_train)
71 | metric_dict_1, _ = train(cfg_train)
72 |
73 | files = os.listdir(tmp_path / "checkpoints")
74 | assert "last.ckpt" in files
75 | assert "epoch_000.ckpt" in files
76 |
77 | with open_dict(cfg_train):
78 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
79 | cfg_train.trainer.max_epochs = 2
80 |
81 | metric_dict_2, _ = train(cfg_train)
82 |
83 | files = os.listdir(tmp_path / "checkpoints")
84 | assert "epoch_001.ckpt" in files
85 | assert "epoch_002.ckpt" not in files
86 |
87 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"]
88 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]
89 |
--------------------------------------------------------------------------------