├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── MODEL_CARD.md ├── README.md ├── assets ├── fig1.png └── image.png ├── model_weights ├── trained_on_afdb.ckpt └── trained_on_pdb.ckpt ├── multiflow_config.yaml ├── proteinblobs ├── blobs.py ├── constants.py ├── dataset.py ├── designability_utils.py ├── logger.py ├── multiflow │ ├── LICENSE.txt │ ├── cross_attention.py │ ├── data │ │ ├── all_atom.py │ │ ├── interpolant.py │ │ ├── protein.py │ │ ├── residue_constants.py │ │ ├── so3_utils.py │ │ └── utils.py │ ├── datasets.py │ ├── edge_feature_net.py │ ├── flow_model.py │ ├── ipa_pytorch.py │ ├── node_feature_net.py │ ├── rigid_utils.py │ └── utils.py ├── multiflow_wrapper.py ├── parsing.py ├── utils.py ├── visualize.py └── wrapper.py ├── run_multiflow.py ├── sample.py ├── scripts ├── ca_only_helicity.py ├── create_multiflowdb.py ├── evaluate_alignment.py ├── evaluate_designability.py ├── evaluate_designability_codesign.py ├── preprocess_clusters.py └── unpack_geniepdb.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ckpt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | weights 3 | __pycache__ 4 | workdir 5 | wandb 6 | results 7 | .cache 8 | afdb* 9 | index.txt 10 | run.sh 11 | outpdb 12 | sbatch 13 | run_tmp.sh 14 | *.ipynb 15 | aln 16 | outpdb 17 | tmp 18 | outpdb2 19 | data 20 | *.pkl 21 | figures 22 | *.pse 23 | *.pdb 24 | *.out 25 | core.* 26 | .DS_Store 27 | proteinblobs/multiflow/metadata 28 | *.html 29 | slurm_out -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 9 | 10 | 2. License Grant 11 | 12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 13 | 14 | 3. Limitations 15 | 16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 17 | 18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 19 | 20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 21 | 22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 23 | 24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 25 | 26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 27 | 28 | 4. Disclaimer of Warranty. 29 | 30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 32 | 33 | 5. Limitation of Liability. 34 | 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 36 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # ProtComposer Model Card 2 | 3 | ## 1. Model Overview 4 | 5 | ### Description: 6 | ProtComposer is a generative model that generates novel protein structures. It leverages joint protein backbone and sequence flow matching and offers controllability through 3D ellipsoids that can be positioned by the user or an auxiliary model to guide the shape of the protein structure during generation. This enables compositional protein structure generation, advancing controllability in protein design tasks. 7 | 8 | The model is for research and non-commercial use only. 9 | 10 | ### License/Terms of Use: 11 | ProtComposer source code and models are licensed under the NVIDIA license, see [LICENSE.txt](LICENSE.txt). By using ProtComposer, you accept the terms and conditions of this license. 12 | 13 | ### References: 14 | Research paper: [“ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids,”](https://openreview.net/forum?id=0ctvBgKFgc) 15 | 16 | ### Model Architecture: 17 | - **Architecture Type:** Transformer (Attention, Cross-Attention, Invariant Point Attention) 18 | - **Network Architecture:** ProtComposer 19 | 20 | ### Input: 21 | - **Input Types:** Number (numbers for total number of protein residues to generate, generation step size, guidance scale and rotational annealing scale). Tensor (tensors for ellipsoid means, flattened ellipsoid covariance matrices and ellipsoid features). 22 | - **Input Formats:** 23 | - Number: Integer (residue number), FP32 (step size, guidance scale, annealing scale) 24 | - Tensor: Pytorch Tensor 25 | - **Input Parameters:** 26 | - Number (Integer and FP32): 1D 27 | - Tensor: 3D (batch elements, ellipsoids, features) 28 | - **Other Properties Related to Input:** Total number of protein residues to generate, generation step size, guidance scale, rotational annealing scale, ellipsoid number of residues and secondary structure type should be positive numbers. 29 | 30 | ### Output: 31 | - **Output Types:** Tensor (residue coordinates of generated protein). Tensor (amino acid types of residues of generated protein) 32 | - **Output Formats:** 33 | - Tensor (coordinates): Pytorch tensor 34 | - Tensor (amino acid types): Pytorch integer tensor 35 | - **Output Parameters:** 36 | - Tensor (coordinates): 3D (batch, length of protein, spatial dimensions) 37 | - Tensor (amino acid types): 2D (batch, length of protein) 38 | 39 | ### Software Integration: 40 | - **Runtime Engine:** Pytorch 41 | - **Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper 42 | - **Preferred Operating System:** Linux 43 | 44 | ### Model Versions: 45 | - **ProtComposer v1.0** (trained on Protein Data Bank) 46 | - **ProtComposer v1.1** (trained on AlphaFold Database) 47 | 48 | ## 2. Training and Evaluation Datasets 49 | 50 | ### Training Datasets: 51 | - **Protein Data Bank (PDB)** 52 | - **Link:** [https://www.rcsb.org/](https://www.rcsb.org/) 53 | - **Data Collection Method by dataset:** Automatic/Sensors/Human (experimental protein structure determination) 54 | - **Labeling Method by dataset:** N/A 55 | - **Properties:** The Protein Data Bank (PDB) contains approx. 200K experimentally determined three-dimensional structures of large biological molecules, such as proteins and nucleic acids, along with auxiliary information such as the protein sequences. We train ProtComposer on a filtered subset of the PDB, comprising 20,312 proteins. Specifically, we use monomers between length 60 and 512 with resolution < 5A downloaded from the PDB with a time cutoff on August 8, 2021. This resulted in 23,913 proteins. The data is further filtered by only including proteins with high secondary structure compositions. For each monomer, we ran the Define Secondary Structure of Proteins (DSSP) algorithm and removed monomers with more than 50% loops, resulting in 20,312 proteins. 56 | - **AlphaFold Protein Structure Database (AFDB)** 57 | - **Link:** [https://alphafold.ebi.ac.uk/](https://alphafold.ebi.ac.uk/) 58 | - **Data Collection Method by dataset:** Synthetic (AlphaFold predictions) 59 | - **Labeling Method by dataset:** N/A 60 | - **Properties:** The AlphaFold Protein Structure Database (AFDB) contains approx. 214M synthetic three-dimensional protein structures predicted by AlphaFold2, along with their corresponding sequences. We train ProtComposer on a subset of the AFDB, comprising 588,570 structures. We use a subsampled version that applies FoldSeek to cluster entries based on structural similarity. We start with all cluster representatives from the FoldSeek-clustered database and then filter them using a pLDDT threshold of >80, to enrich for highly confident predictions, and a maximum sequence length of 256. This results in 588,570 structures. 61 | 62 | ### Evaluation Dataset: 63 | - **Protein Data Bank (PDB)** 64 | - **Link:** [https://www.rcsb.org/](https://www.rcsb.org/) 65 | - **Data Collection Method by dataset:** Automatic/Sensors/Human (experimental protein structure determination) 66 | - **Labeling Method by dataset:** N/A 67 | - **Properties:** The Protein Data Bank (PDB) contains approx. 200K experimentally determined three-dimensional structures of large biological molecules, such as proteins and nucleic acids, along with auxiliary information such as the protein sequences. Similar to Multiflow [Campbell et al., “Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design”](https://arxiv.org/abs/2402.04997), we evaluate ProtComposer on a subset of the PDB, comprising 449 protein structures. Our evaluation set is based on a time-based split of the PDB. We downloaded structures and sequences from the PDB that were released between 1st September 2021 and 28th December 2023. We then select all single chain monomeric proteins with length between 50 and 400 inclusive. We further filter out proteins that are more than 50% coil residues and proteins that have a radius of gyration in the 96th percentile of the original dataset or above. We also filter out structures that have missing residues. We cluster proteins using the 30% sequence identity MMSeqs2 clustering provided by RCSB.org. We take a single protein from each cluster that matches our filtering criteria. This gives us an evaluation set of 449 proteins with minimum length 51 and maximum length 398. 68 | 69 | ### Inference: 70 | - **Engine:** Pytorch 71 | - **Test Hardware:** A100, H100 72 | 73 | ### Ethical Considerations: 74 | NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. 75 | 76 | Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids

Oral at ICLR 2025

2 | 3 |
4 | Hannes Stark*·   5 | Bowen Jing*·   6 | Tomas Geffner·   7 | Jason Yim·   8 | Tommi Jaakkola·   9 | Arash Vahdat·   10 | Karsten Kreis 11 |

12 | *equal contribution

13 | Paper   14 |
15 |

16 |

17 | Animation 18 | Animation 19 |

