├── .gitattributes
├── .gitignore
├── LICENSE.txt
├── MODEL_CARD.md
├── README.md
├── assets
├── fig1.png
└── image.png
├── model_weights
├── trained_on_afdb.ckpt
└── trained_on_pdb.ckpt
├── multiflow_config.yaml
├── proteinblobs
├── blobs.py
├── constants.py
├── dataset.py
├── designability_utils.py
├── logger.py
├── multiflow
│ ├── LICENSE.txt
│ ├── cross_attention.py
│ ├── data
│ │ ├── all_atom.py
│ │ ├── interpolant.py
│ │ ├── protein.py
│ │ ├── residue_constants.py
│ │ ├── so3_utils.py
│ │ └── utils.py
│ ├── datasets.py
│ ├── edge_feature_net.py
│ ├── flow_model.py
│ ├── ipa_pytorch.py
│ ├── node_feature_net.py
│ ├── rigid_utils.py
│ └── utils.py
├── multiflow_wrapper.py
├── parsing.py
├── utils.py
├── visualize.py
└── wrapper.py
├── run_multiflow.py
├── sample.py
├── scripts
├── ca_only_helicity.py
├── create_multiflowdb.py
├── evaluate_alignment.py
├── evaluate_designability.py
├── evaluate_designability_codesign.py
├── preprocess_clusters.py
└── unpack_geniepdb.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ckpt filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | weights
3 | __pycache__
4 | workdir
5 | wandb
6 | results
7 | .cache
8 | afdb*
9 | index.txt
10 | run.sh
11 | outpdb
12 | sbatch
13 | run_tmp.sh
14 | *.ipynb
15 | aln
16 | outpdb
17 | tmp
18 | outpdb2
19 | data
20 | *.pkl
21 | figures
22 | *.pse
23 | *.pdb
24 | *.out
25 | core.*
26 | .DS_Store
27 | proteinblobs/multiflow/metadata
28 | *.html
29 | slurm_out
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | NVIDIA License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9 |
10 | 2. License Grant
11 |
12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13 |
14 | 3. Limitations
15 |
16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17 |
18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19 |
20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
21 |
22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23 |
24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25 |
26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27 |
28 | 4. Disclaimer of Warranty.
29 |
30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32 |
33 | 5. Limitation of Liability.
34 |
35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
36 |
--------------------------------------------------------------------------------
/MODEL_CARD.md:
--------------------------------------------------------------------------------
1 | # ProtComposer Model Card
2 |
3 | ## 1. Model Overview
4 |
5 | ### Description:
6 | ProtComposer is a generative model that generates novel protein structures. It leverages joint protein backbone and sequence flow matching and offers controllability through 3D ellipsoids that can be positioned by the user or an auxiliary model to guide the shape of the protein structure during generation. This enables compositional protein structure generation, advancing controllability in protein design tasks.
7 |
8 | The model is for research and non-commercial use only.
9 |
10 | ### License/Terms of Use:
11 | ProtComposer source code and models are licensed under the NVIDIA license, see [LICENSE.txt](LICENSE.txt). By using ProtComposer, you accept the terms and conditions of this license.
12 |
13 | ### References:
14 | Research paper: [“ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids,”](https://openreview.net/forum?id=0ctvBgKFgc)
15 |
16 | ### Model Architecture:
17 | - **Architecture Type:** Transformer (Attention, Cross-Attention, Invariant Point Attention)
18 | - **Network Architecture:** ProtComposer
19 |
20 | ### Input:
21 | - **Input Types:** Number (numbers for total number of protein residues to generate, generation step size, guidance scale and rotational annealing scale). Tensor (tensors for ellipsoid means, flattened ellipsoid covariance matrices and ellipsoid features).
22 | - **Input Formats:**
23 | - Number: Integer (residue number), FP32 (step size, guidance scale, annealing scale)
24 | - Tensor: Pytorch Tensor
25 | - **Input Parameters:**
26 | - Number (Integer and FP32): 1D
27 | - Tensor: 3D (batch elements, ellipsoids, features)
28 | - **Other Properties Related to Input:** Total number of protein residues to generate, generation step size, guidance scale, rotational annealing scale, ellipsoid number of residues and secondary structure type should be positive numbers.
29 |
30 | ### Output:
31 | - **Output Types:** Tensor (residue coordinates of generated protein). Tensor (amino acid types of residues of generated protein)
32 | - **Output Formats:**
33 | - Tensor (coordinates): Pytorch tensor
34 | - Tensor (amino acid types): Pytorch integer tensor
35 | - **Output Parameters:**
36 | - Tensor (coordinates): 3D (batch, length of protein, spatial dimensions)
37 | - Tensor (amino acid types): 2D (batch, length of protein)
38 |
39 | ### Software Integration:
40 | - **Runtime Engine:** Pytorch
41 | - **Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper
42 | - **Preferred Operating System:** Linux
43 |
44 | ### Model Versions:
45 | - **ProtComposer v1.0** (trained on Protein Data Bank)
46 | - **ProtComposer v1.1** (trained on AlphaFold Database)
47 |
48 | ## 2. Training and Evaluation Datasets
49 |
50 | ### Training Datasets:
51 | - **Protein Data Bank (PDB)**
52 | - **Link:** [https://www.rcsb.org/](https://www.rcsb.org/)
53 | - **Data Collection Method by dataset:** Automatic/Sensors/Human (experimental protein structure determination)
54 | - **Labeling Method by dataset:** N/A
55 | - **Properties:** The Protein Data Bank (PDB) contains approx. 200K experimentally determined three-dimensional structures of large biological molecules, such as proteins and nucleic acids, along with auxiliary information such as the protein sequences. We train ProtComposer on a filtered subset of the PDB, comprising 20,312 proteins. Specifically, we use monomers between length 60 and 512 with resolution < 5A downloaded from the PDB with a time cutoff on August 8, 2021. This resulted in 23,913 proteins. The data is further filtered by only including proteins with high secondary structure compositions. For each monomer, we ran the Define Secondary Structure of Proteins (DSSP) algorithm and removed monomers with more than 50% loops, resulting in 20,312 proteins.
56 | - **AlphaFold Protein Structure Database (AFDB)**
57 | - **Link:** [https://alphafold.ebi.ac.uk/](https://alphafold.ebi.ac.uk/)
58 | - **Data Collection Method by dataset:** Synthetic (AlphaFold predictions)
59 | - **Labeling Method by dataset:** N/A
60 | - **Properties:** The AlphaFold Protein Structure Database (AFDB) contains approx. 214M synthetic three-dimensional protein structures predicted by AlphaFold2, along with their corresponding sequences. We train ProtComposer on a subset of the AFDB, comprising 588,570 structures. We use a subsampled version that applies FoldSeek to cluster entries based on structural similarity. We start with all cluster representatives from the FoldSeek-clustered database and then filter them using a pLDDT threshold of >80, to enrich for highly confident predictions, and a maximum sequence length of 256. This results in 588,570 structures.
61 |
62 | ### Evaluation Dataset:
63 | - **Protein Data Bank (PDB)**
64 | - **Link:** [https://www.rcsb.org/](https://www.rcsb.org/)
65 | - **Data Collection Method by dataset:** Automatic/Sensors/Human (experimental protein structure determination)
66 | - **Labeling Method by dataset:** N/A
67 | - **Properties:** The Protein Data Bank (PDB) contains approx. 200K experimentally determined three-dimensional structures of large biological molecules, such as proteins and nucleic acids, along with auxiliary information such as the protein sequences. Similar to Multiflow [Campbell et al., “Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design”](https://arxiv.org/abs/2402.04997), we evaluate ProtComposer on a subset of the PDB, comprising 449 protein structures. Our evaluation set is based on a time-based split of the PDB. We downloaded structures and sequences from the PDB that were released between 1st September 2021 and 28th December 2023. We then select all single chain monomeric proteins with length between 50 and 400 inclusive. We further filter out proteins that are more than 50% coil residues and proteins that have a radius of gyration in the 96th percentile of the original dataset or above. We also filter out structures that have missing residues. We cluster proteins using the 30% sequence identity MMSeqs2 clustering provided by RCSB.org. We take a single protein from each cluster that matches our filtering criteria. This gives us an evaluation set of 449 proteins with minimum length 51 and maximum length 398.
68 |
69 | ### Inference:
70 | - **Engine:** Pytorch
71 | - **Test Hardware:** A100, H100
72 |
73 | ### Ethical Considerations:
74 | NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
75 |
76 | Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids
Oral at ICLR 2025
2 |
3 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Environment
22 |
23 | We recommend installing Miniconda (https://docs.anaconda.com/miniconda/install/) and creating the following conda environment:
24 | ```shell script
25 | conda create -n nv python=3.9
26 | pip install jupyterlab
27 | pip install numpy==1.21.2 pandas==1.5.3
28 | pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
29 | pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.1 absl-py einops
30 | pip install pytorch_lightning==2.0.4 fair-esm
31 | pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@5484c38'
32 | pip install matplotlib==3.7.2
33 | pip install pydssp biotite omegaconf wandb
34 | pip install numpy==1.21.2
35 | # pip install numpy throws an errror for contourpy but that is fine
36 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+cu113
37 | pip3 install -U scikit-learn
38 | pip install gpustat
39 | ```
40 |
41 | ## Pretrained checkpoints
42 |
43 | We provide two pretrained checkpoints in the directory `model_weights`. One is trained on data from the Protein Data Bank, the other on AlphaFold Database.
44 |
45 | | Description | Checkpoint path |
46 | |:----------|:----------|
47 | | AlphaFold Database training data | [`model_weights/trained_on_afdb.ckpt`](model_weights/trained_on_afdb.ckpt) |
48 | | Protein Data Bank training data | [`model_weights/trained_on_pdb.ckpt`](`model_weights/trained_on_pdb.ckpt`) |
49 |
50 | ## Sampling
51 |
52 | To sample ProtComposer conditioned on ellipsoids from our ellipsoid statistical model:
53 |
54 | ```shell script
55 | python sample.py --guidance 1.0 --num_prots 6 --nu 5 --sigma 6 --helix_frac 0.4 --seed 1 --outdir results --num_blobs 9 --ckpt "model_weights/trained_on_pdb.ckpt"
56 | ```
57 |
58 | ## Evaluation
59 |
60 | We use these scripts to compute metrics for the outputs (`evaluate_alignment` computes the ellipsoid adherence metrics):
61 | ```shell script
62 | python -m scripts.evaluate_designability --dir results
63 | python -m scripts.evaluate_alignment --dir results
64 | ```
65 |
66 |
67 | ## Training
68 |
69 |
70 | ### Data preparations
71 |
72 | We use the data from MultiFlow: https://github.com/jasonkyuyim/multiflow
73 |
74 | They host the datasets on Zenodo [here](https://zenodo.org/records/10714631?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjJjMTk2YjlmLTM4OTUtNGVhYi1hODcxLWE1ZjExOTczY2IzZiIsImRhdGEiOnt9LCJyYW5kb20iOiI4MDY5ZDUzYjVjMTNhNDllMDYxNmI3Yjc2NjcwYjYxZiJ9.C2eZZmRu-nu7H330G-DkV5kttfjYB3ANozdOMNm19uPahvtLrDRvd_4Eqlyb7lp24m06e4OHhHQ4zlj68S1O_A).
75 | Download the following files, and place them into the directory `data`
76 | * `real_train_set.tar.gz` (2.5 GB)
77 | * `synthetic_train_set.tar.gz` (220 MB)
78 | * `test_set.tar.gz` (347 MB)
79 | Next, untar the files
80 | ```bash
81 | # Uncompress training data
82 | mkdir train_set
83 | tar -xzvf real_train_set.tar.gz -C train_set/
84 | tar -xzvf synthetic_train_set.tar.gz -C train_set/
85 |
86 | # Uncompress test data
87 | mkdir test_set
88 | tar -xzvf test_set.tar.gz -C test_set/
89 | ```
90 |
91 | Download the files and put them into the directory `data` such that you obtain the following file structure (the `--pkl_dir` argument is `data` by default):
92 | ```bash
93 | data
94 | ├── train_set
95 | │ ├── processed_pdb
96 | | | ├──
97 | | | | └── .pkl
98 | │ ├── processed_synthetic
99 | | | └── .pkl
100 | ├── test_set
101 | | └── processed
102 | | | ├──
103 | | | | └── .pkl
104 | ...
105 | ```
106 |
107 | ### Launch training run
108 |
109 | By default, we train on 8 GPUs.
110 |
111 | ```shell script
112 | python train.py --batch_size 8 --designability --designability_freq 5 --num_designability_prots 50 --accumulate_grad 8 --inf_batches 5 --val_batches 5 --finetune --dataset multiflow --self_condition --num_workers 10 --save_val --epochs 1000 --run_name my_run_name --wandb
113 | ```
114 |
115 | ## License
116 |
117 | Code and model weights are released under an NVIDIA license for non-commercial or research purposes only. Please see the [LICENSE.txt](https://github.com/NVlabs/protcomposer/blob/main/LICENSE.txt) file.
118 |
119 |
120 | ## Citation
121 |
122 |
123 | ```bib
124 | @inproceedings{stark2025protcomposer,
125 | title={ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids},
126 | author={Hannes Stark and Bowen Jing and Tomas Geffner and Jason Yim and Tommi Jaakkola and Arash Vahdat and Karsten Kreis},
127 | booktitle={The Thirteenth International Conference on Learning Representations (ICLR)},
128 | year={2025},
129 | url={https://openreview.net/forum?id=0ctvBgKFgc}
130 | }
131 | ```
132 |
--------------------------------------------------------------------------------
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/protcomposer/0eefdc79100a4ed5b7f585f264821b87e6e14fed/assets/fig1.png
--------------------------------------------------------------------------------
/assets/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/protcomposer/0eefdc79100a4ed5b7f585f264821b87e6e14fed/assets/image.png
--------------------------------------------------------------------------------
/model_weights/trained_on_afdb.ckpt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:735949839257b87cd8bc5aa6e9770867685a8c3328ad33be351ddb5fa0cf1c69
3 | size 294278008
4 |
--------------------------------------------------------------------------------
/model_weights/trained_on_pdb.ckpt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c48de733ac64d4d38475f1b2e8e820be297aa992c9cea1cb7acf339762428c40
3 | size 294344824
4 |
--------------------------------------------------------------------------------
/multiflow_config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset: pdb
3 | loader:
4 | num_workers: 8
5 | prefetch_factor: 10
6 | sampler:
7 | max_batch_size: 64
8 | max_num_res_squared: 400000
9 | task: hallucination
10 | interpolant:
11 | min_t: 0.01
12 | separate_t: false
13 | provide_kappa: false
14 | hierarchical_t: false
15 | codesign_separate_t: true
16 | codesign_forward_fold_prop: 0.1
17 | codesign_inverse_fold_prop: 0.1
18 | rots:
19 | corrupt: true
20 | train_schedule: linear
21 | sample_schedule: exp
22 | exp_rate: 10
23 | trans:
24 | corrupt: true
25 | batch_ot: true
26 | train_schedule: linear
27 | sample_schedule: linear
28 | sample_temp: 1.0
29 | vpsde_bmin: 0.1
30 | vpsde_bmax: 20.0
31 | potential: null
32 | potential_t_scaling: false
33 | rog:
34 | weight: 10.0
35 | cutoff: 5.0
36 | aatypes:
37 | corrupt: true
38 | schedule: linear
39 | schedule_exp_rate: -3
40 | temp: 0.1
41 | noise: 0.0
42 | do_purity: true
43 | train_extra_mask: 0.0
44 | interpolant_type: masking
45 | sampling:
46 | num_timesteps: 100
47 | do_sde: false
48 | self_condition: ${model.edge_features.self_condition}
49 | twisting:
50 | use: false
51 | folding:
52 | seq_per_sample: 8
53 | own_device: false
54 | folding_model: esmf
55 | pmpnn_path: ./ProteinMPNN/
56 | pt_hub_dir: ./.cache/torch/
57 | experiment:
58 | debug: false
59 | seed: 123
60 | num_devices: 4
61 | warm_start: flow-matching/ckpt/se3-fm/reproduce_v2/2023-12-28_14-23-11/epoch=59-step=28380.ckpt
62 | warm_start_cfg_override: false
63 | raw_state_dict_reload: null
64 | training:
65 | mask_plddt: true
66 | bb_atom_scale: 0.1
67 | trans_scale: 0.1
68 | aatypes_loss_weight: 1.0
69 | aatypes_label_smoothing: 0.0
70 | aatypes_loss_mean_or_sum: mean
71 | aatypes_loss_use_likelihood_weighting: false
72 | translation_loss_weight: 2.0
73 | t_normalize_clip: 0.9
74 | rotation_loss_weights: 1.0
75 | aux_loss_weight: 0.0
76 | aux_loss_use_bb_loss: true
77 | aux_loss_use_pair_loss: true
78 | aux_loss_t_pass: 0.5
79 | wandb:
80 | name: separate_t
81 | project: se3-fm
82 | optimizer:
83 | lr: 0.0001
84 | trainer:
85 | overfit_batches: 0
86 | min_epochs: 1
87 | max_epochs: 200
88 | accelerator: gpu
89 | log_every_n_steps: 1
90 | deterministic: false
91 | strategy: ddp
92 | check_val_every_n_epoch: 6
93 | accumulate_grad_batches: 2
94 | checkpointer:
95 | dirpath: ./
96 | save_last: true
97 | save_top_k: 3
98 | every_n_epochs: 4
99 | monitor: valid/codesign_bb_rmsd
100 | mode: min
101 | inference_dir: ./inference_outputs/./weights/last/unconditional/run_2024-03-02_11-49-28
102 | shared:
103 | seed: 123
104 | max_cache_size: 100000
105 | samples_per_eval_length: 5
106 | num_eval_lengths: 8
107 | max_eval_length: 256
108 | min_motif_percent: 0.05
109 | max_motif_percent: 0.5
110 | pdb_dataset:
111 | seed: ${shared.seed}
112 | csv_path: proteinblobs/multiflow/metadata/pdb_metadata.csv
113 | cluster_path: proteinblobs/multiflow/metadata/pdb.clusters
114 | test_set_pdb_ids_path: null
115 | max_cache_size: ${shared.max_cache_size}
116 | cache_num_res: 0
117 | inpainting_percent: 1.0
118 | add_plddt_mask: false
119 | max_eval_length: ${shared.max_eval_length}
120 | redesigned_csv_path: proteinblobs/multiflow/metadata/pdb_redesigned.csv
121 | use_redesigned: true
122 | synthetic_csv_path: proteinblobs/multiflow/metadata/distillation_metadata.csv
123 | synthetic_cluster_path: proteinblobs/multiflow/metadata/distillation.clusters
124 | use_synthetic: true
125 | samples_per_eval_length: ${shared.samples_per_eval_length}
126 | num_eval_lengths: ${shared.num_eval_lengths}
127 | filter:
128 | max_num_res: 384
129 | min_num_res: 60
130 | max_coil_percent: 0.5
131 | rog_quantile: 0.96
132 | oligomeric:
133 | - monomeric
134 | num_chains:
135 | - 1
136 | min_motif_percent: ${shared.min_motif_percent}
137 | max_motif_percent: ${shared.max_motif_percent}
138 | is_mixed: false
139 | pdb_post2021_dataset:
140 | seed: ${shared.seed}
141 | csv_path: proteinblobs/multiflow/metadata/test_set_metadata.csv
142 | cluster_path: proteinblobs/multiflow/metadata/test_set_clusters.csv
143 | test_set_pdb_ids_path: proteinblobs/multiflow/metadata/test_set_pdb_ids.csv
144 | max_cache_size: ${shared.max_cache_size}
145 | cache_num_res: 0
146 | add_plddt_mask: False
147 | max_eval_length: ${shared.max_eval_length}
148 | redesigned_csv_path: null
149 | use_redesigned: False
150 |
151 | synthetic_csv_path: null
152 | synthetic_cluster_path: null
153 | use_synthetic: False
154 |
155 | # Eval parameters
156 | samples_per_eval_length: ${shared.samples_per_eval_length}
157 | num_eval_lengths: ${shared.num_eval_lengths}
158 |
159 | # Filtering
160 | filter:
161 | max_num_res: 400
162 | min_num_res: 50
163 | max_coil_percent: 0.5
164 | rog_quantile: 0.96
165 | oligomeric: ['monomeric']
166 | num_chains: [1]
167 | model:
168 | node_embed_size: 256
169 | edge_embed_size: 128
170 | symmetric: false
171 | aatype_pred: true
172 | transformer_dropout: 0.2
173 | aatype_pred_num_tokens: 21
174 | node_features:
175 | c_s: ${model.node_embed_size}
176 | c_pos_emb: 128
177 | c_timestep_emb: 128
178 | max_num_res: 2000
179 | timestep_int: 1000
180 | embed_chain: false
181 | embed_aatype: true
182 | use_mlp: true
183 | aatype_pred_num_tokens: ${model.aatype_pred_num_tokens}
184 | edge_features:
185 | single_bias_transition_n: 2
186 | c_s: ${model.node_embed_size}
187 | c_p: ${model.edge_embed_size}
188 | relpos_k: 64
189 | feat_dim: 64
190 | num_bins: 22
191 | self_condition: true
192 | embed_chain: false
193 | embed_diffuse_mask: true
194 | ipa:
195 | c_s: ${model.node_embed_size}
196 | c_z: ${model.edge_embed_size}
197 | c_hidden: 16
198 | no_heads: 8
199 | no_qk_points: 8
200 | no_v_points: 12
201 | seq_tfmr_num_heads: 4
202 | seq_tfmr_num_layers: 4
203 | num_blocks: 8
204 | dropout: 0.0
205 | use_sequence_net: false
206 | sequence_net:
207 | num_layers: 6
208 | use_init_embed: true
209 | use_init_rigid: false
210 | use_local_attention: false
211 | init_node_embed: ${model.node_embed_size}
212 | init_edge_embed: ${model.edge_embed_size}
213 | ipa:
214 | c_s: 128
215 | c_z: 64
216 | c_hidden: 128
217 | no_heads: 4
218 | no_qk_points: 8
219 | no_v_points: 8
220 | dropout: 0.2
221 | inference:
222 | predict_dir: ./inference_outputs/
223 | inference_subdir: run_${now:%Y-%m-%d}_${now:%H-%M-%S}
224 | task: unconditional
225 | seed: 123
226 | use_gpu: true
227 | num_gpus: 1
228 | saved_ckpt_dir: ./saved_ckpts/frameflow
229 | unconditional_ckpt_path: ./weights/last.ckpt
230 | also_fold_pmpnn_seq: true
231 | write_sample_trajectories: false
232 | interpolant:
233 | min_t: 0.01
234 | provide_kappa: false
235 | codesign_separate_t: false
236 | rots:
237 | corrupt: true
238 | sample_schedule: exp
239 | exp_rate: 10
240 | trans:
241 | corrupt: true
242 | pre_align: true
243 | train_schedule: linear
244 | sample_schedule: linear
245 | sample_temp: 1.0
246 | potential: null
247 | potential_t_scaling: false
248 | rog:
249 | weight: 20.0
250 | cutoff: 5.0
251 | aatypes:
252 | corrupt: true
253 | schedule: linear
254 | schedule_exp_rate: -3
255 | temp: 0.1
256 | noise: 20.0
257 | do_purity: true
258 | interpolant_type: masking
259 | sampling:
260 | num_timesteps: 500
261 | do_sde: false
262 | self_condition: true
263 | pmpnn_dir: ./ProteinMPNN
264 | folding:
265 | seq_per_sample: 8
266 | folding_model: esmf
267 | own_device: false
268 | pmpnn_path: ./ProteinMPNN/
269 | pt_hub_dir: ./.cache/torch/
270 | colabfold_path: path/to/colabfold-conda/bin/colabfold_batch
271 | samples:
272 | samples_per_length: 100
273 | num_batch: 1
274 | length_subset:
275 | - 256
276 | min_length: 60
277 | max_length: 256
278 | length_step: 1
--------------------------------------------------------------------------------
/proteinblobs/blobs.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | import numpy as np
13 | import networkx as nx
14 | import numpy as np
15 | from scipy.special import logsumexp, softmax
16 | import torch
17 | from torch.distributions.multivariate_normal import MultivariateNormal
18 | from scipy.stats import wishart
19 |
20 | alpha_slope = 0.22213567936924072
21 | alpha_intercept = 7.512307373103329
22 |
23 | beta_slope = 0.2727479880396384
24 | beta_intercept = 8.28561006547174
25 |
26 |
27 | def blobify(pos, dssp, radius_thresh=5, size_thresh=5):
28 | distmat = np.square(pos[:, None] - pos[None]).sum(-1) ** 0.5
29 | G = nx.Graph()
30 | edges = np.argwhere((distmat < radius_thresh) & (dssp[None] == dssp[:, None]))
31 | G.add_edges_from(edges)
32 | blobs = []
33 | for con in nx.connected_components(G):
34 | con = list(con)
35 | if dssp[con[0]] == 0:
36 | continue
37 | if len(con) < size_thresh:
38 | continue
39 | blobs.append(
40 | {
41 | "residues": con,
42 | "count": len(con),
43 | "dssp": dssp[con[0]],
44 | "pos": pos[con].mean(0).astype(float),
45 | "covar": np.cov(pos[con].T).astype(float),
46 | }
47 | )
48 |
49 | return blobs
50 |
51 |
52 | def shannon_complexity(pos, dssp, thresh=5):
53 | blobs = blobify(pos, dssp, radius_thresh=thresh)
54 | counts = np.array([b["count"] for b in blobs])
55 | probs = counts / counts.sum()
56 | return -(probs * np.log(probs)).sum()
57 |
58 |
59 | def score_blobs(centers, covars):
60 | energy = 0
61 | for pos, cov in zip(centers, covars):
62 | relpos = centers - pos
63 | d = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(cov), relpos) ** 0.5
64 | d = d[d != 0]
65 | energy += (1 / d**2).sum()
66 | return energy
67 |
68 |
69 | def sample_blobs(k, nu, psi, sigma):
70 | while True:
71 | covar = wishart.rvs(nu, psi, size=k)
72 | if k == 1:
73 | covar = covar[None]
74 | pos = np.random.randn(k, 3) * sigma
75 | score = score_blobs(pos, covar)
76 | if np.random.rand() < np.exp(-score):
77 | return pos, covar
78 |
79 |
80 | def gmm_ll(pos, centers, covars, probs):
81 | probs = probs / probs.sum()
82 | relpos = pos[:, None] - centers[None]
83 | ll = (
84 | -0.5 * np.einsum("lni,nij,lnj->ln", relpos, np.linalg.inv(covars), relpos)
85 | - 0.5 * np.linalg.slogdet(covars)[1][None]
86 | + np.log(probs)[None]
87 | )
88 | return logsumexp(ll, -1)
89 |
90 |
91 | def dssp_gmm_ll(pos, centers, covars, probs, pos_dssp, centers_dssp):
92 | if np.any(centers_dssp == 1):
93 | gmm_ll_1 = gmm_ll(
94 | pos,
95 | centers[centers_dssp == 1],
96 | covars[centers_dssp == 1],
97 | probs[centers_dssp == 1],
98 | )
99 | gmm_ll_1[pos_dssp != 1] = -np.inf
100 | else:
101 | gmm_ll_1 = -np.inf * np.ones(len(pos))
102 |
103 | if np.any(centers_dssp == 2):
104 | gmm_ll_2 = gmm_ll(
105 | pos,
106 | centers[centers_dssp == 2],
107 | covars[centers_dssp == 2],
108 | probs[centers_dssp == 2],
109 | )
110 | gmm_ll_2[pos_dssp != 2] = -np.inf
111 | else:
112 | gmm_ll_2 = -np.inf * np.ones(len(pos))
113 |
114 | dssp_1_count = probs[centers_dssp == 1].sum() / probs.sum()
115 | dssp_2_count = probs[centers_dssp == 2].sum() / probs.sum()
116 |
117 | return np.log(np.exp(gmm_ll_1) * dssp_1_count + np.exp(gmm_ll_2) * dssp_2_count)
118 |
119 |
120 | """
121 | Treats the blobs as a GMM and evaluates the average log-DENSITY at each pos.
122 | Note that DENSITY is N * PDF where N is the total number of residues in the blobs.
123 | Not suitable for partial blobs.
124 | """
125 |
126 |
127 | def blob_likelihood(pos, dssp, blobs, structured_only=False, mean=True):
128 | if structured_only:
129 | pos = pos[dssp != 0]
130 | dssp = dssp[dssp != 0]
131 | if len(dssp) == 0:
132 | return -10
133 | if len(blobs) == 0:
134 | ll = -np.inf * np.ones(pos.shape[0])
135 | else:
136 | covars = np.stack([blob["covar"] for blob in blobs])
137 | counts = np.array([blob["count"] for blob in blobs])
138 | centers = np.stack([blob["pos"] for blob in blobs])
139 |
140 | ll = gmm_ll(pos, centers, covars, counts)
141 | ll += np.log(counts.sum())
142 |
143 | if not mean:
144 | return ll
145 | return ll.mean()
146 |
147 |
148 | """
149 | Reblobs the protein and evaluates the number of residues per blob
150 | """
151 |
152 |
153 | def blobs_per_res(pos, dssp, thresh=5):
154 | blobs = blobify(pos, dssp, thresh)
155 | return len(blobs) / len(pos)
156 |
157 |
158 | def res_per_blob(pos, dssp, thresh=5):
159 | blobs = blobify(pos, dssp, thresh)
160 | return len(pos) / len(blobs)
161 |
162 |
163 | """
164 | Reblobs the protein and evaluates the JSD between the GMMs
165 | defined by the original and new blobs.
166 | Not suitable for partial blobs.
167 | If use_dssp=True, consider sample space R^3 x {1,2}
168 | """
169 |
170 |
171 | def reblob_jsd(pos, dssp, blobs, num_samples=1000, use_dssp=False, thresh=5):
172 | p_blobs = blobs
173 | q_blobs = blobify(pos, dssp, thresh)
174 |
175 | p_covars = np.stack([blob["covar"] for blob in p_blobs])
176 | p_counts = np.array([blob["count"] for blob in p_blobs])
177 | p_centers = np.stack([blob["pos"] for blob in p_blobs])
178 | p_dssp = np.array([blob["dssp"] for blob in p_blobs])
179 |
180 | q_covars = np.stack([blob["covar"] for blob in q_blobs])
181 | q_counts = np.array([blob["count"] for blob in q_blobs])
182 | q_centers = np.stack([blob["pos"] for blob in q_blobs])
183 | q_dssp = np.array([blob["dssp"] for blob in q_blobs])
184 |
185 | p_idx = np.random.choice(
186 | np.arange(len(p_blobs)), size=num_samples, p=p_counts / p_counts.sum()
187 | )
188 | p_samps = MultivariateNormal(
189 | torch.from_numpy(p_centers[p_idx]).float(),
190 | torch.from_numpy(p_covars[p_idx]).float(),
191 | ).sample()
192 | if use_dssp:
193 | p_p = dssp_gmm_ll(p_samps, p_centers, p_covars, p_counts, p_dssp[p_idx], p_dssp)
194 | p_q = dssp_gmm_ll(p_samps, q_centers, q_covars, q_counts, p_dssp[p_idx], q_dssp)
195 |
196 | else:
197 | p_p = gmm_ll(p_samps, p_centers, p_covars, p_counts)
198 | p_q = gmm_ll(p_samps, q_centers, q_covars, q_counts)
199 |
200 | p_m = np.log(np.exp(p_p) / 2 + np.exp(p_q) / 2)
201 | kl_p = (p_p - p_m).mean()
202 |
203 | q_idx = np.random.choice(
204 | np.arange(len(q_blobs)), size=num_samples, p=q_counts / q_counts.sum()
205 | )
206 | q_samps = MultivariateNormal(
207 | torch.from_numpy(q_centers[q_idx]).float(),
208 | torch.from_numpy(q_covars[q_idx]).float(),
209 | ).sample()
210 |
211 | if use_dssp:
212 | q_p = dssp_gmm_ll(q_samps, p_centers, p_covars, p_counts, q_dssp[q_idx], p_dssp)
213 | q_q = dssp_gmm_ll(q_samps, q_centers, q_covars, q_counts, q_dssp[q_idx], q_dssp)
214 | else:
215 | q_p = gmm_ll(q_samps, p_centers, p_covars, p_counts)
216 | q_q = gmm_ll(q_samps, q_centers, q_covars, q_counts)
217 | q_m = np.log(np.exp(q_p) / 2 + np.exp(q_q) / 2)
218 | kl_q = (q_q - q_m).mean()
219 |
220 | return (kl_p + kl_q) / 2
221 |
222 |
223 | """
224 | Calculate the sum of the difference in the fraction of residues
225 | assigned to a blob and the fraction of total residues actually placed in the blob.
226 | """
227 |
228 |
229 | def blob_misplacement(pos, dssp, blobs, thresh=2.25, structured_only=False):
230 | if structured_only:
231 | pos = pos[dssp != 0]
232 | dssp = dssp[dssp != 0]
233 | if len(dssp) == 0:
234 | return 1
235 |
236 | assigned = []
237 | actual = []
238 | for b in blobs:
239 | relpos = pos - b["pos"]
240 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos)
241 | mask = dists**0.5 < thresh
242 | actual.append(mask.sum())
243 | assigned.append(b["count"])
244 |
245 | assigned = np.array(assigned) / np.sum(assigned)
246 | actual = np.array(actual) / np.sum(actual)
247 | return np.sum(np.abs(assigned - actual))
248 |
249 |
250 | """
251 | The fraction of residues within some Mahalanobis distance of each blob
252 | with the correct secondary structure.
253 | Suitable for partial blobs.
254 | """
255 |
256 |
257 | def blob_accuracy(pos, dssp, blobs, thresh=2.25, structured_only=False):
258 | if structured_only:
259 | pos = pos[dssp != 0]
260 | dssp = dssp[dssp != 0]
261 | if len(dssp) == 0:
262 | return 0
263 |
264 | correct, total = 0, 0
265 |
266 | for b in blobs:
267 | relpos = pos - b["pos"]
268 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos)
269 | mask = dists**0.5 < thresh
270 | correct += (dssp[mask] == b["dssp"]).sum()
271 | total += mask.sum()
272 |
273 | return correct / total
274 |
275 |
276 | """
277 | Based on sheet and helix GMMs, computes p(ss | pos) for each pos in the new protein.
278 | Reports the average p(ss | pos) over all non-loop residues.
279 | Not suitable for partial blobs.
280 | """
281 |
282 |
283 | def soft_blob_accuracy(pos, dssp, blobs, structured_only=False):
284 | if structured_only:
285 | pos = pos[dssp != 0]
286 | dssp = dssp[dssp != 0]
287 | if len(dssp) == 0:
288 | return 0
289 |
290 | sheet_ll = blob_likelihood(
291 | pos, dssp, [b for b in blobs if b["dssp"] == 1], mean=False
292 | )
293 | helix_ll = blob_likelihood(
294 | pos, dssp, [b for b in blobs if b["dssp"] == 2], mean=False
295 | )
296 |
297 | probs = softmax(np.stack([sheet_ll, helix_ll], -1), -1)
298 | return ((probs[:, 0][dssp == 1]).sum() + (probs[:, 1][dssp == 2]).sum()) / len(dssp)
299 |
300 |
301 | """
302 | Reports the fraction of residues within some Mahalanobis distance of some blob.
303 | Not suitable for partial blobs.
304 | """
305 |
306 |
307 | def blob_coverage(pos, dssp, blobs, thresh=2.25, structured_only=False):
308 | if structured_only:
309 | pos = pos[dssp != 0]
310 | dssp = dssp[dssp != 0]
311 | if len(dssp) == 0:
312 | return 0
313 |
314 | idx = np.arange(len(dssp))
315 | seen = []
316 |
317 | for b in blobs:
318 | relpos = pos - b["pos"]
319 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos)
320 | mask = dists**0.5 < thresh
321 | seen.extend(idx[mask])
322 |
323 | seen = list(set(seen))
324 | total = len(dssp)
325 | return len(seen) / total
326 |
327 |
328 | """
329 | Counts the number of residues within some Mahalanobis distance of some blob,
330 | relative to the total number expected for all blobs collectively.
331 | If unique=True, then residues within cutoff of multiple blobs will only be counted once.
332 | Suitable for partial blobs.
333 | """
334 |
335 |
336 | def blob_occupancy(pos, dssp, blobs, thresh=2.25, structured_only=False, unique=False):
337 | if structured_only:
338 | pos = pos[dssp != 0]
339 | dssp = dssp[dssp != 0]
340 | if len(dssp) == 0:
341 | return 0
342 |
343 | idx = np.arange(len(dssp))
344 | seen = []
345 |
346 | for b in blobs:
347 | relpos = pos - b["pos"]
348 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos)
349 | mask = dists**0.5 < thresh
350 | seen.extend(idx[mask])
351 |
352 | if unique:
353 | seen = list(set(seen))
354 | total = sum(b["count"] for b in blobs)
355 | return len(seen) / total
356 |
--------------------------------------------------------------------------------
/proteinblobs/constants.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | PYMOL_COLORS = ['grey',
13 | 'dash',
14 | 'lightteal',
15 | 'br5',
16 | 'hotpink',
17 | 'deepteal',
18 | 'yellow',
19 | 'tv_blue',
20 | 'yelloworange',
21 | 'slate',
22 | 'palecyan',
23 | 'oxygen',
24 | 'salmon',
25 | 'chocolate',
26 | 'gray',
27 | 'orange',
28 | 'blue',
29 | 'brown',
30 | 'green',
31 | 'ruby',
32 | 'forest',
33 | 'br0',
34 | 'black',
35 | 'firebrick',
36 | 'lightmagenta',
37 | 'dirtyviolet',
38 | 'lightpink',
39 | 'sulfur',
40 | 'deeppurple',
41 | 'br3',
42 | 'marine',
43 | 'darksalmon',
44 | 'br7',
45 | 'nitrogen',
46 | 'limon',
47 | 'br6',
48 | 'deepblue',
49 | 'bluewhite',
50 | 'lightorange',
51 | 'palegreen',
52 | 'carbon',
53 | 'lime',
54 | 'teal',
55 | 'tv_green',
56 | 'smudge',
57 | 'wheat',
58 | 'cyan',
59 | 'deepsalmon',
60 | 'splitpea',
61 | 'limegreen',
62 | 'purpleblue',
63 | 'red',
64 | 'tv_red',
65 | 'density',
66 | 'warmpink',
67 | 'deepolive',
68 | 'sand',
69 | 'raspberry',
70 | 'violetpurple',
71 | 'paleyellow',
72 | 'hydrogen',
73 | 'tv_orange',
74 | 'brightorange',
75 | 'br8',
76 | 'chartreuse',
77 | 'lightblue',
78 | 'br4',
79 | 'aquamarine',
80 | 'tv_yellow',
81 | 'br1',
82 | 'greencyan',
83 | 'magenta',
84 | 'pink',
85 | 'br9',
86 | 'olive',
87 | 'purple',
88 | 'white',
89 | 'skyblue',
90 | 'violet',
91 | 'br2']
--------------------------------------------------------------------------------
/proteinblobs/dataset.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import os
12 | import pickle
13 | import pydssp
14 | import math
15 | import numpy as np
16 | import torch
17 | import pickle
18 | from openfold.np import protein
19 | from torch.utils.data import default_collate
20 | from sklearn.cluster import KMeans
21 | import pandas as pd
22 | from openfold.utils.rigid_utils import Rigid, Rotation
23 | from openfold.np.residue_constants import restype_order, atom_order
24 | from scipy.spatial.transform import Rotation as spRotation
25 | from .blobs import blobify
26 | import proteinblobs.blobs as blobs
27 | import proteinblobs.multiflow.datasets as mfdatasets
28 | from pathlib import Path
29 |
30 |
31 | def from_3_points(
32 | p_neg_x_axis: torch.Tensor,
33 | origin: torch.Tensor,
34 | p_xy_plane: torch.Tensor,
35 | eps: float = 1e-8,
36 | ) -> Rigid:
37 | p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
38 | origin = torch.unbind(origin, dim=-1)
39 | p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
40 |
41 | e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
42 | e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
43 |
44 | the_sum = sum((c * c for c in e0))
45 | the_other_sum = the_sum + eps
46 |
47 | ###### Workaround to prevent unexplicable error when taking sqrt with torch.
48 | sum_np = the_other_sum.numpy()
49 | sqrt = np.sqrt(sum_np)
50 | # denom = torch.sqrt(the_other_sum)
51 | denom = torch.from_numpy(sqrt)
52 |
53 | e0 = [c / denom for c in e0]
54 | dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
55 | e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
56 | denom = torch.sqrt(sum((c * c for c in e1)) + eps)
57 | e1 = [c / denom for c in e1]
58 | e2 = [
59 | e0[1] * e1[2] - e0[2] * e1[1],
60 | e0[2] * e1[0] - e0[0] * e1[2],
61 | e0[0] * e1[1] - e0[1] * e1[0],
62 | ]
63 | rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
64 | rots = rots.reshape(rots.shape[:-1] + (3, 3))
65 | rot_obj = Rotation(rot_mats=rots, quats=None)
66 |
67 | return Rigid(rot_obj, torch.stack(origin, dim=-1))
68 |
69 |
70 | def prot_to_frames(ca_coords, c_coords, n_coords):
71 | prot_frames = from_3_points(
72 | torch.from_numpy(c_coords),
73 | torch.from_numpy(ca_coords),
74 | torch.from_numpy(n_coords),
75 | )
76 | rots = torch.eye(3)
77 | rots[0, 0] = -1
78 | rots[2, 2] = -1
79 | rots = Rotation(rot_mats=rots)
80 | frames = prot_frames.compose(Rigid(rots, None))
81 | return frames
82 |
83 |
84 | class SeqCollate:
85 | def __init__(self, args):
86 | self.args = args
87 | if self.args.no_crop or self.args.no_pad:
88 | self.seq_len_keys = [
89 | "grounding_feat",
90 | "grounding_pos",
91 | "grounding_mask",
92 | "chain_idx",
93 | "res_idx",
94 | "seqres",
95 | "res_mask",
96 | "mask",
97 | "rots",
98 | "trans",
99 | ]
100 | else:
101 | self.seq_len_keys = [
102 | "grounding_feat",
103 | "grounding_pos",
104 | "grounding_mask",
105 | ]
106 |
107 | def __call__(self, batch):
108 | seq_len_batch = {}
109 | for key in self.seq_len_keys:
110 | elems = [item[key] for item in batch]
111 | max_L = max([len(elem) for elem in elems])
112 | mask = torch.zeros((len(elems), max_L), dtype=torch.int16)
113 | elem_tensor = []
114 | for i, elem in enumerate(elems):
115 | L = len(elem)
116 | if isinstance(elem, torch.Tensor):
117 | elem = elem.numpy()
118 | elem = np.concatenate(
119 | [elem, np.zeros((max_L - L, *elem.shape[1:]), dtype=elem.dtype)],
120 | axis=0,
121 | )
122 | elem_tensor.append(elem)
123 | mask[i, :L] = 1
124 | seq_len_batch[key] = torch.from_numpy(np.stack(elem_tensor, axis=0))
125 | seq_len_batch[f"{key}_mask"] = mask
126 |
127 | # remove all self.seq_len_keys from batch and put it through default collate
128 | for item in batch:
129 | for key in self.seq_len_keys:
130 | del item[key]
131 |
132 | batch = default_collate(batch)
133 | batch.update(seq_len_batch)
134 | return batch
135 |
136 |
137 | class BlobDataset(torch.utils.data.Dataset):
138 | def __init__(self, blobs):
139 | self.blobs = blobs
140 |
141 | def __len__(self):
142 | return len(self.blobs)
143 |
144 |
145 | class StructureDataset(torch.utils.data.Dataset):
146 | def __init__(self, args):
147 | super().__init__()
148 | self.args = args
149 | if self.args.length_dist_npz is not None:
150 | lens = np.load(self.args.length_dist_npz)["lengths"]
151 | sample_lens = []
152 | for i in range(100000):
153 | choice = 0
154 | while choice < 4 or choice > args.crop:
155 | choice = np.random.choice(lens)
156 | sample_lens.append(choice)
157 | self.length_dist = np.array(sample_lens)
158 | if self.args.use_latents:
159 | latents_names = pickle.load(open(self.args.latents_order, "rb"))
160 | latents_names = [Path(f).stem for f in latents_names]
161 | latents_tensors = torch.load(self.args.latents_path)
162 | self.latents = {
163 | latents_names[i].lower(): latents_tensors[i]
164 | for i in range(len(latents_names))
165 | }
166 |
167 | def __len__(self):
168 | return None
169 |
170 | def __getitem__(self, idx):
171 | return None
172 |
173 | def process_prot(self, idx, name, pdb=None, prot=None):
174 | if prot is None:
175 | prot = protein.from_pdb_string(pdb)
176 | atom37 = prot.atom_positions.astype(np.float32)
177 | frames = prot_to_frames(
178 | ca_coords=prot.atom_positions[:, atom_order["CA"]],
179 | c_coords=prot.atom_positions[:, atom_order["C"]],
180 | n_coords=prot.atom_positions[:, atom_order["N"]],
181 | )
182 | res_mask = np.ones(atom37.shape[0], dtype=np.float32)
183 | seqres = prot.aatype.astype(int)
184 | res_idx = prot.residue_index
185 | chain_idx = np.zeros_like(seqres)
186 | return self.item_from_prot(
187 | name, atom37, frames, res_mask, seqres, res_idx, chain_idx
188 | )
189 |
190 | def item_from_prot(
191 | self, name, atom37, frames, res_mask, seqres, res_idx, chain_idx
192 | ):
193 | mask = np.ones(atom37.shape[0], dtype=np.float32)
194 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py)
195 | bb_pos = np.concatenate(
196 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1
197 | ) # (L, 4, 3)
198 |
199 | L = frames.shape[0]
200 | ## filter lenght and run pydssp secondary structure determination
201 | if L < 8:
202 | print(
203 | f"Only {L} residues in the protein {name}. Resampling another protein."
204 | )
205 | return self.__getitem__(np.random.randint(len(self)))
206 | try:
207 | dssp = pydssp.assign(
208 | bb_pos, out_type="index"
209 | ) # 0: loop, 1: alpha-helix, 2: beta-strand
210 | except Exception as e:
211 | print(
212 | f"Running pydssp failed in the protein {name}. Resampling another protein."
213 | )
214 | print(str(e))
215 | return self.__getitem__(np.random.randint(len(self)))
216 |
217 | ## crop
218 | if (not self.args.no_crop) and L > self.args.crop:
219 | start = np.random.randint(0, L - self.args.crop + 1)
220 | if self.args.overfit:
221 | start = 0
222 | frames = frames[start : start + self.args.crop]
223 | mask = mask[start : start + self.args.crop]
224 | res_mask = res_mask[start : start + self.args.crop]
225 | seqres = seqres[start : start + self.args.crop]
226 | chain_idx = chain_idx[start : start + self.args.crop]
227 | res_idx = res_idx[start : start + self.args.crop]
228 | dssp = dssp[start : start + self.args.crop]
229 |
230 | ## center
231 | com = (frames._trans * mask[:, None]).sum(0) / mask.sum()
232 | frames._trans -= com
233 |
234 | ## rotate
235 | randrot = spRotation.random().as_matrix().astype(np.float32)
236 | randrot = torch.from_numpy(randrot)
237 | frames._trans = frames._trans @ randrot.T
238 | frames._rots._rot_mats = randrot @ frames._rots._rot_mats
239 |
240 | ## label
241 | thresh = (
242 | np.random.rand() * (self.args.max_blob_thresh - self.args.min_blob_thresh)
243 | + self.args.min_blob_thresh
244 | )
245 | blobs = blobify(frames._trans.numpy(), dssp, thresh)
246 |
247 | if not blobs:
248 | print(f"No blobs in the protein {name}. Resampling another protein.")
249 | return self.__getitem__(np.random.randint(len(self)))
250 |
251 | if self.args.blob_drop_prob > 0.0 and len(blobs) > 1:
252 | np.random.shuffle(blobs)
253 | n = np.random.randint(1, len(blobs))
254 | blobs = blobs[:n]
255 |
256 | ## pad
257 | if (not self.args.no_pad) and L < self.args.crop:
258 | pad = self.args.crop - L
259 | frames = Rigid.cat(
260 | [frames, Rigid.identity((pad,), requires_grad=False, fmt="rot_mat")], 0
261 | )
262 |
263 | mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)])
264 | res_mask = np.concatenate([res_mask, np.zeros(pad, dtype=np.float32)])
265 | seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)])
266 | res_idx = np.concatenate([res_idx, np.zeros(pad, dtype=int)])
267 | chain_idx = np.concatenate([chain_idx, np.zeros(pad, dtype=int)])
268 |
269 | ## featurize
270 | if self.args.synthetic_blobs:
271 | grounding_pos, grounding_feat, grounding_mask = self.get_synthetic_blobs()
272 | prot_size = int(grounding_feat[:, 1].sum())
273 | mask[prot_size:] = 0
274 | res_mask[prot_size:] = 0
275 | mask[:prot_size] = 1
276 | res_mask[:prot_size] = 1
277 | else:
278 | grounding_pos = []
279 | grounding_feat = []
280 | grounding_covar = []
281 | for blob in blobs:
282 | grounding_pos.append(blob["pos"])
283 | grounding_feat.append((blob["dssp"], blob["count"]))
284 | grounding_covar.append(blob["covar"].flatten())
285 | grounding_pos = np.array(grounding_pos).astype(np.float32)
286 | grounding_feat = np.array(grounding_feat)
287 | grounding_covar = np.array(grounding_covar).astype(np.float32)
288 | grounding_feat = np.concatenate(
289 | [grounding_feat, grounding_covar], axis=-1
290 | ).astype(np.float32)
291 | grounding_mask = np.ones_like(grounding_feat[:, 0])
292 |
293 | if self.args.fixed_inference_size is not None:
294 | assert self.args.fixed_inference_size <= self.args.crop
295 | mask[self.args.fixed_inference_size :] = 0
296 | res_mask[self.args.fixed_inference_size :] = 0
297 | mask[: self.args.fixed_inference_size] = 1
298 | res_mask[: self.args.fixed_inference_size] = 1
299 |
300 | if self.args.length_dist_npz is not None:
301 | prot_size = np.random.choice(self.length_dist)
302 | mask[prot_size:] = 0
303 | res_mask[prot_size:] = 0
304 | mask[:prot_size] = 1
305 | res_mask[:prot_size] = 1
306 |
307 | if self.args.use_latents:
308 | try:
309 | latents = self.latents[name.lower()]
310 | except Exception as e:
311 | print(f"key error for {name.lower()}")
312 | return self.__getitem__(np.random.randint(len(self)))
313 | else:
314 | latents = 0
315 |
316 | return {
317 | "name": name,
318 | "grounding_pos": grounding_pos,
319 | "grounding_feat": grounding_feat,
320 | "grounding_mask": grounding_mask,
321 | "trans": frames._trans,
322 | "rots": frames._rots._rot_mats,
323 | "mask": mask,
324 | "res_mask": res_mask,
325 | "seqres": seqres,
326 | "res_idx": res_idx,
327 | "chain_idx": chain_idx,
328 | "blobs": pickle.dumps(blobs),
329 | "thresh": thresh,
330 | "latents": latents,
331 | }
332 |
333 | def get_synthetic_blobs(self):
334 | pos, covar = blobs.sample_blobs(
335 | self.args.num_blobs,
336 | nu=self.args.nu,
337 | psi=(1 / self.args.nu) * self.args.psi**2 * np.eye(3),
338 | sigma=self.args.sigma,
339 | )
340 | is_helix = np.random.rand(self.args.num_blobs) < self.args.helix_frac
341 | volume = np.linalg.det(covar) ** 0.5
342 |
343 | counts = np.where(
344 | is_helix,
345 | blobs.alpha_slope * volume + blobs.alpha_intercept,
346 | blobs.beta_slope * volume + blobs.beta_intercept,
347 | ).astype(int)
348 |
349 | dssp = np.where(is_helix, 1, 2)
350 |
351 | grounding_pos = pos[None].astype(np.float32)
352 | grounding_feat = np.zeros((1, self.args.num_blobs, 11), dtype=np.float32)
353 | grounding_feat[:, :, 0] = dssp # all helix for now
354 | grounding_feat[:, :, 1] = counts
355 | grounding_feat[:, :, 2:] = covar.reshape(1, self.args.num_blobs, 9)
356 | grounding_mask = np.ones(self.args.num_blobs, dtype=np.float32)
357 | return grounding_pos[0], grounding_feat[0], grounding_mask
358 |
359 |
360 | class MultiflowDataset(StructureDataset):
361 | def __init__(self, args, dataset_cfg):
362 | super().__init__(args)
363 | self.args = args
364 | self.mf_dataset = mfdatasets.PdbDataset(
365 | dataset_cfg=dataset_cfg, task="hallucination", is_training=True
366 | )
367 |
368 | def __len__(self):
369 | if self.args.dataset_multiplicity is not None:
370 | return len(self.mf_dataset.csv) * self.args.dataset_multiplicity
371 | else:
372 | return len(self.mf_dataset.csv)
373 |
374 | def __getitem__(self, idx):
375 | if self.args.overfit:
376 | idx = 0
377 | if self.args.dataset_multiplicity is not None:
378 | idx = idx % self.args.dataset_multiplicity
379 | row = self.mf_dataset.csv.iloc[idx]
380 | name = row["pdb_name"]
381 | path = f"{self.args.pkl_dir}/{row['processed_path'].replace('./', '')}"
382 |
383 | # res_idx, chain_idx, res_mask, trans_1, rotmats_1, aatypes_1, res_plddt
384 | mf_prot = mfdatasets._process_csv_row(path)
385 |
386 | atom37 = mf_prot["all_atom_positions"].float().numpy()
387 | # atom37_to_pdb(atom37[None], f'../protfid/data/mf_testset/{name}.pdb')
388 | frames = Rigid(
389 | Rotation(rot_mats=mf_prot["rotmats_1"].float(), quats=None),
390 | mf_prot["trans_1"].float(),
391 | )
392 |
393 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py)
394 | bb_pos = np.concatenate(
395 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1
396 | ) # (L, 4, 3)
397 |
398 | L = frames.shape[0]
399 | res_mask = mf_prot["res_mask"].float().numpy()
400 | seqres = mf_prot["aatypes_1"].numpy()
401 | res_idx = mf_prot["res_idx"]
402 | chain_idx = mf_prot["chain_idx"]
403 | return self.item_from_prot(
404 | name, atom37, frames, res_mask, seqres, res_idx, chain_idx
405 | )
406 |
407 |
408 | class GenieDBDataset(StructureDataset):
409 | def __init__(self, args):
410 | super().__init__(args)
411 | self.args = args
412 | files = os.listdir(self.args.genie_db_path)
413 | self.db = [file for file in files if ".pdb" in file]
414 | self.db = self.db * args.repeat_dataset
415 |
416 | def __len__(self):
417 | return len(self.db)
418 |
419 | def __getitem__(self, idx):
420 | if self.args.overfit:
421 | idx = 0
422 | name = self.db[idx]
423 | with open(f"{self.args.genie_db_path}/{name}", "r") as f:
424 | pdb = f.read()
425 | try:
426 | sample = self.process_prot(idx, name, pdb)
427 | except Exception as e:
428 | # for some reason the name printing here does not work so dont rely on it. We instead write it to a file.
429 | print("name", name, flush=True)
430 | with open(os.path.join(os.environ["MODEL_DIR"], "debug.txt"), "w") as f:
431 | f.write(name)
432 | raise e
433 | return sample
434 |
435 |
436 | class SCOPDataset(torch.utils.data.Dataset):
437 | def __init__(self, args):
438 | super().__init__()
439 | self.args = args
440 | names = os.listdir(args.scop_dir)
441 | nmr_names = pd.read_csv(args.scop_nmr_csv)
442 | self.names = [
443 | name for name in names if name not in nmr_names["nmr_names"].tolist()
444 | ]
445 |
446 | def __len__(self):
447 | return len(self.names)
448 |
449 | def __getitem__(self, idx):
450 | if self.args.overfit:
451 | idx = 0
452 |
453 | name = self.names[idx]
454 | pdb = open(f"{self.args.scop_dir}/{name}", "r").read()
455 | prot = protein.from_pdb_string(pdb)
456 |
457 | atom37 = prot.atom_positions.astype(np.float32)
458 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py)
459 | bb_pos = np.concatenate(
460 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1
461 | ) # (L, 4, 3)
462 | dssp = pydssp.assign(
463 | bb_pos, out_type="index"
464 | ) # 0: loop, 1: alpha-helix, 2: beta-strand
465 |
466 | pos = atom37[:, 1, :]
467 | L = atom37.shape[0]
468 | mask = np.ones(L, dtype=np.float32)
469 | seqres = prot.aatype.astype(int)
470 | resid = prot.residue_index.astype(int)
471 |
472 | cm = (pos * mask[:, None]).sum(0, keepdims=True) / mask.sum()
473 | pos = pos - cm
474 | if not self.args.overfit_rot:
475 | rot = spRotation.random().as_matrix().astype(np.float32)
476 | pos = pos @ rot.T
477 |
478 | n_clusters = math.ceil(L / self.args.res_per_cluster)
479 | kmeans = KMeans(n_clusters=n_clusters, n_init=1).fit(bb_pos[:, 1])
480 | labels = kmeans.labels_
481 | centers = kmeans.cluster_centers_
482 | dssp_onehot = torch.nn.functional.one_hot(torch.from_numpy(dssp), num_classes=3)
483 |
484 | dssp_count = torch.zeros(len(centers), 3, dtype=torch.long).scatter_add_(
485 | dim=0,
486 | index=torch.from_numpy(labels).long()[:, None].expand(-1, 3),
487 | src=dssp_onehot,
488 | )
489 | dssp_dist = dssp_count / dssp_count.sum(1)[:, None]
490 |
491 | return {
492 | "prot": prot,
493 | "labels": labels,
494 | "centers": centers,
495 | "dssp_dist": dssp_dist,
496 | "resid": resid,
497 | "name": name,
498 | "dssp": dssp,
499 | "bb_pos": pos,
500 | "mask": mask,
501 | "seqres": seqres,
502 | }
503 |
--------------------------------------------------------------------------------
/proteinblobs/designability_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import subprocess
12 |
13 | import esm, tqdm, torch
14 | import os
15 | from biotite.sequence.io import fasta
16 | import numpy as np
17 |
18 | from proteinblobs.utils import atom37_to_pdb
19 | from openfold.np import residue_constants
20 |
21 |
22 |
23 | def run_designability(
24 | prots, pmpnn_path, process_rank=0, seqs_per_struct=8, sequences=None
25 | ):
26 | # prots in list of np arrays of ( L, 3)
27 |
28 | print("Loading ESMFold model for designability evaluation")
29 | torch.cuda.empty_cache()
30 | esmf_model = esm.pretrained.esmfold_v1().eval()
31 | esmf_model = esmf_model.to("cuda")
32 |
33 | all_tm_scores = []
34 | all_rmsds = []
35 | top_rmsds = []
36 | top_tm_scores = []
37 | for i, prot in tqdm.tqdm(enumerate(prots), desc="Running PMPNN and ESMFold"):
38 | sample_dir = path = os.path.join(
39 | os.environ["MODEL_DIR"],
40 | f"tmp_design_dir_process{process_rank}",
41 | f"sample{i}_process{process_rank}",
42 | )
43 | os.makedirs(sample_dir, exist_ok=True)
44 | sample_path = os.path.join(sample_dir, f"sample.pdb")
45 |
46 | rmsds = []
47 | #tm_scores = []
48 | if sequences is not None:
49 | restypes = residue_constants.restypes + ["X"]
50 | seqs = ["".join(map(lambda x: restypes[x], sequences[i]))]
51 |
52 | else:
53 | atom37_to_pdb(prot[None], sample_path)
54 | run_pmpnn(
55 | pdb_dir=sample_dir, num_seqs=seqs_per_struct, pmpnn_path=pmpnn_path
56 | )
57 | mpnn_fasta_path = os.path.join(
58 | sample_dir, "seqs", os.path.basename(sample_path).replace(".pdb", ".fa")
59 | )
60 | fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path)
61 |
62 | seqs = [v for k, v in fasta_seqs.items()]
63 | seqs = seqs[
64 | 1:
65 | ] # remove the first sequence, which is the input sequence (and AAAAAAA... if there is no input sequence to the structure)
66 |
67 | for j, seq in enumerate(seqs):
68 | seq = seq.replace("X", "A")
69 |
70 | with torch.no_grad():
71 | output = esmf_model.infer(seq)
72 |
73 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy()
74 | #_, tm_score = get_tm_score(prot[:, 1], out_ca_pos, seq, seq)
75 | rmsd = get_aligned_rmsd(prot[:, 1], out_ca_pos)
76 | rmsds.append(rmsd)
77 | #tm_scores.append(tm_score)
78 | all_rmsds.append(np.array(rmsds).mean())
79 | #all_tm_scores.append(np.array(tm_scores).mean())
80 | top_rmsds.append(np.array(rmsds).min())
81 | #top_tm_scores.append(np.array(tm_scores).max())
82 | del esmf_model
83 | torch.cuda.empty_cache()
84 | #all_tm_scores = np.array(all_tm_scores)
85 | all_rmsds = np.array(all_rmsds)
86 | top_rmsds = np.array(top_rmsds)
87 | #top_tm_scores = np.array(top_tm_scores)
88 |
89 | return {
90 | #"tm_score": all_tm_scores,
91 | "rmsd": all_rmsds,
92 | #"top_tm_score": top_tm_scores,
93 | "top_rmsd": top_rmsds,
94 | }
95 |
96 |
97 | def run_pmpnn(pdb_dir, num_seqs=8, pmpnn_path="../ProteinMPNN", ca_only=False):
98 | os.makedirs(os.path.join(pdb_dir, "seqs"), exist_ok=True)
99 | parsed_chains_path = os.path.join(pdb_dir, "seqs", "parsed_chains.jsonl")
100 |
101 | process = subprocess.Popen(
102 | [
103 | "python",
104 | os.path.join(pmpnn_path, "helper_scripts/parse_multiple_chains.py"),
105 | f"--input_path={pdb_dir}",
106 | f"--output_path={parsed_chains_path}",
107 | ]
108 | )
109 | _ = process.wait()
110 |
111 | pmpnn_args = [
112 | "python",
113 | os.path.join(pmpnn_path, "protein_mpnn_run.py"),
114 | "--out_folder",
115 | pdb_dir,
116 | "--jsonl_path",
117 | parsed_chains_path,
118 | "--num_seq_per_target",
119 | str(num_seqs),
120 | "--sampling_temp",
121 | "0.1",
122 | "--seed",
123 | "38",
124 | "--batch_size",
125 | "1",
126 | ]
127 | if ca_only:
128 | pmpnn_args.append("--ca_only")
129 | print(" ".join(pmpnn_args))
130 |
131 | process = subprocess.run(pmpnn_args)
132 |
133 |
134 | def get_aligned_rmsd(pos_1, pos_2):
135 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0]
136 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
137 |
138 |
139 | def get_tm_score(pos_1, pos_2, seq_1, seq_2):
140 | # tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
141 | return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2
142 |
143 |
144 | def rigid_transform_3D(A, B, verbose=False):
145 | # Transforms A to look like B
146 | # https://github.com/nghiaho12/rigid_transform_3D
147 | assert A.shape == B.shape
148 | A = A.T
149 | B = B.T
150 |
151 | num_rows, num_cols = A.shape
152 | if num_rows != 3:
153 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
154 |
155 | num_rows, num_cols = B.shape
156 | if num_rows != 3:
157 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
158 |
159 | # find mean column wise
160 | centroid_A = np.mean(A, axis=1)
161 | centroid_B = np.mean(B, axis=1)
162 |
163 | # ensure centroids are 3x1
164 | centroid_A = centroid_A.reshape(-1, 1)
165 | centroid_B = centroid_B.reshape(-1, 1)
166 |
167 | # subtract mean
168 | Am = A - centroid_A
169 | Bm = B - centroid_B
170 |
171 | H = Am @ np.transpose(Bm)
172 |
173 | # sanity check
174 | # if linalg.matrix_rank(H) < 3:
175 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))
176 |
177 | # find rotation
178 | U, S, Vt = np.linalg.svd(H)
179 | R = Vt.T @ U.T
180 |
181 | # special reflection case
182 | reflection_detected = False
183 | if np.linalg.det(R) < 0:
184 | if verbose:
185 | print("det(R) < R, reflection detected!, correcting for it ...")
186 | Vt[2, :] *= -1
187 | R = Vt.T @ U.T
188 | reflection_detected = True
189 |
190 | t = -R @ centroid_A + centroid_B
191 | optimal_A = R @ A + t
192 |
193 | return optimal_A.T, R, t, reflection_detected
194 |
--------------------------------------------------------------------------------
/proteinblobs/logger.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | import yaml, logging, socket, os, sys
13 | model_dir = os.environ.get("MODEL_DIR", "./workdir/default")
14 | def get_logger(name):
15 | logger = logging.Logger(name)
16 | level = {"crititical": 50, "error": 40, "warning": 30, "info": 20, "debug": 10}[
17 | os.environ.get("LOGGER_LEVEL", "info")
18 | ]
19 | logger.setLevel(level)
20 |
21 | ch = logging.StreamHandler()
22 | ch.setLevel(logging.INFO)
23 | os.makedirs(model_dir, exist_ok=True)
24 | fh = logging.FileHandler(os.path.join(model_dir, "log.out"))
25 | fh.setLevel(logging.DEBUG)
26 | formatter = logging.Formatter(
27 | f"%(asctime)s [{socket.gethostname()}:%(process)d] [%(levelname)s] %(message)s"
28 | )
29 | ch.setFormatter(formatter)
30 | fh.setFormatter(formatter)
31 | logger.addHandler(ch)
32 | logger.addHandler(fh)
33 | return logger
34 |
35 |
--------------------------------------------------------------------------------
/proteinblobs/multiflow/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Andrew Campbell, Jason Yim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/proteinblobs/multiflow/cross_attention.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import torch.nn as nn
12 |
13 | import torch
14 | import math
15 |
16 | def positional_encoding(x, channels, max_period=10000, min_period=None):
17 | min_period = min_period or max_period / 10000 # second one recommended
18 | freqs = torch.exp(-torch.linspace(math.log(min_period), math.log(max_period), channels // 2, device=x.device))
19 | emb = freqs * x.unsqueeze(-1) # [..., C]
20 | return torch.cat([torch.sin(emb), torch.cos(emb)], -1) # [..., 2C]
21 |
22 |
23 |
24 | class InvariantCrossAttention(nn.Module):
25 | def __init__(self, cfg):
26 | super().__init__()
27 | self.cfg = cfg
28 | pos_emb_dim = 64
29 | self.linear_pos = nn.Linear(3 * pos_emb_dim, cfg.c_s)
30 | self.linear_kv = nn.Linear(cfg.c_s, 2 * cfg.c_s)
31 | self.linear_covar = nn.Linear(9, cfg.c_s)
32 |
33 | self.linear_q = nn.Linear(cfg.c_s, cfg.c_s)
34 | self.linear_out = nn.Linear(cfg.c_s, cfg.c_s)
35 | self.gamma = nn.Parameter(torch.zeros(1))
36 | self.beta = nn.Parameter(torch.ones(1))
37 | self.beta.requires_grad_(False)
38 | self.inf = 1e10
39 |
40 | def forward(self, s, r, x, y, c, mask):
41 | '''
42 | s: frame features
43 | r: frames
44 | x: point positions
45 | y: point features
46 | c: blob covariance
47 | mask: point mask
48 | '''
49 |
50 | # [B, L, N, 3]
51 | relpos = r[:,:,None].invert_apply(x[:,None])
52 |
53 | # [B, L, 3, 3]
54 | R = r._rots.get_rot_mats()
55 |
56 | # [B, L, N, 3, 3]
57 | covar = torch.einsum('blij,bnjk,blkm->blnim', R.transpose(-1, -2), c, R)
58 |
59 | # [B, L, N, C]
60 | relpos = positional_encoding(relpos, 64, max_period=1000) # 1000 nm???
61 | B, L, N, _, _ = relpos.shape
62 |
63 | # [B, L, N, C]
64 | y = y[:,None] + self.linear_pos(relpos.view(B, L, N, 3*64)) + self.linear_covar(covar.view(B, L, N, 9))
65 | B, L, N, C = y.shape
66 |
67 | # [B, L, N, C]
68 | k, v = self.linear_kv(y).chunk(2, dim=-1)
69 |
70 | # [B, L, C]
71 | q = self.linear_q(s)
72 |
73 | H = self.cfg.no_heads
74 | k = k.view(B, L, N, H, C//H)
75 | v = v.view(B, L, N, H, C//H)
76 | q = q.view(B, L, 1, H, C//H)
77 |
78 | # [B, L, N, H]
79 | attn = torch.einsum('BLNHC,BLNHC->BLNH', q, k) / math.sqrt(C//H)
80 | attn = attn + (mask[:,None,:,None] - 1) * self.inf
81 | attn = torch.softmax(attn, dim=-2)
82 |
83 | # [B, L, C]
84 | output = torch.einsum('BLNH,BLNHC->BLHC', attn, v).reshape(B, L, C)
85 | output = self.linear_out(output)
86 |
87 | output = self.beta * torch.tanh(self.gamma) * output
88 |
89 | return output
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/proteinblobs/multiflow/data/all_atom.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | # Copyright 2021 AlQuraishi Laboratory
13 | # Copyright 2021 DeepMind Technologies Limited
14 | #
15 | # Licensed under the Apache License, Version 2.0 (the "License");
16 | # you may not use this file except in compliance with the License.
17 | # You may obtain a copy of the License at
18 | #
19 | # http://www.apache.org/licenses/LICENSE-2.0
20 | #
21 | # Unless required by applicable law or agreed to in writing, software
22 | # distributed under the License is distributed on an "AS IS" BASIS,
23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | # See the License for the specific language governing permissions and
25 | # limitations under the License.
26 |
27 | """Utilities for calculating all atom representations."""
28 |
29 | import torch
30 | from openfold.data import data_transforms
31 | from openfold.np import residue_constants
32 | from .. import rigid_utils as ru
33 | from . import utils as du
34 |
35 | Rigid = ru.Rigid
36 | Rotation = ru.Rotation
37 |
38 | # Residue Constants from OpenFold/AlphaFold2.
39 |
40 |
41 | IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions)
42 | DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame)
43 | ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask)
44 | GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group)
45 |
46 |
47 | def to_atom37(trans, rots):
48 | num_batch, num_res, _ = trans.shape
49 | final_atom37 = compute_backbone(
50 | du.create_rigid(rots, trans),
51 | torch.zeros(num_batch, num_res, 2, device=trans.device),
52 | )[0]
53 | return final_atom37
54 |
55 |
56 | def torsion_angles_to_frames(
57 | r: Rigid, # type: ignore [valid-type]
58 | alpha: torch.Tensor,
59 | aatype: torch.Tensor,
60 | ):
61 | """Conversion method of torsion angles to frames provided the backbone.
62 |
63 | Args:
64 | r: Backbone rigid groups.
65 | alpha: Torsion angles.
66 | aatype: residue types.
67 |
68 | Returns:
69 | All 8 frames corresponding to each torsion frame.
70 |
71 | """
72 | # [*, N, 8, 4, 4]
73 | with torch.no_grad():
74 | default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] # type: ignore [attr-defined]
75 |
76 | # [*, N, 8] transformations, i.e.
77 | # One [*, N, 8, 3, 3] rotation matrix and
78 | # One [*, N, 8, 3] translation matrix
79 | default_r = r.from_tensor_4x4(default_4x4) # type: ignore [attr-defined]
80 |
81 | bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
82 | bb_rot[..., 1] = 1
83 |
84 | # [*, N, 8, 2]
85 | alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
86 |
87 | # [*, N, 8, 3, 3]
88 | # Produces rotation matrices of the form:
89 | # [
90 | # [1, 0 , 0 ],
91 | # [0, a_2,-a_1],
92 | # [0, a_1, a_2]
93 | # ]
94 | # This follows the original code rather than the supplement, which uses
95 | # different indices.
96 |
97 | all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
98 | all_rots[..., 0, 0] = 1
99 | all_rots[..., 1, 1] = alpha[..., 1]
100 | all_rots[..., 1, 2] = -alpha[..., 0]
101 | all_rots[..., 2, 1:] = alpha
102 |
103 | all_rots = Rigid(Rotation(rot_mats=all_rots), None)
104 |
105 | all_frames = default_r.compose(all_rots)
106 |
107 | chi2_frame_to_frame = all_frames[..., 5]
108 | chi3_frame_to_frame = all_frames[..., 6]
109 | chi4_frame_to_frame = all_frames[..., 7]
110 |
111 | chi1_frame_to_bb = all_frames[..., 4]
112 | chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
113 | chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
114 | chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
115 |
116 | all_frames_to_bb = Rigid.cat(
117 | [
118 | all_frames[..., :5],
119 | chi2_frame_to_bb.unsqueeze(-1),
120 | chi3_frame_to_bb.unsqueeze(-1),
121 | chi4_frame_to_bb.unsqueeze(-1),
122 | ],
123 | dim=-1,
124 | )
125 |
126 | all_frames_to_global = r[..., None].compose(all_frames_to_bb) # type: ignore [index]
127 |
128 | return all_frames_to_global
129 |
130 |
131 | def prot_to_torsion_angles(aatype, atom37, atom37_mask):
132 | """Calculate torsion angle features from protein features."""
133 | prot_feats = {
134 | "aatype": aatype,
135 | "all_atom_positions": atom37,
136 | "all_atom_mask": atom37_mask,
137 | }
138 | torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats)
139 | torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"]
140 | torsion_mask = torsion_angles_feats["torsion_angles_mask"]
141 | return torsion_angles, torsion_mask
142 |
143 |
144 | def frames_to_atom14_pos(
145 | r: Rigid, # type: ignore [valid-type]
146 | aatype: torch.Tensor,
147 | ):
148 | """Convert frames to their idealized all atom representation.
149 |
150 | Args:
151 | r: All rigid groups. [..., N, 8, 3]
152 | aatype: Residue types. [..., N]
153 |
154 | Returns:
155 |
156 | """
157 | with torch.no_grad():
158 | group_mask = GROUP_IDX.to(aatype.device)[aatype, ...]
159 | group_mask = torch.nn.functional.one_hot(
160 | group_mask,
161 | num_classes=DEFAULT_FRAMES.shape[-3],
162 | )
163 | frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) # type: ignore [attr-defined]
164 | frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] # type: ignore [attr-defined]
165 |
166 | # [*, N, 14, 8]
167 | t_atoms_to_global = r[..., None, :] * group_mask # type: ignore [index]
168 |
169 | # [*, N, 14]
170 | t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
171 |
172 | # [*, N, 14, 3]
173 | pred_positions = t_atoms_to_global.apply(frame_null_pos)
174 | pred_positions = pred_positions * frame_atom_mask
175 |
176 | return pred_positions
177 |
178 |
179 | def compute_backbone(bb_rigids, psi_torsions):
180 | torsion_angles = torch.tile(
181 | psi_torsions[..., None, :],
182 | tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1),
183 | )
184 | aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long()
185 | # aatype = torch.zeros(bb_rigids.shape).long().to(bb_rigids.device)
186 | all_frames = torsion_angles_to_frames(
187 | bb_rigids,
188 | torsion_angles,
189 | aatype,
190 | )
191 | atom14_pos = frames_to_atom14_pos(all_frames, aatype)
192 | atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device)
193 | # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
194 | # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
195 | atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
196 | atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :]
197 | atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :]
198 | atom37_mask = torch.any(atom37_bb_pos, axis=-1)
199 | return atom37_bb_pos, atom37_mask, aatype, atom14_pos
200 |
201 |
202 | def calculate_neighbor_angles(R_ac, R_ab):
203 | """Calculate angles between atoms c <- a -> b.
204 |
205 | Parameters
206 | ----------
207 | R_ac: Tensor, shape = (N,3)
208 | Vector from atom a to c.
209 | R_ab: Tensor, shape = (N,3)
210 | Vector from atom a to b.
211 |
212 | Returns
213 | -------
214 | angle_cab: Tensor, shape = (N,)
215 | Angle between atoms c <- a -> b.
216 | """
217 | # cos(alpha) = (u * v) / (|u|*|v|)
218 | x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
219 | # sin(alpha) = |u x v| / (|u|*|v|)
220 | y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
221 | # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
222 | y = torch.max(y, torch.tensor(1e-9))
223 | angle = torch.atan2(y, x)
224 | return angle
225 |
226 |
227 | def vector_projection(R_ab, P_n):
228 | """
229 | Project the vector R_ab onto a plane with normal vector P_n.
230 |
231 | Parameters
232 | ----------
233 | R_ab: Tensor, shape = (N,3)
234 | Vector from atom a to b.
235 | P_n: Tensor, shape = (N,3)
236 | Normal vector of a plane onto which to project R_ab.
237 |
238 | Returns
239 | -------
240 | R_ab_proj: Tensor, shape = (N,3)
241 | Projected vector (orthogonal to P_n).
242 | """
243 | a_x_b = torch.sum(R_ab * P_n, dim=-1)
244 | b_x_b = torch.sum(P_n * P_n, dim=-1)
245 | return R_ab - (a_x_b / b_x_b)[:, None] * P_n
246 |
247 |
248 | def transrot_traj_to_atom37(transrot_traj, res_mask):
249 | atom37_traj = []
250 | for trans, rots in transrot_traj:
251 | atom37_traj.append(atom37_from_trans_rot(trans, rots, res_mask))
252 | return atom37_traj
253 |
254 |
255 | def atom37_from_trans_rot(trans, rots, res_mask=None):
256 | if res_mask is None:
257 | res_mask = torch.ones([*trans.shape[:-1]], device=trans.device)
258 | rigids = du.create_rigid(rots, trans)
259 | atom37 = compute_backbone(
260 | rigids, torch.zeros(trans.shape[0], trans.shape[1], 2, device=trans.device)
261 | )[0]
262 | batch_atom37 = []
263 | num_batch = res_mask.shape[0]
264 | for i in range(num_batch):
265 | batch_atom37.append(du.adjust_oxygen_pos(atom37[i], res_mask[i]))
266 | return torch.stack(batch_atom37)
267 |
268 | def process_trans_rot_traj(trans_traj, rots_traj, res_mask):
269 | res_mask = res_mask.detach().cpu()
270 | atom37_traj = [
271 | atom37_from_trans_rot(trans, rots, res_mask)
272 | for trans, rots in zip(trans_traj, rots_traj)
273 | ]
274 | atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1)
275 | return atom37_traj
276 |
--------------------------------------------------------------------------------
/proteinblobs/multiflow/data/protein.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Protein data type."""
16 | import dataclasses
17 | import io
18 | from typing import Any, Mapping, Optional
19 | from . import residue_constants
20 | from Bio.PDB import PDBParser
21 | import numpy as np
22 |
23 | FeatureDict = Mapping[str, np.ndarray]
24 | ModelOutput = Mapping[str, Any] # Is a nested dict.
25 |
26 | # Complete sequence of chain IDs supported by the PDB format.
27 | PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
28 | PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
29 |
30 |
31 | @dataclasses.dataclass(frozen=True)
32 | class Protein:
33 | """Protein structure representation."""
34 |
35 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to
36 | # residue_constants.atom_types, i.e. the first three are N, CA, CB.
37 | atom_positions: np.ndarray # [num_res, num_atom_type, 3]
38 |
39 | # Amino-acid type for each residue represented as an integer between 0 and
40 | # 20, where 20 is 'X'.
41 | aatype: np.ndarray # [num_res]
42 |
43 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
44 | # is present and 0.0 if not. This should be used for loss masking.
45 | atom_mask: np.ndarray # [num_res, num_atom_type]
46 |
47 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
48 | residue_index: np.ndarray # [num_res]
49 |
50 | # 0-indexed number corresponding to the chain in the protein that this residue
51 | # belongs to.
52 | chain_index: np.ndarray # [num_res]
53 |
54 | # B-factors, or temperature factors, of each residue (in sq. angstroms units),
55 | # representing the displacement of the residue from its ground truth mean
56 | # value.
57 | b_factors: np.ndarray # [num_res, num_atom_type]
58 |
59 | def __post_init__(self):
60 | if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
61 | raise ValueError(
62 | f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
63 | 'because these cannot be written to PDB format.')
64 |
65 |
66 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
67 | """Takes a PDB string and constructs a Protein object.
68 |
69 | WARNING: All non-standard residue types will be converted into UNK. All
70 | non-standard atoms will be ignored.
71 |
72 | Args:
73 | pdb_str: The contents of the pdb file
74 | chain_id: If chain_id is specified (e.g. A), then only that chain
75 | is parsed. Otherwise all chains are parsed.
76 |
77 | Returns:
78 | A new `Protein` parsed from the pdb contents.
79 | """
80 | pdb_fh = io.StringIO(pdb_str)
81 | parser = PDBParser(QUIET=True)
82 | structure = parser.get_structure('none', pdb_fh)
83 | models = list(structure.get_models())
84 | if len(models) != 1:
85 | raise ValueError(
86 | f'Only single model PDBs are supported. Found {len(models)} models.')
87 | model = models[0]
88 |
89 | atom_positions = []
90 | aatype = []
91 | atom_mask = []
92 | residue_index = []
93 | chain_ids = []
94 | b_factors = []
95 |
96 | for chain in model:
97 | if chain_id is not None and chain.id != chain_id:
98 | continue
99 | for res in chain:
100 | if res.id[2] != ' ':
101 | raise ValueError(
102 | f'PDB contains an insertion code at chain {chain.id} and residue '
103 | f'index {res.id[1]}. These are not supported.')
104 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
105 | restype_idx = residue_constants.restype_order.get(
106 | res_shortname, residue_constants.restype_num)
107 | pos = np.zeros((residue_constants.atom_type_num, 3))
108 | mask = np.zeros((residue_constants.atom_type_num,))
109 | res_b_factors = np.zeros((residue_constants.atom_type_num,))
110 | for atom in res:
111 | if atom.name not in residue_constants.atom_types:
112 | continue
113 | pos[residue_constants.atom_order[atom.name]] = atom.coord
114 | mask[residue_constants.atom_order[atom.name]] = 1.
115 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
116 | if np.sum(mask) < 0.5:
117 | # If no known atom positions are reported for the residue then skip it.
118 | continue
119 | aatype.append(restype_idx)
120 | atom_positions.append(pos)
121 | atom_mask.append(mask)
122 | residue_index.append(res.id[1])
123 | chain_ids.append(chain.id)
124 | b_factors.append(res_b_factors)
125 |
126 | # Chain IDs are usually characters so map these to ints.
127 | unique_chain_ids = np.unique(chain_ids)
128 | chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
129 | chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
130 |
131 | return Protein(
132 | atom_positions=np.array(atom_positions),
133 | atom_mask=np.array(atom_mask),
134 | aatype=np.array(aatype),
135 | residue_index=np.array(residue_index),
136 | chain_index=chain_index,
137 | b_factors=np.array(b_factors))
138 |
139 |
140 | def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
141 | chain_end = 'TER'
142 | return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
143 | f'{chain_name:>1}{residue_index:>4}')
144 |
145 |
146 | def to_pdb(prot: Protein, model=1, add_end=True) -> str:
147 | """Converts a `Protein` instance to a PDB string.
148 |
149 | Args:
150 | prot: The protein to convert to PDB.
151 |
152 | Returns:
153 | PDB string.
154 | """
155 | restypes = residue_constants.restypes + ['X']
156 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
157 | atom_types = residue_constants.atom_types
158 |
159 | pdb_lines = []
160 |
161 | atom_mask = prot.atom_mask
162 | aatype = prot.aatype
163 | atom_positions = prot.atom_positions
164 | residue_index = prot.residue_index.astype(int)
165 | chain_index = prot.chain_index.astype(int)
166 | b_factors = prot.b_factors
167 |
168 | if np.any(aatype > residue_constants.restype_num):
169 | raise ValueError('Invalid aatypes.')
170 |
171 | # Construct a mapping from chain integer indices to chain ID strings.
172 | chain_ids = {}
173 | for i in np.unique(chain_index): # np.unique gives sorted output.
174 | if i >= PDB_MAX_CHAINS:
175 | raise ValueError(
176 | f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
177 | chain_ids[i] = PDB_CHAIN_IDS[i]
178 |
179 | pdb_lines.append(f'MODEL {model}')
180 | atom_index = 1
181 | last_chain_index = chain_index[0]
182 | # Add all atom sites.
183 | for i in range(aatype.shape[0]):
184 | # Close the previous chain if in a multichain PDB.
185 | if last_chain_index != chain_index[i]:
186 | pdb_lines.append(_chain_end(
187 | atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
188 | residue_index[i - 1]))
189 | last_chain_index = chain_index[i]
190 | atom_index += 1 # Atom index increases at the TER symbol.
191 |
192 | res_name_3 = res_1to3(aatype[i])
193 | for atom_name, pos, mask, b_factor in zip(
194 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
195 | if mask < 0.5:
196 | continue
197 |
198 | record_type = 'ATOM'
199 | name = atom_name if len(atom_name) == 4 else f' {atom_name}'
200 | alt_loc = ''
201 | insertion_code = ''
202 | occupancy = 1.00
203 | element = atom_name[0] # Protein supports only C, N, O, S, this works.
204 | charge = ''
205 | # PDB is a columnar format, every space matters here!
206 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
207 | f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
208 | f'{residue_index[i]:>4}{insertion_code:>1} '
209 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
210 | f'{occupancy:>6.2f}{b_factor:>6.2f} '
211 | f'{element:>2}{charge:>2}')
212 | pdb_lines.append(atom_line)
213 | atom_index += 1
214 |
215 | # Close the final chain.
216 | pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
217 | chain_ids[chain_index[-1]], residue_index[-1]))
218 | pdb_lines.append('ENDMDL')
219 | if add_end:
220 | pdb_lines.append('END')
221 |
222 | # Pad all lines to 80 characters.
223 | pdb_lines = [line.ljust(80) for line in pdb_lines]
224 | return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
225 |
226 |
227 | def ideal_atom_mask(prot: Protein) -> np.ndarray:
228 | """Computes an ideal atom mask.
229 |
230 | `Protein.atom_mask` typically is defined according to the atoms that are
231 | reported in the PDB. This function computes a mask according to heavy atoms
232 | that should be present in the given sequence of amino acids.
233 |
234 | Args:
235 | prot: `Protein` whose fields are `numpy.ndarray` objects.
236 |
237 | Returns:
238 | An ideal atom mask.
239 | """
240 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
241 |
242 |
243 | def from_prediction(
244 | features: FeatureDict,
245 | result: ModelOutput,
246 | b_factors: Optional[np.ndarray] = None,
247 | remove_leading_feature_dimension: bool = True) -> Protein:
248 | """Assembles a protein from a prediction.
249 |
250 | Args:
251 | features: Dictionary holding model inputs.
252 | result: Dictionary holding model outputs.
253 | b_factors: (Optional) B-factors to use for the protein.
254 | remove_leading_feature_dimension: Whether to remove the leading dimension
255 | of the `features` values.
256 |
257 | Returns:
258 | A protein instance.
259 | """
260 | fold_output = result['structure_module']
261 |
262 | def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
263 | return arr[0] if remove_leading_feature_dimension else arr
264 |
265 | if 'asym_id' in features:
266 | chain_index = _maybe_remove_leading_dim(features['asym_id'])
267 | else:
268 | chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
269 |
270 | if b_factors is None:
271 | b_factors = np.zeros_like(fold_output['final_atom_mask'])
272 |
273 | return Protein(
274 | aatype=_maybe_remove_leading_dim(features['aatype']),
275 | atom_positions=fold_output['final_atom_positions'],
276 | atom_mask=fold_output['final_atom_mask'],
277 | residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
278 | chain_index=chain_index,
279 | b_factors=b_factors)
--------------------------------------------------------------------------------
/proteinblobs/multiflow/datasets.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import abc
12 | import numpy as np
13 | import pandas as pd
14 | import logging
15 | import tree
16 | import torch
17 | import random
18 | import os
19 |
20 | from glob import glob
21 | from torch.utils.data import Dataset
22 | import proteinblobs.multiflow.utils as du
23 | from openfold.data import data_transforms
24 | from openfold.utils import rigid_utils
25 | from sklearn.preprocessing import PolynomialFeatures
26 | from sklearn.linear_model import LinearRegression
27 | from pytorch_lightning.utilities import rank_zero_only
28 |
29 |
30 | def _rog_filter(df, quantile):
31 | y_quant = pd.pivot_table(
32 | df,
33 | values='radius_gyration',
34 | index='modeled_seq_len',
35 | aggfunc=lambda x: np.quantile(x, quantile)
36 | )
37 | x_quant = y_quant.index.to_numpy()
38 | y_quant = y_quant.radius_gyration.to_numpy()
39 |
40 | # Fit polynomial regressor
41 | poly = PolynomialFeatures(degree=4, include_bias=True)
42 | poly_features = poly.fit_transform(x_quant[:, None])
43 | poly_reg_model = LinearRegression()
44 | poly_reg_model.fit(poly_features, y_quant)
45 |
46 | # Calculate cutoff for all sequence lengths
47 | max_len = df.modeled_seq_len.max()
48 | pred_poly_features = poly.fit_transform(np.arange(max_len)[:, None])
49 | # Add a little more.
50 | pred_y = poly_reg_model.predict(pred_poly_features) + 0.1
51 |
52 | row_rog_cutoffs = df.modeled_seq_len.map(lambda x: pred_y[x-1])
53 | return df[df.radius_gyration < row_rog_cutoffs]
54 |
55 |
56 | def _length_filter(data_csv, min_res, max_res):
57 | return data_csv[
58 | (data_csv.modeled_seq_len >= min_res)
59 | & (data_csv.modeled_seq_len <= max_res)
60 | ]
61 |
62 |
63 | def _plddt_percent_filter(data_csv, min_plddt_percent):
64 | return data_csv[data_csv.num_confident_plddt > min_plddt_percent]
65 |
66 |
67 | def _max_coil_filter(data_csv, max_coil_percent):
68 | return data_csv[data_csv.coil_percent <= max_coil_percent]
69 |
70 |
71 | def _process_csv_row(processed_file_path):
72 | processed_feats = du.read_pkl(processed_file_path)
73 | processed_feats = du.parse_chain_feats(processed_feats)
74 |
75 | # Only take modeled residues.
76 | modeled_idx = processed_feats['modeled_idx']
77 | min_idx = np.min(modeled_idx)
78 | max_idx = np.max(modeled_idx)
79 | del processed_feats['modeled_idx']
80 | processed_feats = tree.map_structure(
81 | lambda x: x[min_idx:(max_idx+1)], processed_feats)
82 |
83 | # Run through OpenFold data transforms.
84 | chain_feats = {
85 | 'aatype': torch.tensor(processed_feats['aatype']).long(),
86 | 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(),
87 | 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double()
88 | }
89 | chain_feats = data_transforms.atom37_to_frames(chain_feats)
90 | rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0]
91 | rotmats_1 = rigids_1.get_rots().get_rot_mats()
92 | trans_1 = rigids_1.get_trans()
93 | res_plddt = processed_feats['b_factors'][:, 1]
94 | res_mask = torch.tensor(processed_feats['bb_mask']).int()
95 |
96 | # Re-number residue indices for each chain such that it starts from 1.
97 | # Randomize chain indices.
98 | chain_idx = processed_feats['chain_index']
99 | res_idx = processed_feats['residue_index']
100 | new_res_idx = np.zeros_like(res_idx)
101 | new_chain_idx = np.zeros_like(res_idx)
102 | all_chain_idx = np.unique(chain_idx).tolist()
103 | shuffled_chain_idx = np.array(
104 | random.sample(all_chain_idx, len(all_chain_idx))) - np.min(all_chain_idx) + 1
105 | for i,chain_id in enumerate(all_chain_idx):
106 | chain_mask = (chain_idx == chain_id).astype(int)
107 | chain_min_idx = np.min(res_idx + (1 - chain_mask) * 1e3).astype(int)
108 | new_res_idx = new_res_idx + (res_idx - chain_min_idx + 1) * chain_mask
109 |
110 | # Shuffle chain_index
111 | replacement_chain_id = shuffled_chain_idx[i]
112 | new_chain_idx = new_chain_idx + replacement_chain_id * chain_mask
113 |
114 | if torch.isnan(trans_1).any() or torch.isnan(rotmats_1).any():
115 | raise ValueError(f'Found NaNs in {processed_file_path}')
116 |
117 | return {
118 | 'res_plddt': res_plddt,
119 | 'aatypes_1': chain_feats['aatype'],
120 | 'all_atom_positions': chain_feats['all_atom_positions'],
121 | 'rotmats_1': rotmats_1,
122 | 'trans_1': trans_1,
123 | 'res_mask': res_mask,
124 | 'chain_idx': new_chain_idx,
125 | 'res_idx': new_res_idx,
126 | }
127 |
128 |
129 | def _add_plddt_mask(feats, plddt_threshold):
130 | feats['plddt_mask'] = torch.tensor(
131 | feats['res_plddt'] > plddt_threshold).int()
132 |
133 |
134 | def _read_clusters(cluster_path, synthetic=False):
135 | pdb_to_cluster = {}
136 | with open(cluster_path, "r") as f:
137 | for i,line in enumerate(f):
138 | for chain in line.split(' '):
139 | if not synthetic:
140 | pdb = chain.split('_')[0].strip()
141 | else:
142 | pdb = chain.strip()
143 | pdb_to_cluster[pdb.upper()] = i
144 | return pdb_to_cluster
145 |
146 |
147 | class BaseDataset(Dataset):
148 | def __init__(
149 | self,
150 | *,
151 | dataset_cfg,
152 | is_training,
153 | task,
154 | ):
155 | self._log = logging.getLogger(__name__)
156 | self._is_training = is_training
157 | self._dataset_cfg = dataset_cfg
158 | self.task = task
159 | self.raw_csv = pd.read_csv(self.dataset_cfg.csv_path)
160 | metadata_csv = self._filter_metadata(self.raw_csv)
161 | metadata_csv = metadata_csv.sort_values(
162 | 'modeled_seq_len', ascending=False)
163 | if self._dataset_cfg.use_redesigned:
164 | self.redesigned_csv = pd.read_csv(self._dataset_cfg.redesigned_csv_path)
165 | metadata_csv = metadata_csv.merge(
166 | self.redesigned_csv, left_on='pdb_name', right_on='example')
167 | metadata_csv = metadata_csv[metadata_csv.best_rmsd < 2.0]
168 | if self._dataset_cfg.cluster_path is not None:
169 | pdb_to_cluster = _read_clusters(self._dataset_cfg.cluster_path, synthetic=True)
170 | def cluster_lookup(pdb):
171 | pdb = pdb.upper()
172 | if pdb not in pdb_to_cluster:
173 | raise ValueError(f'Cluster not found for {pdb}')
174 | return pdb_to_cluster[pdb]
175 | metadata_csv['cluster'] = metadata_csv['pdb_name'].map(cluster_lookup)
176 | self._create_split(metadata_csv)
177 | self._cache = {}
178 | self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
179 |
180 | @property
181 | def is_training(self):
182 | return self._is_training
183 |
184 | @property
185 | def dataset_cfg(self):
186 | return self._dataset_cfg
187 |
188 | def __len__(self):
189 | return len(self.csv)
190 |
191 | @abc.abstractmethod
192 | def _filter_metadata(self, raw_csv: pd.DataFrame) -> pd.DataFrame:
193 | pass
194 |
195 | def _create_split(self, data_csv):
196 | # Training or validation specific logic.
197 | if self.is_training:
198 | self.csv = data_csv
199 | self._log.info(
200 | f'Training: {len(self.csv)} examples')
201 | else:
202 | if self._dataset_cfg.max_eval_length is None:
203 | eval_lengths = data_csv.modeled_seq_len
204 | else:
205 | eval_lengths = data_csv.modeled_seq_len[
206 | data_csv.modeled_seq_len <= self._dataset_cfg.max_eval_length
207 | ]
208 | all_lengths = np.sort(eval_lengths.unique())
209 | length_indices = (len(all_lengths) - 1) * np.linspace(
210 | 0.0, 1.0, self.dataset_cfg.num_eval_lengths)
211 | length_indices = length_indices.astype(int)
212 | eval_lengths = all_lengths[length_indices]
213 | eval_csv = data_csv[data_csv.modeled_seq_len.isin(eval_lengths)]
214 |
215 | # Fix a random seed to get the same split each time.
216 | eval_csv = eval_csv.groupby('modeled_seq_len').sample(
217 | self.dataset_cfg.samples_per_eval_length,
218 | replace=True,
219 | random_state=123
220 | )
221 | eval_csv = eval_csv.sort_values('modeled_seq_len', ascending=False)
222 | self.csv = eval_csv
223 | self._log.info(
224 | f'Validation: {len(self.csv)} examples with lengths {eval_lengths}')
225 | self.csv['index'] = list(range(len(self.csv)))
226 |
227 | def process_csv_row(self, csv_row):
228 | path = csv_row['processed_path']
229 | seq_len = csv_row['modeled_seq_len']
230 | # Large protein files are slow to read. Cache them.
231 | use_cache = seq_len > self._dataset_cfg.cache_num_res
232 | if use_cache and path in self._cache:
233 | return self._cache[path]
234 | processed_row = _process_csv_row(path)
235 | processed_row['pdb_name'] = csv_row['pdb_name']
236 | if self._dataset_cfg.use_redesigned:
237 | best_seq = csv_row['best_seq']
238 | if not isinstance(best_seq, float):
239 | best_aatype = torch.tensor(du.seq_to_aatype(best_seq)).long()
240 | assert processed_row['aatypes_1'].shape == best_aatype.shape
241 | processed_row['aatypes_1'] = best_aatype
242 | aatypes_1 = du.to_numpy(processed_row['aatypes_1'])
243 | if len(set(aatypes_1)) == 1:
244 | raise ValueError(f'Example {path} has only one amino acid.')
245 | if use_cache:
246 | self._cache[path] = processed_row
247 | return processed_row
248 |
249 |
250 | def __getitem__(self, row_idx):
251 | # Process data example.
252 | csv_row = self.csv.iloc[row_idx]
253 | feats = self.process_csv_row(csv_row)
254 |
255 | if self._dataset_cfg.add_plddt_mask:
256 | _add_plddt_mask(feats, self._dataset_cfg.min_plddt_threshold)
257 | else:
258 | feats['plddt_mask'] = torch.ones_like(feats['res_mask'])
259 |
260 | if self.task == 'hallucination':
261 | feats['diffuse_mask'] = torch.ones_like(feats['res_mask']).bool()
262 | else:
263 | raise ValueError(f'Unknown task {self.task}')
264 | feats['diffuse_mask'] = feats['diffuse_mask'].int()
265 |
266 | # Storing the csv index is helpful for debugging.
267 | feats['csv_idx'] = torch.ones(1, dtype=torch.long) * row_idx
268 | return feats
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 | def pdb_init_(
278 | self,
279 | *,
280 | dataset_cfg,
281 | is_training,
282 | task,
283 | ):
284 | self._log = logging.getLogger(__name__)
285 | self._is_training = is_training
286 | self._dataset_cfg = dataset_cfg
287 | self.task = task
288 | self._cache = {}
289 | self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
290 |
291 | # Process clusters
292 | self.raw_csv = pd.read_csv(self.dataset_cfg.csv_path)
293 | metadata_csv = self._filter_metadata(self.raw_csv)
294 | metadata_csv = metadata_csv.sort_values(
295 | 'modeled_seq_len', ascending=False)
296 |
297 | self._pdb_to_cluster = _read_clusters(self._dataset_cfg.cluster_path, synthetic=False)
298 | self._max_cluster = max(self._pdb_to_cluster.values())
299 | self._missing_pdbs = 0
300 | def cluster_lookup(pdb):
301 | pdb = pdb.upper()
302 | if pdb not in self._pdb_to_cluster:
303 | self._pdb_to_cluster[pdb] = self._max_cluster + 1
304 | self._max_cluster += 1
305 | self._missing_pdbs += 1
306 | return self._pdb_to_cluster[pdb]
307 | metadata_csv['cluster'] = metadata_csv['pdb_name'].map(cluster_lookup)
308 | if self._dataset_cfg.use_redesigned:
309 | self.redesigned_csv = pd.read_csv(self._dataset_cfg.redesigned_csv_path)
310 | metadata_csv = metadata_csv.merge(
311 | self.redesigned_csv, left_on='pdb_name', right_on='example')
312 | metadata_csv = metadata_csv[metadata_csv.best_rmsd < 2.0]
313 | if self._dataset_cfg.use_synthetic:
314 | self.synthetic_csv = pd.read_csv(self._dataset_cfg.synthetic_csv_path)
315 | self._synthetic_pdb_to_cluster = _read_clusters(self._dataset_cfg.synthetic_cluster_path, synthetic=True)
316 |
317 | # offset all the cluster numbers by the number of real data clusters
318 | num_real_clusters = metadata_csv['cluster'].max() + 1
319 | def synthetic_cluster_lookup(pdb):
320 | pdb = pdb.upper()
321 | if pdb not in self._synthetic_pdb_to_cluster:
322 | raise ValueError(f"Synthetic example {pdb} not in synthetic cluster file!")
323 | return self._synthetic_pdb_to_cluster[pdb] + num_real_clusters
324 | self.synthetic_csv['cluster'] = self.synthetic_csv['pdb_name'].map(synthetic_cluster_lookup)
325 |
326 | metadata_csv = pd.concat([metadata_csv, self.synthetic_csv])
327 | self._create_split(metadata_csv)
328 |
329 |
330 | if dataset_cfg.test_set_pdb_ids_path is not None:
331 |
332 | test_set_df = pd.read_csv(dataset_cfg.test_set_pdb_ids_path)
333 |
334 | self.csv = self.csv[self.csv['pdb_name'].isin(test_set_df['pdb_name'].values)]
335 |
336 | def pdb_filter_metadata(self, raw_csv):
337 | """Filter metadata."""
338 | filter_cfg = self.dataset_cfg.filter
339 | data_csv = raw_csv[
340 | raw_csv.oligomeric_detail.isin(filter_cfg.oligomeric)]
341 | data_csv = data_csv[
342 | data_csv.num_chains.isin(filter_cfg.num_chains)]
343 | data_csv = _length_filter(
344 | data_csv, filter_cfg.min_num_res, filter_cfg.max_num_res)
345 | data_csv = _max_coil_filter(data_csv, filter_cfg.max_coil_percent)
346 | data_csv = _rog_filter(data_csv, filter_cfg.rog_quantile)
347 | return data_csv
348 |
349 | class PdbDataset(BaseDataset):
350 |
351 | def __init__(self, *, dataset_cfg, is_training, task):
352 | pdb_init_(self, dataset_cfg=dataset_cfg, is_training=is_training, task=task)
353 |
354 | def _filter_metadata(self, raw_csv):
355 | return pdb_filter_metadata(self, raw_csv)
--------------------------------------------------------------------------------
/proteinblobs/multiflow/edge_feature_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .utils import get_index_embedding, calc_distogram
5 |
6 | class EdgeFeatureNet(nn.Module):
7 |
8 | def __init__(self, module_cfg):
9 | # c_s, c_p, relpos_k, template_type):
10 | super(EdgeFeatureNet, self).__init__()
11 | self._cfg = module_cfg
12 |
13 | self.c_s = self._cfg.c_s
14 | self.c_p = self._cfg.c_p
15 | self.feat_dim = self._cfg.feat_dim
16 |
17 | self.linear_s_p = nn.Linear(self.c_s, self.feat_dim)
18 | self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim)
19 |
20 | total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2
21 | if self._cfg.embed_chain:
22 | total_edge_feats += 1
23 | if self._cfg.embed_diffuse_mask:
24 | total_edge_feats += 2
25 | self.edge_embedder = nn.Sequential(
26 | nn.Linear(total_edge_feats, self.c_p),
27 | nn.ReLU(),
28 | nn.Linear(self.c_p, self.c_p),
29 | nn.ReLU(),
30 | nn.Linear(self.c_p, self.c_p),
31 | nn.LayerNorm(self.c_p),
32 | )
33 |
34 | def embed_relpos(self, r):
35 | # AlphaFold 2 Algorithm 4 & 5
36 | # Based on OpenFold utils/tensor_utils.py
37 | # Input: [b, n_res]
38 | # [b, n_res, n_res]
39 | d = r[:, :, None] - r[:, None, :]
40 | pos_emb = get_index_embedding(d, self._cfg.feat_dim, max_len=2056)
41 | return self.linear_relpos(pos_emb)
42 |
43 | def _cross_concat(self, feats_1d, num_batch, num_res):
44 | return torch.cat([
45 | torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)),
46 | torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)),
47 | ], dim=-1).float().reshape([num_batch, num_res, num_res, -1])
48 |
49 | def forward(self, s, t, sc_t, p_mask, diffuse_mask, chain_idx):
50 | # Input: [b, n_res, c_s]
51 | num_batch, num_res, _ = s.shape
52 |
53 | # [b, n_res, c_p]
54 | p_i = self.linear_s_p(s)
55 | cross_node_feats = self._cross_concat(p_i, num_batch, num_res)
56 |
57 | # [b, n_res]
58 | r = torch.arange(
59 | num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1)
60 | relpos_feats = self.embed_relpos(r)
61 |
62 | dist_feats = calc_distogram(
63 | t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
64 | sc_feats = calc_distogram(
65 | sc_t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
66 |
67 | all_edge_feats = [cross_node_feats, relpos_feats, dist_feats, sc_feats]
68 | if self._cfg.embed_chain:
69 | rel_chain = (chain_idx[:, :, None] == chain_idx[:, None, :]).float()
70 | all_edge_feats.append(rel_chain[..., None])
71 | if self._cfg.embed_diffuse_mask:
72 | diff_feat = self._cross_concat(diffuse_mask[..., None], num_batch, num_res)
73 | all_edge_feats.append(diff_feat)
74 | edge_feats = self.edge_embedder(torch.concat(all_edge_feats, dim=-1))
75 | edge_feats *= p_mask.unsqueeze(-1)
76 | return edge_feats
--------------------------------------------------------------------------------
/proteinblobs/multiflow/flow_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import torch
12 | from torch import nn
13 |
14 | from .node_feature_net import NodeFeatureNet
15 | from .edge_feature_net import EdgeFeatureNet
16 | from .cross_attention import InvariantCrossAttention
17 | from . import ipa_pytorch
18 | from .data import utils as du
19 | import math
20 |
21 | def positional_encoding(x, c, max_period=10000, min_period=None):
22 | min_period = min_period or max_period / 10000 # second one recommended
23 | freqs = torch.exp(-torch.linspace(math.log(min_period), math.log(max_period), c // 2, device=x.device))
24 | emb = freqs * x.unsqueeze(-1) # [..., C]
25 | return torch.cat([torch.sin(emb), torch.cos(emb)], -1) # [..., 2C]
26 |
27 | class FlowModel(nn.Module):
28 |
29 | def __init__(self, model_conf, args):
30 | super(FlowModel, self).__init__()
31 | self._model_conf = model_conf
32 | self.args = args
33 | self._ipa_conf = model_conf.ipa
34 | self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE)
35 | self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE)
36 | self.node_feature_net = NodeFeatureNet(model_conf.node_features)
37 | self.edge_feature_net = EdgeFeatureNet(model_conf.edge_features)
38 | if self.args.use_latents:
39 | self.latent_embedder = ipa_pytorch.Linear(32, model_conf.node_features.c_s, init="final")
40 |
41 | if self.args.blob_attention:
42 | self.ground_feature_net = nn.Embedding(3, self._ipa_conf.c_s)
43 | self.ground_trace_feature_net = nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s)
44 | self.ground_size_feature_net = nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s)
45 |
46 | if self._model_conf.aatype_pred:
47 | node_embed_size = self._model_conf.node_embed_size
48 | self.aatype_pred_net = nn.Sequential(
49 | nn.Linear(node_embed_size, node_embed_size),
50 | nn.ReLU(),
51 | nn.Linear(node_embed_size, node_embed_size),
52 | nn.ReLU(),
53 | nn.Linear(node_embed_size, self._model_conf.aatype_pred_num_tokens),
54 | )
55 |
56 | # Attention trunk
57 | self.trunk = nn.ModuleDict()
58 | for b in range(self._ipa_conf.num_blocks):
59 | self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
60 | self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
61 | if self.args.blob_attention:
62 | self.trunk[f'ica_{b}'] = InvariantCrossAttention(self._ipa_conf)
63 | tfmr_in = self._ipa_conf.c_s
64 | tfmr_layer = torch.nn.TransformerEncoderLayer(
65 | d_model=tfmr_in,
66 | nhead=self._ipa_conf.seq_tfmr_num_heads,
67 | dim_feedforward=tfmr_in,
68 | batch_first=True,
69 | dropout=self._model_conf.transformer_dropout,
70 | norm_first=False
71 | )
72 | self.trunk[f'seq_tfmr_{b}'] = nn.TransformerEncoder(tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
73 | self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(tfmr_in, self._ipa_conf.c_s, init="final")
74 |
75 | if self.args.extra_attn_layer:
76 | self.trunk[f'cond_tfmr_{b}'] = nn.TransformerEncoder(tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
77 | self.trunk[f'post_cond_tfmr_{b}'] = ipa_pytorch.Linear(tfmr_in, self._ipa_conf.c_s, init="final")
78 |
79 |
80 | self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(c=self._ipa_conf.c_s)
81 | self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(self._ipa_conf.c_s, use_rot_updates=True)
82 |
83 | if b < self._ipa_conf.num_blocks-1:
84 | # No edge update on the last block.
85 | edge_in = self._model_conf.edge_embed_size
86 | self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
87 | node_embed_size=self._ipa_conf.c_s,
88 | edge_embed_in=edge_in,
89 | edge_embed_out=self._model_conf.edge_embed_size,
90 | )
91 |
92 | if args.freeze_weights:
93 | for name, p in self.named_parameters():
94 | if ('ica' not in name) and ('cond_tfmr' not in name) and ('ground' not in name):
95 | p.requires_grad_(False)
96 |
97 | def forward(self, input_feats):
98 | node_mask = input_feats['res_mask']
99 | edge_mask = node_mask[:, None] * node_mask[:, :, None]
100 | diffuse_mask = input_feats['diffuse_mask']
101 | chain_index = input_feats['chain_idx']
102 | res_index = input_feats['res_idx']
103 | latents = input_feats['latents']
104 | so3_t = input_feats['so3_t']
105 | r3_t = input_feats['r3_t']
106 | cat_t = input_feats['cat_t']
107 | trans_t = input_feats['trans_t']
108 | rotmats_t = input_feats['rotmats_t']
109 | aatypes_t = input_feats['aatypes_t'].long()
110 | trans_sc = input_feats['trans_sc']
111 | aatypes_sc = input_feats['aatypes_sc']
112 |
113 | ##########
114 | if self.args.blob_attention:
115 | grounding_feats = input_feats['grounding_feat']
116 | grounding_pos = input_feats['grounding_pos']
117 | grounding_mask = input_feats['grounding_mask']
118 |
119 | ###
120 | B, L, _ = grounding_pos.shape
121 | grounding_covar = grounding_feats[:,:,-9:].view(B, L, 3, 3)
122 | trace = grounding_covar[:,:,torch.arange(3),torch.arange(3)].sum(-1)
123 | grounding_covar = grounding_covar / (trace[:,:,None,None] + 1e-3) # "normalized covariance"
124 | ###
125 | grounding_feats = (
126 | self.ground_feature_net(grounding_feats[...,0].long()) +
127 | self.ground_size_feature_net(
128 | positional_encoding(grounding_feats[...,1].float(), self._ipa_conf.c_s)
129 | ) +
130 | self.ground_trace_feature_net(
131 | positional_encoding(trace, self._ipa_conf.c_s)
132 | )
133 | )
134 | assert not torch.isnan(grounding_feats).any()
135 | ##########
136 |
137 | # Initialize node and edge embeddings
138 | init_node_embed = self.node_feature_net(
139 | so3_t=so3_t,
140 | r3_t=r3_t,
141 | cat_t=cat_t,
142 | res_mask=node_mask,
143 | diffuse_mask=diffuse_mask,
144 | chain_index=chain_index,
145 | pos=res_index,
146 | aatypes=aatypes_t,
147 | aatypes_sc=aatypes_sc,
148 | )
149 | if self.args.use_latents:
150 | init_node_embed = init_node_embed + self.latent_embedder(latents)[:,None,:]
151 |
152 | init_edge_embed = self.edge_feature_net(
153 | init_node_embed,
154 | trans_t,
155 | trans_sc,
156 | edge_mask,
157 | diffuse_mask,
158 | chain_index
159 | )
160 |
161 |
162 | # Initial rigids
163 | init_rigids = du.create_rigid(rotmats_t, trans_t)
164 | curr_rigids = du.create_rigid(rotmats_t, trans_t)
165 |
166 | # Main trunk
167 | curr_rigids = self.rigids_ang_to_nm(curr_rigids)
168 | init_node_embed = init_node_embed * node_mask[..., None]
169 | node_embed = init_node_embed * node_mask[..., None]
170 | edge_embed = init_edge_embed * edge_mask[..., None]
171 | for b in range(self._ipa_conf.num_blocks):
172 | ipa_embed = self.trunk[f'ipa_{b}'](
173 | node_embed,
174 | edge_embed,
175 | curr_rigids,
176 | node_mask)
177 | ipa_embed *= node_mask[..., None]
178 | node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
179 |
180 | ############
181 | if self.args.blob_attention:
182 | ica_out = self.trunk[f'ica_{b}'](
183 | node_embed,
184 | curr_rigids,
185 | grounding_pos * du.ANG_TO_NM_SCALE, # FIX,
186 | grounding_feats,
187 | grounding_covar,
188 | grounding_mask
189 | )
190 | node_embed = node_embed + ica_out * (1 if self.training else self.args.inference_gating)
191 | ###########
192 |
193 | def heterogenous_attention(layer, post_layer, node_embed, node_mask, grounding_feats, grounding_mask):
194 |
195 | node_embed_ = torch.cat([node_embed, grounding_feats], 1)
196 | node_mask_ = torch.cat([node_mask, grounding_mask], 1)
197 | cond_tfmr_out = layer(node_embed_, src_key_padding_mask=(1 - node_mask_).bool())
198 | L = node_embed.shape[1]
199 |
200 | cond_tfmr_out = post_layer(cond_tfmr_out)
201 | node_embed = node_embed + cond_tfmr_out[:,:L] * (1 if self.training else self.args.inference_gating)
202 | grounding_feats = grounding_feats + cond_tfmr_out[:,L:]
203 |
204 | return node_embed, grounding_feats
205 |
206 | if self.args.blob_attention and not self.args.extra_attn_layer:
207 | node_embed, grounding_feats = heterogenous_attention(
208 | self.trunk[f'seq_tfmr_{b}'],
209 | self.trunk[f'post_tfmr_{b}'],
210 | node_embed, node_mask, grounding_feats, grounding_mask)
211 |
212 | else:
213 | seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
214 | node_embed, src_key_padding_mask=(1 - node_mask).to(torch.bool))
215 | node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
216 |
217 | ################
218 | if self.args.blob_attention and self.args.extra_attn_layer:
219 | node_embed, grounding_feats = heterogenous_attention(
220 | self.trunk[f'cond_tfmr_{b}'],
221 | self.trunk[f'post_cond_tfmr_{b}'],
222 | node_embed, node_mask, grounding_feats, grounding_mask)
223 |
224 | ################
225 | node_embed = self.trunk[f'node_transition_{b}'](node_embed)
226 | node_embed = node_embed * node_mask[..., None]
227 | rigid_update = self.trunk[f'bb_update_{b}'](
228 | node_embed * node_mask[..., None])
229 | curr_rigids = curr_rigids.compose_q_update_vec(
230 | rigid_update, (node_mask * diffuse_mask)[..., None])
231 | if b < self._ipa_conf.num_blocks-1:
232 | edge_embed = self.trunk[f'edge_transition_{b}'](
233 | node_embed, edge_embed)
234 | edge_embed *= edge_mask[..., None]
235 |
236 | curr_rigids = self.rigids_nm_to_ang(curr_rigids)
237 | pred_trans = curr_rigids.get_trans()
238 | pred_rotmats = curr_rigids.get_rots().get_rot_mats()
239 | if self._model_conf.aatype_pred:
240 | pred_logits = self.aatype_pred_net(node_embed)
241 | pred_aatypes = torch.argmax(pred_logits, dim=-1)
242 | if self._model_conf.aatype_pred_num_tokens == du.NUM_TOKENS + 1:
243 | pred_logits_wo_mask = pred_logits.clone()
244 | pred_logits_wo_mask[:, :, du.MASK_TOKEN_INDEX] = -1e9
245 | pred_aatypes = torch.argmax(pred_logits_wo_mask, dim=-1)
246 | else:
247 | pred_aatypes = torch.argmax(pred_logits, dim=-1)
248 | else:
249 | pred_aatypes = aatypes_t
250 | pred_logits = nn.functional.one_hot(
251 | pred_aatypes, num_classes=self._model_conf.aatype_pred_num_tokens
252 | ).float()
253 | return {
254 | 'pred_trans': pred_trans,
255 | 'pred_rotmats': pred_rotmats,
256 | 'pred_logits': pred_logits,
257 | 'pred_aatypes': pred_aatypes,
258 | }
259 |
--------------------------------------------------------------------------------
/proteinblobs/multiflow/node_feature_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .utils import get_index_embedding, get_time_embedding
4 |
5 |
6 | class NodeFeatureNet(nn.Module):
7 |
8 | def __init__(self, module_cfg):
9 | super(NodeFeatureNet, self).__init__()
10 | self._cfg = module_cfg
11 | self.c_s = self._cfg.c_s
12 | self.c_pos_emb = self._cfg.c_pos_emb
13 | self.c_timestep_emb = self._cfg.c_timestep_emb
14 | embed_size = self._cfg.c_pos_emb + self._cfg.c_timestep_emb * 2 + 1
15 | if self._cfg.embed_chain:
16 | embed_size += self._cfg.c_pos_emb
17 | if self._cfg.embed_aatype:
18 | self.aatype_embedding = nn.Embedding(21, self.c_s) # Always 21 because of 20 amino acids + 1 for unk
19 | embed_size += self.c_s + self._cfg.c_timestep_emb + self._cfg.aatype_pred_num_tokens
20 | if self._cfg.use_mlp:
21 | self.linear = nn.Sequential(
22 | nn.Linear(embed_size, self.c_s),
23 | nn.ReLU(),
24 | nn.Linear(self.c_s, self.c_s),
25 | nn.ReLU(),
26 | nn.Linear(self.c_s, self.c_s),
27 | nn.LayerNorm(self.c_s),
28 | )
29 | else:
30 | self.linear = nn.Linear(embed_size, self.c_s)
31 |
32 | def embed_t(self, timesteps, mask):
33 | timestep_emb = get_time_embedding(
34 | timesteps[:, 0],
35 | self.c_timestep_emb,
36 | max_positions=2056
37 | )[:, None, :].repeat(1, mask.shape[1], 1)
38 | return timestep_emb * mask.unsqueeze(-1)
39 |
40 | def forward(
41 | self,
42 | *,
43 | so3_t,
44 | r3_t,
45 | cat_t,
46 | res_mask,
47 | diffuse_mask,
48 | chain_index,
49 | pos,
50 | aatypes,
51 | aatypes_sc,
52 | ):
53 | # s: [b]
54 |
55 | # [b, n_res, c_pos_emb]
56 | pos_emb = get_index_embedding(pos, self.c_pos_emb, max_len=2056)
57 | pos_emb = pos_emb * res_mask.unsqueeze(-1)
58 |
59 | # [b, n_res, c_timestep_emb]
60 | input_feats = [
61 | pos_emb,
62 | diffuse_mask[..., None],
63 | self.embed_t(so3_t, res_mask),
64 | self.embed_t(r3_t, res_mask)
65 | ]
66 | if self._cfg.embed_aatype:
67 | input_feats.append(self.aatype_embedding(aatypes))
68 | input_feats.append(self.embed_t(cat_t, res_mask))
69 | input_feats.append(aatypes_sc)
70 | if self._cfg.embed_chain:
71 | input_feats.append(
72 | get_index_embedding(
73 | chain_index,
74 | self.c_pos_emb,
75 | max_len=100
76 | )
77 | )
78 | return self.linear(torch.cat(input_feats, dim=-1))
79 |
--------------------------------------------------------------------------------
/proteinblobs/multiflow/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import re
4 | import numpy as np
5 | import pandas as pd
6 | from torch.nn import functional as F
7 | from .data import utils as du
8 | from openfold.utils.superimposition import superimpose
9 | from openfold.np import residue_constants
10 | import pickle
11 |
12 | CA_IDX = residue_constants.atom_order['CA']
13 |
14 | def calc_distogram(pos, min_bin, max_bin, num_bins):
15 | dists_2d = torch.linalg.norm(
16 | pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None]
17 | lower = torch.linspace(
18 | min_bin,
19 | max_bin,
20 | num_bins,
21 | device=pos.device)
22 | upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
23 | dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
24 | return dgram
25 |
26 |
27 | def get_index_embedding(indices, embed_size, max_len=2056):
28 | """Creates sine / cosine positional embeddings from a prespecified indices.
29 |
30 | Args:
31 | indices: offsets of size [..., N_edges] of type integer
32 | max_len: maximum length.
33 | embed_size: dimension of the embeddings to create
34 |
35 | Returns:
36 | positional embedding of shape [N, embed_size]
37 | """
38 | K = torch.arange(embed_size//2, device=indices.device)
39 | pos_embedding_sin = torch.sin(
40 | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
41 | pos_embedding_cos = torch.cos(
42 | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
43 | pos_embedding = torch.cat([
44 | pos_embedding_sin, pos_embedding_cos], axis=-1)
45 | return pos_embedding
46 |
47 |
48 | def get_time_embedding(timesteps, embedding_dim, max_positions=2000):
49 | # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
50 | assert len(timesteps.shape) == 1
51 | timesteps = timesteps * max_positions
52 | half_dim = embedding_dim // 2
53 | emb = math.log(max_positions) / (half_dim - 1)
54 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
55 | emb = timesteps.float()[:, None] * emb[None, :]
56 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
57 | if embedding_dim % 2 == 1: # zero pad
58 | emb = F.pad(emb, (0, 1), mode='constant')
59 | assert emb.shape == (timesteps.shape[0], embedding_dim)
60 | return emb
61 |
62 |
63 | def sinusoidal_encoding(v, N, D):
64 | """Taken from GENIE.
65 |
66 | Args:
67 |
68 | """
69 | # v: [*]
70 |
71 | # [D]
72 | k = torch.arange(1, D+1).to(v.device)
73 |
74 | # [*, D]
75 | sin_div_term = N ** (2 * k / D)
76 | sin_div_term = sin_div_term.view(*((1, ) * len(v.shape) + (len(sin_div_term), )))
77 | sin_enc = torch.sin(v.unsqueeze(-1) * math.pi / sin_div_term)
78 |
79 | # [*, D]
80 | cos_div_term = N ** (2 * (k - 1) / D)
81 | cos_div_term = cos_div_term.view(*((1, ) * len(v.shape) + (len(cos_div_term), )))
82 | cos_enc = torch.cos(v.unsqueeze(-1) * math.pi / cos_div_term)
83 |
84 | # [*, D]
85 | enc = torch.zeros_like(sin_enc).to(v.device)
86 | enc[..., 0::2] = cos_enc[..., 0::2]
87 | enc[..., 1::2] = sin_enc[..., 1::2]
88 |
89 | return enc.to(v.dtype)
90 |
91 |
92 | def distance(p, eps=1e-10):
93 | # [*, 2, 3]
94 | return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
95 |
96 |
97 | def dist_from_ca(trans):
98 |
99 | # [b, n_res, n_res, 1]
100 | d = distance(torch.stack([
101 | trans.unsqueeze(2).repeat(1, 1, trans.shape[1], 1), # Ca_1
102 | trans.unsqueeze(1).repeat(1, trans.shape[1], 1, 1), # Ca_2
103 | ], dim=-2)).unsqueeze(-1)
104 |
105 | return d
106 |
107 |
108 | def calc_rbf(ca_dists, num_rbf, D_min=1e-3, D_max=22.):
109 | # Distance radial basis function
110 | device = ca_dists.device
111 | D_mu = torch.linspace(D_min, D_max, num_rbf).to(device)
112 | D_mu = D_mu.view([1,1,1,-1])
113 | D_sigma = (D_max - D_min) / num_rbf
114 | return torch.exp(-((ca_dists - D_mu) / D_sigma)**2)
115 |
116 |
117 | def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None):
118 | """Stratify loss by binning t."""
119 | batch_t = du.to_numpy(batch_t)
120 | batch_loss = du.to_numpy(batch_loss)
121 | flat_losses = batch_loss.flatten()
122 | flat_t = batch_t.flatten()
123 | bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins+1)
124 | bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1
125 | t_binned_loss = np.bincount(bin_idx, weights=flat_losses)
126 | t_binned_n = np.bincount(bin_idx)
127 | stratified_losses = {}
128 | if loss_name is None:
129 | loss_name = 'loss'
130 | for t_bin in np.unique(bin_idx).tolist():
131 | bin_start = bin_edges[t_bin]
132 | bin_end = bin_edges[t_bin+1]
133 | t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})'
134 | range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin]
135 | stratified_losses[t_range] = range_loss
136 | return stratified_losses
137 |
138 |
139 |
140 | def process_folded_outputs(sample_path, folded_output, true_bb_pos=None):
141 | mpnn_results = {
142 | 'header': [],
143 | 'sequence': [],
144 | 'bb_rmsd': [],
145 | 'mean_plddt': [],
146 | 'folded_path': [],
147 | }
148 |
149 | if true_bb_pos is not None:
150 | mpnn_results['bb_rmsd_to_gt'] = []
151 | mpnn_results['fold_model_bb_rmsd_to_gt'] = []
152 |
153 | sample_feats = du.parse_pdb_feats('sample', sample_path)
154 | sample_ca_pos = sample_feats['bb_positions']
155 | def _calc_ca_rmsd(mask, folded_ca_pos):
156 | return superimpose(
157 | torch.tensor(sample_ca_pos)[None],
158 | torch.tensor(folded_ca_pos[None]),
159 | mask
160 | )[1].rmsd[0].item()
161 |
162 | sample_bb_pos = sample_feats['atom_positions'][:, :3].reshape(-1, 3)
163 | def _calc_bb_rmsd(mask, sample_bb_pos, folded_bb_pos):
164 | aligned_rmsd = superimpose(
165 | torch.tensor(sample_bb_pos)[None],
166 | torch.tensor(folded_bb_pos[None]),
167 | mask[:, None].repeat(1, 3).reshape(-1)
168 | )
169 | return aligned_rmsd[1].item()
170 |
171 | for _, row in folded_output.iterrows():
172 | folded_feats = du.parse_pdb_feats('folded', row.folded_path)
173 | seq = du.aatype_to_seq(folded_feats['aatype'])
174 | folded_ca_pos = folded_feats['bb_positions']
175 | folded_bb_pos = folded_feats['atom_positions'][:, :3].reshape(-1, 3)
176 |
177 | res_mask = torch.ones(folded_ca_pos.shape[0])
178 |
179 | if true_bb_pos is not None:
180 | bb_rmsd_to_gt = _calc_bb_rmsd(res_mask, sample_bb_pos, true_bb_pos)
181 | mpnn_results['bb_rmsd_to_gt'].append(bb_rmsd_to_gt)
182 | fold_model_bb_rmsd_to_gt = _calc_bb_rmsd(res_mask, folded_bb_pos, true_bb_pos)
183 | mpnn_results['fold_model_bb_rmsd_to_gt'].append(fold_model_bb_rmsd_to_gt)
184 | bb_rmsd = _calc_bb_rmsd(res_mask, sample_bb_pos, folded_bb_pos)
185 | mpnn_results['bb_rmsd'].append(bb_rmsd)
186 | mpnn_results['folded_path'].append(row.folded_path)
187 | mpnn_results['header'].append(row.header)
188 | mpnn_results['sequence'].append(seq)
189 | mpnn_results['mean_plddt'].append(row.plddt)
190 | mpnn_results = pd.DataFrame(mpnn_results)
191 | mpnn_results['sample_path'] = sample_path
192 | return mpnn_results
193 |
194 | def extract_clusters_from_maxcluster_out(file_path):
195 | # Extracts cluster information from the stdout of a maxcluster run
196 | cluster_to_paths = {}
197 | paths_to_cluster = {}
198 | read_mode = False
199 | with open(file_path, 'r') as file:
200 | lines = file.readlines()
201 | for line in lines:
202 | if line == "INFO : Item Cluster\n":
203 | read_mode = True
204 | continue
205 |
206 | if line == "INFO : ======================================\n":
207 | read_mode = False
208 |
209 | if read_mode:
210 | # Define a regex pattern to match the second number and the path
211 | pattern = r"INFO\s+:\s+\d+\s:\s+(\d+)\s+(\S+)"
212 |
213 | # Use re.search to find the first match in the string
214 | match = re.search(pattern, line)
215 |
216 | # Check if a match is found
217 | if match:
218 | # Extract the second number and the path
219 | cluster_id = match.group(1)
220 | path = match.group(2)
221 | if cluster_id not in cluster_to_paths:
222 | cluster_to_paths[cluster_id] = [path]
223 | else:
224 | cluster_to_paths[cluster_id].append(path)
225 | paths_to_cluster[path] = cluster_id
226 |
227 | else:
228 | raise ValueError(f"Could not parse line: {line}")
229 |
230 | return cluster_to_paths, paths_to_cluster
231 |
232 | def calc_mdtraj_metrics(pdb_path):
233 | try:
234 | traj = md.load(pdb_path)
235 | pdb_ss = md.compute_dssp(traj, simplified=True)
236 | pdb_coil_percent = np.mean(pdb_ss == 'C')
237 | pdb_helix_percent = np.mean(pdb_ss == 'H')
238 | pdb_strand_percent = np.mean(pdb_ss == 'E')
239 | pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
240 | pdb_rg = md.compute_rg(traj)[0]
241 | except IndexError as e:
242 | print('Error in calc_mdtraj_metrics: {}'.format(e))
243 | pdb_ss_percent = 0.0
244 | pdb_coil_percent = 0.0
245 | pdb_helix_percent = 0.0
246 | pdb_strand_percent = 0.0
247 | pdb_rg = 0.0
248 | return {
249 | 'non_coil_percent': pdb_ss_percent,
250 | 'coil_percent': pdb_coil_percent,
251 | 'helix_percent': pdb_helix_percent,
252 | 'strand_percent': pdb_strand_percent,
253 | 'radius_of_gyration': pdb_rg,
254 | }
255 |
256 | def calc_aatype_metrics(generated_aatypes):
257 | # generated_aatypes (B, N)
258 | unique_aatypes, raw_counts = np.unique(generated_aatypes, return_counts=True)
259 |
260 | # pad with 0's in case it didn't generate any of a certain type
261 | clean_counts = []
262 | for i in range(20):
263 | if i in unique_aatypes:
264 | clean_counts.append(raw_counts[np.where(unique_aatypes == i)[0][0]])
265 | else:
266 | clean_counts.append(0)
267 |
268 | # from the scope128 dataset
269 | reference_normalized_counts = [
270 | 0.0739, 0.05378621, 0.0410424, 0.05732177, 0.01418736, 0.03995128,
271 | 0.07562267, 0.06695857, 0.02163064, 0.0580802, 0.09333149, 0.06777057,
272 | 0.02034217, 0.03673995, 0.04428474, 0.05987899, 0.05502958, 0.01228988,
273 | 0.03233601, 0.07551553
274 | ]
275 |
276 | reference_normalized_counts = np.array(reference_normalized_counts)
277 |
278 | normalized_counts = clean_counts / np.sum(clean_counts)
279 |
280 | # compute the hellinger distance between the normalized counts
281 | # and the reference normalized counts
282 |
283 | hellinger_distance = np.sqrt(np.sum(np.square(np.sqrt(normalized_counts) - np.sqrt(reference_normalized_counts))))
284 |
285 | return {
286 | 'aatype_histogram_dist': hellinger_distance
287 | }
288 |
289 | def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0):
290 | ca_bond_dists = np.linalg.norm(
291 | ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:]
292 | ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca))
293 | ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol))
294 |
295 | ca_ca_dists2d = np.linalg.norm(
296 | ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1)
297 | inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)]
298 | clashes = inter_dists < clash_tol
299 | return {
300 | 'ca_ca_deviation': ca_ca_dev,
301 | 'ca_ca_valid_percent': ca_ca_valid,
302 | 'num_ca_ca_clashes': np.sum(clashes),
303 | }
304 |
305 | class CPU_Unpickler(pickle.Unpickler):
306 | """Pytorch pickle loading workaround.
307 |
308 | https://github.com/pytorch/pytorch/issues/16797
309 | """
310 | def find_class(self, module, name):
311 | if module == 'torch.storage' and name == '_load_from_bytes':
312 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
313 | else: return super().find_class(module, name)
314 |
315 |
316 |
317 | def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None):
318 | """Read data from a pickle file."""
319 | try:
320 | if use_torch:
321 | return torch.load(read_path, map_location=map_location)
322 | else:
323 | with open(read_path, 'rb') as handle:
324 | return pickle.load(handle)
325 | except Exception as e:
326 | try:
327 | with open(read_path, 'rb') as handle:
328 | return CPU_Unpickler(handle).load()
329 | except Exception as e2:
330 | if verbose:
331 | print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}')
332 | raise(e)
333 |
334 | def parse_chain_feats(chain_feats, scale_factor=1., center=True):
335 | chain_feats['bb_mask'] = chain_feats['atom_mask'][:, CA_IDX]
336 | bb_pos = chain_feats['atom_positions'][:, CA_IDX]
337 | if center:
338 | bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5)
339 | centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
340 | scaled_pos = centered_pos / scale_factor
341 | else:
342 | scaled_pos = chain_feats['atom_positions'] / scale_factor
343 | chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
344 | chain_feats['bb_positions'] = chain_feats['atom_positions'][:, CA_IDX]
345 | return chain_feats
346 |
--------------------------------------------------------------------------------
/proteinblobs/parsing.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from argparse import ArgumentParser
12 | import subprocess, os, sys
13 |
14 |
15 | def parse_train_args(args=sys.argv):
16 | parser = ArgumentParser()
17 |
18 | ## Trainer settings
19 |
20 | ## Trainer settings
21 | group = parser.add_argument_group("Epoch settings")
22 | group.add_argument("--ckpt", type=str, default=None)
23 | group.add_argument("--validate", action='store_true', default=False)
24 | group.add_argument("--num_workers", type=int, default=10)
25 | group.add_argument("--epochs", type=int, default=100)
26 | group.add_argument("--train_batches", type=int, default=None)
27 | group.add_argument("--val_batches", type=int, default=None)
28 | group.add_argument("--batch_size", type=int, default=16)
29 | group.add_argument("--val_freq", type=int, default=None)
30 | group.add_argument("--val_epoch_freq", type=int, default=1)
31 | group.add_argument("--seed", type=int, default=137)
32 | group.add_argument("--no_validate", action='store_true')
33 |
34 | ## Inference settings
35 | group = parser.add_argument_group("Inference Settings")
36 | group.add_argument("--inference_gating", type=float, default=1)
37 | group.add_argument("--guidance_weight", type=float, default=1)
38 | group.add_argument("--seq_guidance_weight", type=float, default=1)
39 | group.add_argument("--guidance", action='store_true')
40 | group.add_argument("--sc_guidance_mode", choices=['separate', 'guided', 'unguided'], default='separate', help='which model output to use as self conditinoing input to the models')
41 | group.add_argument("--interpolate_sc", action='store_true')
42 | group.add_argument("--only_save_blobs", action='store_true')
43 | group.add_argument("--interpolate_sc_weight", type=float, default=1)
44 | group.add_argument("--inference_rot_scaling", type=float, default=10)
45 | group.add_argument("--fixed_inference_size", type=int, default=None)
46 |
47 | ## Model settings
48 | group = parser.add_argument_group("Model Settings")
49 | group.add_argument("--pretrained_mf_path", type=str, default='weights/last.ckpt')
50 | group.add_argument("--finetune", action='store_true')
51 | group.add_argument("--freeze_weights", action='store_true')
52 | group.add_argument("--extra_attn_layer", action='store_true')
53 | group.add_argument("--blob_attention", action='store_true')
54 | group.add_argument("--blob_drop_prob", type=float, default=0.0)
55 | group.add_argument("--self_condition", action='store_true')
56 |
57 | ## Data settings
58 | group = parser.add_argument_group("Data Settings")
59 | group.add_argument('--dataset', choices=['multiflow', 'genie'], default='genie')
60 | group.add_argument("--scop_dir", type=str, default='')
61 | group.add_argument("--genie_db_path", type=str, default='', help= '')
62 | group.add_argument("--scop_nmr_csv", type=str, default='scop_nmr_files.csv')
63 | group.add_argument('--pkl_dir', type=str, default='data')
64 | group.add_argument('--clustering_dir', type=str, default='proteinblobs/clustering_20rpc')
65 | group.add_argument('--latents_path', type=str, default='projects/protfid/cache/encodings/SimCLR-ca-dim_32/mf_all.pth')
66 | group.add_argument('--latents_order', type=str, default='projects/protfid/cache/encodings/SimCLR-ca-dim_32/mf_all_names.pickle')
67 | group.add_argument("--use_latents", action='store_true')
68 | group.add_argument('--length_dist_npz', type=str, default=None)
69 | group.add_argument("--crop", type=int, default=256)
70 | group.add_argument("--no_crop", action='store_true')
71 | group.add_argument("--res_per_cluster", type=int, default=20)
72 | group.add_argument("--min_blob_thresh", type=float, default=5)
73 | group.add_argument("--max_blob_thresh", type=float, default=10)
74 | group.add_argument("--dataset_multiplicity", type=int, default=None)
75 | group.add_argument("--repeat_dataset", type=int, default=1)
76 | group.add_argument("--no_pad", action='store_true')
77 | group.add_argument("--overfit", action='store_true')
78 | group.add_argument("--overfit_rot", action='store_true')
79 |
80 | ## Synthetic Blob data settings
81 | group = parser.add_argument_group("Data Settings")
82 | group.add_argument("--synthetic_blobs", action='store_true')
83 | group.add_argument('--num_blobs', type=int, default=5)
84 | group.add_argument('--nu', type=int, default=10) # vary the blob anisotropy [10, 20, 50, 100]
85 | group.add_argument('--psi', type=int, default=5) # FIX this to give proteins with O(200) residues
86 | group.add_argument('--sigma', type=int, default=8) # vary this compactness parameter
87 | group.add_argument('--helix_frac', type=float, default=0.5)
88 |
89 | ## Logging args
90 | group = parser.add_argument_group("Logging settings")
91 | group.add_argument("--print_freq", type=int, default=100)
92 | group.add_argument("--ckpt_freq", type=int, default=1)
93 | group.add_argument("--wandb", action="store_true")
94 | group.add_argument("--save_val", action="store_true")
95 | group.add_argument("--save_single_pdb", action="store_true")
96 | group.add_argument("--run_name", type=str, default="default")
97 | group.add_argument("--inf_batches", type=int, default=4)
98 | group.add_argument("--designability", action='store_true')
99 | group.add_argument("--self_consistency", action='store_true')
100 | group.add_argument("--ref_as_sample", action='store_true')
101 | group.add_argument("--num_timesteps", type=int, default=500)
102 | group.add_argument("--designability_freq", type=int, default=1)
103 | group.add_argument("--num_designability_prots", type=int, default=1000)
104 | group.add_argument("--pmpnn_path", type=str, default='../ProteinMPNN')
105 |
106 | ## Optimization settings
107 | group = parser.add_argument_group("Optimization settings")
108 | group.add_argument("--accumulate_grad", type=int, default=1)
109 | group.add_argument("--lr_scheduler", action='store_true')
110 | group.add_argument("--grad_clip", type=float, default=1.)
111 | group.add_argument("--check_grad", action='store_true')
112 | group.add_argument('--grad_checkpointing', action='store_true')
113 | group.add_argument('--adamW', action='store_true')
114 | group.add_argument('--ema', action='store_true')
115 | group.add_argument('--ema_decay', type=float, default=0.999)
116 | group.add_argument("--lr", type=float, default=1e-4)
117 |
118 | args = parser.parse_args()
119 | os.environ["MODEL_DIR"] = os.path.join("workdir", args.run_name)
120 | os.environ["WANDB_LOGGING"] = str(int(args.wandb))
121 | return args
122 |
123 |
124 |
--------------------------------------------------------------------------------
/proteinblobs/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import os
12 | import re
13 | import subprocess
14 |
15 | import numpy as np
16 | import torch
17 |
18 | from .multiflow.data.all_atom import compute_backbone
19 | import openfold.np.protein as protein
20 |
21 |
22 | def rigid_transform_3D(A, B, verbose=False):
23 | # Transforms A to look like B
24 | # https://github.com/nghiaho12/rigid_transform_3D
25 | assert A.shape == B.shape
26 | A = A.T
27 | B = B.T
28 |
29 | num_rows, num_cols = A.shape
30 | if num_rows != 3:
31 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
32 |
33 | num_rows, num_cols = B.shape
34 | if num_rows != 3:
35 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
36 |
37 | # find mean column wise
38 | centroid_A = np.mean(A, axis=1)
39 | centroid_B = np.mean(B, axis=1)
40 |
41 | # ensure centroids are 3x1
42 | centroid_A = centroid_A.reshape(-1, 1)
43 | centroid_B = centroid_B.reshape(-1, 1)
44 |
45 | # subtract mean
46 | Am = A - centroid_A
47 | Bm = B - centroid_B
48 |
49 | H = Am @ np.transpose(Bm)
50 |
51 | # sanity check
52 | # if linalg.matrix_rank(H) < 3:
53 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))
54 |
55 | # find rotation
56 | U, S, Vt = np.linalg.svd(H)
57 | R = Vt.T @ U.T
58 |
59 | # special reflection case
60 | reflection_detected = False
61 | if np.linalg.det(R) < 0:
62 | if verbose:
63 | print("det(R) < R, reflection detected!, correcting for it ...")
64 | Vt[2, :] *= -1
65 | R = Vt.T @ U.T
66 | reflection_detected = True
67 |
68 | t = -R @ centroid_A + centroid_B
69 | optimal_A = R @ A + t
70 |
71 | return optimal_A.T, R, t, reflection_detected
72 |
73 |
74 | def run_pmpnn(input_dir, output_path, pmpnn_path="../ProteinMPNN"):
75 | os.makedirs(os.path.join(input_dir, "seqs"), exist_ok=True)
76 | process = subprocess.Popen(
77 | [
78 | "python",
79 | os.path.join(pmpnn_path, "helper_scripts/parse_multiple_chains.py"),
80 | f"--input_path={input_dir}",
81 | f"--output_path={output_path}",
82 | ]
83 | )
84 | _ = process.wait()
85 |
86 | pmpnn_args = [
87 | "python",
88 | os.path.join(pmpnn_path, "protein_mpnn_run.py"),
89 | "--out_folder",
90 | input_dir,
91 | "--jsonl_path",
92 | output_path,
93 | "--num_seq_per_target",
94 | "8",
95 | "--sampling_temp",
96 | "0.1",
97 | "--seed",
98 | "38",
99 | "--batch_size",
100 | "1",
101 | ]
102 | print(" ".join(pmpnn_args))
103 |
104 | process = subprocess.run(pmpnn_args)
105 |
106 |
107 | def get_aligned_rmsd(pos_1, pos_2):
108 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0]
109 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
110 |
111 |
112 | def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]):
113 | """Removes prefixes like 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
114 | pattern = re.compile("^" + "|".join(prefixes))
115 | state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
116 | return state_dict
117 |
118 |
119 | def softmax_cross_entropy(logits, labels):
120 | loss = -1 * torch.sum(
121 | labels * torch.nn.functional.log_softmax(logits, dim=-1),
122 | dim=-1,
123 | )
124 | return loss
125 |
126 |
127 | def compute_distogram_loss(
128 | logits,
129 | pseudo_beta,
130 | pseudo_beta_mask,
131 | min_bin=2.3125,
132 | max_bin=21.6875,
133 | no_bins=64,
134 | eps=1e-6,
135 | **kwargs,
136 | ):
137 | boundaries = torch.linspace(
138 | min_bin,
139 | max_bin,
140 | no_bins - 1,
141 | device=logits.device,
142 | )
143 | boundaries = boundaries**2
144 |
145 | dists = torch.sum(
146 | (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
147 | dim=-1,
148 | keepdims=True,
149 | )
150 |
151 | true_bins = torch.sum(dists > boundaries, dim=-1)
152 |
153 | errors = softmax_cross_entropy(
154 | logits,
155 | torch.nn.functional.one_hot(true_bins, no_bins),
156 | )
157 |
158 | square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
159 |
160 | # FP16-friendly sum. Equivalent to:
161 | # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
162 | # (eps + torch.sum(square_mask, dim=(-1, -2))))
163 | denom = eps + torch.sum(square_mask, dim=(-1, -2))
164 | mean = errors * square_mask
165 | mean = torch.sum(mean, dim=-1)
166 | mean = mean / denom[..., None]
167 | mean = torch.sum(mean, dim=-1)
168 |
169 | # Average over the batch dimensions
170 | mean = torch.mean(mean)
171 |
172 | return mean
173 |
174 |
175 | class HarmonicPrior:
176 | def __init__(self, N=256, a=3 / (3.8**2)):
177 | J = torch.zeros(N, N)
178 | for i, j in zip(np.arange(N - 1), np.arange(1, N)):
179 | J[i, i] += a
180 | J[j, j] += a
181 | J[i, j] = J[j, i] = -a
182 | D, P = torch.linalg.eigh(J)
183 | D_inv = 1 / D
184 | D_inv[0] = 0
185 | self.P, self.D_inv = P, D_inv
186 | self.N = N
187 |
188 | def to(self, device):
189 | self.P = self.P.to(device)
190 | self.D_inv = self.D_inv.to(device)
191 | return self
192 |
193 | def sample(self, batch_dims=()):
194 | return self.P @ (
195 | torch.sqrt(self.D_inv)[:, None]
196 | * torch.randn(*batch_dims, self.N, 3, device=self.P.device)
197 | )
198 |
199 |
200 | def adjust_oxygen_pos(atom_37: torch.Tensor, pos_is_known=None) -> torch.Tensor:
201 | """
202 | Imputes the position of the oxygen atom on the backbone by using adjacent frame information.
203 | Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the
204 | current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom
205 | away from the C in the current frame in the direction away from the Ca-C-N triangle.
206 |
207 | For cases where the next frame is not available, for example we are at the C-terminus or the
208 | next frame is not available in the data then we place the oxygen in the same plane as the
209 | N-Ca-C of the current frame and pointing in the same direction as the average of the
210 | Ca->C and Ca->N vectors.
211 |
212 | Args:
213 | atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering
214 | which is ['N', 'CA', 'C', 'CB', 'O', ...]
215 | pos_is_known (torch.Tensor): (N,) mask for known residues.
216 | """
217 |
218 | N = atom_37.shape[0]
219 | assert atom_37.shape == (N, 37, 3)
220 |
221 | # Get vectors to Carbonly from Carbon alpha and N of next residue. (N-1, 3)
222 | # Note that the (N,) ordering is from N-terminal to C-terminal.
223 |
224 | # Calpha to carbonyl both in the current frame.
225 | calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / (
226 | torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7
227 | )
228 | # For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0.
229 | # The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change.
230 |
231 | # Nitrogen of the next frame to carbonyl of the current frame.
232 | nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / (
233 | torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7
234 | )
235 |
236 | carbonyl_to_oxygen: torch.Tensor = (
237 | calpha_to_carbonyl + nitrogen_to_carbonyl
238 | ) # (N-1, 3)
239 | carbonyl_to_oxygen = carbonyl_to_oxygen / (
240 | torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7
241 | )
242 |
243 | atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23
244 |
245 | # Now we deal with frames for which there is no next frame available.
246 |
247 | # Calpha to carbonyl both in the current frame. (N, 3)
248 | calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / (
249 | torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
250 | )
251 | # Calpha to nitrogen both in the current frame. (N, 3)
252 | calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / (
253 | torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
254 | )
255 | carbonyl_to_oxygen_term: torch.Tensor = (
256 | calpha_to_carbonyl_term + calpha_to_nitrogen_term
257 | ) # (N, 3)
258 | carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
259 | torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7
260 | )
261 |
262 | # Create a mask that is 1 when the next residue is not available either
263 | # due to this frame being the C-terminus or the next residue is not
264 | # known due to pos_is_known being false.
265 |
266 | if pos_is_known is None:
267 | pos_is_known = torch.ones(
268 | (atom_37.shape[0],), dtype=torch.int64, device=atom_37.device
269 | )
270 |
271 | next_res_gone: torch.Tensor = ~pos_is_known.bool() # (N,)
272 | next_res_gone = torch.cat(
273 | [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0
274 | ) # (N+1, )
275 | next_res_gone = next_res_gone[1:] # (N,)
276 |
277 | atom_37[next_res_gone, 4, :] = (
278 | atom_37[next_res_gone, 2, :] + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23
279 | )
280 |
281 | return atom_37
282 |
283 |
284 | def trans_to_atom37(trans):
285 | B, L, _ = trans.shape
286 | atom37 = trans.new_zeros(B, L, 37, 3)
287 | atom37[:, :, 1] = trans
288 | return atom37
289 |
290 |
291 | def transrot_to_atom37(rigids):
292 | atom37_traj = []
293 | B, L = rigids.shape
294 |
295 | atom37 = compute_backbone(rigids, torch.zeros((B, L, 2), device=rigids.device))[0]
296 | for i in range(B):
297 | atom37[i] = adjust_oxygen_pos(atom37[i], None)
298 | return atom37
299 |
300 |
301 | def rigid_transform_3D(A, B, verbose=False):
302 | # Transforms A to look like B
303 | # https://github.com/nghiaho12/rigid_transform_3D
304 | assert A.shape == B.shape
305 | A = A.T
306 | B = B.T
307 |
308 | num_rows, num_cols = A.shape
309 | if num_rows != 3:
310 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
311 |
312 | num_rows, num_cols = B.shape
313 | if num_rows != 3:
314 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
315 |
316 | # find mean column wise
317 | centroid_A = np.mean(A, axis=1)
318 | centroid_B = np.mean(B, axis=1)
319 |
320 | # ensure centroids are 3x1
321 | centroid_A = centroid_A.reshape(-1, 1)
322 | centroid_B = centroid_B.reshape(-1, 1)
323 |
324 | # subtract mean
325 | Am = A - centroid_A
326 | Bm = B - centroid_B
327 |
328 | H = Am @ np.transpose(Bm)
329 |
330 | # sanity check
331 | # if linalg.matrix_rank(H) < 3:
332 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))
333 |
334 | # find rotation
335 | U, S, Vt = np.linalg.svd(H)
336 | R = Vt.T @ U.T
337 |
338 | # special reflection case
339 | reflection_detected = False
340 | if np.linalg.det(R) < 0:
341 | if verbose:
342 | print("det(R) < R, reflection detected!, correcting for it ...")
343 | Vt[2, :] *= -1
344 | R = Vt.T @ U.T
345 | reflection_detected = True
346 |
347 | t = -R @ centroid_A + centroid_B
348 | optimal_A = R @ A + t
349 |
350 | return optimal_A.T, R, t, reflection_detected
351 |
352 |
353 | def get_aligned_rmsd(pos_1, pos_2):
354 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0]
355 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
356 |
357 |
358 | def create_full_prot(
359 | atom37: np.ndarray,
360 | aatype=None,
361 | b_factors=None,
362 | ):
363 | assert atom37.ndim == 3
364 | assert atom37.shape[-1] == 3
365 | assert atom37.shape[-2] == 37
366 | n = atom37.shape[0]
367 | residue_index = np.arange(n)
368 | atom37_mask = (np.sum(np.abs(atom37), axis=-1) > 1e-7) & (
369 | np.sum(np.abs(atom37[:, 1:2]), axis=-1) > 1e-7
370 | )
371 | if b_factors is None:
372 | b_factors = np.zeros([n, 37])
373 | if aatype is None:
374 | aatype = np.zeros(n, dtype=int)
375 | return protein.Protein(
376 | atom_positions=atom37,
377 | atom_mask=atom37_mask,
378 | aatype=aatype,
379 | residue_index=residue_index,
380 | b_factors=b_factors,
381 | )
382 |
383 |
384 | def frames_to_pdb(frames, path):
385 | prots = []
386 | atom37 = transrot_to_atom37(frames).cpu().numpy()
387 | for i, pos in enumerate(atom37):
388 | prots.append(create_full_prot(pos))
389 | with open(path, "w") as f:
390 | f.write(prots_to_pdb(prots))
391 |
392 |
393 | def trans_to_pdb(trans, path):
394 | prots = []
395 | atom37 = trans_to_atom37(trans).cpu().numpy()
396 | for i, pos in enumerate(atom37):
397 | prots.append(create_full_prot(pos))
398 | with open(path, "w") as f:
399 | f.write(prots_to_pdb(prots))
400 |
401 |
402 | def atom37_to_pdb(atom37, path, mask=None):
403 | prots = []
404 | for i, pos in enumerate(atom37):
405 | if mask is not None:
406 | pos = pos[mask[i]]
407 | prots.append(create_full_prot(pos))
408 | with open(path, "w") as f:
409 | f.write(prots_to_pdb(prots))
410 |
411 |
412 | def prots_to_pdb(prots):
413 | ss = ""
414 | for i, prot in enumerate(prots):
415 | ss += f"MODEL {i}\n"
416 | prot = protein.to_pdb(prot)
417 | ss += "\n".join(prot.split("\n")[1:-2])
418 | ss += "\nENDMDL\n"
419 | return ss
420 |
421 |
422 | def compute_lddt(pos1, pos2, mask, cutoff=15.0, eps=1e-10, symmetric=False):
423 | dmat1 = torch.sqrt(
424 | eps + torch.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1)
425 | )
426 | dmat2 = torch.sqrt(
427 | eps + torch.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1)
428 | )
429 | if symmetric:
430 | dists_to_score = (dmat1 < cutoff) | (dmat2 < cutoff)
431 | else:
432 | dists_to_score = dmat1 < cutoff
433 | dists_to_score = (
434 | dists_to_score
435 | * mask.unsqueeze(-2)
436 | * mask.unsqueeze(-1)
437 | * (1.0 - torch.eye(mask.shape[-1]).to(mask))
438 | )
439 | dist_l1 = torch.abs(dmat1 - dmat2)
440 | score = (
441 | (dist_l1[..., None] < torch.tensor([0.5, 1.0, 2.0, 4.0]).to(pos1))
442 | .float()
443 | .mean(-1)
444 | )
445 | score = (dists_to_score * score).sum((-1, -2)) / dists_to_score.sum((-1, -2))
446 |
447 | return score
448 |
449 |
450 | def compute_fape(
451 | pred_frames,
452 | target_frames,
453 | frames_mask,
454 | pred_positions,
455 | target_positions,
456 | positions_mask,
457 | length_scale,
458 | l1_clamp_distance=None,
459 | thresh=None,
460 | eps=1e-8,
461 | ) -> torch.Tensor:
462 | """
463 | Computes FAPE loss.
464 |
465 | Args:
466 | pred_frames:
467 | [*, N_frames] Rigid object of predicted frames
468 | target_frames:
469 | [*, N_frames] Rigid object of ground truth frames
470 | frames_mask:
471 | [*, N_frames] binary mask for the frames
472 | pred_positions:
473 | [*, N_pts, 3] predicted atom positions
474 | target_positions:
475 | [*, N_pts, 3] ground truth positions
476 | positions_mask:
477 | [*, N_pts] positions mask
478 | length_scale:
479 | Length scale by which the loss is divided
480 | pair_mask:
481 | [*, N_frames, N_pts] mask to use for
482 | separating intra- from inter-chain losses.
483 | l1_clamp_distance:
484 | Cutoff above which distance errors are disregarded
485 | eps:
486 | Small value used to regularize denominators
487 | Returns:
488 | [*] loss tensor
489 | """
490 |
491 | # [*, N_frames, N_pts, 3]
492 | local_pred_pos = pred_frames.invert()[..., None].apply(
493 | pred_positions[..., None, :, :],
494 | )
495 | local_target_pos = target_frames.invert()[..., None].apply(
496 | target_positions[..., None, :, :],
497 | )
498 |
499 | error_dist = torch.sqrt(
500 | torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
501 | )
502 |
503 | if l1_clamp_distance is not None:
504 | error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
505 |
506 | normed_error = error_dist / length_scale
507 |
508 | if thresh is not None:
509 | thresh_mask = torch.sqrt(torch.sum(local_target_pos**2, dim=-1)) < thresh
510 | mask = thresh_mask * frames_mask[..., None] * positions_mask[..., None, :]
511 |
512 | normed_error = normed_error * mask
513 | normed_error = torch.sum(normed_error, dim=(-1, -2))
514 | normed_error = normed_error / (eps + torch.sum(mask, dim=(-1, -2)))
515 |
516 | else:
517 | normed_error = normed_error * frames_mask[..., None]
518 | normed_error = normed_error * positions_mask[..., None, :]
519 |
520 | normed_error = torch.sum(normed_error, dim=-1)
521 | normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
522 | normed_error = torch.sum(normed_error, dim=-1)
523 | normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
524 |
525 | return normed_error
526 |
--------------------------------------------------------------------------------
/proteinblobs/visualize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | from pymol import cmd
13 | import matplotlib
14 | import matplotlib.pyplot as plt
15 |
16 | matplotlib.use('Agg')
17 | import numpy as np
18 | import plotly.graph_objects as go
19 | import torch
20 | import os
21 | from pathlib import Path
22 | from pymol.cgo import BEGIN, COLOR, TRIANGLES, VERTEX, NORMAL, END
23 |
24 | import math
25 | import numpy as np
26 |
27 | def vertex(a1, a2, a3, u, v, M, r0):
28 | #https://github.com/vlasenkov/pymol-iellipsoid/blob/master/iellipsoid.py
29 | vrtx = M.dot(np.array([
30 | a1 * math.cos(u) * math.cos(v),
31 | a2 * math.cos(u) * math.sin(v),
32 | a3 * math.sin(u)
33 | ]))
34 | nrml = M.dot(np.array([
35 | math.cos(u) * math.cos(v) / a1,
36 | math.cos(u) * math.sin(v) / a2,
37 | math.sin(u) / a3
38 | ]))
39 | return vrtx + r0, nrml
40 |
41 |
42 | def draw_ellipse(r0, C, color=[0.5, 0.5, 0.5], transparency=0.0, name='', resolution=12):
43 | #https://github.com/vlasenkov/pymol-iellipsoid/blob/master/iellipsoid.py
44 | ws, vs = np.linalg.eig(C)
45 | # M = np.linalg.inv(vs)
46 | M = vs
47 | a1, a2, a3 = np.sqrt(ws)
48 | u_segs = resolution
49 | v_segs = resolution
50 | mesh = [BEGIN, TRIANGLES, COLOR]
51 | mesh.extend(color)
52 | dU = math.pi / u_segs
53 | dV = 2 * math.pi / v_segs
54 | U = -math.pi / 2
55 | for Y in range(0, u_segs):
56 | V = math.pi
57 | for X in range(0, v_segs):
58 |
59 | (x1, y1, z1), (n1x, n1y, n1z) = vertex(a1, a2, a3,
60 | U, V, M, r0)
61 | (x2, y2, z2), (n2x, n2y, n2z) = vertex(a1, a2, a3,
62 | U + dU, V, M, r0)
63 | (x3, y3, z3), (n3x, n3y, n3z) = vertex(a1, a2, a3,
64 | U + dU, V + dV,
65 | M, r0)
66 | (x4, y4, z4), (n4x, n4y, n4z) = vertex(a1, a2, a3,
67 | U, V + dV, M, r0)
68 |
69 | mesh.extend([NORMAL, n1x, n1y, n1z, VERTEX, x1, y1, z1])
70 | mesh.extend([NORMAL, n2x, n2y, n2z, VERTEX, x2, y2, z2])
71 | mesh.extend([NORMAL, n4x, n4y, n4z, VERTEX, x4, y4, z4])
72 | mesh.extend([NORMAL, n2x, n2y, n2z, VERTEX, x2, y2, z2])
73 | mesh.extend([NORMAL, n3x, n3y, n3z, VERTEX, x3, y3, z3])
74 | mesh.extend([NORMAL, n4x, n4y, n4z, VERTEX, x4, y4, z4])
75 |
76 | V += dV
77 | U += dU
78 | mesh.append(END)
79 | cmd.load_cgo(mesh, name)
80 | cmd.set("cgo_transparency", transparency, name)
81 |
82 |
83 | def visualize_blobs(tmp_paths, means, covars, save_path=None, ax=None):
84 | cmd.reinitialize()
85 | for mean, covar in zip(means, covars):
86 | draw_ellipse(mean, 5 * covar, transparency=0.5)
87 | for tmp_path in tmp_paths:
88 | cmd.load(tmp_path)
89 | os.remove(tmp_path)
90 | cmd.orient(Path(tmp_paths[-1]).stem)
91 | #cmd.spectrum()
92 | cmd.bg_color('white')
93 | cmd.set('ray_trace_mode', 0)
94 | cmd.set('depth_cue', 'off')
95 | cmd.set('ray_shadows', 'off')
96 |
97 | if save_path is not None:
98 | cmd.save(save_path)
99 | if ax is not None:
100 | cmd.png('/tmp/tmp.png', 640, 640)
101 | img = plt.imread('/tmp/tmp.png')
102 | os.remove('/tmp/tmp.png')
103 |
104 | if ax is None:
105 | plt.imshow(img)
106 | plt.axis('off') # Optional: Hide axis ticks and labels
107 | plt.show()
108 | else:
109 | ax.imshow(img)
110 | ax.axis('off')
111 |
112 | def plot_point_cloud(point_clouds, path=None):
113 | # Takes a list of point cloud tensors and plots them
114 | if not isinstance(point_clouds, list):
115 | point_clouds = [point_clouds]
116 |
117 | colors = ['red', 'blue', 'green', 'yellow', 'orange'] # List of colors for each point cloud
118 | traces = [] # List to store individual traces for each point cloud
119 |
120 | for i, point_cloud in enumerate(point_clouds):
121 | if isinstance(point_cloud, np.ndarray):
122 | pass
123 | elif isinstance(point_cloud, torch.Tensor):
124 | point_cloud = point_cloud.numpy()
125 |
126 | x_data = point_cloud[:, 0]
127 | y_data = point_cloud[:, 1]
128 | z_data = point_cloud[:, 2]
129 |
130 | # Create a trace for each point cloud with a different color
131 | trace = go.Scatter3d(
132 | x=x_data,
133 | y=y_data,
134 | z=z_data,
135 | mode='markers',
136 | marker=dict(
137 | size=5,
138 | opacity=0.8,
139 | color=colors[i % len(colors)] # Assign color based on the index of the point cloud
140 | ),
141 | name=f"Point Cloud {i + 1}"
142 | )
143 | traces.append(trace)
144 |
145 | # Create the layout
146 | layout = go.Layout(
147 | scene=dict(
148 | aspectmode='data'
149 | )
150 | )
151 |
152 | # Create the figure and add the traces to it
153 | fig = go.Figure(data=traces, layout=layout)
154 |
155 | if path is None:
156 | fig.show()
157 | else:
158 | fig.write_html(path)
--------------------------------------------------------------------------------
/proteinblobs/wrapper.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from .logger import get_logger
12 | logger = get_logger(__name__)
13 | import pytorch_lightning as pl
14 | import numpy as np
15 | import torch, time, wandb
16 | from collections import defaultdict
17 | from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
18 | import os
19 | import pandas as pd
20 |
21 | def gather_log(log, world_size):
22 | if world_size == 1:
23 | return log
24 | log_list = [None] * world_size
25 | torch.distributed.all_gather_object(log_list, log)
26 | log = {key: sum([l[key] for l in log_list], []) for key in log}
27 | return log
28 |
29 |
30 | def get_log_mean(log):
31 | out = {}
32 | for key in log:
33 | try:
34 | out[key] = np.nanmean(log[key])
35 | except:
36 | pass
37 | return out
38 |
39 | class Wrapper(pl.LightningModule):
40 |
41 | def __init__(self, args):
42 | super().__init__()
43 | self.save_hyperparameters()
44 | self.args = args
45 | self._log = defaultdict(list)
46 | self.last_log_time = time.time()
47 | self.iter_step = 0
48 |
49 | def log(self, key, data, extend=False):
50 | if isinstance(data, torch.Tensor):
51 | data = data.mean().item()
52 | log = self._log
53 | if extend:
54 | if self.stage == 'train' or self.args.validate:
55 | log["iter_" + key].extend(data)
56 | log[self.stage + "_" + key].extend(data)
57 | else:
58 | if self.stage == 'train' or self.args.validate:
59 | log["iter_" + key].append(data)
60 | log[self.stage + "_" + key].append(data)
61 |
62 | def load_ema_weights(self):
63 | logger.info('Loading EMA weights')
64 | clone_param = lambda t: t.detach().clone()
65 | self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
66 | self.model.load_state_dict(self.ema.state_dict()["params"])
67 |
68 | def restore_cached_weights(self):
69 | logger.info('Restoring cached weights')
70 | self.model.load_state_dict(self.cached_weights)
71 | self.cached_weights = None
72 |
73 | def on_before_zero_grad(self, *args, **kwargs):
74 | if self.args.ema:
75 | self.ema.update(self.model)
76 |
77 | def training_step(self, batch, batch_idx):
78 | self.stage = 'train'
79 | if self.args.ema:
80 | if (self.ema.device != self.device):
81 | self.ema.to(self.device)
82 | return self.general_step(batch)
83 |
84 | def validation_step(self, batch, batch_idx):
85 | self.stage = 'val'
86 | if self.args.ema:
87 | if (self.ema.device != self.device):
88 | self.ema.to(self.device)
89 | if (self.cached_weights is None):
90 | self.load_ema_weights()
91 |
92 | self.general_step(batch)
93 | self.validation_step_extra(batch, batch_idx)
94 | if self.args.validate and self.iter_step % self.args.print_freq == 0:
95 | self.print_log()
96 |
97 | def general_step(self, batch):
98 | pass
99 |
100 | def validation_step_extra(self, batch, batch_idx):
101 | pass
102 |
103 | def on_train_epoch_end(self):
104 | self.print_log(prefix='train', save=False)
105 |
106 | def on_validation_epoch_end(self):
107 | if self.args.ema:
108 | self.restore_cached_weights()
109 | self.on_validation_epoch_end_extra()
110 | self.print_log(prefix='val', save=self.args.save_val)
111 |
112 | def on_validation_epoch_end_extra(self):
113 | pass
114 |
115 | def on_before_optimizer_step(self, optimizer):
116 | if (self.trainer.global_step + 1) % self.args.print_freq == 0:
117 | self.print_log()
118 |
119 | if self.args.check_grad:
120 | for name, p in self.model.named_parameters():
121 | if p.grad is None:
122 | logger.warning(f"Param {name} has no grad")
123 |
124 | def on_load_checkpoint(self, checkpoint):
125 | logger.info('Loading EMA state dict')
126 | if self.args.ema:
127 | ema = checkpoint["ema"]
128 | self.ema.load_state_dict(ema)
129 |
130 | def on_save_checkpoint(self, checkpoint):
131 | if self.args.ema:
132 | if self.cached_weights is not None:
133 | self.restore_cached_weights()
134 | checkpoint["ema"] = self.ema.state_dict()
135 |
136 | def print_log(self, prefix='iter', save=False, extra_logs=None):
137 | log = self._log
138 | log = {key: log[key] for key in log if f"{prefix}_" in key}
139 | log = gather_log(log, self.trainer.world_size)
140 | mean_log = get_log_mean(log)
141 |
142 | mean_log.update({
143 | 'epoch': self.trainer.current_epoch,
144 | 'trainer_step': self.trainer.global_step + int(prefix == 'iter'),
145 | 'iter_step': self.iter_step,
146 | f'{prefix}_count': len(log[next(iter(log))]),
147 |
148 | })
149 | if extra_logs:
150 | mean_log.update(extra_logs)
151 | try:
152 | for param_group in self.optimizers().optimizer.param_groups:
153 | mean_log['lr'] = param_group['lr']
154 | except:
155 | pass
156 |
157 | if self.trainer.is_global_zero:
158 | logger.info(str(mean_log))
159 | if self.args.wandb:
160 | wandb.log(mean_log)
161 | if save:
162 | path = os.path.join(os.environ["MODEL_DIR"],f"{prefix}_{self.trainer.current_epoch}.csv")
163 | max_len = max([len(v) for k,v in log.items()])
164 | pd.DataFrame({k:v for k,v in log.items() if len(v) == max_len}).to_csv(path)
165 | for key in list(log.keys()):
166 | if f"{prefix}_" in key:
167 | del self._log[key]
168 |
169 | def configure_optimizers(self):
170 | cls = torch.optim.AdamW if self.args.adamW else torch.optim.Adam
171 | optimizer = cls(
172 | filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr,
173 | )
174 |
175 | if self.args.lr_scheduler:
176 | lr_scheduler = AlphaFoldLRScheduler(optimizer, max_lr=self.args.lr)
177 | return {
178 | "optimizer": optimizer,
179 | "lr_scheduler": {
180 | "scheduler": lr_scheduler,
181 | "interval": "step",
182 | "name": "AlphaFoldLRScheduler",
183 | }
184 | }
185 | else:
186 | return optimizer
187 |
188 |
--------------------------------------------------------------------------------
/run_multiflow.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from omegaconf import OmegaConf
12 | import os, torch
13 | import pytorch_lightning as pl
14 | import numpy as np
15 | from proteinblobs.multiflow_wrapper import MultiflowWrapper
16 | from openfold.np import protein
17 | from proteinblobs.utils import upgrade_state_dict, create_full_prot
18 | import argparse
19 |
20 | args = argparse.Namespace()
21 | args.__dict__.update({
22 | 'multiflow_yaml': "../multiflow/weights/config.yaml"
23 | })
24 | model = MultiflowWrapper(args)
25 | ckpt_path = "../multiflow/weights/last.ckpt"
26 | ckpt = torch.load(ckpt_path, map_location='cpu')
27 | model.load_state_dict(ckpt['state_dict'])
28 | model.eval().cuda()
29 |
30 | out, _ = model.inference(num_batch=2, sample_length=128)
31 |
32 | for i, prot in enumerate(out):
33 | prot = create_full_prot(prot)
34 |
35 | with open(f'test{i}.pdb', 'w') as f:
36 | f.write(protein.to_pdb(prot))
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--ckpt', default='workdir/tune_SC_elli_warmup/epoch=12-step=59683.ckpt')
15 | parser.add_argument('--num_blobs', type=int, default=5)
16 | parser.add_argument('--num_prots', type=int, default=5)
17 | parser.add_argument('--nu', type=int, default=10) # vary the blob anisotropy [10, 20, 50, 100]
18 | parser.add_argument('--psi', type=int, default=5) # FIX this to give proteins with O(200) residues
19 | parser.add_argument('--sigma', type=int, default=8) # vary this compactness parameter
20 | parser.add_argument('--helix_frac', type=float, default=0.5)
21 | parser.add_argument('--num_residues', type=int, default=160)
22 | parser.add_argument('--seed', type=int, default=137)
23 | parser.add_argument('--num_timesteps', type=int, default=500)
24 | parser.add_argument('--inference_rot_scaling', type=float, default=10)
25 | parser.add_argument('--multiflow', action='store_true')
26 | parser.add_argument('--outdir', default='outpdb/default')
27 | parser.add_argument('--guidance', type=float, default=None)
28 | args = parser.parse_args()
29 |
30 | from pymol import cmd
31 | import torch
32 | from proteinblobs.parsing import parse_train_args
33 | from proteinblobs.multiflow_wrapper import MultiflowWrapper
34 | from proteinblobs import blobs
35 | from proteinblobs.utils import atom37_to_pdb
36 | import os, contextlib
37 | import numpy as np
38 | import tqdm
39 |
40 | @contextlib.contextmanager
41 | def temp_seed(seed):
42 | state = np.random.get_state()
43 | np.random.seed(seed)
44 | torch_state = torch.seed()
45 | torch.manual_seed(seed)
46 | try:
47 | yield
48 | finally:
49 | np.random.set_state(state)
50 | torch.manual_seed(seed)
51 |
52 |
53 | np.random.seed(args.seed)
54 | torch.manual_seed(args.seed)
55 | os.makedirs(args.outdir, exist_ok=True)
56 | blobss = []
57 | for _ in tqdm.tqdm(range(args.num_prots), desc='sampling dataset'):
58 | pos, covar = blobs.sample_blobs(args.num_blobs, nu=args.nu, psi=(1/args.nu)*args.psi**2*np.eye(3), sigma=args.sigma)
59 | is_helix = np.random.rand(args.num_blobs) < args.helix_frac
60 | volume = np.linalg.det(covar)**0.5
61 |
62 | counts = np.where(
63 | is_helix,
64 | blobs.alpha_slope * volume + blobs.alpha_intercept,
65 | blobs.beta_slope * volume + blobs.beta_intercept,
66 | ).astype(int)
67 |
68 | dssp = np.where(is_helix, 1, 2)
69 | blobss.append((pos, covar, counts, dssp))
70 |
71 |
72 | if args.multiflow:
73 |
74 | args_ = argparse.Namespace()
75 | args_.__dict__.update({
76 | 'finetune': True,
77 | 'multiflow_yaml': "multiflow_config.yaml",
78 | 'num_timesteps': args.num_timesteps,
79 | 'self_condition': True,
80 | 'blob_attention': False,
81 | 'extra_attn_layer': False,
82 | 'freeze_weights': True,
83 | 'pretrained_mf_path': 'weights/last.ckpt',
84 | 'inference_rot_scaling': args.inference_rot_scaling,
85 | })
86 |
87 | model = MultiflowWrapper(args_).eval().cuda()
88 | else:
89 | ckpt = torch.load(args.ckpt)
90 |
91 | args_ = ckpt['hyper_parameters']['args']
92 | if args.guidance is not None: #--guidance --sc_guidance_mode separate --interpolate_sc --interpolate_sc_weight 0.8 --guidance_weight 0.8 --seq_guidance_weight 0.8
93 | args_.__dict__.update({
94 | 'guidance': True,
95 | 'sc_guidance_mode': 'separate',
96 | 'num_timesteps': args.num_timesteps,
97 | 'interpolate_sc': True,
98 | 'pretrained_mf_path': 'weights/last.ckpt',
99 | 'interpolate_sc_weight': args.guidance,
100 | 'guidance_weight': args.guidance,
101 | 'seq_guidance_weight': args.guidance,
102 | })
103 | model = MultiflowWrapper(args_)
104 | model.load_state_dict(ckpt['state_dict'])
105 | model = model.eval().cuda()
106 |
107 | if args.guidance is not None:
108 | with temp_seed(args.seed):
109 | model.on_validation_epoch_start()
110 |
111 | for i, (pos, covar, counts, dssp) in tqdm.tqdm(enumerate(blobss), desc='running inference'):
112 |
113 | batch = {}
114 | if args.multiflow:
115 | batch['res_mask'] = np.ones((1, args.num_residues), dtype=np.float32)
116 | batch['grounding_pos'] = batch['grounding_feat'] = batch['grounding_mask'] = np.zeros(1)
117 | else:
118 | batch['res_mask'] = np.ones((1, counts.sum()), dtype=np.float32)
119 | batch['grounding_pos'] = pos[None].astype(np.float32)
120 | batch['grounding_feat'] = np.zeros((1, args.num_blobs, 11), dtype=np.float32)
121 | batch['grounding_feat'][:,:,0] = dssp # all helix for now
122 | batch['grounding_feat'][:,:,1] = counts
123 | batch['grounding_feat'][:,:,2:] = covar.reshape(1, args.num_blobs, 9)
124 | batch['grounding_mask'] = np.ones((1, args.num_blobs), dtype=np.float32)
125 |
126 |
127 | batch = {k: torch.from_numpy(v).cuda() for k, v in batch.items()}
128 | batch['latents'] = 0
129 | if not args.multiflow:
130 | np.savez(f'{args.outdir}/blobs_{i}.npz', pos=pos, covar=covar, counts=counts, dssp=dssp)
131 | with torch.no_grad():
132 | samples, sequences = model.inference(batch)
133 | atom37_to_pdb(samples, f'{args.outdir}/blobs_{i}.pdb')
134 | np.save(f'{args.outdir}/blobs_{i}_sequence.npy', sequences[0])
135 |
--------------------------------------------------------------------------------
/scripts/ca_only_helicity.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import os
12 | import tqdm
13 | import argparse
14 | import numpy as np
15 |
16 | import biotite.structure.io.pdb as pdb
17 | import biotite.structure as struc
18 | import numpy as np
19 |
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--dir', required=True)
23 | parser.add_argument('--outdir', type=str, default=None)
24 | args = parser.parse_args()
25 |
26 |
27 |
28 |
29 | if args.outdir is None:
30 | args.outdir = args.dir
31 |
32 |
33 | def calculate_secondary_structure(pdb_file):
34 | """
35 | Calculate secondary structures using Biotite.
36 | """
37 | file = pdb.PDBFile.read(pdb_file)
38 | array = file.get_structure()
39 | sse = struc.annotate_sse(array[0])
40 |
41 | helices = (sse == 'a').sum() # Helices
42 | strands = (sse == 'b').sum() # Strands
43 | coils = (sse == 'c').sum() # Coils
44 | total = len(sse)
45 |
46 | helix_percentage = helices / total if total > 0 else 0
47 | strand_percentage = strands / total if total > 0 else 0
48 | coil_percentage = coils / total if total > 0 else 0
49 |
50 | return helix_percentage, strand_percentage, coil_percentage
51 |
52 | def preprocess_pdb_files(pdb_files):
53 | """
54 | Ensure every PDB file listed in the specified text file has valid occupancy and B-factor values otherwise will be filled with a default value
55 | """
56 | for pdb_path in pdb_files:
57 | pdb_path = str(pdb_path)
58 | corrected_content = []
59 | with open(pdb_path, 'r') as file:
60 | for line in file:
61 | if line.startswith(("ATOM", "HETATM")):
62 | # Occupancy is located from columns 55 to 60 In Case an Error is thrown
63 | occupancy_value = line[54:60].strip()
64 | if not occupancy_value:
65 | line = line[:54] + '1.00' + line[60:] # Insert default occupancy In Case an Error is thrown
66 |
67 | # B-factor is located from columns 61 to 66 In Case an Error is thrown
68 | b_factor_value = line[60:66].strip()
69 | if not b_factor_value:
70 | line = line[:60] + '0.00' + line[66:] # Insert default B-factor In Case an Error is thrown
71 |
72 | corrected_content.append(line)
73 |
74 | with open(pdb_path, 'w') as file:
75 | file.writelines(corrected_content)
76 |
77 | files = os.listdir(args.dir)
78 | files = sorted([f for f in files if '.pdb' in f])
79 |
80 | helicity = []
81 | for i, file in tqdm.tqdm(enumerate(files)):
82 | preprocess_pdb_files([f"{args.dir}/{file}"])
83 | helix_percent, _, _ = calculate_secondary_structure(f"{args.dir}/{file}")
84 | helicity.append(helix_percent)
85 | old_dict = dict(np.load(f"{args.outdir}/res.npz"))
86 | old_dict['helicity'] = np.array(helicity)
87 | np.savez(f"{args.outdir}/res.npz", **old_dict)
88 |
89 |
90 |
--------------------------------------------------------------------------------
/scripts/create_multiflowdb.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from proteinblobs.parsing import parse_train_args
12 |
13 | args = parse_train_args()
14 |
15 | import torch
16 | from omegaconf import OmegaConf
17 | import numpy as np
18 | from proteinblobs.multiflow.flow_model import FlowModel
19 | from proteinblobs.multiflow.data.interpolant import Interpolant
20 | from proteinblobs.multiflow.data import utils as du
21 | from proteinblobs.utils import atom37_to_pdb
22 | from proteinblobs.dataset import seq_collate, GenieDBDataset
23 |
24 | import copy
25 | from tqdm import tqdm
26 |
27 | cfg = OmegaConf.load("multiflow_config.yaml")
28 | cfg.model.edge_features.self_condition = (
29 | args.self_condition
30 | ) # only the cfg.interpolant.self_condition is used anywhere in the code
31 | cfg.interpolant.self_condition = args.self_condition
32 |
33 | print("init model")
34 | args_uncond = copy.deepcopy(args)
35 | args_uncond.freeze_weights = True
36 | args_uncond.extra_attn_layer = False
37 | args_uncond.blob_attention = False
38 | model_uncond = FlowModel(cfg.model, args_uncond)
39 | ckpt = torch.load(args.pretrained_mf_path, map_location="cuda")
40 | model_uncond.load_state_dict(
41 | {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
42 | )
43 | model_uncond = model_uncond.to("cuda")
44 |
45 | print("init dataset")
46 | dataset = GenieDBDataset(args)
47 | dataset[0]
48 | loader = torch.utils.data.DataLoader(
49 | dataset,
50 | batch_size=args.batch_size,
51 | num_workers=args.num_workers,
52 | collate_fn=seq_collate,
53 | shuffle=True,
54 | )
55 |
56 | interpolant = Interpolant(cfg.interpolant, args)
57 | interpolant.set_device("cuda")
58 | for batch in tqdm(loader):
59 | true_bb_pos = None
60 | num_batch, sample_length = batch["mask"].shape
61 | prot_traj, model_traj = interpolant.sample(
62 | batch["mask"].cuda(),
63 | model_uncond,
64 | num_timesteps=args.num_timesteps,
65 | separate_t=cfg.inference.interpolant.codesign_separate_t,
66 | )
67 | diffuse_mask = torch.ones(1, sample_length)
68 | atom37_traj = [x[0] for x in prot_traj]
69 | atom37_model_traj = [x[0] for x in model_traj]
70 |
71 | bb_trajs = du.to_numpy(torch.stack(atom37_traj, dim=0).transpose(0, 1))
72 | noisy_traj_length = bb_trajs.shape[1]
73 | assert bb_trajs.shape == (num_batch, noisy_traj_length, sample_length, 37, 3)
74 |
75 | model_trajs = du.to_numpy(torch.stack(atom37_model_traj, dim=0).transpose(0, 1))
76 | clean_traj_length = model_trajs.shape[1]
77 | assert model_trajs.shape == (num_batch, clean_traj_length, sample_length, 37, 3)
78 |
79 | aa_traj = [x[1] for x in prot_traj]
80 | clean_aa_traj = [x[1] for x in model_traj]
81 |
82 | aa_trajs = du.to_numpy(torch.stack(aa_traj, dim=0).transpose(0, 1).long())
83 | assert aa_trajs.shape == (num_batch, noisy_traj_length, sample_length)
84 | for i in range(aa_trajs.shape[0]):
85 | for j in range(aa_trajs.shape[2]):
86 | if aa_trajs[i, -1, j] == du.MASK_TOKEN_INDEX:
87 | print("WARNING mask in predicted AA")
88 | aa_trajs[i, -1, j] = 0
89 | clean_aa_trajs = du.to_numpy(
90 | torch.stack(clean_aa_traj, dim=0).transpose(0, 1).long()
91 | )
92 | assert clean_aa_trajs.shape == (num_batch, clean_traj_length, sample_length)
93 |
94 | samples = bb_trajs[:, -1]
95 | for i, sample in enumerate(samples):
96 | sample_path = f"/lustre/fsw/portfolios/nvr/projects/nvr_lpr_compgenai/hstaerk_bjing/multiflow_output/{batch['name'][i].replace('.pdb', '')}_{np.random.randint(100000)}.pdb"
97 | atom37_to_pdb(sample[None], sample_path)
98 |
--------------------------------------------------------------------------------
/scripts/evaluate_alignment.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import os
12 | from proteinblobs import blobs as pb
13 | import esm
14 | from openfold.np import protein
15 | import tqdm
16 | import argparse
17 | import pydssp
18 | import pandas as pd
19 | import numpy as np
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--dir", required=True)
23 | args = parser.parse_args()
24 |
25 | files = os.listdir(args.dir)
26 | files = sorted([f for f in files if ".pdb" in f])
27 |
28 | df = []
29 | for i, pdb in tqdm.tqdm(enumerate(files)):
30 | prot = protein.from_pdb_string(open(f"{args.dir}/{pdb}").read())
31 | blobs = dict(np.load(f"{args.dir}/{pdb.replace('pdb', 'npz')}"))
32 |
33 | pos = np.concatenate(
34 | [prot.atom_positions[:, :3, :], prot.atom_positions[:, 4:5, :]], axis=1
35 | ) # (L, 4, 3)
36 | dssp = pydssp.assign(pos, out_type="index")
37 | pos = pos[:, 1]
38 | blobs = [
39 | {
40 | "pos": blobs["pos"][i],
41 | "covar": blobs["covar"][i],
42 | "dssp": blobs["dssp"][i],
43 | "count": blobs["counts"][i],
44 | }
45 | for i in range(len(blobs["dssp"]))
46 | ]
47 | df.append(
48 | {
49 | "name": pdb,
50 | "coverage": pb.blob_coverage(pos, dssp, blobs, structured_only=True),
51 | "misplacement": pb.blob_misplacement(
52 | pos, dssp, blobs, structured_only=True
53 | ),
54 | "accuracy": pb.blob_accuracy(pos, dssp, blobs, structured_only=True),
55 | "likelihood": pb.blob_likelihood(pos, dssp, blobs, structured_only=True),
56 | "soft_accuracy": pb.soft_blob_accuracy(
57 | pos, dssp, blobs, structured_only=True
58 | ),
59 | "reblob_jsd": pb.reblob_jsd(pos, dssp, blobs, use_dssp=True),
60 | "shannon_6.5": pb.shannon_complexity(pos, dssp, 6.5),
61 | }
62 | )
63 | df = pd.DataFrame(df)
64 | df.set_index("name").to_csv(f"{args.dir}/alignment.csv")
65 | print(df)
66 | for colname in df.columns.tolist():
67 | if colname != "name":
68 | print(f"{colname}: {df[colname].mean()}")
69 |
--------------------------------------------------------------------------------
/scripts/evaluate_designability.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | import os, torch
13 | from biotite.sequence.io import fasta
14 | from proteinblobs.designability_utils import run_pmpnn, get_aligned_rmsd
15 |
16 | import esm
17 | from openfold.np import protein, residue_constants
18 | from proteinblobs.utils import atom37_to_pdb
19 | import json
20 | import argparse
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--dir", required=True)
24 | parser.add_argument("--filter_missing_res", action="store_true", default=False)
25 | parser.add_argument("--ca_only", action="store_true", default=False)
26 | args = parser.parse_args()
27 |
28 | tmpdir = f"/tmp/{os.getpid()}"
29 | os.makedirs(tmpdir, exist_ok=True)
30 |
31 | print("Loading ESMFold model for designability evaluation")
32 | esmf_model = esm.pretrained.esmfold_v1().eval()
33 | esmf_model = esmf_model.to("cuda")
34 |
35 |
36 | def designability(path, seqs_per_struct=8):
37 | sample_path = os.path.join(tmpdir, f"sample.pdb")
38 | if args.filter_missing_res:
39 | prot = protein.from_pdb_string(open(path, "r").read())
40 | ca_idx = residue_constants.atom_order["CA"]
41 | c_idx = residue_constants.atom_order["C"]
42 | n_idx = residue_constants.atom_order["N"]
43 | present_mask = prot.atom_mask[:, [ca_idx, c_idx, n_idx]].all(-1)
44 | pos = prot.atom_positions[present_mask]
45 | atom37_to_pdb(pos[None], sample_path)
46 | else:
47 | os.system(f"cp {path} {tmpdir}/sample.pdb")
48 |
49 | run_pmpnn(
50 | pdb_dir=tmpdir,
51 | num_seqs=seqs_per_struct,
52 | pmpnn_path="/data/rbg/users/hstark/ProteinMPNN",
53 | ca_only=args.ca_only,
54 | )
55 | mpnn_fasta_path = os.path.join(
56 | tmpdir, "seqs", os.path.basename(sample_path).replace(".pdb", ".fa")
57 | )
58 | fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path)
59 |
60 | seqs = [v for k, v in fasta_seqs.items()]
61 | seqs = seqs[
62 | 1:
63 | ] # remove the first sequence, which is the input sequence (and AAAAAAA... if there is no input sequence to the structure)
64 | rmsds = []
65 | for j, seq in enumerate(seqs):
66 | seq = seq.replace("X", "A")
67 |
68 | with torch.no_grad():
69 | output = esmf_model.infer(seq)
70 |
71 | with open(sample_path) as f:
72 | prot = protein.from_pdb_string(f.read())
73 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy()
74 | rmsd = get_aligned_rmsd(prot.atom_positions[:, 1], out_ca_pos)
75 | rmsds.append(rmsd)
76 | return {"rmsd": rmsds}
77 |
78 |
79 | files = os.listdir(args.dir)
80 | files = [f for f in files if ".pdb" in f]
81 | out = {}
82 | for f in files:
83 | out[f] = designability(f"{args.dir}/{f}")
84 | with open(f"{args.dir}/designability.json", "w") as f:
85 | f.write(json.dumps(out, indent=4))
86 |
--------------------------------------------------------------------------------
/scripts/evaluate_designability_codesign.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | import os, torch
13 | from biotite.sequence.io import fasta
14 | import esm
15 | import numpy as np
16 | from openfold.np import protein, residue_constants
17 | from proteinblobs.designability_utils import get_aligned_rmsd
18 | from proteinblobs.utils import atom37_to_pdb
19 | import json
20 | import argparse
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--dir", required=True)
24 | parser.add_argument("--filter_missing_res", action="store_true", default=False)
25 | parser.add_argument("--ca_only", action="store_true", default=False)
26 | args = parser.parse_args()
27 |
28 | tmpdir = f"/tmp/{os.getpid()}"
29 | os.makedirs(tmpdir, exist_ok=True)
30 |
31 | print("Loading ESMFold model for designability evaluation")
32 | esmf_model = esm.pretrained.esmfold_v1().eval()
33 | esmf_model = esmf_model.to("cuda")
34 |
35 |
36 | def designability(path):
37 | sample_path = os.path.join(tmpdir, f"sample.pdb")
38 | if args.filter_missing_res:
39 | prot = protein.from_pdb_string(open(path, "r").read())
40 | ca_idx = residue_constants.atom_order["CA"]
41 | c_idx = residue_constants.atom_order["C"]
42 | n_idx = residue_constants.atom_order["N"]
43 | present_mask = prot.atom_mask[:, [ca_idx, c_idx, n_idx]].all(-1)
44 | pos = prot.atom_positions[present_mask]
45 | atom37_to_pdb(pos[None], sample_path)
46 | else:
47 | os.system(f"cp {path} {tmpdir}/sample.pdb")
48 |
49 | seq_numeric = np.load(path.replace(".pdb", "_sequence.npy"))
50 | seq = "".join([residue_constants.restypes_with_x[r] for r in seq_numeric])
51 | seq = seq.replace("X", "A")
52 |
53 | with torch.no_grad():
54 | output = esmf_model.infer(seq)
55 |
56 | with open(sample_path) as f:
57 | prot = protein.from_pdb_string(f.read())
58 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy()
59 | rmsd = get_aligned_rmsd(prot.atom_positions[:, 1], out_ca_pos)
60 |
61 | return {"rmsd": [rmsd]}
62 |
63 |
64 | files = os.listdir(args.dir)
65 | files = [f for f in files if ".pdb" in f]
66 | out = {}
67 | for f in files:
68 | out[f] = designability(f"{args.dir}/{f}")
69 | with open(f"{args.dir}/designability_codesign.json", "w") as f:
70 | f.write(json.dumps(out, indent=4))
71 |
--------------------------------------------------------------------------------
/scripts/preprocess_clusters.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 | import argparse
11 | import pickle
12 | from sklearn.cluster import SpectralClustering
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--num_workers", type=int, default=1)
16 | parser.add_argument("--worker_id", type=int, default=0)
17 | parser.add_argument("--res_per_cluster", type=int, default=20)
18 | parser.add_argument("--dssp_weight", type=float, default=3)
19 | parser.add_argument("--seq_weight", type=float, default=0.1)
20 | parser.add_argument(
21 | "--pkl_dir", type=str, default="data/multiflow/train_set/processed_pdb"
22 | )
23 | parser.add_argument(
24 | "--target_dir", type=str, default="data/multiflow/train_set/processed_pdb_clusters"
25 | )
26 | parser.add_argument("--flat_dir", action="store_true")
27 |
28 |
29 | args = parser.parse_args()
30 |
31 | import math
32 | import os
33 | import tqdm
34 | import numpy as np
35 | import pydssp
36 |
37 |
38 | if args.flat_dir:
39 | files = os.listdir(args.pkl_dir)
40 | else:
41 | paths = os.listdir(args.pkl_dir)
42 | files = []
43 | for path in paths:
44 | files.extend(os.listdir(f"{args.pkl_dir}/{path}"))
45 | print(len(files))
46 |
47 |
48 | def do_job(file):
49 | if args.flat_dir:
50 | file_path = f"{args.pkl_dir}/{file}"
51 | else:
52 | file_path = f"{args.pkl_dir}/{file[1:3]}/{file}"
53 | try:
54 | with open(file_path, "rb") as f:
55 | prot = pickle.load(f)
56 | except:
57 | print("Failure", file)
58 | return None
59 |
60 | mask = prot["bb_mask"].astype(bool)
61 | atom37 = prot["atom_positions"][mask]
62 |
63 | bb_pos = np.concatenate([atom37[:, :3, :], atom37[:, 4:5, :]], axis=1) # (L, 4, 3)
64 |
65 | try:
66 | dssp = pydssp.assign(
67 | bb_pos, out_type="index"
68 | ) # 0: loop, 1: alpha-helix, 2: beta-strand
69 | except:
70 | print("Failure", file, bb_pos.shape)
71 | return None
72 |
73 | pos = atom37[:, 1, :]
74 | n_clusters = math.ceil(len(pos) / args.res_per_cluster)
75 |
76 | dssp_one_hot = dssp[:, None] == np.arange(3)
77 | feat = np.concatenate(
78 | [
79 | pos,
80 | dssp_one_hot * args.dssp_weight,
81 | np.arange(len(dssp))[:, None] * args.seq_weight,
82 | ],
83 | axis=1,
84 | )
85 |
86 | distmat = np.square(feat[None] - feat[:, None]).sum(-1) ** 0.5
87 | r = 2
88 | W = -(distmat**2) / (2 * r**2)
89 | W = np.exp(W)
90 | n_clusters = math.ceil(len(dssp) / args.res_per_cluster)
91 | labels = (
92 | SpectralClustering(n_clusters=n_clusters, affinity="precomputed").fit(W).labels_
93 | )
94 |
95 | dssp_count = np.zeros((n_clusters, 3))
96 | np.add.at(dssp_count, labels, dssp_one_hot)
97 |
98 | centers = np.zeros((n_clusters, 3))
99 | np.mean.at(centers, labels, pos)
100 |
101 | clustering = {
102 | "centers": centers,
103 | "labels": labels,
104 | "dssp_count": dssp_count,
105 | }
106 |
107 | return clustering
108 |
109 |
110 | out = {}
111 | for job in tqdm.tqdm(files):
112 | out[job] = do_job(job)
113 |
114 | with open(args.target_dir, "wb") as f:
115 | f.write(pickle.dumps(out))
116 |
--------------------------------------------------------------------------------
/scripts/unpack_geniepdb.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import numpy as np
12 | import argparse, tqdm
13 | from openfold.np import protein
14 | import openfold.np.residue_constants as rc
15 | import pandas as pd
16 | from multiprocessing import Pool
17 | import pydssp
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--indir', type=str, default='../afdb')
21 | parser.add_argument('--outdir', type=str, default='./afdb_npz')
22 | parser.add_argument('--outcsv', type=str, default='./afdb.csv')
23 | parser.add_argument('--num_workers', type=int, default=15)
24 | args = parser.parse_args()
25 |
26 |
27 | def main():
28 | pdbs = list(map(lambda x: x.strip(), open('index.txt')))
29 | if args.num_workers > 1:
30 | p = Pool(args.num_workers)
31 | p.__enter__()
32 | __map__ = p.imap
33 | else:
34 | __map__ = map
35 | infos = list(tqdm.tqdm(__map__(unpack_mmcif, pdbs), total=len(pdbs)))
36 | if args.num_workers > 1:
37 | p.__exit__(None, None, None)
38 | df = pd.DataFrame(infos).set_index('name')
39 | df.to_csv(args.outcsv)
40 |
41 | def unpack_mmcif(name):
42 | path = f"{args.indir}/{name}.pdb"
43 |
44 | with open(path, 'r') as f:
45 | prot = protein.from_pdb_string(f.read())
46 |
47 | atom37 = prot.atom_positions
48 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py)
49 | bb_pos = np.concatenate([atom37[ :, :3, :], atom37[ :, 4:5, :]], axis=1) # (L, 4, 3)
50 | dssp = pydssp.assign(bb_pos, out_type='index') # 0: loop, 1: alpha-helix, 2: beta-strand
51 |
52 | data = {
53 | 'atom_positions': prot.atom_positions,
54 | 'aatype': prot.aatype,
55 | 'atom_mask': prot.atom_mask,
56 | 'b_factors': prot.b_factors,
57 | 'dssp': dssp,
58 | }
59 | np.savez(f"{args.outdir}/{name}.npz", **data)
60 |
61 | out = {
62 | 'name': name,
63 | 'seqres': "".join([rc.restypes[x] for x in prot.aatype]),
64 | 'length': len(prot.aatype),
65 | 'plddt': 100 - prot.b_factors.mean(),
66 | 'loop_frac': (dssp == 0).mean(),
67 | 'helix_frac': (dssp == 1).mean(),
68 | 'sheet_frac': (dssp == 2).mean(),
69 | }
70 |
71 | return out
72 |
73 | if __name__ == "__main__":
74 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from proteinblobs.parsing import parse_train_args
12 | args = parse_train_args()
13 | from proteinblobs.logger import get_logger
14 | logger = get_logger(__name__)
15 | from proteinblobs.multiflow_wrapper import MultiflowWrapper
16 |
17 |
18 | from omegaconf import OmegaConf
19 | import torch, wandb, os
20 | from proteinblobs.dataset import SeqCollate
21 | from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
22 | from proteinblobs.dataset import GenieDBDataset, MultiflowDataset
23 | import pytorch_lightning as pl
24 | from pytorch_lightning.utilities import rank_zero_only
25 | import hashlib
26 | import numpy as np
27 |
28 | torch.set_float32_matmul_precision('medium')
29 |
30 | torch.manual_seed(args.seed)
31 | np.random.seed(args.seed)
32 |
33 | @rank_zero_only
34 | def init_wandb():
35 | wandb.init(
36 | entity="no-graining-mit",
37 | settings=wandb.Settings(start_method="fork"),
38 | project="proteinblobs",
39 | name=args.run_name,
40 | id=hashlib.md5(str(args).encode("utf-8")).hexdigest(),
41 | resume='allow', # https://docs.wandb.ai/ref/python/init ----> "allow": if id is set with init(id="UNIQUE_ID") or WANDB_RUN_ID="UNIQUE_ID" and it is identical to a previous run, wandb will automatically resume the run with that id.
42 | config=args,
43 | )
44 |
45 | if args.wandb:
46 | init_wandb()
47 |
48 | args.__dict__.update({
49 | 'multiflow_yaml': "multiflow_config.yaml"
50 | })
51 | cfg = OmegaConf.load(args.multiflow_yaml)
52 |
53 | val_len = 1000
54 | if args.dataset == 'genie':
55 | dataset = GenieDBDataset(args)
56 | trainset, valset = torch.utils.data.random_split(dataset, [len(dataset) - val_len, val_len])
57 | elif args.dataset == 'multiflow':
58 | trainset = MultiflowDataset(args, cfg.pdb_dataset)
59 | valset = MultiflowDataset(args, cfg.pdb_post2021_dataset)
60 | trainset[0]
61 | valset[0]
62 | print('len trainset', len(trainset))
63 | print('len valset', len(valset))
64 |
65 |
66 |
67 | train_loader = torch.utils.data.DataLoader(
68 | trainset,
69 | batch_size=args.batch_size,
70 | num_workers=args.num_workers,
71 | collate_fn=SeqCollate(args),
72 | shuffle=True,
73 | )
74 | val_loader = torch.utils.data.DataLoader(
75 | valset,
76 | batch_size=args.batch_size,
77 | collate_fn=SeqCollate(args),
78 | num_workers=args.num_workers,
79 | shuffle=False
80 | )
81 |
82 | model = MultiflowWrapper(args, cfg)
83 |
84 | trainer = pl.Trainer(
85 | accelerator="gpu" if torch.cuda.is_available() else 'auto',
86 | strategy='ddp',
87 | max_epochs=args.epochs,
88 | limit_train_batches=args.train_batches or 1.0,
89 | limit_val_batches=0.0 if args.no_validate else (args.val_batches or 1.0),
90 | num_sanity_val_steps=0,
91 | enable_progress_bar=not args.wandb,
92 | gradient_clip_val=args.grad_clip,
93 | default_root_dir=os.environ["MODEL_DIR"],
94 | callbacks=[
95 | ModelCheckpoint(
96 | dirpath=os.environ["MODEL_DIR"],
97 | save_top_k=-1,
98 | every_n_epochs=args.ckpt_freq,
99 | ),
100 | ModelSummary(max_depth=2),
101 | ],
102 | accumulate_grad_batches=args.accumulate_grad,
103 | val_check_interval=args.val_freq,
104 | check_val_every_n_epoch=args.val_epoch_freq,
105 | logger=False
106 | )
107 |
108 |
109 | if args.ckpt is not None:
110 | # if there is an hpc checkpoint in the current workdir, then we use that to resume instead of the file in the args.ckpt
111 | if any(['hpc' in file for file in os.listdir(os.environ['MODEL_DIR'])]):
112 | ckpt_path = 'hpc'
113 | else:
114 | ckpt_path = args.ckpt
115 | else:
116 | ckpt_path = None
117 |
118 | if args.validate:
119 | trainer.validate(model, val_loader, ckpt_path=ckpt_path)
120 | else:
121 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
122 |
--------------------------------------------------------------------------------