├── .gitignore ├── README.md ├── config ├── best.yaml ├── train_coords_only.yaml ├── train_no_position.yaml ├── train_no_self_cond.yaml ├── train_rotations_linear.yaml ├── train_rotations_quadratic.yaml ├── train_translations_linear.yaml ├── train_translations_logarithmic.yaml └── train_translations_sigmoid.yaml ├── envs ├── environment.yml └── pip_requirements.txt ├── figures └── 3ulu_ensemble.png ├── loopgen ├── __init__.py ├── __main__.py ├── data.py ├── distributions.py ├── graph.py ├── model │ ├── __init__.py │ ├── datamodule.py │ ├── generate.py │ ├── metrics.py │ ├── model.py │ ├── network.py │ ├── settings.py │ ├── train.py │ ├── types.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── diffusion.py │ └── gvp.py ├── structure.py ├── utils.py └── visualisation.py ├── pytest.ini ├── setup.py ├── setup_env.sh └── tests ├── __init__.py ├── nn ├── __init__.py └── test_gvp.py ├── test_data.py ├── test_distributions.py ├── test_graph.py ├── test_structure.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | .idea/ 4 | envs/git_access_token* 5 | .DS_Store/ 6 | .DS_Store 7 | .pytest_cache 8 | data 9 | .githooks 10 | 11 | *.log 12 | 13 | 14 | lightning_logs/ 15 | mlruns/ 16 | notebooks/ 17 | 18 | slurm* 19 | *.out 20 | *.pkl 21 | 22 | *.egg-info/ 23 | build/ 24 | 25 | notebooks -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoopGen: De novo design of antibody CDR loops with SE(3) diffusion models 2 | 3 | LoopGen is a python package providing functionality for diffusion models for CDR binding loop design. 4 | Read more in our [paper](https://arxiv.org/abs/2310.07051). 5 | 6 | Our model performs diffusion over the space of residue orientations and positions, generating 7 | diverse loop conformations that fit the geometry of the target epitope. Currently, the model does 8 | not generate sequences for the CDR loops - only backbone conformations. We hope to extend our model 9 | to generate sequences in the future. 10 | 11 | ![figure](figures/3ulu_ensemble.png) 12 | 13 | ## Setting up 14 | 15 | --- 16 | 17 | Follow the steps below to set up the virtual environment for this project. 18 | 19 | If you have not installed miniconda on your computer, follow the instructions here to 20 | [install miniconda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). 21 | 22 | Once you have miniconda, you need to install `mamba` in the base environment: 23 | 24 | ``` 25 | conda install mamba -n base -c conda-forge 26 | ``` 27 | 28 | Then all you need to do is clone the repository and move into it: 29 | 30 | ```bash 31 | git clone https://gitlab.developers.cam.ac.uk/ch/sormanni/loopgen.git 32 | cd loopgen 33 | ``` 34 | 35 | And then run the setup script: 36 | ```bash 37 | bash setup_env.sh 38 | ``` 39 | 40 | And finally activate the conda environment: 41 | 42 | ```bash 43 | conda activate loopgen 44 | ``` 45 | 46 | And you are ready to go! 47 | 48 | ## Usage 49 | 50 | The basic structure of the command-line interface is as follows: 51 | 52 | ``` 53 | loopgen /path/to/data/file [options] 54 | ``` 55 | 56 | Where `` can be either `frames` (diffusion over SE3) or `coords` (diffusion over R3). 57 | For ``, users can select either `train` (train a model) or `generate` (generate from a model). 58 | 59 | To generate binding loop structures for an input epitope, you can use the `generate` command: 60 | 61 | ``` 62 | loopgen frames generate /path/to/pdb/file --checkpoint /path/to/weights.ckpt --config config/best.yaml 63 | ``` 64 | 65 | To ensure reasonable performance, we recommend using a reduced PDB file containing only a subset of residues 66 | to be targeted by the generated CDR loops. Note that the CDR will be generated with centre of mass at the 67 | origin (0, 0, 0) of the coordinate system in the PDB file, so the epitope should be transformed so that it 68 | is an appropriate distance and orientation to the CDR. We recommend placing the epitope centre of mass 10-12 69 | angstroms from the origin. Users can also input an HDF5 file (see below) for generation. 70 | 71 | Users can also pass the flags `--permute_epitopes`, `--scramble_epitopes`, `---translate_cdrs` 72 | to run generation under different perturbations of the input epitope. We use these perturbations to evaluate 73 | the dependence of generated structures on different features of the target protein. See 74 | [the paper](https://arxiv.org/abs/2310.07051) for more details. 75 | 76 | ## Training 77 | 78 | To train a frame diffusion model, run: 79 | 80 | ``` 81 | loopgen frames train /path/to/hdf5/file --splits /path/to/json/file --config /path/to/yaml/file 82 | ``` 83 | 84 | You can see all the relevant options for each command by running `loopgen --help`. 85 | 86 | You can access our datasets (including generated structures), training splits, 87 | and trained weights [here](https://drive.google.com/drive/folders/1cxJV5MnMBTl8VjqkfIo4EsRCSDLHWh1B?usp=drive_link). 88 | 89 | We use `mlflow` for our logging and experiment tracking, which is installed as part of the virtual environment. To use 90 | `mlflow` interactively, run the following from the virtual environment on the machine where you're training: 91 | 92 | `mflow ui --host 0.0.0.0 --port ` 93 | 94 | Which deploys an mlflow server on port ``. Then use the browser on your local machine to access: 95 | 96 | `:` 97 | 98 | ### Dataset format 99 | 100 | Training relies on CDR/epitope structural information. We have found that 101 | these data types are best stored using `hdf5` format. The basic structure we choose 102 | to use is as follows: 103 | 104 | ``` 105 | - 106 | - 107 | - "receptor" 108 | - "N_coords" (N, 3) array of nitrogen coordinates for N receptor residues 109 | - "CA_coords" (N, 3) array of CA coordinates for N receptor residues 110 | - "C_coords" (N, 3) array of C coordinates for N receptor residues 111 | - "CB_coords" (N, 3) array of CB coordinates for N receptor residues (can be anything for glycines) 112 | - "sequence" (N,) array of integers 0-19, based on sorted 3-letter AA codes 113 | - "ligand" 114 | - "N_coords" (M, 3) array of nitrogen coordinates for M ligand residues 115 | - "CA_coords" (M, 3) array of CA coordinates for M ligand residues 116 | - "C_coords" (M, 3) array of C coordinates for M ligand residues 117 | - "CB_coords" (M, 3) array of CB coordinates for M ligand residues (can be anything for glycines) 118 | - "sequence" (M,) array of integers 0-19, based on sorted 3-letter AA codes 119 | ``` 120 | 121 | However our pipeline can handle any format so long as the key/value structure 122 | from `receptor`/`ligand` and below is consistent. In our case, `receptor` refers to the 123 | epitope and `ligand` refers to the CDR loop. 124 | 125 | ### Config 126 | 127 | The config file is a YAML file containing all the hyperparameters for training. 128 | The possible options (for both `frames` and `coords` models) are as follows: 129 | 130 | - `learning_rate`: Learning rate used for training (default: `1e-4`) 131 | - `batch_size`: Batch size used for training. The default of 128 requires about 30Gb of GPU memory. (default: `128`) 132 | - `self_conditioning_rate`: Rate at which self-conditioning will be trained. Implementation 133 | is the same as RFDiffusion (default: `0.5`) 134 | - `num_time_steps`: Number of time steps to use for the diffusion discretisation (default: `100`). 135 | - `min_trans_beta`: Minimum value of the coefficients used to calculate variances for the noised translations (default: `1e-4`) 136 | - `max_trans_beta`: Maximum value of the coefficients used to calculate variances for the noised translations (default: `20.0`) 137 | - `translation_beta_schedule`: Schedule to use for the coefficients used to calculate variances for the noised translations. (default: `linear`) 138 | - `time_step_encoding_channels`: Number of features to use for the sinusoidal time step encoding. (default: `5`) 139 | - `use_cdr_positional_encoding`: Whether to use a positional encoding for ligand (CDR) residues. This should 140 | only be used if ligands are linear polypeptide fragments, and it assumes the coordinate arrays in the dataset 141 | file are N-to-C ordered in the input HDF5 file. (default: `True`) 142 | - `positional_encoding_channels`: Number of features to use for the positional encoding. (default: `5`) 143 | - `hidden_scalar_channels`: Dimensionality of the hidden layer node scalar features in the GVP-GNN. (default: `128`) 144 | - `hidden_vector_channels`: Dimensionality of the hidden layer node vector features in the GVP-GNN. (default: `64`) 145 | - `hidden_edge_scalar_channels`: Dimensionality of the hidden layer edge scalar features in the GVP-GNN. (default: `64`) 146 | - `hidden_edge_vector_channels`: Dimensionality of the hidden layer edge vector features in the GVP-GNN. (default: `32`) 147 | - `num_layers`: Number of message passing layers in the GVP-GNN. (default: `3`) 148 | - `dropout`: Dropout rate to use in the GVP-GNN. (default: `0.2`) 149 | - `aggr`: Message aggregation method to use in the GVP-GNN. (default: `sum`) 150 | 151 | For the `frames` models, there are a few additional parameters: 152 | - `rotation_beta_schedule`: Schedule to use for the coefficients used to calculate variances for the noised rotations. (default: `logarithmic`) 153 | - `min_rot_beta`: Minimum value of the coefficients used to calculate variances for the noised rotations (default: `0.1`) 154 | - `max_rot_beta`: Maximum value of the coefficients used to calculate variances for the noised rotations (default: `1.5`) 155 | - `igso3_support_n`: Number of terms used to discretise angles of rotation when sampling from IGSO3. (default: `2000`) 156 | - `igso3_expansion_n`: Number of terms used to approximate the infinite sum in the IGSO3 density function (default: `2000`) 157 | 158 | All of these parameters can be passed in an input config YAML file to the `train` command. When using `generate` 159 | with a trained model, the same config file should be passed as an argument. 160 | -------------------------------------------------------------------------------- /config/best.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "translations_quadratic" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 # only needed if training 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "quadratic" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | # Output Paths 27 | out_dir: "translations_quadratic" 28 | -------------------------------------------------------------------------------- /config/train_coords_only.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "coords_only" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 11 | translation_beta_schedule: "linear" 12 | num_time_steps: 100 13 | min_trans_beta: 0.0001 14 | max_trans_beta: 20 15 | weight_loss_by_norm: True 16 | 17 | # DataModule Parameters 18 | fix_cdr_centre: True 19 | self_conditioning_rate: 0.5 20 | add_pad_cdr_features: True 21 | 22 | # Output Paths 23 | out_dir: "./coords_only" 24 | -------------------------------------------------------------------------------- /config/train_no_position.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "position_ablation" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 # only needed if training 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "linear" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_cdr_positional_encoding: False 20 | use_igso3_cache: False 21 | 22 | # DataModule Parameters 23 | fix_cdr_centre: True 24 | self_conditioning_rate: 0.5 25 | add_pad_cdr_features: True 26 | 27 | #Output Paths 28 | out_dir: "./position_ablation" 29 | -------------------------------------------------------------------------------- /config/train_no_self_cond.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "no_self_conditioning" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "linear" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.0 24 | add_pad_cdr_features: False 25 | 26 | # Output Paths 27 | out_dir: "./no_self_conditioning" 28 | -------------------------------------------------------------------------------- /config/train_rotations_linear.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "rotation_linear" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 # only needed if training 11 | rotation_beta_schedule: "linear" 12 | translation_beta_schedule: "linear" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | # Output Paths 27 | out_dir: "./rotation_linear" 28 | -------------------------------------------------------------------------------- /config/train_rotations_quadratic.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "rotation_quadratic" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 # only needed if training 11 | rotation_beta_schedule: "quadratic" 12 | translation_beta_schedule: "linear" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | # Output Paths 27 | out_dir: "./rotation_quadratic" 28 | -------------------------------------------------------------------------------- /config/train_translations_linear.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 50 4 | run_name: "default" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "linear" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | # Output Paths 27 | out_dir: "./default" 28 | -------------------------------------------------------------------------------- /config/train_translations_logarithmic.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "translations_logarithmic" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "logarithmic" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | # Output Paths 27 | out_dir: "translations_logarithmic" 28 | -------------------------------------------------------------------------------- /config/train_translations_sigmoid.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Parameters 2 | experiment_name: "Loopgen training" 3 | steps_per_log: 5 4 | run_name: "translations_sigmoid" 5 | checkpoint_metric: "validation_loss" 6 | batch_size: 128 7 | max_epochs: 250 8 | 9 | # Model Parameters 10 | learning_rate: 0.0001 # only needed if training 11 | rotation_beta_schedule: "logarithmic" 12 | translation_beta_schedule: "sigmoid" 13 | num_time_steps: 100 14 | min_rot_beta: 0.1 15 | max_rot_beta: 1.5 16 | min_trans_beta: 0.0001 17 | max_trans_beta: 20 18 | weight_loss_by_norm: True 19 | use_igso3_cache: False 20 | 21 | # DataModule Parameters 22 | fix_cdr_centre: True 23 | self_conditioning_rate: 0.5 24 | add_pad_cdr_features: True 25 | 26 | #Output Paths 27 | out_dir: "translations_sigmoid" 28 | -------------------------------------------------------------------------------- /envs/environment.yml: -------------------------------------------------------------------------------- 1 | name: loopgen 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | - pytorch 6 | - defaults 7 | - bioconda 8 | dependencies: 9 | - python=3.11 10 | - biopython=1.81 11 | - scipy 12 | - pyyaml=6.0 13 | - scikit-learn=1.3.0 14 | - pandas=2.0.3 15 | - mlflow=2.5.0 16 | - numpy 17 | - matplotlib=3.7.2 18 | - seaborn=0.12.2 19 | - plotly=5.16.1 20 | - jupyter 21 | - pip 22 | - einops=0.6.1 23 | - pytest=7.4.0 24 | - urllib3=1.26.7 25 | 26 | -------------------------------------------------------------------------------- /envs/pip_requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.9.0 2 | torch-geometric==2.3.1 3 | pytorch-lightning==2.0.1 4 | e3nn==0.5.1 -------------------------------------------------------------------------------- /figures/3ulu_ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mgreenig/loopgen/e2726c8f24e84fdfb824f3616205a3d6b0d9703b/figures/3ulu_ensemble.png -------------------------------------------------------------------------------- /loopgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import ( 2 | ScalarFeatureStructureData, 3 | VectorFeatureStructureData, 4 | ScalarFeatureComplexData, 5 | VectorFeatureComplexData, 6 | ) 7 | from .structure import ( 8 | OrientationFrames, 9 | Structure, 10 | LinearStructure, 11 | AminoAcid3, 12 | AminoAcid1, 13 | ) 14 | from .utils import * 15 | from .data import ReceptorLigandDataset 16 | 17 | from . import nn 18 | from . import model 19 | -------------------------------------------------------------------------------- /loopgen/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line interface for loopgen. 3 | """ 4 | 5 | import argparse 6 | import warnings 7 | 8 | from functools import partial 9 | 10 | from loopgen.model import CDRCoordinateDiffusionModel, CDRFrameDiffusionModel 11 | from loopgen.model.train import add_train_args, train_from_args 12 | from loopgen.model.generate import add_generate_args, generate_from_args 13 | 14 | 15 | USAGE = """ 16 | %(prog)s [options] 17 | 18 | LoopGen: De novo design of peptide CDR binding loops with SE(3) diffusion models. 19 | 20 | Currently, loopgen supports two models: 21 | - frames: Diffusion over the SE(3), i.e. a 3D rotation and translation for each residue. 22 | - coords: Diffusion over R3, i.e. a 3D translation for each residue. 23 | 24 | Each model supports two commands: 25 | - train: Trains a new or saved model on a dataset in HDF5 format. 26 | - generate: Generates new structures using a saved model. 27 | 28 | """ 29 | 30 | 31 | def main(): 32 | """ 33 | Runs the loopgen program based on the provided commands. 34 | """ 35 | 36 | # filter pytorch lightning warning about number of workers in dataloader 37 | warnings.filterwarnings("ignore", ".*does not have many workers.*") 38 | 39 | parser = argparse.ArgumentParser( 40 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 41 | usage=USAGE, 42 | ) 43 | 44 | subparser = parser.add_subparsers(title="commands") 45 | 46 | frames_parser = subparser.add_parser( 47 | "frames", 48 | description="Loop diffusion over SE(3), modelling a 3D rotation (orientation) " 49 | "and translation (CA coordinate) for each residue.", 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 51 | ) 52 | 53 | frames_subparser = frames_parser.add_subparsers(title="commands") 54 | 55 | coords_parser = subparser.add_parser( 56 | "coords", 57 | description="Loop diffusion over R3, modelling a 3D translation (CA coordinate) for each residue.", 58 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 59 | ) 60 | 61 | coords_subparser = coords_parser.add_subparsers(title="commands") 62 | 63 | for subp, model in zip( 64 | [frames_subparser, coords_subparser], 65 | [CDRFrameDiffusionModel, CDRCoordinateDiffusionModel], 66 | ): 67 | # Adds the parsers for `train` and `evaluate` to the parser for each model 68 | train_parser = subp.add_parser( 69 | "train", 70 | description=f"Train a model specified by a config YAML file.", 71 | usage="train [options]", 72 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 73 | ) 74 | 75 | add_train_args(train_parser) 76 | 77 | train_parser.set_defaults(func=partial(train_from_args, model_class=model)) 78 | 79 | gen_parser = subp.add_parser( 80 | "generate", 81 | description=f"Generate new structures using a trained model.", 82 | usage="generate [options]", 83 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 84 | ) 85 | 86 | add_generate_args(gen_parser) 87 | 88 | gen_parser.set_defaults(func=partial(generate_from_args, model_class=model)) 89 | 90 | args = parser.parse_args() 91 | args.func(args) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /loopgen/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the base dataset class for receptor/ligand structure pairs 3 | and the base datamodule class that organises structure pairs into 4 | train/test/validation splits. 5 | """ 6 | 7 | from __future__ import annotations 8 | from typing import ( 9 | Any, 10 | List, 11 | Tuple, 12 | TypedDict, 13 | Sequence, 14 | Optional, 15 | Hashable, 16 | Union, 17 | Dict, 18 | Callable, 19 | Generator, 20 | Type, 21 | Set, 22 | ) 23 | from abc import ABC, abstractmethod 24 | import json 25 | 26 | import torch 27 | from torch.utils.data import Dataset 28 | from pytorch_lightning import LightningDataModule 29 | from sklearn.model_selection import train_test_split as sklearn_split 30 | import numpy as np 31 | import h5py 32 | from h5py import Group as h5pyGroup 33 | from h5py import File as h5pyFile 34 | 35 | from .structure import Structure, LinearStructure 36 | 37 | 38 | class StructureDict(TypedDict): 39 | 40 | """ 41 | Type for a dictionary-like object containing keys for all the relevant information needed to 42 | create a Structure. 43 | """ 44 | 45 | sequence: np.ndarray # sequence represented as integers 0-19 in alphabetical order of 3-letter AA codes 46 | N_coords: np.ndarray 47 | CA_coords: np.ndarray 48 | C_coords: np.ndarray 49 | CB_coords: np.ndarray 50 | 51 | 52 | def structure_dict_to_structure( 53 | frag_dict: StructureDict, 54 | device: torch.device, 55 | structure_class: Type[Structure] = Structure, 56 | ) -> Union[Structure, LinearStructure]: 57 | """ 58 | Creates a structure object (with the class `structure_class`) from a 59 | StructureDict object. 60 | """ 61 | 62 | return structure_class( 63 | N_coords=torch.as_tensor( 64 | frag_dict["N_coords"][:], dtype=torch.float32, device=device 65 | ), 66 | CA_coords=torch.as_tensor( 67 | frag_dict["CA_coords"][:], dtype=torch.float32, device=device 68 | ), 69 | C_coords=torch.as_tensor( 70 | frag_dict["C_coords"][:], dtype=torch.float32, device=device 71 | ), 72 | CB_coords=torch.as_tensor( 73 | frag_dict["CB_coords"][:], dtype=torch.float32, device=device 74 | ), 75 | sequence=torch.as_tensor( 76 | frag_dict["sequence"][:], dtype=torch.int64, device=device 77 | ), 78 | ) 79 | 80 | 81 | class ReceptorLigandPair(TypedDict): 82 | """Type for a dictionary-like object containing a pair of examples with a name.""" 83 | 84 | receptor: StructureDict 85 | ligand: StructureDict 86 | name: str 87 | 88 | 89 | def hdf5_group_generator( 90 | group: Union[h5pyGroup, h5pyFile], 91 | predicate: Callable[[Union[h5pyGroup, h5pyFile]], bool], 92 | ) -> Generator: 93 | """ 94 | Recursively traverses through an HDF5 structure via a 95 | depth-first search, evaluating a predicate at each group. 96 | When a group that yields a value of True is 97 | identified, yields the group. 98 | """ 99 | if isinstance(group, h5py.Dataset): 100 | return 101 | 102 | if predicate(group) is True: 103 | yield group 104 | else: 105 | for key in group: 106 | yield from hdf5_group_generator(group[key], predicate) 107 | 108 | 109 | class ReceptorLigandDataset(Dataset): 110 | 111 | """ 112 | Class for storing a dataset of receptor/ligand structure pairs. 113 | 114 | The class stores a list of receptor/ligand fragment pairs, 115 | where each pair is a mapping-like ReceptorLigandPair object 116 | that contains StructureDict for the receptor and ligand, 117 | which contains keys for sequence (stored as integers), 118 | and N, CA, C, and CB coordinates. 119 | 120 | Most commonly, instances of this class are initialised via the `from_hdf5_file()` class method. 121 | 122 | Indexing and instance of the class returns a 3-tuple, the first element of 123 | which is the name of the complex, the second element of which is the 124 | receptor Structure, and the final element of which is the ligand Structure. 125 | """ 126 | 127 | def __init__( 128 | self, 129 | structure_pairs: List[ReceptorLigandPair], 130 | device: torch.device, 131 | receptor_structure_class: Type[Structure] = Structure, 132 | ligand_structure_class: Type[Structure] = Structure, 133 | ): 134 | self._structure_pairs = structure_pairs 135 | self._structure_pairs_by_name = {pair["name"]: pair for pair in structure_pairs} 136 | self.device = device 137 | 138 | self.receptor_structure_class = receptor_structure_class 139 | self.ligand_structure_class = ligand_structure_class 140 | 141 | @property 142 | def structure_pairs(self) -> List[ReceptorLigandPair]: 143 | """The underlying receptor/ligand structure pairs.""" 144 | return self._structure_pairs 145 | 146 | @property 147 | def structure_pairs_by_name(self) -> Dict[str, ReceptorLigandPair]: 148 | """The underlying receptor/ligand structure pairs, stored in a dictionary indexed by name.""" 149 | return self._structure_pairs_by_name 150 | 151 | def __len__(self) -> int: 152 | return len(self._structure_pairs) 153 | 154 | def __contains__(self, name: str) -> bool: 155 | """Checks whether a structure pair with the given name exists in the dataset.""" 156 | return name in self._structure_pairs_by_name 157 | 158 | def pair_to_structures( 159 | self, frag_pair: ReceptorLigandPair 160 | ) -> Tuple[str, Structure, LinearStructure]: 161 | """ 162 | Converts a fragment pair into a 3-tuple, the first element of which is the name 163 | of the fragment pair, the second element of which is the antigen Structure, and 164 | the third element of which is the CDR Structure. 165 | """ 166 | 167 | name = frag_pair["name"] 168 | 169 | receptor_dict = frag_pair["receptor"] 170 | receptor_structure = structure_dict_to_structure( 171 | receptor_dict, self.device, self.receptor_structure_class 172 | ) 173 | 174 | ligand_dict = frag_pair["ligand"] 175 | ligand_structure = structure_dict_to_structure( 176 | ligand_dict, self.device, self.ligand_structure_class 177 | ) 178 | 179 | return name, receptor_structure, ligand_structure 180 | 181 | def __getitem__(self, idx: int) -> Tuple[str, Structure, LinearStructure]: 182 | frag_pair = self._structure_pairs[idx] 183 | return self.pair_to_structures(frag_pair) 184 | 185 | def __repr__(self): 186 | return f"{self.__class__.__name__}(len={len(self)})" 187 | 188 | def sample(self, num: int = 1) -> List[Tuple[str, Structure, LinearStructure]]: 189 | """Samples some number of receptor/ligand complexes from the dataset.""" 190 | samples = np.random.choice(len(self._structure_pairs), num) 191 | return [self[idx] for idx in samples] 192 | 193 | @classmethod 194 | def from_hdf5_file( 195 | cls, 196 | hdf5_file: Union[h5py.File, str], 197 | device: torch.device = torch.device("cpu"), 198 | receptor_key: str = "receptor", 199 | ligand_key: str = "ligand", 200 | receptor_structure_class: Type[Structure] = Structure, 201 | ligand_structure_class: Type[Structure] = Structure, 202 | ): 203 | """ 204 | Takes in a hdf5 file containing keys for fragment pairs stored 205 | under keys for proteins, and simply returns a class instance initialised with all the pairs. 206 | 207 | First searches for the receptor key in groups in file - once a group is found 208 | with the receptor key, the group is checked for the ligand key. If the ligand key 209 | is not in the group, the group is descended until the ligand key is found. 210 | If ligand CDR keys are found within subgroups of the receptor group, they 211 | are all added to the dataset. 212 | 213 | :param hdf5_file: The hdf5 file containing the fragment pairs, passed either as a string filepath 214 | or as an h5py.File object. 215 | :param device: The device on which to store the structures. 216 | :param receptor_key: The key under which receptor data is stored. The keys "N_coords", "CA_coords", 217 | "C_coords", "CB_coords", and "sequence" must be present under this key. (default: "receptor") 218 | :param ligand_key: The key under which ligand data are stored. The keys "N_coords", "CA_coords", 219 | "C_coords", "CB_coords", and "sequence" must be present under this key. (default: "ligand") 220 | :param receptor_structure_class: The class to be used to represent a receptor structure. This can 221 | be any subclass of Structure, and can be changed after initialisation. (default: Structure) 222 | :param ligand_structure_class: The class to be used to represent a ligand structure. This can 223 | be any subclass of Structure, and can be changed after initialisation. (default: Structure) 224 | :returns: Dataset with receptor/ligand pairs read from the file. 225 | """ 226 | 227 | if isinstance(hdf5_file, str): 228 | hdf5_file = h5py.File(hdf5_file) 229 | 230 | structure_pairs = [] 231 | group_has_receptor = lambda grp: receptor_key in grp 232 | group_has_ligand = lambda grp: ligand_key in grp 233 | 234 | # loop through groups with the receptor key 235 | for receptor_group in hdf5_group_generator(hdf5_file, group_has_receptor): 236 | # if ligand in the receptor group, make a receptor/ligand pair 237 | if ligand_key in receptor_group: 238 | pair_dict = ReceptorLigandPair( 239 | name=receptor_group.name, 240 | receptor=receptor_group[receptor_key], 241 | ligand=receptor_group[ligand_key], 242 | ) 243 | structure_pairs.append(pair_dict) 244 | # otherwise look for the CDR key in the lower levels of the group 245 | else: 246 | for ligand_group in hdf5_group_generator( 247 | receptor_group, group_has_ligand 248 | ): 249 | pair_dict = ReceptorLigandPair( 250 | name=ligand_group.name, 251 | receptor=receptor_group[receptor_key], 252 | ligand=ligand_group[ligand_key], 253 | ) 254 | structure_pairs.append(pair_dict) 255 | 256 | return cls( 257 | structure_pairs, device, receptor_structure_class, ligand_structure_class 258 | ) 259 | 260 | def subset_by_name(self, names: Set[str]): 261 | """ 262 | Returns a new dataset with the subset of structure pairs whose names appear 263 | in the input set of names. 264 | 265 | :param names: A set of names of structure pairs to keep. These are checked with the string 266 | stored under the key "name" in each ReceptorLigandPair object in structure_pairs. 267 | :returns: A new ReceptorLigandDataset object with the subset of structure pairs. 268 | """ 269 | 270 | name_set = set(names) 271 | kept_pairs = [] 272 | for pair in self._structure_pairs: 273 | if pair["name"] in name_set: 274 | kept_pairs.append(pair) 275 | 276 | return self.__class__(kept_pairs, self.device) 277 | 278 | def train_test_split( 279 | self, 280 | train_prop: float = 0.8, 281 | random_state: int = 123, 282 | by: Optional[Sequence[Hashable]] = None, 283 | *args: Any, 284 | **kwargs: Any, 285 | ) -> Tuple[ReceptorLigandDataset, ReceptorLigandDataset]: 286 | """ 287 | Performs a train/test split with a proportion 288 | `train_prop` samples kept in the training set, and 289 | the remainder allocated to the test set. Users can 290 | specify the `by` argument to provide a sequence 291 | of labels - the same length as `self.fragment_pairs` - 292 | which should be used to determine the train/test split, 293 | so that no label appears in both the test and the train set. 294 | If `by` is specified, `train_prop` will refer to the proportion 295 | of **unique** labels allocated to the train set. 296 | """ 297 | 298 | if train_prop == 1.0: 299 | train_frags = self._structure_pairs 300 | test_frags = [] 301 | elif train_prop == 0.0: 302 | train_frags = [] 303 | test_frags = self._structure_pairs 304 | # split by labels if provided 305 | elif by is not None: 306 | if len(by) != len(self): 307 | raise ValueError( 308 | f"Length of by ({len(by)}) does not match length of dataset {len(self)}." 309 | ) 310 | 311 | labels = np.unique(by) 312 | train_labels, test_labels = sklearn_split( 313 | labels, train_size=train_prop, random_state=random_state 314 | ) 315 | train_label_set = set(train_labels) 316 | test_label_set = set(test_labels) 317 | 318 | train_frags = [] 319 | test_frags = [] 320 | for lab, frag_pair in zip(by, self._structure_pairs): 321 | if lab in train_label_set: 322 | train_frags.append(frag_pair) 323 | elif lab in test_label_set: 324 | test_frags.append(frag_pair) 325 | 326 | # otherwise just split the list of fragment pairs 327 | else: 328 | train_frags, test_frags = sklearn_split( 329 | self._structure_pairs, 330 | train_size=train_prop, 331 | random_state=random_state, 332 | ) 333 | 334 | # make class instances for both the train and test datasets 335 | train_dataset = self.__class__(train_frags, device=self.device, *args, **kwargs) 336 | test_dataset = self.__class__(test_frags, device=self.device, *args, **kwargs) 337 | 338 | return train_dataset, test_dataset 339 | 340 | 341 | def load_splits_file( 342 | filepath: str, dataset: ReceptorLigandDataset 343 | ) -> Dict[str, Set[str]]: 344 | """ 345 | Validates a JSON file containing a dictionary of train/test/validation splits against a dataset, 346 | ensuring that all the required keys are provided, that names do not appear in multiple sets, and 347 | that all names provided in the file are present in the dataset. Raises an error if these conditions are not met. 348 | 349 | Returns a dictionary of the splits (taken from the file) if all conditions are met, 350 | with the lists of names converted into sets. 351 | 352 | :param filepath: Filepath to a JSON file containing a dictionary of train/test/validation splits. 353 | The file must have the key "train", and optionally keys "test" and "validation" for the test/validation sets. 354 | :param dataset: The dataset against which to validate the splits. 355 | :returns: A dictionary of the splits (sets of names) if all conditions are met. 356 | """ 357 | split_dict = json.load(open(filepath, "r")) 358 | 359 | if "train" not in split_dict: 360 | raise KeyError("JSON file must contain a key 'train' for the training set.") 361 | 362 | if not isinstance(split_dict["train"], list): 363 | raise TypeError( 364 | "The item stored under the key 'train' must be a list of names." 365 | ) 366 | 367 | train_names = set(split_dict["train"]) 368 | for name in train_names: 369 | if name not in dataset: 370 | raise ValueError( 371 | f"Name {name} in training set not found in dataset {dataset}." 372 | ) 373 | 374 | if "test" in split_dict: 375 | test_names = set(split_dict["test"]) 376 | if not isinstance(split_dict["test"], list): 377 | raise TypeError( 378 | "The item stored under the key 'test' must be a list of names." 379 | ) 380 | for name in split_dict["test"]: 381 | if name not in dataset: 382 | raise ValueError( 383 | f"Name {name} in test set not found in dataset {dataset}." 384 | ) 385 | if name in train_names: 386 | raise ValueError( 387 | f"Name {name} appears in both the training and test sets." 388 | ) 389 | else: 390 | test_names = set() 391 | 392 | if "validation" in split_dict: 393 | validation_names = set(split_dict["validation"]) 394 | if not isinstance(split_dict["validation"], list): 395 | raise TypeError( 396 | "The item stored under the key 'validation' must be a list of names." 397 | ) 398 | for name in split_dict["validation"]: 399 | if name not in dataset: 400 | raise ValueError( 401 | f"Name {name} in test set not found in dataset {dataset}." 402 | ) 403 | if name in train_names: 404 | raise ValueError( 405 | f"Name {name} appears in both the training and validation sets." 406 | ) 407 | if name in test_names: 408 | raise ValueError( 409 | f"Name {name} appears in both the test and validation sets." 410 | ) 411 | else: 412 | validation_names = set() 413 | 414 | split_dict = { 415 | "train": train_names, 416 | "test": test_names, 417 | "validation": validation_names, 418 | } 419 | return split_dict 420 | 421 | 422 | def load_generated_structures( 423 | hdf5_filepath: str, device: torch.device("cpu") 424 | ) -> List[Tuple[str, Structure, LinearStructure, Tuple[LinearStructure, ...]]]: 425 | """ 426 | Loads generated structures from an HDF5 file. The HDF5 file should be the output 427 | of a generation run (i.e. running loopgen generate [args]), 428 | and should contain the following keys stored under every structure pair 429 | group: 430 | - receptor 431 | - ligand 432 | - generated_ (for however many generated structures were generated) 433 | """ 434 | 435 | file = h5py.File(hdf5_filepath) 436 | all_structures = [] 437 | for group in hdf5_group_generator( 438 | file, lambda g: "receptor" in g and "ligand" in g 439 | ): 440 | receptor = group["receptor"] 441 | receptor_structure = structure_dict_to_structure(receptor, device) 442 | 443 | ligand = group["ligand"] 444 | ligand_structure = structure_dict_to_structure(ligand, device, LinearStructure) 445 | 446 | generated = tuple(group[key] for key in group if key.startswith("generated_")) 447 | generated_structures = tuple( 448 | structure_dict_to_structure(g, device, LinearStructure) for g in generated 449 | ) 450 | 451 | all_structures.append( 452 | (group.name, receptor_structure, ligand_structure, generated_structures) 453 | ) 454 | 455 | return all_structures 456 | -------------------------------------------------------------------------------- /loopgen/distributions.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Any 2 | 3 | import numpy as np 4 | import torch 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | from math import pi 7 | from abc import ABC, abstractmethod 8 | from e3nn.o3 import axis_angle_to_matrix 9 | 10 | from .utils import is_positive_semidefinite, is_symmetric 11 | 12 | 13 | SampleSize = Union[int, Tuple[int, ...]] 14 | 15 | 16 | class Distribution(ABC): 17 | 18 | """ 19 | Abstract class for representing a probability distribution, with methods `pdf()` and `sample()`. 20 | """ 21 | 22 | @abstractmethod 23 | def __init__(self, *params: Any, **kw_params: Any): 24 | """ 25 | Initialises an instance of the distribution with some parameters. 26 | """ 27 | 28 | pass 29 | 30 | @abstractmethod 31 | def pdf(self, value: Any): 32 | """ 33 | Calculates the probability density for some value(s) under the parametrised distribution. 34 | """ 35 | 36 | pass 37 | 38 | @abstractmethod 39 | def sample( 40 | self, size: SampleSize, device: torch.device = torch.device("cpu") 41 | ) -> torch.Tensor: 42 | """ 43 | Samples some values from the distribution with the specified size, returning a tensor. 44 | """ 45 | 46 | pass 47 | 48 | 49 | class Gaussian3DDistribution(Distribution): 50 | 51 | """ 52 | Standard multivariate gaussian distribution for 3 dimensions, 53 | using the PyTorch implementation `MultivariateNormal` 54 | but with the `Distribution` interface defined here. 55 | """ 56 | 57 | def __init__(self, mean: torch.Tensor, var: Union[torch.Tensor, float] = 1.0): 58 | if isinstance(mean, torch.Tensor) and mean.shape == (3,): 59 | self._mean = mean 60 | else: 61 | raise ValueError( 62 | "Argument mean should either be a float or a tensor of shape (3,)" 63 | ) 64 | 65 | if isinstance(var, float): 66 | self._cov_matrix = torch.zeros((3, 3), device=mean.device) 67 | self._cov_matrix.fill_diagonal_(var) 68 | self._prec_matrix = 1 / self._cov_matrix 69 | self._is_isotropic = True 70 | elif isinstance(var, torch.Tensor) and var.shape == (3, 3): 71 | assert is_positive_semidefinite( 72 | var 73 | ), "Covariance must be positive semidefinite" 74 | assert is_symmetric(var), "Covariance matrix must be symmetric" 75 | self._cov_matrix = var 76 | self._prec_matrix = torch.linalg.inv(self._cov_matrix) 77 | self._is_isotropic = False 78 | else: 79 | raise ValueError( 80 | "Argument var should either be a float or a tensor of shape (3, 3)" 81 | ) 82 | 83 | self._distribution = MultivariateNormal( 84 | loc=self._mean, covariance_matrix=self._cov_matrix 85 | ) 86 | 87 | @property 88 | def mean(self) -> torch.Tensor: 89 | """The mean of the distribution.""" 90 | return self._mean 91 | 92 | @property 93 | def cov_matrix(self) -> torch.Tensor: 94 | """The covariance matrix of the distribution.""" 95 | return self._cov_matrix 96 | 97 | @property 98 | def prec_matrix(self) -> torch.Tensor: 99 | """The precision matrix of the distribution.""" 100 | return self._prec_matrix 101 | 102 | @property 103 | def is_isotropic(self) -> bool: 104 | """ 105 | Whether the distribution has a scalar variance 106 | (i.e. a diagonal covariance matrix with constant value) or not. 107 | """ 108 | return self._is_isotropic 109 | 110 | def pdf(self, value: torch.Tensor) -> torch.Tensor: 111 | """The probability density function of the distribution.""" 112 | return torch.exp(self._distribution.log_prob(value)) 113 | 114 | def sample( 115 | self, size: SampleSize, device: torch.device = torch.device("cpu") 116 | ) -> torch.Tensor: 117 | """Samples from the distribution.""" 118 | if isinstance(size, int): 119 | size = (size,) 120 | 121 | return self._distribution.sample(size).to(device) 122 | 123 | def __repr__(self): 124 | return ( 125 | f"{self.__class__.__name__}(" 126 | f"mean={self._mean}, " 127 | f"cov={self._cov_matrix})" 128 | ) 129 | 130 | 131 | class IGSO3Distribution(Distribution): 132 | 133 | """ 134 | Contains code for the isotropic gaussian distribution on SO(3), which can be used to sample rotations. 135 | 136 | The distribution takes a single variance parameter `var` that determines the width of the distribution 137 | "around" the identity matrix. Sampling is performed via inverse-transform sampling using an interpolated 138 | approximation of the inverse CDF. Just call the `sample()` method, which takes in a shape (either an int or 139 | a tuple of ints) as its only argument and generates a set of random rotations. 140 | """ 141 | 142 | def __init__( 143 | self, 144 | var: float = 1.0, 145 | support_n: int = 5000, 146 | expansion_n: int = 5000, 147 | as_matrices: bool = True, 148 | uniform: bool = False, 149 | ): 150 | assert var > 0, "Variance must be greater than 0" 151 | assert ( 152 | support_n > 0 and expansion_n > 0 153 | ), "Length of support and expansion must be greater than 0" 154 | 155 | self._var = var 156 | self._pi = pi 157 | self._support_n = support_n 158 | self._expansion_n = expansion_n 159 | 160 | self.as_matrices = as_matrices 161 | self.uniform = uniform 162 | 163 | self._support = np.linspace(0, pi, num=self._support_n + 1)[1:] 164 | self._densities = self.pdf(self._support) 165 | self._scores = self.score(self._support) 166 | self._cumulative_densities = np.cumsum( 167 | (self._support[1] - self._support[0]) * self._densities, axis=0 168 | ) 169 | 170 | def __repr__(self): 171 | return ( 172 | f"{self.__class__.__name__}(" 173 | f"var={self._var:.3f}, " 174 | f"support_n={self._support_n}, " 175 | f"expansion_n={self._expansion_n})" 176 | ) 177 | 178 | @property 179 | def var(self) -> float: 180 | """The variance of the distribution.""" 181 | return self._var 182 | 183 | @property 184 | def support(self) -> np.ndarray: 185 | """The support of the distribution, i.e. the range of values over which densities are calculated.""" 186 | return self._support 187 | 188 | @property 189 | def densities(self) -> np.ndarray: 190 | """The densities of the distribution over the support.""" 191 | return self._densities 192 | 193 | @property 194 | def scores(self) -> np.ndarray: 195 | """ 196 | The scalar-valued scores of the distribution over the support. To get the actual score - which 197 | is a 3-D vector in the tangent space of SO(3) - multiply the scalar values by the axis of rotation. 198 | """ 199 | return self._scores 200 | 201 | @property 202 | def cumulative_densities(self) -> np.ndarray: 203 | """The cumulative density for each value in the support.""" 204 | return self._cumulative_densities 205 | 206 | def inf_sum(self, angle: Union[float, np.ndarray]) -> np.ndarray: 207 | """Infinite sum in the IGSO3 distribution.""" 208 | if isinstance(angle, float) or isinstance(angle, int): 209 | angle = np.array([angle], dtype=np.float64) 210 | 211 | expansion_steps = np.arange(self._expansion_n)[None, :] 212 | 213 | return np.sum( 214 | ((2 * expansion_steps) + 1) 215 | * np.exp(-expansion_steps * (expansion_steps + 1) * self._var) 216 | * ( 217 | np.sin((expansion_steps + (1 / 2)) * angle[:, None]) 218 | / np.sin(angle[:, None] / 2) 219 | ), 220 | axis=-1, 221 | ) 222 | 223 | def pdf(self, angle: Union[float, np.ndarray]) -> np.ndarray: 224 | """ 225 | Gives the probability density for some angle(s) under the parameterised IGSO3 distribution with some 226 | specified number of terms to expand the infinite series (see https://arxiv.org/pdf/2210.01776.pdf). 227 | """ 228 | 229 | density = (1 - np.cos(angle)) / self._pi 230 | 231 | if not self.uniform: 232 | density *= self.inf_sum(angle) 233 | 234 | return density 235 | 236 | def cdf(self, angle: Union[float, np.ndarray]) -> np.ndarray: 237 | """ 238 | Gives the cumulative density for some angle(s) under the parameterised IGSO3 distribution 239 | (see https://arxiv.org/pdf/2210.01776.pdf). 240 | """ 241 | 242 | if isinstance(angle, float) or isinstance(angle, int): 243 | angle = np.array([angle], dtype=np.float64) 244 | 245 | densities = np.resize( 246 | self._densities[None, :], (angle.shape[0], self._densities.shape[0]) 247 | ) 248 | support = np.resize( 249 | self._support[None, :], (angle.shape[0], self._support.shape[0]) 250 | ) 251 | 252 | angle_support_index = np.argmin(np.abs(angle[:, None] - support), axis=-1) 253 | 254 | angle_support_values = support[np.arange(support.shape[0]), angle_support_index] 255 | zeroed_densities = np.where( 256 | support > angle_support_values[:, None], 0.0, densities 257 | ) 258 | 259 | return np.trapz(zeroed_densities, x=support) 260 | 261 | def inv_cdf(self, cumulative_density: Union[float, np.ndarray]) -> np.ndarray: 262 | """ 263 | Inverse of the cumulative density function, taking a cumulative density value as 264 | input and returning the correct value on the distribution's support. 265 | """ 266 | assert np.all(cumulative_density >= 0), "The cumulative density must be >= 0" 267 | assert np.all(cumulative_density <= 1), "The cumulative density must be <1" 268 | 269 | return np.interp(cumulative_density, self._cumulative_densities, self._support) 270 | 271 | def score(self, angle: Union[float, np.ndarray], eps: float = 1e-12) -> np.ndarray: 272 | """ 273 | Gets the gradient of the log PDF at a given angle or array of angles. 274 | Specifically this computes d log f(w)/dw via df(w)/dw * 1/f(w) (quotient rule). 275 | The argument `eps` is for numerical stability, and is added to the divisor. 276 | """ 277 | if isinstance(angle, float) or isinstance(angle, int): 278 | angle = np.array([angle], dtype=np.float64) 279 | 280 | expansion_steps = np.arange(self._expansion_n)[None, :] 281 | a = expansion_steps + 0.5 282 | 283 | angle_expanded = angle[:, None] 284 | cos_half_angle = np.cos(angle_expanded / 2) 285 | cos_a_angle = np.cos(a * angle_expanded) 286 | sin_a_angle = np.sin(a * angle_expanded) 287 | sin_half_angle = np.sin(angle_expanded / 2) 288 | 289 | inf_sum_constant_terms = ((2 * expansion_steps) + 1) * np.exp( 290 | -expansion_steps * (expansion_steps + 1) * self._var 291 | ) 292 | 293 | inf_sum = np.sum( 294 | inf_sum_constant_terms * (sin_a_angle / sin_half_angle), 295 | axis=-1, 296 | ) 297 | inf_sum_derivative = np.sum( 298 | inf_sum_constant_terms 299 | * ( 300 | ((a * cos_a_angle) / sin_half_angle) 301 | - ((cos_half_angle * sin_a_angle) / (2 * sin_half_angle**2)) 302 | ), 303 | axis=-1, 304 | ) 305 | 306 | return inf_sum_derivative / (inf_sum + eps) 307 | 308 | @staticmethod 309 | def sample_axis(size: SampleSize) -> np.ndarray: 310 | """ 311 | Uniformly samples a random axis for rotation. 312 | 313 | Generates 3 variables from Gaussian (0,1), then normalizes. Method proven in 314 | Marsaglia, 1972. https://mathworld.wolfram.com/SpherePointPicking.html 315 | """ 316 | 317 | if size == 1: 318 | size = tuple() 319 | elif isinstance(size, int): 320 | size = (size,) 321 | 322 | vec = np.random.normal(size=size + (3,)) 323 | vec /= np.linalg.norm(vec, axis=-1, keepdims=True) 324 | 325 | return vec 326 | 327 | def sample_angle(self, size: SampleSize) -> np.ndarray: 328 | """ 329 | Samples a random angle for rotation according to Eq. 5 in https://openreview.net/forum?id=BY88eBbkpe5 330 | """ 331 | 332 | if isinstance(size, tuple): 333 | cdfs = np.random.rand(*size) 334 | else: 335 | cdfs = np.random.rand(size) 336 | 337 | angle = self.inv_cdf(cdfs) 338 | 339 | return angle 340 | 341 | def sample( 342 | self, size: SampleSize, device: torch.device = torch.device("cpu") 343 | ) -> torch.Tensor: 344 | """ 345 | Samples one or more rotation matrices according to `size`, returned as a tensor of shape `size + (3, 3)` 346 | or (3, 3) if `size` is 1. 347 | """ 348 | 349 | axes = torch.as_tensor( 350 | self.sample_axis(size), dtype=torch.float32, device=device 351 | ) 352 | angles = torch.as_tensor( 353 | self.sample_angle(size), dtype=torch.float32, device=device 354 | ) 355 | 356 | if self.as_matrices: 357 | return axis_angle_to_matrix(axes, angles).squeeze(0) 358 | 359 | return axes * angles 360 | -------------------------------------------------------------------------------- /loopgen/model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a sub-package for code used to build specific deep learning model. 3 | """ 4 | from typing import Type, Literal, Optional, Tuple, Union, Dict, Set 5 | 6 | import logging as lg 7 | 8 | import torch 9 | from pytorch_lightning import LightningDataModule, LightningModule 10 | from pytorch_lightning.utilities.seed import isolate_rng 11 | 12 | from .settings import ModelSettings, TrainSettings 13 | from .network import GVPR3ScorePredictor, GVPSE3ScorePredictor 14 | from .datamodule import CDRFrameDataModule, CDRCoordinateDataModule 15 | from .model import CDRFrameDiffusionModel, CDRCoordinateDiffusionModel 16 | from .types import ParamDictionary 17 | 18 | from ..data import ReceptorLigandDataset, load_splits_file 19 | 20 | 21 | def load_trained_model( 22 | model_class: Union[Type[CDRCoordinateDiffusionModel], Type[CDRFrameDiffusionModel]], 23 | checkpoint_path: str, 24 | settings_path: str, 25 | dataset_path: Optional[str] = None, 26 | splits_path: Optional[str] = None, 27 | device: torch.device = torch.device("cpu"), 28 | strict: bool = False, 29 | ) -> Tuple[LightningDataModule, LightningModule]: 30 | """ 31 | Loads a trained model and its datamodule from a checkpoint and settings file. 32 | 33 | :param model_class: The model class that was trained. 34 | :param checkpoint_path: Path to the checkpoint file. 35 | :param settings_path: Path to the settings YAML file. 36 | :param dataset_path: The path to the HDF5 file that was used for training - this is optional, but if 37 | provided, the dataset will be passed to the returned datamodule. 38 | :param splits_path: Dictionary containing the train/test/validation splits as string names 39 | of the instances in the dataset, stored under the respective keys "train", "test", and "validation". 40 | :param device: The device to load the model on. 41 | :param strict: Whether to load the model in strict mode, i.e. throw an error if the parameters 42 | in the checkpoint file do not match the parameters in the model class. 43 | :returns: The trained model and its associated datamodule. 44 | """ 45 | if dataset_path is not None: 46 | dataset = ReceptorLigandDataset.from_hdf5_file(dataset_path, device=device) 47 | else: 48 | dataset = None 49 | 50 | if splits_path is not None and dataset is not None: 51 | splits = load_splits_file(splits_path, dataset) 52 | else: 53 | splits = None 54 | 55 | settings = TrainSettings.from_yaml(settings_path) 56 | param_dict = settings.distribute_model_params(model_class) 57 | datamodule = model_class.datamodule_class( 58 | dataset=dataset, 59 | splits=splits, 60 | **param_dict[model_class.datamodule_class.__name__], 61 | ) 62 | example_batch = datamodule.generate_example() 63 | network = model_class.network_class( 64 | example_batch, **param_dict[model_class.network_class.__name__] 65 | ) 66 | model = model_class.load_from_checkpoint( 67 | checkpoint_path, 68 | network=network, 69 | map_location=device, 70 | strict=strict, 71 | **param_dict[model_class.__name__], 72 | ) 73 | return datamodule, model 74 | 75 | 76 | def setup_model( 77 | dataset: ReceptorLigandDataset, 78 | splits: Optional[Dict[str, Set[str]]], 79 | param_dict: ParamDictionary, 80 | model_class: Union[Type[CDRCoordinateDiffusionModel], Type[CDRFrameDiffusionModel]], 81 | checkpoint_path: Optional[str] = None, 82 | ) -> Tuple[LightningDataModule, LightningModule]: 83 | """ 84 | Using a dataset and a param dictionary storing arguments 85 | for the model's datamodule class, network class, and the model class itself, 86 | generate a datamodule and model instance. 87 | 88 | :param dataset: The dataset to be used for training/evaluation. 89 | :param splits: Optional dictionary containing the train/test/validation splits as string names 90 | of the instances in the dataset, stored under the respective keys "train", "test", and "validation". 91 | :param param_dict: A dictionary of parameters for the model. 92 | :param model_class: The model class to be used. 93 | :param checkpoint_path: A path to a PyTorch checkpoint file containing model weights. 94 | :returns: A tuple of the datamodule and model instances. 95 | """ 96 | 97 | datamodule = model_class.datamodule_class( 98 | dataset, 99 | splits, 100 | **param_dict[model_class.datamodule_class.__name__], 101 | ) 102 | 103 | with isolate_rng(): 104 | example_batch = datamodule.generate_example() 105 | 106 | network = model_class.network_class( 107 | example_batch, **param_dict[model_class.network_class.__name__] 108 | ) 109 | 110 | if checkpoint_path is None: 111 | model = model_class(network=network, **param_dict[model_class.__name__]) 112 | else: 113 | lg.info(f"Loading model from checkpoint {checkpoint_path}...") 114 | model = model_class.load_from_checkpoint( 115 | checkpoint_path, 116 | network=network, 117 | **param_dict[model_class.__name__], 118 | strict=False, 119 | ) 120 | 121 | return datamodule, model 122 | -------------------------------------------------------------------------------- /loopgen/model/datamodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the datamodule class that loads epitope and CDR structures 3 | for LoopGen diffusion models. 4 | """ 5 | 6 | from abc import ABC 7 | from typing import Optional, Sequence, Tuple, Callable, Dict, List, Union, Set 8 | 9 | import torch 10 | from e3nn.o3 import rand_matrix 11 | from pytorch_lightning import LightningDataModule 12 | from torch_geometric.data import HeteroData 13 | 14 | from .types import CDRFramesBatch 15 | 16 | from .utils import ( 17 | pad_cdr_features, 18 | get_cdr_feature, 19 | replace_cdr_features, 20 | replace_epitope_features, 21 | get_cdr_epitope_subgraphs, 22 | sinusoidal_encoding, 23 | ) 24 | from .types import ProteinGraph, VectorFeatureGraph 25 | 26 | from ..data import ReceptorLigandDataset 27 | from ..structure import Structure, LinearStructure, OrientationFrames 28 | from ..graph import VectorFeatureStructureData, VectorFeatureComplexData 29 | 30 | 31 | def add_time_step_encoding( 32 | graph: ProteinGraph, time_step_encoding: torch.Tensor 33 | ) -> None: 34 | """ 35 | Adds an input encoding of the time step to the node features `x` in a heterogeneous graph 36 | for all node types. The argument `time_step` is expected to be a tensor of size 1. 37 | This modifies the graph in-place. 38 | """ 39 | cdr_graph, epitope_graph = get_cdr_epitope_subgraphs(graph) 40 | cdr_node_features = cdr_graph.x 41 | cdr_node_features_with_time = torch.cat( 42 | [ 43 | cdr_node_features, 44 | time_step_encoding.expand(cdr_node_features.shape[0], -1), 45 | ], 46 | dim=-1, 47 | ) 48 | graph = replace_cdr_features(graph, cdr_node_features_with_time) 49 | 50 | if epitope_graph is not None: 51 | epitope_node_features = epitope_graph.x 52 | epitope_node_features_with_time = torch.cat( 53 | [ 54 | epitope_node_features, 55 | time_step_encoding.expand(epitope_node_features.shape[0], -1), 56 | ], 57 | dim=-1, 58 | ) 59 | graph = replace_epitope_features(graph, epitope_node_features_with_time) 60 | 61 | return graph 62 | 63 | 64 | def add_cdr_positional_encoding(graph: HeteroData, num_channels: int) -> None: 65 | """ 66 | Adds a sinusoidal encoding of each CDR's sequence position - 67 | assumed to be the position of the CDR's features in the node features `x` in a heterogeneous graph. 68 | This modifies the input graph in-place. 69 | """ 70 | cdr_features = get_cdr_feature(graph, "x") 71 | cdr_ptr = get_cdr_feature(graph, "ptr") 72 | 73 | num_per_graph = torch.diff(cdr_ptr) 74 | num_graphs = len(num_per_graph) 75 | max_num_per_batch = torch.max(num_per_graph).item() 76 | 77 | all_positions = ( 78 | torch.arange(max_num_per_batch, device=cdr_features.device) 79 | .unsqueeze(0) 80 | .expand(num_graphs, -1) 81 | ) 82 | 83 | cdr_positions = all_positions[all_positions < num_per_graph.unsqueeze(-1)] 84 | cdr_positional_encoding_channels = torch.arange( 85 | num_channels, device=cdr_features.device 86 | ) 87 | positional_encodings = sinusoidal_encoding( 88 | cdr_positions, cdr_positional_encoding_channels 89 | ) 90 | 91 | new_cdr_features = torch.cat([cdr_features, positional_encodings], dim=-1) 92 | replace_cdr_features(graph, new_cdr_features) 93 | 94 | 95 | def create_frame_graph( 96 | cdr_frames: OrientationFrames, 97 | epitope: Optional[Structure], 98 | time_step_encoding: torch.Tensor, 99 | use_cdr_positional_encoding: bool = True, 100 | num_pos_encoding_channels: int = 5, 101 | add_pad_cdr_features: bool = False, 102 | num_pad_cdr_features: int = 0, 103 | num_pad_cdr_vec_features: int = 0, 104 | pad_feature_value: float = 0.0, 105 | ) -> VectorFeatureGraph: 106 | """ 107 | Creates a graph representation of an epitope Structure and CDR frames, 108 | concatenating a sinusoidal time step encoding to the node features 109 | and optionally a sinusoidal sequence positional encoding of the CDR sequence position. 110 | """ 111 | 112 | if epitope is not None: 113 | graph = VectorFeatureComplexData.from_structures(epitope, cdr_frames) 114 | else: 115 | graph = VectorFeatureStructureData.from_structure(cdr_frames) 116 | 117 | if use_cdr_positional_encoding: 118 | add_cdr_positional_encoding(graph, num_pos_encoding_channels) 119 | 120 | add_time_step_encoding(graph, time_step_encoding) 121 | 122 | if add_pad_cdr_features is True: 123 | if num_pad_cdr_features > 0: 124 | graph = pad_cdr_features(graph, num_pad_cdr_features, pad_feature_value) 125 | if num_pad_cdr_vec_features > 0: 126 | graph = pad_cdr_features( 127 | graph, 128 | num_pad_cdr_vec_features, 129 | pad_feature_value, 130 | pad_dim=-2, 131 | feature_attr_name="vector_x", 132 | ) 133 | 134 | return graph 135 | 136 | 137 | def create_coord_graph( 138 | cdr_frames: OrientationFrames, 139 | epitope: Optional[Structure], 140 | time_step_encoding: torch.Tensor, 141 | use_cdr_positional_encoding: bool = True, 142 | num_pos_encoding_channels: int = 5, 143 | add_pad_cdr_features: bool = False, 144 | num_pad_cdr_features: int = 0, 145 | num_pad_cdr_vec_features: int = 0, 146 | pad_feature_value: float = 0.0, 147 | ) -> VectorFeatureGraph: 148 | """ 149 | Creates a graph representation from an epitope Structure and CDR coordinates 150 | (stored within OrientationFrames), concatenating a sinusoidal time step encoding to the node features 151 | and optionally a sinusoidal sequence positional encoding of the CDR sequence position. 152 | Removing frame information involves swapping the CDR vector features for a vector of zeros. 153 | """ 154 | 155 | graph = create_frame_graph( 156 | cdr_frames, 157 | epitope, 158 | time_step_encoding, 159 | use_cdr_positional_encoding, 160 | num_pos_encoding_channels, 161 | add_pad_cdr_features, 162 | num_pad_cdr_features, 163 | num_pad_cdr_vec_features, 164 | pad_feature_value, 165 | ) 166 | 167 | # replace the vector features with a single vector feature of zeros 168 | if epitope is not None: 169 | graph["ligand"].vector_x = torch.zeros( 170 | (cdr_frames.num_residues, 1, 3), device=cdr_frames.translations.device 171 | ) 172 | else: 173 | graph.vector_x = torch.zeros( 174 | (cdr_frames.num_residues, 1, 3), device=cdr_frames.translations.device 175 | ) 176 | 177 | if num_pad_cdr_vec_features > 0: 178 | graph = pad_cdr_features( 179 | graph, 180 | num_pad_cdr_vec_features, 181 | pad_feature_value, 182 | pad_dim=-2, 183 | feature_attr_name="vector_x", 184 | ) 185 | 186 | return graph 187 | 188 | 189 | class CDRDiffusionDataModule(LightningDataModule, ABC): 190 | """ 191 | Lightning data module for loading complexes as epitope structures and CDR frames. 192 | """ 193 | 194 | # callable that creates a graph from CDR frames, an epitope Structure, and a time step encoding 195 | create_graph: Callable 196 | 197 | def __init__( 198 | self, 199 | dataset: Optional[ReceptorLigandDataset] = None, 200 | splits: Optional[Dict[str, Set[str]]] = None, 201 | self_conditioning_rate: float = 0.5, 202 | pad_feature_value: float = 0.0, 203 | time_step_encoding_channels: int = 5, 204 | use_cdr_positional_encoding: bool = True, 205 | positional_encoding_channels: int = 5, 206 | batch_size: int = 128, 207 | ): 208 | """ 209 | :param dataset: Optional ReceptorLigandDataset for loading antigen/CDR Structure objects, 210 | which yields 3-tuples of the form (name, antigen Structure, cdr Structure) 211 | when indexed. If this is None, the datamodule throws an error when setup() is called. 212 | :param splits: Optional dictionary mapping keys ("train", "validation", "test") to sequences of 213 | antigen/CDR complex names. If this is None, the datamodule throws an error when setup() is called. 214 | :param self_conditioning_rate: The rate at which samples for training "self-conditioning" 215 | will be used. This is the rate at which the model will make a prediction 216 | on samples drawn from q(x_{t+1} | x_t) and use that information 217 | to condition its predictions for q(x_t | x_0). This is only passed as an argument 218 | here to check whether pad features are needed. 219 | :param pad_feature_value: The feature value for added pad features (if self-conditioning 220 | is used). 221 | :param time_step_encoding_channels: Number of channels to use for the sinusoidal time step 222 | encoding, which is concatenated onto the node features for both the epitope and CDR. 223 | :param use_cdr_positional_encoding: Whether to use a sequence positional encoding for CDR 224 | residues. If True, a sinusoidal positional encoding of each residue's sequence position 225 | is concatenated to each CDR residue feature. Note this is only used for the CDR residues, 226 | since it is not guaranteed that the epitope will be a linear sequence. 227 | :param positional_encoding_channels: Number of channels to use for the sinusoidal positional 228 | encoding for CDR residues, which is concatenated onto the node features for the CDR only. 229 | :param batch_size: Number of antigen/CDR complexes per batch. 230 | """ 231 | 232 | super().__init__() 233 | 234 | self._dataset = dataset 235 | self._splits = splits 236 | 237 | self._train_dataset = None 238 | self._validation_dataset = None 239 | self._test_dataset = None 240 | 241 | self._batch_size = batch_size 242 | 243 | if self._dataset is not None: 244 | self._device = self._dataset.device 245 | else: 246 | self._device = torch.device("cpu") 247 | 248 | self._using_self_conditioning = self_conditioning_rate > 0 249 | 250 | self._pad_feature_value = pad_feature_value 251 | self._time_step_encoding_channels = time_step_encoding_channels 252 | self._use_cdr_positional_encoding = use_cdr_positional_encoding 253 | self._positional_encoding_channels = positional_encoding_channels 254 | 255 | def setup(self, stage: str): 256 | """Performs the train/test/validation split according to the proportions passed to the constructor.""" 257 | if self._dataset is None: 258 | raise ValueError( 259 | "No dataset and/or splits dictionary provided to constructor, setup failed." 260 | ) 261 | 262 | if self._splits is None: 263 | self._train_dataset = self._dataset 264 | self._validation_dataset = self._dataset 265 | self._test_dataset = self._dataset 266 | else: 267 | self._train_dataset = self._dataset.subset_by_name(self._splits["train"]) 268 | self._validation_dataset = self._dataset.subset_by_name( 269 | self._splits["validation"] 270 | ) 271 | self._test_dataset = self._dataset.subset_by_name(self._splits["test"]) 272 | 273 | @property 274 | def dataset(self) -> ReceptorLigandDataset: 275 | """The complete underlying dataset.""" 276 | return self._dataset 277 | 278 | @property 279 | def train_dataset(self) -> Optional[ReceptorLigandDataset]: 280 | """The training dataset.""" 281 | return self._train_dataset 282 | 283 | @property 284 | def validation_dataset(self) -> Optional[ReceptorLigandDataset]: 285 | """The validation dataset.""" 286 | return self._validation_dataset 287 | 288 | @property 289 | def test_dataset(self) -> Optional[ReceptorLigandDataset]: 290 | """The test dataset.""" 291 | return self._test_dataset 292 | 293 | @property 294 | def batch_size(self): 295 | """The batch size.""" 296 | return self._batch_size 297 | 298 | def generate_example(self) -> ProteinGraph: 299 | """ 300 | Generates an example graph from random data. 301 | """ 302 | num_res_per_batch = 5 303 | num_residues = self._batch_size * num_res_per_batch 304 | 305 | cdr_rotations = rand_matrix(num_residues, device=self._device) 306 | cdr_translations = torch.randn((num_residues, 3), device=self._device) 307 | cdr_batch = torch.arange( 308 | self._batch_size, device=self._device 309 | ).repeat_interleave(num_res_per_batch) 310 | cdr_frames = OrientationFrames(cdr_rotations, cdr_translations, batch=cdr_batch) 311 | 312 | dummy_time_step_encoding = torch.zeros( 313 | self._time_step_encoding_channels, device=cdr_frames.translations.device 314 | ) 315 | 316 | epitope_N = torch.randn((num_residues, 3), device=self._device) 317 | epitope_CA = torch.randn((num_residues, 3), device=self._device) 318 | epitope_C = torch.randn((num_residues, 3), device=self._device) 319 | epitope_CB = torch.randn((num_residues, 3), device=self._device) 320 | epitope_sequence = torch.randint(20, size=(num_residues,), device=self._device) 321 | epitope_batch = torch.arange( 322 | self._batch_size, device=self._device 323 | ).repeat_interleave(num_res_per_batch) 324 | epitope_structure = Structure( 325 | epitope_N, 326 | epitope_CA, 327 | epitope_C, 328 | epitope_CB, 329 | epitope_sequence, 330 | batch=epitope_batch, 331 | ) 332 | 333 | example_graph = self.create_graph( 334 | cdr_frames, 335 | epitope_structure, 336 | time_step_encoding=dummy_time_step_encoding, 337 | use_cdr_positional_encoding=self._use_cdr_positional_encoding, 338 | num_pos_encoding_channels=self._positional_encoding_channels, 339 | add_pad_cdr_features=self.add_pad_cdr_features, 340 | num_pad_cdr_features=self.num_pad_cdr_features, 341 | num_pad_cdr_vec_features=self.num_pad_cdr_vec_features, 342 | pad_feature_value=self._pad_feature_value, 343 | ) 344 | 345 | return example_graph 346 | 347 | def train_dataloader(self) -> torch.utils.data.DataLoader: 348 | """The train dataloader using the train dataset.""" 349 | return torch.utils.data.DataLoader( 350 | self._train_dataset, 351 | shuffle=True, 352 | collate_fn=self.collate, 353 | batch_size=self._batch_size, 354 | ) 355 | 356 | def val_dataloader(self) -> torch.utils.data.DataLoader: 357 | """ 358 | The validation dataloader using the validation dataset. This is 359 | typically used for hyperparameter searches or early stopping. 360 | """ 361 | return torch.utils.data.DataLoader( 362 | self._validation_dataset, 363 | collate_fn=self.collate, 364 | shuffle=False, 365 | batch_size=self._batch_size, 366 | ) 367 | 368 | def test_dataloader(self) -> torch.utils.data.DataLoader: 369 | """ 370 | The test dataloader using the test dataset. This is typically used 371 | for final model evaluation. 372 | """ 373 | 374 | return torch.utils.data.DataLoader( 375 | self._test_dataset, 376 | collate_fn=self.collate, 377 | shuffle=False, 378 | batch_size=self._batch_size, 379 | ) 380 | 381 | @staticmethod 382 | def fix_complex_by_cdr( 383 | epitope: Structure, 384 | cdr: OrientationFrames, 385 | ) -> Tuple[Structure, Union[OrientationFrames, LinearStructure]]: 386 | """ 387 | Translates the coordinates of the CDR and epitope structures so the CDR 388 | structure is centered at the coordinate specified by `fixed_cdr_coord`. Returns two new objects 389 | (one epitope Structure and one CDR OrientationFrames) with translated coordinates. 390 | """ 391 | fixed_cdr, cdr_centroids = cdr.center() 392 | 393 | if not epitope.has_batch: 394 | epitope_batch = torch.zeros( 395 | len(epitope), device=cdr_centroids.device, dtype=torch.int64 396 | ) 397 | else: 398 | epitope_batch = epitope.batch 399 | 400 | fixed_epitope = epitope.translate(-cdr_centroids[epitope_batch]) 401 | 402 | return fixed_epitope, fixed_cdr 403 | 404 | def collate( 405 | self, structures: List[Tuple[str, Structure, LinearStructure]] 406 | ) -> CDRFramesBatch: 407 | """ 408 | Combines a list of tuples of the form (names, antigen Structure, cdr Structure) into a 409 | graph, with the CDR represented only as a set of orientation frames. Then samples a time 410 | step and uses the forward process to noise the CDR orientation frames and gets the 411 | rotation/translation scores for the new samples under the noising distribution. 412 | 413 | Returns a tuple consisting of: 414 | 1. Tuple of CDR names, one for each in the batch 415 | 2. Batch of epitope structures (if `exclude_antigen` is False, otherwise None) 416 | 3. Batch of CDR structures 417 | """ 418 | names, epitopes, cdrs = map(tuple, zip(*structures)) 419 | 420 | epitope = Structure.combine(epitopes) 421 | cdr_frames = OrientationFrames.combine([cdr.orientation_frames for cdr in cdrs]) 422 | 423 | epitope, cdr_frames = self.fix_complex_by_cdr(epitope, cdr_frames) 424 | 425 | return names, epitope, cdr_frames 426 | 427 | @property 428 | def add_pad_cdr_features(self): 429 | """Whether to add pad features (default False).""" 430 | return False 431 | 432 | @property 433 | def num_pad_cdr_features(self) -> int: 434 | """Number of pad CDR features (default 0).""" 435 | return 0 436 | 437 | @property 438 | def num_pad_cdr_vec_features(self) -> int: 439 | """Number of pad CDR vector features (default 0).""" 440 | return 0 441 | 442 | 443 | class CDRCoordinateDataModule(CDRDiffusionDataModule): 444 | """ 445 | A datamodule that loads epitope/CDR structures for coordinate diffusion. 446 | """ 447 | 448 | create_graph = staticmethod(create_coord_graph) 449 | 450 | 451 | class CDRFrameDataModule(CDRDiffusionDataModule): 452 | """ 453 | A datamodule that loads epitope/CDR structures for frame diffusion. 454 | """ 455 | 456 | create_graph = staticmethod(create_frame_graph) 457 | 458 | @property 459 | def add_pad_cdr_features(self) -> bool: 460 | """Add pad CDR features for frame diffusion if using self conditions.""" 461 | return self._using_self_conditioning 462 | 463 | @property 464 | def num_pad_cdr_features(self) -> int: 465 | """Number of pad CDR features for frame diffusion.""" 466 | return 3 if self._using_self_conditioning else 0 467 | 468 | @property 469 | def num_pad_cdr_vec_features(self) -> int: 470 | """Number of pad CDR vector features for frame diffusion.""" 471 | return 1 if self._using_self_conditioning else 0 472 | -------------------------------------------------------------------------------- /loopgen/model/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generates new structures for some input epitopes using a LoopGen model. 3 | 4 | If input is a PDB file, treats each model as a separate epitope and generates CDR loop structures for each. 5 | These generated structures are saved as different models in an output PDB file. 6 | 7 | If input is an HDF5 file, generates CDR loop structures for each epitope in the file (see README.md for 8 | file format details). The generated structures are then stored in an output HDF5 file with the same 9 | format as the input file, with keys of the form "generated_{i}" (i labels which generated structure is present) 10 | added to each structure pair group (at the same level as the keys "receptor" and "ligand"). 11 | 12 | Generated structures are saved as different models in an output PDB file. 13 | """ 14 | 15 | from typing import List, Tuple, Union, Type 16 | import logging as lg 17 | import os 18 | import argparse 19 | from pathlib import Path 20 | 21 | import torch 22 | import numpy as np 23 | import h5py 24 | import pytorch_lightning as pl 25 | import pandas as pd 26 | from Bio.PDB import PDBParser 27 | 28 | from . import setup_model 29 | from .settings import ModelSettings 30 | from .datamodule import CDRFrameDataModule, CDRCoordinateDataModule 31 | from .model import CDRFrameDiffusionModel, CDRCoordinateDiffusionModel 32 | from .utils import permute_cdrs, permute_epitopes, translate_cdrs_away 33 | from .metrics import ( 34 | get_rmsd, 35 | get_clash_loss, 36 | get_epitope_cdr_clashes, 37 | get_bond_length_loss, 38 | get_bond_angle_loss, 39 | mean_pairwise_rmsd, 40 | pca, 41 | ) 42 | from .types import CDRFramesBatch 43 | from ..structure import Structure, LinearStructure 44 | from ..data import ( 45 | ReceptorLigandDataset, 46 | ReceptorLigandPair, 47 | StructureDict, 48 | load_splits_file, 49 | ) 50 | from ..utils import get_device 51 | 52 | 53 | def add_generate_args(parser: argparse.ArgumentParser) -> None: 54 | """ 55 | Adds command line arguments for generation to a parser. 56 | """ 57 | 58 | parser.add_argument( 59 | "data_path", 60 | type=str, 61 | help="Path to the HDF5 data file or a PDB file containing epitope data to be used for generation.", 62 | ) 63 | parser.add_argument( 64 | "--checkpoint", 65 | required=True, 66 | help="Path to a PyTorch checkpoint file containing model weights.", 67 | ) 68 | parser.add_argument( 69 | "--config", 70 | type=str, 71 | required=True, 72 | help="Path to YAML file containing the settings.", 73 | ) 74 | parser.add_argument( 75 | "--splits", 76 | type=str, 77 | default=None, 78 | help="Path to a JSON file containing the names of instances stored under " 79 | "the keys 'train', 'validation', and 'test'. Only used if an HDF5 file is given as a data set. " 80 | "If provided, structures will only be generated for the test set. " 81 | "If not provided, structures will be generated " 82 | "for all instances in the input HDF5 file.", 83 | ) 84 | parser.add_argument( 85 | "-n", 86 | type=int, 87 | default=10, 88 | help="Number of structures to generate for each input epitope.", 89 | ) 90 | parser.add_argument( 91 | "-l", 92 | "--length", 93 | type=int, 94 | default=10, 95 | help="Length of generated CDRs (only used if input file is a PDB file).", 96 | ) 97 | parser.add_argument( 98 | "--seed", 99 | type=int, 100 | default=123, 101 | help="Random seed for reproducibility.", 102 | ) 103 | parser.add_argument( 104 | "--noise_scale", 105 | type=float, 106 | default=0.2, 107 | help="Scale of the noise used in the reverse process. Higher values " 108 | "generate more diverse samples at the cost of lower quality. " 109 | "We have found that only values <0.5 generate valid structures.", 110 | ) 111 | parser.add_argument( 112 | "--out_dir", 113 | default=".", 114 | help="Directory in which results (predictions) will be saved.", 115 | ) 116 | parser.add_argument( 117 | "--permute_epitopes", 118 | action="store_true", 119 | help="Whether to permute epitopes during the generation. This randomly swaps epitopes between " 120 | "complexes, aligning the principal components of the swapped epitope to the original epitope.", 121 | ) 122 | parser.add_argument( 123 | "--scramble_epitopes", 124 | action="store_true", 125 | help="Whether to scramble each epitope's sequence, randomly swapping residue identities.", 126 | ) 127 | parser.add_argument( 128 | "--translate_cdrs", 129 | action="store_true", 130 | help="Whether to translate each CDR 20 angstroms away from the epitope.", 131 | ) 132 | parser.add_argument( 133 | "--device", 134 | default=None, 135 | choices=["cpu", "gpu"], 136 | help="Which device to use (defaults to None, in which case the " 137 | "GPU is used if available, and if not the CPU is used).", 138 | ) 139 | 140 | 141 | def generate( 142 | datamodule: Union[CDRFrameDataModule, CDRCoordinateDataModule], 143 | model: Union[CDRFrameDiffusionModel, CDRCoordinateDiffusionModel], 144 | n: int, 145 | use_epitope_permutation: bool, 146 | use_epitope_scrambling: bool, 147 | use_cdr_translation: bool, 148 | noise_scale: float, 149 | seed: int, 150 | ) -> List[Tuple[str, Structure, LinearStructure, Tuple[LinearStructure, ...]]]: 151 | """ 152 | Generates samples from a model. 153 | 154 | :param datamodule: The datamodule for collating data. 155 | :param model: The model 156 | :param n: Number of CDR loops to generate for each epitope. 157 | :param use_epitope_permutation: Whether to permute epitopes in the dataset, swapping each CDR's 158 | epitope with a random one and aligning the new epitope to the old epitope's principal components. 159 | :param use_epitope_scrambling: Whether to scramble the epitope's sequence. 160 | :param use_cdr_translation: Whether to translate each CDR away from the epitope. 161 | :param noise_scale: Scale of the noise used in the reverse process. Higher values 162 | generate more diverse samples at the cost of lower quality. 163 | :param seed: Random seed. 164 | :returns: A list of tuples containing the name of the complex, the epitope, the ground truth CDR, 165 | and a tuple of generated CDR structures. 166 | """ 167 | 168 | def gen_batch(batch: CDRFramesBatch) -> List[Tuple[LinearStructure, ...]]: 169 | """Generation for a single batch, returning a list of generated CDR structures for each batch element.""" 170 | 171 | with torch.no_grad(): 172 | output = model.generate(batch, noise_scale=noise_scale, n=n) 173 | output_structures = [ 174 | LinearStructure.from_frames(f).detach().to(torch.device("cpu")) 175 | for f in output[-1] 176 | ] 177 | output_structures_split = [s.split() for s in output_structures] 178 | 179 | gen_cdr_structures = list(map(tuple, zip(*output_structures_split))) 180 | return gen_cdr_structures 181 | 182 | pl.seed_everything(seed, workers=True) 183 | 184 | datamodule.setup("generate") 185 | 186 | all_instances = [ 187 | datamodule.test_dataset[i] for i in range(len(datamodule.test_dataset)) 188 | ] 189 | 190 | if use_epitope_permutation: 191 | all_instances = permute_epitopes(all_instances) 192 | 193 | if use_epitope_scrambling: 194 | all_instances = [ 195 | (name, ep.scramble_sequence(), cdr) for name, ep, cdr in all_instances 196 | ] 197 | 198 | if use_cdr_translation: 199 | all_instances = translate_cdrs_away(all_instances) 200 | 201 | outputs = [] 202 | for i in range(0, len(datamodule.test_dataset), datamodule.batch_size): 203 | start = i 204 | end = min(i + datamodule.batch_size, len(datamodule.test_dataset)) 205 | batch_instances = all_instances[start:end] 206 | 207 | batch = datamodule.collate(batch_instances) 208 | names, epitope, _ = batch 209 | 210 | epitopes = epitope.split() 211 | cdrs = [cdr.center(return_centre=False) for _, _, cdr in batch_instances] 212 | 213 | gen_cdr_structures = gen_batch(batch) 214 | 215 | for name, ep, cdr, pred_cdrs in zip(names, epitopes, cdrs, gen_cdr_structures): 216 | outputs.append( 217 | ( 218 | name, 219 | ep, 220 | cdr, 221 | pred_cdrs, 222 | ) 223 | ) 224 | 225 | return outputs 226 | 227 | 228 | def generate_from_args( 229 | args: argparse.Namespace, model_class: Type[pl.LightningModule] 230 | ) -> None: 231 | """Runs `generate()` using command line arguments.""" 232 | device, accelerator = get_device(args.device) 233 | 234 | settings = ModelSettings.from_yaml(args.config) 235 | 236 | is_pdb = args.data_path.endswith(".pdb") 237 | 238 | if is_pdb and ( 239 | args.permute_epitopes or args.scramble_epitopes or args.translate_cdrs 240 | ): 241 | lg.warning( 242 | "Epitope permutation/scrambling and CDR translating is not available for PDB input, " 243 | "switch to HDF5 format instead." 244 | ) 245 | 246 | lg.basicConfig(format="%(asctime)s %(levelname)-8s: %(message)s") 247 | 248 | if not os.path.exists(args.out_dir): 249 | lg.info( 250 | f"Specified output directory {args.out_dir} does not exist, creating..." 251 | ) 252 | os.mkdir(args.out_dir) 253 | 254 | if is_pdb: 255 | pdb_id = Path(args.data_path).stem 256 | parser = PDBParser() 257 | structure = parser.get_structure(pdb_id, args.data_path) 258 | 259 | structure_pairs = [] 260 | for i, model in enumerate(structure.get_models(), start=1): 261 | residues = list(model.get_residues()) 262 | 263 | epitope_structure = Structure.from_pdb_residues(residues) 264 | epitope_dict = StructureDict( 265 | N_coords=epitope_structure.N_coords.cpu().numpy(), 266 | CA_coords=epitope_structure.CA_coords.cpu().numpy(), 267 | C_coords=epitope_structure.C_coords.cpu().numpy(), 268 | CB_coords=epitope_structure.CB_coords.cpu().numpy(), 269 | sequence=epitope_structure.sequence.cpu().numpy(), 270 | ) 271 | 272 | cdr_coords = np.zeros((args.length, 3)) 273 | cdr_sequence = np.zeros((args.length,), dtype=int) 274 | cdr_dict = StructureDict( 275 | N_coords=cdr_coords, 276 | CA_coords=cdr_coords, 277 | C_coords=cdr_coords, 278 | CB_coords=cdr_coords, 279 | sequence=cdr_sequence, 280 | ) 281 | 282 | structure_pair = ReceptorLigandPair( 283 | name=f"{pdb_id}_{i}", receptor=epitope_dict, ligand=cdr_dict 284 | ) 285 | 286 | structure_pairs.append(structure_pair) 287 | 288 | dataset = ReceptorLigandDataset(structure_pairs, device=device) 289 | splits = None 290 | hdf5_file = None 291 | out_hdf5_file = None 292 | 293 | else: 294 | hdf5_file = h5py.File(args.data_path) 295 | dataset = ReceptorLigandDataset.from_hdf5_file(args.data_path, device=device) 296 | splits = load_splits_file(args.splits, dataset) 297 | 298 | base_filepath = f"{Path(args.data_path).stem}_generated" 299 | 300 | if args.permute_epitopes: 301 | base_filepath = f"{base_filepath}_permuted" 302 | 303 | if args.scramble_epitopes: 304 | base_filepath = f"{base_filepath}_scrambled" 305 | 306 | if args.translate_cdrs: 307 | base_filepath = f"{base_filepath}_translated" 308 | 309 | out_hdf5_filepath = os.path.join(args.out_dir, f"{base_filepath}.hdf5") 310 | i = 1 311 | while os.path.exists(out_hdf5_filepath): 312 | out_hdf5_filepath = f"{base_filepath}_v{i}.hdf5" 313 | i += 1 314 | out_hdf5_file = h5py.File(out_hdf5_filepath, "w") 315 | 316 | param_dict = settings.distribute_model_params(model_class) 317 | 318 | datamodule, model = setup_model( 319 | dataset, splits, param_dict, model_class, args.checkpoint 320 | ) 321 | 322 | model = model.to(device) 323 | model.eval() 324 | 325 | outputs = generate( 326 | datamodule, 327 | model, 328 | args.n, 329 | args.permute_epitopes, 330 | args.scramble_epitopes, 331 | args.translate_cdrs, 332 | args.noise_scale, 333 | args.seed, 334 | ) 335 | 336 | for name, epitope, cdr, pred_cdrs in outputs: 337 | if is_pdb: 338 | pred_cdr_batch = LinearStructure.combine(pred_cdrs) 339 | pred_cdr_batch.write_to_pdb( 340 | os.path.join( 341 | args.out_dir, f"{name.lstrip('/').replace('/', '_')}_generated.pdb" 342 | ) 343 | ) 344 | else: 345 | group = out_hdf5_file.create_group(name) 346 | epitope.write_to_hdf5(group.create_group("receptor")) 347 | cdr.write_to_hdf5(group.create_group("ligand")) 348 | for i, pred_cdr in enumerate(pred_cdrs, start=1): 349 | pred_cdr_group = out_hdf5_file[name].create_group(f"generated_{i}") 350 | pred_cdr.write_to_hdf5(pred_cdr_group) 351 | 352 | if not is_pdb: 353 | out_hdf5_file.close() 354 | hdf5_file.close() 355 | -------------------------------------------------------------------------------- /loopgen/model/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics for evaluating models. 3 | """ 4 | 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | import einops 9 | from torch_scatter import scatter_sum, scatter_mean 10 | 11 | from ..structure import ( 12 | Structure, 13 | LinearStructure, 14 | BondLengths, 15 | BondAngles, 16 | BondAngleStdDevs, 17 | BondLengthStdDevs, 18 | BondAngleCosineStdDevs, 19 | AtomVanDerWaalRadii, 20 | ) 21 | 22 | from ..utils import get_distance_matrix 23 | 24 | # these are calculated using an analytical solution assuming a normal distribution of angles 25 | # with the given mean and standard deviation 26 | N_CA_C_ANGLE_COS_MEAN = torch.exp( 27 | torch.as_tensor(BondAngleStdDevs["N_CA_C"].value ** 2) / 2 28 | ) * torch.cos(torch.as_tensor(BondAngles["N_CA_C"].value)) 29 | CA_C_N_ANGLE_COS_MEAN = torch.exp( 30 | torch.as_tensor(BondAngleStdDevs["CA_C_N"].value ** 2) / 2 31 | ) * torch.cos(torch.as_tensor(BondAngles["CA_C_N"].value)) 32 | C_N_CA_ANGLE_COS_MEAN = torch.exp( 33 | torch.as_tensor(BondAngleStdDevs["C_N_CA"].value ** 2) / 2 34 | ) * torch.cos(torch.as_tensor(BondAngles["C_N_CA"].value)) 35 | 36 | N_CA_C_ANGLE_COS_STD = torch.as_tensor(BondAngleCosineStdDevs["N_CA_C"].value) 37 | CA_C_N_ANGLE_COS_STD = torch.as_tensor(BondAngleCosineStdDevs["CA_C_N"].value) 38 | C_N_CA_ANGLE_COS_STD = torch.as_tensor(BondAngleCosineStdDevs["C_N_CA"].value) 39 | 40 | 41 | def get_clash_loss(structure: LinearStructure) -> torch.Tensor: 42 | """ 43 | Calculates loss penalising steric clashes, calculated using Van Der Waals radii. 44 | Specifically any non-covalently bonded atoms whose Van der Waals radii overlap 45 | are deemed a clash. Implemented exactly the same way as in AlphaFold2 46 | (https://www.nature.com/articles/s41586-021-03819-2). 47 | 48 | :param structure: The LinearStructure with N residues for which the clash loss will be calculated. 49 | The structure must be linear so that the covalent bond structure can be inferred. 50 | :returns: Tensor of shape (N,) containing the clash loss for each residue. 51 | """ 52 | N_N_lit_dist = 2.0 * AtomVanDerWaalRadii["N"].value 53 | C_C_lit_dist = 2.0 * AtomVanDerWaalRadii["C"].value 54 | C_N_lit_dist = AtomVanDerWaalRadii["C"].value + AtomVanDerWaalRadii["N"].value 55 | 56 | N_dist = get_distance_matrix( 57 | structure.N_coords, batch=structure.batch, pad_value=torch.inf 58 | ) 59 | CA_dist = get_distance_matrix( 60 | structure.CA_coords, batch=structure.batch, pad_value=torch.inf 61 | ) 62 | C_dist = get_distance_matrix( 63 | structure.C_coords, batch=structure.batch, pad_value=torch.inf 64 | ) 65 | 66 | N_CA_dist = get_distance_matrix( 67 | structure.N_coords, 68 | structure.CA_coords, 69 | batch=structure.batch, 70 | pad_value=torch.inf, 71 | ) 72 | N_C_dist = get_distance_matrix( 73 | structure.N_coords, 74 | structure.C_coords.roll(1, dims=-2), 75 | batch=structure.batch, 76 | pad_value=torch.inf, 77 | ) 78 | CA_C_dist = get_distance_matrix( 79 | structure.CA_coords, 80 | structure.C_coords, 81 | batch=structure.batch, 82 | pad_value=torch.inf, 83 | ) 84 | 85 | # fill diagonals so that covalently bonded atoms are not penalised for being within VDW radius 86 | diag_mask_matrix = ( 87 | torch.eye(N_dist.shape[-2], N_dist.shape[-2], device=structure.CA_coords.device) 88 | .bool() 89 | .expand(N_dist.shape[:-2] + (-1, -1)) 90 | ) 91 | 92 | pad_mask = torch.all(CA_dist == torch.inf, dim=-1) 93 | N_dist = torch.where(diag_mask_matrix, torch.inf, N_dist) 94 | CA_dist = torch.where(diag_mask_matrix, torch.inf, CA_dist) 95 | C_dist = torch.where(diag_mask_matrix, torch.inf, C_dist) 96 | N_CA_dist = torch.where(diag_mask_matrix, torch.inf, N_CA_dist) 97 | N_C_dist = torch.where(diag_mask_matrix, torch.inf, N_C_dist) 98 | CA_C_dist = torch.where(diag_mask_matrix, torch.inf, CA_C_dist) 99 | 100 | N_clash_loss = torch.clamp_min(N_N_lit_dist - 1.5 - N_dist, 0.0) / 2 101 | CA_clash_loss = torch.clamp_min(C_C_lit_dist - 1.5 - CA_dist, 0.0) / 2 102 | C_clash_loss = torch.clamp_min(C_C_lit_dist - 1.5 - C_dist, 0.0) / 2 103 | N_CA_clash_loss = torch.clamp_min(C_N_lit_dist - 1.5 - N_CA_dist, 0.0) 104 | N_C_clash_loss = torch.clamp_min(C_N_lit_dist - 1.5 - N_C_dist, 0.0) 105 | CA_C_clash_loss = torch.clamp_min(C_C_lit_dist - 1.5 - CA_C_dist, 0.0) 106 | 107 | total_class_loss = torch.sum( 108 | N_clash_loss 109 | + CA_clash_loss 110 | + C_clash_loss 111 | + N_CA_clash_loss 112 | + N_C_clash_loss 113 | + CA_C_clash_loss, 114 | dim=-1, 115 | ) 116 | 117 | total_class_loss = total_class_loss[~pad_mask] 118 | 119 | return total_class_loss 120 | 121 | 122 | def get_bond_angle_loss(structure: LinearStructure, num_stds: int = 12) -> torch.Tensor: 123 | """ 124 | Calculates a bond angle loss term, depending on deviations 125 | between predicted backbone bond angles and their literature values. 126 | Specifically this uses a flat-bottomed loss that only takes values >0 127 | if the cosine of the bond angle is outside of the mean +/- num_stds * std. 128 | Implemented the same way as in AlphaFold2 (https://www.nature.com/articles/s41586-021-03819-2). 129 | 130 | :param structure: The LinearStructure with N residues to calculate the bond angle loss for. 131 | :param num_stds: The number of standard deviations to use for the flat-bottomed loss. Default 132 | is 12, which is the value used in AlphaFold2. 133 | :returns: Tensor of shape (N,) containing the bond angle loss for each residue. 134 | """ 135 | N_CA_vectors = torch.nn.functional.normalize( 136 | structure.N_coords - structure.CA_coords 137 | ) 138 | C_CA_vectors = torch.nn.functional.normalize( 139 | structure.C_coords - structure.CA_coords 140 | ) 141 | # Roll the N coords, so that the coords are lined up correctly, and then cut the last element 142 | C_N_vectors = torch.nn.functional.normalize( 143 | structure.C_coords - structure.N_coords.roll(-1, dims=-2) 144 | ) 145 | 146 | cos_N_CA_C_bond_angles = torch.sum(N_CA_vectors * C_CA_vectors, dim=-1) 147 | cos_CA_C_N_bond_angles = torch.sum(C_CA_vectors * C_N_vectors, dim=-1) 148 | cos_C_N_CA_bond_angles = torch.sum( 149 | C_N_vectors * -N_CA_vectors.roll(-1, dims=-2), dim=-1 150 | ) 151 | 152 | if structure.has_batch: 153 | cos_CA_C_N_bond_angles = cos_CA_C_N_bond_angles.index_fill( 154 | 0, structure.ptr[1:] - 1, CA_C_N_ANGLE_COS_MEAN 155 | ) 156 | cos_C_N_CA_bond_angles = cos_C_N_CA_bond_angles.index_fill( 157 | 0, structure.ptr[1:] - 1, C_N_CA_ANGLE_COS_MEAN 158 | ) 159 | else: 160 | cos_CA_C_N_bond_angles[-1] = CA_C_N_ANGLE_COS_MEAN 161 | cos_C_N_CA_bond_angles[-1] = C_N_CA_ANGLE_COS_MEAN 162 | 163 | N_CA_C_angle_loss = ( 164 | torch.clamp_min( 165 | torch.abs(cos_N_CA_C_bond_angles - N_CA_C_ANGLE_COS_MEAN) 166 | - num_stds * N_CA_C_ANGLE_COS_STD, 167 | 0.0, 168 | ) 169 | / cos_N_CA_C_bond_angles.shape[-1] 170 | ) 171 | 172 | CA_C_N_angle_loss = ( 173 | torch.clamp_min( 174 | torch.abs(cos_CA_C_N_bond_angles - CA_C_N_ANGLE_COS_MEAN) 175 | - num_stds * CA_C_N_ANGLE_COS_STD, 176 | 0.0, 177 | ) 178 | / cos_CA_C_N_bond_angles.shape[-1] 179 | ) 180 | 181 | C_N_CA_angle_loss = ( 182 | torch.clamp_min( 183 | torch.abs(cos_C_N_CA_bond_angles - C_N_CA_ANGLE_COS_MEAN) 184 | - num_stds * C_N_CA_ANGLE_COS_STD, 185 | 0.0, 186 | ) 187 | / cos_C_N_CA_bond_angles.shape[-1] 188 | ) 189 | 190 | return N_CA_C_angle_loss + CA_C_N_angle_loss + C_N_CA_angle_loss 191 | 192 | 193 | def get_bond_length_loss( 194 | structure: LinearStructure, num_stds: int = 12 195 | ) -> torch.Tensor: 196 | """ 197 | Calculates a bond length loss term, depending on deviations 198 | between predicted backbone bond lengths and their literature values. 199 | Specifically this uses a flat-bottomed loss that only takes values >0 200 | if the bond length is outside of the mean +/- num_stds * std. 201 | Implemented the same way as in AlphaFold2 (https://www.nature.com/articles/s41586-021-03819-2). 202 | 203 | :param structure: The LinearStructure with N residues to calculate the bond angle loss for. 204 | :param num_stds: The number of standard deviations to use for the flat-bottomed loss. Default 205 | is 12, which is the value used in AlphaFold2. 206 | :returns: Tensor of shape (N,) containing the bond angle loss for each residue. 207 | """ 208 | N_CA_bond_lengths = torch.linalg.norm( 209 | structure.N_coords - structure.CA_coords, dim=-1 210 | ) 211 | CA_C_bond_lengths = torch.linalg.norm( 212 | structure.CA_coords - structure.C_coords, dim=-1 213 | ) 214 | # Roll the N coords, so that the coords are lined up correctly, and then cut the last element 215 | C_N_bond_lengths = torch.linalg.norm( 216 | structure.C_coords - structure.N_coords.roll(-1, dims=-2), 217 | dim=-1, 218 | ) 219 | 220 | if structure.has_batch: 221 | C_N_bond_lengths = C_N_bond_lengths.index_fill( 222 | 0, structure.ptr[1:] - 1, BondLengths["C_N"].value 223 | ) 224 | else: 225 | C_N_bond_lengths[-1] = BondLengths["C_N"].value 226 | 227 | N_CA_length_loss = ( 228 | torch.clamp_min( 229 | torch.abs(N_CA_bond_lengths - BondLengths["N_CA"].value) 230 | - num_stds * BondLengthStdDevs["N_CA"].value, 231 | 0.0, 232 | ) 233 | / N_CA_bond_lengths.shape[0] 234 | ) 235 | 236 | CA_C_length_loss = ( 237 | torch.clamp_min( 238 | torch.abs(CA_C_bond_lengths - BondLengths["CA_C"].value) 239 | - num_stds * BondLengthStdDevs["CA_C"].value, 240 | 0.0, 241 | ) 242 | / CA_C_bond_lengths.shape[0] 243 | ) 244 | 245 | C_N_length_loss = ( 246 | torch.clamp_min( 247 | torch.abs(C_N_bond_lengths - BondLengths["C_N"].value) 248 | - num_stds * BondLengthStdDevs["C_N"].value, 249 | 0.0, 250 | ) 251 | / C_N_bond_lengths.shape[0] 252 | ) 253 | 254 | return N_CA_length_loss + CA_C_length_loss + C_N_length_loss 255 | 256 | 257 | def get_violations( 258 | structure: LinearStructure, 259 | ) -> torch.Tensor: 260 | """ 261 | Identifies whether each structure in the batch has any structural violation, 262 | i.e. non-zero values for the bond length, bond angle, and clash loss terms. 263 | Returns a binary float tensor with a 1 for each structure with a violation, 264 | and a 0 for each structure without a violation. 265 | 266 | :param structure: A LinearStructure object of N residues containing the predicted coordinates. 267 | :returns: Tensor of shape (N,) containing a 1 for each structure with a violation, and a 0 for each 268 | structure without a violation. 269 | """ 270 | if structure.has_batch: 271 | batch = structure.batch 272 | else: 273 | batch = torch.zeros( 274 | len(structure), device=structure.CA_coords.device, dtype=torch.long 275 | ) 276 | 277 | bond_len_loss = get_bond_length_loss(structure) 278 | bond_ang_loss = get_bond_angle_loss(structure) 279 | cl_loss = get_clash_loss(structure) 280 | 281 | loss_per_structure = scatter_sum( 282 | bond_len_loss + bond_ang_loss + cl_loss, 283 | batch, 284 | dim=0, 285 | ) 286 | 287 | violations = (loss_per_structure > 0).float() 288 | 289 | return violations 290 | 291 | 292 | def get_epitope_cdr_clashes( 293 | cdr: Structure, epitope: Structure, threshold: float = 3.5 294 | ) -> torch.Tensor: 295 | """ 296 | Returns a binary tensor indicating whether each CDR residue clashes with 297 | any epitope residue with the same batch assignment. 298 | """ 299 | if not cdr.has_batch == epitope.has_batch: 300 | raise ValueError( 301 | "CDR and epitope must either both have batches or neither have batches." 302 | ) 303 | 304 | cdr_epitope_distances = get_distance_matrix( 305 | cdr.CA_coords, 306 | epitope.CA_coords, 307 | batch=cdr.batch, 308 | other_batch=epitope.batch, 309 | pad_value=torch.inf, 310 | ) 311 | 312 | clashing = torch.any( 313 | cdr_epitope_distances.flatten(end_dim=1) < threshold, dim=-1 314 | ).float() 315 | 316 | return clashing 317 | 318 | 319 | def get_rmsd( 320 | coords_1: torch.Tensor, 321 | coords_2: torch.Tensor, 322 | batch: Optional[torch.Tensor] = None, 323 | dim: int = 0, 324 | ) -> torch.Tensor: 325 | """ 326 | Calculates root mean squared deviation for two tensors of coordinates. Note 327 | that this does not perform the typical centering or Kabsch rotation to align 328 | the two coordinates before calculating RMSD. 329 | 330 | :param coords_1: One tensor of coordinates. 331 | :param coords_2: Another tensor of coordinates. 332 | :param batch: Optional batch tensor specifying the batch to which each coordinate belongs. 333 | If this is provided and reduce is set to True, the returned Tensor will have multiple 334 | elements, i.e. one RMSD for each batch element (default: None). 335 | :param dim: Dimension over which the mean squared deviation is calculated (default: 0). 336 | :return: Float tensor containing the RMSD value(s). 337 | """ 338 | sq_distances = torch.sum((coords_1 - coords_2) ** 2, dim=-1) 339 | 340 | if batch is not None: 341 | mean_sq_distance = scatter_mean(sq_distances, batch, dim=dim) 342 | else: 343 | mean_sq_distance = torch.mean(sq_distances, dim=dim) 344 | 345 | root_mean_sq_dev = torch.sqrt(mean_sq_distance) 346 | 347 | return root_mean_sq_dev 348 | 349 | 350 | def mean_pairwise_rmsd( 351 | pred_coords: list[torch.Tensor], other_coords: list[torch.Tensor] 352 | ): 353 | """ 354 | Gets the mean of all pairwise RMSDs between two sets of 355 | predicted coordinates. They must be the same length. 356 | """ 357 | 358 | coords_stacked = torch.stack(pred_coords) 359 | coords_stacked -= torch.mean(coords_stacked, dim=1, keepdim=True) 360 | other_coords_stacked = torch.stack(other_coords) 361 | other_coords_stacked -= torch.mean(other_coords_stacked, dim=1, keepdim=True) 362 | 363 | coords_1 = einops.rearrange(coords_stacked, "b r d -> 1 b r d") 364 | coords_2 = einops.rearrange(other_coords_stacked, "b r d -> b 1 r d") 365 | rmsds = torch.sum((coords_1 - coords_2) ** 2, dim=-1).mean(dim=-1).sqrt() 366 | 367 | # if the two sets of coordinates are the same, distance matrix will contain repeats so 368 | # take upper triangular 369 | if torch.allclose(coords_stacked, other_coords_stacked): 370 | idx1, idx2 = torch.triu_indices(rmsds.shape[0], rmsds.shape[1], offset=1) 371 | rmsds = rmsds[idx1, idx2] 372 | 373 | return torch.mean(rmsds) 374 | 375 | 376 | def pca(coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 377 | """ 378 | Performs a PCA on a set of coordinates. Returns the principal components 379 | and their associated eigenvalues. 380 | 381 | :param coords: Tensor of shape (N, 3) containing the coordinates to perform PCA on. 382 | :returns: Tuple of tensors containing the principal components and their associated eigenvalues. 383 | """ 384 | U, S, V = torch.pca_lowrank(coords) 385 | return V, torch.diag(S) 386 | -------------------------------------------------------------------------------- /loopgen/model/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains score prediction networks for CDR backbone diffusion model, 3 | which predict 3-D score vectors (gradient of the forward process log density) 4 | for each node in a graph. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Union, List, Optional, Any, Dict 10 | 11 | import torch 12 | from torch import nn 13 | from torch_geometric.data import Data, HeteroData 14 | from torch_geometric.nn import Aggregation, to_hetero 15 | 16 | from einops.layers.torch import Rearrange 17 | 18 | from .types import Score, VectorFeatureGraph, ProteinGraph 19 | from ..nn import GVPN, GVPAttentionTypes 20 | 21 | 22 | class GVPR3ScorePredictor(nn.Module): 23 | """ 24 | Module that uses a geometric vector perceptron to approximate 25 | scores (gradient of log density) in an R3 diffusion model. 26 | 27 | By default, the model outputs a single SO(3)-equivariant 3-D vector 28 | for the score. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | example_batch: ProteinGraph, 34 | out_vector_channels: int = 1, 35 | hidden_scalar_channels: int = 128, 36 | hidden_vector_channels: int = 64, 37 | hidden_edge_scalar_channels: int = 64, 38 | hidden_edge_vector_channels: int = 32, 39 | num_layers: int = 3, 40 | dropout: float = 0.2, 41 | aggr: Union[str, List, Aggregation] = "sum", 42 | attention_type: Optional[GVPAttentionTypes] = None, 43 | use_scaling_gate: bool = True, 44 | share_params: bool = True, 45 | vector_dim_size: int = 3, 46 | **kwargs: Any, 47 | ): 48 | """ 49 | :param example_batch: Example graph batch object (Data/HeteroData) to infer input feature dimensions from. 50 | :param out_vector_channels: Number of output vector features. Defaults to 1 for the single 3-D 51 | SO(3)-equivariant translation score. 52 | :param hidden_scalar_channels: Dimensionality of the hidden scalar features. 53 | :param hidden_vector_channels: Dimensionality of the hidden vector features. 54 | :param hidden_edge_scalar_channels: Dimensionality of the hidden scalar edge features. 55 | :param hidden_edge_vector_channels: Dimensionality of the hidden vector edge features. 56 | :param num_layers: Number of message passing layers. 57 | :param dropout: Dropout rate. 58 | :param aggr: Aggregation scheme used in message passing. 59 | :param attention_type: Type of attention layer used in the graph network (GVPN). 60 | :param use_scaling_gate: Uses a final layer after message passing to output a positive scalar, 61 | which is used to scale the output vector features. This is potentially 62 | useful to allow the network to learn norm information about its outputs. 63 | :param share_params: Whether to share parameters across message passing layers. 64 | :param vector_dim_size: Dimensionality of the vectors used as vector features 65 | (not the number of vector features). 66 | :param kwargs: Additional keyword arguments passed to the graph network (GVPN). 67 | """ 68 | 69 | super().__init__() 70 | 71 | self._hidden_scalar_channels = hidden_scalar_channels 72 | self._hidden_vector_channels = hidden_vector_channels 73 | self._out_vector_channels = out_vector_channels 74 | self._use_scaling_gate = bool(use_scaling_gate) 75 | 76 | # get feature dimensions depending on whether graph is heterogeneous or homogeneous 77 | if isinstance(example_batch, HeteroData): 78 | node_storage = example_batch["ligand"] 79 | in_scalar_channels = node_storage.x.shape[-1] 80 | in_vector_channels = node_storage.vector_x.shape[-2] 81 | edge_storage = example_batch[("ligand", "ligand")] 82 | in_edge_scalar_channels = edge_storage.edge_attr.shape[-1] 83 | in_edge_vector_channels = edge_storage.vector_edge_attr.shape[-2] 84 | else: 85 | in_scalar_channels = example_batch.x.shape[-1] 86 | in_vector_channels = example_batch.vector_x.shape[-2] 87 | in_edge_scalar_channels = example_batch.edge_attr.shape[-1] 88 | in_edge_vector_channels = example_batch.vector_edge_attr.shape[-2] 89 | 90 | graph_network = GVPN( 91 | in_scalar_channels, 92 | hidden_scalar_channels, 93 | in_vector_channels, 94 | hidden_vector_channels, 95 | in_edge_scalar_channels, 96 | hidden_edge_scalar_channels, 97 | in_edge_vector_channels, 98 | hidden_edge_vector_channels, 99 | num_layers=num_layers, 100 | dropout=dropout, 101 | aggr=aggr, 102 | attention_type=attention_type, 103 | share_params=share_params, 104 | vector_dim_size=vector_dim_size, 105 | **kwargs, 106 | ) 107 | 108 | self._graph_network = graph_network 109 | self._vector_pred_layer = None 110 | self._vector_scaling_gate = None 111 | 112 | self._create_prediction_layers() 113 | 114 | self._is_hetero = False 115 | 116 | if isinstance(example_batch, HeteroData): 117 | self.as_hetero(example_batch, aggr) 118 | 119 | def as_hetero( 120 | self, 121 | example_graph: HeteroData, 122 | aggr: str = "sum", 123 | **kwargs: Any, 124 | ): 125 | """ 126 | Modifies the underlying GVPN network to be 127 | compatible with a heterogeneous graph, with 128 | separate message passing parameters for each edge type. 129 | """ 130 | 131 | self._vector_pred_layer = to_hetero( 132 | self._vector_pred_layer, example_graph.metadata(), aggr=aggr 133 | ) 134 | 135 | if self._use_scaling_gate: 136 | self._vector_scaling_gate = to_hetero( 137 | self._vector_scaling_gate, example_graph.metadata(), aggr=aggr 138 | ) 139 | 140 | self._graph_network = self._graph_network.to_hetero(example_graph, aggr) 141 | 142 | if self._is_hetero is True: 143 | return 144 | 145 | if not isinstance(example_graph, HeteroData): 146 | raise TypeError( 147 | f"Argument 'example_graph' should be a torch_geometric HeteroData object." 148 | ) 149 | 150 | self._is_hetero = True 151 | 152 | @property 153 | def is_hetero(self): 154 | """Whether the network is heterogeneous.""" 155 | return self._is_hetero 156 | 157 | def _create_prediction_layers(self): 158 | """Generates the prediction layers to be used to generate outputs after message passing.""" 159 | 160 | # add a final prediction layer for SO(3)-equivariant vector features. 161 | vector_pred_layer = nn.Sequential( 162 | Rearrange("... n d -> ... d n"), 163 | nn.Linear( 164 | self._hidden_vector_channels, self._out_vector_channels, bias=False 165 | ), 166 | Rearrange("... d n -> ... n d"), 167 | ) 168 | 169 | self._vector_pred_layer = vector_pred_layer 170 | 171 | if self._use_scaling_gate: 172 | # the features used to calculate the scaling factors are the scalar features 173 | # concatenated with the norms of the vector features 174 | num_gate_channels = ( 175 | self._hidden_scalar_channels + self._hidden_vector_channels 176 | ) 177 | self._vector_scaling_gate = nn.Sequential( 178 | nn.Linear(num_gate_channels, num_gate_channels), 179 | nn.ReLU(), 180 | nn.Linear(num_gate_channels, 1), 181 | nn.Softplus(), 182 | ) 183 | 184 | def _predict_scores( 185 | self, 186 | scalar_features: Union[torch.Tensor, Dict[str, torch.Tensor]], 187 | vector_features: Union[torch.Tensor, Dict[str, torch.Tensor]], 188 | ) -> Union[Score, Dict[str, Score]]: 189 | """ 190 | Gets the translation scores from the scalar features and vector features 191 | outputted after message passing. 192 | 193 | :param scalar_features: Scalar features outputted after message passing. 194 | :param vector_features: Vector features outputted after message passing. 195 | :returns: Tensor containing the predicted translation scores. 196 | """ 197 | translation_scores = self._vector_pred_layer(vector_features) 198 | 199 | if self._use_scaling_gate: 200 | if self._is_hetero: 201 | gate_features = { 202 | node_type: torch.cat( 203 | [ 204 | scalar_features[node_type], 205 | torch.linalg.norm(vector_features[node_type], dim=-1), 206 | ], 207 | dim=-1, 208 | ) 209 | for node_type in scalar_features 210 | } 211 | else: 212 | gate_features = torch.cat( 213 | [ 214 | scalar_features, 215 | torch.linalg.norm(vector_features, dim=-1), 216 | ], 217 | dim=-1, 218 | ) 219 | 220 | translation_scale_factors = self._vector_scaling_gate(gate_features) 221 | 222 | if self._is_hetero: 223 | translation_scores = { 224 | node_type: v_x * translation_scale_factors[node_type].unsqueeze(-1) 225 | for node_type, v_x in translation_scores.items() 226 | } 227 | else: 228 | translation_scores *= translation_scale_factors.unsqueeze(-1) 229 | 230 | return translation_scores 231 | 232 | def forward( 233 | self, graph: VectorFeatureGraph 234 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 235 | """ 236 | Returns an SO(3)-equivariant (N, 3) tensor containing the translation score for each node in the graph. 237 | 238 | :param graph: Vector feature graph representation of a protein or protein complex. 239 | :returns: Tensor of translation score vectors. 240 | """ 241 | if self._is_hetero: 242 | x = graph.x_dict 243 | vector_x = graph.vector_x_dict 244 | edge_index = graph.edge_index_dict 245 | edge_attr = graph.edge_attr_dict 246 | vector_edge_attr = graph.vector_edge_attr_dict 247 | orientations = graph.orientations_dict 248 | else: 249 | x = graph.x 250 | vector_x = graph.vector_x 251 | edge_index = graph.edge_index 252 | edge_attr = graph.edge_attr 253 | vector_edge_attr = graph.vector_edge_attr 254 | orientations = graph.orientations 255 | 256 | node_scalar_embeddings, node_vector_embeddings = self._graph_network( 257 | x, 258 | vector_x, 259 | edge_index, 260 | edge_attr, 261 | vector_edge_attr, 262 | orientations, 263 | ) 264 | 265 | return self._predict_scores(node_scalar_embeddings, node_vector_embeddings) 266 | 267 | 268 | class GVPSE3ScorePredictor(GVPR3ScorePredictor): 269 | 270 | """ 271 | Module that uses a geometric vector perceptron to approximate 272 | scores (gradient of log density) in an SE(3) diffusion model. 273 | 274 | This class is an extension of GVPR3ScorePredictor (which by default 275 | outputs a single SO(3)-equivariant 3-D vector for the translation score). 276 | This class outputs the same translation score, but also outputs a 277 | single SE(3)-invariant 3-D vector for the rotation score. 278 | """ 279 | 280 | def __init__( 281 | self, 282 | example_batch: ProteinGraph, 283 | out_channels: int = 3, 284 | out_vector_channels: int = 1, 285 | hidden_scalar_channels: int = 128, 286 | hidden_vector_channels: int = 64, 287 | hidden_edge_scalar_channels: int = 64, 288 | hidden_edge_vector_channels: int = 32, 289 | num_layers: int = 3, 290 | dropout: float = 0.2, 291 | aggr: Union[str, List, Aggregation] = "sum", 292 | attention_type: Optional[GVPAttentionTypes] = None, 293 | use_scaling_gate: bool = True, 294 | share_params: bool = True, 295 | vector_dim_size: int = 3, 296 | **kwargs: Any, 297 | ): 298 | """ 299 | :param example_batch: Example graph batch object (Data/HeteroData) to infer input feature dimensions from. 300 | :param out_channels: Number of output scalar features. Defaults to 3 for the single SE(3)-invariant 301 | rotation score. 302 | :param out_vector_channels: Number of output vector features. Defaults to 1 for the single 3-D 303 | SO(3)-equivariant translation score. 304 | :param hidden_scalar_channels: Dimensionality of the hidden scalar features. 305 | :param hidden_vector_channels: Dimensionality of the hidden vector features. 306 | :param hidden_edge_scalar_channels: Dimensionality of the hidden scalar edge features. 307 | :param hidden_edge_vector_channels: Dimensionality of the hidden vector edge features. 308 | :param num_layers: Number of message passing layers. 309 | :param dropout: Dropout rate. 310 | :param aggr: Aggregation scheme used in message passing. 311 | :param attention_type: Type of attention layer used in the graph network (GVPN). 312 | :param use_scaling_gate: Uses two final layers after message passing to output a positive scalar each, 313 | which are used to scale the output scalar and vector features (separately). This is potentially 314 | useful to allow the network to learn norm information about its outputs. 315 | :param share_params: Whether to share parameters across message passing layers. 316 | :param vector_dim_size: Dimensionality of the vectors used as vector features 317 | (not the number of vector features). 318 | :param kwargs: Additional keyword arguments passed to the graph network (GVPN). 319 | """ 320 | 321 | self._out_channels = out_channels 322 | self._scalar_pred_layer = None 323 | self._scalar_scaling_gate = None 324 | 325 | super().__init__( 326 | example_batch, 327 | out_vector_channels, 328 | hidden_scalar_channels, 329 | hidden_vector_channels, 330 | hidden_edge_scalar_channels, 331 | hidden_edge_vector_channels, 332 | num_layers, 333 | dropout, 334 | aggr, 335 | attention_type, 336 | use_scaling_gate, 337 | share_params, 338 | vector_dim_size, 339 | **kwargs, 340 | ) 341 | 342 | def as_hetero( 343 | self, 344 | example_graph: HeteroData, 345 | aggr: str = "sum", 346 | **kwargs: Any, 347 | ): 348 | """ 349 | Modifies the underlying GVPN network to be 350 | compatible with a heterogeneous graph, with 351 | separate message passing parameters for each edge type. 352 | """ 353 | 354 | super().as_hetero(example_graph, aggr, **kwargs) 355 | 356 | self._scalar_pred_layer = to_hetero( 357 | self._scalar_pred_layer, example_graph.metadata(), aggr=aggr 358 | ) 359 | if self._use_scaling_gate: 360 | self._scalar_scaling_gate = to_hetero( 361 | self._scalar_scaling_gate, example_graph.metadata(), aggr=aggr 362 | ) 363 | 364 | def _create_prediction_layers(self): 365 | """Generates the prediction layers to be used to generate predictions after message passing.""" 366 | super()._create_prediction_layers() 367 | # add a final prediction layer for SE(3)-invariant scalar features 368 | scalar_pred_layer = nn.Sequential( 369 | nn.Linear(self._hidden_scalar_channels, self._hidden_scalar_channels), 370 | nn.ReLU(), 371 | nn.Linear(self._hidden_scalar_channels, self._out_channels), 372 | ) 373 | self._scalar_pred_layer = scalar_pred_layer 374 | 375 | if self._use_scaling_gate: 376 | # the features used to calculate the scaling factors are the scalar features 377 | # concatenated with the norms of the vector features 378 | num_gate_channels = ( 379 | self._hidden_scalar_channels + self._hidden_vector_channels 380 | ) 381 | self._scalar_scaling_gate = nn.Sequential( 382 | nn.Linear(num_gate_channels, num_gate_channels), 383 | nn.ReLU(), 384 | nn.Linear(num_gate_channels, 1), 385 | nn.Softplus(), 386 | ) 387 | 388 | def _predict_scores( 389 | self, 390 | scalar_features: Union[torch.Tensor, Dict[str, torch.Tensor]], 391 | vector_features: Union[torch.Tensor, Dict[str, torch.Tensor]], 392 | ) -> Union[Score, Dict[str, Score]]: 393 | """ 394 | Gets the rotation and translation scores from the scalar features and vector features 395 | outputted after message passing. 396 | 397 | :param scalar_features: Scalar features outputted after message passing. 398 | :param vector_features: Vector features outputted after message passing. 399 | :returns: 2-tuple of tensors containing the predicted rotation and translation scores. 400 | """ 401 | 402 | rotation_scores = self._scalar_pred_layer(scalar_features) 403 | translation_scores = self._vector_pred_layer(vector_features) 404 | 405 | if self._use_scaling_gate: 406 | if self._is_hetero: 407 | gate_features = { 408 | node_type: torch.cat( 409 | [ 410 | scalar_features[node_type], 411 | torch.linalg.norm(vector_features[node_type], dim=-1), 412 | ], 413 | dim=-1, 414 | ) 415 | for node_type in scalar_features 416 | } 417 | else: 418 | gate_features = torch.cat( 419 | [ 420 | scalar_features, 421 | torch.linalg.norm(vector_features, dim=-1), 422 | ], 423 | dim=-1, 424 | ) 425 | 426 | rotation_scale_factors = self._scalar_scaling_gate(gate_features) 427 | translation_scale_factors = self._vector_scaling_gate(gate_features) 428 | 429 | if self._is_hetero: 430 | rotation_scores = { 431 | node_type: x * rotation_scale_factors[node_type] 432 | for node_type, x in rotation_scores.items() 433 | } 434 | translation_scores = { 435 | node_type: v_x * translation_scale_factors[node_type].unsqueeze(-1) 436 | for node_type, v_x in translation_scores.items() 437 | } 438 | else: 439 | rotation_scores *= rotation_scale_factors 440 | translation_scores *= translation_scale_factors.unsqueeze(-1) 441 | 442 | return rotation_scores, translation_scores 443 | -------------------------------------------------------------------------------- /loopgen/model/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the required settings needed to specify a deep learning model. 3 | """ 4 | 5 | from __future__ import annotations 6 | from typing import Optional, Callable, Dict, Any, Type, Literal 7 | from datetime import date 8 | 9 | import yaml 10 | 11 | from inspect import signature, Parameter 12 | from collections import defaultdict 13 | from pytorch_lightning import LightningModule 14 | from pytorch_lightning.loggers import MLFlowLogger 15 | 16 | from .types import ParamDictionary 17 | from ..graph import EdgeIndexMethods 18 | 19 | 20 | class ModelSettings: 21 | """ 22 | Base class for an object specifying the settings for a deep learning model. 23 | At the moment these are loaded from YAML files but other format parsers can be 24 | added in the future. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | experiment_name: str = "Experiment", 30 | run_name: str = "run", 31 | steps_per_log: int = 100, 32 | **kwargs: Any, # additional parameters specific to particular models 33 | ): 34 | self.experiment_name = experiment_name 35 | self.steps_per_log = steps_per_log 36 | 37 | # run_name is protected so that we can define a setter method in child classes 38 | # that modifies other attributes if the run name is changed 39 | self._run_name = run_name 40 | 41 | for attr, value in kwargs.items(): 42 | setattr(self, attr, value) 43 | 44 | self._params = kwargs 45 | self._param_names = set(self._params) 46 | 47 | @classmethod 48 | def from_yaml(cls, path: str): 49 | """Loads the settings from a YAML file.""" 50 | with open(path) as file: 51 | settings = yaml.safe_load(file) 52 | cls._check_settings(settings) 53 | return cls(**settings) 54 | 55 | @property 56 | def run_name(self) -> str: 57 | """The name of the current run.""" 58 | return self._run_name 59 | 60 | @run_name.setter 61 | def run_name(self, value: str): 62 | """Sets the run name.""" 63 | self._run_name = value 64 | 65 | def distribute_params(self, *callables: Callable) -> ParamDictionary: 66 | """ 67 | For input callables, distribute the parameters in the `params` 68 | attribute by searching for their name (key) in each callable's signature. 69 | 70 | Returns a dictionary mapping each callable name to a dictionary of key word 71 | arguments to be passed to that callable. 72 | """ 73 | cl_param_dict = defaultdict(dict) 74 | for cl in callables: 75 | cl_signature = signature(cl) 76 | cl_params = set(cl_signature.parameters) 77 | cl_params_in_settings = self._param_names.intersection(cl_params) 78 | cl_provided_params = {} 79 | for p in cl_params_in_settings: 80 | cl_provided_params[p] = self._params[p] 81 | 82 | cl_param_dict[cl.__name__] = cl_provided_params 83 | 84 | return cl_param_dict 85 | 86 | def distribute_model_params( 87 | self, model_class: Type[LightningModule] 88 | ) -> ParamDictionary: 89 | """ 90 | Distributes the params for a DL model between the model itself, 91 | the model's datamodule class, and the model's network class. 92 | All classes (model, datamodule, network) have parameters 93 | distributed to the constructor. 94 | """ 95 | 96 | param_dict = self.distribute_params( 97 | model_class, 98 | model_class.datamodule_class, 99 | model_class.network_class, 100 | ) 101 | 102 | # add edge index methods to the datamodule signature 103 | datamodule_params = signature(model_class.datamodule_class).parameters 104 | if "edge_method" in datamodule_params: 105 | datamodule_name = model_class.datamodule_class.__name__ 106 | if "edge_method" in param_dict[datamodule_name]: 107 | edge_index_method = param_dict[datamodule_name]["edge_method"] 108 | edge_index_class = EdgeIndexMethods[edge_index_method].value 109 | elif datamodule_params["edge_method"].default != Parameter.empty: 110 | default_method = datamodule_params["edge_method"].default 111 | edge_index_class = EdgeIndexMethods[default_method].value 112 | else: 113 | raise ValueError( 114 | "Edge method not provided and no default method found." 115 | ) 116 | 117 | edge_index_param_dict = self.distribute_params(edge_index_class) 118 | param_dict[datamodule_name]["edge_kwargs"] = edge_index_param_dict[ 119 | edge_index_class.__name__ 120 | ] 121 | 122 | return param_dict 123 | 124 | @classmethod 125 | def _check_settings(cls, settings: Dict[str, Any]) -> None: 126 | """ 127 | Checks a settings dictionary and raises an error if any of 128 | the provided settings are missing or invalid. 129 | """ 130 | 131 | # check if all requires arguments provided 132 | constructor_sig = signature(cls.__init__) 133 | required_slots = { 134 | param_name 135 | for param_name, param in constructor_sig.parameters.items() 136 | if param.default == Parameter.empty and param_name not in {"kwargs", "self"} 137 | } 138 | 139 | if not required_slots.issubset(settings): 140 | raise ValueError( 141 | f"One or more of the required fields: {required_slots} was not found in the input." 142 | ) 143 | 144 | 145 | class TrainSettings(ModelSettings): 146 | 147 | """ 148 | Extends ModelSettings with additional parameters specific for training. 149 | """ 150 | 151 | def __init__( 152 | self, 153 | experiment_name: str = "Experiment", 154 | run_name: str = "run", 155 | steps_per_log: int = 100, 156 | checkpoint_outfile: Optional[str] = None, 157 | checkpoint_metric: str = "validation_loss", 158 | checkpoint_mode: Literal["min", "max"] = "min", 159 | save_top_k: int = 1, 160 | **kwargs: Any, # additional parameters specific to particular models, saved under self._params 161 | ): 162 | super().__init__(experiment_name, run_name, steps_per_log, **kwargs) 163 | 164 | self.checkpoint_metric = checkpoint_metric 165 | 166 | self._checkpoint_outfile_provided = checkpoint_outfile is not None 167 | 168 | if self._checkpoint_outfile_provided: 169 | self.checkpoint_outfile = checkpoint_outfile 170 | else: 171 | self.checkpoint_outfile = ( 172 | f"{self._run_name}-{date.today()}-" 173 | f"{{epoch:02d}}-{{{checkpoint_metric}:.2f}}" 174 | ) 175 | 176 | self.checkpoint_mode = checkpoint_mode 177 | self.save_top_k = save_top_k 178 | 179 | @ModelSettings.run_name.setter 180 | def run_name(self, value: str): 181 | """ 182 | Sets the run name and changes the checkpoint outfile name 183 | accordingly if it was not provided to the constructor. 184 | """ 185 | self._run_name = value 186 | if not self._checkpoint_outfile_provided: 187 | self.checkpoint_outfile = ( 188 | f"{self._run_name}-{date.today()}-" 189 | f"{{epoch:02d}}-{{{self.checkpoint_metric}:.2f}}" 190 | ) 191 | 192 | def get_mlflow_logger( 193 | exp_name: str, run_name: str, settings: ModelSettings 194 | ) -> MLFlowLogger: 195 | """ 196 | Returns an MLFlow logger object from pytorch lightning with 197 | the settings logged. 198 | """ 199 | mlflow_logger = MLFlowLogger(experiment_name=exp_name, run_name=run_name) 200 | mlflow_logger.log_hyperparams( 201 | {param: value for param, value in settings.__dict__.items() if param[0] != "_"} 202 | ) 203 | 204 | return mlflow_logger 205 | -------------------------------------------------------------------------------- /loopgen/model/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a LoopGen model. 3 | """ 4 | 5 | from typing import Type, Literal, Optional, Dict, Set 6 | import logging as lg 7 | import os 8 | import argparse 9 | import sys 10 | 11 | import h5py 12 | import pytorch_lightning as pl 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | 15 | from .settings import TrainSettings, get_mlflow_logger 16 | from . import setup_model 17 | 18 | from .. import ReceptorLigandDataset, get_device 19 | from ..data import load_splits_file 20 | 21 | 22 | def add_train_args(parser: argparse.ArgumentParser) -> None: 23 | """ 24 | Adds command line arguments for training to a parser. 25 | """ 26 | 27 | parser.add_argument( 28 | "data_path", 29 | type=str, 30 | help="Path to the HDF5 data file to be used for training/testing.", 31 | ) 32 | parser.add_argument( 33 | "--splits", 34 | type=str, 35 | required=True, 36 | help="Path to a JSON file containing the names of instances stored under " 37 | "the keys 'train', 'validation', and 'test'.", 38 | ) 39 | parser.add_argument( 40 | "--config", 41 | type=str, 42 | required=True, 43 | help="Path to YAML file containing the settings.", 44 | ) 45 | parser.add_argument( 46 | "--checkpoint", 47 | default=None, 48 | help="Path to a PyTorch checkpoint file containing model weights.", 49 | ) 50 | parser.add_argument( 51 | "--restore_full_state", 52 | action="store_true", 53 | help="Whether to restore the full training state from the provided checkpoint. Only" 54 | "applicable if --checkpoint is passed.", 55 | ) 56 | parser.add_argument( 57 | "--out_dir", 58 | default=".", 59 | help="Directory in which results (predictions) will be saved.", 60 | ) 61 | parser.add_argument( 62 | "--run_name", 63 | default=None, 64 | help="Name of the MLFlow run under which the current run data will be saved.", 65 | ) 66 | parser.add_argument( 67 | "-t", "--test", action="store_true", help="Whether to run as a test run" 68 | ) 69 | parser.add_argument( 70 | "--device", 71 | default=None, 72 | choices=["cpu", "gpu"], 73 | help="Which device to use (defaults to None, in which case the " 74 | "GPU is used if available, and if not the CPU is used).", 75 | ) 76 | 77 | 78 | def train( 79 | dataset: ReceptorLigandDataset, 80 | splits: Dict[str, Set[str]], 81 | settings: TrainSettings, 82 | model_class: Type[pl.LightningModule], 83 | out_dir: str, 84 | checkpoint: Optional[str], 85 | restore_full_state: bool, 86 | accelerator: Literal["cpu", "gpu"], 87 | test: bool = False, 88 | ) -> None: 89 | """ 90 | Trains a model specified by a settings YAML file. 91 | 92 | :param dataset: The dataset to use for training. 93 | :param splits: Dictionary containing the train/test/validation splits as string names 94 | of the instances in the dataset, stored under the respective keys "train", "test", and "validation". 95 | :param settings: The settings for the model. 96 | :param model_class: The model class to use. 97 | :param out_dir: The directory in which to save the results. 98 | :param checkpoint: The path to a PyTorch checkpoint file containing model weights. 99 | :param restore_full_state: Whether to restore the full state of the Trainer (including 100 | optimizer state, schedulers, etc.) from the provided checkpoint. Only applicable if 101 | checkpoint is passed. 102 | :param accelerator: Which device to use 103 | (defaults to None, in which case the GPU is used if available, and if not the CPU is used). 104 | :param test: Whether to run as a test run. 105 | """ 106 | 107 | pl.seed_everything(123, workers=True) 108 | 109 | param_dict = settings.distribute_model_params(model_class) 110 | 111 | if "test_results_filepath" not in param_dict[model_class.__name__]: 112 | param_dict[model_class.__name__]["test_results_filepath"] = os.path.join( 113 | out_dir, "test_results.csv" 114 | ) 115 | 116 | lg.basicConfig(format="%(asctime)s %(levelname)-8s: %(message)s") 117 | 118 | if not os.path.exists(out_dir): 119 | lg.info(f"Specified output directory {out_dir} does not exist, creating...") 120 | os.mkdir(out_dir) 121 | 122 | num_train_samples = len(splits["train"]) 123 | num_test_samples = len(splits["test"]) 124 | num_validation_samples = len(splits["validation"]) 125 | total_num_samples = num_train_samples + num_test_samples + num_validation_samples 126 | 127 | lg.info( 128 | f"Received train/test/validation splits of " 129 | f"{num_train_samples / total_num_samples:.2f}/" 130 | f"{num_test_samples / total_num_samples:.2f}/" 131 | f"{num_validation_samples / total_num_samples:.2f}" 132 | ) 133 | 134 | datamodule, model = setup_model( 135 | dataset, splits, param_dict, model_class, checkpoint 136 | ) 137 | 138 | exp_name = settings.experiment_name + " test" if test else settings.experiment_name 139 | run_name = settings.run_name + " test" if test else settings.run_name 140 | mlflow_logger = get_mlflow_logger(exp_name, run_name, settings) 141 | 142 | checkpoint_path = settings.checkpoint_outfile 143 | 144 | i = 1 145 | while os.path.exists(checkpoint_path): 146 | checkpoint_path = f"{settings.checkpoint_outfile}-v{i}" 147 | i += 1 148 | 149 | lightning_args = settings.distribute_params(pl.Trainer, ModelCheckpoint) 150 | checkpoint_args = lightning_args["ModelCheckpoint"] 151 | 152 | checkpoint_callback = ModelCheckpoint( 153 | dirpath=out_dir, 154 | monitor=settings.checkpoint_metric, 155 | filename=checkpoint_path, 156 | **checkpoint_args, 157 | ) 158 | 159 | trainer_args = lightning_args["Trainer"] 160 | trainer_args["fast_dev_run"] = test 161 | 162 | if "accelerator" not in trainer_args: 163 | trainer_args["accelerator"] = accelerator 164 | if "enable_progress_bar" not in trainer_args: 165 | trainer_args["enable_progress_bar"] = False 166 | 167 | trainer = pl.Trainer( 168 | logger=mlflow_logger, 169 | callbacks=[checkpoint_callback], 170 | log_every_n_steps=settings.steps_per_log, 171 | devices=1, 172 | **trainer_args, 173 | ) 174 | 175 | try: 176 | if restore_full_state: 177 | trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint) 178 | else: 179 | trainer.fit(model, datamodule=datamodule) 180 | except RuntimeError as e: 181 | lg.error( 182 | f"RuntimeError: {e} Saving model checkpoint at error to: {checkpoint_path}" 183 | ) 184 | sys.exit(1) 185 | 186 | model = model_class.load_from_checkpoint( 187 | checkpoint_callback.best_model_path, 188 | network=model.network, 189 | **param_dict[model_class.__name__], 190 | ) 191 | 192 | trainer.test(model, datamodule=datamodule) 193 | 194 | 195 | def train_from_args( 196 | args: argparse.Namespace, model_class: Type[pl.LightningModule] 197 | ) -> None: 198 | """Runs `train()` using command line arguments.""" 199 | device, accelerator = get_device(args.device) 200 | 201 | with h5py.File(args.data_path) as hdf5_file: 202 | dataset = ReceptorLigandDataset.from_hdf5_file(hdf5_file, device=device) 203 | splits = load_splits_file(args.splits, dataset) 204 | settings = TrainSettings.from_yaml(args.config) 205 | 206 | out_dir = args.out_dir 207 | checkpoint = args.checkpoint 208 | restore_full_state = args.restore_full_state 209 | test = args.test 210 | 211 | if hasattr(settings, "out_dir"): 212 | if args.out_dir != ".": 213 | lg.warning( 214 | "Out directory found in settings file. Command line will supercede settings.", 215 | ) 216 | else: 217 | out_dir = settings.out_dir 218 | if hasattr(settings, "run_name"): 219 | if args.run_name: 220 | lg.warning( 221 | "Run name also found in settings file. Command line will supercede settings.", 222 | ) 223 | settings.run_name = args.run_name 224 | if hasattr(settings, "checkpoint"): 225 | if args.checkpoint: 226 | lg.warning( 227 | "Checkpoint found in settings file. " 228 | "Command line will supercede settings and load from checkpoint.", 229 | ) 230 | else: 231 | checkpoint = settings.checkpoint 232 | if hasattr(settings, "test"): 233 | if args.test: 234 | lg.warning( 235 | "Test status found in settings file. Command line will supercede settings.", 236 | ) 237 | else: 238 | test = settings.generate 239 | 240 | train( 241 | dataset, 242 | splits, 243 | settings, 244 | model_class, 245 | out_dir, 246 | checkpoint, 247 | restore_full_state, 248 | accelerator, 249 | test, 250 | ) 251 | -------------------------------------------------------------------------------- /loopgen/model/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Types used across many or all model. 3 | """ 4 | 5 | from typing import Any, Dict, DefaultDict, Union, Tuple, Optional, Sequence 6 | 7 | import torch 8 | 9 | from ..structure import Structure, OrientationFrames 10 | from ..graph import ( 11 | ComplexData, 12 | StructureData, 13 | VectorFeatureComplexData, 14 | VectorFeatureStructureData, 15 | ScalarFeatureComplexData, 16 | ScalarFeatureStructureData, 17 | ) 18 | 19 | # all the possible types of CDR graph (either in complex or as a solo structure) 20 | ProteinGraph = Union[ComplexData, StructureData] 21 | VectorFeatureGraph = Union[VectorFeatureComplexData, VectorFeatureStructureData] 22 | ScalarFeatureGraph = Union[ScalarFeatureComplexData, ScalarFeatureStructureData] 23 | 24 | # dictionary for storing parameters for different classes 25 | ParamDictionary = DefaultDict[str, Dict[str, Any]] 26 | 27 | # The output of the CDRFrameDataModule collate() (used for training), consisting of: 28 | # 1. tuple of IDs, one associated with each CDR in the batch 29 | # 2. epitope structure 30 | # 3. ground truth CDR orientation frames 31 | CDRFramesBatch = Tuple[Tuple[str], Structure, OrientationFrames] 32 | 33 | Score = Union[torch.Tensor, Sequence[torch.Tensor]] 34 | 35 | # The output of a forward process, consisting of 36 | # 1. tensor(s) of scores 37 | # 2. the noised graph sampled from q(x_t | x_0) 38 | # 3. an optional self-conditioning noised graph sampled from q(x_{t+1} | x_t) 39 | ForwardProcessOutput = Tuple[ 40 | Score, 41 | VectorFeatureGraph, 42 | Optional[VectorFeatureGraph], 43 | ] 44 | -------------------------------------------------------------------------------- /loopgen/model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some useful utility functions, many of which are related 3 | to modifying or interacting with CDR graphs. 4 | """ 5 | 6 | from typing import Optional, Tuple, Any, Union, List 7 | 8 | from functools import partial 9 | 10 | import torch 11 | import math 12 | import numpy as np 13 | from torch_geometric.data import HeteroData 14 | 15 | from .types import ProteinGraph 16 | from ..utils import node_type_subgraph, combine_coords 17 | from ..structure import Structure, OrientationFrames 18 | from ..graph import StructureData 19 | 20 | 21 | def get_cdr_epitope_subgraphs( 22 | graph: ProteinGraph, 23 | ) -> Tuple[StructureData, Optional[StructureData]]: 24 | """ 25 | Gets the CDR and epitope subgraphs of a StructureData or ComplexData object. 26 | If the graph is heterogeneous, gets the node type subgraphs 27 | for "ligand" (epitope) and "receptor" (CDR) and converts 28 | each into a homogeneous form. If the graph is homogeneous, it is 29 | checked for a `node_type` attribute, and if it is present, 30 | it is assumed that epitope nodes are 31 | `node_type == 0` and CDR nodes are `node_type == 1`. 32 | If `node_type` is not present, assumes the graph is just 33 | a PeptideGraph of an unbound CDR (epitope is None). 34 | """ 35 | if isinstance(graph, HeteroData): 36 | cdr = graph.node_type_subgraph(["ligand"]).to_homogeneous() 37 | epitope = graph.node_type_subgraph(["receptor"]).to_homogeneous() 38 | elif hasattr(graph, "node_type"): 39 | cdr = node_type_subgraph(graph, node_type=1) 40 | epitope = node_type_subgraph(graph, node_type=0) 41 | else: 42 | cdr = graph 43 | epitope = None 44 | 45 | return cdr, epitope 46 | 47 | 48 | def _get_feature( 49 | graph: ProteinGraph, feature_attr_name: str, key: str, node_type: int 50 | ) -> Any: 51 | """ 52 | Gets the feature tensor for a specified node-level attribute 53 | in a StructureData/ComplexData graph object. 54 | 55 | If the underlying object is a ComplexData object, looks up the attribute under `key`. 56 | If the underlying object is a StructureData object with a node_type attribute, returns the 57 | feature tensor for nodes with `node_type == node_type`. If the underlying 58 | object is a StructureData object with no attribute node_type, returns 59 | the whole feature tensor. 60 | """ 61 | 62 | if isinstance(graph, HeteroData): 63 | cdr_features = getattr(graph[key], feature_attr_name) 64 | elif hasattr(graph, "node_type"): 65 | cdr_mask = graph.node_type == node_type 66 | cdr_features = getattr(graph, feature_attr_name)[cdr_mask] 67 | else: 68 | cdr_features = getattr(graph, feature_attr_name) 69 | 70 | return cdr_features 71 | 72 | 73 | get_node_feature_docstring = """ 74 | Gets the {} feature tensor for a specified node-level attribute 75 | in a StructureData/ComplexData graph object. 76 | 77 | :param graph: The graph object to get the features from. 78 | :param feature_attr_name: The name of the attribute storing the feature tensor. 79 | """ 80 | 81 | get_cdr_feature = partial(_get_feature, key="ligand", node_type=1) 82 | get_cdr_feature.__doc__ = get_node_feature_docstring.format("CDR") 83 | get_epitope_feature = partial(_get_feature, key="receptor", node_type=0) 84 | get_epitope_feature.__doc__ = get_node_feature_docstring.format("epitope") 85 | 86 | 87 | def _replace_features( 88 | graph: ProteinGraph, 89 | replacement_features: Any, 90 | key: str, 91 | node_type: int, 92 | inplace: bool = True, 93 | feature_attr_name: str = "x", 94 | ) -> ProteinGraph: 95 | """ 96 | Replaces features for residues in a StructureData/ComplexData graph object. 97 | 98 | If the underlying object is a ComplexData object, looks up the feature attribute under `key`. 99 | If the underlying object is a StructureData object with a node_type attribute, returns the 100 | feature tensor for nodes with `node_type == node_type`. If the underlying 101 | object is a StructureData object with no attribute node_type, returns 102 | the whole feature tensor. 103 | """ 104 | if not inplace: 105 | graph = graph.clone() 106 | 107 | if isinstance(graph, HeteroData): 108 | setattr(graph[key], feature_attr_name, replacement_features) 109 | elif hasattr(graph, "node_type"): 110 | cdr_mask = graph.node_type == node_type 111 | features = getattr(graph, feature_attr_name) 112 | features[cdr_mask] = replacement_features 113 | else: 114 | setattr(graph, feature_attr_name, replacement_features) 115 | 116 | return graph 117 | 118 | 119 | replace_features_docstring = """ 120 | Replaces features for {} residues in a StructureData/ComplexData graph object. 121 | 122 | :param graph: The graph object to modify. 123 | :param replacement_features: The replacement features. 124 | :param inplace: Whether to modify the graph in-place or return a copy. (default: True) 125 | :param feature_attr_name: The name of the attribute to modify. (default: "x") 126 | """ 127 | 128 | 129 | replace_cdr_features = partial(_replace_features, key="ligand", node_type=1) 130 | replace_cdr_features.__doc__ = replace_features_docstring.format("CDR") 131 | replace_epitope_features = partial(_replace_features, key="receptor", node_type=0) 132 | replace_epitope_features.__doc__ = replace_features_docstring.format("epitope") 133 | 134 | 135 | def _pad_features( 136 | graph: ProteinGraph, 137 | num_pads: int, 138 | pad_value: float, 139 | key: str, 140 | inplace: bool = True, 141 | pad_dim: int = -1, 142 | feature_attr_name: str = "x", 143 | ) -> ProteinGraph: 144 | """ 145 | Pads node features in a graph by adding `num_pads` new padded features, 146 | each with a value of `pad_value`. Used in recycling to add initial recurrent 147 | features to the input graph (i.e. where the model's own predictions are provided 148 | as features in subsequent iterations). 149 | 150 | If the underlying object is a ComplexData object, pads the feature attribute under `key`. 151 | If the underlying object is a StructureData object, pads the whole underlying 152 | feature tensor. 153 | """ 154 | 155 | if not inplace: 156 | graph = graph.clone() 157 | 158 | graph_is_hetero = isinstance(graph, HeteroData) 159 | 160 | if graph_is_hetero: 161 | features_to_pad = getattr(graph[key], feature_attr_name) 162 | else: 163 | features_to_pad = getattr(graph, feature_attr_name) 164 | 165 | if pad_dim < 0: 166 | num_dims_before_pad = -1 - pad_dim 167 | else: 168 | num_dims_before_pad = len(features_to_pad.shape) - pad_dim - 1 169 | 170 | pad = ((0, 0) * num_dims_before_pad) + (0, num_pads) 171 | 172 | padded_features = torch.nn.functional.pad(features_to_pad, pad=pad, value=pad_value) 173 | 174 | if graph_is_hetero: 175 | setattr(graph[key], feature_attr_name, padded_features) 176 | else: 177 | setattr(graph, feature_attr_name, padded_features) 178 | 179 | return graph 180 | 181 | 182 | pad_features_docstring = """ 183 | Pads {} features in a graph by adding `num_pads` new padded features, 184 | each with a value of `pad_value`. 185 | 186 | :param graph: A ProteinGraph object. 187 | :param num_pads: The number of new padded features to add. 188 | :param pad_value: The value to use for the new padded features. 189 | :param inplace: Whether to modify the graph in place. (default: True) 190 | :param pad_dim: The dimension along which to pad the features. (default: -1) 191 | :param feature_attr_name: The name of the feature attribute to pad. (default: "x") 192 | """ 193 | 194 | pad_cdr_features = partial(_pad_features, key="ligand") 195 | pad_cdr_features.__doc__ = pad_features_docstring.format("CDR") 196 | pad_epitope_features = partial(_pad_features, key="receptor") 197 | pad_epitope_features.__doc__ = pad_features_docstring.format("epitope") 198 | 199 | 200 | def _update_structure( 201 | graph: ProteinGraph, 202 | new_structure: Union[OrientationFrames, Structure], 203 | key: str, 204 | inplace: bool = True, 205 | ) -> ProteinGraph: 206 | """ 207 | Updates the underlying structure in a graph object. 208 | 209 | If the underlying object is a ComplexData object, pads the Structure under `key`. 210 | If the underlying object is a StructureData object, updates the whole underlying 211 | Structure. 212 | """ 213 | 214 | if not inplace: 215 | graph = graph.clone() 216 | 217 | not_hetero = not isinstance(graph, HeteroData) 218 | no_node_type = not hasattr(graph, "node_type") 219 | 220 | if not_hetero: 221 | if no_node_type: 222 | graph.update_structure(new_structure) 223 | return graph 224 | graph.to_heterogeneous() 225 | 226 | graph.update_structure(new_structure, key="ligand") 227 | 228 | if not_hetero: 229 | graph.to_homogeneous() 230 | 231 | return graph 232 | 233 | 234 | update_structure_docstring = """ 235 | Updates the underlying {} structure in a graph object. 236 | 237 | :param graph: A ProteinGraph object. 238 | :param new_structure: The new structure to use. 239 | :param inplace: Whether to modify the graph in place. (default: True) 240 | """ 241 | 242 | update_cdr_structure = partial(_update_structure, key="ligand") 243 | update_cdr_structure.__doc__ = update_structure_docstring.format("CDR") 244 | update_epitope_structure = partial(update_cdr_structure, key="receptor") 245 | update_epitope_structure.__doc__ = update_structure_docstring.format("epitope") 246 | 247 | 248 | def sinusoidal_encoding(value: torch.Tensor, channels: torch.Tensor, base: int = 100): 249 | """ 250 | Sinusoidal encoding of a value for some number of channels, as in the original 251 | Transformer (https://arxiv.org/abs/1706.03762). 252 | 253 | :param value: Rank 1 tensor of containing N values to be encoded. 254 | :param channels: Channels over which the encoding is generated. This should be 255 | a rank 1 tensor of M positive integers, increasing by 1. 256 | :param base: Value used to scale the sin/cos function outputs in the final encoding. 257 | The larger this number is, the less variability there will be between encodings for different values. 258 | The original Transformer paper used 10000 here but we prefer a smaller number to introduce 259 | more variation between encodings for shorter sequences. 260 | :returns: Tensor of shape (N, M) containing the M-dimensional encoding for 261 | each of the N input values. 262 | """ 263 | encoding = torch.where( 264 | channels % 2 == 0, 265 | torch.sin( 266 | value.unsqueeze(-1) / (base ** (2 * channels / channels[-1])).unsqueeze(0) 267 | ), 268 | torch.cos( 269 | value.unsqueeze(-1) / (base ** (2 * channels / channels[-1])).unsqueeze(0) 270 | ), 271 | ) 272 | return encoding 273 | 274 | 275 | def axis_angle_to_matrix(axis, angle_radians): 276 | """ 277 | Converts a rotation from an axis-angle representation to a rotation matrix. 278 | 279 | Parameters: 280 | - axis: torch.Tensor of shape (3,) - Arbitrary axis of rotation. 281 | - angle_degrees: float - Rotation angle in radians. 282 | 283 | Returns: 284 | - rotation_matrix: torch.Tensor of shape (3, 3) - Tensor containing the rotation matrix. 285 | """ 286 | 287 | axis = axis / torch.norm(axis) 288 | 289 | # Rodrigues' rotation formula 290 | cos_theta = math.cos(angle_radians) 291 | sin_theta = math.sin(angle_radians) 292 | cross_product_matrix = torch.tensor( 293 | [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]] 294 | ) 295 | 296 | rotation_matrix = ( 297 | torch.eye(3) 298 | + sin_theta * cross_product_matrix 299 | + (1 - cos_theta) * torch.matmul(cross_product_matrix, cross_product_matrix) 300 | ) 301 | 302 | return rotation_matrix 303 | 304 | 305 | def compute_axis(t1: torch.tensor, t2: torch.tensor): 306 | return torch.cross(t1, t2) 307 | 308 | 309 | def compute_angle(t1: torch.tensor, t2: torch.tensor): 310 | return torch.arccos(torch.dot(t1, t2) / (torch.norm(t1) * torch.norm(t2))) 311 | 312 | 313 | def permute_epitopes( 314 | structures: List[Tuple[str, Structure, Union[Structure, OrientationFrames]]], 315 | align: bool = True, 316 | ) -> List[Tuple[str, Structure, Union[Structure, OrientationFrames]]]: 317 | """ 318 | For a given list of structure tuples of the form (names, epitope, cdr), 319 | permutes the epitopes in the batch. 320 | 321 | :param structures: A list of structure tuples of the form (names, epitope, cdr). 322 | :param align: Whether to align the permuted epitope to the original epitope by aligning 323 | their difference vectors. 324 | :returns: List of the same form as the input list of structures, but with the epitopes permuted. 325 | """ 326 | names, epitopes, cdrs = map(tuple, zip(*structures)) 327 | permutation = np.random.permutation(len(epitopes)) 328 | epitopes_permuted = [epitopes[i].clone() for i in permutation] 329 | 330 | if align is True: 331 | cdrs_aligned = [] 332 | epitopes_permuted_aligned = [] 333 | for i in range(len(epitopes_permuted)): 334 | epitope = epitopes[i] 335 | cdr = cdrs[i] 336 | cdr_centered, cdr_center = cdr.center() 337 | _, epitope_center = epitope.center() 338 | 339 | perm_epitope_centered, _ = epitopes_permuted[i].center() 340 | 341 | ref_vector_WT_epitope = cdr_center - epitope_center 342 | 343 | permuted_receptor_aligned = perm_epitope_centered.translate( 344 | -ref_vector_WT_epitope 345 | ) 346 | epitopes_permuted_aligned.append(permuted_receptor_aligned) 347 | cdrs_aligned.append(cdr_centered) 348 | 349 | epitopes_permuted = epitopes_permuted_aligned 350 | cdrs = cdrs_aligned 351 | 352 | structures_permuted = list(zip(names, epitopes_permuted, cdrs)) 353 | 354 | return structures_permuted 355 | 356 | 357 | def permute_cdrs( 358 | structures: List[ 359 | Tuple[ 360 | str, 361 | Union[Structure, OrientationFrames], 362 | Union[Structure, OrientationFrames], 363 | ] 364 | ], 365 | match_by_length: bool = True, 366 | ) -> List[Tuple[str, Structure, Structure]]: 367 | """ 368 | For a given list of structure tuples of the form (names, epitope, cdr), 369 | permutes the CDRs in the batch. 370 | 371 | :param structures: A list of structure tuples of the form (names, epitope, cdr). 372 | :param match_by_length: Whether to match CDRs by length when permuting. 373 | :param align: Whether to align the permuted CDR to the original CDR by aligning 374 | their principal components. 375 | :returns: List of the same form as the input list of structures, but with the epitopes permuted. 376 | """ 377 | names, epitopes, cdrs = map(tuple, zip(*structures)) 378 | 379 | structures_permuted = [] 380 | if match_by_length: 381 | individual_cdr_lengths = np.array([len(cdr) for cdr in cdrs]) 382 | cdr_lengths, cdr_len_counts = np.unique( 383 | individual_cdr_lengths, return_counts=True 384 | ) 385 | singleton_counts = set(cdr_len_counts[cdr_len_counts == 1]) 386 | 387 | index_mapping = {} 388 | for cdr_len in cdr_lengths: 389 | if cdr_len in singleton_counts: 390 | continue 391 | indices = np.nonzero(individual_cdr_lengths == cdr_len)[0] 392 | perm_indices = indices[np.random.permutation(len(indices))] 393 | index_mapping.update(dict(zip(indices, perm_indices))) 394 | 395 | for i in range(len(structures)): 396 | if i in index_mapping: 397 | permuted_cdr = cdrs[index_mapping[i]] 398 | structures_permuted.append((names[i], epitopes[i], permuted_cdr)) 399 | 400 | else: 401 | structures_permuted = [ 402 | (name, ep, cdrs[i]) 403 | for i, (name, ep, _) in zip( 404 | np.random.permutation(len(structures)), structures 405 | ) 406 | ] 407 | 408 | return structures_permuted 409 | 410 | 411 | def translate_cdrs_away( 412 | structures: List[ 413 | Tuple[ 414 | str, 415 | Union[Structure, OrientationFrames], 416 | Union[Structure, OrientationFrames], 417 | ] 418 | ], 419 | distance: float = 20.0, 420 | ) -> List[ 421 | Tuple[str, Union[Structure, OrientationFrames], Union[Structure, OrientationFrames]] 422 | ]: 423 | """ 424 | For a given list of structure tuples of the form (names, epitope, cdr), 425 | translates the CDRs away from the epitope centre of mass. 426 | 427 | :param structures: A list of structure tuples of the form (names, epitope, cdr). 428 | :param distance: The distance to translate the CDRs away from the epitope centre of mass. 429 | :returns: List of the same form as the input list of structures, but with the CDRs translated 430 | in the opposite direction from the epitope. 431 | """ 432 | 433 | names, epitopes, cdrs = map(tuple, zip(*structures)) 434 | cdrs_translated = [] 435 | 436 | for i in range(len(cdrs)): 437 | epitope = epitopes[i] 438 | cdr = cdrs[i] 439 | 440 | _, cdr_center = cdr.center() 441 | _, epitope_center = epitope.center() 442 | displacement = torch.nn.functional.normalize( 443 | cdr_center - epitope_center, dim=-1 444 | ) 445 | 446 | cdr_translated = cdr.translate(displacement * distance) 447 | cdrs_translated.append(cdr_translated) 448 | 449 | structures_translated = list(zip(names, epitopes, cdrs_translated)) 450 | 451 | return structures_translated 452 | -------------------------------------------------------------------------------- /loopgen/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .gvp import ( 2 | GVPN, 3 | GVPMessagePassing, 4 | GeometricVectorPerceptron, 5 | GVPAttentionTypes 6 | ) 7 | from .diffusion import ( 8 | Gaussian3DForwardProcess, 9 | IGSO3ForwardProcess, 10 | R3ReverseProcess, 11 | SO3ReverseProcess 12 | ) -------------------------------------------------------------------------------- /loopgen/visualisation.py: -------------------------------------------------------------------------------- 1 | from typing import OrderedDict, Optional, Tuple, List 2 | 3 | import plotly.graph_objects as go 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import collections 9 | import torch 10 | import seaborn as sns 11 | 12 | from sklearn.metrics import confusion_matrix 13 | 14 | from .structure import Structure, LinearStructure, AminoAcid3 15 | from .utils import combine_coords, get_covalent_bonds 16 | 17 | 18 | def _extend_contig(covalent_bond_matrix: torch.Tensor, indices: list): 19 | """ 20 | Extends a list of residue indices if any residues are covalently bonded to the 21 | residue indexed by the final element of `indices`. 22 | """ 23 | next_residue_bonds = covalent_bond_matrix[:, indices[-1]] == 1.0 24 | if torch.any(next_residue_bonds): 25 | next_index = torch.nonzero(next_residue_bonds).squeeze(-1)[0].item() 26 | indices.append(next_index) 27 | return _extend_contig(covalent_bond_matrix, indices) 28 | else: 29 | return indices 30 | 31 | 32 | def get_contiguous_regions( 33 | covalent_bond_matrix: torch.Tensor, 34 | ) -> OrderedDict[int, list[int]]: 35 | """ 36 | Gets an integer-indexed dictionary of lists of indices corresponding 37 | to covalently-bonded atoms, obtained from an input binary covalent bond matrix. 38 | The output dictionary has the form {region_num: region_residue_indices}. 39 | """ 40 | contiguous_regions = collections.OrderedDict() 41 | N_bonded_indices, C_bonded_indices = torch.nonzero( 42 | covalent_bond_matrix, as_tuple=True 43 | ) 44 | current_N_index, current_C_index = ( 45 | N_bonded_indices[0].item(), 46 | C_bonded_indices[0].item(), 47 | ) 48 | region_num = 0 49 | while ( 50 | current_N_index < covalent_bond_matrix.shape[-2] - 1 51 | and current_C_index < covalent_bond_matrix.shape[-1] 52 | ): 53 | contig = _extend_contig( 54 | covalent_bond_matrix, indices=[current_C_index, current_N_index] 55 | ) 56 | contiguous_regions[region_num] = contig 57 | region_num += 1 58 | current_N_index = N_bonded_indices[N_bonded_indices > contig[-1]] 59 | current_C_index = C_bonded_indices[C_bonded_indices > contig[-1]] 60 | if len(current_N_index) > 0 and len(current_C_index) > 0: 61 | current_N_index = current_N_index[0].item() 62 | current_C_index = current_C_index[0].item() 63 | else: 64 | break 65 | 66 | return contiguous_regions 67 | 68 | 69 | def plot_structure( 70 | structure: Structure, 71 | figure: Optional[go.Figure] = None, 72 | backbone_colour: str = "black", 73 | sidechain_colour: str = "yellow", 74 | ): 75 | """Generates a 3-D plot of a Structure object.""" 76 | if isinstance(structure, LinearStructure): 77 | contiguous_regions = {0: list(range(len(structure)))} 78 | else: 79 | covalent_bond_matrix = get_covalent_bonds( 80 | structure.N_coords.unsqueeze(-2), structure.C_coords.unsqueeze(-3) 81 | ) 82 | contiguous_regions = get_contiguous_regions(covalent_bond_matrix) 83 | 84 | residue_nums_by_contig = {} 85 | all_contig_nums = [] 86 | for contig, residues in contiguous_regions.items(): 87 | all_contig_nums.extend([contig] * len(residues)) 88 | for res_num in residues: 89 | residue_nums_by_contig[res_num] = contig 90 | 91 | all_coords = pd.DataFrame( 92 | combine_coords( 93 | structure.N_coords, 94 | structure.CA_coords, 95 | structure.C_coords, 96 | structure.CB_coords, 97 | ) 98 | .cpu() 99 | .detach() 100 | .numpy(), 101 | columns=["x", "y", "z"], 102 | ) 103 | 104 | all_coords["contig"] = sorted( 105 | np.repeat(all_contig_nums, 4), key=residue_nums_by_contig.get 106 | ) 107 | 108 | backbone_coords = pd.DataFrame( 109 | combine_coords(structure.N_coords, structure.CA_coords, structure.C_coords) 110 | .cpu() 111 | .detach() 112 | .numpy(), 113 | columns=["x", "y", "z"], 114 | ) 115 | 116 | backbone_coords["contig"] = sorted( 117 | np.repeat(all_contig_nums, 3), key=residue_nums_by_contig.get 118 | ) 119 | 120 | sidechain_coords = pd.DataFrame( 121 | combine_coords(structure.CA_coords, structure.CB_coords).cpu().detach().numpy(), 122 | columns=["x", "y", "z"], 123 | ) 124 | sidechain_coords["contig"] = sorted( 125 | np.repeat(all_contig_nums, 2), key=residue_nums_by_contig.get 126 | ) 127 | 128 | atom_scatter_plots = [ 129 | go.Scatter3d( 130 | x=df["x"], 131 | y=df["y"], 132 | z=df["z"], 133 | marker=dict(color=backbone_colour), 134 | mode="markers", 135 | name=frag_type, 136 | ) 137 | for frag_type, df in all_coords.groupby("contig") 138 | ] 139 | 140 | backbone_line_plots = [ 141 | go.Scatter3d( 142 | x=df["x"], 143 | y=df["y"], 144 | z=df["z"], 145 | line=dict(color=backbone_colour), 146 | mode="lines", 147 | showlegend=False, 148 | ) 149 | for frag_type, df in backbone_coords.groupby("contig") 150 | ] 151 | 152 | sidechain_line_plots = [ 153 | go.Scatter3d( 154 | x=sidechain_coords.iloc[(i - 1) : (i + 1)]["x"], 155 | y=sidechain_coords.iloc[(i - 1) : (i + 1)]["y"], 156 | z=sidechain_coords.iloc[(i - 1) : (i + 1)]["z"], 157 | line=dict(color=sidechain_colour), 158 | mode="lines", 159 | showlegend=False, 160 | ) 161 | for i in range(1, len(sidechain_coords), 2) 162 | ] 163 | 164 | if figure is not None: 165 | figure.add_traces( 166 | data=atom_scatter_plots + backbone_line_plots + sidechain_line_plots 167 | ) 168 | else: 169 | figure = go.Figure( 170 | data=atom_scatter_plots + backbone_line_plots + sidechain_line_plots 171 | ) 172 | figure.update_layout(height=600, width=1000, autosize=False) 173 | 174 | return figure 175 | 176 | 177 | def create_color_map( 178 | minval: float, 179 | maxval: float, 180 | colors=["0.8", "0"], 181 | masked_vals_color=None, 182 | return_sm=False, 183 | set_under=None, 184 | set_over=None, 185 | ) -> matplotlib.colors.LinearSegmentedColormap: 186 | """Creates a colorblind-friendly color map for plotting.""" 187 | mymap = matplotlib.colors.LinearSegmentedColormap.from_list( 188 | "mycolors", colors, N=256 189 | ) 190 | if set_under is None: 191 | set_under = colors[0] 192 | if set_over is None: 193 | set_over = colors[-1] 194 | 195 | mymap.set_under(set_under) 196 | mymap.set_over(set_over) 197 | 198 | if masked_vals_color is not None: 199 | mymap.set_bad(masked_vals_color) 200 | 201 | if return_sm: 202 | sm = plt.cm.ScalarMappable( 203 | cmap=mymap, norm=plt.Normalize(vmin=minval, vmax=maxval) 204 | ) 205 | sm._A = [] 206 | return mymap, sm 207 | 208 | return mymap 209 | 210 | 211 | def plot_color_map_simple(colormap, npoints: int = 256) -> None: 212 | if isinstance(colormap, str): 213 | colormap = matplotlib.cm.get_cmap(colormap) 214 | 215 | f = plt.figure(figsize=(2, 8)) 216 | ax = f.gca() 217 | ax.scatter( 218 | [0] * npoints, 219 | np.linspace(0, 1, npoints), 220 | color=colormap(np.linspace(0, 1, npoints)), 221 | ) 222 | ax.set_xticklabels([]) 223 | plt.show(block=False) 224 | 225 | 226 | readable_map_for_chars = create_color_map( 227 | 1, 0, colors=["white", "Gold", "SpringGreen", "DeepSkyBlue"] 228 | ) 229 | readable_map_for_chars_centered = create_color_map( 230 | 1, 0, colors=["Gold", "Moccasin", "white", "SpringGreen", "DeepSkyBlue"] 231 | ) 232 | readable_map_for_chars_vmin = create_color_map( 233 | 0, 234 | 1, 235 | colors=[ 236 | "AntiqueWhite", 237 | "Bisque", 238 | "Gold", 239 | "GoldenRod", 240 | "GreenYellow", 241 | "SpringGreen", 242 | "LightSkyBlue", 243 | "DeepSkyBlue", 244 | ], 245 | set_under="white", 246 | set_over="blue", 247 | ) 248 | 249 | 250 | def aa_conf_mat_heatmap( 251 | true_labs: pd.Series, 252 | pred_labs: pd.Series, 253 | normalize: str = "pred", 254 | cmap: Optional[str] = None, 255 | annot: bool = True, 256 | threshold: float = 0.10, 257 | figsize: Tuple[int, int] = (22, 16), 258 | xlab: str = "Predicted residues", 259 | ylab: str = "True residues", 260 | ): 261 | """ 262 | Plots a confusion matrix heatmap for a set of true labels/predicted amino acids. 263 | Expects amino acids encoded as integers 0-19, sorted in alphabetical order of 3 letter codes. 264 | Used to visualise the performance of sequence prediction models. 265 | """ 266 | 267 | labs = np.arange(20) 268 | 269 | conf_mat = confusion_matrix(true_labs, pred_labs, labels=labs, normalize=normalize) 270 | 271 | fig = plt.figure(figsize=figsize) 272 | gs = fig.add_gridspec(2, 3, height_ratios=[1.5, 5], width_ratios=[1.5, 4, 0.25]) 273 | dummy1_ax = fig.add_subplot(gs[0, 0]) 274 | bar1_ax = fig.add_subplot(gs[0, 1]) 275 | dummy2_ax = fig.add_subplot(gs[0, 2]) 276 | bar2_ax = fig.add_subplot(gs[1, 0]) 277 | hm_ax = fig.add_subplot(gs[1, 1], sharex=bar1_ax, sharey=bar2_ax) 278 | cmap_ax = fig.add_subplot(gs[1, 2]) 279 | 280 | if normalize == "true": 281 | cbar_title = "Row\nproportion\n" 282 | elif normalize == "pred": 283 | cbar_title = "Column\nproportion\n" 284 | else: 285 | cbar_title = "" 286 | 287 | cmap_ax.set_title(cbar_title, fontsize=18) 288 | cmap_ax.tick_params(labelsize=16) 289 | 290 | dummy1_ax.axis("off") 291 | dummy2_ax.axis("off") 292 | 293 | if normalize is not None: 294 | vmin = 0 295 | vmax = 1 296 | else: 297 | vmin = None 298 | vmax = None 299 | 300 | if cmap is None: 301 | cmap = readable_map_for_chars 302 | 303 | sns.heatmap( 304 | conf_mat, 305 | xticklabels=AminoAcid3.__members__, 306 | yticklabels=AminoAcid3.__members__, 307 | ax=hm_ax, 308 | cmap=cmap, 309 | annot=annot, 310 | annot_kws={"fontsize": 16}, 311 | cbar_ax=cmap_ax, 312 | linewidths=2, 313 | vmin=vmin, 314 | vmax=vmax, 315 | ) 316 | hm_ax.tick_params(axis="both", labelrotation=0, left=True) 317 | 318 | hm_ax.xaxis.tick_top() 319 | hm_ax.xaxis.set_label_position("top") 320 | for t in hm_ax.texts: 321 | if float(t.get_text()) >= threshold: 322 | t.set_text(t.get_text()) 323 | else: 324 | t.set_text("") 325 | 326 | hm_ax.tick_params(labelsize=14) 327 | bar1_ax.tick_params(labelsize=14) 328 | bar2_ax.tick_params(labelsize=14) 329 | 330 | true_lab_counts = true_labs.value_counts(normalize=True).sort_index() 331 | true_lab_counts = true_lab_counts.reindex(labs, fill_value=0.0) 332 | y_tick_pos = [i + 0.5 for i in range(20)] 333 | bar2_ax.barh( 334 | y=y_tick_pos, width=true_lab_counts.values, align="center", edgecolor="black" 335 | ) 336 | bar2_ax.set_xlim((0, max(true_lab_counts) + 0.05)) 337 | bar2_ax.invert_xaxis() 338 | bar2_ax.set_xlabel("\nFrequency", fontsize=20) 339 | bar2_ax.set_ylabel(f"{ylab}\n", fontsize=20) 340 | bar2_ax.tick_params(axis="y", left=True) 341 | 342 | pred_lab_counts = pred_labs.value_counts(normalize=True).sort_index() 343 | pred_lab_counts = pred_lab_counts.reindex(labs, fill_value=0.0) 344 | x_tick_pos = [i + 0.5 for i in range(20)] 345 | bar1_ax.bar( 346 | x=x_tick_pos, height=pred_lab_counts.values, align="center", edgecolor="black" 347 | ) 348 | bar1_ax.xaxis.tick_top() 349 | bar1_ax.xaxis.set_label_position("top") 350 | bar1_ax.set_ylabel("Frequency\n", fontsize=20) 351 | bar1_ax.set_xlabel(f"{xlab}\n", fontsize=20) 352 | 353 | return fig 354 | 355 | 356 | def ramachandran_plot( 357 | structures: List[LinearStructure], ax: Optional[plt.Axes] = None, **kwargs 358 | ): 359 | """ 360 | Makes a ramachandran plot using a list of structures. If `ax` is not specified, 361 | a new figure is created. Additional keyword arguments are passed to `plt.plot`. 362 | 363 | :param structures: A list of structures to plot. 364 | :param ax: The matplotlib axes on which to plot. 365 | :param kwargs: Additional keyword arguments to pass to `plt.plot`. 366 | :return: The matplotlib axes on which the plot was made. 367 | """ 368 | 369 | phi_arr = [] 370 | omega_arr = [] 371 | psi_arr = [] 372 | 373 | for structure in structures: 374 | angles = structure.get_backbone_dihedrals() 375 | phi = angles[1:, 0] 376 | omega = angles[:-1, 1] 377 | psi = angles[:-1, 2] 378 | 379 | phi_arr.extend(phi.numpy().tolist()) 380 | omega_arr.extend(omega.numpy().tolist()) 381 | psi_arr.extend(psi.numpy().tolist()) 382 | 383 | if ax is None: 384 | _, ax = plt.subplots(figsize=(8, 8)) 385 | 386 | ax.plot(phi_arr, psi_arr, ".", **kwargs) 387 | 388 | ax.set_xlim(-np.pi, np.pi) 389 | ax.set_ylim(-np.pi, np.pi) 390 | 391 | ax.set_aspect("equal") 392 | 393 | ax.axhline(color="black", linewidth=0.8) # Add horizontal axis 394 | ax.axvline(color="black", linewidth=0.8) # Add vertical axis 395 | 396 | ax.grid(True, linestyle="--", alpha=0.2) # Add grid lines 397 | return ax 398 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = ignore::DeprecationWarning -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | if __name__ == "__main__": 6 | setup( 7 | name="loopgen", 8 | version="0.0.1", 9 | packages=find_packages(), 10 | authors="Matt Greenig", 11 | email="mg989@cam.ac.uk", 12 | description="LoopGen: De novo design of peptide CDR binding loops with SE(3) diffusion models.", 13 | include_package_data=True, 14 | entry_points={"console_scripts": ["loopgen=loopgen.__main__:main"]}, 15 | ) 16 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | env_name="loopgen" 4 | 5 | # check OS 6 | unameOut="$(uname -s)" 7 | case "${unameOut}" in 8 | Linux*) OS=Linux;; 9 | Darwin*) OS=Mac;; 10 | *) OS="Other" 11 | esac 12 | 13 | # set up conda environment 14 | conda create -n $env_name 15 | eval "$(conda shell.bash hook)" 16 | conda activate $env_name 17 | 18 | # if on a mac and not on an ARM chip change the env config 19 | if [ $OS == "Mac" ] && [ $(uname -p) != "arm" ]; then 20 | conda config --env --set subdir osx-64 21 | fi 22 | 23 | conda install nomkl 24 | 25 | mamba env update -n $env_name -f envs/environment.yml 26 | 27 | # make sure to install compatible torch packages if on Linux to allow GPU usage 28 | if [ $OS == "Linux" ]; then 29 | pip install https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp311-cp311-linux_x86_64.whl 30 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_scatter-2.1.1%2Bpt20cu117-cp311-cp311-linux_x86_64.whl 31 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_sparse-0.6.17%2Bpt20cu117-cp311-cp311-linux_x86_64.whl 32 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_cluster-1.6.1%2Bpt20cu117-cp311-cp311-linux_x86_64.whl 33 | elif [ $OS == "Mac" ]; then 34 | pip install https://download.pytorch.org/whl/cpu/torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl 35 | pip install numpy==1.25.2 --force-reinstall 36 | pip install scipy==1.11.1 --force-reinstall 37 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_sparse-0.6.17-cp311-cp311-macosx_10_9_universal2.whl 38 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_scatter-2.1.1-cp311-cp311-macosx_10_9_universal2.whl 39 | pip install https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_cluster-1.6.1-cp311-cp311-macosx_10_9_universal2.whl 40 | else 41 | pip install torch==2.0.1 42 | pip install torch-sparse==0.6.17 43 | pip install torch-scatter==2.1.1 44 | pip install torch-cluster==1.6.1 45 | fi 46 | 47 | # install pip requirements 48 | pip install -r envs/pip_requirements.txt 49 | 50 | # install the package itself 51 | pip install . 52 | 53 | # set lib path for C++ libraries 54 | env_path=$(conda info --base)/envs/$env_name 55 | 56 | activate_env_vars=$env_path/etc/conda/activate.d/env_vars.sh 57 | echo "export LD_LIBRARY_PATH=$env_path/lib:$LD_LIBRARY_PATH" > $activate_env_vars 58 | 59 | deactivate_env_vars=$env_path/etc/conda/deactivate.d/env_vars.sh 60 | echo "unset LD_LIBRARY_PATH" > $deactivate_env_vars -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mgreenig/loopgen/e2726c8f24e84fdfb824f3616205a3d6b0d9703b/tests/__init__.py -------------------------------------------------------------------------------- /tests/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mgreenig/loopgen/e2726c8f24e84fdfb824f3616205a3d6b0d9703b/tests/nn/__init__.py -------------------------------------------------------------------------------- /tests/nn/test_gvp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the Geometric Vector Perceptron module. 3 | """ 4 | 5 | from typing import Tuple 6 | import torch 7 | import pytest 8 | from torch_scatter import scatter_sum 9 | from e3nn.o3 import rand_matrix 10 | 11 | from loopgen.nn.gvp import ( 12 | combine_gvp_features, 13 | separate_gvp_features, 14 | GeometricVectorPerceptron, 15 | GVPDropout, 16 | GVPLayerNorm, 17 | GVPMessage, 18 | GVPUpdate, 19 | GVPAttentionTypes, 20 | GVPAttention, 21 | GVPAttentionStrategySelector, 22 | GVPMessagePassing, 23 | ) 24 | from loopgen.utils import is_rotation_equivariant 25 | 26 | # Some parameters for the testing data 27 | NUM_NODES = 100 28 | BATCH_SIZE = 10 29 | NUM_SCALAR_FEATURES = 20 30 | NUM_VECTOR_FEATURES = 5 31 | NUM_COORDS = 1 32 | VECTOR_DIM_SIZE = 3 33 | COORD_DIM_SIZE = 3 34 | NUM_EDGES = 500 35 | NUM_EDGE_SCALAR_FEATURES = 10 36 | NUM_EDGE_VECTOR_FEATURES = 4 37 | 38 | assert NUM_NODES % BATCH_SIZE == 0, "NUM_NODES must be divisible by BATCH_SIZE" 39 | 40 | 41 | @pytest.fixture 42 | def gvp_data() -> Tuple[torch.Tensor, torch.Tensor]: 43 | """Generates some scalar and vector features for a GVP.""" 44 | torch.manual_seed(123) 45 | scalar_features = torch.randn((NUM_NODES, NUM_SCALAR_FEATURES)) 46 | vector_features = torch.randn((NUM_NODES, NUM_VECTOR_FEATURES, VECTOR_DIM_SIZE)) 47 | return scalar_features, vector_features 48 | 49 | 50 | @pytest.fixture 51 | def gvp_edge_data() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 52 | """Generates an edge index, scalar edge features, and vector edge features for a GVP-GNN.""" 53 | torch.manual_seed(123) 54 | edge_index = torch.randint(0, NUM_NODES, (2, NUM_EDGES)) 55 | edge_scalar_features = torch.randn((NUM_EDGES, NUM_EDGE_SCALAR_FEATURES)) 56 | edge_vector_features = torch.randn( 57 | (NUM_EDGES, NUM_EDGE_VECTOR_FEATURES, VECTOR_DIM_SIZE) 58 | ) 59 | return edge_index, edge_scalar_features, edge_vector_features 60 | 61 | 62 | @pytest.fixture 63 | def orientations() -> torch.Tensor: 64 | """Generates some orientation (rotation) matrices.""" 65 | torch.manual_seed(123) 66 | return rand_matrix(NUM_NODES) 67 | 68 | 69 | def test_combine_gvp_features(gvp_data): 70 | """ 71 | Tests the combine_gvp_features() function, which combines separate scalar and vector features 72 | into a single tensor, where the vector features are flattened. 73 | """ 74 | scalar_features, vector_features = gvp_data 75 | combined_features = combine_gvp_features(scalar_features, vector_features) 76 | assert combined_features.shape == ( 77 | NUM_NODES, 78 | NUM_SCALAR_FEATURES + (NUM_VECTOR_FEATURES * VECTOR_DIM_SIZE), 79 | ), "Wrong shape for combined features" 80 | 81 | 82 | def test_separate_gvp_features(gvp_data): 83 | """ 84 | Tests the separate_gvp_features() function, which separates a flattened tensor of scalar/vector features 85 | into two separate tensors. 86 | """ 87 | scalar_features, vector_features = gvp_data 88 | combined_features = combine_gvp_features(scalar_features, vector_features) 89 | output_scalar_features, output_vector_features = separate_gvp_features( 90 | combined_features, NUM_VECTOR_FEATURES, VECTOR_DIM_SIZE 91 | ) 92 | assert torch.allclose( 93 | scalar_features, output_scalar_features 94 | ), "Scalar features do not match" 95 | assert torch.allclose( 96 | vector_features, output_vector_features 97 | ), "Vector features do not match" 98 | 99 | 100 | def test_gvp(gvp_data): 101 | """ 102 | Tests the GeometricVectorPerceptron module, which is a single layer that processes 103 | scalar and vector features, ensuring that output vector features are rotation equivariant. 104 | """ 105 | scalar_features, vector_features = gvp_data 106 | out_scalar_channels = 10 107 | out_vector_channels = 5 108 | gvp_layer = GeometricVectorPerceptron( 109 | NUM_SCALAR_FEATURES, 110 | out_scalar_channels, 111 | NUM_VECTOR_FEATURES, 112 | out_vector_channels, 113 | ) 114 | 115 | # Test forward pass 116 | output_scalar_features, output_vector_features = gvp_layer(gvp_data) 117 | assert output_scalar_features.shape == ( 118 | NUM_NODES, 119 | out_scalar_channels, 120 | ), "Wrong shape for output scalar features" 121 | assert output_vector_features.shape == ( 122 | NUM_NODES, 123 | out_vector_channels, 124 | VECTOR_DIM_SIZE, 125 | ), "Wrong shape for output vector features" 126 | 127 | # Test equivariance using a GVP function that takes only 128 | # vector inputs (scalar features are fixed) and returns only vector outputs 129 | gvp_fn = lambda vs: gvp_layer((scalar_features, vs))[1] 130 | 131 | assert is_rotation_equivariant( 132 | gvp_fn, vector_features 133 | ), "GVP should be rotation equivariant" 134 | 135 | # Test scalars/vectors interacting by swapping features and seeing if the other type of output features change 136 | torch.manual_seed(42) 137 | new_scalar_features = torch.randn((NUM_NODES, NUM_SCALAR_FEATURES)) 138 | new_vector_features = torch.randn((NUM_NODES, NUM_VECTOR_FEATURES, VECTOR_DIM_SIZE)) 139 | 140 | new_output_scalar_features, _ = gvp_layer((scalar_features, new_vector_features)) 141 | assert not torch.allclose( 142 | output_scalar_features, new_output_scalar_features, atol=1e-3 143 | ), "Scalar features should change when vector features change" 144 | 145 | _, new_output_vector_features = gvp_layer((new_scalar_features, vector_features)) 146 | assert not torch.allclose( 147 | output_vector_features, new_output_vector_features, atol=1e-3 148 | ), "Vector features should change when scalar features change" 149 | 150 | 151 | def test_gvp_dropout(gvp_data): 152 | """ 153 | Tests the GVPDropout module, which performs normal dropout on scalar features 154 | and zeros entire 3-D vector features at once. 155 | """ 156 | scalar_features, vector_features = gvp_data 157 | 158 | torch.manual_seed(123) 159 | gvp_dropout = GVPDropout(0.5, NUM_VECTOR_FEATURES) 160 | 161 | # Sanity check that no input features are zero 162 | assert not torch.any( 163 | scalar_features == 0.0 164 | ).item(), "No scalar features should be zero" 165 | assert not torch.any( 166 | vector_features == 0.0 167 | ).item(), "No vector features should be zero" 168 | 169 | # Test forward pass 170 | output_scalar_features, output_vector_features = gvp_dropout(gvp_data) 171 | assert torch.any( 172 | output_scalar_features == 0.0 173 | ), "Some scalar features should be zero after dropout" 174 | assert torch.any( 175 | output_vector_features == 0.0 176 | ), "Some vector features should be zero dropout" 177 | 178 | # Test that entire vector channels are dropped out at once 179 | assert torch.all( 180 | torch.any(output_vector_features == 0.0, dim=-1) 181 | == torch.all(output_vector_features == 0.0, dim=-1) 182 | ).item(), "Entire vector channels should be dropped out at once" 183 | 184 | 185 | def test_gvp_layer_norm(gvp_data): 186 | """ 187 | Tests the GVPLayerNorm module, which performs typical layer normalization on scalar features 188 | and normalises vector features so that the norm of the vector feature matrix is equal to the 189 | number of vector features. 190 | """ 191 | gvp_layer_norm = GVPLayerNorm(NUM_SCALAR_FEATURES, NUM_VECTOR_FEATURES) 192 | 193 | # Test forward pass 194 | output_scalar_features, output_vector_features = gvp_layer_norm(gvp_data) 195 | 196 | assert torch.allclose( 197 | torch.mean(output_scalar_features, dim=-1), torch.zeros(NUM_NODES), atol=1e-5 198 | ), "Scalar feature outputs of LayerNorm should be zero-mean" 199 | 200 | assert torch.allclose( 201 | torch.std(output_scalar_features, unbiased=False, dim=-1), 202 | torch.ones(NUM_NODES), 203 | atol=1e-5, 204 | ), "Scalar feature outputs of LayerNorm should be unit standard deviation" 205 | 206 | assert torch.allclose( 207 | torch.linalg.norm(output_vector_features, dim=(-2, -1)), 208 | torch.sqrt(torch.as_tensor(NUM_VECTOR_FEATURES)), 209 | atol=1e-5, 210 | ), "Vector feature outputs of LayerNorm should have norm equal to sqrt(num_vector_features)" 211 | 212 | 213 | def test_gvp_message(gvp_data): 214 | """ 215 | Tests the GVPMessage module, which transforms a message using sequential GVPs. 216 | """ 217 | 218 | scalar_features, vector_features = gvp_data 219 | num_gvps = 3 220 | gvp_message = GVPMessage(NUM_SCALAR_FEATURES, NUM_VECTOR_FEATURES, num_gvps) 221 | 222 | assert ( 223 | len(gvp_message._message_gvp_layers) == num_gvps 224 | ), "Wrong number of message GVP layers" 225 | 226 | out_scalar_features, out_vector_features = gvp_message( 227 | scalar_features, vector_features 228 | ) 229 | 230 | assert ( 231 | out_scalar_features.shape == scalar_features.shape 232 | ), "GVPMessage should not change shape of scalar inputs" 233 | assert ( 234 | out_vector_features.shape == vector_features.shape 235 | ), "GVPMessage should not change shape of vector inputs" 236 | 237 | message_fn = lambda v_x: gvp_message(scalar_features, v_x)[1] 238 | assert is_rotation_equivariant( 239 | message_fn, vector_features 240 | ), "GVPMessage should be rotation equivariant" 241 | 242 | 243 | def test_gvp_update(gvp_data): 244 | """ 245 | Tests the GVPUpdate module, which transforms a message using sequential GVPs. 246 | """ 247 | 248 | scalar_features, vector_features = gvp_data 249 | num_output_scalar_features = NUM_SCALAR_FEATURES + 1 250 | num_output_vector_features = NUM_VECTOR_FEATURES + 1 251 | state_scalar_features = torch.nn.functional.pad(scalar_features, (0, 1)) 252 | state_vector_features = torch.nn.functional.pad(vector_features, (0, 0, 0, 1)) 253 | num_gvps = 3 254 | dropout = 0.2 255 | gvp_update = GVPUpdate( 256 | num_output_scalar_features, 257 | NUM_SCALAR_FEATURES, 258 | num_output_vector_features, 259 | NUM_VECTOR_FEATURES, 260 | num_gvps, 261 | dropout, 262 | ) 263 | gvp_update.eval() 264 | 265 | assert ( 266 | len(gvp_update._update_gvp_layers) == num_gvps 267 | ), "Wrong number of update GVP layers" 268 | 269 | out_scalar_features, out_vector_features = gvp_update( 270 | state_scalar_features, scalar_features, state_vector_features, vector_features 271 | ) 272 | 273 | assert ( 274 | out_scalar_features.shape == state_scalar_features.shape 275 | ), "GVPUpdate did not produce the correct number of output scalar features" 276 | assert ( 277 | out_vector_features.shape == state_vector_features.shape 278 | ), "GVPUpdate did not produce the correct number of output vector features" 279 | 280 | message_fn = lambda state_v_x, v_x: gvp_update( 281 | state_scalar_features, scalar_features, state_v_x, v_x 282 | )[1] 283 | assert is_rotation_equivariant( 284 | message_fn, state_vector_features, vector_features 285 | ), "GVPUpdate should be rotation equivariant" 286 | 287 | 288 | def test_gvp_attention(gvp_data, orientations): 289 | """ 290 | Tests the GVPAttentionStrategySelector - which selects a GVPAttention strategy for an input string - 291 | and the GVPAttention modules, which calculate attention coefficients based on input scalar and vector features. 292 | """ 293 | 294 | scalar_features, vector_features = gvp_data 295 | batch = torch.arange(NUM_NODES // BATCH_SIZE).repeat_interleave(BATCH_SIZE) 296 | num_heads = 5 297 | 298 | selector = GVPAttentionStrategySelector( 299 | NUM_SCALAR_FEATURES, 300 | NUM_VECTOR_FEATURES, 301 | NUM_SCALAR_FEATURES, 302 | NUM_VECTOR_FEATURES, 303 | num_heads=num_heads, 304 | ) 305 | for attn_type in GVPAttentionTypes.__args__: 306 | attn_module = selector.get_layer(attn_type) 307 | assert isinstance( 308 | attn_module, GVPAttention 309 | ), "GVPAttentionStrategySelector should return a GVPAttention module" 310 | attn_weights = attn_module( 311 | scalar_features, 312 | vector_features, 313 | orientations, 314 | scalar_features, 315 | vector_features, 316 | orientations, 317 | batch, 318 | ) 319 | assert attn_weights.shape == ( 320 | NUM_NODES, 321 | num_heads, 322 | ), "Wrong shape for attention weights" 323 | 324 | summed_weights = scatter_sum(attn_weights, batch, dim=0) 325 | assert torch.allclose( 326 | summed_weights, torch.ones_like(summed_weights) 327 | ), "Attention weights should sum to 1 over the softmax index" 328 | 329 | # Test equivariance using a GVP function that takes only 330 | # vector inputs (scalar features are fixed) and returns only vector outputs 331 | attn_fn = lambda v, o: attn_module( 332 | scalar_features, v, o, scalar_features, v, o, batch 333 | ) 334 | assert is_rotation_equivariant( 335 | attn_fn, vector_features, orientations, test_invariant=True 336 | ), "Attention weights should be rotation invariant" 337 | 338 | # check that passing an invalid string raises a ValueError 339 | with pytest.raises(ValueError): 340 | selector.get_layer("invalid") 341 | 342 | 343 | class TestGVPMessagePassing: 344 | """Tests the GVPMessagePassing module.""" 345 | 346 | # some default message passing parameters 347 | aggr = "sum" 348 | num_message_gvps = 3 349 | num_update_gvps = 2 350 | dropout = 0.2 351 | num_heads = 1 352 | 353 | message_passing_layer = GVPMessagePassing( 354 | NUM_SCALAR_FEATURES, 355 | NUM_VECTOR_FEATURES, 356 | NUM_EDGE_SCALAR_FEATURES, 357 | NUM_EDGE_VECTOR_FEATURES, 358 | aggr, 359 | num_message_gvps, 360 | num_update_gvps, 361 | dropout, 362 | num_heads=num_heads, 363 | vector_dim_size=VECTOR_DIM_SIZE, 364 | ) 365 | message_passing_layer.eval() 366 | 367 | message_passing_layer_attn = GVPMessagePassing( 368 | NUM_SCALAR_FEATURES, 369 | NUM_VECTOR_FEATURES, 370 | NUM_EDGE_SCALAR_FEATURES, 371 | NUM_EDGE_VECTOR_FEATURES, 372 | aggr, 373 | num_message_gvps, 374 | num_update_gvps, 375 | dropout, 376 | attention_type="flatten", 377 | num_heads=num_heads, 378 | vector_dim_size=VECTOR_DIM_SIZE, 379 | ) 380 | message_passing_layer_attn.eval() 381 | 382 | message_out_channels = ( 383 | NUM_SCALAR_FEATURES + (NUM_VECTOR_FEATURES * VECTOR_DIM_SIZE) 384 | ) * 2 + (NUM_EDGE_SCALAR_FEATURES + (NUM_EDGE_VECTOR_FEATURES * VECTOR_DIM_SIZE)) 385 | 386 | def test_init(self): 387 | """Tests the constructor.""" 388 | 389 | assert ( 390 | len(self.message_passing_layer._message_layer._message_gvp_layers) 391 | == self.num_message_gvps 392 | ), "Wrong number of message GVP layers" 393 | assert ( 394 | len(self.message_passing_layer._node_update_layer._update_gvp_layers) 395 | == self.num_update_gvps 396 | ), "Wrong number of update GVP layers" 397 | 398 | @staticmethod 399 | def get_features( 400 | node_features: Tuple[torch.Tensor, torch.Tensor], 401 | edge_features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 402 | ): 403 | """Converts the GVP data into the features used by the message passing module.""" 404 | scalar_features, vector_features = node_features 405 | edge_index, edge_scalar_features, edge_vector_features = edge_features 406 | 407 | x = torch.cat([scalar_features, vector_features.flatten(start_dim=-2)], dim=-1) 408 | edge_attr = torch.cat( 409 | [edge_scalar_features, edge_vector_features.flatten(start_dim=-2)], 410 | dim=-1, 411 | ) 412 | 413 | return x, edge_attr 414 | 415 | def test_forward(self, gvp_data, gvp_edge_data, orientations): 416 | """Tests the forward pass and its constituent operations (message, aggregate and update).""" 417 | 418 | scalar_features = gvp_data[0] 419 | vector_features = gvp_data[1] 420 | 421 | edge_index, edge_scalar_features, edge_vector_features = gvp_edge_data 422 | 423 | x, edge_attr = self.get_features(gvp_data, gvp_edge_data) 424 | x_i = x[edge_index[1]] 425 | x_j = x[edge_index[0]] 426 | 427 | message = self.message_passing_layer.message( 428 | x_i, 429 | x_j, 430 | edge_attr, 431 | orientations[edge_index[1]].flatten(start_dim=-2), 432 | orientations[edge_index[0]].flatten(start_dim=-2), 433 | edge_index[1], 434 | ) 435 | attn_message = self.message_passing_layer_attn.message( 436 | x_i, 437 | x_j, 438 | edge_attr, 439 | orientations[edge_index[1]].flatten(start_dim=-2), 440 | orientations[edge_index[0]].flatten(start_dim=-2), 441 | edge_index[1], 442 | ) 443 | 444 | assert message.shape == ( 445 | NUM_EDGES, 446 | self.message_out_channels, 447 | ), "Message shape is not correct" 448 | assert attn_message.shape == ( 449 | NUM_EDGES, 450 | self.message_out_channels, 451 | ), "Message shape is not correct" 452 | assert not torch.any( 453 | message.isnan() 454 | ).item(), "Output of message() contains NaNs" 455 | assert not torch.any( 456 | attn_message.isnan() 457 | ).item(), "Output of message() contains NaNs" 458 | 459 | aggr_message = self.message_passing_layer.aggregate(message, edge_index[1]) 460 | aggr_attn_message = self.message_passing_layer_attn.aggregate( 461 | attn_message, edge_index[1] 462 | ) 463 | 464 | updated = self.message_passing_layer.update(aggr_message, x) 465 | updated_attn = self.message_passing_layer_attn.update(aggr_attn_message, x) 466 | 467 | assert updated.shape == x.shape, "Updated shape is not correct" 468 | assert ( 469 | updated_attn.shape == x.shape 470 | ), "Updated shape is not correct (with attention)" 471 | assert not torch.any(updated.isnan()).item(), "Output of update() contains NaNs" 472 | assert not torch.any( 473 | updated_attn.isnan() 474 | ).item(), "Output of update() contains NaNs" 475 | 476 | def vector_forward_pass(layer, v_x, e_v_x, ors): 477 | """Performs a forward pass of the message passing layer as a function of the vector features.""" 478 | node_features, edge_features = self.get_features( 479 | (scalar_features, v_x), (edge_index, edge_scalar_features, e_v_x) 480 | ) 481 | orientation_features = ors.flatten(start_dim=-2) 482 | output = layer( 483 | node_features, edge_index, edge_features, orientation_features 484 | ) 485 | _, v_features = separate_gvp_features( 486 | output, NUM_VECTOR_FEATURES, VECTOR_DIM_SIZE 487 | ) 488 | return v_features 489 | 490 | # Test equivariance of the forward pass 491 | assert is_rotation_equivariant( 492 | lambda v_x, e_v_x, ors: vector_forward_pass( 493 | self.message_passing_layer, 494 | v_x, 495 | e_v_x, 496 | ors, 497 | ), 498 | vector_features, 499 | edge_vector_features, 500 | orientations, 501 | ), "Message passing should be rotation equivariant" 502 | 503 | # attention-based model should also be equivariant, 504 | # but lower sensitivity should be used due to numerical instability 505 | # with the softmax in the attention 506 | assert is_rotation_equivariant( 507 | lambda v_x, e_v_x, ors: vector_forward_pass( 508 | self.message_passing_layer_attn, 509 | v_x, 510 | e_v_x, 511 | ors, 512 | ), 513 | vector_features, 514 | edge_vector_features, 515 | orientations, 516 | atol=1e-2, 517 | ), "Message passing should be rotation equivariant" 518 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | """ Tests the dataset class.""" 2 | 3 | import pytest 4 | import numpy as np 5 | import os 6 | import torch 7 | from loopgen import ReceptorLigandDataset, Structure 8 | from loopgen.data import BaseDataModule 9 | 10 | 11 | @pytest.fixture 12 | def data_path() -> str: 13 | """Returns the path to the test data set.""" 14 | return os.path.join(os.path.dirname(__file__), "data/cdrs_20.hdf5") 15 | 16 | 17 | @pytest.fixture 18 | def dataset(data_path) -> ReceptorLigandDataset: 19 | """Returns the actual data set object for the test data set.""" 20 | return ReceptorLigandDataset.from_hdf5_file(data_path) 21 | 22 | 23 | @pytest.fixture 24 | def datamodule(dataset) -> BaseDataModule: 25 | """Returns a non-abstract version of the BaseCDRDataModule class.""" 26 | datamodule_class = BaseDataModule 27 | datamodule_class.__abstractmethods__ = set() 28 | return datamodule_class(dataset) 29 | 30 | 31 | @pytest.fixture 32 | def datamodule_with_cdr_types(dataset) -> BaseDataModule: 33 | """Returns a non-abstract version of the BaseCDRDataModule class with cdr types.""" 34 | datamodule_class = BaseDataModule 35 | datamodule_class.__abstractmethods__ = set() 36 | cdr_types = { 37 | pair["name"]: str(int(i % 2 == 0)) 38 | for i, pair in enumerate(dataset.structure_pairs) 39 | } 40 | cdr_type_weights = {"0": 1, "1": 3} 41 | return datamodule_class( 42 | dataset, 43 | cdr_type_dict=cdr_types, 44 | cdr_type_weights=cdr_type_weights, 45 | name_to_id_fn=lambda x: x, 46 | ) 47 | 48 | 49 | class TestReceptorLigandDataset: 50 | def test_from_hdf5_file(self, data_path): 51 | dataset = ReceptorLigandDataset.from_hdf5_file(data_path) 52 | 53 | for pair_dict in dataset.structure_pairs: 54 | assert "name" in pair_dict, "Structure pair should have a name" 55 | assert "antigen" in pair_dict, "Structure pair should have an antigen" 56 | assert "cdr" in pair_dict, "Structure pair should have a cdr" 57 | 58 | def test_len(self, dataset): 59 | assert len(dataset) == len( 60 | dataset.structure_pairs 61 | ), "Length of dataset should be same as length of the list of structure pairs" 62 | 63 | def test_getitem(self, dataset): 64 | for i in range(20): 65 | name, epitope, cdr = dataset[i] 66 | assert isinstance(name, str), "Name should be a string" 67 | assert isinstance(epitope, Structure), "Epitope should be a Structure" 68 | assert isinstance(cdr, Structure), "CDR should be a Structure" 69 | 70 | assert dataset.structure_pairs[i]["name"] == name, "Name should match" 71 | 72 | assert np.allclose( 73 | dataset.structure_pairs[i]["antigen"]["N_coords"][:], 74 | epitope.N_coords.cpu().numpy(), 75 | ), "N coords of returned antigen should match those in dataset file" 76 | 77 | assert np.allclose( 78 | dataset.structure_pairs[i]["antigen"]["CA_coords"][:], 79 | epitope.CA_coords.cpu().numpy(), 80 | ), "CA coords of returned antigen should match those in dataset file" 81 | 82 | assert np.allclose( 83 | dataset.structure_pairs[i]["antigen"]["C_coords"][:], 84 | epitope.C_coords.cpu().numpy(), 85 | ), "C coords of returned antigen should match those in dataset file" 86 | 87 | assert np.allclose( 88 | dataset.structure_pairs[i]["antigen"]["CB_coords"][:], 89 | epitope.CB_coords.cpu().numpy(), 90 | ), "CB coords of returned antigen should match those in dataset file" 91 | 92 | assert np.allclose( 93 | dataset.structure_pairs[i]["antigen"]["sequence"][:], 94 | epitope.sequence.cpu().numpy(), 95 | ), "Sequence of returned antigen should match those in dataset file" 96 | 97 | assert np.allclose( 98 | dataset.structure_pairs[i]["cdr"]["N_coords"][:], 99 | cdr.N_coords.cpu().numpy(), 100 | ), "N coords of returned cdr should match those in dataset file" 101 | 102 | assert np.allclose( 103 | dataset.structure_pairs[i]["cdr"]["CA_coords"][:], 104 | cdr.CA_coords.cpu().numpy(), 105 | ), "CA coords of returned cdr should match those in dataset file" 106 | 107 | assert np.allclose( 108 | dataset.structure_pairs[i]["cdr"]["C_coords"][:], 109 | cdr.C_coords.cpu().numpy(), 110 | ), "C coords of returned cdr should match those in dataset file" 111 | 112 | assert np.allclose( 113 | dataset.structure_pairs[i]["cdr"]["CB_coords"][:], 114 | cdr.CB_coords.cpu().numpy(), 115 | ), "CB coords of returned cdr should match that in dataset file" 116 | 117 | assert np.allclose( 118 | dataset.structure_pairs[i]["cdr"]["sequence"][:], 119 | cdr.sequence.cpu().numpy(), 120 | ), "Sequence of returned cdr should match those in dataset file" 121 | 122 | def test_train_test_split(self, dataset): 123 | """ 124 | Tests the train_test_split() function, which performs a random train/test 125 | split of the dataset. 126 | """ 127 | train_p = 0.8 128 | train_1, test_1 = dataset.train_test_split(train_prop=train_p, random_state=123) 129 | train_2, test_2 = dataset.train_test_split(train_prop=train_p, random_state=123) 130 | 131 | assert len(train_1) == int( 132 | len(dataset) * train_p 133 | ), "Train set should be train_p * dataset set size" 134 | 135 | assert ( 136 | train_1.structure_pairs == train_2.structure_pairs 137 | ), "Train sets should be identical when using same random state" 138 | assert ( 139 | test_1.structure_pairs == test_2.structure_pairs 140 | ), "Test sets should be identical when using same random state" 141 | 142 | train_names = [pair["name"] for pair in train_1.structure_pairs] 143 | test_names = [pair["name"] for pair in test_1.structure_pairs] 144 | 145 | assert ( 146 | len(set(train_names).intersection(set(test_names))) == 0 147 | ), "Train and test sets should not have any names in common" 148 | 149 | 150 | class TestBaseCDRDataModule: 151 | """ 152 | Tests the base class for CDR data modules. 153 | This is actually an abstract class but we instantiate it here 154 | to test some of its general functionality. 155 | """ 156 | 157 | def test_setup(self, datamodule): 158 | """ 159 | Tests the setup() function, which is called by the pytorch lightning trainer 160 | before training begins. 161 | """ 162 | # check if train/test/val datasets are None before setup() 163 | assert ( 164 | datamodule.train_dataset is None 165 | ), "Train dataset should be None before setup()" 166 | assert ( 167 | datamodule.test_dataset is None 168 | ), "Test dataset should be None before setup()" 169 | assert ( 170 | datamodule.validation_dataset is None 171 | ), "Val dataset should be None before setup()" 172 | 173 | datamodule.setup("fit") 174 | 175 | assert ( 176 | datamodule.train_dataset is not None 177 | ), "Train dataset should not be None after setup()" 178 | assert ( 179 | datamodule.test_dataset is not None 180 | ), "Test dataset should not be None after setup()" 181 | assert ( 182 | datamodule.validation_dataset is not None 183 | ), "Val dataset should not be None after setup()" 184 | 185 | assert len(datamodule.train_dataset) == int( 186 | datamodule._train_prop * len(datamodule.dataset) 187 | ), "Train dataset should be train_prop * dataset size" 188 | assert len(datamodule.test_dataset) == int( 189 | datamodule._test_prop * len(datamodule.dataset) 190 | ), "Test dataset should be test_prop * dataset size" 191 | assert len(datamodule.validation_dataset) == int( 192 | datamodule._val_prop * len(datamodule.dataset) 193 | ), "Val dataset should be val_prop * dataset size" 194 | 195 | train_names = [ 196 | pair["name"] for pair in datamodule.train_dataset.structure_pairs 197 | ] 198 | test_names = [pair["name"] for pair in datamodule.test_dataset.structure_pairs] 199 | val_names = [ 200 | pair["name"] for pair in datamodule.validation_dataset.structure_pairs 201 | ] 202 | 203 | assert ( 204 | len(set(train_names).intersection(set(test_names))) == 0 205 | ), "Train and test sets should not have any names in common" 206 | assert ( 207 | len(set(train_names).intersection(set(val_names))) == 0 208 | ), "Train and val sets should not have any names in common" 209 | assert ( 210 | len(set(test_names).intersection(set(val_names))) == 0 211 | ), "Test and val sets should not have any names in common" 212 | 213 | train_pdb_ids = [datamodule._name_to_pdb_id_fn(name) for name in train_names] 214 | test_pdb_ids = [datamodule._name_to_pdb_id_fn(name) for name in test_names] 215 | val_pdb_ids = [datamodule._name_to_pdb_id_fn(name) for name in val_names] 216 | 217 | assert ( 218 | len(set(train_pdb_ids).intersection(set(test_pdb_ids))) == 0 219 | ), "Train and test sets should not have any PDB IDs in common" 220 | assert ( 221 | len(set(train_pdb_ids).intersection(set(val_pdb_ids))) == 0 222 | ), "Train and val sets should not have any PDB IDs in common" 223 | assert ( 224 | len(set(test_pdb_ids).intersection(set(val_pdb_ids))) == 0 225 | ), "Test and val sets should not have any PDB IDs in common" 226 | 227 | def test_get_cdr_sampler(self, datamodule, datamodule_with_cdr_types): 228 | weighted_sampler = datamodule_with_cdr_types._get_cdr_sampler( 229 | datamodule_with_cdr_types.dataset 230 | ) 231 | """ 232 | Tests the _get_cdr_sampler() function, a protected 233 | method that returns a WeightedRandomSampler if CDR types and their 234 | corresponding sampling weights are provided to the datamodule, and None otherwise. 235 | """ 236 | 237 | assert isinstance( 238 | weighted_sampler, torch.utils.data.WeightedRandomSampler 239 | ), "CDR sampler should be a WeightedRandomSampler when CDR types are provided" 240 | 241 | sampler = datamodule._get_cdr_sampler(datamodule.dataset) 242 | 243 | assert ( 244 | sampler is None 245 | ), "CDR sampler should be None when CDR types are not provided" 246 | -------------------------------------------------------------------------------- /tests/test_distributions.py: -------------------------------------------------------------------------------- 1 | """ Tests the various distributions in loopgen.distributions.""" 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from math import pi 7 | from loopgen.distributions import IGSO3Distribution, Gaussian3DDistribution 8 | from loopgen.utils import is_positive_semidefinite, is_symmetric 9 | 10 | 11 | class Gaussian3DDistributionTest: 12 | 13 | """ 14 | Tests the Gaussian3DDistribution class. 15 | """ 16 | 17 | mean = torch.tensor([0.0, 0.0, 0.0]) 18 | variance = 1.0 19 | non_iso_cov_matrix = torch.ones(9).reshape(3, 3) 20 | non_iso_cov_matrix += torch.eye(3) * 2 21 | non_iso_cov_matrix += non_iso_cov_matrix.T.clone() 22 | 23 | distribution = Gaussian3DDistribution(mean, variance) 24 | non_iso_distribution = Gaussian3DDistribution(mean, non_iso_cov_matrix) 25 | 26 | def test_mean(self): 27 | assert torch.all( 28 | torch.eq(self.distribution.mean, self.mean) 29 | ).item(), "Inputted mean and distribution mean are not equal" 30 | 31 | def test_cov_matrix(self): 32 | """ 33 | Check that the stored covariance matrix - when a scalar variance is entered - 34 | is positive semi-definite and symmetric. 35 | """ 36 | 37 | for distr in [self.distribution, self.non_iso_distribution]: 38 | assert is_symmetric(distr.cov_matrix), "Covariance matrix is not symmetric" 39 | assert is_positive_semidefinite( 40 | distr.cov_matrix 41 | ), "Covariance matrix is not positive semi-definite" 42 | 43 | assert is_symmetric(distr.prec_matrix), "Precision matrix is not symmetric" 44 | assert is_positive_semidefinite( 45 | distr.prec_matrix 46 | ), "Precision matrix is not positive semidefinite" 47 | 48 | prec_matrix = torch.linalg.inv(distr.cov_matrix) 49 | assert torch.allclose( 50 | distr.prec_matrix, prec_matrix 51 | ), "Precision matrix is not the inverse of the covariance matrix" 52 | 53 | def test_pdf(self): 54 | sample = torch.as_tensor([0.1, 0.2, 0.3]) 55 | pdf_value = self.distribution.pdf(sample) 56 | expected_pdf_value = torch.exp(self.distribution._distribution.log_prob(sample)) 57 | assert np.allclose( 58 | pdf_value.item(), expected_pdf_value.item() 59 | ), "PDF values do not match true values" 60 | 61 | def test_is_isotropic(self): 62 | assert self.distribution.is_isotropic, "Distribution should be isotropic" 63 | assert ( 64 | not self.non_iso_distribution.is_isotropic 65 | ), "Distribution should not be isotropic" 66 | 67 | 68 | class TestIGSO3Distribution: 69 | 70 | """ 71 | Tests the IGSO3 distribution. 72 | """ 73 | 74 | default = IGSO3Distribution() 75 | 76 | def test_init(self): 77 | with pytest.raises(AssertionError): 78 | IGSO3Distribution(-1) 79 | with pytest.raises(AssertionError): 80 | IGSO3Distribution(0) 81 | with pytest.raises(AssertionError): 82 | IGSO3Distribution(support_n=-1) 83 | with pytest.raises(AssertionError): 84 | IGSO3Distribution(support_n=0) 85 | with pytest.raises(AssertionError): 86 | IGSO3Distribution(expansion_n=-1) 87 | with pytest.raises(AssertionError): 88 | IGSO3Distribution(expansion_n=0) 89 | 90 | def test_sample(self): 91 | """ 92 | Rotation matrix must be valid. IGSO(3) has been visually confirmed 93 | as described at the top of the class 94 | """ 95 | 96 | rot_mat = self.default.sample(1) 97 | 98 | assert rot_mat.shape == ( 99 | 3, 100 | 3, 101 | ), "Sampled rotation matrix should have shape (3, 3)" 102 | 103 | det = np.linalg.det(rot_mat) 104 | assert np.isclose( 105 | det, 1.0, atol=1e-5 106 | ), "Sampled rotation matrix does not have a determinant of 1" 107 | 108 | inv = np.linalg.inv(rot_mat) 109 | transpose = np.transpose(rot_mat) 110 | assert np.all( 111 | np.isclose(inv, transpose, atol=1e-5) 112 | ), "Sampled rotation matrix is not orthogonal" 113 | 114 | # test sampling for different values of size 115 | sizes = [2, 20, (3, 4)] 116 | for size in sizes: 117 | rot_mat = self.default.sample(size) 118 | if isinstance(size, int): 119 | size = (size,) 120 | assert rot_mat.shape == size + ( 121 | 3, 122 | 3, 123 | ), "Sampled rotation matrix should have shape size + (3, 3)" 124 | 125 | def test_pdf(self): 126 | pdf_default = self.default._densities 127 | assert np.allclose( 128 | pdf_default[0], 0.0, atol=1e-6 129 | ), "0 should have 0 probability" 130 | 131 | new = IGSO3Distribution(support_n=5) 132 | assert len(new._densities == 5), "pdf should only have size of support_n" 133 | 134 | def test_cdf(self): 135 | assert np.allclose( 136 | self.default.cdf(0), 0.0, atol=1e-3 137 | ), "cumulative density at 0º should be 0" 138 | assert np.allclose( 139 | self.default.cdf(pi), 1.0, atol=1e-3 140 | ), "cumulative density at pi should be 1" 141 | 142 | def test_inv_cdf(self): 143 | with pytest.raises(AssertionError): 144 | self.default.inv_cdf(-1) 145 | 146 | # self.assertAlmostEqual(default.inv_cdf(0),0, delta = 1e-6, msg = "inv cdf of 0 should be 0") 147 | # self.assertAlmostEqual(default.inv_cdf(1e-9),0, delta = 1e-3, msg = "inv cdf of 1e-9 should be 0") 148 | assert np.allclose( 149 | self.default.inv_cdf(1), pi, atol=1e-3 150 | ), "inv cdf of 0 should be 0" 151 | 152 | def test_score(self): 153 | score = self.default.score(1.0) 154 | eps = 1e-12 155 | # estimate score with finite differences 156 | estimated_score = ( 157 | np.log(self.default.inf_sum(1.0 + eps)) - np.log(self.default.inf_sum(1.0)) 158 | ) / 1e-12 159 | assert np.allclose( 160 | score, estimated_score, atol=1e-2 161 | ), "Score estimated with finite differences should be within 0.01 of calculated score." 162 | 163 | def test_sample_axis(self): 164 | """ 165 | checks axis is normalized 166 | """ 167 | vec = self.default.sample_axis(1) 168 | norm = np.linalg.norm(vec) 169 | assert np.allclose(norm, 1.0), "Sampled axis norm is not 1" 170 | 171 | def test_sample_angle(self): 172 | """ 173 | sample is combo of axis and angle, so if it works these are fine 174 | """ 175 | angle = self.default.sample_angle(1) 176 | assert 0 <= angle <= pi, "Sampled angle is not in [0, pi]" 177 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ Tests the utility functions in loopgen.utils.""" 2 | 3 | import pytest 4 | import torch 5 | from torch_geometric.data import Data 6 | 7 | from loopgen.utils import ( 8 | is_equivariant, 9 | is_rotation_equivariant, 10 | is_translation_equivariant, 11 | so3_log_map, 12 | so3_exp_map, 13 | so3_hat, 14 | so3_hat_inv, 15 | node_type_subgraph, 16 | ) 17 | 18 | NUM_COORDS = 10 19 | 20 | 21 | @pytest.fixture 22 | def coords() -> torch.Tensor: 23 | """Returns a tensor of coordinates.""" 24 | torch.manual_seed(123) 25 | return torch.randn(NUM_COORDS, 3) 26 | 27 | 28 | def test_is_equivariant(coords): 29 | """ 30 | Tests the is_equivariant() function, which determines if an 31 | input function is equivariant with respect to another function. 32 | """ 33 | # assert that identity is equivariant with respect to addition and multiplication 34 | assert is_equivariant( 35 | lambda x: x, lambda x: x + 1, coords 36 | ), "Identity should be equivariant with respect to addition" 37 | assert is_equivariant( 38 | lambda x: x, lambda x: x * 2, coords 39 | ), "Identity should be equivariant with respect to multiplication" 40 | 41 | # check multiple arguments work 42 | assert is_equivariant( 43 | lambda x, y: x + y, lambda x: x * 2, coords, coords 44 | ), "Identity should be equivariant with respect to multiplication" 45 | 46 | # check that squaring is not equivariant with respect to addition 47 | assert not is_equivariant( 48 | lambda x: x**2, lambda x: x + 1, coords 49 | ), "Squaring should not be equivariant with respect to addition" 50 | 51 | 52 | def test_is_rotation_equivariant(coords): 53 | """ 54 | Tests the is_rotation_equivariant() function, 55 | which determines if an input function is equivariant to 3D rotation. 56 | """ 57 | assert is_rotation_equivariant( 58 | lambda x, y: x + y, coords, coords 59 | ), "Addition of two vectors should be rotation equivariant" 60 | assert not is_rotation_equivariant( 61 | lambda x: x + 1, coords 62 | ), "Addition by a constant should not be rotation equivariant" 63 | assert is_rotation_equivariant( 64 | lambda x: x * 2, coords 65 | ), "Multiplication should be rotation equivariant" 66 | assert not is_rotation_equivariant( 67 | lambda x: x * torch.as_tensor([1.0, 2.0, 3.0]), coords 68 | ), "Coordinate-wise multiplication should not be rotation equivariant" 69 | 70 | 71 | def test_is_translation_equivariant(coords): 72 | """ 73 | Tests the is_translation_equivariant() function, 74 | which determines if an input function is equivariant to 3D translation. 75 | """ 76 | assert not is_translation_equivariant( 77 | lambda x, y: x + y, coords, coords 78 | ), "Addition of two vectors should not be translation equivariant" 79 | assert is_translation_equivariant( 80 | lambda x: x + 1, coords 81 | ), "Addition by a constant should be translation equivariant" 82 | assert is_translation_equivariant( 83 | lambda x, y: (x + y) / 2, coords, coords 84 | ), "Mean should be translation equivariant" 85 | 86 | 87 | def test_so3_exp_map(): 88 | """ 89 | Tests the so3_exp_map() function, which maps a rotation vector to its 90 | corresponding rotation matrix. 91 | """ 92 | # test that the output shape is correct 93 | torch.manual_seed(123) 94 | num_rots = 10 95 | rot_vecs = torch.randn(num_rots, 3) 96 | rot_mats = so3_exp_map(rot_vecs) 97 | assert rot_mats.shape == ( 98 | 10, 99 | 3, 100 | 3, 101 | ), "Output shape should be (..., 3, 3)" 102 | 103 | assert torch.allclose( 104 | torch.det(rot_mats), torch.ones(num_rots) 105 | ), "Determinant should be 1" 106 | assert torch.allclose( 107 | torch.linalg.inv(rot_mats), torch.transpose(rot_mats, -1, -2) 108 | ), "Inverse should be transpose" 109 | 110 | # test that the exp map undoes log map 111 | assert torch.allclose( 112 | so3_log_map(rot_mats), rot_vecs 113 | ), "Exp map should be the inverse of log map" 114 | 115 | 116 | def test_so3_log_map(): 117 | """ 118 | Tests the so3_log_map() function, which maps a rotation matrix to its 119 | corresponding rotation vector. 120 | """ 121 | 122 | # test that the output shape is correct 123 | torch.manual_seed(123) 124 | num_rots = 10 125 | rot_mat = torch.randn(num_rots, 3, 3) 126 | assert so3_log_map(rot_mat).shape == (num_rots, 3), "Output shape should be (3,)" 127 | 128 | # test that identity matrix maps to zero vector 129 | assert torch.all( 130 | so3_log_map(torch.eye(3)) == torch.zeros(3) 131 | ), "Identity matrix should map to zero vector" 132 | 133 | # test that the log map undoes exp map 134 | rot_vec = torch.tensor([0.1, 0.2, 0.3]) 135 | rot_mat = so3_exp_map(rot_vec) 136 | assert torch.allclose( 137 | so3_log_map(rot_mat), rot_vec 138 | ), "Log map should be the inverse of exp map" 139 | 140 | 141 | def test_so3_hat_inv(): 142 | """ 143 | Tests the so3_hat() function, which maps a rotation vector to its 144 | corresponding skew-symmetric matrix. 145 | """ 146 | 147 | torch.manual_seed(123) 148 | num_rots = 10 149 | rot_vec = torch.randn(num_rots, 3) 150 | skew_sym = so3_hat_inv(rot_vec) 151 | 152 | assert skew_sym.shape == (num_rots, 3, 3), "Output shape should be (..., 3, 3)" 153 | assert torch.allclose( 154 | skew_sym, -torch.transpose(skew_sym, -1, -2) 155 | ), "Output matrix should be skew-symmetric" 156 | 157 | 158 | def test_so3_hat(): 159 | """ 160 | Tests the so3_hat_inv() function, which converts from a skew-symmetric matrix 161 | into a rotation vector. 162 | """ 163 | 164 | torch.manual_seed(123) 165 | num_rots = 10 166 | rot_vec = torch.randn(num_rots, 3) 167 | skew_sym = so3_hat_inv(rot_vec) 168 | output_rot_vec = so3_hat(skew_sym) 169 | 170 | assert torch.allclose( 171 | rot_vec, output_rot_vec 172 | ), "Calling hat() on the output of hat_inv() should be the identity" 173 | 174 | 175 | def test_node_type_subgraph(): 176 | """ 177 | Tests the node_type_subgraph() function, which returns a subgraph 178 | containing only the specified node types. 179 | """ 180 | num_nodes = 10 181 | num_features = 5 182 | node_type_0_features = torch.zeros(num_nodes // 2, num_features) 183 | node_type_1_features = torch.ones(num_nodes // 2, num_features) 184 | node_features = torch.cat( 185 | [ 186 | node_type_0_features, 187 | node_type_1_features, 188 | ], 189 | dim=0, 190 | ) 191 | edge_index = torch.cat( 192 | [ 193 | torch.arange(num_nodes).repeat_interleave(num_nodes), 194 | torch.arange(num_nodes).repeat(num_nodes), 195 | ], 196 | dim=0, 197 | ) 198 | node_types = torch.cat( 199 | [torch.zeros(num_nodes // 2), torch.ones(num_nodes // 2)], dim=0 200 | ).to(torch.long) 201 | graph = Data(x=node_features, node_type=node_types, edge_index=edge_index) 202 | 203 | node_type_0_subgraph = node_type_subgraph(graph, 0) 204 | node_type_1_subgraph = node_type_subgraph(graph, 1) 205 | 206 | assert torch.allclose( 207 | node_type_0_subgraph.x, node_type_0_features 208 | ), "Node type 0 subgraph should only contain node type 1 features" 209 | 210 | assert torch.allclose( 211 | node_type_1_subgraph.x, node_type_1_features 212 | ), "Node type 1 subgraph should only contain node type 1 features" 213 | --------------------------------------------------------------------------------