20 | 21 | ## Environment 22 | 23 | We recommend installing Miniconda (https://docs.anaconda.com/miniconda/install/) and creating the following conda environment: 24 | ```shell script 25 | conda create -n nv python=3.9 26 | pip install jupyterlab 27 | pip install numpy==1.21.2 pandas==1.5.3 28 | pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 29 | pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.1 absl-py einops 30 | pip install pytorch_lightning==2.0.4 fair-esm 31 | pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@5484c38' 32 | pip install matplotlib==3.7.2 33 | pip install pydssp biotite omegaconf wandb 34 | pip install numpy==1.21.2 35 | # pip install numpy throws an errror for contourpy but that is fine 36 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+cu113 37 | pip3 install -U scikit-learn 38 | pip install gpustat 39 | ``` 40 | 41 | ## Pretrained checkpoints 42 | 43 | We provide two pretrained checkpoints in the directory `model_weights`. One is trained on data from the Protein Data Bank, the other on AlphaFold Database. 44 | 45 | | Description | Checkpoint path | 46 | |:----------|:----------| 47 | | AlphaFold Database training data | [`model_weights/trained_on_afdb.ckpt`](model_weights/trained_on_afdb.ckpt) | 48 | | Protein Data Bank training data | [`model_weights/trained_on_pdb.ckpt`](`model_weights/trained_on_pdb.ckpt`) | 49 | 50 | ## Sampling 51 | 52 | To sample ProtComposer conditioned on ellipsoids from our ellipsoid statistical model: 53 | 54 | ```shell script 55 | python sample.py --guidance 1.0 --num_prots 6 --nu 5 --sigma 6 --helix_frac 0.4 --seed 1 --outdir results --num_blobs 9 --ckpt "model_weights/trained_on_pdb.ckpt" 56 | ``` 57 | 58 | ## Evaluation 59 | 60 | We use these scripts to compute metrics for the outputs (`evaluate_alignment` computes the ellipsoid adherence metrics): 61 | ```shell script 62 | python -m scripts.evaluate_designability --dir results 63 | python -m scripts.evaluate_alignment --dir results 64 | ``` 65 | 66 | 67 | ## Training 68 | 69 | 70 | ### Data preparations 71 | 72 | We use the data from MultiFlow: https://github.com/jasonkyuyim/multiflow 73 | 74 | They host the datasets on Zenodo [here](https://zenodo.org/records/10714631?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjJjMTk2YjlmLTM4OTUtNGVhYi1hODcxLWE1ZjExOTczY2IzZiIsImRhdGEiOnt9LCJyYW5kb20iOiI4MDY5ZDUzYjVjMTNhNDllMDYxNmI3Yjc2NjcwYjYxZiJ9.C2eZZmRu-nu7H330G-DkV5kttfjYB3ANozdOMNm19uPahvtLrDRvd_4Eqlyb7lp24m06e4OHhHQ4zlj68S1O_A). 75 | Download the following files, and place them into the directory `data` 76 | * `real_train_set.tar.gz` (2.5 GB) 77 | * `synthetic_train_set.tar.gz` (220 MB) 78 | * `test_set.tar.gz` (347 MB) 79 | Next, untar the files 80 | ```bash 81 | # Uncompress training data 82 | mkdir train_set 83 | tar -xzvf real_train_set.tar.gz -C train_set/ 84 | tar -xzvf synthetic_train_set.tar.gz -C train_set/ 85 | 86 | # Uncompress test data 87 | mkdir test_set 88 | tar -xzvf test_set.tar.gz -C test_set/ 89 | ``` 90 | 91 | Download the files and put them into the directory `data` such that you obtain the following file structure (the `--pkl_dir` argument is `data` by default): 92 | ```bash 93 | data 94 | ├── train_set 95 | │   ├── processed_pdb 96 | | | ├── 97 | | | | └── .pkl 98 | │   ├── processed_synthetic 99 | | | └── .pkl 100 | ├── test_set 101 | | └── processed 102 | | | ├── 103 | | | | └── .pkl 104 | ... 105 | ``` 106 | 107 | ### Launch training run 108 | 109 | By default, we train on 8 GPUs. 110 | 111 | ```shell script 112 | python train.py --batch_size 8 --designability --designability_freq 5 --num_designability_prots 50 --accumulate_grad 8 --inf_batches 5 --val_batches 5 --finetune --dataset multiflow --self_condition --num_workers 10 --save_val --epochs 1000 --run_name my_run_name --wandb 113 | ``` 114 | 115 | ## License 116 | 117 | Code and model weights are released under an NVIDIA license for non-commercial or research purposes only. Please see the [LICENSE.txt](https://github.com/NVlabs/protcomposer/blob/main/LICENSE.txt) file. 118 | 119 | 120 | ## Citation 121 | 122 | 123 | ```bib 124 | @inproceedings{stark2025protcomposer, 125 | title={ProtComposer: Compositional Protein Structure Generation with 3D Ellipsoids}, 126 | author={Hannes Stark and Bowen Jing and Tomas Geffner and Jason Yim and Tommi Jaakkola and Arash Vahdat and Karsten Kreis}, 127 | booktitle={The Thirteenth International Conference on Learning Representations (ICLR)}, 128 | year={2025}, 129 | url={https://openreview.net/forum?id=0ctvBgKFgc} 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/protcomposer/0eefdc79100a4ed5b7f585f264821b87e6e14fed/assets/fig1.png -------------------------------------------------------------------------------- /assets/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/protcomposer/0eefdc79100a4ed5b7f585f264821b87e6e14fed/assets/image.png -------------------------------------------------------------------------------- /model_weights/trained_on_afdb.ckpt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:735949839257b87cd8bc5aa6e9770867685a8c3328ad33be351ddb5fa0cf1c69 3 | size 294278008 4 | -------------------------------------------------------------------------------- /model_weights/trained_on_pdb.ckpt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c48de733ac64d4d38475f1b2e8e820be297aa992c9cea1cb7acf339762428c40 3 | size 294344824 4 | -------------------------------------------------------------------------------- /multiflow_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: pdb 3 | loader: 4 | num_workers: 8 5 | prefetch_factor: 10 6 | sampler: 7 | max_batch_size: 64 8 | max_num_res_squared: 400000 9 | task: hallucination 10 | interpolant: 11 | min_t: 0.01 12 | separate_t: false 13 | provide_kappa: false 14 | hierarchical_t: false 15 | codesign_separate_t: true 16 | codesign_forward_fold_prop: 0.1 17 | codesign_inverse_fold_prop: 0.1 18 | rots: 19 | corrupt: true 20 | train_schedule: linear 21 | sample_schedule: exp 22 | exp_rate: 10 23 | trans: 24 | corrupt: true 25 | batch_ot: true 26 | train_schedule: linear 27 | sample_schedule: linear 28 | sample_temp: 1.0 29 | vpsde_bmin: 0.1 30 | vpsde_bmax: 20.0 31 | potential: null 32 | potential_t_scaling: false 33 | rog: 34 | weight: 10.0 35 | cutoff: 5.0 36 | aatypes: 37 | corrupt: true 38 | schedule: linear 39 | schedule_exp_rate: -3 40 | temp: 0.1 41 | noise: 0.0 42 | do_purity: true 43 | train_extra_mask: 0.0 44 | interpolant_type: masking 45 | sampling: 46 | num_timesteps: 100 47 | do_sde: false 48 | self_condition: ${model.edge_features.self_condition} 49 | twisting: 50 | use: false 51 | folding: 52 | seq_per_sample: 8 53 | own_device: false 54 | folding_model: esmf 55 | pmpnn_path: ./ProteinMPNN/ 56 | pt_hub_dir: ./.cache/torch/ 57 | experiment: 58 | debug: false 59 | seed: 123 60 | num_devices: 4 61 | warm_start: flow-matching/ckpt/se3-fm/reproduce_v2/2023-12-28_14-23-11/epoch=59-step=28380.ckpt 62 | warm_start_cfg_override: false 63 | raw_state_dict_reload: null 64 | training: 65 | mask_plddt: true 66 | bb_atom_scale: 0.1 67 | trans_scale: 0.1 68 | aatypes_loss_weight: 1.0 69 | aatypes_label_smoothing: 0.0 70 | aatypes_loss_mean_or_sum: mean 71 | aatypes_loss_use_likelihood_weighting: false 72 | translation_loss_weight: 2.0 73 | t_normalize_clip: 0.9 74 | rotation_loss_weights: 1.0 75 | aux_loss_weight: 0.0 76 | aux_loss_use_bb_loss: true 77 | aux_loss_use_pair_loss: true 78 | aux_loss_t_pass: 0.5 79 | wandb: 80 | name: separate_t 81 | project: se3-fm 82 | optimizer: 83 | lr: 0.0001 84 | trainer: 85 | overfit_batches: 0 86 | min_epochs: 1 87 | max_epochs: 200 88 | accelerator: gpu 89 | log_every_n_steps: 1 90 | deterministic: false 91 | strategy: ddp 92 | check_val_every_n_epoch: 6 93 | accumulate_grad_batches: 2 94 | checkpointer: 95 | dirpath: ./ 96 | save_last: true 97 | save_top_k: 3 98 | every_n_epochs: 4 99 | monitor: valid/codesign_bb_rmsd 100 | mode: min 101 | inference_dir: ./inference_outputs/./weights/last/unconditional/run_2024-03-02_11-49-28 102 | shared: 103 | seed: 123 104 | max_cache_size: 100000 105 | samples_per_eval_length: 5 106 | num_eval_lengths: 8 107 | max_eval_length: 256 108 | min_motif_percent: 0.05 109 | max_motif_percent: 0.5 110 | pdb_dataset: 111 | seed: ${shared.seed} 112 | csv_path: proteinblobs/multiflow/metadata/pdb_metadata.csv 113 | cluster_path: proteinblobs/multiflow/metadata/pdb.clusters 114 | test_set_pdb_ids_path: null 115 | max_cache_size: ${shared.max_cache_size} 116 | cache_num_res: 0 117 | inpainting_percent: 1.0 118 | add_plddt_mask: false 119 | max_eval_length: ${shared.max_eval_length} 120 | redesigned_csv_path: proteinblobs/multiflow/metadata/pdb_redesigned.csv 121 | use_redesigned: true 122 | synthetic_csv_path: proteinblobs/multiflow/metadata/distillation_metadata.csv 123 | synthetic_cluster_path: proteinblobs/multiflow/metadata/distillation.clusters 124 | use_synthetic: true 125 | samples_per_eval_length: ${shared.samples_per_eval_length} 126 | num_eval_lengths: ${shared.num_eval_lengths} 127 | filter: 128 | max_num_res: 384 129 | min_num_res: 60 130 | max_coil_percent: 0.5 131 | rog_quantile: 0.96 132 | oligomeric: 133 | - monomeric 134 | num_chains: 135 | - 1 136 | min_motif_percent: ${shared.min_motif_percent} 137 | max_motif_percent: ${shared.max_motif_percent} 138 | is_mixed: false 139 | pdb_post2021_dataset: 140 | seed: ${shared.seed} 141 | csv_path: proteinblobs/multiflow/metadata/test_set_metadata.csv 142 | cluster_path: proteinblobs/multiflow/metadata/test_set_clusters.csv 143 | test_set_pdb_ids_path: proteinblobs/multiflow/metadata/test_set_pdb_ids.csv 144 | max_cache_size: ${shared.max_cache_size} 145 | cache_num_res: 0 146 | add_plddt_mask: False 147 | max_eval_length: ${shared.max_eval_length} 148 | redesigned_csv_path: null 149 | use_redesigned: False 150 | 151 | synthetic_csv_path: null 152 | synthetic_cluster_path: null 153 | use_synthetic: False 154 | 155 | # Eval parameters 156 | samples_per_eval_length: ${shared.samples_per_eval_length} 157 | num_eval_lengths: ${shared.num_eval_lengths} 158 | 159 | # Filtering 160 | filter: 161 | max_num_res: 400 162 | min_num_res: 50 163 | max_coil_percent: 0.5 164 | rog_quantile: 0.96 165 | oligomeric: ['monomeric'] 166 | num_chains: [1] 167 | model: 168 | node_embed_size: 256 169 | edge_embed_size: 128 170 | symmetric: false 171 | aatype_pred: true 172 | transformer_dropout: 0.2 173 | aatype_pred_num_tokens: 21 174 | node_features: 175 | c_s: ${model.node_embed_size} 176 | c_pos_emb: 128 177 | c_timestep_emb: 128 178 | max_num_res: 2000 179 | timestep_int: 1000 180 | embed_chain: false 181 | embed_aatype: true 182 | use_mlp: true 183 | aatype_pred_num_tokens: ${model.aatype_pred_num_tokens} 184 | edge_features: 185 | single_bias_transition_n: 2 186 | c_s: ${model.node_embed_size} 187 | c_p: ${model.edge_embed_size} 188 | relpos_k: 64 189 | feat_dim: 64 190 | num_bins: 22 191 | self_condition: true 192 | embed_chain: false 193 | embed_diffuse_mask: true 194 | ipa: 195 | c_s: ${model.node_embed_size} 196 | c_z: ${model.edge_embed_size} 197 | c_hidden: 16 198 | no_heads: 8 199 | no_qk_points: 8 200 | no_v_points: 12 201 | seq_tfmr_num_heads: 4 202 | seq_tfmr_num_layers: 4 203 | num_blocks: 8 204 | dropout: 0.0 205 | use_sequence_net: false 206 | sequence_net: 207 | num_layers: 6 208 | use_init_embed: true 209 | use_init_rigid: false 210 | use_local_attention: false 211 | init_node_embed: ${model.node_embed_size} 212 | init_edge_embed: ${model.edge_embed_size} 213 | ipa: 214 | c_s: 128 215 | c_z: 64 216 | c_hidden: 128 217 | no_heads: 4 218 | no_qk_points: 8 219 | no_v_points: 8 220 | dropout: 0.2 221 | inference: 222 | predict_dir: ./inference_outputs/ 223 | inference_subdir: run_${now:%Y-%m-%d}_${now:%H-%M-%S} 224 | task: unconditional 225 | seed: 123 226 | use_gpu: true 227 | num_gpus: 1 228 | saved_ckpt_dir: ./saved_ckpts/frameflow 229 | unconditional_ckpt_path: ./weights/last.ckpt 230 | also_fold_pmpnn_seq: true 231 | write_sample_trajectories: false 232 | interpolant: 233 | min_t: 0.01 234 | provide_kappa: false 235 | codesign_separate_t: false 236 | rots: 237 | corrupt: true 238 | sample_schedule: exp 239 | exp_rate: 10 240 | trans: 241 | corrupt: true 242 | pre_align: true 243 | train_schedule: linear 244 | sample_schedule: linear 245 | sample_temp: 1.0 246 | potential: null 247 | potential_t_scaling: false 248 | rog: 249 | weight: 20.0 250 | cutoff: 5.0 251 | aatypes: 252 | corrupt: true 253 | schedule: linear 254 | schedule_exp_rate: -3 255 | temp: 0.1 256 | noise: 20.0 257 | do_purity: true 258 | interpolant_type: masking 259 | sampling: 260 | num_timesteps: 500 261 | do_sde: false 262 | self_condition: true 263 | pmpnn_dir: ./ProteinMPNN 264 | folding: 265 | seq_per_sample: 8 266 | folding_model: esmf 267 | own_device: false 268 | pmpnn_path: ./ProteinMPNN/ 269 | pt_hub_dir: ./.cache/torch/ 270 | colabfold_path: path/to/colabfold-conda/bin/colabfold_batch 271 | samples: 272 | samples_per_length: 100 273 | num_batch: 1 274 | length_subset: 275 | - 256 276 | min_length: 60 277 | max_length: 256 278 | length_step: 1 -------------------------------------------------------------------------------- /proteinblobs/blobs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import numpy as np 13 | import networkx as nx 14 | import numpy as np 15 | from scipy.special import logsumexp, softmax 16 | import torch 17 | from torch.distributions.multivariate_normal import MultivariateNormal 18 | from scipy.stats import wishart 19 | 20 | alpha_slope = 0.22213567936924072 21 | alpha_intercept = 7.512307373103329 22 | 23 | beta_slope = 0.2727479880396384 24 | beta_intercept = 8.28561006547174 25 | 26 | 27 | def blobify(pos, dssp, radius_thresh=5, size_thresh=5): 28 | distmat = np.square(pos[:, None] - pos[None]).sum(-1) ** 0.5 29 | G = nx.Graph() 30 | edges = np.argwhere((distmat < radius_thresh) & (dssp[None] == dssp[:, None])) 31 | G.add_edges_from(edges) 32 | blobs = [] 33 | for con in nx.connected_components(G): 34 | con = list(con) 35 | if dssp[con[0]] == 0: 36 | continue 37 | if len(con) < size_thresh: 38 | continue 39 | blobs.append( 40 | { 41 | "residues": con, 42 | "count": len(con), 43 | "dssp": dssp[con[0]], 44 | "pos": pos[con].mean(0).astype(float), 45 | "covar": np.cov(pos[con].T).astype(float), 46 | } 47 | ) 48 | 49 | return blobs 50 | 51 | 52 | def shannon_complexity(pos, dssp, thresh=5): 53 | blobs = blobify(pos, dssp, radius_thresh=thresh) 54 | counts = np.array([b["count"] for b in blobs]) 55 | probs = counts / counts.sum() 56 | return -(probs * np.log(probs)).sum() 57 | 58 | 59 | def score_blobs(centers, covars): 60 | energy = 0 61 | for pos, cov in zip(centers, covars): 62 | relpos = centers - pos 63 | d = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(cov), relpos) ** 0.5 64 | d = d[d != 0] 65 | energy += (1 / d**2).sum() 66 | return energy 67 | 68 | 69 | def sample_blobs(k, nu, psi, sigma): 70 | while True: 71 | covar = wishart.rvs(nu, psi, size=k) 72 | if k == 1: 73 | covar = covar[None] 74 | pos = np.random.randn(k, 3) * sigma 75 | score = score_blobs(pos, covar) 76 | if np.random.rand() < np.exp(-score): 77 | return pos, covar 78 | 79 | 80 | def gmm_ll(pos, centers, covars, probs): 81 | probs = probs / probs.sum() 82 | relpos = pos[:, None] - centers[None] 83 | ll = ( 84 | -0.5 * np.einsum("lni,nij,lnj->ln", relpos, np.linalg.inv(covars), relpos) 85 | - 0.5 * np.linalg.slogdet(covars)[1][None] 86 | + np.log(probs)[None] 87 | ) 88 | return logsumexp(ll, -1) 89 | 90 | 91 | def dssp_gmm_ll(pos, centers, covars, probs, pos_dssp, centers_dssp): 92 | if np.any(centers_dssp == 1): 93 | gmm_ll_1 = gmm_ll( 94 | pos, 95 | centers[centers_dssp == 1], 96 | covars[centers_dssp == 1], 97 | probs[centers_dssp == 1], 98 | ) 99 | gmm_ll_1[pos_dssp != 1] = -np.inf 100 | else: 101 | gmm_ll_1 = -np.inf * np.ones(len(pos)) 102 | 103 | if np.any(centers_dssp == 2): 104 | gmm_ll_2 = gmm_ll( 105 | pos, 106 | centers[centers_dssp == 2], 107 | covars[centers_dssp == 2], 108 | probs[centers_dssp == 2], 109 | ) 110 | gmm_ll_2[pos_dssp != 2] = -np.inf 111 | else: 112 | gmm_ll_2 = -np.inf * np.ones(len(pos)) 113 | 114 | dssp_1_count = probs[centers_dssp == 1].sum() / probs.sum() 115 | dssp_2_count = probs[centers_dssp == 2].sum() / probs.sum() 116 | 117 | return np.log(np.exp(gmm_ll_1) * dssp_1_count + np.exp(gmm_ll_2) * dssp_2_count) 118 | 119 | 120 | """ 121 | Treats the blobs as a GMM and evaluates the average log-DENSITY at each pos. 122 | Note that DENSITY is N * PDF where N is the total number of residues in the blobs. 123 | Not suitable for partial blobs. 124 | """ 125 | 126 | 127 | def blob_likelihood(pos, dssp, blobs, structured_only=False, mean=True): 128 | if structured_only: 129 | pos = pos[dssp != 0] 130 | dssp = dssp[dssp != 0] 131 | if len(dssp) == 0: 132 | return -10 133 | if len(blobs) == 0: 134 | ll = -np.inf * np.ones(pos.shape[0]) 135 | else: 136 | covars = np.stack([blob["covar"] for blob in blobs]) 137 | counts = np.array([blob["count"] for blob in blobs]) 138 | centers = np.stack([blob["pos"] for blob in blobs]) 139 | 140 | ll = gmm_ll(pos, centers, covars, counts) 141 | ll += np.log(counts.sum()) 142 | 143 | if not mean: 144 | return ll 145 | return ll.mean() 146 | 147 | 148 | """ 149 | Reblobs the protein and evaluates the number of residues per blob 150 | """ 151 | 152 | 153 | def blobs_per_res(pos, dssp, thresh=5): 154 | blobs = blobify(pos, dssp, thresh) 155 | return len(blobs) / len(pos) 156 | 157 | 158 | def res_per_blob(pos, dssp, thresh=5): 159 | blobs = blobify(pos, dssp, thresh) 160 | return len(pos) / len(blobs) 161 | 162 | 163 | """ 164 | Reblobs the protein and evaluates the JSD between the GMMs 165 | defined by the original and new blobs. 166 | Not suitable for partial blobs. 167 | If use_dssp=True, consider sample space R^3 x {1,2} 168 | """ 169 | 170 | 171 | def reblob_jsd(pos, dssp, blobs, num_samples=1000, use_dssp=False, thresh=5): 172 | p_blobs = blobs 173 | q_blobs = blobify(pos, dssp, thresh) 174 | 175 | p_covars = np.stack([blob["covar"] for blob in p_blobs]) 176 | p_counts = np.array([blob["count"] for blob in p_blobs]) 177 | p_centers = np.stack([blob["pos"] for blob in p_blobs]) 178 | p_dssp = np.array([blob["dssp"] for blob in p_blobs]) 179 | 180 | q_covars = np.stack([blob["covar"] for blob in q_blobs]) 181 | q_counts = np.array([blob["count"] for blob in q_blobs]) 182 | q_centers = np.stack([blob["pos"] for blob in q_blobs]) 183 | q_dssp = np.array([blob["dssp"] for blob in q_blobs]) 184 | 185 | p_idx = np.random.choice( 186 | np.arange(len(p_blobs)), size=num_samples, p=p_counts / p_counts.sum() 187 | ) 188 | p_samps = MultivariateNormal( 189 | torch.from_numpy(p_centers[p_idx]).float(), 190 | torch.from_numpy(p_covars[p_idx]).float(), 191 | ).sample() 192 | if use_dssp: 193 | p_p = dssp_gmm_ll(p_samps, p_centers, p_covars, p_counts, p_dssp[p_idx], p_dssp) 194 | p_q = dssp_gmm_ll(p_samps, q_centers, q_covars, q_counts, p_dssp[p_idx], q_dssp) 195 | 196 | else: 197 | p_p = gmm_ll(p_samps, p_centers, p_covars, p_counts) 198 | p_q = gmm_ll(p_samps, q_centers, q_covars, q_counts) 199 | 200 | p_m = np.log(np.exp(p_p) / 2 + np.exp(p_q) / 2) 201 | kl_p = (p_p - p_m).mean() 202 | 203 | q_idx = np.random.choice( 204 | np.arange(len(q_blobs)), size=num_samples, p=q_counts / q_counts.sum() 205 | ) 206 | q_samps = MultivariateNormal( 207 | torch.from_numpy(q_centers[q_idx]).float(), 208 | torch.from_numpy(q_covars[q_idx]).float(), 209 | ).sample() 210 | 211 | if use_dssp: 212 | q_p = dssp_gmm_ll(q_samps, p_centers, p_covars, p_counts, q_dssp[q_idx], p_dssp) 213 | q_q = dssp_gmm_ll(q_samps, q_centers, q_covars, q_counts, q_dssp[q_idx], q_dssp) 214 | else: 215 | q_p = gmm_ll(q_samps, p_centers, p_covars, p_counts) 216 | q_q = gmm_ll(q_samps, q_centers, q_covars, q_counts) 217 | q_m = np.log(np.exp(q_p) / 2 + np.exp(q_q) / 2) 218 | kl_q = (q_q - q_m).mean() 219 | 220 | return (kl_p + kl_q) / 2 221 | 222 | 223 | """ 224 | Calculate the sum of the difference in the fraction of residues 225 | assigned to a blob and the fraction of total residues actually placed in the blob. 226 | """ 227 | 228 | 229 | def blob_misplacement(pos, dssp, blobs, thresh=2.25, structured_only=False): 230 | if structured_only: 231 | pos = pos[dssp != 0] 232 | dssp = dssp[dssp != 0] 233 | if len(dssp) == 0: 234 | return 1 235 | 236 | assigned = [] 237 | actual = [] 238 | for b in blobs: 239 | relpos = pos - b["pos"] 240 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos) 241 | mask = dists**0.5 < thresh 242 | actual.append(mask.sum()) 243 | assigned.append(b["count"]) 244 | 245 | assigned = np.array(assigned) / np.sum(assigned) 246 | actual = np.array(actual) / np.sum(actual) 247 | return np.sum(np.abs(assigned - actual)) 248 | 249 | 250 | """ 251 | The fraction of residues within some Mahalanobis distance of each blob 252 | with the correct secondary structure. 253 | Suitable for partial blobs. 254 | """ 255 | 256 | 257 | def blob_accuracy(pos, dssp, blobs, thresh=2.25, structured_only=False): 258 | if structured_only: 259 | pos = pos[dssp != 0] 260 | dssp = dssp[dssp != 0] 261 | if len(dssp) == 0: 262 | return 0 263 | 264 | correct, total = 0, 0 265 | 266 | for b in blobs: 267 | relpos = pos - b["pos"] 268 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos) 269 | mask = dists**0.5 < thresh 270 | correct += (dssp[mask] == b["dssp"]).sum() 271 | total += mask.sum() 272 | 273 | return correct / total 274 | 275 | 276 | """ 277 | Based on sheet and helix GMMs, computes p(ss | pos) for each pos in the new protein. 278 | Reports the average p(ss | pos) over all non-loop residues. 279 | Not suitable for partial blobs. 280 | """ 281 | 282 | 283 | def soft_blob_accuracy(pos, dssp, blobs, structured_only=False): 284 | if structured_only: 285 | pos = pos[dssp != 0] 286 | dssp = dssp[dssp != 0] 287 | if len(dssp) == 0: 288 | return 0 289 | 290 | sheet_ll = blob_likelihood( 291 | pos, dssp, [b for b in blobs if b["dssp"] == 1], mean=False 292 | ) 293 | helix_ll = blob_likelihood( 294 | pos, dssp, [b for b in blobs if b["dssp"] == 2], mean=False 295 | ) 296 | 297 | probs = softmax(np.stack([sheet_ll, helix_ll], -1), -1) 298 | return ((probs[:, 0][dssp == 1]).sum() + (probs[:, 1][dssp == 2]).sum()) / len(dssp) 299 | 300 | 301 | """ 302 | Reports the fraction of residues within some Mahalanobis distance of some blob. 303 | Not suitable for partial blobs. 304 | """ 305 | 306 | 307 | def blob_coverage(pos, dssp, blobs, thresh=2.25, structured_only=False): 308 | if structured_only: 309 | pos = pos[dssp != 0] 310 | dssp = dssp[dssp != 0] 311 | if len(dssp) == 0: 312 | return 0 313 | 314 | idx = np.arange(len(dssp)) 315 | seen = [] 316 | 317 | for b in blobs: 318 | relpos = pos - b["pos"] 319 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos) 320 | mask = dists**0.5 < thresh 321 | seen.extend(idx[mask]) 322 | 323 | seen = list(set(seen)) 324 | total = len(dssp) 325 | return len(seen) / total 326 | 327 | 328 | """ 329 | Counts the number of residues within some Mahalanobis distance of some blob, 330 | relative to the total number expected for all blobs collectively. 331 | If unique=True, then residues within cutoff of multiple blobs will only be counted once. 332 | Suitable for partial blobs. 333 | """ 334 | 335 | 336 | def blob_occupancy(pos, dssp, blobs, thresh=2.25, structured_only=False, unique=False): 337 | if structured_only: 338 | pos = pos[dssp != 0] 339 | dssp = dssp[dssp != 0] 340 | if len(dssp) == 0: 341 | return 0 342 | 343 | idx = np.arange(len(dssp)) 344 | seen = [] 345 | 346 | for b in blobs: 347 | relpos = pos - b["pos"] 348 | dists = np.einsum("li,ij,lj->l", relpos, np.linalg.inv(b["covar"]), relpos) 349 | mask = dists**0.5 < thresh 350 | seen.extend(idx[mask]) 351 | 352 | if unique: 353 | seen = list(set(seen)) 354 | total = sum(b["count"] for b in blobs) 355 | return len(seen) / total 356 | -------------------------------------------------------------------------------- /proteinblobs/constants.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | PYMOL_COLORS = ['grey', 13 | 'dash', 14 | 'lightteal', 15 | 'br5', 16 | 'hotpink', 17 | 'deepteal', 18 | 'yellow', 19 | 'tv_blue', 20 | 'yelloworange', 21 | 'slate', 22 | 'palecyan', 23 | 'oxygen', 24 | 'salmon', 25 | 'chocolate', 26 | 'gray', 27 | 'orange', 28 | 'blue', 29 | 'brown', 30 | 'green', 31 | 'ruby', 32 | 'forest', 33 | 'br0', 34 | 'black', 35 | 'firebrick', 36 | 'lightmagenta', 37 | 'dirtyviolet', 38 | 'lightpink', 39 | 'sulfur', 40 | 'deeppurple', 41 | 'br3', 42 | 'marine', 43 | 'darksalmon', 44 | 'br7', 45 | 'nitrogen', 46 | 'limon', 47 | 'br6', 48 | 'deepblue', 49 | 'bluewhite', 50 | 'lightorange', 51 | 'palegreen', 52 | 'carbon', 53 | 'lime', 54 | 'teal', 55 | 'tv_green', 56 | 'smudge', 57 | 'wheat', 58 | 'cyan', 59 | 'deepsalmon', 60 | 'splitpea', 61 | 'limegreen', 62 | 'purpleblue', 63 | 'red', 64 | 'tv_red', 65 | 'density', 66 | 'warmpink', 67 | 'deepolive', 68 | 'sand', 69 | 'raspberry', 70 | 'violetpurple', 71 | 'paleyellow', 72 | 'hydrogen', 73 | 'tv_orange', 74 | 'brightorange', 75 | 'br8', 76 | 'chartreuse', 77 | 'lightblue', 78 | 'br4', 79 | 'aquamarine', 80 | 'tv_yellow', 81 | 'br1', 82 | 'greencyan', 83 | 'magenta', 84 | 'pink', 85 | 'br9', 86 | 'olive', 87 | 'purple', 88 | 'white', 89 | 'skyblue', 90 | 'violet', 91 | 'br2'] -------------------------------------------------------------------------------- /proteinblobs/dataset.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import pickle 13 | import pydssp 14 | import math 15 | import numpy as np 16 | import torch 17 | import pickle 18 | from openfold.np import protein 19 | from torch.utils.data import default_collate 20 | from sklearn.cluster import KMeans 21 | import pandas as pd 22 | from openfold.utils.rigid_utils import Rigid, Rotation 23 | from openfold.np.residue_constants import restype_order, atom_order 24 | from scipy.spatial.transform import Rotation as spRotation 25 | from .blobs import blobify 26 | import proteinblobs.blobs as blobs 27 | import proteinblobs.multiflow.datasets as mfdatasets 28 | from pathlib import Path 29 | 30 | 31 | def from_3_points( 32 | p_neg_x_axis: torch.Tensor, 33 | origin: torch.Tensor, 34 | p_xy_plane: torch.Tensor, 35 | eps: float = 1e-8, 36 | ) -> Rigid: 37 | p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) 38 | origin = torch.unbind(origin, dim=-1) 39 | p_xy_plane = torch.unbind(p_xy_plane, dim=-1) 40 | 41 | e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] 42 | e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] 43 | 44 | the_sum = sum((c * c for c in e0)) 45 | the_other_sum = the_sum + eps 46 | 47 | ###### Workaround to prevent unexplicable error when taking sqrt with torch. 48 | sum_np = the_other_sum.numpy() 49 | sqrt = np.sqrt(sum_np) 50 | # denom = torch.sqrt(the_other_sum) 51 | denom = torch.from_numpy(sqrt) 52 | 53 | e0 = [c / denom for c in e0] 54 | dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) 55 | e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] 56 | denom = torch.sqrt(sum((c * c for c in e1)) + eps) 57 | e1 = [c / denom for c in e1] 58 | e2 = [ 59 | e0[1] * e1[2] - e0[2] * e1[1], 60 | e0[2] * e1[0] - e0[0] * e1[2], 61 | e0[0] * e1[1] - e0[1] * e1[0], 62 | ] 63 | rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) 64 | rots = rots.reshape(rots.shape[:-1] + (3, 3)) 65 | rot_obj = Rotation(rot_mats=rots, quats=None) 66 | 67 | return Rigid(rot_obj, torch.stack(origin, dim=-1)) 68 | 69 | 70 | def prot_to_frames(ca_coords, c_coords, n_coords): 71 | prot_frames = from_3_points( 72 | torch.from_numpy(c_coords), 73 | torch.from_numpy(ca_coords), 74 | torch.from_numpy(n_coords), 75 | ) 76 | rots = torch.eye(3) 77 | rots[0, 0] = -1 78 | rots[2, 2] = -1 79 | rots = Rotation(rot_mats=rots) 80 | frames = prot_frames.compose(Rigid(rots, None)) 81 | return frames 82 | 83 | 84 | class SeqCollate: 85 | def __init__(self, args): 86 | self.args = args 87 | if self.args.no_crop or self.args.no_pad: 88 | self.seq_len_keys = [ 89 | "grounding_feat", 90 | "grounding_pos", 91 | "grounding_mask", 92 | "chain_idx", 93 | "res_idx", 94 | "seqres", 95 | "res_mask", 96 | "mask", 97 | "rots", 98 | "trans", 99 | ] 100 | else: 101 | self.seq_len_keys = [ 102 | "grounding_feat", 103 | "grounding_pos", 104 | "grounding_mask", 105 | ] 106 | 107 | def __call__(self, batch): 108 | seq_len_batch = {} 109 | for key in self.seq_len_keys: 110 | elems = [item[key] for item in batch] 111 | max_L = max([len(elem) for elem in elems]) 112 | mask = torch.zeros((len(elems), max_L), dtype=torch.int16) 113 | elem_tensor = [] 114 | for i, elem in enumerate(elems): 115 | L = len(elem) 116 | if isinstance(elem, torch.Tensor): 117 | elem = elem.numpy() 118 | elem = np.concatenate( 119 | [elem, np.zeros((max_L - L, *elem.shape[1:]), dtype=elem.dtype)], 120 | axis=0, 121 | ) 122 | elem_tensor.append(elem) 123 | mask[i, :L] = 1 124 | seq_len_batch[key] = torch.from_numpy(np.stack(elem_tensor, axis=0)) 125 | seq_len_batch[f"{key}_mask"] = mask 126 | 127 | # remove all self.seq_len_keys from batch and put it through default collate 128 | for item in batch: 129 | for key in self.seq_len_keys: 130 | del item[key] 131 | 132 | batch = default_collate(batch) 133 | batch.update(seq_len_batch) 134 | return batch 135 | 136 | 137 | class BlobDataset(torch.utils.data.Dataset): 138 | def __init__(self, blobs): 139 | self.blobs = blobs 140 | 141 | def __len__(self): 142 | return len(self.blobs) 143 | 144 | 145 | class StructureDataset(torch.utils.data.Dataset): 146 | def __init__(self, args): 147 | super().__init__() 148 | self.args = args 149 | if self.args.length_dist_npz is not None: 150 | lens = np.load(self.args.length_dist_npz)["lengths"] 151 | sample_lens = [] 152 | for i in range(100000): 153 | choice = 0 154 | while choice < 4 or choice > args.crop: 155 | choice = np.random.choice(lens) 156 | sample_lens.append(choice) 157 | self.length_dist = np.array(sample_lens) 158 | if self.args.use_latents: 159 | latents_names = pickle.load(open(self.args.latents_order, "rb")) 160 | latents_names = [Path(f).stem for f in latents_names] 161 | latents_tensors = torch.load(self.args.latents_path) 162 | self.latents = { 163 | latents_names[i].lower(): latents_tensors[i] 164 | for i in range(len(latents_names)) 165 | } 166 | 167 | def __len__(self): 168 | return None 169 | 170 | def __getitem__(self, idx): 171 | return None 172 | 173 | def process_prot(self, idx, name, pdb=None, prot=None): 174 | if prot is None: 175 | prot = protein.from_pdb_string(pdb) 176 | atom37 = prot.atom_positions.astype(np.float32) 177 | frames = prot_to_frames( 178 | ca_coords=prot.atom_positions[:, atom_order["CA"]], 179 | c_coords=prot.atom_positions[:, atom_order["C"]], 180 | n_coords=prot.atom_positions[:, atom_order["N"]], 181 | ) 182 | res_mask = np.ones(atom37.shape[0], dtype=np.float32) 183 | seqres = prot.aatype.astype(int) 184 | res_idx = prot.residue_index 185 | chain_idx = np.zeros_like(seqres) 186 | return self.item_from_prot( 187 | name, atom37, frames, res_mask, seqres, res_idx, chain_idx 188 | ) 189 | 190 | def item_from_prot( 191 | self, name, atom37, frames, res_mask, seqres, res_idx, chain_idx 192 | ): 193 | mask = np.ones(atom37.shape[0], dtype=np.float32) 194 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py) 195 | bb_pos = np.concatenate( 196 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1 197 | ) # (L, 4, 3) 198 | 199 | L = frames.shape[0] 200 | ## filter lenght and run pydssp secondary structure determination 201 | if L < 8: 202 | print( 203 | f"Only {L} residues in the protein {name}. Resampling another protein." 204 | ) 205 | return self.__getitem__(np.random.randint(len(self))) 206 | try: 207 | dssp = pydssp.assign( 208 | bb_pos, out_type="index" 209 | ) # 0: loop, 1: alpha-helix, 2: beta-strand 210 | except Exception as e: 211 | print( 212 | f"Running pydssp failed in the protein {name}. Resampling another protein." 213 | ) 214 | print(str(e)) 215 | return self.__getitem__(np.random.randint(len(self))) 216 | 217 | ## crop 218 | if (not self.args.no_crop) and L > self.args.crop: 219 | start = np.random.randint(0, L - self.args.crop + 1) 220 | if self.args.overfit: 221 | start = 0 222 | frames = frames[start : start + self.args.crop] 223 | mask = mask[start : start + self.args.crop] 224 | res_mask = res_mask[start : start + self.args.crop] 225 | seqres = seqres[start : start + self.args.crop] 226 | chain_idx = chain_idx[start : start + self.args.crop] 227 | res_idx = res_idx[start : start + self.args.crop] 228 | dssp = dssp[start : start + self.args.crop] 229 | 230 | ## center 231 | com = (frames._trans * mask[:, None]).sum(0) / mask.sum() 232 | frames._trans -= com 233 | 234 | ## rotate 235 | randrot = spRotation.random().as_matrix().astype(np.float32) 236 | randrot = torch.from_numpy(randrot) 237 | frames._trans = frames._trans @ randrot.T 238 | frames._rots._rot_mats = randrot @ frames._rots._rot_mats 239 | 240 | ## label 241 | thresh = ( 242 | np.random.rand() * (self.args.max_blob_thresh - self.args.min_blob_thresh) 243 | + self.args.min_blob_thresh 244 | ) 245 | blobs = blobify(frames._trans.numpy(), dssp, thresh) 246 | 247 | if not blobs: 248 | print(f"No blobs in the protein {name}. Resampling another protein.") 249 | return self.__getitem__(np.random.randint(len(self))) 250 | 251 | if self.args.blob_drop_prob > 0.0 and len(blobs) > 1: 252 | np.random.shuffle(blobs) 253 | n = np.random.randint(1, len(blobs)) 254 | blobs = blobs[:n] 255 | 256 | ## pad 257 | if (not self.args.no_pad) and L < self.args.crop: 258 | pad = self.args.crop - L 259 | frames = Rigid.cat( 260 | [frames, Rigid.identity((pad,), requires_grad=False, fmt="rot_mat")], 0 261 | ) 262 | 263 | mask = np.concatenate([mask, np.zeros(pad, dtype=np.float32)]) 264 | res_mask = np.concatenate([res_mask, np.zeros(pad, dtype=np.float32)]) 265 | seqres = np.concatenate([seqres, np.zeros(pad, dtype=int)]) 266 | res_idx = np.concatenate([res_idx, np.zeros(pad, dtype=int)]) 267 | chain_idx = np.concatenate([chain_idx, np.zeros(pad, dtype=int)]) 268 | 269 | ## featurize 270 | if self.args.synthetic_blobs: 271 | grounding_pos, grounding_feat, grounding_mask = self.get_synthetic_blobs() 272 | prot_size = int(grounding_feat[:, 1].sum()) 273 | mask[prot_size:] = 0 274 | res_mask[prot_size:] = 0 275 | mask[:prot_size] = 1 276 | res_mask[:prot_size] = 1 277 | else: 278 | grounding_pos = [] 279 | grounding_feat = [] 280 | grounding_covar = [] 281 | for blob in blobs: 282 | grounding_pos.append(blob["pos"]) 283 | grounding_feat.append((blob["dssp"], blob["count"])) 284 | grounding_covar.append(blob["covar"].flatten()) 285 | grounding_pos = np.array(grounding_pos).astype(np.float32) 286 | grounding_feat = np.array(grounding_feat) 287 | grounding_covar = np.array(grounding_covar).astype(np.float32) 288 | grounding_feat = np.concatenate( 289 | [grounding_feat, grounding_covar], axis=-1 290 | ).astype(np.float32) 291 | grounding_mask = np.ones_like(grounding_feat[:, 0]) 292 | 293 | if self.args.fixed_inference_size is not None: 294 | assert self.args.fixed_inference_size <= self.args.crop 295 | mask[self.args.fixed_inference_size :] = 0 296 | res_mask[self.args.fixed_inference_size :] = 0 297 | mask[: self.args.fixed_inference_size] = 1 298 | res_mask[: self.args.fixed_inference_size] = 1 299 | 300 | if self.args.length_dist_npz is not None: 301 | prot_size = np.random.choice(self.length_dist) 302 | mask[prot_size:] = 0 303 | res_mask[prot_size:] = 0 304 | mask[:prot_size] = 1 305 | res_mask[:prot_size] = 1 306 | 307 | if self.args.use_latents: 308 | try: 309 | latents = self.latents[name.lower()] 310 | except Exception as e: 311 | print(f"key error for {name.lower()}") 312 | return self.__getitem__(np.random.randint(len(self))) 313 | else: 314 | latents = 0 315 | 316 | return { 317 | "name": name, 318 | "grounding_pos": grounding_pos, 319 | "grounding_feat": grounding_feat, 320 | "grounding_mask": grounding_mask, 321 | "trans": frames._trans, 322 | "rots": frames._rots._rot_mats, 323 | "mask": mask, 324 | "res_mask": res_mask, 325 | "seqres": seqres, 326 | "res_idx": res_idx, 327 | "chain_idx": chain_idx, 328 | "blobs": pickle.dumps(blobs), 329 | "thresh": thresh, 330 | "latents": latents, 331 | } 332 | 333 | def get_synthetic_blobs(self): 334 | pos, covar = blobs.sample_blobs( 335 | self.args.num_blobs, 336 | nu=self.args.nu, 337 | psi=(1 / self.args.nu) * self.args.psi**2 * np.eye(3), 338 | sigma=self.args.sigma, 339 | ) 340 | is_helix = np.random.rand(self.args.num_blobs) < self.args.helix_frac 341 | volume = np.linalg.det(covar) ** 0.5 342 | 343 | counts = np.where( 344 | is_helix, 345 | blobs.alpha_slope * volume + blobs.alpha_intercept, 346 | blobs.beta_slope * volume + blobs.beta_intercept, 347 | ).astype(int) 348 | 349 | dssp = np.where(is_helix, 1, 2) 350 | 351 | grounding_pos = pos[None].astype(np.float32) 352 | grounding_feat = np.zeros((1, self.args.num_blobs, 11), dtype=np.float32) 353 | grounding_feat[:, :, 0] = dssp # all helix for now 354 | grounding_feat[:, :, 1] = counts 355 | grounding_feat[:, :, 2:] = covar.reshape(1, self.args.num_blobs, 9) 356 | grounding_mask = np.ones(self.args.num_blobs, dtype=np.float32) 357 | return grounding_pos[0], grounding_feat[0], grounding_mask 358 | 359 | 360 | class MultiflowDataset(StructureDataset): 361 | def __init__(self, args, dataset_cfg): 362 | super().__init__(args) 363 | self.args = args 364 | self.mf_dataset = mfdatasets.PdbDataset( 365 | dataset_cfg=dataset_cfg, task="hallucination", is_training=True 366 | ) 367 | 368 | def __len__(self): 369 | if self.args.dataset_multiplicity is not None: 370 | return len(self.mf_dataset.csv) * self.args.dataset_multiplicity 371 | else: 372 | return len(self.mf_dataset.csv) 373 | 374 | def __getitem__(self, idx): 375 | if self.args.overfit: 376 | idx = 0 377 | if self.args.dataset_multiplicity is not None: 378 | idx = idx % self.args.dataset_multiplicity 379 | row = self.mf_dataset.csv.iloc[idx] 380 | name = row["pdb_name"] 381 | path = f"{self.args.pkl_dir}/{row['processed_path'].replace('./', '')}" 382 | 383 | # res_idx, chain_idx, res_mask, trans_1, rotmats_1, aatypes_1, res_plddt 384 | mf_prot = mfdatasets._process_csv_row(path) 385 | 386 | atom37 = mf_prot["all_atom_positions"].float().numpy() 387 | # atom37_to_pdb(atom37[None], f'../protfid/data/mf_testset/{name}.pdb') 388 | frames = Rigid( 389 | Rotation(rot_mats=mf_prot["rotmats_1"].float(), quats=None), 390 | mf_prot["trans_1"].float(), 391 | ) 392 | 393 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py) 394 | bb_pos = np.concatenate( 395 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1 396 | ) # (L, 4, 3) 397 | 398 | L = frames.shape[0] 399 | res_mask = mf_prot["res_mask"].float().numpy() 400 | seqres = mf_prot["aatypes_1"].numpy() 401 | res_idx = mf_prot["res_idx"] 402 | chain_idx = mf_prot["chain_idx"] 403 | return self.item_from_prot( 404 | name, atom37, frames, res_mask, seqres, res_idx, chain_idx 405 | ) 406 | 407 | 408 | class GenieDBDataset(StructureDataset): 409 | def __init__(self, args): 410 | super().__init__(args) 411 | self.args = args 412 | files = os.listdir(self.args.genie_db_path) 413 | self.db = [file for file in files if ".pdb" in file] 414 | self.db = self.db * args.repeat_dataset 415 | 416 | def __len__(self): 417 | return len(self.db) 418 | 419 | def __getitem__(self, idx): 420 | if self.args.overfit: 421 | idx = 0 422 | name = self.db[idx] 423 | with open(f"{self.args.genie_db_path}/{name}", "r") as f: 424 | pdb = f.read() 425 | try: 426 | sample = self.process_prot(idx, name, pdb) 427 | except Exception as e: 428 | # for some reason the name printing here does not work so dont rely on it. We instead write it to a file. 429 | print("name", name, flush=True) 430 | with open(os.path.join(os.environ["MODEL_DIR"], "debug.txt"), "w") as f: 431 | f.write(name) 432 | raise e 433 | return sample 434 | 435 | 436 | class SCOPDataset(torch.utils.data.Dataset): 437 | def __init__(self, args): 438 | super().__init__() 439 | self.args = args 440 | names = os.listdir(args.scop_dir) 441 | nmr_names = pd.read_csv(args.scop_nmr_csv) 442 | self.names = [ 443 | name for name in names if name not in nmr_names["nmr_names"].tolist() 444 | ] 445 | 446 | def __len__(self): 447 | return len(self.names) 448 | 449 | def __getitem__(self, idx): 450 | if self.args.overfit: 451 | idx = 0 452 | 453 | name = self.names[idx] 454 | pdb = open(f"{self.args.scop_dir}/{name}", "r").read() 455 | prot = protein.from_pdb_string(pdb) 456 | 457 | atom37 = prot.atom_positions.astype(np.float32) 458 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py) 459 | bb_pos = np.concatenate( 460 | [atom37[:, :3, :], atom37[:, 4:5, :]], axis=1 461 | ) # (L, 4, 3) 462 | dssp = pydssp.assign( 463 | bb_pos, out_type="index" 464 | ) # 0: loop, 1: alpha-helix, 2: beta-strand 465 | 466 | pos = atom37[:, 1, :] 467 | L = atom37.shape[0] 468 | mask = np.ones(L, dtype=np.float32) 469 | seqres = prot.aatype.astype(int) 470 | resid = prot.residue_index.astype(int) 471 | 472 | cm = (pos * mask[:, None]).sum(0, keepdims=True) / mask.sum() 473 | pos = pos - cm 474 | if not self.args.overfit_rot: 475 | rot = spRotation.random().as_matrix().astype(np.float32) 476 | pos = pos @ rot.T 477 | 478 | n_clusters = math.ceil(L / self.args.res_per_cluster) 479 | kmeans = KMeans(n_clusters=n_clusters, n_init=1).fit(bb_pos[:, 1]) 480 | labels = kmeans.labels_ 481 | centers = kmeans.cluster_centers_ 482 | dssp_onehot = torch.nn.functional.one_hot(torch.from_numpy(dssp), num_classes=3) 483 | 484 | dssp_count = torch.zeros(len(centers), 3, dtype=torch.long).scatter_add_( 485 | dim=0, 486 | index=torch.from_numpy(labels).long()[:, None].expand(-1, 3), 487 | src=dssp_onehot, 488 | ) 489 | dssp_dist = dssp_count / dssp_count.sum(1)[:, None] 490 | 491 | return { 492 | "prot": prot, 493 | "labels": labels, 494 | "centers": centers, 495 | "dssp_dist": dssp_dist, 496 | "resid": resid, 497 | "name": name, 498 | "dssp": dssp, 499 | "bb_pos": pos, 500 | "mask": mask, 501 | "seqres": seqres, 502 | } 503 | -------------------------------------------------------------------------------- /proteinblobs/designability_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import subprocess 12 | 13 | import esm, tqdm, torch 14 | import os 15 | from biotite.sequence.io import fasta 16 | import numpy as np 17 | 18 | from proteinblobs.utils import atom37_to_pdb 19 | from openfold.np import residue_constants 20 | 21 | 22 | 23 | def run_designability( 24 | prots, pmpnn_path, process_rank=0, seqs_per_struct=8, sequences=None 25 | ): 26 | # prots in list of np arrays of ( L, 3) 27 | 28 | print("Loading ESMFold model for designability evaluation") 29 | torch.cuda.empty_cache() 30 | esmf_model = esm.pretrained.esmfold_v1().eval() 31 | esmf_model = esmf_model.to("cuda") 32 | 33 | all_tm_scores = [] 34 | all_rmsds = [] 35 | top_rmsds = [] 36 | top_tm_scores = [] 37 | for i, prot in tqdm.tqdm(enumerate(prots), desc="Running PMPNN and ESMFold"): 38 | sample_dir = path = os.path.join( 39 | os.environ["MODEL_DIR"], 40 | f"tmp_design_dir_process{process_rank}", 41 | f"sample{i}_process{process_rank}", 42 | ) 43 | os.makedirs(sample_dir, exist_ok=True) 44 | sample_path = os.path.join(sample_dir, f"sample.pdb") 45 | 46 | rmsds = [] 47 | #tm_scores = [] 48 | if sequences is not None: 49 | restypes = residue_constants.restypes + ["X"] 50 | seqs = ["".join(map(lambda x: restypes[x], sequences[i]))] 51 | 52 | else: 53 | atom37_to_pdb(prot[None], sample_path) 54 | run_pmpnn( 55 | pdb_dir=sample_dir, num_seqs=seqs_per_struct, pmpnn_path=pmpnn_path 56 | ) 57 | mpnn_fasta_path = os.path.join( 58 | sample_dir, "seqs", os.path.basename(sample_path).replace(".pdb", ".fa") 59 | ) 60 | fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path) 61 | 62 | seqs = [v for k, v in fasta_seqs.items()] 63 | seqs = seqs[ 64 | 1: 65 | ] # remove the first sequence, which is the input sequence (and AAAAAAA... if there is no input sequence to the structure) 66 | 67 | for j, seq in enumerate(seqs): 68 | seq = seq.replace("X", "A") 69 | 70 | with torch.no_grad(): 71 | output = esmf_model.infer(seq) 72 | 73 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy() 74 | #_, tm_score = get_tm_score(prot[:, 1], out_ca_pos, seq, seq) 75 | rmsd = get_aligned_rmsd(prot[:, 1], out_ca_pos) 76 | rmsds.append(rmsd) 77 | #tm_scores.append(tm_score) 78 | all_rmsds.append(np.array(rmsds).mean()) 79 | #all_tm_scores.append(np.array(tm_scores).mean()) 80 | top_rmsds.append(np.array(rmsds).min()) 81 | #top_tm_scores.append(np.array(tm_scores).max()) 82 | del esmf_model 83 | torch.cuda.empty_cache() 84 | #all_tm_scores = np.array(all_tm_scores) 85 | all_rmsds = np.array(all_rmsds) 86 | top_rmsds = np.array(top_rmsds) 87 | #top_tm_scores = np.array(top_tm_scores) 88 | 89 | return { 90 | #"tm_score": all_tm_scores, 91 | "rmsd": all_rmsds, 92 | #"top_tm_score": top_tm_scores, 93 | "top_rmsd": top_rmsds, 94 | } 95 | 96 | 97 | def run_pmpnn(pdb_dir, num_seqs=8, pmpnn_path="../ProteinMPNN", ca_only=False): 98 | os.makedirs(os.path.join(pdb_dir, "seqs"), exist_ok=True) 99 | parsed_chains_path = os.path.join(pdb_dir, "seqs", "parsed_chains.jsonl") 100 | 101 | process = subprocess.Popen( 102 | [ 103 | "python", 104 | os.path.join(pmpnn_path, "helper_scripts/parse_multiple_chains.py"), 105 | f"--input_path={pdb_dir}", 106 | f"--output_path={parsed_chains_path}", 107 | ] 108 | ) 109 | _ = process.wait() 110 | 111 | pmpnn_args = [ 112 | "python", 113 | os.path.join(pmpnn_path, "protein_mpnn_run.py"), 114 | "--out_folder", 115 | pdb_dir, 116 | "--jsonl_path", 117 | parsed_chains_path, 118 | "--num_seq_per_target", 119 | str(num_seqs), 120 | "--sampling_temp", 121 | "0.1", 122 | "--seed", 123 | "38", 124 | "--batch_size", 125 | "1", 126 | ] 127 | if ca_only: 128 | pmpnn_args.append("--ca_only") 129 | print(" ".join(pmpnn_args)) 130 | 131 | process = subprocess.run(pmpnn_args) 132 | 133 | 134 | def get_aligned_rmsd(pos_1, pos_2): 135 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0] 136 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1)) 137 | 138 | 139 | def get_tm_score(pos_1, pos_2, seq_1, seq_2): 140 | # tm_results = tm_align(pos_1, pos_2, seq_1, seq_2) 141 | return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2 142 | 143 | 144 | def rigid_transform_3D(A, B, verbose=False): 145 | # Transforms A to look like B 146 | # https://github.com/nghiaho12/rigid_transform_3D 147 | assert A.shape == B.shape 148 | A = A.T 149 | B = B.T 150 | 151 | num_rows, num_cols = A.shape 152 | if num_rows != 3: 153 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 154 | 155 | num_rows, num_cols = B.shape 156 | if num_rows != 3: 157 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 158 | 159 | # find mean column wise 160 | centroid_A = np.mean(A, axis=1) 161 | centroid_B = np.mean(B, axis=1) 162 | 163 | # ensure centroids are 3x1 164 | centroid_A = centroid_A.reshape(-1, 1) 165 | centroid_B = centroid_B.reshape(-1, 1) 166 | 167 | # subtract mean 168 | Am = A - centroid_A 169 | Bm = B - centroid_B 170 | 171 | H = Am @ np.transpose(Bm) 172 | 173 | # sanity check 174 | # if linalg.matrix_rank(H) < 3: 175 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H))) 176 | 177 | # find rotation 178 | U, S, Vt = np.linalg.svd(H) 179 | R = Vt.T @ U.T 180 | 181 | # special reflection case 182 | reflection_detected = False 183 | if np.linalg.det(R) < 0: 184 | if verbose: 185 | print("det(R) < R, reflection detected!, correcting for it ...") 186 | Vt[2, :] *= -1 187 | R = Vt.T @ U.T 188 | reflection_detected = True 189 | 190 | t = -R @ centroid_A + centroid_B 191 | optimal_A = R @ A + t 192 | 193 | return optimal_A.T, R, t, reflection_detected 194 | -------------------------------------------------------------------------------- /proteinblobs/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import yaml, logging, socket, os, sys 13 | model_dir = os.environ.get("MODEL_DIR", "./workdir/default") 14 | def get_logger(name): 15 | logger = logging.Logger(name) 16 | level = {"crititical": 50, "error": 40, "warning": 30, "info": 20, "debug": 10}[ 17 | os.environ.get("LOGGER_LEVEL", "info") 18 | ] 19 | logger.setLevel(level) 20 | 21 | ch = logging.StreamHandler() 22 | ch.setLevel(logging.INFO) 23 | os.makedirs(model_dir, exist_ok=True) 24 | fh = logging.FileHandler(os.path.join(model_dir, "log.out")) 25 | fh.setLevel(logging.DEBUG) 26 | formatter = logging.Formatter( 27 | f"%(asctime)s [{socket.gethostname()}:%(process)d] [%(levelname)s] %(message)s" 28 | ) 29 | ch.setFormatter(formatter) 30 | fh.setFormatter(formatter) 31 | logger.addHandler(ch) 32 | logger.addHandler(fh) 33 | return logger 34 | 35 | -------------------------------------------------------------------------------- /proteinblobs/multiflow/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andrew Campbell, Jason Yim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /proteinblobs/multiflow/cross_attention.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import torch.nn as nn 12 | 13 | import torch 14 | import math 15 | 16 | def positional_encoding(x, channels, max_period=10000, min_period=None): 17 | min_period = min_period or max_period / 10000 # second one recommended 18 | freqs = torch.exp(-torch.linspace(math.log(min_period), math.log(max_period), channels // 2, device=x.device)) 19 | emb = freqs * x.unsqueeze(-1) # [..., C] 20 | return torch.cat([torch.sin(emb), torch.cos(emb)], -1) # [..., 2C] 21 | 22 | 23 | 24 | class InvariantCrossAttention(nn.Module): 25 | def __init__(self, cfg): 26 | super().__init__() 27 | self.cfg = cfg 28 | pos_emb_dim = 64 29 | self.linear_pos = nn.Linear(3 * pos_emb_dim, cfg.c_s) 30 | self.linear_kv = nn.Linear(cfg.c_s, 2 * cfg.c_s) 31 | self.linear_covar = nn.Linear(9, cfg.c_s) 32 | 33 | self.linear_q = nn.Linear(cfg.c_s, cfg.c_s) 34 | self.linear_out = nn.Linear(cfg.c_s, cfg.c_s) 35 | self.gamma = nn.Parameter(torch.zeros(1)) 36 | self.beta = nn.Parameter(torch.ones(1)) 37 | self.beta.requires_grad_(False) 38 | self.inf = 1e10 39 | 40 | def forward(self, s, r, x, y, c, mask): 41 | ''' 42 | s: frame features 43 | r: frames 44 | x: point positions 45 | y: point features 46 | c: blob covariance 47 | mask: point mask 48 | ''' 49 | 50 | # [B, L, N, 3] 51 | relpos = r[:,:,None].invert_apply(x[:,None]) 52 | 53 | # [B, L, 3, 3] 54 | R = r._rots.get_rot_mats() 55 | 56 | # [B, L, N, 3, 3] 57 | covar = torch.einsum('blij,bnjk,blkm->blnim', R.transpose(-1, -2), c, R) 58 | 59 | # [B, L, N, C] 60 | relpos = positional_encoding(relpos, 64, max_period=1000) # 1000 nm??? 61 | B, L, N, _, _ = relpos.shape 62 | 63 | # [B, L, N, C] 64 | y = y[:,None] + self.linear_pos(relpos.view(B, L, N, 3*64)) + self.linear_covar(covar.view(B, L, N, 9)) 65 | B, L, N, C = y.shape 66 | 67 | # [B, L, N, C] 68 | k, v = self.linear_kv(y).chunk(2, dim=-1) 69 | 70 | # [B, L, C] 71 | q = self.linear_q(s) 72 | 73 | H = self.cfg.no_heads 74 | k = k.view(B, L, N, H, C//H) 75 | v = v.view(B, L, N, H, C//H) 76 | q = q.view(B, L, 1, H, C//H) 77 | 78 | # [B, L, N, H] 79 | attn = torch.einsum('BLNHC,BLNHC->BLNH', q, k) / math.sqrt(C//H) 80 | attn = attn + (mask[:,None,:,None] - 1) * self.inf 81 | attn = torch.softmax(attn, dim=-2) 82 | 83 | # [B, L, C] 84 | output = torch.einsum('BLNH,BLNHC->BLHC', attn, v).reshape(B, L, C) 85 | output = self.linear_out(output) 86 | 87 | output = self.beta * torch.tanh(self.gamma) * output 88 | 89 | return output 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /proteinblobs/multiflow/data/all_atom.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | # Copyright 2021 AlQuraishi Laboratory 13 | # Copyright 2021 DeepMind Technologies Limited 14 | # 15 | # Licensed under the Apache License, Version 2.0 (the "License"); 16 | # you may not use this file except in compliance with the License. 17 | # You may obtain a copy of the License at 18 | # 19 | # http://www.apache.org/licenses/LICENSE-2.0 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | 27 | """Utilities for calculating all atom representations.""" 28 | 29 | import torch 30 | from openfold.data import data_transforms 31 | from openfold.np import residue_constants 32 | from .. import rigid_utils as ru 33 | from . import utils as du 34 | 35 | Rigid = ru.Rigid 36 | Rotation = ru.Rotation 37 | 38 | # Residue Constants from OpenFold/AlphaFold2. 39 | 40 | 41 | IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) 42 | DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) 43 | ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) 44 | GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) 45 | 46 | 47 | def to_atom37(trans, rots): 48 | num_batch, num_res, _ = trans.shape 49 | final_atom37 = compute_backbone( 50 | du.create_rigid(rots, trans), 51 | torch.zeros(num_batch, num_res, 2, device=trans.device), 52 | )[0] 53 | return final_atom37 54 | 55 | 56 | def torsion_angles_to_frames( 57 | r: Rigid, # type: ignore [valid-type] 58 | alpha: torch.Tensor, 59 | aatype: torch.Tensor, 60 | ): 61 | """Conversion method of torsion angles to frames provided the backbone. 62 | 63 | Args: 64 | r: Backbone rigid groups. 65 | alpha: Torsion angles. 66 | aatype: residue types. 67 | 68 | Returns: 69 | All 8 frames corresponding to each torsion frame. 70 | 71 | """ 72 | # [*, N, 8, 4, 4] 73 | with torch.no_grad(): 74 | default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] # type: ignore [attr-defined] 75 | 76 | # [*, N, 8] transformations, i.e. 77 | # One [*, N, 8, 3, 3] rotation matrix and 78 | # One [*, N, 8, 3] translation matrix 79 | default_r = r.from_tensor_4x4(default_4x4) # type: ignore [attr-defined] 80 | 81 | bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) 82 | bb_rot[..., 1] = 1 83 | 84 | # [*, N, 8, 2] 85 | alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) 86 | 87 | # [*, N, 8, 3, 3] 88 | # Produces rotation matrices of the form: 89 | # [ 90 | # [1, 0 , 0 ], 91 | # [0, a_2,-a_1], 92 | # [0, a_1, a_2] 93 | # ] 94 | # This follows the original code rather than the supplement, which uses 95 | # different indices. 96 | 97 | all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) 98 | all_rots[..., 0, 0] = 1 99 | all_rots[..., 1, 1] = alpha[..., 1] 100 | all_rots[..., 1, 2] = -alpha[..., 0] 101 | all_rots[..., 2, 1:] = alpha 102 | 103 | all_rots = Rigid(Rotation(rot_mats=all_rots), None) 104 | 105 | all_frames = default_r.compose(all_rots) 106 | 107 | chi2_frame_to_frame = all_frames[..., 5] 108 | chi3_frame_to_frame = all_frames[..., 6] 109 | chi4_frame_to_frame = all_frames[..., 7] 110 | 111 | chi1_frame_to_bb = all_frames[..., 4] 112 | chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) 113 | chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) 114 | chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) 115 | 116 | all_frames_to_bb = Rigid.cat( 117 | [ 118 | all_frames[..., :5], 119 | chi2_frame_to_bb.unsqueeze(-1), 120 | chi3_frame_to_bb.unsqueeze(-1), 121 | chi4_frame_to_bb.unsqueeze(-1), 122 | ], 123 | dim=-1, 124 | ) 125 | 126 | all_frames_to_global = r[..., None].compose(all_frames_to_bb) # type: ignore [index] 127 | 128 | return all_frames_to_global 129 | 130 | 131 | def prot_to_torsion_angles(aatype, atom37, atom37_mask): 132 | """Calculate torsion angle features from protein features.""" 133 | prot_feats = { 134 | "aatype": aatype, 135 | "all_atom_positions": atom37, 136 | "all_atom_mask": atom37_mask, 137 | } 138 | torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats) 139 | torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"] 140 | torsion_mask = torsion_angles_feats["torsion_angles_mask"] 141 | return torsion_angles, torsion_mask 142 | 143 | 144 | def frames_to_atom14_pos( 145 | r: Rigid, # type: ignore [valid-type] 146 | aatype: torch.Tensor, 147 | ): 148 | """Convert frames to their idealized all atom representation. 149 | 150 | Args: 151 | r: All rigid groups. [..., N, 8, 3] 152 | aatype: Residue types. [..., N] 153 | 154 | Returns: 155 | 156 | """ 157 | with torch.no_grad(): 158 | group_mask = GROUP_IDX.to(aatype.device)[aatype, ...] 159 | group_mask = torch.nn.functional.one_hot( 160 | group_mask, 161 | num_classes=DEFAULT_FRAMES.shape[-3], 162 | ) 163 | frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) # type: ignore [attr-defined] 164 | frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] # type: ignore [attr-defined] 165 | 166 | # [*, N, 14, 8] 167 | t_atoms_to_global = r[..., None, :] * group_mask # type: ignore [index] 168 | 169 | # [*, N, 14] 170 | t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) 171 | 172 | # [*, N, 14, 3] 173 | pred_positions = t_atoms_to_global.apply(frame_null_pos) 174 | pred_positions = pred_positions * frame_atom_mask 175 | 176 | return pred_positions 177 | 178 | 179 | def compute_backbone(bb_rigids, psi_torsions): 180 | torsion_angles = torch.tile( 181 | psi_torsions[..., None, :], 182 | tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1), 183 | ) 184 | aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long() 185 | # aatype = torch.zeros(bb_rigids.shape).long().to(bb_rigids.device) 186 | all_frames = torsion_angles_to_frames( 187 | bb_rigids, 188 | torsion_angles, 189 | aatype, 190 | ) 191 | atom14_pos = frames_to_atom14_pos(all_frames, aatype) 192 | atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device) 193 | # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB'] 194 | # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O'] 195 | atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :] 196 | atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :] 197 | atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :] 198 | atom37_mask = torch.any(atom37_bb_pos, axis=-1) 199 | return atom37_bb_pos, atom37_mask, aatype, atom14_pos 200 | 201 | 202 | def calculate_neighbor_angles(R_ac, R_ab): 203 | """Calculate angles between atoms c <- a -> b. 204 | 205 | Parameters 206 | ---------- 207 | R_ac: Tensor, shape = (N,3) 208 | Vector from atom a to c. 209 | R_ab: Tensor, shape = (N,3) 210 | Vector from atom a to b. 211 | 212 | Returns 213 | ------- 214 | angle_cab: Tensor, shape = (N,) 215 | Angle between atoms c <- a -> b. 216 | """ 217 | # cos(alpha) = (u * v) / (|u|*|v|) 218 | x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,) 219 | # sin(alpha) = |u x v| / (|u|*|v|) 220 | y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,) 221 | # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN 222 | y = torch.max(y, torch.tensor(1e-9)) 223 | angle = torch.atan2(y, x) 224 | return angle 225 | 226 | 227 | def vector_projection(R_ab, P_n): 228 | """ 229 | Project the vector R_ab onto a plane with normal vector P_n. 230 | 231 | Parameters 232 | ---------- 233 | R_ab: Tensor, shape = (N,3) 234 | Vector from atom a to b. 235 | P_n: Tensor, shape = (N,3) 236 | Normal vector of a plane onto which to project R_ab. 237 | 238 | Returns 239 | ------- 240 | R_ab_proj: Tensor, shape = (N,3) 241 | Projected vector (orthogonal to P_n). 242 | """ 243 | a_x_b = torch.sum(R_ab * P_n, dim=-1) 244 | b_x_b = torch.sum(P_n * P_n, dim=-1) 245 | return R_ab - (a_x_b / b_x_b)[:, None] * P_n 246 | 247 | 248 | def transrot_traj_to_atom37(transrot_traj, res_mask): 249 | atom37_traj = [] 250 | for trans, rots in transrot_traj: 251 | atom37_traj.append(atom37_from_trans_rot(trans, rots, res_mask)) 252 | return atom37_traj 253 | 254 | 255 | def atom37_from_trans_rot(trans, rots, res_mask=None): 256 | if res_mask is None: 257 | res_mask = torch.ones([*trans.shape[:-1]], device=trans.device) 258 | rigids = du.create_rigid(rots, trans) 259 | atom37 = compute_backbone( 260 | rigids, torch.zeros(trans.shape[0], trans.shape[1], 2, device=trans.device) 261 | )[0] 262 | batch_atom37 = [] 263 | num_batch = res_mask.shape[0] 264 | for i in range(num_batch): 265 | batch_atom37.append(du.adjust_oxygen_pos(atom37[i], res_mask[i])) 266 | return torch.stack(batch_atom37) 267 | 268 | def process_trans_rot_traj(trans_traj, rots_traj, res_mask): 269 | res_mask = res_mask.detach().cpu() 270 | atom37_traj = [ 271 | atom37_from_trans_rot(trans, rots, res_mask) 272 | for trans, rots in zip(trans_traj, rots_traj) 273 | ] 274 | atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1) 275 | return atom37_traj 276 | -------------------------------------------------------------------------------- /proteinblobs/multiflow/data/protein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Protein data type.""" 16 | import dataclasses 17 | import io 18 | from typing import Any, Mapping, Optional 19 | from . import residue_constants 20 | from Bio.PDB import PDBParser 21 | import numpy as np 22 | 23 | FeatureDict = Mapping[str, np.ndarray] 24 | ModelOutput = Mapping[str, Any] # Is a nested dict. 25 | 26 | # Complete sequence of chain IDs supported by the PDB format. 27 | PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' 28 | PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. 29 | 30 | 31 | @dataclasses.dataclass(frozen=True) 32 | class Protein: 33 | """Protein structure representation.""" 34 | 35 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to 36 | # residue_constants.atom_types, i.e. the first three are N, CA, CB. 37 | atom_positions: np.ndarray # [num_res, num_atom_type, 3] 38 | 39 | # Amino-acid type for each residue represented as an integer between 0 and 40 | # 20, where 20 is 'X'. 41 | aatype: np.ndarray # [num_res] 42 | 43 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom 44 | # is present and 0.0 if not. This should be used for loss masking. 45 | atom_mask: np.ndarray # [num_res, num_atom_type] 46 | 47 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. 48 | residue_index: np.ndarray # [num_res] 49 | 50 | # 0-indexed number corresponding to the chain in the protein that this residue 51 | # belongs to. 52 | chain_index: np.ndarray # [num_res] 53 | 54 | # B-factors, or temperature factors, of each residue (in sq. angstroms units), 55 | # representing the displacement of the residue from its ground truth mean 56 | # value. 57 | b_factors: np.ndarray # [num_res, num_atom_type] 58 | 59 | def __post_init__(self): 60 | if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: 61 | raise ValueError( 62 | f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' 63 | 'because these cannot be written to PDB format.') 64 | 65 | 66 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: 67 | """Takes a PDB string and constructs a Protein object. 68 | 69 | WARNING: All non-standard residue types will be converted into UNK. All 70 | non-standard atoms will be ignored. 71 | 72 | Args: 73 | pdb_str: The contents of the pdb file 74 | chain_id: If chain_id is specified (e.g. A), then only that chain 75 | is parsed. Otherwise all chains are parsed. 76 | 77 | Returns: 78 | A new `Protein` parsed from the pdb contents. 79 | """ 80 | pdb_fh = io.StringIO(pdb_str) 81 | parser = PDBParser(QUIET=True) 82 | structure = parser.get_structure('none', pdb_fh) 83 | models = list(structure.get_models()) 84 | if len(models) != 1: 85 | raise ValueError( 86 | f'Only single model PDBs are supported. Found {len(models)} models.') 87 | model = models[0] 88 | 89 | atom_positions = [] 90 | aatype = [] 91 | atom_mask = [] 92 | residue_index = [] 93 | chain_ids = [] 94 | b_factors = [] 95 | 96 | for chain in model: 97 | if chain_id is not None and chain.id != chain_id: 98 | continue 99 | for res in chain: 100 | if res.id[2] != ' ': 101 | raise ValueError( 102 | f'PDB contains an insertion code at chain {chain.id} and residue ' 103 | f'index {res.id[1]}. These are not supported.') 104 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') 105 | restype_idx = residue_constants.restype_order.get( 106 | res_shortname, residue_constants.restype_num) 107 | pos = np.zeros((residue_constants.atom_type_num, 3)) 108 | mask = np.zeros((residue_constants.atom_type_num,)) 109 | res_b_factors = np.zeros((residue_constants.atom_type_num,)) 110 | for atom in res: 111 | if atom.name not in residue_constants.atom_types: 112 | continue 113 | pos[residue_constants.atom_order[atom.name]] = atom.coord 114 | mask[residue_constants.atom_order[atom.name]] = 1. 115 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor 116 | if np.sum(mask) < 0.5: 117 | # If no known atom positions are reported for the residue then skip it. 118 | continue 119 | aatype.append(restype_idx) 120 | atom_positions.append(pos) 121 | atom_mask.append(mask) 122 | residue_index.append(res.id[1]) 123 | chain_ids.append(chain.id) 124 | b_factors.append(res_b_factors) 125 | 126 | # Chain IDs are usually characters so map these to ints. 127 | unique_chain_ids = np.unique(chain_ids) 128 | chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} 129 | chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) 130 | 131 | return Protein( 132 | atom_positions=np.array(atom_positions), 133 | atom_mask=np.array(atom_mask), 134 | aatype=np.array(aatype), 135 | residue_index=np.array(residue_index), 136 | chain_index=chain_index, 137 | b_factors=np.array(b_factors)) 138 | 139 | 140 | def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: 141 | chain_end = 'TER' 142 | return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' 143 | f'{chain_name:>1}{residue_index:>4}') 144 | 145 | 146 | def to_pdb(prot: Protein, model=1, add_end=True) -> str: 147 | """Converts a `Protein` instance to a PDB string. 148 | 149 | Args: 150 | prot: The protein to convert to PDB. 151 | 152 | Returns: 153 | PDB string. 154 | """ 155 | restypes = residue_constants.restypes + ['X'] 156 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') 157 | atom_types = residue_constants.atom_types 158 | 159 | pdb_lines = [] 160 | 161 | atom_mask = prot.atom_mask 162 | aatype = prot.aatype 163 | atom_positions = prot.atom_positions 164 | residue_index = prot.residue_index.astype(int) 165 | chain_index = prot.chain_index.astype(int) 166 | b_factors = prot.b_factors 167 | 168 | if np.any(aatype > residue_constants.restype_num): 169 | raise ValueError('Invalid aatypes.') 170 | 171 | # Construct a mapping from chain integer indices to chain ID strings. 172 | chain_ids = {} 173 | for i in np.unique(chain_index): # np.unique gives sorted output. 174 | if i >= PDB_MAX_CHAINS: 175 | raise ValueError( 176 | f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') 177 | chain_ids[i] = PDB_CHAIN_IDS[i] 178 | 179 | pdb_lines.append(f'MODEL {model}') 180 | atom_index = 1 181 | last_chain_index = chain_index[0] 182 | # Add all atom sites. 183 | for i in range(aatype.shape[0]): 184 | # Close the previous chain if in a multichain PDB. 185 | if last_chain_index != chain_index[i]: 186 | pdb_lines.append(_chain_end( 187 | atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], 188 | residue_index[i - 1])) 189 | last_chain_index = chain_index[i] 190 | atom_index += 1 # Atom index increases at the TER symbol. 191 | 192 | res_name_3 = res_1to3(aatype[i]) 193 | for atom_name, pos, mask, b_factor in zip( 194 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]): 195 | if mask < 0.5: 196 | continue 197 | 198 | record_type = 'ATOM' 199 | name = atom_name if len(atom_name) == 4 else f' {atom_name}' 200 | alt_loc = '' 201 | insertion_code = '' 202 | occupancy = 1.00 203 | element = atom_name[0] # Protein supports only C, N, O, S, this works. 204 | charge = '' 205 | # PDB is a columnar format, every space matters here! 206 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' 207 | f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' 208 | f'{residue_index[i]:>4}{insertion_code:>1} ' 209 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' 210 | f'{occupancy:>6.2f}{b_factor:>6.2f} ' 211 | f'{element:>2}{charge:>2}') 212 | pdb_lines.append(atom_line) 213 | atom_index += 1 214 | 215 | # Close the final chain. 216 | pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), 217 | chain_ids[chain_index[-1]], residue_index[-1])) 218 | pdb_lines.append('ENDMDL') 219 | if add_end: 220 | pdb_lines.append('END') 221 | 222 | # Pad all lines to 80 characters. 223 | pdb_lines = [line.ljust(80) for line in pdb_lines] 224 | return '\n'.join(pdb_lines) + '\n' # Add terminating newline. 225 | 226 | 227 | def ideal_atom_mask(prot: Protein) -> np.ndarray: 228 | """Computes an ideal atom mask. 229 | 230 | `Protein.atom_mask` typically is defined according to the atoms that are 231 | reported in the PDB. This function computes a mask according to heavy atoms 232 | that should be present in the given sequence of amino acids. 233 | 234 | Args: 235 | prot: `Protein` whose fields are `numpy.ndarray` objects. 236 | 237 | Returns: 238 | An ideal atom mask. 239 | """ 240 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype] 241 | 242 | 243 | def from_prediction( 244 | features: FeatureDict, 245 | result: ModelOutput, 246 | b_factors: Optional[np.ndarray] = None, 247 | remove_leading_feature_dimension: bool = True) -> Protein: 248 | """Assembles a protein from a prediction. 249 | 250 | Args: 251 | features: Dictionary holding model inputs. 252 | result: Dictionary holding model outputs. 253 | b_factors: (Optional) B-factors to use for the protein. 254 | remove_leading_feature_dimension: Whether to remove the leading dimension 255 | of the `features` values. 256 | 257 | Returns: 258 | A protein instance. 259 | """ 260 | fold_output = result['structure_module'] 261 | 262 | def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: 263 | return arr[0] if remove_leading_feature_dimension else arr 264 | 265 | if 'asym_id' in features: 266 | chain_index = _maybe_remove_leading_dim(features['asym_id']) 267 | else: 268 | chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) 269 | 270 | if b_factors is None: 271 | b_factors = np.zeros_like(fold_output['final_atom_mask']) 272 | 273 | return Protein( 274 | aatype=_maybe_remove_leading_dim(features['aatype']), 275 | atom_positions=fold_output['final_atom_positions'], 276 | atom_mask=fold_output['final_atom_mask'], 277 | residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, 278 | chain_index=chain_index, 279 | b_factors=b_factors) -------------------------------------------------------------------------------- /proteinblobs/multiflow/datasets.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import abc 12 | import numpy as np 13 | import pandas as pd 14 | import logging 15 | import tree 16 | import torch 17 | import random 18 | import os 19 | 20 | from glob import glob 21 | from torch.utils.data import Dataset 22 | import proteinblobs.multiflow.utils as du 23 | from openfold.data import data_transforms 24 | from openfold.utils import rigid_utils 25 | from sklearn.preprocessing import PolynomialFeatures 26 | from sklearn.linear_model import LinearRegression 27 | from pytorch_lightning.utilities import rank_zero_only 28 | 29 | 30 | def _rog_filter(df, quantile): 31 | y_quant = pd.pivot_table( 32 | df, 33 | values='radius_gyration', 34 | index='modeled_seq_len', 35 | aggfunc=lambda x: np.quantile(x, quantile) 36 | ) 37 | x_quant = y_quant.index.to_numpy() 38 | y_quant = y_quant.radius_gyration.to_numpy() 39 | 40 | # Fit polynomial regressor 41 | poly = PolynomialFeatures(degree=4, include_bias=True) 42 | poly_features = poly.fit_transform(x_quant[:, None]) 43 | poly_reg_model = LinearRegression() 44 | poly_reg_model.fit(poly_features, y_quant) 45 | 46 | # Calculate cutoff for all sequence lengths 47 | max_len = df.modeled_seq_len.max() 48 | pred_poly_features = poly.fit_transform(np.arange(max_len)[:, None]) 49 | # Add a little more. 50 | pred_y = poly_reg_model.predict(pred_poly_features) + 0.1 51 | 52 | row_rog_cutoffs = df.modeled_seq_len.map(lambda x: pred_y[x-1]) 53 | return df[df.radius_gyration < row_rog_cutoffs] 54 | 55 | 56 | def _length_filter(data_csv, min_res, max_res): 57 | return data_csv[ 58 | (data_csv.modeled_seq_len >= min_res) 59 | & (data_csv.modeled_seq_len <= max_res) 60 | ] 61 | 62 | 63 | def _plddt_percent_filter(data_csv, min_plddt_percent): 64 | return data_csv[data_csv.num_confident_plddt > min_plddt_percent] 65 | 66 | 67 | def _max_coil_filter(data_csv, max_coil_percent): 68 | return data_csv[data_csv.coil_percent <= max_coil_percent] 69 | 70 | 71 | def _process_csv_row(processed_file_path): 72 | processed_feats = du.read_pkl(processed_file_path) 73 | processed_feats = du.parse_chain_feats(processed_feats) 74 | 75 | # Only take modeled residues. 76 | modeled_idx = processed_feats['modeled_idx'] 77 | min_idx = np.min(modeled_idx) 78 | max_idx = np.max(modeled_idx) 79 | del processed_feats['modeled_idx'] 80 | processed_feats = tree.map_structure( 81 | lambda x: x[min_idx:(max_idx+1)], processed_feats) 82 | 83 | # Run through OpenFold data transforms. 84 | chain_feats = { 85 | 'aatype': torch.tensor(processed_feats['aatype']).long(), 86 | 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), 87 | 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() 88 | } 89 | chain_feats = data_transforms.atom37_to_frames(chain_feats) 90 | rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0] 91 | rotmats_1 = rigids_1.get_rots().get_rot_mats() 92 | trans_1 = rigids_1.get_trans() 93 | res_plddt = processed_feats['b_factors'][:, 1] 94 | res_mask = torch.tensor(processed_feats['bb_mask']).int() 95 | 96 | # Re-number residue indices for each chain such that it starts from 1. 97 | # Randomize chain indices. 98 | chain_idx = processed_feats['chain_index'] 99 | res_idx = processed_feats['residue_index'] 100 | new_res_idx = np.zeros_like(res_idx) 101 | new_chain_idx = np.zeros_like(res_idx) 102 | all_chain_idx = np.unique(chain_idx).tolist() 103 | shuffled_chain_idx = np.array( 104 | random.sample(all_chain_idx, len(all_chain_idx))) - np.min(all_chain_idx) + 1 105 | for i,chain_id in enumerate(all_chain_idx): 106 | chain_mask = (chain_idx == chain_id).astype(int) 107 | chain_min_idx = np.min(res_idx + (1 - chain_mask) * 1e3).astype(int) 108 | new_res_idx = new_res_idx + (res_idx - chain_min_idx + 1) * chain_mask 109 | 110 | # Shuffle chain_index 111 | replacement_chain_id = shuffled_chain_idx[i] 112 | new_chain_idx = new_chain_idx + replacement_chain_id * chain_mask 113 | 114 | if torch.isnan(trans_1).any() or torch.isnan(rotmats_1).any(): 115 | raise ValueError(f'Found NaNs in {processed_file_path}') 116 | 117 | return { 118 | 'res_plddt': res_plddt, 119 | 'aatypes_1': chain_feats['aatype'], 120 | 'all_atom_positions': chain_feats['all_atom_positions'], 121 | 'rotmats_1': rotmats_1, 122 | 'trans_1': trans_1, 123 | 'res_mask': res_mask, 124 | 'chain_idx': new_chain_idx, 125 | 'res_idx': new_res_idx, 126 | } 127 | 128 | 129 | def _add_plddt_mask(feats, plddt_threshold): 130 | feats['plddt_mask'] = torch.tensor( 131 | feats['res_plddt'] > plddt_threshold).int() 132 | 133 | 134 | def _read_clusters(cluster_path, synthetic=False): 135 | pdb_to_cluster = {} 136 | with open(cluster_path, "r") as f: 137 | for i,line in enumerate(f): 138 | for chain in line.split(' '): 139 | if not synthetic: 140 | pdb = chain.split('_')[0].strip() 141 | else: 142 | pdb = chain.strip() 143 | pdb_to_cluster[pdb.upper()] = i 144 | return pdb_to_cluster 145 | 146 | 147 | class BaseDataset(Dataset): 148 | def __init__( 149 | self, 150 | *, 151 | dataset_cfg, 152 | is_training, 153 | task, 154 | ): 155 | self._log = logging.getLogger(__name__) 156 | self._is_training = is_training 157 | self._dataset_cfg = dataset_cfg 158 | self.task = task 159 | self.raw_csv = pd.read_csv(self.dataset_cfg.csv_path) 160 | metadata_csv = self._filter_metadata(self.raw_csv) 161 | metadata_csv = metadata_csv.sort_values( 162 | 'modeled_seq_len', ascending=False) 163 | if self._dataset_cfg.use_redesigned: 164 | self.redesigned_csv = pd.read_csv(self._dataset_cfg.redesigned_csv_path) 165 | metadata_csv = metadata_csv.merge( 166 | self.redesigned_csv, left_on='pdb_name', right_on='example') 167 | metadata_csv = metadata_csv[metadata_csv.best_rmsd < 2.0] 168 | if self._dataset_cfg.cluster_path is not None: 169 | pdb_to_cluster = _read_clusters(self._dataset_cfg.cluster_path, synthetic=True) 170 | def cluster_lookup(pdb): 171 | pdb = pdb.upper() 172 | if pdb not in pdb_to_cluster: 173 | raise ValueError(f'Cluster not found for {pdb}') 174 | return pdb_to_cluster[pdb] 175 | metadata_csv['cluster'] = metadata_csv['pdb_name'].map(cluster_lookup) 176 | self._create_split(metadata_csv) 177 | self._cache = {} 178 | self._rng = np.random.default_rng(seed=self._dataset_cfg.seed) 179 | 180 | @property 181 | def is_training(self): 182 | return self._is_training 183 | 184 | @property 185 | def dataset_cfg(self): 186 | return self._dataset_cfg 187 | 188 | def __len__(self): 189 | return len(self.csv) 190 | 191 | @abc.abstractmethod 192 | def _filter_metadata(self, raw_csv: pd.DataFrame) -> pd.DataFrame: 193 | pass 194 | 195 | def _create_split(self, data_csv): 196 | # Training or validation specific logic. 197 | if self.is_training: 198 | self.csv = data_csv 199 | self._log.info( 200 | f'Training: {len(self.csv)} examples') 201 | else: 202 | if self._dataset_cfg.max_eval_length is None: 203 | eval_lengths = data_csv.modeled_seq_len 204 | else: 205 | eval_lengths = data_csv.modeled_seq_len[ 206 | data_csv.modeled_seq_len <= self._dataset_cfg.max_eval_length 207 | ] 208 | all_lengths = np.sort(eval_lengths.unique()) 209 | length_indices = (len(all_lengths) - 1) * np.linspace( 210 | 0.0, 1.0, self.dataset_cfg.num_eval_lengths) 211 | length_indices = length_indices.astype(int) 212 | eval_lengths = all_lengths[length_indices] 213 | eval_csv = data_csv[data_csv.modeled_seq_len.isin(eval_lengths)] 214 | 215 | # Fix a random seed to get the same split each time. 216 | eval_csv = eval_csv.groupby('modeled_seq_len').sample( 217 | self.dataset_cfg.samples_per_eval_length, 218 | replace=True, 219 | random_state=123 220 | ) 221 | eval_csv = eval_csv.sort_values('modeled_seq_len', ascending=False) 222 | self.csv = eval_csv 223 | self._log.info( 224 | f'Validation: {len(self.csv)} examples with lengths {eval_lengths}') 225 | self.csv['index'] = list(range(len(self.csv))) 226 | 227 | def process_csv_row(self, csv_row): 228 | path = csv_row['processed_path'] 229 | seq_len = csv_row['modeled_seq_len'] 230 | # Large protein files are slow to read. Cache them. 231 | use_cache = seq_len > self._dataset_cfg.cache_num_res 232 | if use_cache and path in self._cache: 233 | return self._cache[path] 234 | processed_row = _process_csv_row(path) 235 | processed_row['pdb_name'] = csv_row['pdb_name'] 236 | if self._dataset_cfg.use_redesigned: 237 | best_seq = csv_row['best_seq'] 238 | if not isinstance(best_seq, float): 239 | best_aatype = torch.tensor(du.seq_to_aatype(best_seq)).long() 240 | assert processed_row['aatypes_1'].shape == best_aatype.shape 241 | processed_row['aatypes_1'] = best_aatype 242 | aatypes_1 = du.to_numpy(processed_row['aatypes_1']) 243 | if len(set(aatypes_1)) == 1: 244 | raise ValueError(f'Example {path} has only one amino acid.') 245 | if use_cache: 246 | self._cache[path] = processed_row 247 | return processed_row 248 | 249 | 250 | def __getitem__(self, row_idx): 251 | # Process data example. 252 | csv_row = self.csv.iloc[row_idx] 253 | feats = self.process_csv_row(csv_row) 254 | 255 | if self._dataset_cfg.add_plddt_mask: 256 | _add_plddt_mask(feats, self._dataset_cfg.min_plddt_threshold) 257 | else: 258 | feats['plddt_mask'] = torch.ones_like(feats['res_mask']) 259 | 260 | if self.task == 'hallucination': 261 | feats['diffuse_mask'] = torch.ones_like(feats['res_mask']).bool() 262 | else: 263 | raise ValueError(f'Unknown task {self.task}') 264 | feats['diffuse_mask'] = feats['diffuse_mask'].int() 265 | 266 | # Storing the csv index is helpful for debugging. 267 | feats['csv_idx'] = torch.ones(1, dtype=torch.long) * row_idx 268 | return feats 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | def pdb_init_( 278 | self, 279 | *, 280 | dataset_cfg, 281 | is_training, 282 | task, 283 | ): 284 | self._log = logging.getLogger(__name__) 285 | self._is_training = is_training 286 | self._dataset_cfg = dataset_cfg 287 | self.task = task 288 | self._cache = {} 289 | self._rng = np.random.default_rng(seed=self._dataset_cfg.seed) 290 | 291 | # Process clusters 292 | self.raw_csv = pd.read_csv(self.dataset_cfg.csv_path) 293 | metadata_csv = self._filter_metadata(self.raw_csv) 294 | metadata_csv = metadata_csv.sort_values( 295 | 'modeled_seq_len', ascending=False) 296 | 297 | self._pdb_to_cluster = _read_clusters(self._dataset_cfg.cluster_path, synthetic=False) 298 | self._max_cluster = max(self._pdb_to_cluster.values()) 299 | self._missing_pdbs = 0 300 | def cluster_lookup(pdb): 301 | pdb = pdb.upper() 302 | if pdb not in self._pdb_to_cluster: 303 | self._pdb_to_cluster[pdb] = self._max_cluster + 1 304 | self._max_cluster += 1 305 | self._missing_pdbs += 1 306 | return self._pdb_to_cluster[pdb] 307 | metadata_csv['cluster'] = metadata_csv['pdb_name'].map(cluster_lookup) 308 | if self._dataset_cfg.use_redesigned: 309 | self.redesigned_csv = pd.read_csv(self._dataset_cfg.redesigned_csv_path) 310 | metadata_csv = metadata_csv.merge( 311 | self.redesigned_csv, left_on='pdb_name', right_on='example') 312 | metadata_csv = metadata_csv[metadata_csv.best_rmsd < 2.0] 313 | if self._dataset_cfg.use_synthetic: 314 | self.synthetic_csv = pd.read_csv(self._dataset_cfg.synthetic_csv_path) 315 | self._synthetic_pdb_to_cluster = _read_clusters(self._dataset_cfg.synthetic_cluster_path, synthetic=True) 316 | 317 | # offset all the cluster numbers by the number of real data clusters 318 | num_real_clusters = metadata_csv['cluster'].max() + 1 319 | def synthetic_cluster_lookup(pdb): 320 | pdb = pdb.upper() 321 | if pdb not in self._synthetic_pdb_to_cluster: 322 | raise ValueError(f"Synthetic example {pdb} not in synthetic cluster file!") 323 | return self._synthetic_pdb_to_cluster[pdb] + num_real_clusters 324 | self.synthetic_csv['cluster'] = self.synthetic_csv['pdb_name'].map(synthetic_cluster_lookup) 325 | 326 | metadata_csv = pd.concat([metadata_csv, self.synthetic_csv]) 327 | self._create_split(metadata_csv) 328 | 329 | 330 | if dataset_cfg.test_set_pdb_ids_path is not None: 331 | 332 | test_set_df = pd.read_csv(dataset_cfg.test_set_pdb_ids_path) 333 | 334 | self.csv = self.csv[self.csv['pdb_name'].isin(test_set_df['pdb_name'].values)] 335 | 336 | def pdb_filter_metadata(self, raw_csv): 337 | """Filter metadata.""" 338 | filter_cfg = self.dataset_cfg.filter 339 | data_csv = raw_csv[ 340 | raw_csv.oligomeric_detail.isin(filter_cfg.oligomeric)] 341 | data_csv = data_csv[ 342 | data_csv.num_chains.isin(filter_cfg.num_chains)] 343 | data_csv = _length_filter( 344 | data_csv, filter_cfg.min_num_res, filter_cfg.max_num_res) 345 | data_csv = _max_coil_filter(data_csv, filter_cfg.max_coil_percent) 346 | data_csv = _rog_filter(data_csv, filter_cfg.rog_quantile) 347 | return data_csv 348 | 349 | class PdbDataset(BaseDataset): 350 | 351 | def __init__(self, *, dataset_cfg, is_training, task): 352 | pdb_init_(self, dataset_cfg=dataset_cfg, is_training=is_training, task=task) 353 | 354 | def _filter_metadata(self, raw_csv): 355 | return pdb_filter_metadata(self, raw_csv) -------------------------------------------------------------------------------- /proteinblobs/multiflow/edge_feature_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils import get_index_embedding, calc_distogram 5 | 6 | class EdgeFeatureNet(nn.Module): 7 | 8 | def __init__(self, module_cfg): 9 | # c_s, c_p, relpos_k, template_type): 10 | super(EdgeFeatureNet, self).__init__() 11 | self._cfg = module_cfg 12 | 13 | self.c_s = self._cfg.c_s 14 | self.c_p = self._cfg.c_p 15 | self.feat_dim = self._cfg.feat_dim 16 | 17 | self.linear_s_p = nn.Linear(self.c_s, self.feat_dim) 18 | self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim) 19 | 20 | total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2 21 | if self._cfg.embed_chain: 22 | total_edge_feats += 1 23 | if self._cfg.embed_diffuse_mask: 24 | total_edge_feats += 2 25 | self.edge_embedder = nn.Sequential( 26 | nn.Linear(total_edge_feats, self.c_p), 27 | nn.ReLU(), 28 | nn.Linear(self.c_p, self.c_p), 29 | nn.ReLU(), 30 | nn.Linear(self.c_p, self.c_p), 31 | nn.LayerNorm(self.c_p), 32 | ) 33 | 34 | def embed_relpos(self, r): 35 | # AlphaFold 2 Algorithm 4 & 5 36 | # Based on OpenFold utils/tensor_utils.py 37 | # Input: [b, n_res] 38 | # [b, n_res, n_res] 39 | d = r[:, :, None] - r[:, None, :] 40 | pos_emb = get_index_embedding(d, self._cfg.feat_dim, max_len=2056) 41 | return self.linear_relpos(pos_emb) 42 | 43 | def _cross_concat(self, feats_1d, num_batch, num_res): 44 | return torch.cat([ 45 | torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)), 46 | torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)), 47 | ], dim=-1).float().reshape([num_batch, num_res, num_res, -1]) 48 | 49 | def forward(self, s, t, sc_t, p_mask, diffuse_mask, chain_idx): 50 | # Input: [b, n_res, c_s] 51 | num_batch, num_res, _ = s.shape 52 | 53 | # [b, n_res, c_p] 54 | p_i = self.linear_s_p(s) 55 | cross_node_feats = self._cross_concat(p_i, num_batch, num_res) 56 | 57 | # [b, n_res] 58 | r = torch.arange( 59 | num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1) 60 | relpos_feats = self.embed_relpos(r) 61 | 62 | dist_feats = calc_distogram( 63 | t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins) 64 | sc_feats = calc_distogram( 65 | sc_t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins) 66 | 67 | all_edge_feats = [cross_node_feats, relpos_feats, dist_feats, sc_feats] 68 | if self._cfg.embed_chain: 69 | rel_chain = (chain_idx[:, :, None] == chain_idx[:, None, :]).float() 70 | all_edge_feats.append(rel_chain[..., None]) 71 | if self._cfg.embed_diffuse_mask: 72 | diff_feat = self._cross_concat(diffuse_mask[..., None], num_batch, num_res) 73 | all_edge_feats.append(diff_feat) 74 | edge_feats = self.edge_embedder(torch.concat(all_edge_feats, dim=-1)) 75 | edge_feats *= p_mask.unsqueeze(-1) 76 | return edge_feats -------------------------------------------------------------------------------- /proteinblobs/multiflow/flow_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import torch 12 | from torch import nn 13 | 14 | from .node_feature_net import NodeFeatureNet 15 | from .edge_feature_net import EdgeFeatureNet 16 | from .cross_attention import InvariantCrossAttention 17 | from . import ipa_pytorch 18 | from .data import utils as du 19 | import math 20 | 21 | def positional_encoding(x, c, max_period=10000, min_period=None): 22 | min_period = min_period or max_period / 10000 # second one recommended 23 | freqs = torch.exp(-torch.linspace(math.log(min_period), math.log(max_period), c // 2, device=x.device)) 24 | emb = freqs * x.unsqueeze(-1) # [..., C] 25 | return torch.cat([torch.sin(emb), torch.cos(emb)], -1) # [..., 2C] 26 | 27 | class FlowModel(nn.Module): 28 | 29 | def __init__(self, model_conf, args): 30 | super(FlowModel, self).__init__() 31 | self._model_conf = model_conf 32 | self.args = args 33 | self._ipa_conf = model_conf.ipa 34 | self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE) 35 | self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE) 36 | self.node_feature_net = NodeFeatureNet(model_conf.node_features) 37 | self.edge_feature_net = EdgeFeatureNet(model_conf.edge_features) 38 | if self.args.use_latents: 39 | self.latent_embedder = ipa_pytorch.Linear(32, model_conf.node_features.c_s, init="final") 40 | 41 | if self.args.blob_attention: 42 | self.ground_feature_net = nn.Embedding(3, self._ipa_conf.c_s) 43 | self.ground_trace_feature_net = nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s) 44 | self.ground_size_feature_net = nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s) 45 | 46 | if self._model_conf.aatype_pred: 47 | node_embed_size = self._model_conf.node_embed_size 48 | self.aatype_pred_net = nn.Sequential( 49 | nn.Linear(node_embed_size, node_embed_size), 50 | nn.ReLU(), 51 | nn.Linear(node_embed_size, node_embed_size), 52 | nn.ReLU(), 53 | nn.Linear(node_embed_size, self._model_conf.aatype_pred_num_tokens), 54 | ) 55 | 56 | # Attention trunk 57 | self.trunk = nn.ModuleDict() 58 | for b in range(self._ipa_conf.num_blocks): 59 | self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf) 60 | self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s) 61 | if self.args.blob_attention: 62 | self.trunk[f'ica_{b}'] = InvariantCrossAttention(self._ipa_conf) 63 | tfmr_in = self._ipa_conf.c_s 64 | tfmr_layer = torch.nn.TransformerEncoderLayer( 65 | d_model=tfmr_in, 66 | nhead=self._ipa_conf.seq_tfmr_num_heads, 67 | dim_feedforward=tfmr_in, 68 | batch_first=True, 69 | dropout=self._model_conf.transformer_dropout, 70 | norm_first=False 71 | ) 72 | self.trunk[f'seq_tfmr_{b}'] = nn.TransformerEncoder(tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False) 73 | self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(tfmr_in, self._ipa_conf.c_s, init="final") 74 | 75 | if self.args.extra_attn_layer: 76 | self.trunk[f'cond_tfmr_{b}'] = nn.TransformerEncoder(tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False) 77 | self.trunk[f'post_cond_tfmr_{b}'] = ipa_pytorch.Linear(tfmr_in, self._ipa_conf.c_s, init="final") 78 | 79 | 80 | self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(c=self._ipa_conf.c_s) 81 | self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(self._ipa_conf.c_s, use_rot_updates=True) 82 | 83 | if b < self._ipa_conf.num_blocks-1: 84 | # No edge update on the last block. 85 | edge_in = self._model_conf.edge_embed_size 86 | self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition( 87 | node_embed_size=self._ipa_conf.c_s, 88 | edge_embed_in=edge_in, 89 | edge_embed_out=self._model_conf.edge_embed_size, 90 | ) 91 | 92 | if args.freeze_weights: 93 | for name, p in self.named_parameters(): 94 | if ('ica' not in name) and ('cond_tfmr' not in name) and ('ground' not in name): 95 | p.requires_grad_(False) 96 | 97 | def forward(self, input_feats): 98 | node_mask = input_feats['res_mask'] 99 | edge_mask = node_mask[:, None] * node_mask[:, :, None] 100 | diffuse_mask = input_feats['diffuse_mask'] 101 | chain_index = input_feats['chain_idx'] 102 | res_index = input_feats['res_idx'] 103 | latents = input_feats['latents'] 104 | so3_t = input_feats['so3_t'] 105 | r3_t = input_feats['r3_t'] 106 | cat_t = input_feats['cat_t'] 107 | trans_t = input_feats['trans_t'] 108 | rotmats_t = input_feats['rotmats_t'] 109 | aatypes_t = input_feats['aatypes_t'].long() 110 | trans_sc = input_feats['trans_sc'] 111 | aatypes_sc = input_feats['aatypes_sc'] 112 | 113 | ########## 114 | if self.args.blob_attention: 115 | grounding_feats = input_feats['grounding_feat'] 116 | grounding_pos = input_feats['grounding_pos'] 117 | grounding_mask = input_feats['grounding_mask'] 118 | 119 | ### 120 | B, L, _ = grounding_pos.shape 121 | grounding_covar = grounding_feats[:,:,-9:].view(B, L, 3, 3) 122 | trace = grounding_covar[:,:,torch.arange(3),torch.arange(3)].sum(-1) 123 | grounding_covar = grounding_covar / (trace[:,:,None,None] + 1e-3) # "normalized covariance" 124 | ### 125 | grounding_feats = ( 126 | self.ground_feature_net(grounding_feats[...,0].long()) + 127 | self.ground_size_feature_net( 128 | positional_encoding(grounding_feats[...,1].float(), self._ipa_conf.c_s) 129 | ) + 130 | self.ground_trace_feature_net( 131 | positional_encoding(trace, self._ipa_conf.c_s) 132 | ) 133 | ) 134 | assert not torch.isnan(grounding_feats).any() 135 | ########## 136 | 137 | # Initialize node and edge embeddings 138 | init_node_embed = self.node_feature_net( 139 | so3_t=so3_t, 140 | r3_t=r3_t, 141 | cat_t=cat_t, 142 | res_mask=node_mask, 143 | diffuse_mask=diffuse_mask, 144 | chain_index=chain_index, 145 | pos=res_index, 146 | aatypes=aatypes_t, 147 | aatypes_sc=aatypes_sc, 148 | ) 149 | if self.args.use_latents: 150 | init_node_embed = init_node_embed + self.latent_embedder(latents)[:,None,:] 151 | 152 | init_edge_embed = self.edge_feature_net( 153 | init_node_embed, 154 | trans_t, 155 | trans_sc, 156 | edge_mask, 157 | diffuse_mask, 158 | chain_index 159 | ) 160 | 161 | 162 | # Initial rigids 163 | init_rigids = du.create_rigid(rotmats_t, trans_t) 164 | curr_rigids = du.create_rigid(rotmats_t, trans_t) 165 | 166 | # Main trunk 167 | curr_rigids = self.rigids_ang_to_nm(curr_rigids) 168 | init_node_embed = init_node_embed * node_mask[..., None] 169 | node_embed = init_node_embed * node_mask[..., None] 170 | edge_embed = init_edge_embed * edge_mask[..., None] 171 | for b in range(self._ipa_conf.num_blocks): 172 | ipa_embed = self.trunk[f'ipa_{b}']( 173 | node_embed, 174 | edge_embed, 175 | curr_rigids, 176 | node_mask) 177 | ipa_embed *= node_mask[..., None] 178 | node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed) 179 | 180 | ############ 181 | if self.args.blob_attention: 182 | ica_out = self.trunk[f'ica_{b}']( 183 | node_embed, 184 | curr_rigids, 185 | grounding_pos * du.ANG_TO_NM_SCALE, # FIX, 186 | grounding_feats, 187 | grounding_covar, 188 | grounding_mask 189 | ) 190 | node_embed = node_embed + ica_out * (1 if self.training else self.args.inference_gating) 191 | ########### 192 | 193 | def heterogenous_attention(layer, post_layer, node_embed, node_mask, grounding_feats, grounding_mask): 194 | 195 | node_embed_ = torch.cat([node_embed, grounding_feats], 1) 196 | node_mask_ = torch.cat([node_mask, grounding_mask], 1) 197 | cond_tfmr_out = layer(node_embed_, src_key_padding_mask=(1 - node_mask_).bool()) 198 | L = node_embed.shape[1] 199 | 200 | cond_tfmr_out = post_layer(cond_tfmr_out) 201 | node_embed = node_embed + cond_tfmr_out[:,:L] * (1 if self.training else self.args.inference_gating) 202 | grounding_feats = grounding_feats + cond_tfmr_out[:,L:] 203 | 204 | return node_embed, grounding_feats 205 | 206 | if self.args.blob_attention and not self.args.extra_attn_layer: 207 | node_embed, grounding_feats = heterogenous_attention( 208 | self.trunk[f'seq_tfmr_{b}'], 209 | self.trunk[f'post_tfmr_{b}'], 210 | node_embed, node_mask, grounding_feats, grounding_mask) 211 | 212 | else: 213 | seq_tfmr_out = self.trunk[f'seq_tfmr_{b}']( 214 | node_embed, src_key_padding_mask=(1 - node_mask).to(torch.bool)) 215 | node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out) 216 | 217 | ################ 218 | if self.args.blob_attention and self.args.extra_attn_layer: 219 | node_embed, grounding_feats = heterogenous_attention( 220 | self.trunk[f'cond_tfmr_{b}'], 221 | self.trunk[f'post_cond_tfmr_{b}'], 222 | node_embed, node_mask, grounding_feats, grounding_mask) 223 | 224 | ################ 225 | node_embed = self.trunk[f'node_transition_{b}'](node_embed) 226 | node_embed = node_embed * node_mask[..., None] 227 | rigid_update = self.trunk[f'bb_update_{b}']( 228 | node_embed * node_mask[..., None]) 229 | curr_rigids = curr_rigids.compose_q_update_vec( 230 | rigid_update, (node_mask * diffuse_mask)[..., None]) 231 | if b < self._ipa_conf.num_blocks-1: 232 | edge_embed = self.trunk[f'edge_transition_{b}']( 233 | node_embed, edge_embed) 234 | edge_embed *= edge_mask[..., None] 235 | 236 | curr_rigids = self.rigids_nm_to_ang(curr_rigids) 237 | pred_trans = curr_rigids.get_trans() 238 | pred_rotmats = curr_rigids.get_rots().get_rot_mats() 239 | if self._model_conf.aatype_pred: 240 | pred_logits = self.aatype_pred_net(node_embed) 241 | pred_aatypes = torch.argmax(pred_logits, dim=-1) 242 | if self._model_conf.aatype_pred_num_tokens == du.NUM_TOKENS + 1: 243 | pred_logits_wo_mask = pred_logits.clone() 244 | pred_logits_wo_mask[:, :, du.MASK_TOKEN_INDEX] = -1e9 245 | pred_aatypes = torch.argmax(pred_logits_wo_mask, dim=-1) 246 | else: 247 | pred_aatypes = torch.argmax(pred_logits, dim=-1) 248 | else: 249 | pred_aatypes = aatypes_t 250 | pred_logits = nn.functional.one_hot( 251 | pred_aatypes, num_classes=self._model_conf.aatype_pred_num_tokens 252 | ).float() 253 | return { 254 | 'pred_trans': pred_trans, 255 | 'pred_rotmats': pred_rotmats, 256 | 'pred_logits': pred_logits, 257 | 'pred_aatypes': pred_aatypes, 258 | } 259 | -------------------------------------------------------------------------------- /proteinblobs/multiflow/node_feature_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils import get_index_embedding, get_time_embedding 4 | 5 | 6 | class NodeFeatureNet(nn.Module): 7 | 8 | def __init__(self, module_cfg): 9 | super(NodeFeatureNet, self).__init__() 10 | self._cfg = module_cfg 11 | self.c_s = self._cfg.c_s 12 | self.c_pos_emb = self._cfg.c_pos_emb 13 | self.c_timestep_emb = self._cfg.c_timestep_emb 14 | embed_size = self._cfg.c_pos_emb + self._cfg.c_timestep_emb * 2 + 1 15 | if self._cfg.embed_chain: 16 | embed_size += self._cfg.c_pos_emb 17 | if self._cfg.embed_aatype: 18 | self.aatype_embedding = nn.Embedding(21, self.c_s) # Always 21 because of 20 amino acids + 1 for unk 19 | embed_size += self.c_s + self._cfg.c_timestep_emb + self._cfg.aatype_pred_num_tokens 20 | if self._cfg.use_mlp: 21 | self.linear = nn.Sequential( 22 | nn.Linear(embed_size, self.c_s), 23 | nn.ReLU(), 24 | nn.Linear(self.c_s, self.c_s), 25 | nn.ReLU(), 26 | nn.Linear(self.c_s, self.c_s), 27 | nn.LayerNorm(self.c_s), 28 | ) 29 | else: 30 | self.linear = nn.Linear(embed_size, self.c_s) 31 | 32 | def embed_t(self, timesteps, mask): 33 | timestep_emb = get_time_embedding( 34 | timesteps[:, 0], 35 | self.c_timestep_emb, 36 | max_positions=2056 37 | )[:, None, :].repeat(1, mask.shape[1], 1) 38 | return timestep_emb * mask.unsqueeze(-1) 39 | 40 | def forward( 41 | self, 42 | *, 43 | so3_t, 44 | r3_t, 45 | cat_t, 46 | res_mask, 47 | diffuse_mask, 48 | chain_index, 49 | pos, 50 | aatypes, 51 | aatypes_sc, 52 | ): 53 | # s: [b] 54 | 55 | # [b, n_res, c_pos_emb] 56 | pos_emb = get_index_embedding(pos, self.c_pos_emb, max_len=2056) 57 | pos_emb = pos_emb * res_mask.unsqueeze(-1) 58 | 59 | # [b, n_res, c_timestep_emb] 60 | input_feats = [ 61 | pos_emb, 62 | diffuse_mask[..., None], 63 | self.embed_t(so3_t, res_mask), 64 | self.embed_t(r3_t, res_mask) 65 | ] 66 | if self._cfg.embed_aatype: 67 | input_feats.append(self.aatype_embedding(aatypes)) 68 | input_feats.append(self.embed_t(cat_t, res_mask)) 69 | input_feats.append(aatypes_sc) 70 | if self._cfg.embed_chain: 71 | input_feats.append( 72 | get_index_embedding( 73 | chain_index, 74 | self.c_pos_emb, 75 | max_len=100 76 | ) 77 | ) 78 | return self.linear(torch.cat(input_feats, dim=-1)) 79 | -------------------------------------------------------------------------------- /proteinblobs/multiflow/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import re 4 | import numpy as np 5 | import pandas as pd 6 | from torch.nn import functional as F 7 | from .data import utils as du 8 | from openfold.utils.superimposition import superimpose 9 | from openfold.np import residue_constants 10 | import pickle 11 | 12 | CA_IDX = residue_constants.atom_order['CA'] 13 | 14 | def calc_distogram(pos, min_bin, max_bin, num_bins): 15 | dists_2d = torch.linalg.norm( 16 | pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] 17 | lower = torch.linspace( 18 | min_bin, 19 | max_bin, 20 | num_bins, 21 | device=pos.device) 22 | upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) 23 | dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) 24 | return dgram 25 | 26 | 27 | def get_index_embedding(indices, embed_size, max_len=2056): 28 | """Creates sine / cosine positional embeddings from a prespecified indices. 29 | 30 | Args: 31 | indices: offsets of size [..., N_edges] of type integer 32 | max_len: maximum length. 33 | embed_size: dimension of the embeddings to create 34 | 35 | Returns: 36 | positional embedding of shape [N, embed_size] 37 | """ 38 | K = torch.arange(embed_size//2, device=indices.device) 39 | pos_embedding_sin = torch.sin( 40 | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) 41 | pos_embedding_cos = torch.cos( 42 | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) 43 | pos_embedding = torch.cat([ 44 | pos_embedding_sin, pos_embedding_cos], axis=-1) 45 | return pos_embedding 46 | 47 | 48 | def get_time_embedding(timesteps, embedding_dim, max_positions=2000): 49 | # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py 50 | assert len(timesteps.shape) == 1 51 | timesteps = timesteps * max_positions 52 | half_dim = embedding_dim // 2 53 | emb = math.log(max_positions) / (half_dim - 1) 54 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 55 | emb = timesteps.float()[:, None] * emb[None, :] 56 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 57 | if embedding_dim % 2 == 1: # zero pad 58 | emb = F.pad(emb, (0, 1), mode='constant') 59 | assert emb.shape == (timesteps.shape[0], embedding_dim) 60 | return emb 61 | 62 | 63 | def sinusoidal_encoding(v, N, D): 64 | """Taken from GENIE. 65 | 66 | Args: 67 | 68 | """ 69 | # v: [*] 70 | 71 | # [D] 72 | k = torch.arange(1, D+1).to(v.device) 73 | 74 | # [*, D] 75 | sin_div_term = N ** (2 * k / D) 76 | sin_div_term = sin_div_term.view(*((1, ) * len(v.shape) + (len(sin_div_term), ))) 77 | sin_enc = torch.sin(v.unsqueeze(-1) * math.pi / sin_div_term) 78 | 79 | # [*, D] 80 | cos_div_term = N ** (2 * (k - 1) / D) 81 | cos_div_term = cos_div_term.view(*((1, ) * len(v.shape) + (len(cos_div_term), ))) 82 | cos_enc = torch.cos(v.unsqueeze(-1) * math.pi / cos_div_term) 83 | 84 | # [*, D] 85 | enc = torch.zeros_like(sin_enc).to(v.device) 86 | enc[..., 0::2] = cos_enc[..., 0::2] 87 | enc[..., 1::2] = sin_enc[..., 1::2] 88 | 89 | return enc.to(v.dtype) 90 | 91 | 92 | def distance(p, eps=1e-10): 93 | # [*, 2, 3] 94 | return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5 95 | 96 | 97 | def dist_from_ca(trans): 98 | 99 | # [b, n_res, n_res, 1] 100 | d = distance(torch.stack([ 101 | trans.unsqueeze(2).repeat(1, 1, trans.shape[1], 1), # Ca_1 102 | trans.unsqueeze(1).repeat(1, trans.shape[1], 1, 1), # Ca_2 103 | ], dim=-2)).unsqueeze(-1) 104 | 105 | return d 106 | 107 | 108 | def calc_rbf(ca_dists, num_rbf, D_min=1e-3, D_max=22.): 109 | # Distance radial basis function 110 | device = ca_dists.device 111 | D_mu = torch.linspace(D_min, D_max, num_rbf).to(device) 112 | D_mu = D_mu.view([1,1,1,-1]) 113 | D_sigma = (D_max - D_min) / num_rbf 114 | return torch.exp(-((ca_dists - D_mu) / D_sigma)**2) 115 | 116 | 117 | def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None): 118 | """Stratify loss by binning t.""" 119 | batch_t = du.to_numpy(batch_t) 120 | batch_loss = du.to_numpy(batch_loss) 121 | flat_losses = batch_loss.flatten() 122 | flat_t = batch_t.flatten() 123 | bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins+1) 124 | bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1 125 | t_binned_loss = np.bincount(bin_idx, weights=flat_losses) 126 | t_binned_n = np.bincount(bin_idx) 127 | stratified_losses = {} 128 | if loss_name is None: 129 | loss_name = 'loss' 130 | for t_bin in np.unique(bin_idx).tolist(): 131 | bin_start = bin_edges[t_bin] 132 | bin_end = bin_edges[t_bin+1] 133 | t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})' 134 | range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin] 135 | stratified_losses[t_range] = range_loss 136 | return stratified_losses 137 | 138 | 139 | 140 | def process_folded_outputs(sample_path, folded_output, true_bb_pos=None): 141 | mpnn_results = { 142 | 'header': [], 143 | 'sequence': [], 144 | 'bb_rmsd': [], 145 | 'mean_plddt': [], 146 | 'folded_path': [], 147 | } 148 | 149 | if true_bb_pos is not None: 150 | mpnn_results['bb_rmsd_to_gt'] = [] 151 | mpnn_results['fold_model_bb_rmsd_to_gt'] = [] 152 | 153 | sample_feats = du.parse_pdb_feats('sample', sample_path) 154 | sample_ca_pos = sample_feats['bb_positions'] 155 | def _calc_ca_rmsd(mask, folded_ca_pos): 156 | return superimpose( 157 | torch.tensor(sample_ca_pos)[None], 158 | torch.tensor(folded_ca_pos[None]), 159 | mask 160 | )[1].rmsd[0].item() 161 | 162 | sample_bb_pos = sample_feats['atom_positions'][:, :3].reshape(-1, 3) 163 | def _calc_bb_rmsd(mask, sample_bb_pos, folded_bb_pos): 164 | aligned_rmsd = superimpose( 165 | torch.tensor(sample_bb_pos)[None], 166 | torch.tensor(folded_bb_pos[None]), 167 | mask[:, None].repeat(1, 3).reshape(-1) 168 | ) 169 | return aligned_rmsd[1].item() 170 | 171 | for _, row in folded_output.iterrows(): 172 | folded_feats = du.parse_pdb_feats('folded', row.folded_path) 173 | seq = du.aatype_to_seq(folded_feats['aatype']) 174 | folded_ca_pos = folded_feats['bb_positions'] 175 | folded_bb_pos = folded_feats['atom_positions'][:, :3].reshape(-1, 3) 176 | 177 | res_mask = torch.ones(folded_ca_pos.shape[0]) 178 | 179 | if true_bb_pos is not None: 180 | bb_rmsd_to_gt = _calc_bb_rmsd(res_mask, sample_bb_pos, true_bb_pos) 181 | mpnn_results['bb_rmsd_to_gt'].append(bb_rmsd_to_gt) 182 | fold_model_bb_rmsd_to_gt = _calc_bb_rmsd(res_mask, folded_bb_pos, true_bb_pos) 183 | mpnn_results['fold_model_bb_rmsd_to_gt'].append(fold_model_bb_rmsd_to_gt) 184 | bb_rmsd = _calc_bb_rmsd(res_mask, sample_bb_pos, folded_bb_pos) 185 | mpnn_results['bb_rmsd'].append(bb_rmsd) 186 | mpnn_results['folded_path'].append(row.folded_path) 187 | mpnn_results['header'].append(row.header) 188 | mpnn_results['sequence'].append(seq) 189 | mpnn_results['mean_plddt'].append(row.plddt) 190 | mpnn_results = pd.DataFrame(mpnn_results) 191 | mpnn_results['sample_path'] = sample_path 192 | return mpnn_results 193 | 194 | def extract_clusters_from_maxcluster_out(file_path): 195 | # Extracts cluster information from the stdout of a maxcluster run 196 | cluster_to_paths = {} 197 | paths_to_cluster = {} 198 | read_mode = False 199 | with open(file_path, 'r') as file: 200 | lines = file.readlines() 201 | for line in lines: 202 | if line == "INFO : Item Cluster\n": 203 | read_mode = True 204 | continue 205 | 206 | if line == "INFO : ======================================\n": 207 | read_mode = False 208 | 209 | if read_mode: 210 | # Define a regex pattern to match the second number and the path 211 | pattern = r"INFO\s+:\s+\d+\s:\s+(\d+)\s+(\S+)" 212 | 213 | # Use re.search to find the first match in the string 214 | match = re.search(pattern, line) 215 | 216 | # Check if a match is found 217 | if match: 218 | # Extract the second number and the path 219 | cluster_id = match.group(1) 220 | path = match.group(2) 221 | if cluster_id not in cluster_to_paths: 222 | cluster_to_paths[cluster_id] = [path] 223 | else: 224 | cluster_to_paths[cluster_id].append(path) 225 | paths_to_cluster[path] = cluster_id 226 | 227 | else: 228 | raise ValueError(f"Could not parse line: {line}") 229 | 230 | return cluster_to_paths, paths_to_cluster 231 | 232 | def calc_mdtraj_metrics(pdb_path): 233 | try: 234 | traj = md.load(pdb_path) 235 | pdb_ss = md.compute_dssp(traj, simplified=True) 236 | pdb_coil_percent = np.mean(pdb_ss == 'C') 237 | pdb_helix_percent = np.mean(pdb_ss == 'H') 238 | pdb_strand_percent = np.mean(pdb_ss == 'E') 239 | pdb_ss_percent = pdb_helix_percent + pdb_strand_percent 240 | pdb_rg = md.compute_rg(traj)[0] 241 | except IndexError as e: 242 | print('Error in calc_mdtraj_metrics: {}'.format(e)) 243 | pdb_ss_percent = 0.0 244 | pdb_coil_percent = 0.0 245 | pdb_helix_percent = 0.0 246 | pdb_strand_percent = 0.0 247 | pdb_rg = 0.0 248 | return { 249 | 'non_coil_percent': pdb_ss_percent, 250 | 'coil_percent': pdb_coil_percent, 251 | 'helix_percent': pdb_helix_percent, 252 | 'strand_percent': pdb_strand_percent, 253 | 'radius_of_gyration': pdb_rg, 254 | } 255 | 256 | def calc_aatype_metrics(generated_aatypes): 257 | # generated_aatypes (B, N) 258 | unique_aatypes, raw_counts = np.unique(generated_aatypes, return_counts=True) 259 | 260 | # pad with 0's in case it didn't generate any of a certain type 261 | clean_counts = [] 262 | for i in range(20): 263 | if i in unique_aatypes: 264 | clean_counts.append(raw_counts[np.where(unique_aatypes == i)[0][0]]) 265 | else: 266 | clean_counts.append(0) 267 | 268 | # from the scope128 dataset 269 | reference_normalized_counts = [ 270 | 0.0739, 0.05378621, 0.0410424, 0.05732177, 0.01418736, 0.03995128, 271 | 0.07562267, 0.06695857, 0.02163064, 0.0580802, 0.09333149, 0.06777057, 272 | 0.02034217, 0.03673995, 0.04428474, 0.05987899, 0.05502958, 0.01228988, 273 | 0.03233601, 0.07551553 274 | ] 275 | 276 | reference_normalized_counts = np.array(reference_normalized_counts) 277 | 278 | normalized_counts = clean_counts / np.sum(clean_counts) 279 | 280 | # compute the hellinger distance between the normalized counts 281 | # and the reference normalized counts 282 | 283 | hellinger_distance = np.sqrt(np.sum(np.square(np.sqrt(normalized_counts) - np.sqrt(reference_normalized_counts)))) 284 | 285 | return { 286 | 'aatype_histogram_dist': hellinger_distance 287 | } 288 | 289 | def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0): 290 | ca_bond_dists = np.linalg.norm( 291 | ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:] 292 | ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca)) 293 | ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol)) 294 | 295 | ca_ca_dists2d = np.linalg.norm( 296 | ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1) 297 | inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)] 298 | clashes = inter_dists < clash_tol 299 | return { 300 | 'ca_ca_deviation': ca_ca_dev, 301 | 'ca_ca_valid_percent': ca_ca_valid, 302 | 'num_ca_ca_clashes': np.sum(clashes), 303 | } 304 | 305 | class CPU_Unpickler(pickle.Unpickler): 306 | """Pytorch pickle loading workaround. 307 | 308 | https://github.com/pytorch/pytorch/issues/16797 309 | """ 310 | def find_class(self, module, name): 311 | if module == 'torch.storage' and name == '_load_from_bytes': 312 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 313 | else: return super().find_class(module, name) 314 | 315 | 316 | 317 | def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): 318 | """Read data from a pickle file.""" 319 | try: 320 | if use_torch: 321 | return torch.load(read_path, map_location=map_location) 322 | else: 323 | with open(read_path, 'rb') as handle: 324 | return pickle.load(handle) 325 | except Exception as e: 326 | try: 327 | with open(read_path, 'rb') as handle: 328 | return CPU_Unpickler(handle).load() 329 | except Exception as e2: 330 | if verbose: 331 | print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}') 332 | raise(e) 333 | 334 | def parse_chain_feats(chain_feats, scale_factor=1., center=True): 335 | chain_feats['bb_mask'] = chain_feats['atom_mask'][:, CA_IDX] 336 | bb_pos = chain_feats['atom_positions'][:, CA_IDX] 337 | if center: 338 | bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) 339 | centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] 340 | scaled_pos = centered_pos / scale_factor 341 | else: 342 | scaled_pos = chain_feats['atom_positions'] / scale_factor 343 | chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] 344 | chain_feats['bb_positions'] = chain_feats['atom_positions'][:, CA_IDX] 345 | return chain_feats 346 | -------------------------------------------------------------------------------- /proteinblobs/parsing.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from argparse import ArgumentParser 12 | import subprocess, os, sys 13 | 14 | 15 | def parse_train_args(args=sys.argv): 16 | parser = ArgumentParser() 17 | 18 | ## Trainer settings 19 | 20 | ## Trainer settings 21 | group = parser.add_argument_group("Epoch settings") 22 | group.add_argument("--ckpt", type=str, default=None) 23 | group.add_argument("--validate", action='store_true', default=False) 24 | group.add_argument("--num_workers", type=int, default=10) 25 | group.add_argument("--epochs", type=int, default=100) 26 | group.add_argument("--train_batches", type=int, default=None) 27 | group.add_argument("--val_batches", type=int, default=None) 28 | group.add_argument("--batch_size", type=int, default=16) 29 | group.add_argument("--val_freq", type=int, default=None) 30 | group.add_argument("--val_epoch_freq", type=int, default=1) 31 | group.add_argument("--seed", type=int, default=137) 32 | group.add_argument("--no_validate", action='store_true') 33 | 34 | ## Inference settings 35 | group = parser.add_argument_group("Inference Settings") 36 | group.add_argument("--inference_gating", type=float, default=1) 37 | group.add_argument("--guidance_weight", type=float, default=1) 38 | group.add_argument("--seq_guidance_weight", type=float, default=1) 39 | group.add_argument("--guidance", action='store_true') 40 | group.add_argument("--sc_guidance_mode", choices=['separate', 'guided', 'unguided'], default='separate', help='which model output to use as self conditinoing input to the models') 41 | group.add_argument("--interpolate_sc", action='store_true') 42 | group.add_argument("--only_save_blobs", action='store_true') 43 | group.add_argument("--interpolate_sc_weight", type=float, default=1) 44 | group.add_argument("--inference_rot_scaling", type=float, default=10) 45 | group.add_argument("--fixed_inference_size", type=int, default=None) 46 | 47 | ## Model settings 48 | group = parser.add_argument_group("Model Settings") 49 | group.add_argument("--pretrained_mf_path", type=str, default='weights/last.ckpt') 50 | group.add_argument("--finetune", action='store_true') 51 | group.add_argument("--freeze_weights", action='store_true') 52 | group.add_argument("--extra_attn_layer", action='store_true') 53 | group.add_argument("--blob_attention", action='store_true') 54 | group.add_argument("--blob_drop_prob", type=float, default=0.0) 55 | group.add_argument("--self_condition", action='store_true') 56 | 57 | ## Data settings 58 | group = parser.add_argument_group("Data Settings") 59 | group.add_argument('--dataset', choices=['multiflow', 'genie'], default='genie') 60 | group.add_argument("--scop_dir", type=str, default='') 61 | group.add_argument("--genie_db_path", type=str, default='', help= '') 62 | group.add_argument("--scop_nmr_csv", type=str, default='scop_nmr_files.csv') 63 | group.add_argument('--pkl_dir', type=str, default='data') 64 | group.add_argument('--clustering_dir', type=str, default='proteinblobs/clustering_20rpc') 65 | group.add_argument('--latents_path', type=str, default='projects/protfid/cache/encodings/SimCLR-ca-dim_32/mf_all.pth') 66 | group.add_argument('--latents_order', type=str, default='projects/protfid/cache/encodings/SimCLR-ca-dim_32/mf_all_names.pickle') 67 | group.add_argument("--use_latents", action='store_true') 68 | group.add_argument('--length_dist_npz', type=str, default=None) 69 | group.add_argument("--crop", type=int, default=256) 70 | group.add_argument("--no_crop", action='store_true') 71 | group.add_argument("--res_per_cluster", type=int, default=20) 72 | group.add_argument("--min_blob_thresh", type=float, default=5) 73 | group.add_argument("--max_blob_thresh", type=float, default=10) 74 | group.add_argument("--dataset_multiplicity", type=int, default=None) 75 | group.add_argument("--repeat_dataset", type=int, default=1) 76 | group.add_argument("--no_pad", action='store_true') 77 | group.add_argument("--overfit", action='store_true') 78 | group.add_argument("--overfit_rot", action='store_true') 79 | 80 | ## Synthetic Blob data settings 81 | group = parser.add_argument_group("Data Settings") 82 | group.add_argument("--synthetic_blobs", action='store_true') 83 | group.add_argument('--num_blobs', type=int, default=5) 84 | group.add_argument('--nu', type=int, default=10) # vary the blob anisotropy [10, 20, 50, 100] 85 | group.add_argument('--psi', type=int, default=5) # FIX this to give proteins with O(200) residues 86 | group.add_argument('--sigma', type=int, default=8) # vary this compactness parameter 87 | group.add_argument('--helix_frac', type=float, default=0.5) 88 | 89 | ## Logging args 90 | group = parser.add_argument_group("Logging settings") 91 | group.add_argument("--print_freq", type=int, default=100) 92 | group.add_argument("--ckpt_freq", type=int, default=1) 93 | group.add_argument("--wandb", action="store_true") 94 | group.add_argument("--save_val", action="store_true") 95 | group.add_argument("--save_single_pdb", action="store_true") 96 | group.add_argument("--run_name", type=str, default="default") 97 | group.add_argument("--inf_batches", type=int, default=4) 98 | group.add_argument("--designability", action='store_true') 99 | group.add_argument("--self_consistency", action='store_true') 100 | group.add_argument("--ref_as_sample", action='store_true') 101 | group.add_argument("--num_timesteps", type=int, default=500) 102 | group.add_argument("--designability_freq", type=int, default=1) 103 | group.add_argument("--num_designability_prots", type=int, default=1000) 104 | group.add_argument("--pmpnn_path", type=str, default='../ProteinMPNN') 105 | 106 | ## Optimization settings 107 | group = parser.add_argument_group("Optimization settings") 108 | group.add_argument("--accumulate_grad", type=int, default=1) 109 | group.add_argument("--lr_scheduler", action='store_true') 110 | group.add_argument("--grad_clip", type=float, default=1.) 111 | group.add_argument("--check_grad", action='store_true') 112 | group.add_argument('--grad_checkpointing', action='store_true') 113 | group.add_argument('--adamW', action='store_true') 114 | group.add_argument('--ema', action='store_true') 115 | group.add_argument('--ema_decay', type=float, default=0.999) 116 | group.add_argument("--lr", type=float, default=1e-4) 117 | 118 | args = parser.parse_args() 119 | os.environ["MODEL_DIR"] = os.path.join("workdir", args.run_name) 120 | os.environ["WANDB_LOGGING"] = str(int(args.wandb)) 121 | return args 122 | 123 | 124 | -------------------------------------------------------------------------------- /proteinblobs/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import re 13 | import subprocess 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from .multiflow.data.all_atom import compute_backbone 19 | import openfold.np.protein as protein 20 | 21 | 22 | def rigid_transform_3D(A, B, verbose=False): 23 | # Transforms A to look like B 24 | # https://github.com/nghiaho12/rigid_transform_3D 25 | assert A.shape == B.shape 26 | A = A.T 27 | B = B.T 28 | 29 | num_rows, num_cols = A.shape 30 | if num_rows != 3: 31 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 32 | 33 | num_rows, num_cols = B.shape 34 | if num_rows != 3: 35 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 36 | 37 | # find mean column wise 38 | centroid_A = np.mean(A, axis=1) 39 | centroid_B = np.mean(B, axis=1) 40 | 41 | # ensure centroids are 3x1 42 | centroid_A = centroid_A.reshape(-1, 1) 43 | centroid_B = centroid_B.reshape(-1, 1) 44 | 45 | # subtract mean 46 | Am = A - centroid_A 47 | Bm = B - centroid_B 48 | 49 | H = Am @ np.transpose(Bm) 50 | 51 | # sanity check 52 | # if linalg.matrix_rank(H) < 3: 53 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H))) 54 | 55 | # find rotation 56 | U, S, Vt = np.linalg.svd(H) 57 | R = Vt.T @ U.T 58 | 59 | # special reflection case 60 | reflection_detected = False 61 | if np.linalg.det(R) < 0: 62 | if verbose: 63 | print("det(R) < R, reflection detected!, correcting for it ...") 64 | Vt[2, :] *= -1 65 | R = Vt.T @ U.T 66 | reflection_detected = True 67 | 68 | t = -R @ centroid_A + centroid_B 69 | optimal_A = R @ A + t 70 | 71 | return optimal_A.T, R, t, reflection_detected 72 | 73 | 74 | def run_pmpnn(input_dir, output_path, pmpnn_path="../ProteinMPNN"): 75 | os.makedirs(os.path.join(input_dir, "seqs"), exist_ok=True) 76 | process = subprocess.Popen( 77 | [ 78 | "python", 79 | os.path.join(pmpnn_path, "helper_scripts/parse_multiple_chains.py"), 80 | f"--input_path={input_dir}", 81 | f"--output_path={output_path}", 82 | ] 83 | ) 84 | _ = process.wait() 85 | 86 | pmpnn_args = [ 87 | "python", 88 | os.path.join(pmpnn_path, "protein_mpnn_run.py"), 89 | "--out_folder", 90 | input_dir, 91 | "--jsonl_path", 92 | output_path, 93 | "--num_seq_per_target", 94 | "8", 95 | "--sampling_temp", 96 | "0.1", 97 | "--seed", 98 | "38", 99 | "--batch_size", 100 | "1", 101 | ] 102 | print(" ".join(pmpnn_args)) 103 | 104 | process = subprocess.run(pmpnn_args) 105 | 106 | 107 | def get_aligned_rmsd(pos_1, pos_2): 108 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0] 109 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1)) 110 | 111 | 112 | def upgrade_state_dict(state_dict, prefixes=["encoder.sentence_encoder.", "encoder."]): 113 | """Removes prefixes like 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" 114 | pattern = re.compile("^" + "|".join(prefixes)) 115 | state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} 116 | return state_dict 117 | 118 | 119 | def softmax_cross_entropy(logits, labels): 120 | loss = -1 * torch.sum( 121 | labels * torch.nn.functional.log_softmax(logits, dim=-1), 122 | dim=-1, 123 | ) 124 | return loss 125 | 126 | 127 | def compute_distogram_loss( 128 | logits, 129 | pseudo_beta, 130 | pseudo_beta_mask, 131 | min_bin=2.3125, 132 | max_bin=21.6875, 133 | no_bins=64, 134 | eps=1e-6, 135 | **kwargs, 136 | ): 137 | boundaries = torch.linspace( 138 | min_bin, 139 | max_bin, 140 | no_bins - 1, 141 | device=logits.device, 142 | ) 143 | boundaries = boundaries**2 144 | 145 | dists = torch.sum( 146 | (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, 147 | dim=-1, 148 | keepdims=True, 149 | ) 150 | 151 | true_bins = torch.sum(dists > boundaries, dim=-1) 152 | 153 | errors = softmax_cross_entropy( 154 | logits, 155 | torch.nn.functional.one_hot(true_bins, no_bins), 156 | ) 157 | 158 | square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] 159 | 160 | # FP16-friendly sum. Equivalent to: 161 | # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) / 162 | # (eps + torch.sum(square_mask, dim=(-1, -2)))) 163 | denom = eps + torch.sum(square_mask, dim=(-1, -2)) 164 | mean = errors * square_mask 165 | mean = torch.sum(mean, dim=-1) 166 | mean = mean / denom[..., None] 167 | mean = torch.sum(mean, dim=-1) 168 | 169 | # Average over the batch dimensions 170 | mean = torch.mean(mean) 171 | 172 | return mean 173 | 174 | 175 | class HarmonicPrior: 176 | def __init__(self, N=256, a=3 / (3.8**2)): 177 | J = torch.zeros(N, N) 178 | for i, j in zip(np.arange(N - 1), np.arange(1, N)): 179 | J[i, i] += a 180 | J[j, j] += a 181 | J[i, j] = J[j, i] = -a 182 | D, P = torch.linalg.eigh(J) 183 | D_inv = 1 / D 184 | D_inv[0] = 0 185 | self.P, self.D_inv = P, D_inv 186 | self.N = N 187 | 188 | def to(self, device): 189 | self.P = self.P.to(device) 190 | self.D_inv = self.D_inv.to(device) 191 | return self 192 | 193 | def sample(self, batch_dims=()): 194 | return self.P @ ( 195 | torch.sqrt(self.D_inv)[:, None] 196 | * torch.randn(*batch_dims, self.N, 3, device=self.P.device) 197 | ) 198 | 199 | 200 | def adjust_oxygen_pos(atom_37: torch.Tensor, pos_is_known=None) -> torch.Tensor: 201 | """ 202 | Imputes the position of the oxygen atom on the backbone by using adjacent frame information. 203 | Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the 204 | current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom 205 | away from the C in the current frame in the direction away from the Ca-C-N triangle. 206 | 207 | For cases where the next frame is not available, for example we are at the C-terminus or the 208 | next frame is not available in the data then we place the oxygen in the same plane as the 209 | N-Ca-C of the current frame and pointing in the same direction as the average of the 210 | Ca->C and Ca->N vectors. 211 | 212 | Args: 213 | atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering 214 | which is ['N', 'CA', 'C', 'CB', 'O', ...] 215 | pos_is_known (torch.Tensor): (N,) mask for known residues. 216 | """ 217 | 218 | N = atom_37.shape[0] 219 | assert atom_37.shape == (N, 37, 3) 220 | 221 | # Get vectors to Carbonly from Carbon alpha and N of next residue. (N-1, 3) 222 | # Note that the (N,) ordering is from N-terminal to C-terminal. 223 | 224 | # Calpha to carbonyl both in the current frame. 225 | calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( 226 | torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 227 | ) 228 | # For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0. 229 | # The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change. 230 | 231 | # Nitrogen of the next frame to carbonyl of the current frame. 232 | nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( 233 | torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 234 | ) 235 | 236 | carbonyl_to_oxygen: torch.Tensor = ( 237 | calpha_to_carbonyl + nitrogen_to_carbonyl 238 | ) # (N-1, 3) 239 | carbonyl_to_oxygen = carbonyl_to_oxygen / ( 240 | torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 241 | ) 242 | 243 | atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 244 | 245 | # Now we deal with frames for which there is no next frame available. 246 | 247 | # Calpha to carbonyl both in the current frame. (N, 3) 248 | calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( 249 | torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 250 | ) 251 | # Calpha to nitrogen both in the current frame. (N, 3) 252 | calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( 253 | torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 254 | ) 255 | carbonyl_to_oxygen_term: torch.Tensor = ( 256 | calpha_to_carbonyl_term + calpha_to_nitrogen_term 257 | ) # (N, 3) 258 | carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( 259 | torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 260 | ) 261 | 262 | # Create a mask that is 1 when the next residue is not available either 263 | # due to this frame being the C-terminus or the next residue is not 264 | # known due to pos_is_known being false. 265 | 266 | if pos_is_known is None: 267 | pos_is_known = torch.ones( 268 | (atom_37.shape[0],), dtype=torch.int64, device=atom_37.device 269 | ) 270 | 271 | next_res_gone: torch.Tensor = ~pos_is_known.bool() # (N,) 272 | next_res_gone = torch.cat( 273 | [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 274 | ) # (N+1, ) 275 | next_res_gone = next_res_gone[1:] # (N,) 276 | 277 | atom_37[next_res_gone, 4, :] = ( 278 | atom_37[next_res_gone, 2, :] + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 279 | ) 280 | 281 | return atom_37 282 | 283 | 284 | def trans_to_atom37(trans): 285 | B, L, _ = trans.shape 286 | atom37 = trans.new_zeros(B, L, 37, 3) 287 | atom37[:, :, 1] = trans 288 | return atom37 289 | 290 | 291 | def transrot_to_atom37(rigids): 292 | atom37_traj = [] 293 | B, L = rigids.shape 294 | 295 | atom37 = compute_backbone(rigids, torch.zeros((B, L, 2), device=rigids.device))[0] 296 | for i in range(B): 297 | atom37[i] = adjust_oxygen_pos(atom37[i], None) 298 | return atom37 299 | 300 | 301 | def rigid_transform_3D(A, B, verbose=False): 302 | # Transforms A to look like B 303 | # https://github.com/nghiaho12/rigid_transform_3D 304 | assert A.shape == B.shape 305 | A = A.T 306 | B = B.T 307 | 308 | num_rows, num_cols = A.shape 309 | if num_rows != 3: 310 | raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") 311 | 312 | num_rows, num_cols = B.shape 313 | if num_rows != 3: 314 | raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") 315 | 316 | # find mean column wise 317 | centroid_A = np.mean(A, axis=1) 318 | centroid_B = np.mean(B, axis=1) 319 | 320 | # ensure centroids are 3x1 321 | centroid_A = centroid_A.reshape(-1, 1) 322 | centroid_B = centroid_B.reshape(-1, 1) 323 | 324 | # subtract mean 325 | Am = A - centroid_A 326 | Bm = B - centroid_B 327 | 328 | H = Am @ np.transpose(Bm) 329 | 330 | # sanity check 331 | # if linalg.matrix_rank(H) < 3: 332 | # raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H))) 333 | 334 | # find rotation 335 | U, S, Vt = np.linalg.svd(H) 336 | R = Vt.T @ U.T 337 | 338 | # special reflection case 339 | reflection_detected = False 340 | if np.linalg.det(R) < 0: 341 | if verbose: 342 | print("det(R) < R, reflection detected!, correcting for it ...") 343 | Vt[2, :] *= -1 344 | R = Vt.T @ U.T 345 | reflection_detected = True 346 | 347 | t = -R @ centroid_A + centroid_B 348 | optimal_A = R @ A + t 349 | 350 | return optimal_A.T, R, t, reflection_detected 351 | 352 | 353 | def get_aligned_rmsd(pos_1, pos_2): 354 | aligned_pos_1 = rigid_transform_3D(pos_1, pos_2)[0] 355 | return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1)) 356 | 357 | 358 | def create_full_prot( 359 | atom37: np.ndarray, 360 | aatype=None, 361 | b_factors=None, 362 | ): 363 | assert atom37.ndim == 3 364 | assert atom37.shape[-1] == 3 365 | assert atom37.shape[-2] == 37 366 | n = atom37.shape[0] 367 | residue_index = np.arange(n) 368 | atom37_mask = (np.sum(np.abs(atom37), axis=-1) > 1e-7) & ( 369 | np.sum(np.abs(atom37[:, 1:2]), axis=-1) > 1e-7 370 | ) 371 | if b_factors is None: 372 | b_factors = np.zeros([n, 37]) 373 | if aatype is None: 374 | aatype = np.zeros(n, dtype=int) 375 | return protein.Protein( 376 | atom_positions=atom37, 377 | atom_mask=atom37_mask, 378 | aatype=aatype, 379 | residue_index=residue_index, 380 | b_factors=b_factors, 381 | ) 382 | 383 | 384 | def frames_to_pdb(frames, path): 385 | prots = [] 386 | atom37 = transrot_to_atom37(frames).cpu().numpy() 387 | for i, pos in enumerate(atom37): 388 | prots.append(create_full_prot(pos)) 389 | with open(path, "w") as f: 390 | f.write(prots_to_pdb(prots)) 391 | 392 | 393 | def trans_to_pdb(trans, path): 394 | prots = [] 395 | atom37 = trans_to_atom37(trans).cpu().numpy() 396 | for i, pos in enumerate(atom37): 397 | prots.append(create_full_prot(pos)) 398 | with open(path, "w") as f: 399 | f.write(prots_to_pdb(prots)) 400 | 401 | 402 | def atom37_to_pdb(atom37, path, mask=None): 403 | prots = [] 404 | for i, pos in enumerate(atom37): 405 | if mask is not None: 406 | pos = pos[mask[i]] 407 | prots.append(create_full_prot(pos)) 408 | with open(path, "w") as f: 409 | f.write(prots_to_pdb(prots)) 410 | 411 | 412 | def prots_to_pdb(prots): 413 | ss = "" 414 | for i, prot in enumerate(prots): 415 | ss += f"MODEL {i}\n" 416 | prot = protein.to_pdb(prot) 417 | ss += "\n".join(prot.split("\n")[1:-2]) 418 | ss += "\nENDMDL\n" 419 | return ss 420 | 421 | 422 | def compute_lddt(pos1, pos2, mask, cutoff=15.0, eps=1e-10, symmetric=False): 423 | dmat1 = torch.sqrt( 424 | eps + torch.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1) 425 | ) 426 | dmat2 = torch.sqrt( 427 | eps + torch.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1) 428 | ) 429 | if symmetric: 430 | dists_to_score = (dmat1 < cutoff) | (dmat2 < cutoff) 431 | else: 432 | dists_to_score = dmat1 < cutoff 433 | dists_to_score = ( 434 | dists_to_score 435 | * mask.unsqueeze(-2) 436 | * mask.unsqueeze(-1) 437 | * (1.0 - torch.eye(mask.shape[-1]).to(mask)) 438 | ) 439 | dist_l1 = torch.abs(dmat1 - dmat2) 440 | score = ( 441 | (dist_l1[..., None] < torch.tensor([0.5, 1.0, 2.0, 4.0]).to(pos1)) 442 | .float() 443 | .mean(-1) 444 | ) 445 | score = (dists_to_score * score).sum((-1, -2)) / dists_to_score.sum((-1, -2)) 446 | 447 | return score 448 | 449 | 450 | def compute_fape( 451 | pred_frames, 452 | target_frames, 453 | frames_mask, 454 | pred_positions, 455 | target_positions, 456 | positions_mask, 457 | length_scale, 458 | l1_clamp_distance=None, 459 | thresh=None, 460 | eps=1e-8, 461 | ) -> torch.Tensor: 462 | """ 463 | Computes FAPE loss. 464 | 465 | Args: 466 | pred_frames: 467 | [*, N_frames] Rigid object of predicted frames 468 | target_frames: 469 | [*, N_frames] Rigid object of ground truth frames 470 | frames_mask: 471 | [*, N_frames] binary mask for the frames 472 | pred_positions: 473 | [*, N_pts, 3] predicted atom positions 474 | target_positions: 475 | [*, N_pts, 3] ground truth positions 476 | positions_mask: 477 | [*, N_pts] positions mask 478 | length_scale: 479 | Length scale by which the loss is divided 480 | pair_mask: 481 | [*, N_frames, N_pts] mask to use for 482 | separating intra- from inter-chain losses. 483 | l1_clamp_distance: 484 | Cutoff above which distance errors are disregarded 485 | eps: 486 | Small value used to regularize denominators 487 | Returns: 488 | [*] loss tensor 489 | """ 490 | 491 | # [*, N_frames, N_pts, 3] 492 | local_pred_pos = pred_frames.invert()[..., None].apply( 493 | pred_positions[..., None, :, :], 494 | ) 495 | local_target_pos = target_frames.invert()[..., None].apply( 496 | target_positions[..., None, :, :], 497 | ) 498 | 499 | error_dist = torch.sqrt( 500 | torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps 501 | ) 502 | 503 | if l1_clamp_distance is not None: 504 | error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance) 505 | 506 | normed_error = error_dist / length_scale 507 | 508 | if thresh is not None: 509 | thresh_mask = torch.sqrt(torch.sum(local_target_pos**2, dim=-1)) < thresh 510 | mask = thresh_mask * frames_mask[..., None] * positions_mask[..., None, :] 511 | 512 | normed_error = normed_error * mask 513 | normed_error = torch.sum(normed_error, dim=(-1, -2)) 514 | normed_error = normed_error / (eps + torch.sum(mask, dim=(-1, -2))) 515 | 516 | else: 517 | normed_error = normed_error * frames_mask[..., None] 518 | normed_error = normed_error * positions_mask[..., None, :] 519 | 520 | normed_error = torch.sum(normed_error, dim=-1) 521 | normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] 522 | normed_error = torch.sum(normed_error, dim=-1) 523 | normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) 524 | 525 | return normed_error 526 | -------------------------------------------------------------------------------- /proteinblobs/visualize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | from pymol import cmd 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | 16 | matplotlib.use('Agg') 17 | import numpy as np 18 | import plotly.graph_objects as go 19 | import torch 20 | import os 21 | from pathlib import Path 22 | from pymol.cgo import BEGIN, COLOR, TRIANGLES, VERTEX, NORMAL, END 23 | 24 | import math 25 | import numpy as np 26 | 27 | def vertex(a1, a2, a3, u, v, M, r0): 28 | #https://github.com/vlasenkov/pymol-iellipsoid/blob/master/iellipsoid.py 29 | vrtx = M.dot(np.array([ 30 | a1 * math.cos(u) * math.cos(v), 31 | a2 * math.cos(u) * math.sin(v), 32 | a3 * math.sin(u) 33 | ])) 34 | nrml = M.dot(np.array([ 35 | math.cos(u) * math.cos(v) / a1, 36 | math.cos(u) * math.sin(v) / a2, 37 | math.sin(u) / a3 38 | ])) 39 | return vrtx + r0, nrml 40 | 41 | 42 | def draw_ellipse(r0, C, color=[0.5, 0.5, 0.5], transparency=0.0, name='', resolution=12): 43 | #https://github.com/vlasenkov/pymol-iellipsoid/blob/master/iellipsoid.py 44 | ws, vs = np.linalg.eig(C) 45 | # M = np.linalg.inv(vs) 46 | M = vs 47 | a1, a2, a3 = np.sqrt(ws) 48 | u_segs = resolution 49 | v_segs = resolution 50 | mesh = [BEGIN, TRIANGLES, COLOR] 51 | mesh.extend(color) 52 | dU = math.pi / u_segs 53 | dV = 2 * math.pi / v_segs 54 | U = -math.pi / 2 55 | for Y in range(0, u_segs): 56 | V = math.pi 57 | for X in range(0, v_segs): 58 | 59 | (x1, y1, z1), (n1x, n1y, n1z) = vertex(a1, a2, a3, 60 | U, V, M, r0) 61 | (x2, y2, z2), (n2x, n2y, n2z) = vertex(a1, a2, a3, 62 | U + dU, V, M, r0) 63 | (x3, y3, z3), (n3x, n3y, n3z) = vertex(a1, a2, a3, 64 | U + dU, V + dV, 65 | M, r0) 66 | (x4, y4, z4), (n4x, n4y, n4z) = vertex(a1, a2, a3, 67 | U, V + dV, M, r0) 68 | 69 | mesh.extend([NORMAL, n1x, n1y, n1z, VERTEX, x1, y1, z1]) 70 | mesh.extend([NORMAL, n2x, n2y, n2z, VERTEX, x2, y2, z2]) 71 | mesh.extend([NORMAL, n4x, n4y, n4z, VERTEX, x4, y4, z4]) 72 | mesh.extend([NORMAL, n2x, n2y, n2z, VERTEX, x2, y2, z2]) 73 | mesh.extend([NORMAL, n3x, n3y, n3z, VERTEX, x3, y3, z3]) 74 | mesh.extend([NORMAL, n4x, n4y, n4z, VERTEX, x4, y4, z4]) 75 | 76 | V += dV 77 | U += dU 78 | mesh.append(END) 79 | cmd.load_cgo(mesh, name) 80 | cmd.set("cgo_transparency", transparency, name) 81 | 82 | 83 | def visualize_blobs(tmp_paths, means, covars, save_path=None, ax=None): 84 | cmd.reinitialize() 85 | for mean, covar in zip(means, covars): 86 | draw_ellipse(mean, 5 * covar, transparency=0.5) 87 | for tmp_path in tmp_paths: 88 | cmd.load(tmp_path) 89 | os.remove(tmp_path) 90 | cmd.orient(Path(tmp_paths[-1]).stem) 91 | #cmd.spectrum() 92 | cmd.bg_color('white') 93 | cmd.set('ray_trace_mode', 0) 94 | cmd.set('depth_cue', 'off') 95 | cmd.set('ray_shadows', 'off') 96 | 97 | if save_path is not None: 98 | cmd.save(save_path) 99 | if ax is not None: 100 | cmd.png('/tmp/tmp.png', 640, 640) 101 | img = plt.imread('/tmp/tmp.png') 102 | os.remove('/tmp/tmp.png') 103 | 104 | if ax is None: 105 | plt.imshow(img) 106 | plt.axis('off') # Optional: Hide axis ticks and labels 107 | plt.show() 108 | else: 109 | ax.imshow(img) 110 | ax.axis('off') 111 | 112 | def plot_point_cloud(point_clouds, path=None): 113 | # Takes a list of point cloud tensors and plots them 114 | if not isinstance(point_clouds, list): 115 | point_clouds = [point_clouds] 116 | 117 | colors = ['red', 'blue', 'green', 'yellow', 'orange'] # List of colors for each point cloud 118 | traces = [] # List to store individual traces for each point cloud 119 | 120 | for i, point_cloud in enumerate(point_clouds): 121 | if isinstance(point_cloud, np.ndarray): 122 | pass 123 | elif isinstance(point_cloud, torch.Tensor): 124 | point_cloud = point_cloud.numpy() 125 | 126 | x_data = point_cloud[:, 0] 127 | y_data = point_cloud[:, 1] 128 | z_data = point_cloud[:, 2] 129 | 130 | # Create a trace for each point cloud with a different color 131 | trace = go.Scatter3d( 132 | x=x_data, 133 | y=y_data, 134 | z=z_data, 135 | mode='markers', 136 | marker=dict( 137 | size=5, 138 | opacity=0.8, 139 | color=colors[i % len(colors)] # Assign color based on the index of the point cloud 140 | ), 141 | name=f"Point Cloud {i + 1}" 142 | ) 143 | traces.append(trace) 144 | 145 | # Create the layout 146 | layout = go.Layout( 147 | scene=dict( 148 | aspectmode='data' 149 | ) 150 | ) 151 | 152 | # Create the figure and add the traces to it 153 | fig = go.Figure(data=traces, layout=layout) 154 | 155 | if path is None: 156 | fig.show() 157 | else: 158 | fig.write_html(path) -------------------------------------------------------------------------------- /proteinblobs/wrapper.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from .logger import get_logger 12 | logger = get_logger(__name__) 13 | import pytorch_lightning as pl 14 | import numpy as np 15 | import torch, time, wandb 16 | from collections import defaultdict 17 | from openfold.utils.lr_schedulers import AlphaFoldLRScheduler 18 | import os 19 | import pandas as pd 20 | 21 | def gather_log(log, world_size): 22 | if world_size == 1: 23 | return log 24 | log_list = [None] * world_size 25 | torch.distributed.all_gather_object(log_list, log) 26 | log = {key: sum([l[key] for l in log_list], []) for key in log} 27 | return log 28 | 29 | 30 | def get_log_mean(log): 31 | out = {} 32 | for key in log: 33 | try: 34 | out[key] = np.nanmean(log[key]) 35 | except: 36 | pass 37 | return out 38 | 39 | class Wrapper(pl.LightningModule): 40 | 41 | def __init__(self, args): 42 | super().__init__() 43 | self.save_hyperparameters() 44 | self.args = args 45 | self._log = defaultdict(list) 46 | self.last_log_time = time.time() 47 | self.iter_step = 0 48 | 49 | def log(self, key, data, extend=False): 50 | if isinstance(data, torch.Tensor): 51 | data = data.mean().item() 52 | log = self._log 53 | if extend: 54 | if self.stage == 'train' or self.args.validate: 55 | log["iter_" + key].extend(data) 56 | log[self.stage + "_" + key].extend(data) 57 | else: 58 | if self.stage == 'train' or self.args.validate: 59 | log["iter_" + key].append(data) 60 | log[self.stage + "_" + key].append(data) 61 | 62 | def load_ema_weights(self): 63 | logger.info('Loading EMA weights') 64 | clone_param = lambda t: t.detach().clone() 65 | self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) 66 | self.model.load_state_dict(self.ema.state_dict()["params"]) 67 | 68 | def restore_cached_weights(self): 69 | logger.info('Restoring cached weights') 70 | self.model.load_state_dict(self.cached_weights) 71 | self.cached_weights = None 72 | 73 | def on_before_zero_grad(self, *args, **kwargs): 74 | if self.args.ema: 75 | self.ema.update(self.model) 76 | 77 | def training_step(self, batch, batch_idx): 78 | self.stage = 'train' 79 | if self.args.ema: 80 | if (self.ema.device != self.device): 81 | self.ema.to(self.device) 82 | return self.general_step(batch) 83 | 84 | def validation_step(self, batch, batch_idx): 85 | self.stage = 'val' 86 | if self.args.ema: 87 | if (self.ema.device != self.device): 88 | self.ema.to(self.device) 89 | if (self.cached_weights is None): 90 | self.load_ema_weights() 91 | 92 | self.general_step(batch) 93 | self.validation_step_extra(batch, batch_idx) 94 | if self.args.validate and self.iter_step % self.args.print_freq == 0: 95 | self.print_log() 96 | 97 | def general_step(self, batch): 98 | pass 99 | 100 | def validation_step_extra(self, batch, batch_idx): 101 | pass 102 | 103 | def on_train_epoch_end(self): 104 | self.print_log(prefix='train', save=False) 105 | 106 | def on_validation_epoch_end(self): 107 | if self.args.ema: 108 | self.restore_cached_weights() 109 | self.on_validation_epoch_end_extra() 110 | self.print_log(prefix='val', save=self.args.save_val) 111 | 112 | def on_validation_epoch_end_extra(self): 113 | pass 114 | 115 | def on_before_optimizer_step(self, optimizer): 116 | if (self.trainer.global_step + 1) % self.args.print_freq == 0: 117 | self.print_log() 118 | 119 | if self.args.check_grad: 120 | for name, p in self.model.named_parameters(): 121 | if p.grad is None: 122 | logger.warning(f"Param {name} has no grad") 123 | 124 | def on_load_checkpoint(self, checkpoint): 125 | logger.info('Loading EMA state dict') 126 | if self.args.ema: 127 | ema = checkpoint["ema"] 128 | self.ema.load_state_dict(ema) 129 | 130 | def on_save_checkpoint(self, checkpoint): 131 | if self.args.ema: 132 | if self.cached_weights is not None: 133 | self.restore_cached_weights() 134 | checkpoint["ema"] = self.ema.state_dict() 135 | 136 | def print_log(self, prefix='iter', save=False, extra_logs=None): 137 | log = self._log 138 | log = {key: log[key] for key in log if f"{prefix}_" in key} 139 | log = gather_log(log, self.trainer.world_size) 140 | mean_log = get_log_mean(log) 141 | 142 | mean_log.update({ 143 | 'epoch': self.trainer.current_epoch, 144 | 'trainer_step': self.trainer.global_step + int(prefix == 'iter'), 145 | 'iter_step': self.iter_step, 146 | f'{prefix}_count': len(log[next(iter(log))]), 147 | 148 | }) 149 | if extra_logs: 150 | mean_log.update(extra_logs) 151 | try: 152 | for param_group in self.optimizers().optimizer.param_groups: 153 | mean_log['lr'] = param_group['lr'] 154 | except: 155 | pass 156 | 157 | if self.trainer.is_global_zero: 158 | logger.info(str(mean_log)) 159 | if self.args.wandb: 160 | wandb.log(mean_log) 161 | if save: 162 | path = os.path.join(os.environ["MODEL_DIR"],f"{prefix}_{self.trainer.current_epoch}.csv") 163 | max_len = max([len(v) for k,v in log.items()]) 164 | pd.DataFrame({k:v for k,v in log.items() if len(v) == max_len}).to_csv(path) 165 | for key in list(log.keys()): 166 | if f"{prefix}_" in key: 167 | del self._log[key] 168 | 169 | def configure_optimizers(self): 170 | cls = torch.optim.AdamW if self.args.adamW else torch.optim.Adam 171 | optimizer = cls( 172 | filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr, 173 | ) 174 | 175 | if self.args.lr_scheduler: 176 | lr_scheduler = AlphaFoldLRScheduler(optimizer, max_lr=self.args.lr) 177 | return { 178 | "optimizer": optimizer, 179 | "lr_scheduler": { 180 | "scheduler": lr_scheduler, 181 | "interval": "step", 182 | "name": "AlphaFoldLRScheduler", 183 | } 184 | } 185 | else: 186 | return optimizer 187 | 188 | -------------------------------------------------------------------------------- /run_multiflow.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from omegaconf import OmegaConf 12 | import os, torch 13 | import pytorch_lightning as pl 14 | import numpy as np 15 | from proteinblobs.multiflow_wrapper import MultiflowWrapper 16 | from openfold.np import protein 17 | from proteinblobs.utils import upgrade_state_dict, create_full_prot 18 | import argparse 19 | 20 | args = argparse.Namespace() 21 | args.__dict__.update({ 22 | 'multiflow_yaml': "../multiflow/weights/config.yaml" 23 | }) 24 | model = MultiflowWrapper(args) 25 | ckpt_path = "../multiflow/weights/last.ckpt" 26 | ckpt = torch.load(ckpt_path, map_location='cpu') 27 | model.load_state_dict(ckpt['state_dict']) 28 | model.eval().cuda() 29 | 30 | out, _ = model.inference(num_batch=2, sample_length=128) 31 | 32 | for i, prot in enumerate(out): 33 | prot = create_full_prot(prot) 34 | 35 | with open(f'test{i}.pdb', 'w') as f: 36 | f.write(protein.to_pdb(prot)) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--ckpt', default='workdir/tune_SC_elli_warmup/epoch=12-step=59683.ckpt') 15 | parser.add_argument('--num_blobs', type=int, default=5) 16 | parser.add_argument('--num_prots', type=int, default=5) 17 | parser.add_argument('--nu', type=int, default=10) # vary the blob anisotropy [10, 20, 50, 100] 18 | parser.add_argument('--psi', type=int, default=5) # FIX this to give proteins with O(200) residues 19 | parser.add_argument('--sigma', type=int, default=8) # vary this compactness parameter 20 | parser.add_argument('--helix_frac', type=float, default=0.5) 21 | parser.add_argument('--num_residues', type=int, default=160) 22 | parser.add_argument('--seed', type=int, default=137) 23 | parser.add_argument('--num_timesteps', type=int, default=500) 24 | parser.add_argument('--inference_rot_scaling', type=float, default=10) 25 | parser.add_argument('--multiflow', action='store_true') 26 | parser.add_argument('--outdir', default='outpdb/default') 27 | parser.add_argument('--guidance', type=float, default=None) 28 | args = parser.parse_args() 29 | 30 | from pymol import cmd 31 | import torch 32 | from proteinblobs.parsing import parse_train_args 33 | from proteinblobs.multiflow_wrapper import MultiflowWrapper 34 | from proteinblobs import blobs 35 | from proteinblobs.utils import atom37_to_pdb 36 | import os, contextlib 37 | import numpy as np 38 | import tqdm 39 | 40 | @contextlib.contextmanager 41 | def temp_seed(seed): 42 | state = np.random.get_state() 43 | np.random.seed(seed) 44 | torch_state = torch.seed() 45 | torch.manual_seed(seed) 46 | try: 47 | yield 48 | finally: 49 | np.random.set_state(state) 50 | torch.manual_seed(seed) 51 | 52 | 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | os.makedirs(args.outdir, exist_ok=True) 56 | blobss = [] 57 | for _ in tqdm.tqdm(range(args.num_prots), desc='sampling dataset'): 58 | pos, covar = blobs.sample_blobs(args.num_blobs, nu=args.nu, psi=(1/args.nu)*args.psi**2*np.eye(3), sigma=args.sigma) 59 | is_helix = np.random.rand(args.num_blobs) < args.helix_frac 60 | volume = np.linalg.det(covar)**0.5 61 | 62 | counts = np.where( 63 | is_helix, 64 | blobs.alpha_slope * volume + blobs.alpha_intercept, 65 | blobs.beta_slope * volume + blobs.beta_intercept, 66 | ).astype(int) 67 | 68 | dssp = np.where(is_helix, 1, 2) 69 | blobss.append((pos, covar, counts, dssp)) 70 | 71 | 72 | if args.multiflow: 73 | 74 | args_ = argparse.Namespace() 75 | args_.__dict__.update({ 76 | 'finetune': True, 77 | 'multiflow_yaml': "multiflow_config.yaml", 78 | 'num_timesteps': args.num_timesteps, 79 | 'self_condition': True, 80 | 'blob_attention': False, 81 | 'extra_attn_layer': False, 82 | 'freeze_weights': True, 83 | 'pretrained_mf_path': 'weights/last.ckpt', 84 | 'inference_rot_scaling': args.inference_rot_scaling, 85 | }) 86 | 87 | model = MultiflowWrapper(args_).eval().cuda() 88 | else: 89 | ckpt = torch.load(args.ckpt) 90 | 91 | args_ = ckpt['hyper_parameters']['args'] 92 | if args.guidance is not None: #--guidance --sc_guidance_mode separate --interpolate_sc --interpolate_sc_weight 0.8 --guidance_weight 0.8 --seq_guidance_weight 0.8 93 | args_.__dict__.update({ 94 | 'guidance': True, 95 | 'sc_guidance_mode': 'separate', 96 | 'num_timesteps': args.num_timesteps, 97 | 'interpolate_sc': True, 98 | 'pretrained_mf_path': 'weights/last.ckpt', 99 | 'interpolate_sc_weight': args.guidance, 100 | 'guidance_weight': args.guidance, 101 | 'seq_guidance_weight': args.guidance, 102 | }) 103 | model = MultiflowWrapper(args_) 104 | model.load_state_dict(ckpt['state_dict']) 105 | model = model.eval().cuda() 106 | 107 | if args.guidance is not None: 108 | with temp_seed(args.seed): 109 | model.on_validation_epoch_start() 110 | 111 | for i, (pos, covar, counts, dssp) in tqdm.tqdm(enumerate(blobss), desc='running inference'): 112 | 113 | batch = {} 114 | if args.multiflow: 115 | batch['res_mask'] = np.ones((1, args.num_residues), dtype=np.float32) 116 | batch['grounding_pos'] = batch['grounding_feat'] = batch['grounding_mask'] = np.zeros(1) 117 | else: 118 | batch['res_mask'] = np.ones((1, counts.sum()), dtype=np.float32) 119 | batch['grounding_pos'] = pos[None].astype(np.float32) 120 | batch['grounding_feat'] = np.zeros((1, args.num_blobs, 11), dtype=np.float32) 121 | batch['grounding_feat'][:,:,0] = dssp # all helix for now 122 | batch['grounding_feat'][:,:,1] = counts 123 | batch['grounding_feat'][:,:,2:] = covar.reshape(1, args.num_blobs, 9) 124 | batch['grounding_mask'] = np.ones((1, args.num_blobs), dtype=np.float32) 125 | 126 | 127 | batch = {k: torch.from_numpy(v).cuda() for k, v in batch.items()} 128 | batch['latents'] = 0 129 | if not args.multiflow: 130 | np.savez(f'{args.outdir}/blobs_{i}.npz', pos=pos, covar=covar, counts=counts, dssp=dssp) 131 | with torch.no_grad(): 132 | samples, sequences = model.inference(batch) 133 | atom37_to_pdb(samples, f'{args.outdir}/blobs_{i}.pdb') 134 | np.save(f'{args.outdir}/blobs_{i}_sequence.npy', sequences[0]) 135 | -------------------------------------------------------------------------------- /scripts/ca_only_helicity.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | import tqdm 13 | import argparse 14 | import numpy as np 15 | 16 | import biotite.structure.io.pdb as pdb 17 | import biotite.structure as struc 18 | import numpy as np 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dir', required=True) 23 | parser.add_argument('--outdir', type=str, default=None) 24 | args = parser.parse_args() 25 | 26 | 27 | 28 | 29 | if args.outdir is None: 30 | args.outdir = args.dir 31 | 32 | 33 | def calculate_secondary_structure(pdb_file): 34 | """ 35 | Calculate secondary structures using Biotite. 36 | """ 37 | file = pdb.PDBFile.read(pdb_file) 38 | array = file.get_structure() 39 | sse = struc.annotate_sse(array[0]) 40 | 41 | helices = (sse == 'a').sum() # Helices 42 | strands = (sse == 'b').sum() # Strands 43 | coils = (sse == 'c').sum() # Coils 44 | total = len(sse) 45 | 46 | helix_percentage = helices / total if total > 0 else 0 47 | strand_percentage = strands / total if total > 0 else 0 48 | coil_percentage = coils / total if total > 0 else 0 49 | 50 | return helix_percentage, strand_percentage, coil_percentage 51 | 52 | def preprocess_pdb_files(pdb_files): 53 | """ 54 | Ensure every PDB file listed in the specified text file has valid occupancy and B-factor values otherwise will be filled with a default value 55 | """ 56 | for pdb_path in pdb_files: 57 | pdb_path = str(pdb_path) 58 | corrected_content = [] 59 | with open(pdb_path, 'r') as file: 60 | for line in file: 61 | if line.startswith(("ATOM", "HETATM")): 62 | # Occupancy is located from columns 55 to 60 In Case an Error is thrown 63 | occupancy_value = line[54:60].strip() 64 | if not occupancy_value: 65 | line = line[:54] + '1.00' + line[60:] # Insert default occupancy In Case an Error is thrown 66 | 67 | # B-factor is located from columns 61 to 66 In Case an Error is thrown 68 | b_factor_value = line[60:66].strip() 69 | if not b_factor_value: 70 | line = line[:60] + '0.00' + line[66:] # Insert default B-factor In Case an Error is thrown 71 | 72 | corrected_content.append(line) 73 | 74 | with open(pdb_path, 'w') as file: 75 | file.writelines(corrected_content) 76 | 77 | files = os.listdir(args.dir) 78 | files = sorted([f for f in files if '.pdb' in f]) 79 | 80 | helicity = [] 81 | for i, file in tqdm.tqdm(enumerate(files)): 82 | preprocess_pdb_files([f"{args.dir}/{file}"]) 83 | helix_percent, _, _ = calculate_secondary_structure(f"{args.dir}/{file}") 84 | helicity.append(helix_percent) 85 | old_dict = dict(np.load(f"{args.outdir}/res.npz")) 86 | old_dict['helicity'] = np.array(helicity) 87 | np.savez(f"{args.outdir}/res.npz", **old_dict) 88 | 89 | 90 | -------------------------------------------------------------------------------- /scripts/create_multiflowdb.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from proteinblobs.parsing import parse_train_args 12 | 13 | args = parse_train_args() 14 | 15 | import torch 16 | from omegaconf import OmegaConf 17 | import numpy as np 18 | from proteinblobs.multiflow.flow_model import FlowModel 19 | from proteinblobs.multiflow.data.interpolant import Interpolant 20 | from proteinblobs.multiflow.data import utils as du 21 | from proteinblobs.utils import atom37_to_pdb 22 | from proteinblobs.dataset import seq_collate, GenieDBDataset 23 | 24 | import copy 25 | from tqdm import tqdm 26 | 27 | cfg = OmegaConf.load("multiflow_config.yaml") 28 | cfg.model.edge_features.self_condition = ( 29 | args.self_condition 30 | ) # only the cfg.interpolant.self_condition is used anywhere in the code 31 | cfg.interpolant.self_condition = args.self_condition 32 | 33 | print("init model") 34 | args_uncond = copy.deepcopy(args) 35 | args_uncond.freeze_weights = True 36 | args_uncond.extra_attn_layer = False 37 | args_uncond.blob_attention = False 38 | model_uncond = FlowModel(cfg.model, args_uncond) 39 | ckpt = torch.load(args.pretrained_mf_path, map_location="cuda") 40 | model_uncond.load_state_dict( 41 | {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} 42 | ) 43 | model_uncond = model_uncond.to("cuda") 44 | 45 | print("init dataset") 46 | dataset = GenieDBDataset(args) 47 | dataset[0] 48 | loader = torch.utils.data.DataLoader( 49 | dataset, 50 | batch_size=args.batch_size, 51 | num_workers=args.num_workers, 52 | collate_fn=seq_collate, 53 | shuffle=True, 54 | ) 55 | 56 | interpolant = Interpolant(cfg.interpolant, args) 57 | interpolant.set_device("cuda") 58 | for batch in tqdm(loader): 59 | true_bb_pos = None 60 | num_batch, sample_length = batch["mask"].shape 61 | prot_traj, model_traj = interpolant.sample( 62 | batch["mask"].cuda(), 63 | model_uncond, 64 | num_timesteps=args.num_timesteps, 65 | separate_t=cfg.inference.interpolant.codesign_separate_t, 66 | ) 67 | diffuse_mask = torch.ones(1, sample_length) 68 | atom37_traj = [x[0] for x in prot_traj] 69 | atom37_model_traj = [x[0] for x in model_traj] 70 | 71 | bb_trajs = du.to_numpy(torch.stack(atom37_traj, dim=0).transpose(0, 1)) 72 | noisy_traj_length = bb_trajs.shape[1] 73 | assert bb_trajs.shape == (num_batch, noisy_traj_length, sample_length, 37, 3) 74 | 75 | model_trajs = du.to_numpy(torch.stack(atom37_model_traj, dim=0).transpose(0, 1)) 76 | clean_traj_length = model_trajs.shape[1] 77 | assert model_trajs.shape == (num_batch, clean_traj_length, sample_length, 37, 3) 78 | 79 | aa_traj = [x[1] for x in prot_traj] 80 | clean_aa_traj = [x[1] for x in model_traj] 81 | 82 | aa_trajs = du.to_numpy(torch.stack(aa_traj, dim=0).transpose(0, 1).long()) 83 | assert aa_trajs.shape == (num_batch, noisy_traj_length, sample_length) 84 | for i in range(aa_trajs.shape[0]): 85 | for j in range(aa_trajs.shape[2]): 86 | if aa_trajs[i, -1, j] == du.MASK_TOKEN_INDEX: 87 | print("WARNING mask in predicted AA") 88 | aa_trajs[i, -1, j] = 0 89 | clean_aa_trajs = du.to_numpy( 90 | torch.stack(clean_aa_traj, dim=0).transpose(0, 1).long() 91 | ) 92 | assert clean_aa_trajs.shape == (num_batch, clean_traj_length, sample_length) 93 | 94 | samples = bb_trajs[:, -1] 95 | for i, sample in enumerate(samples): 96 | sample_path = f"/lustre/fsw/portfolios/nvr/projects/nvr_lpr_compgenai/hstaerk_bjing/multiflow_output/{batch['name'][i].replace('.pdb', '')}_{np.random.randint(100000)}.pdb" 97 | atom37_to_pdb(sample[None], sample_path) 98 | -------------------------------------------------------------------------------- /scripts/evaluate_alignment.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import os 12 | from proteinblobs import blobs as pb 13 | import esm 14 | from openfold.np import protein 15 | import tqdm 16 | import argparse 17 | import pydssp 18 | import pandas as pd 19 | import numpy as np 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--dir", required=True) 23 | args = parser.parse_args() 24 | 25 | files = os.listdir(args.dir) 26 | files = sorted([f for f in files if ".pdb" in f]) 27 | 28 | df = [] 29 | for i, pdb in tqdm.tqdm(enumerate(files)): 30 | prot = protein.from_pdb_string(open(f"{args.dir}/{pdb}").read()) 31 | blobs = dict(np.load(f"{args.dir}/{pdb.replace('pdb', 'npz')}")) 32 | 33 | pos = np.concatenate( 34 | [prot.atom_positions[:, :3, :], prot.atom_positions[:, 4:5, :]], axis=1 35 | ) # (L, 4, 3) 36 | dssp = pydssp.assign(pos, out_type="index") 37 | pos = pos[:, 1] 38 | blobs = [ 39 | { 40 | "pos": blobs["pos"][i], 41 | "covar": blobs["covar"][i], 42 | "dssp": blobs["dssp"][i], 43 | "count": blobs["counts"][i], 44 | } 45 | for i in range(len(blobs["dssp"])) 46 | ] 47 | df.append( 48 | { 49 | "name": pdb, 50 | "coverage": pb.blob_coverage(pos, dssp, blobs, structured_only=True), 51 | "misplacement": pb.blob_misplacement( 52 | pos, dssp, blobs, structured_only=True 53 | ), 54 | "accuracy": pb.blob_accuracy(pos, dssp, blobs, structured_only=True), 55 | "likelihood": pb.blob_likelihood(pos, dssp, blobs, structured_only=True), 56 | "soft_accuracy": pb.soft_blob_accuracy( 57 | pos, dssp, blobs, structured_only=True 58 | ), 59 | "reblob_jsd": pb.reblob_jsd(pos, dssp, blobs, use_dssp=True), 60 | "shannon_6.5": pb.shannon_complexity(pos, dssp, 6.5), 61 | } 62 | ) 63 | df = pd.DataFrame(df) 64 | df.set_index("name").to_csv(f"{args.dir}/alignment.csv") 65 | print(df) 66 | for colname in df.columns.tolist(): 67 | if colname != "name": 68 | print(f"{colname}: {df[colname].mean()}") 69 | -------------------------------------------------------------------------------- /scripts/evaluate_designability.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import os, torch 13 | from biotite.sequence.io import fasta 14 | from proteinblobs.designability_utils import run_pmpnn, get_aligned_rmsd 15 | 16 | import esm 17 | from openfold.np import protein, residue_constants 18 | from proteinblobs.utils import atom37_to_pdb 19 | import json 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--dir", required=True) 24 | parser.add_argument("--filter_missing_res", action="store_true", default=False) 25 | parser.add_argument("--ca_only", action="store_true", default=False) 26 | args = parser.parse_args() 27 | 28 | tmpdir = f"/tmp/{os.getpid()}" 29 | os.makedirs(tmpdir, exist_ok=True) 30 | 31 | print("Loading ESMFold model for designability evaluation") 32 | esmf_model = esm.pretrained.esmfold_v1().eval() 33 | esmf_model = esmf_model.to("cuda") 34 | 35 | 36 | def designability(path, seqs_per_struct=8): 37 | sample_path = os.path.join(tmpdir, f"sample.pdb") 38 | if args.filter_missing_res: 39 | prot = protein.from_pdb_string(open(path, "r").read()) 40 | ca_idx = residue_constants.atom_order["CA"] 41 | c_idx = residue_constants.atom_order["C"] 42 | n_idx = residue_constants.atom_order["N"] 43 | present_mask = prot.atom_mask[:, [ca_idx, c_idx, n_idx]].all(-1) 44 | pos = prot.atom_positions[present_mask] 45 | atom37_to_pdb(pos[None], sample_path) 46 | else: 47 | os.system(f"cp {path} {tmpdir}/sample.pdb") 48 | 49 | run_pmpnn( 50 | pdb_dir=tmpdir, 51 | num_seqs=seqs_per_struct, 52 | pmpnn_path="/data/rbg/users/hstark/ProteinMPNN", 53 | ca_only=args.ca_only, 54 | ) 55 | mpnn_fasta_path = os.path.join( 56 | tmpdir, "seqs", os.path.basename(sample_path).replace(".pdb", ".fa") 57 | ) 58 | fasta_seqs = fasta.FastaFile.read(mpnn_fasta_path) 59 | 60 | seqs = [v for k, v in fasta_seqs.items()] 61 | seqs = seqs[ 62 | 1: 63 | ] # remove the first sequence, which is the input sequence (and AAAAAAA... if there is no input sequence to the structure) 64 | rmsds = [] 65 | for j, seq in enumerate(seqs): 66 | seq = seq.replace("X", "A") 67 | 68 | with torch.no_grad(): 69 | output = esmf_model.infer(seq) 70 | 71 | with open(sample_path) as f: 72 | prot = protein.from_pdb_string(f.read()) 73 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy() 74 | rmsd = get_aligned_rmsd(prot.atom_positions[:, 1], out_ca_pos) 75 | rmsds.append(rmsd) 76 | return {"rmsd": rmsds} 77 | 78 | 79 | files = os.listdir(args.dir) 80 | files = [f for f in files if ".pdb" in f] 81 | out = {} 82 | for f in files: 83 | out[f] = designability(f"{args.dir}/{f}") 84 | with open(f"{args.dir}/designability.json", "w") as f: 85 | f.write(json.dumps(out, indent=4)) 86 | -------------------------------------------------------------------------------- /scripts/evaluate_designability_codesign.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | 12 | import os, torch 13 | from biotite.sequence.io import fasta 14 | import esm 15 | import numpy as np 16 | from openfold.np import protein, residue_constants 17 | from proteinblobs.designability_utils import get_aligned_rmsd 18 | from proteinblobs.utils import atom37_to_pdb 19 | import json 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--dir", required=True) 24 | parser.add_argument("--filter_missing_res", action="store_true", default=False) 25 | parser.add_argument("--ca_only", action="store_true", default=False) 26 | args = parser.parse_args() 27 | 28 | tmpdir = f"/tmp/{os.getpid()}" 29 | os.makedirs(tmpdir, exist_ok=True) 30 | 31 | print("Loading ESMFold model for designability evaluation") 32 | esmf_model = esm.pretrained.esmfold_v1().eval() 33 | esmf_model = esmf_model.to("cuda") 34 | 35 | 36 | def designability(path): 37 | sample_path = os.path.join(tmpdir, f"sample.pdb") 38 | if args.filter_missing_res: 39 | prot = protein.from_pdb_string(open(path, "r").read()) 40 | ca_idx = residue_constants.atom_order["CA"] 41 | c_idx = residue_constants.atom_order["C"] 42 | n_idx = residue_constants.atom_order["N"] 43 | present_mask = prot.atom_mask[:, [ca_idx, c_idx, n_idx]].all(-1) 44 | pos = prot.atom_positions[present_mask] 45 | atom37_to_pdb(pos[None], sample_path) 46 | else: 47 | os.system(f"cp {path} {tmpdir}/sample.pdb") 48 | 49 | seq_numeric = np.load(path.replace(".pdb", "_sequence.npy")) 50 | seq = "".join([residue_constants.restypes_with_x[r] for r in seq_numeric]) 51 | seq = seq.replace("X", "A") 52 | 53 | with torch.no_grad(): 54 | output = esmf_model.infer(seq) 55 | 56 | with open(sample_path) as f: 57 | prot = protein.from_pdb_string(f.read()) 58 | out_ca_pos = output["positions"][-1].squeeze()[:, 2].cpu().numpy() 59 | rmsd = get_aligned_rmsd(prot.atom_positions[:, 1], out_ca_pos) 60 | 61 | return {"rmsd": [rmsd]} 62 | 63 | 64 | files = os.listdir(args.dir) 65 | files = [f for f in files if ".pdb" in f] 66 | out = {} 67 | for f in files: 68 | out[f] = designability(f"{args.dir}/{f}") 69 | with open(f"{args.dir}/designability_codesign.json", "w") as f: 70 | f.write(json.dumps(out, indent=4)) 71 | -------------------------------------------------------------------------------- /scripts/preprocess_clusters.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | import argparse 11 | import pickle 12 | from sklearn.cluster import SpectralClustering 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--num_workers", type=int, default=1) 16 | parser.add_argument("--worker_id", type=int, default=0) 17 | parser.add_argument("--res_per_cluster", type=int, default=20) 18 | parser.add_argument("--dssp_weight", type=float, default=3) 19 | parser.add_argument("--seq_weight", type=float, default=0.1) 20 | parser.add_argument( 21 | "--pkl_dir", type=str, default="data/multiflow/train_set/processed_pdb" 22 | ) 23 | parser.add_argument( 24 | "--target_dir", type=str, default="data/multiflow/train_set/processed_pdb_clusters" 25 | ) 26 | parser.add_argument("--flat_dir", action="store_true") 27 | 28 | 29 | args = parser.parse_args() 30 | 31 | import math 32 | import os 33 | import tqdm 34 | import numpy as np 35 | import pydssp 36 | 37 | 38 | if args.flat_dir: 39 | files = os.listdir(args.pkl_dir) 40 | else: 41 | paths = os.listdir(args.pkl_dir) 42 | files = [] 43 | for path in paths: 44 | files.extend(os.listdir(f"{args.pkl_dir}/{path}")) 45 | print(len(files)) 46 | 47 | 48 | def do_job(file): 49 | if args.flat_dir: 50 | file_path = f"{args.pkl_dir}/{file}" 51 | else: 52 | file_path = f"{args.pkl_dir}/{file[1:3]}/{file}" 53 | try: 54 | with open(file_path, "rb") as f: 55 | prot = pickle.load(f) 56 | except: 57 | print("Failure", file) 58 | return None 59 | 60 | mask = prot["bb_mask"].astype(bool) 61 | atom37 = prot["atom_positions"][mask] 62 | 63 | bb_pos = np.concatenate([atom37[:, :3, :], atom37[:, 4:5, :]], axis=1) # (L, 4, 3) 64 | 65 | try: 66 | dssp = pydssp.assign( 67 | bb_pos, out_type="index" 68 | ) # 0: loop, 1: alpha-helix, 2: beta-strand 69 | except: 70 | print("Failure", file, bb_pos.shape) 71 | return None 72 | 73 | pos = atom37[:, 1, :] 74 | n_clusters = math.ceil(len(pos) / args.res_per_cluster) 75 | 76 | dssp_one_hot = dssp[:, None] == np.arange(3) 77 | feat = np.concatenate( 78 | [ 79 | pos, 80 | dssp_one_hot * args.dssp_weight, 81 | np.arange(len(dssp))[:, None] * args.seq_weight, 82 | ], 83 | axis=1, 84 | ) 85 | 86 | distmat = np.square(feat[None] - feat[:, None]).sum(-1) ** 0.5 87 | r = 2 88 | W = -(distmat**2) / (2 * r**2) 89 | W = np.exp(W) 90 | n_clusters = math.ceil(len(dssp) / args.res_per_cluster) 91 | labels = ( 92 | SpectralClustering(n_clusters=n_clusters, affinity="precomputed").fit(W).labels_ 93 | ) 94 | 95 | dssp_count = np.zeros((n_clusters, 3)) 96 | np.add.at(dssp_count, labels, dssp_one_hot) 97 | 98 | centers = np.zeros((n_clusters, 3)) 99 | np.mean.at(centers, labels, pos) 100 | 101 | clustering = { 102 | "centers": centers, 103 | "labels": labels, 104 | "dssp_count": dssp_count, 105 | } 106 | 107 | return clustering 108 | 109 | 110 | out = {} 111 | for job in tqdm.tqdm(files): 112 | out[job] = do_job(job) 113 | 114 | with open(args.target_dir, "wb") as f: 115 | f.write(pickle.dumps(out)) 116 | -------------------------------------------------------------------------------- /scripts/unpack_geniepdb.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | import numpy as np 12 | import argparse, tqdm 13 | from openfold.np import protein 14 | import openfold.np.residue_constants as rc 15 | import pandas as pd 16 | from multiprocessing import Pool 17 | import pydssp 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--indir', type=str, default='../afdb') 21 | parser.add_argument('--outdir', type=str, default='./afdb_npz') 22 | parser.add_argument('--outcsv', type=str, default='./afdb.csv') 23 | parser.add_argument('--num_workers', type=int, default=15) 24 | args = parser.parse_args() 25 | 26 | 27 | def main(): 28 | pdbs = list(map(lambda x: x.strip(), open('index.txt'))) 29 | if args.num_workers > 1: 30 | p = Pool(args.num_workers) 31 | p.__enter__() 32 | __map__ = p.imap 33 | else: 34 | __map__ = map 35 | infos = list(tqdm.tqdm(__map__(unpack_mmcif, pdbs), total=len(pdbs))) 36 | if args.num_workers > 1: 37 | p.__exit__(None, None, None) 38 | df = pd.DataFrame(infos).set_index('name') 39 | df.to_csv(args.outcsv) 40 | 41 | def unpack_mmcif(name): 42 | path = f"{args.indir}/{name}.pdb" 43 | 44 | with open(path, 'r') as f: 45 | prot = protein.from_pdb_string(f.read()) 46 | 47 | atom37 = prot.atom_positions 48 | # take N, CA, C, and O. The order in atom37 is N, CA, C, CB, O ... (see atom_types in residue_constants.py) 49 | bb_pos = np.concatenate([atom37[ :, :3, :], atom37[ :, 4:5, :]], axis=1) # (L, 4, 3) 50 | dssp = pydssp.assign(bb_pos, out_type='index') # 0: loop, 1: alpha-helix, 2: beta-strand 51 | 52 | data = { 53 | 'atom_positions': prot.atom_positions, 54 | 'aatype': prot.aatype, 55 | 'atom_mask': prot.atom_mask, 56 | 'b_factors': prot.b_factors, 57 | 'dssp': dssp, 58 | } 59 | np.savez(f"{args.outdir}/{name}.npz", **data) 60 | 61 | out = { 62 | 'name': name, 63 | 'seqres': "".join([rc.restypes[x] for x in prot.aatype]), 64 | 'length': len(prot.aatype), 65 | 'plddt': 100 - prot.b_factors.mean(), 66 | 'loop_frac': (dssp == 0).mean(), 67 | 'helix_frac': (dssp == 1).mean(), 68 | 'sheet_frac': (dssp == 2).mean(), 69 | } 70 | 71 | return out 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | 11 | from proteinblobs.parsing import parse_train_args 12 | args = parse_train_args() 13 | from proteinblobs.logger import get_logger 14 | logger = get_logger(__name__) 15 | from proteinblobs.multiflow_wrapper import MultiflowWrapper 16 | 17 | 18 | from omegaconf import OmegaConf 19 | import torch, wandb, os 20 | from proteinblobs.dataset import SeqCollate 21 | from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary 22 | from proteinblobs.dataset import GenieDBDataset, MultiflowDataset 23 | import pytorch_lightning as pl 24 | from pytorch_lightning.utilities import rank_zero_only 25 | import hashlib 26 | import numpy as np 27 | 28 | torch.set_float32_matmul_precision('medium') 29 | 30 | torch.manual_seed(args.seed) 31 | np.random.seed(args.seed) 32 | 33 | @rank_zero_only 34 | def init_wandb(): 35 | wandb.init( 36 | entity="no-graining-mit", 37 | settings=wandb.Settings(start_method="fork"), 38 | project="proteinblobs", 39 | name=args.run_name, 40 | id=hashlib.md5(str(args).encode("utf-8")).hexdigest(), 41 | resume='allow', # https://docs.wandb.ai/ref/python/init ----> "allow": if id is set with init(id="UNIQUE_ID") or WANDB_RUN_ID="UNIQUE_ID" and it is identical to a previous run, wandb will automatically resume the run with that id. 42 | config=args, 43 | ) 44 | 45 | if args.wandb: 46 | init_wandb() 47 | 48 | args.__dict__.update({ 49 | 'multiflow_yaml': "multiflow_config.yaml" 50 | }) 51 | cfg = OmegaConf.load(args.multiflow_yaml) 52 | 53 | val_len = 1000 54 | if args.dataset == 'genie': 55 | dataset = GenieDBDataset(args) 56 | trainset, valset = torch.utils.data.random_split(dataset, [len(dataset) - val_len, val_len]) 57 | elif args.dataset == 'multiflow': 58 | trainset = MultiflowDataset(args, cfg.pdb_dataset) 59 | valset = MultiflowDataset(args, cfg.pdb_post2021_dataset) 60 | trainset[0] 61 | valset[0] 62 | print('len trainset', len(trainset)) 63 | print('len valset', len(valset)) 64 | 65 | 66 | 67 | train_loader = torch.utils.data.DataLoader( 68 | trainset, 69 | batch_size=args.batch_size, 70 | num_workers=args.num_workers, 71 | collate_fn=SeqCollate(args), 72 | shuffle=True, 73 | ) 74 | val_loader = torch.utils.data.DataLoader( 75 | valset, 76 | batch_size=args.batch_size, 77 | collate_fn=SeqCollate(args), 78 | num_workers=args.num_workers, 79 | shuffle=False 80 | ) 81 | 82 | model = MultiflowWrapper(args, cfg) 83 | 84 | trainer = pl.Trainer( 85 | accelerator="gpu" if torch.cuda.is_available() else 'auto', 86 | strategy='ddp', 87 | max_epochs=args.epochs, 88 | limit_train_batches=args.train_batches or 1.0, 89 | limit_val_batches=0.0 if args.no_validate else (args.val_batches or 1.0), 90 | num_sanity_val_steps=0, 91 | enable_progress_bar=not args.wandb, 92 | gradient_clip_val=args.grad_clip, 93 | default_root_dir=os.environ["MODEL_DIR"], 94 | callbacks=[ 95 | ModelCheckpoint( 96 | dirpath=os.environ["MODEL_DIR"], 97 | save_top_k=-1, 98 | every_n_epochs=args.ckpt_freq, 99 | ), 100 | ModelSummary(max_depth=2), 101 | ], 102 | accumulate_grad_batches=args.accumulate_grad, 103 | val_check_interval=args.val_freq, 104 | check_val_every_n_epoch=args.val_epoch_freq, 105 | logger=False 106 | ) 107 | 108 | 109 | if args.ckpt is not None: 110 | # if there is an hpc checkpoint in the current workdir, then we use that to resume instead of the file in the args.ckpt 111 | if any(['hpc' in file for file in os.listdir(os.environ['MODEL_DIR'])]): 112 | ckpt_path = 'hpc' 113 | else: 114 | ckpt_path = args.ckpt 115 | else: 116 | ckpt_path = None 117 | 118 | if args.validate: 119 | trainer.validate(model, val_loader, ckpt_path=ckpt_path) 120 | else: 121 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path) 122 | --------------------------------------------------------------------------------